美文网首页
class_weight

class_weight

作者: 菌子甚毒 | 来源:发表于2022-05-22 12:22 被阅读0次
  1. 创建数据。
x = torch.randn(20, 5)  # 20个sample,5个类别各自的output
y = torch.randint(0, 5, (20,))  # 20个sample的真实标签,值为0-4。
print(Counter(y.numpy()))  # 输出各个类别有多少samples。
"""
x是model的output。
y是真实的label。
其中:
Counter({2: 3, 0: 5, 3: 5, 4: 4, 1: 3})
"""
  1. 使用sklearn自动计算class_weight。
import sklearn.utils.class_weight as class_weight
class_weights=class_weight.compute_class_weight(
                                 class_weight='balanced',
                                 classes=np.unique(y),
                                 y=y.numpy())
"""
需要给定3个值。
1. class_weight:计算class weight的方式。如果选择'balanced', 由下式计算:
    n_samples / (n_classes * np.bincount(y))
    即:
    对于类1,一共有20个samples,第1类有3个samples,一共有5类。
    class_weight_1 = 20/(3*5) = 1.3333
2. classes:有哪些类。
    此处classes=tensor([0.8000, 1.3333, 1.3333, 0.8000, 1.0000])
3. y: 真实label。

output:
class_weight:
[0.8000, 1.3333, 1.3333, 0.8000, 1.0000]
"""
  1. 在定义loss_function时将class_weight传入loss function。
class_weights = torch.Tensor(class_weights)
loss = nn.CrossEntropyLoss(weight=class_weights,reduction='mean')
# 此处weight输入必须是tensor格式。
  1. 计算loss。
loss_weighted = loss(x, y)
"""
output:
tensor(2.4286)
"""

总:

# 利用sklearn的class_weight计算weights
import torch
from collections import Counter
import numpy as np
import sklearn.utils.class_weight as class_weight
import torch.nn as nn 

output = torch.randn(20, 5)  # 20个sample,5个类别各自的output
y = torch.randint(0, 5, (20,))  # 20个sample的真实标签,值为0-4。
class_weights=torch.Tensor(class_weight.compute_class_weight(class_weight='balanced',classes=np.unique(y),y=y.numpy()))

loss_func = nn.CrossEntropyLoss(weight=class_weights)
loss = loss_func(output, y)
# 自定义class weight(每个类占的百分比的倒数)
output = torch.randn(20, 5)  # 20个sample,5个类别各自的output
y = torch.randint(0, 5, (20,))  # 20个sample的真实标签,值为0-4。

weights = torch.Tensor(1/(np.bincount(y)/output.shape[0]))  # 每个类占的百分比的倒数
loss_func = nn.CrossEntropyLoss(weight=weights)
loss = loss_func(output,y)

⚠️注意:以上两种办法算出来的weights虽然不一样,但是算出的loss是一样的。
⚠️注意:class weight必须是tensor float32的形式,否则报错!!!!!


相关链接:

  1. doc
  2. 参考:https://androidkt.com/how-to-use-class-weight-in-crossentropyloss-for-an-imbalanced-dataset/
  3. 相关语法

相关文章

网友评论

      本文标题:class_weight

      本文链接:https://www.haomeiwen.com/subject/dwjkprtx.html