想获取网络的中间输出,但是尝试后,发现
-
hook不好用
-
Sequential有时用不了
所以最终决定还是直接使用list保存,代码中的修改部分如下:
def forward(self, x):
all_output = [] #新增
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
all_output.append(x) #新增
x = F.relu(self.fc2(x))
all_output.append(x) #新增
x = self.fc3(x)
all_output.append(x) #新增
return x, all_output
# 训练fc
def FCN_train(lr, epochs, train_loader):
model = FCN(28 * 28, 256, 512, 10)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.5)
step_list = []
loss_list = []
for epoch in range(epochs):
for step, (x, y) in enumerate(train_loader):
output = model(x)[0] #修改,原本是model(x)
#第一层输出 all_output[0]
#第二层输出 all_output[1]
......
网友评论