美文网首页
Pytorch 图片降噪实现

Pytorch 图片降噪实现

作者: dawsonenjoy | 来源:发表于2019-11-09 12:30 被阅读0次

思路

要训练一个图片降噪的模型,首先需要将添加噪声的图片作为输入数据,而原图作为输出目标数据,由于是对图像的处理,并且输入输出都是图片,所以可以用卷积提取特征,maxpool降维,再用upsampling升维回图片,原理还是挺简单的,网上示例也很多,这里就用pytorch代码实现一下

导入相关模块

由于这里使用mnist集,所以需要使用到torchvision模块导入数据,下面是使用到的模块:

from torchvision import datasets
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

基本配置

这里配置一下gpu环境,代码如下:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

数据预处理

首先我的的输入数据应该是原图添加噪声后的图片,而输出则是原图,并且形状为:(batch, channel, width, height),因此这里在载入mnist集的后,执行以下步骤:

  • 将数据reshape成对应格式
  • 给数据添加噪声,并且控制像素值还是在0到255之间
  • 对数据进行归一化
  • 分训练和测试数据,这里选前50000数据训练,剩下的10000拿来测试

代码如下:

dataset = datasets.MNIST('data/',download=True)
data = dataset.data.reshape(-1, 1, 28, 28).float().to(device)

data_x = (data + 80 * torch.rand(60000, 1, 28, 28).to(device)).clamp(0, 255)/ 255.
# 添加噪声并归一化,数据x是添加噪声后的图
data_y = data / 255.
# 归一化,数据y是原图
# plt.imshow(data_y[0].cpu().squeeze())
# plt.show()
# plt.imshow(data_x[0].cpu().squeeze())
# plt.show()

# 分割训练和测试集
train_x, train_y = data_x[:50000], data_y[:50000]
test_x, test_y = data_x[50000:], data_y[50000:]

定义网络模型

这里的网络就用几次卷积加maxpool提取特征并降维之后,再通过upsampling升维,代码如下:

class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.layer1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        # (-1, 1, 28, 28) -> (-1, 32, 28, 28)
        self.layer2 = nn.MaxPool2d(2, stride=2)
        # (-1, 32, 28, 28) -> (-1, 32, 14, 14)
        self.layer3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.layer4 = nn.MaxPool2d(2, stride=2)
        # (-1, 32, 14, 14) -> (-1, 32, 7, 7)
        self.layer5 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.layer6 = nn.UpsamplingNearest2d(scale_factor=2)
        # (-1, 32, 7, 7) -> (-1, 32, 14, 14)
        self.layer7 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.layer8 = nn.UpsamplingNearest2d(scale_factor=2)
        # (-1, 32, 14, 14) -> (-1, 32, 28, 28)
        self.layer9 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)
        # (-1, 32, 28, 28) -> (-1, 1, 28, 28)
    def forward(self,x):
        x = self.sigmoid(self.layer1(x))
        x = self.layer2(x)
        x = self.sigmoid(self.layer3(x))
        x = self.layer4(x)
        x = self.sigmoid(self.layer5(x))
        x = self.layer6(x)
        x = self.sigmoid(self.layer7(x))
        x = self.layer8(x)
        x = self.sigmoid(self.layer9(x))
        return x

model = net().to(device)

定义损失函数和优化器

这里损失函数就用简单的mse就行,优化器用adam,代码如下:

loss_fun = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

训练模型

batch_size = 1000
# 开始训练
model.train()
for epoch in range(50):
    for batch in range(0, 50000 - batch_size, batch_size):
        output = model(train_x[batch: batch+batch_size])
        loss = loss_fun(train_y[batch: batch+batch_size], output)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    test_output = model(test_x)
    loss_test = loss_fun(test_y, test_output)
    print('Epoch: {}, Loss: {}, test_loss: {}'.format(epoch, loss.data, loss_test.data))
    torch.save(model.state_dict(), "autodecode.mdl")

测试模型

model.eval()
test_output = model(test_x[:1000]).cpu()
train_output = model(train_x[:1000]).cpu()

# -----------------------------------
# 显示降噪后的效果对比
n = 10
plt.figure(figsize=(10, 50))
for i in range(n):
    ax = plt.subplot(n, 3, i*3 + 1)
    plt.imshow((test_x[i*3 + 1].cpu().squeeze().detach().numpy() * 255.).astype(np.int))
    ax = plt.subplot(n, 3, i*3 + 2)
    plt.imshow((test_output[i*3 + 1].cpu().squeeze().detach().numpy() * 255.).astype(np.int))
    ax = plt.subplot(n, 3, i*3 + 3)
    plt.imshow((test_y[i*3 + 1].cpu().squeeze().detach().numpy() * 255.).astype(np.int))
plt.show()

测试结果:


从左到右依次为:噪声图、降噪图、原图

完整代码

from torchvision import datasets
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------------
# 基本配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

# -----------------------------------
# 数据预处理
dataset = datasets.MNIST('data/',download=True)
data = dataset.data.reshape(-1, 1, 28, 28).float().to(device)

data_x = (data + 80 * torch.rand(60000, 1, 28, 28).to(device)).clamp(0, 255)/ 255.
# 添加噪声并归一化,数据x是添加噪声后的图
data_y = data / 255.
# 归一化,数据y是原图
# plt.imshow(data_y[0].cpu().squeeze())
# plt.show()
# plt.imshow(data_x[0].cpu().squeeze())
# plt.show()

# 分割训练和测试集
train_x, train_y = data_x[:50000], data_y[:50000]
test_x, test_y = data_x[50000:], data_y[50000:]

# -----------------------------------
# 定义网络模型
class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.layer1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        # (-1, 1, 28, 28) -> (-1, 32, 28, 28)
        self.layer2 = nn.MaxPool2d(2, stride=2)
        # (-1, 32, 28, 28) -> (-1, 32, 14, 14)
        self.layer3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.layer4 = nn.MaxPool2d(2, stride=2)
        # (-1, 32, 14, 14) -> (-1, 32, 7, 7)
        self.layer5 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.layer6 = nn.UpsamplingNearest2d(scale_factor=2)
        # (-1, 32, 7, 7) -> (-1, 32, 14, 14)
        self.layer7 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.layer8 = nn.UpsamplingNearest2d(scale_factor=2)
        # (-1, 32, 14, 14) -> (-1, 32, 28, 28)
        self.layer9 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)
        # (-1, 32, 28, 28) -> (-1, 1, 28, 28)
    def forward(self,x):
        x = self.sigmoid(self.layer1(x))
        x = self.layer2(x)
        x = self.sigmoid(self.layer3(x))
        x = self.layer4(x)
        x = self.sigmoid(self.layer5(x))
        x = self.layer6(x)
        x = self.sigmoid(self.layer7(x))
        x = self.layer8(x)
        x = self.sigmoid(self.layer9(x))
        return x

# -----------------------------------
# 定义网络、损失函数和优化器
model = net().to(device)
loss_fun = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# -----------------------------------
# 训练模型
# model.load_state_dict(torch.load("autodecode.mdl"))
# # 载入已保存模型

batch_size = 1000
# 开始训练
model.train()
for epoch in range(50):
    for batch in range(0, 50000 - batch_size, batch_size):
        output = model(train_x[batch: batch+batch_size])
        loss = loss_fun(train_y[batch: batch+batch_size], output)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    test_output = model(test_x)
    loss_test = loss_fun(test_y, test_output)
    print('Epoch: {}, Loss: {}, test_loss: {}'.format(epoch, loss.data, loss_test.data))
    torch.save(model.state_dict(), "autodecode.mdl")

# -----------------------------------
# 测试模型
model.eval()
test_output = model(test_x[:1000]).cpu()
train_output = model(train_x[:1000]).cpu()

# -----------------------------------
# 显示降噪后的效果对比
n = 10
plt.figure(figsize=(10, 50))
for i in range(n):
    ax = plt.subplot(n, 3, i*3 + 1)
    plt.imshow((test_x[i*3 + 1].cpu().squeeze().detach().numpy() * 255.).astype(np.int))
    ax = plt.subplot(n, 3, i*3 + 2)
    plt.imshow((test_output[i*3 + 1].cpu().squeeze().detach().numpy() * 255.).astype(np.int))
    ax = plt.subplot(n, 3, i*3 + 3)
    plt.imshow((test_y[i*3 + 1].cpu().squeeze().detach().numpy() * 255.).astype(np.int))
plt.show()

相关文章

网友评论

      本文标题:Pytorch 图片降噪实现

      本文链接:https://www.haomeiwen.com/subject/jyjdbctx.html