简易代码
def print_model_parm_nums():
model = models.alexnet()
total = sum([param.nelement() for param in model.parameters()])
print(' + Number of params: %.2fM' % (total / 1e6))
简易代码
def print_model_parm_nums():
model = models.alexnet()
total = sum([param.nelement() for param in model.parameters()])
print(' + Number of params: %.2fM' % (total / 1e6))
本文标题:pytorch计算参数量
本文链接:https://www.haomeiwen.com/subject/qnhzbqtx.html
网友评论