美文网首页
WGAN-GP 原理和代码分析

WGAN-GP 原理和代码分析

作者: 小潤澤 | 来源:发表于2024-02-26 17:04 被阅读0次

生成对抗模型(GAN)简介可以参考:https://www.jianshu.com/p/34d9d0755f51
这里介绍的WGAN,将损失函数进行了正则化

文章链接:《Improving protein function prediction with synthetic feature samples created by generative adversarial networks

这里作者提出了一种新的损失函数定义模式,对于普通GAN的损失函数定义:


\widetilde{x} 由生成器 G 产生的 fake data,x 代表 real data,那么对于WGAN-GP 它的损失函数为:


其中 \widetilde{x} 由生成器 G 产生的 fake data,x 代表 real data,\widehat{x} 在本研究中代表:
α 代表随机的参数,λ 这一项代表正则项作为梯度约束

代码部分:https://github.com/psipred/FFPredGAN/blob/master/src/Generating_Synthetic_Positive_Samples_FFPred-GAN.py

这里只重点讲讲目标函数约束的代码部分:

ITERS = 100000 
CRITIC_ITERS = 5

# 训练模型
for iteration in range(ITERS):
    for p in netD.parameters():  
        p.requires_grad = True  

    data = inf_train_gen()
    real_data = torch.FloatTensor(data)
    real_data_v = autograd.Variable(real_data)
    
    noise = torch.randn(BATCH_SIZE, 258)
    noisev = autograd.Variable(noise, volatile=True)  
    fake = autograd.Variable(netG(noisev, real_data_v).data)

    fake_output=fake.data.cpu().numpy()
    
    # 训练判别器 netD
    for iter_d in range(CRITIC_ITERS):
        # 梯度清零
        netD.zero_grad()

        D_real, hidden_output_real_1, hidden_output_real_2, hidden_output_real_3 = netD(real_data_v)

        # 高维张量取平均值,变成一个标量
        D_real = D_real.mean()

        noise = torch.randn(BATCH_SIZE, 258)
        noisev = autograd.Variable(noise, volatile=True)  
        fake = autograd.Variable(netG(noisev, real_data_v).data)
        
        inputv = fake
        D_fake, hidden_output_fake_1, hidden_output_fake_2, hidden_output_fake_3 = netD(inputv)
       
        # 高维张量取平均值,变成一个标量
        D_fake = D_fake.mean()
        
        # 计算正则项
        gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data)
        
        # WGAN-GP 损失函数
        D_cost = D_fake - D_real + gradient_penalty
        Wasserstein_D = D_real - D_fake

        # 反向传播损失函数
        D_cost.backward()
        # 迭代更新
        optimizerD.step()

    # 训练生成器 netG
    for p in netD.parameters():
            p.requires_grad = False

        netG.zero_grad()
        real_data = torch.Tensor(data)
        real_data_v = autograd.Variable(real_data)

        noise = torch.randn(BATCH_SIZE, 258)
        noisev = autograd.Variable(noise)
        fake = netG(noisev, real_data_v)
        G, hidden_output_ignore_1, hidden_output_ignore_2, hidden_output_ignore_3 = netD(fake)

        G = G.mean()
        G_cost = -G
        # 反向传播损失函数
        G_cost.backward()
        # 迭代更新
        optimizerG.step()

计算gradient_penalty的代码为:

def calc_gradient_penalty(netD, real_data, fake_data):
    alpha = torch.rand(BATCH_SIZE, 1)
    alpha = alpha.expand(real_data.size())
    alpha = alpha.cuda() if use_cuda else alpha

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    interpolates = autograd.Variable(interpolates, requires_grad=True)

    disc_interpolates, hidden_output_1, hidden_output_2, hidden_output_3 = netD(interpolates) 
    
    # 求梯度
    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(
                                  disc_interpolates.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    
    # 正则项,二阶范数
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty

相关文章

网友评论

      本文标题:WGAN-GP 原理和代码分析

      本文链接:https://www.haomeiwen.com/subject/yavzadtx.html