使用 index_select 切分数据
下面介绍一个index_select
API 对 tensor 数据进行切分。
a.index_select(2,torch.arange(28)).shape
- 2 这里维度也就是我们选择一个维度作为数据切分的依据
-
torch.arange(28)
在指定维度上切分的范围
torch.Size([4, 3, 28, 28])
因为在 2 维度上选取0 -27 维度范围也就是没有进行任何切分效果因为数据在此维度上就是 28 维度
a.index_select(0,torch.tensor([0,2])).shape
- 在 0 维度取 0 到 1 ,遵循左闭右开原则,所以取到第一个维度上前 2 条数据
torch.Size([2, 3, 28, 28])
习题
呵呵这里弄一个习题,给大家习题还是第一次,大家可以自己看看解释一下下面代码输出为什么是torch.Size([4, 3, 8, 28])
a.index_select(2,torch.arange(8)).shape
torch.Size([4, 3, 8, 28])
省略表示默认
a[...].shape
在计算机中编程一些符号有着其在编程中独特含义,不同语言也可能不同。但是合理的 API 设计是不用大家学习,一看就知道怎么用 API,这里省略号表示默认,所以这里省略号就是表示什么也不做,对数据没有切分操作。
torch.Size([4, 3, 28, 28])
我们看第一个,下面切分在第一位是 0 表示确定是第一张图片,后面省略好表示不做任何操作,所以切分出的数据是表示一张 大小 3 通道的图片。
a[0,...].shape
torch.Size([3, 28, 28])
如果在第 2 维度上取 1 表示,表示数据 2 维度确定都是取图片 1 通道,从图片通道 RGB 来看,这里描述就是我们得到 3 张图片都是取 G 通道数据
a[:,1,...].shape
torch.Size([4, 28, 28])
接下来看这段代码
a[0,...,::2].shape
通过遮罩来筛选数据
x = torch.randn(3,4)
通过 randn 随机生产 矩阵,然后通过条件进行筛选得到 mask 矩阵。在 mask 矩阵会根据条件生产一个矩阵,矩阵是由 True 和 False 来表示。
tensor([[-0.0040, 0.3439, 1.3629, 1.7692],
[-0.1891, -2.1325, -0.9377, -0.3534],
[-0.4318, 0.3152, 0.1341, -1.5351]])
mask = x.ge(0.5)
tensor([[False, True, True, False],
[ True, True, False, False],
[False, True, True, True]])
然后使用 masked_select 来利用 mask 进行筛选得到数据切分效果
torch.masked_select(x,mask)
tensor([1.3629, 1.7692])
网友评论