美文网首页
torch.narrow()函数使用详解

torch.narrow()函数使用详解

作者: 一位学有余力的同学 | 来源:发表于2021-04-12 22:21 被阅读0次

torch.narrow()函数是用来返回Tensor的切片的,它的使用方法如下:

torch.narrow(input, dim, start, length)

  • input– 待处理的tensor
  • dim – 维度,当为0时以行为单位进行切片,当为1时以列为单位进行切片
  • start – 切片开始的索引
  • length – 切片的长度

下面用一个例子来加以说明:

>>> a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
>>> torch.narrow(a, 0, 0, 2)
tensor([[1, 2, 3],
        [4, 5, 6]])
>>> torch.narrow(a, 1, 1, 2)
tensor([[2, 3],
        [5, 6],
        [8, 9]])

如果不传入input,也可以直接对tensor进行操作:

>>> a.narrow(0,0,2)
tensor([[1, 2, 3],
        [4, 5, 6]])

同时,Numpy上的快速切片方法在Pytorch上也同样适用:

>>> a[:,0:2]
tensor([[1, 2],
        [4, 5],
        [7, 8]])

参考:
pytorch官网

相关文章

网友评论

      本文标题:torch.narrow()函数使用详解

      本文链接:https://www.haomeiwen.com/subject/fqyflltx.html