美文网首页
Keras GAN 同时储存模型和优化器信息

Keras GAN 同时储存模型和优化器信息

作者: 寂风如雪 | 来源:发表于2020-12-28 18:33 被阅读0次

Github 上有一个很好的开源项目,介绍见这里,包含了用 Keras 实现的 GAN 及其变种。但在研究的过程中发现,难以同时存储模型和优化器状态,如果使用 model.save 则发现再次 load 时优化器信息丢失,而 Keras 的 model 又不能 pickle,也尝试了 pickle 优化器,但似乎会出现一些奇怪的问题,比如 Adam 某些层上的动量不变化等,于是查看了 Keras 源码,并实现了解决方法,分享如下:

from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
from tensorflow.python.keras.saving import hdf5_format
import h5py

def get_optimizer_weights(filepath):
    f = h5py.File(filepath, mode='r')
    weights = hdf5_format.load_optimizer_weights_from_hdf5_group(f)
    return f, weights

class GAN():
    def __init__(self):
        # --------------------------------- #
        #   行28,列28,也就是mnist的shape
        # --------------------------------- #
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        # 28,28,1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
        # adam优化器
        optimizer_d = Adam(0.0002, 0.5)
        optimizer_c = Adam(0.0002, 0.5)

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer_d,
            metrics=['accuracy'])

        self.generator = self.build_generator()
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        # 在训练generate的时候不训练discriminator
        self.discriminator.trainable = False
        # 对生成的假图片进行预测
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer_c)


    def build_generator(self):
        # --------------------------------- #
        #   生成器,输入一串随机数字
        # --------------------------------- #
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):
        # ----------------------------------- #
        #   评价器,对输入进来的图片进行评价
        # ----------------------------------- #
        model = Sequential()
        # 输入一张图片
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        # 判断真伪
        model.add(Dense(1, activation='sigmoid'))

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):
        # 获得数据
        (X_train, _), (_, _) = mnist.load_data()

        # 进行标准化
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # 创建标签
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # --------------------------- #
            #   随机选取batch_size个图片
            #   对discriminator进行训练
            # --------------------------- #
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            gen_imgs = self.generator.predict(noise)

            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # --------------------------- #
            #  训练generator
            # --------------------------- #
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.combined.train_on_batch(noise, valid)
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):

        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()
    
    def save(self, path):
        discriminator_path = path + '_discriminator.h5'
        generator_path = path + '_generator.h5'
        combined_path = path + '_combined.h5'

        self.discriminator.trainable = True
        self.discriminator.save(discriminator_path)
        self.discriminator.trainable = False
        self.generator.save(generator_path)
        self.combined.save(combined_path)
        
    def load(self, path):
        discriminator_path = path + '_discriminator.h5'
        generator_path = path + '_generator.h5'
        combined_path = path + '_combined.h5'
        
        self.discriminator = tf.keras.models.load_model(discriminator_path)
        self.generator = tf.keras.models.load_model(generator_path, compile=False)
        f, c_weights = get_optimizer_weights(combined_path)
        
        self.discriminator.trainable = False
        
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
        
        self.combined.optimizer._create_all_weights(self.combined.trainable_variables)
        self.combined.optimizer.set_weights(c_weights)
        
        f.close()

存储模型:

gan = GAN()
gan.train(epochs=1, batch_size=256, sample_interval=1000)
gan.save('savetest')

读取模型

gan2 = GAN() 
gan2.load('savetest')

祝大家 造假炼丹,学习愉快。

相关文章

  • Keras GAN 同时储存模型和优化器信息

    Github 上有一个很好的开源项目[https://github.com/eriklindernoren/Ker...

  • GAN和WGAN

    1. GAN GAN有两个模型,分别是生成模型和判别模型。这两个模型进行博弈,先保持生成模型不变,优化判别模型(判...

  • 0003-keras自定义优化器

    原文 keras优化器的代码 自定义一个SGD优化器 实现“软batch” 假如模型比较庞大,自己的显卡最多也就能...

  • 23、keras

    一、keras 1.关于Keras 二、代码 代码: 声明模型、损失、准确率、优化算法Adam 定义优化过程 fo...

  • GAN 的 keras 实现

    本文结构: 什么是 GAN? 优点? keras 例子? 什么是 GAN? GAN,全称为 Generative ...

  • WGAN的介绍

    GAN的介绍 GAN的思想 GAN模型由生成式模型(generative model)和判别式模型(discrim...

  • 15 keras模型

    原文地址介绍完了优化器和目标函数,那么剩下的就是训练模型了。这一小节,我们来看一下Keras的Models是如何使...

  • 自动部署深度神经网络模型TensorFlow(Keras)到生产

    目录 Keras简介 Keras模型分类 Keras模型部署准备 默认部署Keras模型 自定义部署Keras模型...

  • tensorflow模型建立与训练

    线性分类器Model 数据获取及预处理: tf.keras.datasets 模型的构建: tf.keras.Mo...

  • Conditional Adversarial Nets 详解

    简介 本文提出在利用 GAN(对抗网络)的方法时,在生成模型G和判别模型D中都加入条件信息来引导模型的训练,并将这...

网友评论

      本文标题:Keras GAN 同时储存模型和优化器信息

      本文链接:https://www.haomeiwen.com/subject/dlyxiktx.html