美文网首页
Pytorch中torch.max()函数解析

Pytorch中torch.max()函数解析

作者: 逍遥_yjz | 来源:发表于2022-10-29 10:02 被阅读0次

    一. torch.max()函数解析

    1. 官网链接

    torch.max,如下图所示:

    2. torch.max(input)函数解析

    torch.max(input) → Tensor
    
    

    将输入input张量,无论有几维,首先将其reshape排列成一个一维向量,然后找出这个一维向量里面最大值

    3. 代码举例

    3.1 输入一维张量,返回一维张量里面最大值

    x = torch.randn(4)
    y = torch.max(x)
    x,y
    
    
    输出结果如下:
    (tensor([-0.6223,  0.0043, -0.8753,  1.4240]), tensor(1.4240))
    
    

    3.2 输入二维张量,返回二维张量里面最大值

    x = torch.randn(3,4)
    y = torch.max(x)
    x,y
    
    
    输出结果如下:
    (tensor([[-1.1052,  0.1026,  0.9994, -0.3092],
             [-0.8400,  0.2004,  0.9212,  0.7807],
             [-1.2979, -0.4327,  2.3044,  0.0140]]),
     tensor(2.3044))
    
    

    3.3 输入两个一维张量,输出这两个张量里面相应元素中的最大值

    x = torch.randn(4)
    z = torch.randn(4)
    max = torch.max(x,z)
    x,z,max
    
    
    输出结果如下:
    (tensor([-1.5147, -1.2790, -1.0159, -0.4732]),
     tensor([-0.4547, -2.8545,  0.0554, -0.3548]),
     tensor([-0.4547, -1.2790,  0.0554, -0.3548]))
    
    

    3.4 输入两个张量,一个张量一维,一个张量二维,此时一维张量会进行广播成二维张量,然后再输出这两个张量里面相应元素中的最大值,输出张量为二维。

    x = torch.randn(3,4)
    z = torch.randn(4)
    max = torch.max(x,z)
    x,z,max
    
    
    输出结果如下:
    (tensor([[ 1.1917,  0.6338,  0.7590, -0.9802],
             [ 0.2247,  0.3635,  1.3743,  1.6229],
             [ 1.6165,  0.0634,  0.5259,  0.1285]]),
     tensor([3.4765, 0.4480, 0.1502, 0.3738]),
     tensor([[3.4765, 0.6338, 0.7590, 0.3738],
             [3.4765, 0.4480, 1.3743, 1.6229],
             [3.4765, 0.4480, 0.5259, 0.3738]]))
    
    

    3.5 输入两个二维张量,输出这两个张量里面相应元素中的最大值,输出张量为二维。

    x = torch.randn(3,4)
    z = torch.randn(3,4)
    max = torch.max(x,z)
    x,z,max
    
    
    输出结果如下:
    (tensor([[-0.0835,  0.0718, -1.7404, -0.3218],
             [ 0.0577,  0.6271,  1.4014, -0.6417],
             [ 0.3917,  0.0761,  1.2479, -0.4352]]),
     tensor([[-0.0717,  0.3822,  0.7256,  1.4147],
             [-0.1271,  0.1503,  0.3934,  1.6760],
             [-2.2341,  2.5286, -0.3500, -0.1751]]),
     tensor([[-0.0717,  0.3822,  0.7256,  1.4147],
             [ 0.0577,  0.6271,  1.4014,  1.6760],
             [ 0.3917,  2.5286,  1.2479, -0.1751]]))
    
    

    4. torch.max(input,dim)函数解析

    torch.max(input, dim, keepdim=False, *, out=None)
    
    

    输入input(二维)张量,当dim=0时表示找出每列的最大值,函数会返回两个tensor,第一个tensor是每列的最大值,第二个tensor是每列最大值的索引;当dim=1时表示找出每行的最大值,函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引。

    5. 代码举例

    5.1 dim=0,找出每列的最大值,函数会返回两个tensor,第一个tensor是每列的最大值,第二个tensor是每列最大值的索引,两个tensor都是一维。

    x = torch.randn(3,4)
    max,indices = torch.max(x,dim=0)
    x,max,indices
    
    
    (tensor([[ 0.1806,  1.0274,  0.5138, -1.4184],
             [ 0.5892, -0.7117, -1.2707,  0.7682],
             [ 0.5152, -0.8803,  1.7604,  0.4852]]),
     torch.return_types.max(
     values=tensor([0.5892, 1.0274, 1.7604, 0.7682]),
     indices=tensor([1, 0, 2, 1])))
    
    
    输出结果如下:
    (tensor([[ 0.0190,  0.8180, -1.0463,  1.7940],
             [ 0.7537, -1.0291, -2.3431,  0.3906],
             [ 0.3715,  1.6940, -1.1200, -0.4580]]),
     tensor([ 0.7537,  1.6940, -1.0463,  1.7940]),
     tensor([1, 2, 0, 0]))
    
    

    5.2 dim=1,找出每行的最大值,函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引,两个tensor都是一维。

    x = torch.randn(3,4)
    max,indices = torch.max(x,dim=1)
    x,max,indices
    
    
    输出结果如下:
    (tensor([[ 1.4832,  0.1886, -0.3044, -0.6111],
             [-0.8998,  0.0610,  0.3388,  1.7176],
             [ 1.6153,  0.6864,  2.3225,  1.3818]]),
     tensor([1.4832, 1.7176, 2.3225]),
     tensor([0, 3, 2]))
    
    

    参考知识文章

    相关文章

      网友评论

          本文标题:Pytorch中torch.max()函数解析

          本文链接:https://www.haomeiwen.com/subject/nclzzrtx.html