class VAE(object):
def __init__(self, n_hidden=500, dim_z=20, n_epochs=20, batch_size=128, learn_rate=1e-3,
model_path=base_dir +"/similarity_k_line/feature_extract/vae_feature_map_model"):
""" parameters """
self.model_path = model_path
# network architecture
self.n_hidden = n_hidden
self.dim_img =7*89 # number of pixels for a feature image
self.dim_z = dim_z
# train
self.n_epochs = n_epochs
self.batch_size = batch_size
self.learn_rate = learn_rate
# start a subThread for map producer
self.data_generator()
# build graph
self.build_graph()
def data_generator(self):
pool =list()
self.warehouse = Warehouse(pool=pool)
p = Build_maps(pool=self.warehouse)
p.start()
def build_graph(self):
# input placeholders
# In denoising-autoencoder, x_hat == x + noise, otherwise x_hat == x
self.x_hat = tf.placeholder(tf.float32, shape=[None, self.dim_img], name='input_img')
self.x = tf.placeholder(tf.float32, shape=[None, self.dim_img], name='target_img')
self.global_steps = tf.Variable(0, trainable=False, name="global_steps")
# dropout
self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
# network architecture
y, self.z, self.loss, self.neg_marginal_likelihood, self.KL_divergence = \
vae.autoencoder(self.x_hat, self.x, self.dim_img, self.dim_z, self.n_hidden, self.keep_prob)
with tf.name_scope('loss'):
tf.summary.scalar('total_loss', self.loss)
tf.summary.scalar('KL divergence', self.KL_divergence)
tf.summary.scalar('likelihood loss', self.neg_marginal_likelihood)
# 构建Graph的变量列表
# self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
def train(self):
# train_op
self.train_op = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss, global_step=self.global_steps)
saver = tf.train.Saver(max_to_keep=4)
with tf.Session()as sess:
writer = tf.summary.FileWriter("logs/", sess.graph)
merge_summary = tf.summary.merge_all()
ckpt = tf.train.get_checkpoint_state(self.model_path)
if ckptand ckpt.model_checkpoint_path:
# saver = tf.train.import_meta_graph(meta_graph_or_file=ckpt.model_checkpoint_path + ".meta")
saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir=self.model_path))
print('finish loading model!')
else:
print("no checkpoint found...")
sess.run(tf.global_variables_initializer())
# start train thread
while True:
if self.warehouse.get_length() <10000:
continue
for epochin range(self.n_epochs):
batch_xs_input = np.array(self.warehouse.get(num_retrived=500)).reshape([500, -1])
batch_xs_target = batch_xs_input
# train
_, tot_loss, loss_likelihood, loss_divergence, train_summary, global_steps, z = sess.run(
(self.train_op, self.loss, self.neg_marginal_likelihood, self.KL_divergence, merge_summary,
self.global_steps, self.z), feed_dict={self.x_hat: batch_xs_input, self.x: batch_xs_target,
self.keep_prob:0.9})
writer.add_summary(train_summary, global_steps)
# print cost every epoch
print("epoch %d: L_tot %03.2f L_likelihood %03.2f L_divergence %03.2f" %
(epoch, tot_loss, loss_likelihood, loss_divergence))
if global_steps %10 ==0:
saver.save(sess=sess, save_path=self.model_path +"/vae_model", global_step=global_steps)
def run_encoder(self):
feature_map_input =None
saver = tf.train.Saver(max_to_keep=4)
with tf.Session()as sess:
# model restore
ckpt = tf.train.get_checkpoint_state(self.model_path)
if ckptand ckpt.model_checkpoint_path:
# saver = tf.train.import_meta_graph(meta_graph_or_file=ckpt.model_checkpoint_path + ".meta")
saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir=self.model_path))
print('finish loading model!')
while True:
if self.warehouse.get_length() >100:
feature_map_input = np.array(self.warehouse.get(num_retrived=1)).reshape([1, -1])
# encoder
z = sess.run(fetches=self.z, feed_dict={self.x_hat: feature_map_input, self.x: feature_map_input,
self.keep_prob:0.9})
print(z)
网友评论