美文网首页pytorch学习
pytorch基础学习(五) 数据处理(二)

pytorch基础学习(五) 数据处理(二)

作者: SnowPye | 来源:发表于2020-06-14 01:12 被阅读0次

    本篇主要介绍pytorch中tensor的基本操作如:对tensor进行flatten操作, 对tensor进行拼接,tensor的broadcast(广播)机制.

    1. 对tensor进行flatten操作

    flatten顾名思义就是平展,将一个tensor由高维(rank)变为1维,元素的数量保持不变,这在深度学习中很常用,当输入是一个图像时候,图像的维度(rank)是3,当网络的输入层只能输入一维的数据时(如全连接层),flatten操作就显得非常有用了.

    下面我们说两种flatten的实现方式:使用上篇的squeeze,reshape函数间接实现;使用tensor自带的flatten函数实现.

    1. 使用squeeze,reshape函数间接实现

    实现代码如下,我们可以写一个flatten函数,该函数的功能是输入一个tensor,将它维度变为1输出.

    def flatten(t):
        t = t.reshape(1, -1)
        t = t.squeeze()
        return t
    

    函数的第一行先利用reshape函数将tensor变为2维度,具体可以参考上篇. 此时可以注意到,第一个axis的长度为1,因此第二行使用squeeze函数将长度为1的axis去掉,这时候tensor的维度自然而然就变成1啦!

    举个栗子吧:

    t = torch.tensor([
        [1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]
    ], dtype=torch.float32)
    print(t.shape)
    print(flatten(t))
    print(flatten(t).shape)
    
    output:
    torch.Size([3, 4])
    tensor([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.])
    torch.Size([12])
    

    可以看到,使用我们自己写的flatten函数,将tensor变成了1维,但这很不方便欸 ,有自带的干嘛不用呢,用它!

    2. 直接使用tensor的flatten函数

    tensor是有flatten函数的,话不多说,直接举个栗子:

    t1 = torch.ones(4, 4)
    t2 = torch.ones(4, 4) * 2
    t3 = torch.ones(4, 4) * 3
    t = torch.stack((t1, t2, t3)) 
    t = t.reshape(3, 1, 4, 4) 
    print(t.flatten(start_dim=1))
    print(t.flatten(start_dim=1).shape)
    print(t.reshape(t.shape[0], -1))
    print(t.flatten().shape)
    
    output:
    tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
            [2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
            [3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]])
    torch.Size([3, 16])
    tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
            [2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
            [3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]])
    torch.Size([48])
    

    可以看到flatten有个参数start_dim,表示从start_dim到最后一个维度都做平展操作,因此t.flatten(start_dim=1)后,就只有第0维保持不变,其它维度做了flatten,变成了一维,从而shape变为torch.Size([3, 16]),这在深度学习中,向全连接层输入时经常会使用到.

    当flatten没有参数时,默认将整个tensor进行flatten操作,即start_dim=0,因此经过t.flatten(),其shape变为torch.Size([48]).

    2. 对tensor进行拼接

    将2个甚至更多的tensor进行拼接是经常要使用到的功能,pytorch中对tensor拼接常用的函数为cat,举个栗子吧:

    t1 = torch.tensor([
        [1, 2],
        [3, 4]
    ])
    t2 = torch.tensor([
        [5, 6],
        [7, 8]
    ])
    print(torch.cat((t1, t2), dim=0))
    print(torch.cat((t1, t2), dim=0).shape)
    print(torch.cat((t1, t2), dim=1))
    print(torch.cat((t1, t2), dim=1).shape)
    
    output:
    torch.Size([12])
    tensor([[1, 2],
            [3, 4],
            [5, 6],
            [7, 8]])
    torch.Size([4, 2])
    tensor([[1, 2, 5, 6],
            [3, 4, 7, 8]])
    torch.Size([2, 4])
    

    可以看到,cat中的dim参数决定了拼接的维度.
    当dim=0时,在第0维(第1个axis)进行拼接 ,其余维度长度不变,因此shape变为torch.Size([4, 2]).
    当dim=1时,在第1维(第2个axis)进行拼接 ,其余维度长度不变,因此shape变为torch.Size([2, 4]).
    为了达成拼接的目的,很容易我们可以看出,对于要拼接的tensor(可以不止2个),除了需要拼接的维度,其余维度的长度必须保持相同,否则会引起错误.

    3. tensor的broadcast(广播)机制

    在数学运算中,两个形状不同的矩阵进行加减运算显然是不行的,但对于tensor,在某些形况下是完全可以的,这得益于pytorch中tensor的broadcast机制.

    举个栗子:

    t1 = torch.tensor([
        [1, 2],
        [3, 4]
    ])
    t2 = torch.tensor([
        [9, 8],
        [7, 6]
    ])
    print(t1 + t2)
    print(t1 + 2)
    print(t1 + torch.tensor(
        np.broadcast_to(2, t1.shape),
        dtype=torch.int32
    ))  # equal to last line
    
    output:
    tensor([[10, 10],
            [10, 10]])
    tensor([[3, 4],
            [5, 6]])
    tensor([[3, 4],
            [5, 6]])
    

    从倒数第二个print那里,我们惊讶的发现print(t1 + 2)竟然也可以运算!这得益于broadcast机制,等同于最后一个print中的语句,pytorch将2自动扩充成了与t1形状相同的tensor,这样当然就可以运算啦.

    那么有个大胆的想法,除了常数,不同形状的tensor是否也可以这样操作,举个栗子试一试:

    t1 = torch.tensor([
        [1, 2],
        [3, 4]
    ])
    t3 = torch.tensor([2, 4])
    print(t1 + t3)
    print(t1 + torch.tensor(
        np.broadcast_to(t3.numpy(), t1.shape),
        dtype=torch.int32
    ))  # equal to last line
    
    output:
    tensor([[3, 6],
            [5, 8]])
    tensor([[3, 6],
            [5, 8]])
    

    哈哈,果然是可以的,pytorch将t3自动扩充成了与t1形状相同的tensor(将t3又复制了一行)再与之运算. 这个特性很有意思,很方便我们书写代码.

    相关文章

      网友评论

        本文标题:pytorch基础学习(五) 数据处理(二)

        本文链接:https://www.haomeiwen.com/subject/jtxpzhtx.html