美文网首页
pytorch repeat 解析

pytorch repeat 解析

作者: 潘旭 | 来源:发表于2020-06-21 18:43 被阅读0次

    pytorch repeat 解析

    pytorch 中 Tensor.repeat 函数,能够将一个 tensor 从不同的维度上进行重复。这个能力在 Graph Attention Networks 中,有着应用。现在来看下,repeat 的能力是如何工作的?

    repeat(*sizes) → Tensor
    * sizes (torch.Size or int...) – The number of times to repeat this tensor along each dimension

    翻译过来:

    repeat 会将Tensor 在指定的维度方向上进行重复。比如设置参数是 2, 3, 4: 表示在 0 维方向上 重复2次,1 维方向上重复 3次, 2 维方向上重复4次。 注意这里的 2, 3, 4 不是指的维度方向,而是 0:2, 1:3, 2:4 在不同的维度上重复的次数。同时,也会进行维度的扩充。

    import torch
    
    def repeat_1():
    
        x = torch.tensor([1, 2, 3])
        print(f"x shape: {x.size()} : {x}")
        
        print(f"在 0 维上 重复2次 ---")
        xx = x.repeat(2)
        
        print(f"xx shape: {xx.size()}, {xx}")
        
        print(f"在 0 维上重复2次, 1 维上重复3次")
        xx = x.repeat(2, 3)
        print(f"xx shape: {xx.size()}, {xx}")
    
    repeat_1()
    
    
    x shape: torch.Size([3]) : tensor([1, 2, 3])
    在 0 维上 重复2次 ---
    xx shape: torch.Size([6]), tensor([1, 2, 3, 1, 2, 3])
    在 0 维上重复2次, 1 维上重复3次
    xx shape: torch.Size([2, 9]), tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
            [1, 2, 3, 1, 2, 3, 1, 2, 3]])
    

    上面演示的是 x 只有一个维度,现在演示 2 个维度的。比如在 "Graph Attention Networds", 需要计算 图的 attention, 那么需要将图上所有的节点进行两两拼接。也就是说:

    Node = \{a_1, a_2, ..., a_n\}, a_i \in \mathbb{R}^{channel}

    经过两两拼接后, 形成的拼接 graph, 如下:

    Graph = \begin{bmatrix} a_{11} & a_{12} & ... & a_{1n}\\ a_{21} & a_{22} & ... & a_{2n}\\ ...\\ a_{n1} & a_{n2} & ... & a_{nn} \end{bmatrix}

    其中 a_{ij} 表示 [a_i || a_j] 表示两个向量拼接,拼接后的维度是 2 \times channel

    现在,当我们有了一个 Node 的矩阵,如何拼接出 Graph 就用到了 repeate.

    没有 batch 的情况

    Node \in \mathbb{R}^{N \times C}, 经过转换后 Graph \in \mathbb{R}^{N \times N \times 2C}

    开始时候:

    Node = \begin{bmatrix} a_1\\ a_{2}\\ \end{bmatrix}

    其中 a_i \in \mathbb{R}^{c}

    在下面的例子中 Node \in \mathbb{R}^{2 \times 3}

    import torch
    
    # node = [a1, a2]
    node = torch.tensor([
                            [1, 2, 3],
                            [4, 5, 6]
                        ])
    n = node.size(0)
    c = node.size(-1)
    print(f"node: {node.size()}, {node}")
    
    node: torch.Size([2, 3]), tensor([[1, 2, 3],
            [4, 5, 6]])
    

    现将 Node变成:

    Node\_repeat\_1 = \begin{bmatrix} a_1\\ a_{2}\\ a_1\\ a_{2}\\ \end{bmatrix}

    沿着 n 的方向上重复 n次, c 的方向上不便
    重复完的 node_repeat_1 shape: (n*n, c) 也就是 (4, 3)

    node_repeat_1 = node.repeat(n, 1)
    
    print(f"node_repeat_1 shape: {node_repeat_1.size()}, {node_repeat_1}")
    
    node_repeat_1 shape: torch.Size([4, 3]), tensor([[1, 2, 3],
            [4, 5, 6],
            [1, 2, 3],
            [4, 5, 6]])
    

    想要链接在一起来不够,还要产生一个 Node\_repeat\_2,

    Node\_repeat\_2 = \begin{bmatrix} a_1\\ a_1\\ a_2\\ a_2\\ \end{bmatrix}

    这样 Node\_repeat\_1Node\_repeat\_2 经过 concat 操作就能够得到我们需要的 graph 了。

    直接做这件事,需要一点技巧,在 c 这个方向上重复 n 次,然后在做一个 view变换。

    node_repeat_2_tmp = node.repeat(1, n)
    print(f"node_repeat_2_tmp: {node_repeat_2_tmp.size()}, {node_repeat_2_tmp}")
    
    node_repeat_2_tmp: torch.Size([2, 6]), tensor([[1, 2, 3, 1, 2, 3],
            [4, 5, 6, 4, 5, 6]])
    
    node_repeat_2 = node_repeat_2_tmp.view(-1, c)
    print(f"node_repeat_2:  {node_repeat_2.size()}, {node_repeat_2}")
    
    node_repeat_2:  torch.Size([4, 3]), tensor([[1, 2, 3],
            [1, 2, 3],
            [4, 5, 6],
            [4, 5, 6]])
    

    最后,将 node_repeat_1 与 node_repeat_2 concat 在一起就是 graph了。 注意: node_repeat_2 在前面, node_repeat_1在后面,因为 graph 第 i 行是:
    graph[i] = [a_{i1}, a_{i2}, ..., a_{in}]
    所以,需要 a_{i} 与其他的所有相连接,所以需要 node_repeat_2 在前面。

    graph = torch.cat((node_repeat_2, node_repeat_1), dim=-1)
    print(f"graph:  {graph.size()}, {graph}")
    
    graph:  torch.Size([4, 6]), tensor([[1, 2, 3, 1, 2, 3],
            [1, 2, 3, 4, 5, 6],
            [4, 5, 6, 1, 2, 3],
            [4, 5, 6, 4, 5, 6]])
    
    graph_pretty = graph.view(n, n, 2 * c)
    print(f"graph_pretty:  {graph_pretty.size()}, {graph_pretty}")
    
    graph_pretty:  torch.Size([2, 2, 6]), tensor([[[1, 2, 3, 1, 2, 3],
             [1, 2, 3, 4, 5, 6]],
    
            [[4, 5, 6, 1, 2, 3],
             [4, 5, 6, 4, 5, 6]]])
    

    上面的变换,用一个函数来表示:

    def single_graph(node: torch.Tensor):
        assert node.dim() == 2
        
        n = node.size(0)
        c = node.size(1)
        
        repeat_1 = node.repeat(1, n).view(-1, c)
        
        assert repeat_1.size(), (n*n, c)
        
        repeat_2 = node.repeat(n, 1)
        assert repeat_2.size() == (n*n, c)
        
        graph = torch.cat((repeat_1, repeat_2), dim=-1)
        
        assert graph.size() == (n*n, 2*c)
        
        graph = graph.view(n, n, 2*c)
        assert graph.size() == (n, n, 2*c)
        return graph
    
    single_graph(node)
    
    tensor([[[1, 2, 3, 1, 2, 3],
             [1, 2, 3, 4, 5, 6]],
    
            [[4, 5, 6, 1, 2, 3],
             [4, 5, 6, 4, 5, 6]]])
    

    带有 batch size 的 Grap 构建

    前面介绍了没有 batch size 的构建方式,但是 很多时候是有 batch size 的那么构建方式就发生了变化。

    batch_node = torch.tensor([
                                [
                                    [1, 2, 3],
                                    [4, 5, 6]
                                ],
                                [
                                    [7, 8, 9],
                                    [10, 11, 12]
                                ]
                            ])
    
    def batch_graph(batch_node: torch.Tensor):
        
        assert batch_node.dim() == 3
        
        batch_size = batch_node.size(0)
        n = batch_node.size(1)
        c = batch_node.size(2)
        
        print(f"batch node shape: {batch_node.size()}")
        
        # 对 node 进行repeat, batch_size 不变, c 进行 n 次重复
        
        repeat_1 = batch_node.repeat(1, 1, n)
        
        print(f"repeat_1 shape: {repeat_1.size()} \n {repeat_1}")
        
        # view 转换回正确的数据
        repeat_1 = repeat_1.view(-1, n*n, c)
        
        print(f"repeat_1 shape: {repeat_1.size()}")
        
        assert repeat_1.size() == (batch_size, n*n, c)
    
        repeat_2 = batch_node.repeat(1, n, 1)
        
        assert repeat_2.size() == (batch_size, n*n, c)
        
        graph = torch.cat((repeat_1, repeat_2), dim=-1)
        
        assert graph.size() == (batch_size, n*n, 2*c)
        
        graph = graph.view(-1, n, n, 2*c)
        
        return graph
    
    graph = batch_graph(batch_node)
    
    print(f"graph size: {graph.size()} \n {graph}")
        
    
    batch node shape: torch.Size([2, 2, 3])
    repeat_1 shape: torch.Size([2, 2, 6]) 
     tensor([[[ 1,  2,  3,  1,  2,  3],
             [ 4,  5,  6,  4,  5,  6]],
    
            [[ 7,  8,  9,  7,  8,  9],
             [10, 11, 12, 10, 11, 12]]])
    repeat_1 shape: torch.Size([2, 4, 3])
    graph size: torch.Size([2, 2, 2, 6]) 
     tensor([[[[ 1,  2,  3,  1,  2,  3],
              [ 1,  2,  3,  4,  5,  6]],
    
             [[ 4,  5,  6,  1,  2,  3],
              [ 4,  5,  6,  4,  5,  6]]],
    
    
            [[[ 7,  8,  9,  7,  8,  9],
              [ 7,  8,  9, 10, 11, 12]],
    
             [[10, 11, 12,  7,  8,  9],
              [10, 11, 12, 10, 11, 12]]]])
    
    
    

    相关文章

      网友评论

          本文标题:pytorch repeat 解析

          本文链接:https://www.haomeiwen.com/subject/puiaxktx.html