美文网首页
PyTorch中的modules()

PyTorch中的modules()

作者: 0error_ | 来源:发表于2020-03-20 20:22 被阅读0次

中文文档:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/#modules

英文文档:https://pytorch.org/docs/stable/index.html

看源码的时候遇到,查一下。

中文文档的解释:

返回一个包含 当前模型 所有模块的迭代器。

import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        submodule = nn.Conv2d(10, 20, 4)
        self.add_module("conv", submodule)
        self.add_module("conv1", submodule)
model = Model()

for module in model.modules():
    print(module)

输出:
Model (
(conv): Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
(conv1): Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
)
Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))

注意自己写网络的时候,要继承nn.Module类。
class Model(nn.Module):

这几篇文章讲了具体的例子:
https://blog.csdn.net/u012609509/article/details/81203436

https://blog.csdn.net/qq_27825451/article/details/90705328

https://zhuanlan.zhihu.com/p/40350673

相关文章

网友评论

      本文标题:PyTorch中的modules()

      本文链接:https://www.haomeiwen.com/subject/gpssyhtx.html