美文网首页
Freeze BN in Pytorch

Freeze BN in Pytorch

作者: Birdy潇 | 来源:发表于2021-03-08 15:35 被阅读0次
def set_bn_eval(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()

use model.apply() to freeze bn

def train(model,data_loader,criterion,epoch):
    model.train() # switch to train mode
    model.apply(set_bn_eval) # this will freeze the bn in training process
    ###
    # training code
    ###

wrap up, commonly used

def main():
    # ...
    for epoch in epochs:
        train(model,train_loader,criterion,epoch)
        test(model,eval_loader,epoch)
    # ...

相关文章

网友评论

      本文标题:Freeze BN in Pytorch

      本文链接:https://www.haomeiwen.com/subject/fdipqltx.html