解决办法参考: Get the mean from a list of tensors
问题背景
最近跑一个Siamese-FC的复现程序,要求配置是python2.7+pytorch0.4,之前安装的是Pytorch1.0,降低版本下载过慢多次失败,最终选择在Pytorch1.0版本下解决这个问题。
问题描述
项目地址:https://github.com/zzwang058/SiamFC-PyTorch
在运行run_Train_SiamFC.py中
print ("Epoch %d training loss: %f, validation loss: %f" % (i+1, np.mean(train_loss), np.mean(val_loss)))
是Pytorch1.0存在的问题,似乎是因为对张量求平均?
解决办法如下:
将 np.mean(a) 替换为 torch.mean(torch.stack(a))
网友评论