pix2pix

作者: blackmanba_084b | 来源:发表于2020-06-07 21:47 被阅读0次

作为GAN网络的基础pix2pix网络是一个类似于CycleGAN网络的另一种Image-to-Image Translation。网络结构也是非常简单,这里稍微做一下总结。先提供下论文地址 :Image-to-Image Translation with Conditional Adversarial Networks。 这次仍然拿keras框架代码来理解这个网络: code

一、 模型用处

其实pix2pix就是图片翻译网络,即A↔B图片相互转换。例如下面


二、 网络结构

network

G网络其实就是UNet网络可将x生成G(x)。这里的U网络其实就是下采样再进行上采样的Encoder-Decoder, 再加入skip-connection。 判别器网络就是将real A图生成的 fake B进行concate作为输入fake input和real A与real B进行concate作为输入real input。 之后将real input以及fake input通过卷积的结果用判别器分别进行判别。

三、 代码解读

1. Generate网络

    def build_generator(self):
        """U-Net Generator"""

        def conv2d(layer_input, filters, f_size=4, bn=True):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = BatchNormalization(momentum=0.8)(u)
            u = Concatenate()([u, skip_input])
            return u

        # Image input
        d0 = Input(shape=self.img_shape)

        # Downsampling
        d1 = conv2d(d0, self.gf, bn=False)
        d2 = conv2d(d1, self.gf*2)
        d3 = conv2d(d2, self.gf*4)
        d4 = conv2d(d3, self.gf*8)
        d5 = conv2d(d4, self.gf*8)
        d6 = conv2d(d5, self.gf*8)
        d7 = conv2d(d6, self.gf*8)

        # Upsampling
        u1 = deconv2d(d7, d6, self.gf*8)
        u2 = deconv2d(u1, d5, self.gf*8)
        u3 = deconv2d(u2, d4, self.gf*8)
        u4 = deconv2d(u3, d3, self.gf*4)
        u5 = deconv2d(u4, d2, self.gf*2)
        u6 = deconv2d(u5, d1, self.gf)

        u7 = UpSampling2D(size=2)(u6)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)

        return Model(d0, output_img)

d1-d7为下采样网络, u1-u7为上采样网络。

2. Discriminator 网络

    def build_discriminator(self):

        def d_layer(layer_input, filters, f_size=4, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # Concatenate image and conditioning image by channels to produce input
        combined_imgs = Concatenate(axis=-1)([img_A, img_B])

        d1 = d_layer(combined_imgs, self.df, bn=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)

        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

        return Model([img_A, img_B], validity)

这里我们将concate[A, B]通过卷积作为判别器进行判别。

3. loss计算

# Train the discriminators (original images = real / generated = Fake)
d_loss_real = self.discriminator.train_on_batch([imgs_A, imgs_B], valid)
d_loss_fake = self.discriminator.train_on_batch([fake_A, imgs_B], fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# -----------------
#  Train Generator
# -----------------
# Train the generators
g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])

首先我们先确定label

patch = int(self.img_rows / 2**4)
self.disc_patch = (patch, patch, 1)
valid = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)

通过patchGAN得到我们的patch,这样patch每一个cell其实可以理解为一个感受野代表原图的一个区域。这个时候分别得到valid(batch_size, patch, patch) 每个元素为1的矩阵, 以及fake(batch_size, patch, patch) 每个元素为0的矩阵。之后再计算判别器loss(real input,valid)和 loss(fake input, fake), 这里loss用的是MSEloss,当然需要将这两个loss做加权求平均。之后在计算生成器loss,

self.combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A])
self.combined.compile(loss=['mse', 'mae'],  loss_weights=[1, 100],optimizer=optimizer)

四、 展示效果

相关文章

网友评论

      本文标题:pix2pix

      本文链接:https://www.haomeiwen.com/subject/dqutzhtx.html