美文网首页
pytorch中label转换

pytorch中label转换

作者: random_walk | 来源:发表于2018-12-31 09:21 被阅读0次

在pytorch中,损失函数计算的时候,经常需要将label转换为one-hot的形式,在pytorch中怎么转换呢,在pytorch中只需要如下即可

#假设我们有10类,batchsize是4,随机生成一组label
class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size, 1).random_() 
#然后
one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)
#这里scatter_(dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中。

值得注意的是,如果使用交叉熵作为损失函数,并不需要我们进行one-hot编码,因为已经替我们执行了这一操作,只需要我们传入的label是longtensor即可。


有的时候我们对多标签进行编码,我们可以用sklearn中提供的工具进行实现

from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer(classes=np.arange(0, num_classes))
label = mlb.fit_transform([tuple(label)])

相关文章

网友评论

      本文标题:pytorch中label转换

      本文链接:https://www.haomeiwen.com/subject/mawtlqtx.html