import torch
from torch import nn
from torchvision.utils import make_grid
from torch.nn import init
import matplotlib.pyplot as plt
class Model(nn.Module):
def __init__(self, num_layers=18):
super(Model, self).__init__()
convs = [
nn.Conv2d(1, 1, 3, padding=1, bias=False)
for i in range(num_layers)
]
for conv in convs:
init.constant_(conv.weight, 1)
self.layers = nn.Sequential(*convs)
def forward(self, x):
return self.layers(x)
model = Model()
x = torch.ones(1, 1, 256, 256)
x.requires_grad = True
y = model(x)
mask = torch.zeros_like(y, dtype=torch.float)
mask[0, 0, 128, 128] = 1
y = y * mask
torch.autograd.backward(y, x)
image = torch.cat([x, y, x.grad], dim=0)
image = make_grid(image)
# image = image / image.max()
image = image.permute(1, 2, 0)
plt.imshow(image)
plt.show()
网友评论