美文网首页
PyTorch torch.optim 传入两个网络参数

PyTorch torch.optim 传入两个网络参数

作者: 人生一场梦163 | 来源:发表于2019-11-15 15:57 被阅读0次
    CLASS torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    

    以Adam优化器为例,其params定义如下:

    • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

    所以我们传入的是一个迭代器,可以通过tertools.chain将两个网络参数连接起来。

    import itertools
    ......
    optimizer = torch.optim.Adam(itretools.chain(net1.parameters(), net2.parameters()), 0.001, weight_decay = 1e-5)
    

    相关文章

      网友评论

          本文标题:PyTorch torch.optim 传入两个网络参数

          本文链接:https://www.haomeiwen.com/subject/iydsictx.html