美文网首页
3.Pytorch 中 torch.stack()/vstack

3.Pytorch 中 torch.stack()/vstack

作者: yoyo9999 | 来源:发表于2021-03-24 00:55 被阅读0次

    3.Pytorch 中 torch.stack()/vstack()/hstack()和torch.cat()

    1.torch.stack()

    torch.stack(tensors, dim=0, *, out=None) → Tensor

    作用:

    Concatenates a sequence of tensors along a new dimension. All tensors need to be of the same size.

    把一系列tensor沿着新的维度堆起来。注意要tensor都一样的size,并且会增加一个维度。默认,dim=0.

    x = torch.arange(9).view(3,3)
    print(x)
    print("---")
    new_x = torch.stack([x, x, x])
    print(new_x.shape)
    print(new_x)
    
    ================================================================
    tensor([[0, 1, 2],
            [3, 4, 5],
            [6, 7, 8]])
    ---
    torch.Size([3, 3, 3])
    tensor([[[0, 1, 2],
             [3, 4, 5],
             [6, 7, 8]],
    
            [[0, 1, 2],
             [3, 4, 5],
             [6, 7, 8]],
    
            [[0, 1, 2],
             [3, 4, 5],
             [6, 7, 8]]])
    
    

    如果指定维数的话,

    • dim=0时, 这个维度是3, 将其看作3行,那么特征数是2。将2个特征列依次叠加。
    • dim=1时, 这个维度是2, 将其看作2列,那么特征数是3。将3个特征行依次叠加。
    a = torch.arange(0, 6).view((3, 2))
    b = torch.arange(6, 12).view((3, 2))
    print('a:', a)
    print('b:', b)
    ab0= torch.stack((a, b), dim=0)
    ab1 = torch.stack((a, b), dim=1)
    print(ab0, '\n', ab1)
    
    +++++++++++++++++++++++++++++++++++++++++++
    a: tensor([[0, 1],
            [2, 3],
            [4, 5]])
    b: tensor([[ 6,  7],
            [ 8,  9],
            [10, 11]])
    tensor([[[ 0,  1],
             [ 2,  3],
             [ 4,  5]],
    
            [[ 6,  7],
             [ 8,  9],
             [10, 11]]]) 
     tensor([[[ 0,  1],
             [ 6,  7]],
    
            [[ 2,  3],
             [ 8,  9]],
    
            [[ 4,  5],
             [10, 11]]])
    
    

    2. torch.vstack()和 torch.hstack()

    torch.vstack(tensors, *, out=None) → Tensorhstack()

    作用:这两个方法在1.8.0之后才支持,没有就用torch.cat()

    在竖直、水平方向上堆tensor

    ab_vstack_0 = torch.vstack((a, b))
    ab_vstack_1 = torch.vstack((a, b))
    print('ab_vstack_0 :', ab_vstack_0 )
    print('ab_vstack_1 :', ab_vstack_1 )
    print(torch.__version__)
    
    ++++++++++++++++++++++++++++++++++++++++++++
    ab_vstack_0 : tensor([[ 0,  1],
            [ 2,  3],
            [ 4,  5],
            [ 6,  7],
            [ 8,  9],
            [10, 11]])
    ab_vstack_1 : tensor([[ 0,  1],
            [ 2,  3],
            [ 4,  5],
            [ 6,  7],
            [ 8,  9],
            [10, 11]])
    1.8.0+cu101
    

    3. torch.cat()

    torch.cat(tensors, dim=0, *, out=None) → Tensortorch.stack()区别是:不增加维度

    作用:

    Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.

    按照指定维度连接tensor,所有tensor必须有同样的shape, 除了指定合并的维度或者是空tensor。

    ab_cat_0 = torch.cat((a, b), dim=0)
    ab_cat_1 = torch.cat((a, b), dim=1)
    print('ab_cat_0 :', ab_cat_0 )
    print('ab_cat_1 :', ab_cat_1 )
    
    ++++++++++++++++++++++++++++++++++++
    ab_cat_0 : tensor([[ 0,  1],
            [ 2,  3],
            [ 4,  5],
            [ 6,  7],
            [ 8,  9],
            [10, 11]])
    ab_cat_1 : tensor([[ 0,  1,  6,  7],
            [ 2,  3,  8,  9],
            [ 4,  5, 10, 11]])
    
    

    相关文章

      网友评论

          本文标题:3.Pytorch 中 torch.stack()/vstack

          本文链接:https://www.haomeiwen.com/subject/hxqchltx.html