pytorch 切片(下)

作者: zidea | 来源:发表于2020-08-18 20:13 被阅读0次
    slice.jpeg

    使用 index_select 切分数据

    下面介绍一个index_selectAPI 对 tensor 数据进行切分。

    a.index_select(2,torch.arange(28)).shape
    
    • 2 这里维度也就是我们选择一个维度作为数据切分的依据
    • torch.arange(28)在指定维度上切分的范围
    torch.Size([4, 3, 28, 28])
    

    因为在 2 维度上选取0 -27 维度范围也就是没有进行任何切分效果因为数据在此维度上就是 28 维度

    a.index_select(0,torch.tensor([0,2])).shape
    
    • 在 0 维度取 0 到 1 ,遵循左闭右开原则,所以取到第一个维度上前 2 条数据
    torch.Size([2, 3, 28, 28])
    

    习题

    呵呵这里弄一个习题,给大家习题还是第一次,大家可以自己看看解释一下下面代码输出为什么是torch.Size([4, 3, 8, 28])

    a.index_select(2,torch.arange(8)).shape
    
    torch.Size([4, 3, 8, 28])
    

    省略表示默认

    a[...].shape
    

    在计算机中编程一些符号有着其在编程中独特含义,不同语言也可能不同。但是合理的 API 设计是不用大家学习,一看就知道怎么用 API,这里省略号表示默认,所以这里省略号就是表示什么也不做,对数据没有切分操作。

    torch.Size([4, 3, 28, 28])
    

    我们看第一个,下面切分在第一位是 0 表示确定是第一张图片,后面省略好表示不做任何操作,所以切分出的数据是表示一张 28 \times 28 大小 3 通道的图片。

    a[0,...].shape
    
    torch.Size([3, 28, 28])
    

    如果在第 2 维度上取 1 表示,表示数据 2 维度确定都是取图片 1 通道,从图片通道 RGB 来看,这里描述就是我们得到 3 张图片都是取 G 通道数据

    a[:,1,...].shape
    
    torch.Size([4, 28, 28])
    

    接下来看这段代码

    a[0,...,::2].shape
    

    通过遮罩来筛选数据

    x = torch.randn(3,4)
    

    通过 randn 随机生产 3 \times 4 矩阵,然后通过条件进行筛选得到 mask 矩阵。在 mask 矩阵会根据条件生产一个矩阵,矩阵是由 True 和 False 来表示。

    tensor([[-0.0040,  0.3439,  1.3629,  1.7692],
            [-0.1891, -2.1325, -0.9377, -0.3534],
            [-0.4318,  0.3152,  0.1341, -1.5351]])
    
    mask = x.ge(0.5)
    
    tensor([[False,  True,  True, False],
            [ True,  True, False, False],
            [False,  True,  True,  True]])
    

    然后使用 masked_select 来利用 mask 进行筛选得到数据切分效果

    torch.masked_select(x,mask)
    
    tensor([1.3629, 1.7692])
    

    相关文章

      网友评论

        本文标题:pytorch 切片(下)

        本文链接:https://www.haomeiwen.com/subject/tivhjktx.html