https://blog.csdn.net/qq_38410428/article/details/101102075
model.train()和model.eval()很重要是因为:Batch Normalization和Dropout两层。
如果模型中有BN层(Batch Normalization)和 Dropout(),需要在训练时添加model.train(),测试时添加model.eval()。以此保证在测试时保留特定的神经连接路径(dropout),以及不再更新全局均值和方差(BN)(全局值不是因为反向传播更新的而是每次有数据由momentum控制更新的)。
model.train()和model.eval()的添加位置:
def train(model, optimizer, epoch, train_loader, validation_loader):
model.train()
"""
错误的位置
"""
for batch_idx, (data, target) in experiment.batch_loop(iterable=train_loader):
model.train()
"""
正确的位置,保证每一个batch都能进入model.train()的模式
"""
data, target = Variable(data), Variable(target)
# Inference
output = model(data)
...
def test(model, test_loader):
model.eval()
...
网友评论