美文网首页
TensorFlow代码重构为 Pytorch 代码中要注意的地

TensorFlow代码重构为 Pytorch 代码中要注意的地

作者: LCG22 | 来源:发表于2020-12-01 18:03 被阅读0次

注:下面的 TensorFlow 简称为 tf,Pytorch 简称为 torch

1、tf 的 Variable 可替代为 torch 的 Variable (from torch.autograd import Variable),其中前者的参数 trainable=False,可替代为后者的 requires_grad=False 

2、torch 的 Tensor 方法传入的值不能是数值,而是应该将数值转换为 [val] 这样的列表

3、tf 跟 torch 的部分函数,例如 tf 的 optimizer 和 torch 的 optim 里的优化器,功能相同,但是具体的默认参数可能会不同。

此时要不要保持一致呢?如果需要保持一致,那么会增加大量的工作,而且两者的参数会有部分不同的参数,这些不同的参数又该如何处理呢?

如果不保持一致,会不会导致训练结果有差别呢?

4、tensorflow 的数据维度默认为 (N, H, W, C),而 pytorch 的数据维度默认为 (N, C, H, W)

5、pytorch 里 nn 的函数是要进行先实例化,再传入上一个节点的输出

6、pytorch 的 Tensor 不要随便转换为 numpy 对象,否则可能会丢失反向传播的信息,导致出问题,例如模型的损失不收敛等

相关文章

网友评论

      本文标题:TensorFlow代码重构为 Pytorch 代码中要注意的地

      本文链接:https://www.haomeiwen.com/subject/gakxvktx.html