在上一节中可以看到基于”推土距离“的WGAN网络能够有效生成马图片,但是网络构造能力有所不足,因此导致有些图片模糊,甚至有些图片连马的轮廓都没有构建出来,本节我们改进WGAN网络,让它具有更强大的图像生成能力。
在介绍WGAN网络算法时提到,如果把网络看成一个函数,那么网络要想具备好的图像生成能力就必须满足1-Lipshitz条件,也就是要满足公式:
屏幕快照 2020-05-08 上午10.06.59.png根据微积分的中值定理,如果函数f(x)可导,那么对任意x1,x2,可以找到位于(x1,x2)之间的x3,使得如下公式成了:
屏幕快照 2020-05-08 上午10.09.44.png将它带入到上面公式就有:
屏幕快照 2020-05-08 上午10.10.42.png
这意味着如果函数满足1-Lipshitz条件,那么它必须在定义域内的没一点都可导,而且其求倒数后的结果绝对值不能大于1,这是一个相当苛刻的条件。所以上一节描述WGAN网络时,算法作者想不到好的办法让构造的网络满足这个条件,于是”拍脑袋“想出了将网络内部参数的数值全部剪切到(-1,1)之间,这也是造成网络生成图像质量不好的原因。
如果把函数f看做鉴别者网络,把输入的参数x看做是输入网络的图片,那么需要网络对所有输入图片求导后,所得结果求模后不大于1.这里需要进一步解释的是,由于图片含有多个像素点,如果把每一个像素点的值都看成是输入网络的参数,那么网络就是一个多元函数f(x1,x2,....xn),其中x1,x2...xn就是输入图片的像素值,对其求导就是分别针对x1,x2...xn求导,如果使用f1对应与针对x1求导后的结果,那么对所有x1,x2...xn求导后就会得到一个向量(f1,f2....fn),将该向量求模就对应第二个公式中的|f'(xn)|。
问题在于算法要求对所有输入图片都要满足求模后结果不大于1的要求,这点我们无法做到,因为我们不可能拿所有图像输入到网络。例如要让网络生成人脸,我们也不可能拿所有人脸图像来训练网络,因此就要做折中或妥协,我们拿一张真的人脸图像,然后用构造者网络生成一张假的人脸图像,在这两个人脸图像之间取一点,然后让网络对该点求导后结果的绝对值不大于1即可,算法流程如下图所示:
17-12.png由于WGAN-GP算法相对于上一节的WGAN算法,只是针对鉴别者网络的训练过程做了修改,其他都没变,因此这里只给出WGAN-GP的鉴别者网络训练代码:
def train_discriminator(self, image_batch):
'''
训练鉴别师网络,它的训练分两步骤,首先是输入正确图片,让网络有识别正确图片的能力。
然后使用生成者网络构造图片,并告知鉴别师网络图片为假,让网络具有识别生成者网络伪造图片的能力
'''
with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape: #只修改鉴别者网络的内部参数
tape.watch(self.discriminator.trainable_variables)
noise = tf.random.normal([len(image_batch), self.z_dim])
true_logits = self.discriminator(image_batch, training = True)
gen_imgs = self.generator(noise, training = True) #让生成者网络根据关键向量生成图片
fake_logits = self.discriminator(gen_imgs, training = True)
d_loss_real = tf.multiply(tf.ones_like(true_logits), true_logits)#根据推土距离将真图片的标签设置为1
d_loss_fake = tf.multiply(-tf.ones_like(fake_logits), fake_logits)#将伪造图片的标签设置为-1
with tf.GradientTape(watch_accessed_variables=False) as iterploted_tape:#注意此处是与WGAn的主要差异
t = tf.random.uniform(shape = (len(image_batch), 1, 1, 1)) #生成[0,1]区间的随机数
interploted_imgs = tf.add(tf.multiply(1 - t, image_batch), tf.multiply(t, gen_imgs)) #获得真实图片与虚假图片中间的差值
iterploted_tape.watch(interploted_imgs)
interploted_loss = self.discriminator(interploted_imgs)
interploted_imgs_grads = iterploted_tape.gradient(interploted_loss, interploted_imgs)#针对差值求导
grad_norms = tf.norm(interploted_imgs_grads)
penalty = 10 * tf.reduce_mean((grad_norms - 1) ** 2)#确保差值求导所得的模不超过1
d_loss = d_loss_real + d_loss_fake + penalty #penalty 对应WGAN-GP中的GP
grads = tape.gradient(d_loss , self.discriminator.trainable_variables)
self.discriminator_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_variables)) #改进鉴别者网络内部参数
self.d_loss.append(d_loss)
self.d_loss_real.append(d_loss_real)
self.d_loss_fake.append(d_loss_fake)
这里要注意代码中实现在真假图片中间取数值点,然后让其倒数求模不超过1的实现,也就是interploted_imgs_grads的计算过程,这一小片代码决定了网络最终生成图像的质量,使用WGA-GP算法训练网络后,最终生成的人脸图像如下:
屏幕快照 2020-05-08 上午10.28.11.png可以看到网络生成的人脸图像非常细腻生动,虽然有些人脸图像不够清楚,但绝大多数人脸图像,例如第一行第一章人脸图像,你很难想象它是由神经网络生成的虚拟人脸图像,因为它太逼真了。前段时间流行的deep fake,其原理差不多,只是在实现的技术层面做了更多的优化和处理。
更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:
这里写图片描述
网友评论