torch.cat()与torch.stack()函数

作者: 午字横 | 来源:发表于2022-11-28 14:47 被阅读0次

    torch.cat()

    import torch
    
    x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
    x1.shape # torch.Size([2, 3])
    # x2
    x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
    x2.shape  # torch.Size([2, 3])
    
    inputs = [x1, x2]
    print(inputs)
    
    x=torch.cat(inputs,dim=0)
    print(x.shape)
    

    outputs = torch.cat(inputs, dim=?)
    dim代表在哪个维度上进行堆叠
    inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列
    dim : 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列。

    torch.stack()

    T1 = torch.tensor([[1, 2, 3],
                    [4, 5, 6],
                    [7, 8, 9]])
    
    T2 = torch.tensor([[10, 20, 30],
                    [40, 50, 60],
                    [70, 80, 90]])
    print(T1.shape)
    print(T2.shape)
    print(torch.stack((T1,T2),dim=0).shape)
    print(torch.stack((T1,T2),dim=1).shape)
    print(torch.stack((T1,T2),dim=2).shape)
    

    outputs = torch.stack(inputs, dim=?)
    dim代表要生成的维度是哪个

    inputs: 待连接的张量序列。
    注:python的序列数据只有listtuple

    dim : 新的维度, 必须在0len(outputs)之间。
    注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。

    相关文章

      网友评论

        本文标题:torch.cat()与torch.stack()函数

        本文链接:https://www.haomeiwen.com/subject/ffxkfdtx.html