美文网首页
Keras GAN 同时储存模型和优化器信息

Keras GAN 同时储存模型和优化器信息

作者: 寂风如雪 | 来源:发表于2020-12-28 18:33 被阅读0次

    Github 上有一个很好的开源项目,介绍见这里,包含了用 Keras 实现的 GAN 及其变种。但在研究的过程中发现,难以同时存储模型和优化器状态,如果使用 model.save 则发现再次 load 时优化器信息丢失,而 Keras 的 model 又不能 pickle,也尝试了 pickle 优化器,但似乎会出现一些奇怪的问题,比如 Adam 某些层上的动量不变化等,于是查看了 Keras 源码,并实现了解决方法,分享如下:

    from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
    from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
    from tensorflow.keras.layers import LeakyReLU
    from tensorflow.keras.layers import UpSampling2D, Conv2D
    from tensorflow.keras.models import Sequential, Model
    from tensorflow.keras.optimizers import Adam
    from tensorflow.keras.datasets import mnist
    from tensorflow.python.keras.saving import hdf5_format
    import h5py
    
    def get_optimizer_weights(filepath):
        f = h5py.File(filepath, mode='r')
        weights = hdf5_format.load_optimizer_weights_from_hdf5_group(f)
        return f, weights
    
    class GAN():
        def __init__(self):
            # --------------------------------- #
            #   行28,列28,也就是mnist的shape
            # --------------------------------- #
            self.img_rows = 28
            self.img_cols = 28
            self.channels = 1
            # 28,28,1
            self.img_shape = (self.img_rows, self.img_cols, self.channels)
            self.latent_dim = 100
            # adam优化器
            optimizer_d = Adam(0.0002, 0.5)
            optimizer_c = Adam(0.0002, 0.5)
    
            self.discriminator = self.build_discriminator()
            self.discriminator.compile(loss='binary_crossentropy',
                optimizer=optimizer_d,
                metrics=['accuracy'])
    
            self.generator = self.build_generator()
            gan_input = Input(shape=(self.latent_dim,))
            img = self.generator(gan_input)
            # 在训练generate的时候不训练discriminator
            self.discriminator.trainable = False
            # 对生成的假图片进行预测
            validity = self.discriminator(img)
            self.combined = Model(gan_input, validity)
            self.combined.compile(loss='binary_crossentropy', optimizer=optimizer_c)
    
    
        def build_generator(self):
            # --------------------------------- #
            #   生成器,输入一串随机数字
            # --------------------------------- #
            model = Sequential()
    
            model.add(Dense(256, input_dim=self.latent_dim))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
    
            model.add(Dense(512))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
    
            model.add(Dense(1024))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
    
            model.add(Dense(np.prod(self.img_shape), activation='tanh'))
            model.add(Reshape(self.img_shape))
    
            noise = Input(shape=(self.latent_dim,))
            img = model(noise)
    
            return Model(noise, img)
    
        def build_discriminator(self):
            # ----------------------------------- #
            #   评价器,对输入进来的图片进行评价
            # ----------------------------------- #
            model = Sequential()
            # 输入一张图片
            model.add(Flatten(input_shape=self.img_shape))
            model.add(Dense(512))
            model.add(LeakyReLU(alpha=0.2))
            model.add(Dense(256))
            model.add(LeakyReLU(alpha=0.2))
            # 判断真伪
            model.add(Dense(1, activation='sigmoid'))
    
            img = Input(shape=self.img_shape)
            validity = model(img)
    
            return Model(img, validity)
    
        def train(self, epochs, batch_size=128, sample_interval=50):
            # 获得数据
            (X_train, _), (_, _) = mnist.load_data()
    
            # 进行标准化
            X_train = X_train / 127.5 - 1.
            X_train = np.expand_dims(X_train, axis=3)
    
            # 创建标签
            valid = np.ones((batch_size, 1))
            fake = np.zeros((batch_size, 1))
    
            for epoch in range(epochs):
    
                # --------------------------- #
                #   随机选取batch_size个图片
                #   对discriminator进行训练
                # --------------------------- #
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs = X_train[idx]
    
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
    
                gen_imgs = self.generator.predict(noise)
    
                d_loss_real = self.discriminator.train_on_batch(imgs, valid)
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    
                # --------------------------- #
                #  训练generator
                # --------------------------- #
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                g_loss = self.combined.train_on_batch(noise, valid)
                print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
    
                if epoch % sample_interval == 0:
                    self.sample_images(epoch)
    
        def sample_images(self, epoch):
    
            r, c = 5, 5
            noise = np.random.normal(0, 1, (r * c, self.latent_dim))
            gen_imgs = self.generator.predict(noise)
    
            gen_imgs = 0.5 * gen_imgs + 0.5
    
            fig, axs = plt.subplots(r, c)
            cnt = 0
            for i in range(r):
                for j in range(c):
                    axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                    axs[i,j].axis('off')
                    cnt += 1
            fig.savefig("images/%d.png" % epoch)
            plt.close()
        
        def save(self, path):
            discriminator_path = path + '_discriminator.h5'
            generator_path = path + '_generator.h5'
            combined_path = path + '_combined.h5'
    
            self.discriminator.trainable = True
            self.discriminator.save(discriminator_path)
            self.discriminator.trainable = False
            self.generator.save(generator_path)
            self.combined.save(combined_path)
            
        def load(self, path):
            discriminator_path = path + '_discriminator.h5'
            generator_path = path + '_generator.h5'
            combined_path = path + '_combined.h5'
            
            self.discriminator = tf.keras.models.load_model(discriminator_path)
            self.generator = tf.keras.models.load_model(generator_path, compile=False)
            f, c_weights = get_optimizer_weights(combined_path)
            
            self.discriminator.trainable = False
            
            gan_input = Input(shape=(self.latent_dim,))
            img = self.generator(gan_input)
            validity = self.discriminator(img)
            self.combined = Model(gan_input, validity)
            self.combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
            
            self.combined.optimizer._create_all_weights(self.combined.trainable_variables)
            self.combined.optimizer.set_weights(c_weights)
            
            f.close()
    

    存储模型:

    gan = GAN()
    gan.train(epochs=1, batch_size=256, sample_interval=1000)
    gan.save('savetest')
    

    读取模型

    gan2 = GAN() 
    gan2.load('savetest')
    

    祝大家 造假炼丹,学习愉快。

    相关文章

      网友评论

          本文标题:Keras GAN 同时储存模型和优化器信息

          本文链接:https://www.haomeiwen.com/subject/dlyxiktx.html