美文网首页pytorch
torch.max 和 torch.argmax的区别

torch.max 和 torch.argmax的区别

作者: 午字横 | 来源:发表于2022-11-17 10:05 被阅读0次

    推荐去看这个文章 https://www.jianshu.com/p/3ed11362b54f

    在分类问题中,通常需要使用max()函数对softmax函数的输出值进行操作,求出预测值索引,然后与标签进行比对,计算准确率。下面讲解一下torch.max()函数的输入及输出值都是什么,便于我们理解该函数。

    1:torch.max(input, dim)

    函数定义:
    torch.max(input, dim, max=None, max_indices=None, keepdim=False) -> (Tensor, LongTensor)
    作用:找出给定tensor的指定维度dim上的上的最大值,并返回最大值在该维度上的值和位置索引。

    输入

    input是softmax函数输出的一个tensor
    dim是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值

    输出
    函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引。

    应用举例:
    例1:返回相应维度上的最大值,并返回最大值的位置索引

    a=torch.randn(3,4)
    print(a)
    print(a.shape)
    b=torch.max(a,1)
    print(b)
    print(b.indices)
    
    >tensor([[ 0.0092, -0.6736, -1.1466, -2.2001],
            [-0.2323, -0.3589,  1.4158, -0.1154],
            [ 0.7965, -1.3123, -2.2986, -0.8566]])
    torch.Size([3, 4])
    torch.return_types.max(
    values=tensor([0.0092, 1.4158, 0.7965]),
    indices=tensor([0, 2, 0]))
    tensor([0, 2, 0])
    

    例2:如果max的参数只有一个tensor,则返回该tensor里所有值中的最大值。

    a=torch.randn(3,4)
    print(a)
    print(a.shape)
    b=torch.max(a)
    print(b)
    
    >tensor([[ 0.2871,  0.6765, -1.4023,  0.7667],
            [-0.8243, -0.4072, -0.6755,  2.3382],
            [ 0.7859, -0.0375, -0.0800,  1.0330]])
    torch.Size([3, 4])
    tensor(2.3382)
    
    

    例3:如果max的参数是两个相同shape的tensor,则返回两tensor对应的最大值的新tensor。

    ...
    

    1:torch.max

    函数定义
    torch.argmax(input, dim, keepdim=False) → LongTensor
    作用:返回输入张量中指定维度的最大值的索引。

    例2:不指定维度,返回整体上最大值的序号

    a=torch.randn(3,4)
    print(a)
    print(a.shape)
    b=torch.argmax(a)
    print(b)
    >tensor([[ 0.4901,  0.1444,  0.1232, -0.0455],
            [ 0.5649,  0.5250, -0.8006,  1.2138],
            [-0.4935, -0.3818, -0.0039, -0.5936]])
    torch.Size([3, 4])
    tensor(7)
    

    例2:指定维度:返回相应维度最大值的索引

    a=torch.randn(3,4)
    print(a)
    print(a.shape)
    b=torch.argmax(a,dim=1)
    print(b)
    >tensor([[-1.6544,  0.0162,  0.6156,  0.0326],
            [ 1.5378, -0.4793,  0.8182, -1.3668],
            [-0.3072, -0.7020, -0.0641,  0.3797]])
    torch.Size([3, 4])
    tensor([2, 0, 3])
    

    相关文章

      网友评论

        本文标题:torch.max 和 torch.argmax的区别

        本文链接:https://www.haomeiwen.com/subject/kexpxdtx.html