美文网首页
TensorFlow代码重构为 Pytorch 代码中要注意的地

TensorFlow代码重构为 Pytorch 代码中要注意的地

作者: LCG22 | 来源:发表于2020-12-01 18:03 被阅读0次

    注:下面的 TensorFlow 简称为 tf,Pytorch 简称为 torch

    1、tf 的 Variable 可替代为 torch 的 Variable (from torch.autograd import Variable),其中前者的参数 trainable=False,可替代为后者的 requires_grad=False 

    2、torch 的 Tensor 方法传入的值不能是数值,而是应该将数值转换为 [val] 这样的列表

    3、tf 跟 torch 的部分函数,例如 tf 的 optimizer 和 torch 的 optim 里的优化器,功能相同,但是具体的默认参数可能会不同。

    此时要不要保持一致呢?如果需要保持一致,那么会增加大量的工作,而且两者的参数会有部分不同的参数,这些不同的参数又该如何处理呢?

    如果不保持一致,会不会导致训练结果有差别呢?

    4、tensorflow 的数据维度默认为 (N, H, W, C),而 pytorch 的数据维度默认为 (N, C, H, W)

    5、pytorch 里 nn 的函数是要进行先实例化,再传入上一个节点的输出

    6、pytorch 的 Tensor 不要随便转换为 numpy 对象,否则可能会丢失反向传播的信息,导致出问题,例如模型的损失不收敛等

    相关文章

      网友评论

          本文标题:TensorFlow代码重构为 Pytorch 代码中要注意的地

          本文链接:https://www.haomeiwen.com/subject/gakxvktx.html