美文网首页
[论文阅读] Co-teaching

[论文阅读] Co-teaching

作者: thinson | 来源:发表于2021-11-29 20:24 被阅读0次

    作者

    第一完成单位为悉尼科技大学
    论文:https://arxiv.org/pdf/1804.06872.pdf
    代码: https://github.com/bhanML/Co-teaching

    背景

    co-teaching是nips2018上的文章。主要解决的是noisy label的问题。

    基础观点

    1. DNN在训练的时候,一开始是从clean label中获取信息,训练后期倾向于拟合噪声。
    2. 训练前期,loss低的倾向于是clean的label。

    方法

    最右为co-teaching
    M-Net 利用DNN本身进行噪声修正,或造成噪声积累。
    Co-teaching的思想相当于把当前网络认为clean的样本交给另外的网络训练,避免了噪声累积。网络1把自己认为干净的一部分样本交给网络2,网络2把自己认为干净一部分的样本交给网络1.
    随着训练的进行(epoch变大),互相认为干净的这部分比例逐步减少 ,也就是一开始给网络的数据很多(不管是含有噪声还是没有噪声都扔给网络训练),最后给网络的训练数据逐步减少,避免了网络对于噪声数据的拟合,对应基础观点1.

    算法具体流程如下:

    算法流程.png

    代码

    核心代码为梯度互相更新的部分, 相当于先对每个网络各求了loss,再对loss进行排序,选取最低的一定比例交给另一个网络进行反向传播。

    loss部分代码如下:

    # Loss functions
    def loss_coteaching(y_1, y_2, t, forget_rate, ind, noise_or_not):
        loss_1 = F.cross_entropy(y_1, t, reduce = False)
        ind_1_sorted = np.argsort(loss_1.data).cuda()
        loss_1_sorted = loss_1[ind_1_sorted]
    
        loss_2 = F.cross_entropy(y_2, t, reduce = False)
        ind_2_sorted = np.argsort(loss_2.data).cuda()
        loss_2_sorted = loss_2[ind_2_sorted]
    
        remember_rate = 1 - forget_rate
        num_remember = int(remember_rate * len(loss_1_sorted))
    
        pure_ratio_1 = np.sum(noise_or_not[ind[ind_1_sorted[:num_remember]]])/float(num_remember)
        pure_ratio_2 = np.sum(noise_or_not[ind[ind_2_sorted[:num_remember]]])/float(num_remember)
    
        ind_1_update=ind_1_sorted[:num_remember]
        ind_2_update=ind_2_sorted[:num_remember]
        # exchange
        loss_1_update = F.cross_entropy(y_1[ind_2_update], t[ind_2_update])
        loss_2_update = F.cross_entropy(y_2[ind_1_update], t[ind_1_update])
    
        return torch.sum(loss_1_update)/num_remember, torch.sum(loss_2_update)/num_remember, pure_ratio_1, pure_ratio_2
    

    forget rate更新代码:

    # define drop rate schedule
    rate_schedule = np.ones(args.n_epoch)*forget_rate
    rate_schedule[:args.num_gradual] = np.linspace(0, forget_rate**args.exponent, args.num_gradual)
    

    调用:

    loss_1, loss_2, pure_ratio_1, pure_ratio_2 = loss_coteaching(logits1, logits2, labels, rate_schedule[epoch], ind, noise_or_not)
    

    相关文章

      网友评论

          本文标题:[论文阅读] Co-teaching

          本文链接:https://www.haomeiwen.com/subject/fwilxrtx.html