torch.narrow()
函数是用来返回Tensor
的切片的,它的使用方法如下:
torch.narrow(input, dim, start, length)
- input– 待处理的tensor
- dim – 维度,当为0时以行为单位进行切片,当为1时以列为单位进行切片
- start – 切片开始的索引
- length – 切片的长度
下面用一个例子来加以说明:
>>> a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
>>> torch.narrow(a, 0, 0, 2)
tensor([[1, 2, 3],
[4, 5, 6]])
>>> torch.narrow(a, 1, 1, 2)
tensor([[2, 3],
[5, 6],
[8, 9]])
如果不传入input
,也可以直接对tensor进行操作:
>>> a.narrow(0,0,2)
tensor([[1, 2, 3],
[4, 5, 6]])
同时,Numpy上的快速切片方法在Pytorch上也同样适用:
>>> a[:,0:2]
tensor([[1, 2],
[4, 5],
[7, 8]])
参考:
pytorch官网
网友评论