美文网首页
如何使用GAN生成手写数字识别图像——GAN生成式对抗神经网络

如何使用GAN生成手写数字识别图像——GAN生成式对抗神经网络

作者: 小黄不头秃 | 来源:发表于2023-06-12 01:23 被阅读0次

    GAN的全称叫做:生成式对抗神经网络。它是生成类模型的代表,今天我们来看一下这个网络是如何进行工作的,并且我们写一个GAN,让他来帮我们生成手写数字的图片。

    GAN的核心为,让一个简单的随机的分布,经过网络的计算后生成一个有规律的分布。例如一张图片或者一段声音。

    一、了解GAN

    GAN这个网络主要分成两个部分:generator(生成器)、discriminator(判别器)。他们两个的工作分别为,生成器用于生产图片,判别器用来判断生产出来的图片够不够真实。

    我们举一个简单的例子来说,生成器就像是假币制造工厂,判别器就像是验钞的部门。我们给假币制造厂输入棉花和颜料,让工厂给我们仿制出一批纸币,然后交给验钞部门去检验真假。随着时间的推移,工厂制造的假币就会越来越真,验钞部门见过的假币越多也越来越有经验。最后我们希望工厂制造出来的假币连验钞部门也分辨不出来到底是真是假。这个时候我们的GAN就训练成功了。

    GAN是很难训练出来的,首先generator和discriminator需要相互匹配,其中有一个出问题了,网络就无法训练起来。第二点,网络有可能缺少多样性。输出一张图片看似还不错,但是多输出几张,发现他们都长得一样。就需要另外的方法来解决多样性的问题。

    二、多种多样的GAN

    (1)conditional generation 有条件的生成

    例如:输入一段文字的描述,输出一张对应要求的图片。

    这里与前面的不同,这里可能需要对图片进行标注,打上label。那如果我们仅仅是打上标签作为输入,然后去生成一张图片,网络极有可能就是忽略掉你输入的条件,随便生成一张图片给你。这就需要label和图片同时满足要求才算通过。这里就要设计一个更加严格的discriminator。

    (2)Cycle GAN

    有时候我们的需求是:给真实的人物生成动漫形象。这里的动漫形象那要求和真实的人长得还挺像的。

    那如果用前面所提到的方法似乎都行不通,生成的图片肯定是随机生成,无法达到要求的。那么cycle gan的解决方案就是,我要生成出来的动漫形象还能还原出原来的人的样子。这样生成出来的动漫形象就不会跑偏了。

    三、代码实现GAN

    import argparse
    import os
    import numpy as np
    import math
    
    import torchvision.transforms as transforms
    from torchvision.utils import save_image
    
    from torch.utils.data import DataLoader
    from torchvision import datasets
    from torch.autograd import Variable
    
    import torch.nn as nn
    import torch.nn.functional as F
    import torch
    
    # 创建images文件夹,存在不覆盖
    os.makedirs("images", exist_ok=True)
    
    """
    ################################################################################################################
    (1)                                        超参数的选择和GPU检测
    ################################################################################################################
    """
    # 初始化参数,并将其构建为parser结构
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=500, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
    opt = parser.parse_args()
    print(opt) # 打印所有参数信息
    
    # 图像的尺寸:1*28*28 = 764
    img_shape = (opt.channels, opt.img_size, opt.img_size)
    
    # GPU是否可用
    cuda = True if torch.cuda.is_available() else False
    
    """
    ################################################################################################################
    (2)                                   GAN的生成器模型和判别器模型的构造
    ################################################################################################################
    """
    # GAN_生成器
    class Generator(nn.Module):
        # 初始化构造函数
        def __init__(self):
            super(Generator, self).__init__()
    
            # 全连接层块:全连接层+[BN]+LekyReLu
            def block(in_feat, out_feat, normalize=True):
                layers = [nn.Linear(in_feat, out_feat)]
                if normalize:
                    layers.append(nn.BatchNorm1d(out_feat, 0.8))
                layers.append(nn.LeakyReLU(0.2, inplace=True))
                return layers
            
            # 模型结构:全连接块*4 + 全连接层 + Tanh
            self.model = nn.Sequential(
                *block(opt.latent_dim, 128, normalize=False), # 100 -> 128
                *block(128, 256), # 128 -> 256
                *block(256, 512), # 256 -> 512
                *block(512, 1024), # 512 -> 1024
                nn.Linear(1024, int(np.prod(img_shape))), # 1024 -> 764
                nn.Tanh()
            )
    
        # 前向计算
        def forward(self, z):
            img = self.model(z) # torch.Size([64, 784])
            img = img.view(img.size(0), *img_shape) # 还原成单通道,宽高为28的图像 :torch.Size([64, 1, 28, 28])
            return img
    
    # GAN_判别器
    class Discriminator(nn.Module):
        # 初始化构造函数
        def __init__(self):
            super(Discriminator, self).__init__()
            # 判别器模型结构:全连接层 + leakyrelu + 全连接层 + leakyrelu + 全连接层 + sigmoid
            self.model = nn.Sequential(
                nn.Linear(int(np.prod(img_shape)), 512), # 764 -> 512
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(512, 256), # 512 -> 256
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(256, 1), # 256 -> 1,fake or not?
                nn.Sigmoid(),
            )
    
        # 前向计算
        def forward(self, img):
            img_flat = img.view(img.size(0), -1) # 将图像进行展平:torch.Size([64, 784])
            validity = self.model(img_flat) # 判别结果:torch.Size([64, 1])
            return validity
    
    """
    ################################################################################################################
    (3)                                            MNIST数据集制作
    ################################################################################################################
    """
    # 制作数据集
    os.makedirs("./data/mnist", exist_ok=True) # 创建文件夹
    # 数据构造器
    dataloader = torch.utils.data.DataLoader(
        # 在线下载MNIST数据集
        datasets.MNIST(
            "./data/mnist",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        ),
        batch_size=opt.batch_size,
        shuffle=True,
    )
    
    """
    ################################################################################################################
    (4)                                            模型训练和结果保存
    ################################################################################################################
    """
    # 损失函数,使用BCELoss
    adversarial_loss = torch.nn.BCELoss()
    
    # 初始化生成器和判别器
    generator = Generator()
    discriminator = Discriminator()
    
    # 进行GPU训练
    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()
    
    # 生成器和判别器的优化器初始化
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    
    # 迭代开始
    for epoch in range(opt.n_epochs):
        for i, (imgs, _) in enumerate(dataloader):
            # 相当于标签:1为真实图片,0为生成图片
            valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # torch.Size([64, 1])
            fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # # torch.Size([64, 1])
    
            # 真实输入
            real_imgs = Variable(imgs.type(Tensor)) # torch.Size([64, 1, 28, 28])
    
            # -----------------
            #  训练GAN_生成器
            # -----------------
            optimizer_G.zero_grad() # 手动清除梯度
            # 采样随机噪声作为生成器的输入
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
            gen_imgs = generator(z) # 生成一个批量的图片
            res = discriminator(gen_imgs) # 判别此批量生成图片的真假
            g_loss = adversarial_loss(res, valid) # 计算损失
    
            g_loss.backward() # 反向求导
            optimizer_G.step() # 更新生成器参数,不更新判别器参数
    
            # ---------------------
            #  训练GAN_判别器
            # ---------------------
            optimizer_D.zero_grad() # 手动清除梯度
            # 测量判别器判别真假的能力
            real_loss = adversarial_loss(discriminator(real_imgs), valid) # 判别真实图片是否为真的损失
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 判别生成图片是否为假的损失
            d_loss = (real_loss + fake_loss) / 2 # 总损失
    
            d_loss.backward() # 反向求导
            optimizer_D.step() # 更新生成器参数,不更新判别器参数
    
            # 每个300个batch输出信息
            if i % 300 == 0: 
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                    % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
                )
    
            # 保存生成图片
            batches_done = epoch * len(dataloader) + i
            if batches_done % opt.sample_interval*100 == 0:
                save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
    
    epoch 0 epoch 1 epoch 240

    大家快动手试试吧!

    相关文章

      网友评论

          本文标题:如何使用GAN生成手写数字识别图像——GAN生成式对抗神经网络

          本文链接:https://www.haomeiwen.com/subject/mttuedtx.html