美文网首页
CGAN-监督式GAN

CGAN-监督式GAN

作者: baiyang白杨 | 来源:发表于2019-05-16 22:29 被阅读0次

    1.CGAN的简介

    为了解决带标签的数据生成问题,研究者们提出了条件生成对抗网络(CGAN)的概念。

    CGAN的结构如上图所示,与GAN的主要区别是生成器和判别器的输入数据中都加入类别标签向量(C_vector),生成器的优化目标函数基本上没有变化。

    总的来说CGAN在GAN上的改动并不大,但是普通的GAN所生成的内容是随机的,CGAN实现了根据输入标签生成指定类别的内容。

    2.CGAN的实现

    目前CGAN的实现由多种形式,主要的区别是C_vector的形式,目前主要有以下三种形式:

    第一中形式:

    将输入Generator的C_vector进行One-hot编码,然后与noise进行拼接,此时C_vector为(batch_size, class_num) ,noise为(batch_size, latent_dim),将拼接之后大小为(batch_size, latent_dim+class_num)作为生成器的输入。

    将输入Discrimintor的C_vector首先进行One-hot编码然后通过expand()方法进行维度扩展,此时的C_vector为 (batch_size, class_num, cols, rows) , Real_data 和 Fake_data为(batch_size, channel, cols, rows),最后将转换后的C_vector和Real_data或者Fake_data进行拼接,将拼接之后大小为(batch_size, channel+class_num, cols, rows)的张量作为判别器的输入。

    第二种形式:

    将输入Generator的C_vector通过Embedding方法进词嵌入,并进行Flatten操作,从而将C_vector转换成为与noise大小相同的张量(batch_size, latent_dim), 然后将noise 和 C_vector 进行mulitiply()操作(即对应位置上的元素相乘,该运输不改变张量的大小),将最终得到的(batch_size, latent_dim)的张量作为生成器的输入。

    将输入Discriminator的C_vector通过Embedding方法进行词嵌入,并进行Flatten操作,从而将C_vector转换为(batch_size, channel*rows*cols),接着对Real_data和Fake_data进行Flatten操作,将其转换为(batch_size, channel*rows*cols),然后将转换后的C_vector和Real_data或者Fake_data进行multiply()操作,将最终得到的(batch_size, channel*rows*cols)张量作为判别器的输入。

    第三种形式:

    将输入Generator的C_vector进行One-hot编码,然后与noise进行拼接,此时C_vector为(batch_size, class_num) ,noise为(batch_size, latent_dim),最后将拼接后大小为(batch_size, latent_dim+class_num)作为生成器的输入。

    将输入Discriminator的C_vector进行One-hot编码,然后与经过Flatten()处理之后的Real_data或者Fake_data进行拼接,此时Real_data和Fake_data为(batch_size, channel*rows*cols),C_vector为(batch_size, num_class),最后将拼接之后大小为(batch_size, channel*rows*cols + num_class)的张量作为判别器的输入。

    损失函数:

    在具体实现上,CGAN的损失函数和GAN基本相同。

    相关文章

      网友评论

          本文标题:CGAN-监督式GAN

          本文链接:https://www.haomeiwen.com/subject/iqfwaqtx.html