美文网首页
pytorch 中分类网络损失函数

pytorch 中分类网络损失函数

作者: 深度学习努力中 | 来源:发表于2021-02-01 11:27 被阅读0次


1、分类网络搭建

如图搭建简单的分类网络,以二分类为例:

二分类网络

2,10,2分别代表:输入的特征数,隐藏神经元的个数,输出的概率(one-hot编码)

prediction=net(x):概率可以为负数

[0.6,-0.1]:表示预测为0,最大概率的索引为0

[-0.08,6.5]:表示预测为1,最大概率的索引为1

2、分类损失函数

loss_func=torch.nn.CrossEntropyLoss()

3、网络训练,预测

网络训练

loss_func中输入的是直接从网络中输出的量(prediction=net(x))

为了查看其预测的标签,通过打印这句话即可:

print(torch.max(F.softmax(prediction), 1)[1])

1、F.softmax(prediction):将二分类概率都变成0-1之间的数,且相加为1

[0.3,0.7]:预测为1

[0.6,0.4]:预测为0

2、torch.max(F.softmax(prediction), 1):

torch.max(out,1):按行取最大,每行的最大值放到一个矩阵中

torch.max(out,0):按列取最大,每列的最大值放到一个矩阵中

这句话最后返回两个tensor:

values=tensor(最大的概率值)

indices=tensor(最大概率值的索引)

3、torch.max(F.softmax(prediction), 1)[1]:

只保留indices的tensor

实例:

*注:在有的地方我们会看到torch.max(a, 1).data.numpy()的写法,这是因为在早期的pytorch的版本中,variable变量和tenosr是不一样的数据格式,variable可以进行反向传播,tensor不可以,需要将variable转变成tensor再转变成numpy。现在的版本已经将variable和tenosr合并,所以只用torch.max(a,1).numpy()就可以了。

参考链接:https://www.jianshu.com/p/3ed11362b54f

相关文章

网友评论

      本文标题:pytorch 中分类网络损失函数

      本文链接:https://www.haomeiwen.com/subject/agjjtltx.html