中文文档:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/#modules
英文文档:https://pytorch.org/docs/stable/index.html
看源码的时候遇到,查一下。
中文文档的解释:
返回一个包含 当前模型 所有模块的迭代器。
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
submodule = nn.Conv2d(10, 20, 4)
self.add_module("conv", submodule)
self.add_module("conv1", submodule)
model = Model()
for module in model.modules():
print(module)
输出:
Model (
(conv): Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
(conv1): Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
)
Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
注意自己写网络的时候,要继承nn.Module类。
class Model(nn.Module):
这几篇文章讲了具体的例子:
https://blog.csdn.net/u012609509/article/details/81203436
网友评论