美文网首页
PyTorch 快速上手

PyTorch 快速上手

作者: 此番风景 | 来源:发表于2019-05-11 19:50 被阅读0次

    torch.Tensor数据类型

    torch.Tensor是一种包含单一数据类型元素的多维矩阵。

    Data tyoe CPU tensor GPU tensor
    32-bit floating point torch.FloatTensor torch.cuda.FloatTensor
    64-bit floating point torch.DoubleTensor torch.cuda.DoubleTensor
    16-bit floating point N/A torch.cuda.HalfTensor
    8-bit integer (unsigned) torch.ByteTensor torch.cuda.ByteTensor
    8-bit integer (signed) torch.CharTensor torch.cuda.CharTensor
    16-bit integer (signed) torch.ShortTensor torch.cuda.ShortTensor
    32-bit integer (signed) torch.IntTensor torch.cuda.IntTensor
    64-bit integer (signed) torch.LongTensor torch.cuda.LongTensor

    torch.Tensor是默认的tensor类型(torch.FlaotTensor)的简称

    会改变tensor的函数操作会用一个下划线后缀来标示。比如,torch.FloatTensor.abs_()会在原地计算绝对值,并返回改变后的tensor,而tensor.FloatTensor.abs()将会在一个新的tensor中计算结果。

    创建Tensor

    # uninitialized
    torch.empty()
    torch.FloatTensor()
    torch.IntTensor(d1,d2,d3)
    
    torch.tensor([1.2, 3]).type()
    # 设置默认数据类型
    torch.set_default_tensor_type(torch.DoubleTensor)
    
    # 随机初始化
    a = torch.rand(3,3)  #  [0,1]
    torch.rand_like(a)
    torch.randint(1,10,[3,3])  # [min, max]
    # 正态分布
    torch.randn(3,3)  # N(0,1)
    torch.normal(mean=torch.full([10], 0), std=torch.arange(1, 0, -0.1))
    
    torch.full([2,3], 7)  # 每个元素都设置为7
    torch.full([], 7)  # 标量
    torch.arange(0,10)
    
    # linspace/logspace
    torch.linspace(0,10, steps=4)
    torch.logspace(0, -1, steps=10)
    
    # ones/zeros/eye/*_like
    torch.ones(3,3)
    torch.zeros(3,3)
    torch.eye(3,4)
    
    # randperm == random.shuffle
    torch.randperm(10)
    
    

    Tensor 切片

    类似于numpy切片操作,eg: a[1:10,:], a[:10:2,:]

    a = torch.randn(4,3,28,28)
    a[:2]
    a[:2, 1:, :,:].shape # output: [2,2,28,28]
    
    # select by specific index
    a.index_select(0, torch.tensor([0,2])) 
    a[...].shape # 任意维度
    a[..., :2] # 与*list 变长解包类似?
    
    # select by mask
    x = torch.randn(3,4)
    mask = x.ge(0.5)
    torch.masked_select(x, mask)
    
    # select by flatten index
    src = torch.tensor([[4,3,5], [6,7,8]])
    torch.take(src, torch.tensor([0,2]))
    

    Tensor维度变换

    • view/reshape
    • squeeze/unsqueeze
    • transpose/t/permute
    • expand/repeat
    # view reshape  (lost dim information)
    In [41]: a = torch.rand(4,1 ,28, 28)
    
    In [42]: a.shape
    Out[42]: torch.Size([4, 1, 28, 28])
    
    In [43]: a.view(4, 28*28)
    Out[43]: 
    tensor([[0.6006, 0.8933, 0.1474,  ..., 0.5848, 0.9790, 0.6479],
            [0.1824, 0.8874, 0.1635,  ..., 0.3386, 0.3563, 0.0075],
            [0.8867, 0.9460, 0.1208,  ..., 0.1569, 0.2614, 0.7639],
            [0.1437, 0.5749, 0.2275,  ..., 0.5167, 0.6074, 0.5263]])
    In [44]: a.view(4, 28*28).shape
    Out[44]: torch.Size([4, 784])
    
    # unsqueeze(维度增加)
    In [50]: b = torch.rand(32)
    In [51]: f = torch.rand(4,32, 14,14)
    In [52]: b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
    In [53]: b.shape
    Out[53]: torch.Size([1, 32, 1, 1])
    # expand/repeat
    b.expand([4,32,14,14]) # [1,32,1,1] -> [4,32,14,14]
    b.repeat(4,1,32,32) # 重复
    
    # a.t() 2d数据
    # transpose
    a.transpose(1,3) # 指定交换的dim
    a.transpose(1,3).contiguous()
    
    # permute 交换维度
    # [b c h w] -> [b h w c]
    b.permute(0,2,3,1) # [b h w c]
    

    相关文章

      网友评论

          本文标题:PyTorch 快速上手

          本文链接:https://www.haomeiwen.com/subject/gblfaqtx.html