思路
要训练一个图片降噪的模型,首先需要将添加噪声的图片作为输入数据,而原图作为输出目标数据,由于是对图像的处理,并且输入输出都是图片,所以可以用卷积提取特征,maxpool降维,再用upsampling升维回图片,原理还是挺简单的,网上示例也很多,这里就用pytorch代码实现一下
导入相关模块
由于这里使用mnist集,所以需要使用到torchvision
模块导入数据,下面是使用到的模块:
from torchvision import datasets
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
基本配置
这里配置一下gpu环境,代码如下:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()
数据预处理
首先我的的输入数据应该是原图添加噪声后的图片,而输出则是原图,并且形状为:(batch, channel, width, height)
,因此这里在载入mnist集的后,执行以下步骤:
- 将数据reshape成对应格式
- 给数据添加噪声,并且控制像素值还是在0到255之间
- 对数据进行归一化
- 分训练和测试数据,这里选前50000数据训练,剩下的10000拿来测试
代码如下:
dataset = datasets.MNIST('data/',download=True)
data = dataset.data.reshape(-1, 1, 28, 28).float().to(device)
data_x = (data + 80 * torch.rand(60000, 1, 28, 28).to(device)).clamp(0, 255)/ 255.
# 添加噪声并归一化,数据x是添加噪声后的图
data_y = data / 255.
# 归一化,数据y是原图
# plt.imshow(data_y[0].cpu().squeeze())
# plt.show()
# plt.imshow(data_x[0].cpu().squeeze())
# plt.show()
# 分割训练和测试集
train_x, train_y = data_x[:50000], data_y[:50000]
test_x, test_y = data_x[50000:], data_y[50000:]
定义网络模型
这里的网络就用几次卷积加maxpool提取特征并降维之后,再通过upsampling升维,代码如下:
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.layer1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
# (-1, 1, 28, 28) -> (-1, 32, 28, 28)
self.layer2 = nn.MaxPool2d(2, stride=2)
# (-1, 32, 28, 28) -> (-1, 32, 14, 14)
self.layer3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.layer4 = nn.MaxPool2d(2, stride=2)
# (-1, 32, 14, 14) -> (-1, 32, 7, 7)
self.layer5 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.layer6 = nn.UpsamplingNearest2d(scale_factor=2)
# (-1, 32, 7, 7) -> (-1, 32, 14, 14)
self.layer7 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.layer8 = nn.UpsamplingNearest2d(scale_factor=2)
# (-1, 32, 14, 14) -> (-1, 32, 28, 28)
self.layer9 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)
# (-1, 32, 28, 28) -> (-1, 1, 28, 28)
def forward(self,x):
x = self.sigmoid(self.layer1(x))
x = self.layer2(x)
x = self.sigmoid(self.layer3(x))
x = self.layer4(x)
x = self.sigmoid(self.layer5(x))
x = self.layer6(x)
x = self.sigmoid(self.layer7(x))
x = self.layer8(x)
x = self.sigmoid(self.layer9(x))
return x
model = net().to(device)
定义损失函数和优化器
这里损失函数就用简单的mse就行,优化器用adam,代码如下:
loss_fun = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
训练模型
batch_size = 1000
# 开始训练
model.train()
for epoch in range(50):
for batch in range(0, 50000 - batch_size, batch_size):
output = model(train_x[batch: batch+batch_size])
loss = loss_fun(train_y[batch: batch+batch_size], output)
optimizer.zero_grad()
loss.backward()
optimizer.step()
test_output = model(test_x)
loss_test = loss_fun(test_y, test_output)
print('Epoch: {}, Loss: {}, test_loss: {}'.format(epoch, loss.data, loss_test.data))
torch.save(model.state_dict(), "autodecode.mdl")
测试模型
model.eval()
test_output = model(test_x[:1000]).cpu()
train_output = model(train_x[:1000]).cpu()
# -----------------------------------
# 显示降噪后的效果对比
n = 10
plt.figure(figsize=(10, 50))
for i in range(n):
ax = plt.subplot(n, 3, i*3 + 1)
plt.imshow((test_x[i*3 + 1].cpu().squeeze().detach().numpy() * 255.).astype(np.int))
ax = plt.subplot(n, 3, i*3 + 2)
plt.imshow((test_output[i*3 + 1].cpu().squeeze().detach().numpy() * 255.).astype(np.int))
ax = plt.subplot(n, 3, i*3 + 3)
plt.imshow((test_y[i*3 + 1].cpu().squeeze().detach().numpy() * 255.).astype(np.int))
plt.show()
测试结果:
从左到右依次为:噪声图、降噪图、原图
完整代码
from torchvision import datasets
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
# -----------------------------------
# 基本配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()
# -----------------------------------
# 数据预处理
dataset = datasets.MNIST('data/',download=True)
data = dataset.data.reshape(-1, 1, 28, 28).float().to(device)
data_x = (data + 80 * torch.rand(60000, 1, 28, 28).to(device)).clamp(0, 255)/ 255.
# 添加噪声并归一化,数据x是添加噪声后的图
data_y = data / 255.
# 归一化,数据y是原图
# plt.imshow(data_y[0].cpu().squeeze())
# plt.show()
# plt.imshow(data_x[0].cpu().squeeze())
# plt.show()
# 分割训练和测试集
train_x, train_y = data_x[:50000], data_y[:50000]
test_x, test_y = data_x[50000:], data_y[50000:]
# -----------------------------------
# 定义网络模型
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.layer1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
# (-1, 1, 28, 28) -> (-1, 32, 28, 28)
self.layer2 = nn.MaxPool2d(2, stride=2)
# (-1, 32, 28, 28) -> (-1, 32, 14, 14)
self.layer3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.layer4 = nn.MaxPool2d(2, stride=2)
# (-1, 32, 14, 14) -> (-1, 32, 7, 7)
self.layer5 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.layer6 = nn.UpsamplingNearest2d(scale_factor=2)
# (-1, 32, 7, 7) -> (-1, 32, 14, 14)
self.layer7 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.layer8 = nn.UpsamplingNearest2d(scale_factor=2)
# (-1, 32, 14, 14) -> (-1, 32, 28, 28)
self.layer9 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)
# (-1, 32, 28, 28) -> (-1, 1, 28, 28)
def forward(self,x):
x = self.sigmoid(self.layer1(x))
x = self.layer2(x)
x = self.sigmoid(self.layer3(x))
x = self.layer4(x)
x = self.sigmoid(self.layer5(x))
x = self.layer6(x)
x = self.sigmoid(self.layer7(x))
x = self.layer8(x)
x = self.sigmoid(self.layer9(x))
return x
# -----------------------------------
# 定义网络、损失函数和优化器
model = net().to(device)
loss_fun = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# -----------------------------------
# 训练模型
# model.load_state_dict(torch.load("autodecode.mdl"))
# # 载入已保存模型
batch_size = 1000
# 开始训练
model.train()
for epoch in range(50):
for batch in range(0, 50000 - batch_size, batch_size):
output = model(train_x[batch: batch+batch_size])
loss = loss_fun(train_y[batch: batch+batch_size], output)
optimizer.zero_grad()
loss.backward()
optimizer.step()
test_output = model(test_x)
loss_test = loss_fun(test_y, test_output)
print('Epoch: {}, Loss: {}, test_loss: {}'.format(epoch, loss.data, loss_test.data))
torch.save(model.state_dict(), "autodecode.mdl")
# -----------------------------------
# 测试模型
model.eval()
test_output = model(test_x[:1000]).cpu()
train_output = model(train_x[:1000]).cpu()
# -----------------------------------
# 显示降噪后的效果对比
n = 10
plt.figure(figsize=(10, 50))
for i in range(n):
ax = plt.subplot(n, 3, i*3 + 1)
plt.imshow((test_x[i*3 + 1].cpu().squeeze().detach().numpy() * 255.).astype(np.int))
ax = plt.subplot(n, 3, i*3 + 2)
plt.imshow((test_output[i*3 + 1].cpu().squeeze().detach().numpy() * 255.).astype(np.int))
ax = plt.subplot(n, 3, i*3 + 3)
plt.imshow((test_y[i*3 + 1].cpu().squeeze().detach().numpy() * 255.).astype(np.int))
plt.show()
网友评论