美文网首页
论model.eval()的重要性

论model.eval()的重要性

作者: 菌子甚毒 | 来源:发表于2022-07-19 14:01 被阅读0次

    https://blog.csdn.net/qq_38410428/article/details/101102075

    model.train()和model.eval()很重要是因为:Batch Normalization和Dropout两层。

    如果模型中有BN层(Batch Normalization)和 Dropout(),需要在训练时添加model.train(),测试时添加model.eval()。以此保证在测试时保留特定的神经连接路径(dropout),以及不再更新全局均值和方差(BN)(全局值不是因为反向传播更新的而是每次有数据由momentum控制更新的)。

    model.train()和model.eval()的添加位置:

    def train(model, optimizer, epoch, train_loader, validation_loader):
        model.train() 
        """
        错误的位置
        """
        for batch_idx, (data, target) in experiment.batch_loop(iterable=train_loader):
            model.train()  
           """
            正确的位置,保证每一个batch都能进入model.train()的模式
           """
            data, target = Variable(data), Variable(target)
            # Inference
            output = model(data)
            ...
    
    def test(model, test_loader):
        model.eval()
        ...
    
    

    相关文章

      网友评论

          本文标题:论model.eval()的重要性

          本文链接:https://www.haomeiwen.com/subject/ezojirtx.html