美文网首页
PyTorch复现SRGAN算法核心代码(带注释)

PyTorch复现SRGAN算法核心代码(带注释)

作者: 上弦同学 | 来源:发表于2019-03-09 09:26 被阅读0次

    GitHub地址 : https://github.com/SummerChaser/SRGAN-pytorch

    train.py

    import argparse
    import os
    from math import log10
    import pandas as pd
    import torch.optim as optim
    import torch.utils.data
    import torchvision.utils as utils
    from torch.autograd import Variable
    from torch.utils.data import DataLoader
    from tqdm import tqdm  # 进度条
    import pytorch_ssim
    from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform
    from loss import GeneratorLoss
    from model import Generator, Discriminator
    
    
    # 给分析器增加description,crop_size(图片裁剪大小),放大因子,epoch(跑的次数)等参数
    parser = argparse.ArgumentParser(description='Train Super Resolution Models')
    parser.add_argument('--crop_size', default=88, type=int, help='training images crop size')
    parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8],
                        help='super resolution upscale factor')
    parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')
    
    # 对之前add的参数进行赋值,并返回响应namespace
    opt = parser.parse_args()
    
    # 提取opt(选项器)中设置的参数,设定为常量
    CROP_SIZE = opt.crop_size
    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs
    
    # 从指定路径导入train_set,指定裁剪大小和放大因子
    train_set = TrainDatasetFromFolder('data/VOC2012/train', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('data/VOC2012/val', upscale_factor=UPSCALE_FACTOR)
    
    # 使用loader,从训练集中,一次性处理一个batch的文件 (批量加载器)
    train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
    val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)
    
    # 创建生成器实例 netG ,输出生成器参数的数量
    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
    
    # 实例化生成器损失函数模型
    generator_criterion = GeneratorLoss()
    
    # 如果能gpu加速,把网络放到gpu上
    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()
    
    # 构建优化器optimizer,传入模型所有参数,使用Adam参数优化算法,调用step()可进行一次模型参数优化
    # Adam - 自适应学习率+适用非凸优化
    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())
    
    # 结果集 : loss score psnr(峰值信噪比) ssim(结构相似性)
    results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
    
    # 一次epoch跑一趟训练集
    for epoch in range(1, NUM_EPOCHS + 1):
        # 加载进度条
        train_bar = tqdm(train_loader)
        running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}
    
        # 进入train模式
        netG.train()
        netD.train()
        
        for data, target in train_bar:
            g_update_first = True
            batch_size = data.size(0)
            running_results['batch_sizes'] += batch_size
    
            ############################
            # (1) Update D network: maximize D(x)-1-D(G(z))
            ###########################
            real_img = Variable(target)
            if torch.cuda.is_available():
                real_img = real_img.cuda()
            z = Variable(data)
            if torch.cuda.is_available():
                z = z.cuda()
            fake_img = netG(z)
    
            netD.zero_grad()
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            d_loss.backward(retain_graph=True)
            # 进行参数优化
            optimizerD.step()
    
            ############################
            # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            ###########################
            netG.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()
            optimizerG.step()
            fake_img = netG(z)
            fake_out = netD(fake_img).mean()
    
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            running_results['g_loss'] += g_loss.data[0] * batch_size
            d_loss = 1 - real_out + fake_out
    
            running_results['d_loss'] += d_loss.data[0] * batch_size  # d_loss real/fake通过判别器的差距
            running_results['d_score'] += real_out.data[0] * batch_size  # real通过判别器的值
            running_results['g_score'] += fake_out.data[0] * batch_size  # fake通过判别器的值
    
            # 描述进度和损失函数,得分函数的平均值
            train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes']))
        # 进入eval模式 (测试模式参数固定,只有前向传播)
        netG.eval()
        out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        val_bar = tqdm(val_loader)
        valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
        val_images = []
        for val_lr, val_hr_restore, val_hr in val_bar:
            batch_size = val_lr.size(0)
            # 已经测试过的数目
            valing_results['batch_sizes'] += batch_size
            lr = Variable(val_lr, volatile=True)
            hr = Variable(val_hr, volatile=True)
            if torch.cuda.is_available():
                lr = lr.cuda()
                hr = hr.cuda()
            # 直接输出结果,没有参数优化的过程
            sr = netG(lr)
            # 计算mse
            batch_mse = ((sr - hr) ** 2).data.mean()
            valing_results['mse'] += batch_mse * batch_size
            batch_ssim = pytorch_ssim.ssim(sr, hr).data[0]
            valing_results['ssims'] += batch_ssim * batch_size
            valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes']))
            valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
            val_bar.set_description(
                desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                    valing_results['psnr'], valing_results['ssim']))
            # 通过extend把三张图连在一起
            val_images.extend(
                [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
                 display_transform()(sr.data.cpu().squeeze(0))])
        # 拉伸?
        print("val_images", val_images)
        val_images = torch.stack(val_images)
        print("val_images",val_images)
        val_images = torch.chunk(val_images, val_images.size(0) // 15)
        print("val_images", val_images)
        val_save_bar = tqdm(val_images, desc='[saving training results]')
        index = 1
        for image in val_save_bar:
            # 每一行显示三个图像
            image = utils.make_grid(image, nrow=3, padding=5)
            utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
            index += 1
    
        # save model parameters
        torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
        torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
        # save loss\scores\psnr\ssim
        results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
        results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
        results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
        results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
        results['psnr'].append(valing_results['psnr'])
        results['ssim'].append(valing_results['ssim'])
    
        if epoch % 10 == 0 and epoch != 0:
            out_path = 'statistics/'
            data_frame = pd.DataFrame(
                data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                      'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))
            data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')
    

    loss.py

    import torch
    from torch import nn
    from torchvision.models.vgg import vgg16
    
    
    class GeneratorLoss(nn.Module):
        def __init__(self):
            super(GeneratorLoss, self).__init__()
            vgg = vgg16(pretrained=True)
            loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
            for param in loss_network.parameters():
                param.requires_grad = False
            self.loss_network = loss_network
            self.mse_loss = nn.MSELoss()
            self.tv_loss = TVLoss()
    
        def forward(self, out_labels, out_images, target_images):
            # Adversarial Loss
            adversarial_loss = torch.mean(1 - out_labels)
            # Perception Loss
            perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
            # Image Loss
            image_loss = self.mse_loss(out_images, target_images)
            # TV Loss
            tv_loss = self.tv_loss(out_images)
            return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss
    
    
    class TVLoss(nn.Module):
        def __init__(self, tv_loss_weight=1):
            super(TVLoss, self).__init__()
            self.tv_loss_weight = tv_loss_weight
    
        def forward(self, x):
            batch_size = x.size()[0]
            h_x = x.size()[2]
            w_x = x.size()[3]
            count_h = self.tensor_size(x[:, :, 1:, :])
            count_w = self.tensor_size(x[:, :, :, 1:])
            h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
            w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
            return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
    
        @staticmethod
        def tensor_size(t):
            return t.size()[1] * t.size()[2] * t.size()[3]
    
    
    if __name__ == "__main__":
        g_loss = GeneratorLoss()
        print(g_loss)
    
    

    model.py

    import math
    
    import torch.nn.functional as F
    from torch import nn
    
    # 生成器
    class Generator(nn.Module):
        def __init__(self, scale_factor):
            upsample_block_num = int(math.log(scale_factor, 2))
    
            super(Generator, self).__init__()
            self.block1 = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=9, padding=4),
                nn.PReLU()
            )
            self.block2 = ResidualBlock(64)
            self.block3 = ResidualBlock(64)
            self.block4 = ResidualBlock(64)
            self.block5 = ResidualBlock(64)
            self.block6 = ResidualBlock(64)
            self.block7 = nn.Sequential(
                nn.Conv2d(64, 64, kernel_size=3, padding=1),
                nn.PReLU()
            )
            block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
            block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
            self.block8 = nn.Sequential(*block8)
    
        def forward(self, x):
            block1 = self.block1(x)
            block2 = self.block2(block1)
            block3 = self.block3(block2)
            block4 = self.block4(block3)
            block5 = self.block5(block4)
            block6 = self.block6(block5)
            block7 = self.block7(block6)
            block8 = self.block8(block1 + block7)
    
            return (F.tanh(block8) + 1) / 2
    
    
    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            self.net = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, padding=1),
                nn.LeakyReLU(0.2),
    
                nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(0.2),
    
                nn.Conv2d(64, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2),
    
                nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2),
    
                nn.Conv2d(128, 256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.2),
    
                nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.2),
    
                nn.Conv2d(256, 512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
                nn.LeakyReLU(0.2),
    
                nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(512),
                nn.LeakyReLU(0.2),
    
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(512, 1024, kernel_size=1),
                nn.LeakyReLU(0.2),
                nn.Conv2d(1024, 1, kernel_size=1)
            )
    
        def forward(self, x):
            batch_size = x.size(0)
            return F.sigmoid(self.net(x).view(batch_size))
    
    # 定义残差块
    class ResidualBlock(nn.Module):
        def __init__(self, channels):
            super(ResidualBlock, self).__init__()
            self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
            self.bn1 = nn.BatchNorm2d(channels)
            self.prelu = nn.PReLU()
            self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
            self.bn2 = nn.BatchNorm2d(channels)
    
        def forward(self, x):
            residual = self.conv1(x)
            residual = self.bn1(residual)
            residual = self.prelu(residual)
            residual = self.conv2(residual)
            residual = self.bn2(residual)
    
            return x + residual
    
    
    class UpsampleBLock(nn.Module):
        def __init__(self, in_channels, up_scale):
            super(UpsampleBLock, self).__init__()
            self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
            self.pixel_shuffle = nn.PixelShuffle(up_scale)
            self.prelu = nn.PReLU()
    
        def forward(self, x):
            x = self.conv(x)
            x = self.pixel_shuffle(x)
            x = self.prelu(x)
            return x
    
    

    data_utils.py

    from os import listdir
    from os.path import join
    
    from PIL import Image
    from torch.utils.data.dataset import Dataset
    # torchvision.transforms - 图像预处理包
    # Compose - 把多个步骤整合一起
    from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
    
    
    # 通过后缀检查是否为图片文件
    def is_image_file(filename):
        return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
    
    # 实际有效的图片区域范围
    def calculate_valid_crop_size(crop_size, upscale_factor):
        return crop_size - (crop_size % upscale_factor)
    
    
    def train_hr_transform(crop_size):
        return Compose([
            RandomCrop(crop_size),  # 在随机位置裁剪
            ToTensor(),  # convert a PIL image to tensor (H*W*C)
        ])
    
    
    def train_lr_transform(crop_size, upscale_factor):
        return Compose([
            ToPILImage(),  # convert a tensor to PIL image
            Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),  # 通过双三次插值把图像resize成lr
            ToTensor()
        ])
    
    
    def display_transform():
        return Compose([
            ToPILImage(),
            Resize(400),  # 把图像调整到400标准格式
            CenterCrop(400),
            ToTensor()
        ])
    
    
    # 从文件夹获取训练集
    class TrainDatasetFromFolder(Dataset):
        def __init__(self, dataset_dir, crop_size, upscale_factor):
            super(TrainDatasetFromFolder, self).__init__()
            # 获取图片列表
            self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
            crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
            # 定义hr lr转化函数
            self.hr_transform = train_hr_transform(crop_size)
            self.lr_transform = train_lr_transform(crop_size, upscale_factor)
    
        def __getitem__(self, index):
            # 获取该index的高清图像,同时转化得到低清图像
            hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
            lr_image = self.lr_transform(hr_image)
            return lr_image, hr_image
    
        def __len__(self):
            return len(self.image_filenames)
    
    # 验证集
    class ValDatasetFromFolder(Dataset):
        def __init__(self, dataset_dir, upscale_factor):
            super(ValDatasetFromFolder, self).__init__()
            self.upscale_factor = upscale_factor
            self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
    
        def __getitem__(self, index):
            hr_image = Image.open(self.image_filenames[index])  # 原始图片为高清图
            w, h = hr_image.size
            crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
            lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
            hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
            hr_image = CenterCrop(crop_size)(hr_image)  # 裁剪
            lr_image = lr_scale(hr_image)  # 双三次resize成lr
            hr_restore_img = hr_scale(lr_image)
            return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
    
        def __len__(self):
            return len(self.image_filenames)
    
    # 测试集
    class TestDatasetFromFolder(Dataset):
        def __init__(self, dataset_dir, upscale_factor):
            super(TestDatasetFromFolder, self).__init__()
            # 有hr lr两个文件目录
            self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
            self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
            self.upscale_factor = upscale_factor
            self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
            self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]
    
        def __getitem__(self, index):
            # 获取hr lr 图像
            image_name = self.lr_filenames[index].split('/')[-1]
            lr_image = Image.open(self.lr_filenames[index])
            w, h = lr_image.size
            hr_image = Image.open(self.hr_filenames[index])
            hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=Image.BICUBIC)
            hr_restore_img = hr_scale(lr_image)
            return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
    
        def __len__(self):
            return len(self.lr_filenames)
    
    

    没有GPU,在自己mac电脑上测试了一下,训练集删减为10多张图片,跑了40个epoch,虽然高糊但是有个大概雏形出来了

    image.png image.png

    相关文章

      网友评论

          本文标题:PyTorch复现SRGAN算法核心代码(带注释)

          本文链接:https://www.haomeiwen.com/subject/heojpqtx.html