美文网首页
Pytorch函数

Pytorch函数

作者: _Cooper_ | 来源:发表于2018-07-01 23:03 被阅读63次
  • max()
    torch.max(input, dim)
    dim参数指出删去哪一维度,0-行,1-列;输出两个tensor,第一个得到最大值结果,第二个给出相对位置(0-index)
>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222,  0.8475],
        [ 1.1949, -1.1127, -2.2379, -0.6702],
        [ 1.5717, -0.9207,  0.1297, -1.8768],
        [-0.6172,  1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
(tensor([ 0.8475,  1.1949,  1.5717,  1.0036]), tensor([ 3,  0,  0,  1]))

dim=1,删除列的维度,只有1列,每一行为该行最大值,第二个tensor给出该最大值所在的列数
等同于a.max(1)
例:在训练网络时

output = net(img)
 _, predicted = output.max(1)

output为对img的预测输出,batch行label列,每行是一个图片的输出,每次输出batch组。所以预测结果需要看每行的最大值,找每行最大值的位置。output.max(1)找到每行最大值,有两个tensor输出,第一个为最大值,第二个为最大值所在位置,所关注的是位置,所以第一个下划线_舍弃掉最大值。

相关文章

网友评论

      本文标题:Pytorch函数

      本文链接:https://www.haomeiwen.com/subject/pznhuftx.html