美文网首页
2018-12-12-VAE

2018-12-12-VAE

作者: AI_Finance | 来源:发表于2018-12-12 17:30 被阅读0次

    class VAE(object):

    def __init__(self, n_hidden=500, dim_z=20, n_epochs=20, batch_size=128, learn_rate=1e-3,

                    model_path=base_dir +"/similarity_k_line/feature_extract/vae_feature_map_model"):

    """ parameters """

            self.model_path = model_path

    # network architecture

            self.n_hidden = n_hidden

    self.dim_img =7*89  # number of pixels for a feature image

            self.dim_z = dim_z

    # train

            self.n_epochs = n_epochs

    self.batch_size = batch_size

    self.learn_rate = learn_rate

    # start a subThread for map producer

            self.data_generator()

    # build graph

            self.build_graph()

    def data_generator(self):

    pool =list()

    self.warehouse = Warehouse(pool=pool)

    p = Build_maps(pool=self.warehouse)

    p.start()

    def build_graph(self):

    # input placeholders

    # In denoising-autoencoder, x_hat == x + noise, otherwise x_hat == x

            self.x_hat = tf.placeholder(tf.float32, shape=[None, self.dim_img], name='input_img')

    self.x = tf.placeholder(tf.float32, shape=[None, self.dim_img], name='target_img')

    self.global_steps = tf.Variable(0, trainable=False, name="global_steps")

    # dropout

            self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')

    # network architecture

            y, self.z, self.loss, self.neg_marginal_likelihood, self.KL_divergence = \

    vae.autoencoder(self.x_hat, self.x, self.dim_img, self.dim_z, self.n_hidden, self.keep_prob)

    with tf.name_scope('loss'):

    tf.summary.scalar('total_loss', self.loss)

    tf.summary.scalar('KL divergence', self.KL_divergence)

    tf.summary.scalar('likelihood loss', self.neg_marginal_likelihood)

    # 构建Graph的变量列表

    # self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)

        def train(self):

    # train_op

            self.train_op = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss, global_step=self.global_steps)

    saver = tf.train.Saver(max_to_keep=4)

    with tf.Session()as sess:

    writer = tf.summary.FileWriter("logs/", sess.graph)

    merge_summary = tf.summary.merge_all()

    ckpt = tf.train.get_checkpoint_state(self.model_path)

    if ckptand ckpt.model_checkpoint_path:

    # saver = tf.train.import_meta_graph(meta_graph_or_file=ckpt.model_checkpoint_path + ".meta")

                    saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir=self.model_path))

    print('finish loading model!')

    else:

    print("no checkpoint found...")

    sess.run(tf.global_variables_initializer())

    # start train thread

                while True:

    if self.warehouse.get_length() <10000:

    continue

                    for epochin range(self.n_epochs):

    batch_xs_input = np.array(self.warehouse.get(num_retrived=500)).reshape([500, -1])

    batch_xs_target = batch_xs_input

    # train

                        _, tot_loss, loss_likelihood, loss_divergence, train_summary, global_steps, z = sess.run(

    (self.train_op, self.loss, self.neg_marginal_likelihood, self.KL_divergence, merge_summary,

                            self.global_steps, self.z), feed_dict={self.x_hat: batch_xs_input, self.x: batch_xs_target,

                                                            self.keep_prob:0.9})

    writer.add_summary(train_summary, global_steps)

    # print cost every epoch

                        print("epoch %d: L_tot %03.2f L_likelihood %03.2f L_divergence %03.2f" %

    (epoch, tot_loss, loss_likelihood, loss_divergence))

    if global_steps %10 ==0:

    saver.save(sess=sess, save_path=self.model_path +"/vae_model", global_step=global_steps)

    def run_encoder(self):

    feature_map_input =None

            saver = tf.train.Saver(max_to_keep=4)

    with tf.Session()as sess:

    # model restore

                ckpt = tf.train.get_checkpoint_state(self.model_path)

    if ckptand ckpt.model_checkpoint_path:

    # saver = tf.train.import_meta_graph(meta_graph_or_file=ckpt.model_checkpoint_path + ".meta")

                    saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir=self.model_path))

    print('finish loading model!')

    while True:

    if self.warehouse.get_length() >100:

    feature_map_input = np.array(self.warehouse.get(num_retrived=1)).reshape([1, -1])

    # encoder

                    z = sess.run(fetches=self.z, feed_dict={self.x_hat: feature_map_input, self.x: feature_map_input,

                                                            self.keep_prob:0.9})

    print(z)

    相关文章

      网友评论

          本文标题:2018-12-12-VAE

          本文链接:https://www.haomeiwen.com/subject/kwyphqtx.html