美文网首页
PyTorch中的modules()

PyTorch中的modules()

作者: 0error_ | 来源:发表于2020-03-20 20:22 被阅读0次

    中文文档:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/#modules

    英文文档:https://pytorch.org/docs/stable/index.html

    看源码的时候遇到,查一下。

    中文文档的解释:

    返回一个包含 当前模型 所有模块的迭代器。

    import torch.nn as nn
    
    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            submodule = nn.Conv2d(10, 20, 4)
            self.add_module("conv", submodule)
            self.add_module("conv1", submodule)
    model = Model()
    
    for module in model.modules():
        print(module)
    

    输出:
    Model (
    (conv): Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
    (conv1): Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
    )
    Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))

    注意自己写网络的时候,要继承nn.Module类。
    class Model(nn.Module):

    这几篇文章讲了具体的例子:
    https://blog.csdn.net/u012609509/article/details/81203436

    https://blog.csdn.net/qq_27825451/article/details/90705328

    https://zhuanlan.zhihu.com/p/40350673

    相关文章

      网友评论

          本文标题:PyTorch中的modules()

          本文链接:https://www.haomeiwen.com/subject/gpssyhtx.html