WGAN

作者: scpy | 来源:发表于2018-12-24 13:44 被阅读0次
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.utils import save_image
import os
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets


class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(100, 128)
        self.fc11 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.out = nn.Linear(1024, 784)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc11(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.tanh(self.out(x))
        return x


class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.out = nn.Linear(256, 1)

    def forward(self, input):
        x = input.view(input.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.out(x) #最后一层去掉sigmoid
        return x


gen = generator()
dis = discriminator()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
gen.to(device)
dis.to(device)
one = torch.FloatTensor([1]).cuda()
mone = one * -1
one.to(device)
mone.to(device)
#WGAN不用动量类optim, 减少梯度的漂移
optimizer_G = optim.RMSprop(gen.parameters(), lr=0.00005)
optimizer_D = optim.RMSprop(dis.parameters(), lr=0.00005)
# Configure data loader
os.makedirs('D:/mnist/', exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('D:/mnist/', train=True, download=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])), batch_size=64, shuffle=True)
data_iter = iter(dataloader)
imp = next(data_iter)
print(imp[0].shape)

#print((img.shape())
gen_iterations = 0
print("finish load dataset")
for epoch in range(200):
    data_iter = iter(dataloader)
    i = 0
    while i < len(dataloader):
        for p in dis.parameters():
            p.requires_grad = True
        # idx < 25 时 D 循环更新 25 次才会更新 G,用来保证 D 的网络大致满足 Wasserstein 距离
        if gen_iterations < 25 or gen_iterations % 500 == 0:
            Diters = 100
        else:
            Diters = 5
        batch_size = 64
        noise = torch.Tensor(np.random.normal(0, 1, (batch_size, 100))).cuda()  # 随机生成noise
        """
            update D network
        """
        j = 0
        while j < Diters and i < len(dataloader):
            j += 1
            for p in dis.parameters():
                p.data.clamp_(-0.01, 0.01)  # 将判别器所有的参数截断到一个区间内
            img = data_iter.__next__()
            i += 1
            dis.zero_grad()
            # train real
            real_imgs = torch.Tensor(img[0]).cuda()
            real_loss = dis(real_imgs).mean(0)
            real_loss.backward(one)  # one mone用处是啥没弄清楚
            fake_loss = dis(gen(noise)).mean(0)
            fake_loss.backward(mone)
            d_loss = real_loss - fake_loss  # Wasserstein 距离
            optimizer_D.step()
            gen_iterations += 1
            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f]" % (epoch, 200, i, len(dataloader),
                                                                             d_loss.item()))
            # update G network
        for p in dis.parameters():
            p.requires_grad = False
        gen.zero_grad()
        fake = gen(noise)
        g_loss = dis(fake).mean(0)
        g_loss.backward(one)
        optimizer_G.step()
        print("[Epoch %d/%d] [Batch %d/%d] [G loss: %f]" % (epoch, 200, i, len(dataloader), g_loss.item()))

        if gen_iterations % 400 <= 1:
            save_image(fake.data[:25].view(25, 1, 28, 28), 'D:/mnist/images-2/%d.png' % gen_iterations, nrow=5, normalize=True)

相关文章

  • 一个比WGAN更优秀的模型(WGAN-GP)

    WGAN-GP (improved wgan) paper GitHub WGAN-GP是WGAN之后的改进版,主...

  • 用Pytorch实现WGAN

    本文是解读WGAN的实践篇,目标是用pytorch实现能生成人脸图像的WGAN。如果对WGAN、DCGANs和GA...

  • Wasserstein GAN简明版

    涉及WGAN的论文总共三篇:WGAN前作:Towards Principled Methods for Train...

  • WGAN

    GAN-QP 写到一半发现关于 WGAN 以及它相关约束部分之前没有完全读懂,需要重读,那顺手也把笔记给谢了吧 W...

  • WGAN

  • 零样本图像识别 | Feature Generating Net

    创新:提出f-GAN 、 f-WGAN 和 f-CLSWGAN、将WGAN的loss和Classfication的...

  • GAN 生成对抗网络

    GAN网络可以说是近来最火的神经网络模型,其变种包括WGAN,WCGAN,WGAN-l,circleGAN等,被广...

  • 从零单排fastai脚本(2)

    这次看下wgan脚本,这里使用fastai来完成wgan的训练和使用。 老三样,我就不加标题了 1 重要的包 其中...

  • 解读Wasserstein GAN

    Generative Adversarial Networks (GANs) 在解读WGAN(Wasserstei...

  • 从GAN到WGAN再到WGAN-GP

    从GAN到WGAN再到WGAN-GP 基本理论知识 KL散度: JS散度:其中性质:满足对称性;当两概率为0时,J...

网友评论

      本文标题:WGAN

      本文链接:https://www.haomeiwen.com/subject/nbxzkqtx.html