作者
第一完成单位为悉尼科技大学
论文:https://arxiv.org/pdf/1804.06872.pdf
代码: https://github.com/bhanML/Co-teaching
背景
co-teaching是nips2018上的文章。主要解决的是noisy label的问题。
基础观点
- DNN在训练的时候,一开始是从clean label中获取信息,训练后期倾向于拟合噪声。
- 训练前期,loss低的倾向于是clean的label。
方法
最右为co-teachingM-Net 利用DNN本身进行噪声修正,或造成噪声积累。
Co-teaching的思想相当于把当前网络认为clean的样本交给另外的网络训练,避免了噪声累积。网络1把自己认为干净的一部分样本交给网络2,网络2把自己认为干净一部分的样本交给网络1.
随着训练的进行(epoch变大),互相认为干净的这部分比例逐步减少 ,也就是一开始给网络的数据很多(不管是含有噪声还是没有噪声都扔给网络训练),最后给网络的训练数据逐步减少,避免了网络对于噪声数据的拟合,对应基础观点1.
算法具体流程如下:
算法流程.png代码
核心代码为梯度互相更新的部分, 相当于先对每个网络各求了loss,再对loss进行排序,选取最低的一定比例交给另一个网络进行反向传播。
loss部分代码如下:
# Loss functions
def loss_coteaching(y_1, y_2, t, forget_rate, ind, noise_or_not):
loss_1 = F.cross_entropy(y_1, t, reduce = False)
ind_1_sorted = np.argsort(loss_1.data).cuda()
loss_1_sorted = loss_1[ind_1_sorted]
loss_2 = F.cross_entropy(y_2, t, reduce = False)
ind_2_sorted = np.argsort(loss_2.data).cuda()
loss_2_sorted = loss_2[ind_2_sorted]
remember_rate = 1 - forget_rate
num_remember = int(remember_rate * len(loss_1_sorted))
pure_ratio_1 = np.sum(noise_or_not[ind[ind_1_sorted[:num_remember]]])/float(num_remember)
pure_ratio_2 = np.sum(noise_or_not[ind[ind_2_sorted[:num_remember]]])/float(num_remember)
ind_1_update=ind_1_sorted[:num_remember]
ind_2_update=ind_2_sorted[:num_remember]
# exchange
loss_1_update = F.cross_entropy(y_1[ind_2_update], t[ind_2_update])
loss_2_update = F.cross_entropy(y_2[ind_1_update], t[ind_1_update])
return torch.sum(loss_1_update)/num_remember, torch.sum(loss_2_update)/num_remember, pure_ratio_1, pure_ratio_2
forget rate更新代码:
# define drop rate schedule
rate_schedule = np.ones(args.n_epoch)*forget_rate
rate_schedule[:args.num_gradual] = np.linspace(0, forget_rate**args.exponent, args.num_gradual)
调用:
loss_1, loss_2, pure_ratio_1, pure_ratio_2 = loss_coteaching(logits1, logits2, labels, rate_schedule[epoch], ind, noise_or_not)
网友评论