Github 上有一个很好的开源项目,介绍见这里,包含了用 Keras 实现的 GAN 及其变种。但在研究的过程中发现,难以同时存储模型和优化器状态,如果使用 model.save
则发现再次 load
时优化器信息丢失,而 Keras 的 model 又不能 pickle
,也尝试了 pickle
优化器,但似乎会出现一些奇怪的问题,比如 Adam
某些层上的动量不变化等,于是查看了 Keras 源码,并实现了解决方法,分享如下:
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
from tensorflow.python.keras.saving import hdf5_format
import h5py
def get_optimizer_weights(filepath):
f = h5py.File(filepath, mode='r')
weights = hdf5_format.load_optimizer_weights_from_hdf5_group(f)
return f, weights
class GAN():
def __init__(self):
# --------------------------------- #
# 行28,列28,也就是mnist的shape
# --------------------------------- #
self.img_rows = 28
self.img_cols = 28
self.channels = 1
# 28,28,1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
# adam优化器
optimizer_d = Adam(0.0002, 0.5)
optimizer_c = Adam(0.0002, 0.5)
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer_d,
metrics=['accuracy'])
self.generator = self.build_generator()
gan_input = Input(shape=(self.latent_dim,))
img = self.generator(gan_input)
# 在训练generate的时候不训练discriminator
self.discriminator.trainable = False
# 对生成的假图片进行预测
validity = self.discriminator(img)
self.combined = Model(gan_input, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer_c)
def build_generator(self):
# --------------------------------- #
# 生成器,输入一串随机数字
# --------------------------------- #
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
def build_discriminator(self):
# ----------------------------------- #
# 评价器,对输入进来的图片进行评价
# ----------------------------------- #
model = Sequential()
# 输入一张图片
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
# 判断真伪
model.add(Dense(1, activation='sigmoid'))
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
def train(self, epochs, batch_size=128, sample_interval=50):
# 获得数据
(X_train, _), (_, _) = mnist.load_data()
# 进行标准化
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
# 创建标签
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# --------------------------- #
# 随机选取batch_size个图片
# 对discriminator进行训练
# --------------------------- #
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_imgs = self.generator.predict(noise)
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# --------------------------- #
# 训练generator
# --------------------------- #
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
g_loss = self.combined.train_on_batch(noise, valid)
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
plt.close()
def save(self, path):
discriminator_path = path + '_discriminator.h5'
generator_path = path + '_generator.h5'
combined_path = path + '_combined.h5'
self.discriminator.trainable = True
self.discriminator.save(discriminator_path)
self.discriminator.trainable = False
self.generator.save(generator_path)
self.combined.save(combined_path)
def load(self, path):
discriminator_path = path + '_discriminator.h5'
generator_path = path + '_generator.h5'
combined_path = path + '_combined.h5'
self.discriminator = tf.keras.models.load_model(discriminator_path)
self.generator = tf.keras.models.load_model(generator_path, compile=False)
f, c_weights = get_optimizer_weights(combined_path)
self.discriminator.trainable = False
gan_input = Input(shape=(self.latent_dim,))
img = self.generator(gan_input)
validity = self.discriminator(img)
self.combined = Model(gan_input, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
self.combined.optimizer._create_all_weights(self.combined.trainable_variables)
self.combined.optimizer.set_weights(c_weights)
f.close()
存储模型:
gan = GAN()
gan.train(epochs=1, batch_size=256, sample_interval=1000)
gan.save('savetest')
读取模型
gan2 = GAN()
gan2.load('savetest')
祝大家 造假,炼丹,学习愉快。
网友评论