美文网首页
torch.roll的用法

torch.roll的用法

作者: 炼己者 | 来源:发表于2021-04-13 17:10 被阅读0次

    作者:炼己者

    本博客所有内容以学习、研究和分享为主,如需转载,请联系本人,标明作者和出处,并且是非商业用途,谢谢!


    最近在看swin transformer的代码,在这里记录一下遇到的torch.roll的用法。

    # swin transformer中移位的操作
    class CyclicShift(nn.Module):
        def __init__(self, displacement):
            super().__init__()
            self.displacement = displacement
    
        def forward(self, x):
            return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))
    

    torch.roll(input, shifts, dims=None) 这个函数到底是干啥的,咋用的呢?
    简单说,它就是用来移位的,是顺移。input是咱们要移动的tensor向量,shifts是要移动到的位置,要移动去哪儿,dims是值在什么方向上(维度)去移动。比如2维的数据,那就两个方向,横着或者竖着。对了,关键的一句话,所有操作针对第一行或者第一列,主要是这个"第一",下面举例子给大家做解释,自己慢慢体会

    >>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3)
    >>> x
    tensor([[1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]])
    
    >>> torch.roll(x, 1, 0)
    tensor([[7, 8, 9],
            [1, 2, 3],
            [4, 5, 6]])
    

    torch.roll(x, 1, 0) 这行代码的意思就是把x的第一行(0维度)移到1这个位置上,其他位置的数据顺移。
    x——咱们要移动的向量
    1——第一行向量要移动到的最终位置
    0——从行的角度去移动

    再来一个列的例子

    >>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3)
    >>> x
    tensor([[1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]])
    
    >>> torch.roll(x, -1, 1)
    >>> x
    tensor([[2, 3, 1],
            [5, 6, 4],
            [8, 9, 7]])
    

    torch.roll(x, -1, 1) 这行代码的意思就是把x的第一列(1维度)移到-1这个位置(最后一个位置)上,其他位置的数据顺移。

    shifts和dims可以是元组,其实就是分步骤去移动,再举个例子

    >>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3)
    >>> x
    tensor([[1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]])
    
    >>> torch.roll(x, (0,1), (1,1))
    tensor([[3, 1, 2],
            [6, 4, 5],
            [9, 7, 8]])
    
    

    torch.roll(x, (0,1), (1,1)) 这行代码的意思:
    第一步,把x的第一列(1维度)移到0这个位置(原地不动)上,其他位置的数据顺移。(所有数据原地不动)

    >>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3)
    >>> x
    tensor([[1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]])
    
    >>> a = torch.roll(x, 0, 1)
    >>> a
    tensor([[1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]])
    

    第二步,把a的第一列(1维度)移到1这个位置上,其他位置的数据顺移。

    >>> a = torch.roll(x, 0, 1)
    >>> a
    tensor([[1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]])
    
    >>> b = torch.roll(a, 1, 1)
    >>> b
    tensor([[3, 1, 2],
            [6, 4, 5],
            [9, 7, 8]])
    

    以上便是torch.roll()函数的用法,大家有疑问可以在底下评论交流,跑跑代码,便可迎刃而解

    以下是我所有文章的目录,大家如果感兴趣,也可以前往查看
    👉戳右边:打开它,也许会看到很多对你有帮助的文章

    相关文章

      网友评论

          本文标题:torch.roll的用法

          本文链接:https://www.haomeiwen.com/subject/tqsclltx.html