美文网首页
pytorch 乘法运算汇总与解析

pytorch 乘法运算汇总与解析

作者: 潘旭 | 来源:发表于2020-06-21 23:20 被阅读0次

    pytorch 有多种乘法运算,在这里做一次全面的总结。

    元素一一相乘

    该操作又称作 "哈达玛积", 简单来说就是 tensor 元素逐个相乘。这个操作,是通过 * 也就是常规的乘号操作符定义的操作结果。torch.mul 是等价的。

    import torch
    
    def element_by_element():
        
        x = torch.tensor([1, 2, 3])
        y = torch.tensor([4, 5, 6])
        
        return x * y, torch.mul(x, y)
    
    element_by_element()
    
    (tensor([ 4, 10, 18]), tensor([ 4, 10, 18]))
    

    这个操作是可以 broad cast 的。

    
    def element_by_element_broadcast():
        
        x = torch.tensor([1, 2, 3])
        y = 2
        
        return x * y
    
    element_by_element_broadcast()
    
    tensor([2, 4, 6])
    

    向量点乘

    torch.matmul: If both tensors are 1-dimensional, the dot product (scalar) is returned.

    如果都是1维的,返回的就是 dot product 结果

    def vec_dot_product():
        
        x = torch.tensor([1, 2, 3])
        y = torch.tensor([4, 5, 6])
        
        return torch.matmul(x, y)
    vec_dot_product()
    
    tensor(32)
    

    矩阵乘法

    torch.matmul: If both arguments are 2-dimensional, the matrix-matrix product is returned.

    如果都是2维,那么就是矩阵乘法的结果返回。与 torch.mm 是等价的,torch.mm 仅仅能处理的是矩阵乘法。

    def matrix_multiple():
        
        x = torch.tensor([
            [1, 2, 3],
            [4, 5, 6]
        ])
        y = torch.tensor([
            [7, 8],
            [9, 10],
            [11, 12]
        ])
        
        return torch.matmul(x, y), torch.mm(x, y)
    
    matrix_multiple()
    
    (tensor([[ 58,  64],
             [139, 154]]), tensor([[ 58,  64],
             [139, 154]]))
    

    vector 与 matrix 相乘

    torch.matmul: If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.

    如果第一个是 vector, 第二个是 matrix, 会在 vector 中增加一个维度。也就是 vector 变成了 shape: (1 \times N) 与 matrix (N \times M) 相乘之后,变成 (1 \times M), 在结果中将 1 维 再去掉。

    def vec_matrix():
        x = torch.tensor([1, 2, 3])
        y = torch.tensor([
            [7, 8],
            [9, 10],
            [11, 12]
        ])
        
        return torch.matmul(x, y)
    
    vec_matrix()
    
    tensor([58, 64])
    

    matrix 与 vector 相乘

    同样的道理, vector会被扩充一个维度。

    def matrix_vec():
        x = torch.tensor([
            [1, 2, 3],
            [4, 5, 6]
        ])
        y = torch.tensor([
            7, 8, 9
        ])
        
        return torch.matmul(x, y)
    
    matrix_vec()
    
    tensor([ 50, 122])
    

    带有batch_size 的 broad cast乘法

    def batched_matrix_broadcasted_vector():
        x = torch.tensor([
            [
                [1, 2], [3, 4]
            ],
            [
                [5, 6], [7, 8]
            ]
        ])
        
        print(f"x shape: {x.size()} \n {x}")
        y = torch.tensor([1, 3])
        
        return torch.matmul(x, y)
    
    batched_matrix_broadcasted_vector()
    
    x shape: torch.Size([2, 2, 2]) 
     tensor([[[1, 2],
             [3, 4]],
    
            [[5, 6],
             [7, 8]]])
    
    
    
    
    
    tensor([[ 7, 15],
            [23, 31]])
    

    batched matrix x batched matrix

    def batched_matrix_batched_matrix():
        x = torch.tensor([
            [
                [1, 2, 1], [3, 4, 4]
            ],
            [
                [5, 6, 2], [7, 8, 0]
            ]
        ])
        
    
        y = torch.tensor([
            [
                [1, 2], 
                [3, 4], 
                [5, 6]
            ],
            [
                [7, 8], 
                [9, 10], 
                [1, 2]
            ]
        ])
        
        
        print(f"x shape: {x.size()} \n y shape: {y.size()}")
        return torch.matmul(x, y)
    
    xy = batched_matrix_batched_matrix()
    print(f"xy shape: {xy.size()} \n {xy}")
    
    x shape: torch.Size([2, 2, 3]) 
     y shape: torch.Size([2, 3, 2])
    xy shape: torch.Size([2, 2, 2]) 
     tensor([[[ 12,  16],
             [ 35,  46]],
    
            [[ 91, 104],
             [121, 136]]])
    

    上面的效果与 torch.bmm 是一样的。matmulbmm 功能更加强大,但是 bmm 的语义非常明确, bmm 处理的只能是 3维的。

    def batched_matrix_batched_matrix_bmm():
        x = torch.tensor([
            [
                [1, 2, 1], [3, 4, 4]
            ],
            [
                [5, 6, 2], [7, 8, 0]
            ]
        ])
        
    
        y = torch.tensor([
            [
                [1, 2], 
                [3, 4], 
                [5, 6]
            ],
            [
                [7, 8], 
                [9, 10], 
                [1, 2]
            ]
        ])
        
        
        print(f"x shape: {x.size()} \n y shape: {y.size()}")
        return torch.bmm(x, y)
    
    xy = batched_matrix_batched_matrix()
    print(f"xy shape: {xy.size()} \n {xy}")
    
    x shape: torch.Size([2, 2, 3]) 
     y shape: torch.Size([2, 3, 2])
    xy shape: torch.Size([2, 2, 2]) 
     tensor([[[ 12,  16],
             [ 35,  46]],
    
            [[ 91, 104],
             [121, 136]]])
    

    tensordot

    这个函数还没有特别清楚。

    def tesnordot():
        
        x = torch.tensor([
            [1, 2, 1], 
            [3, 4, 4]])
        
    
        y = torch.tensor([
            [7, 8], 
            [9, 10], 
            [1, 2]])
        
        print(f"x shape: {x.size()}, y shape: {y.size()}")
        return torch.tensordot(x, y, dims=([0], [1]))
    
    tesnordot()
    
    x shape: torch.Size([2, 3]), y shape: torch.Size([3, 2])
    
    
    
    
    
    tensor([[31, 39,  7],
            [46, 58, 10],
            [39, 49,  9]])
    
    
    

    相关文章

      网友评论

          本文标题:pytorch 乘法运算汇总与解析

          本文链接:https://www.haomeiwen.com/subject/hsozxktx.html