美文网首页Machine Learning & Recommendation & NLP & DL
自然语言处理N天-Day1103从0搭建一个RNN神经网络作诗(

自然语言处理N天-Day1103从0搭建一个RNN神经网络作诗(

作者: 我的昵称违规了 | 来源:发表于2019-02-19 09:15 被阅读2次

    说明:本文依据Github上面的一个2000星项目完成。项目作者jinfagang项目地址,在这里感谢那些开源的程序员,让我们学到更多。
    我会尽量将项目进行拆解,希望对大家的学习有所帮助吧。

    第十一课 使用RNN生成古诗

    上一节数据预处理和模型构建

    4.模型的训练

    获取数据 batch 的代码位于poem.py的generate_batch方法,作用是用来获取每一个batch的数值。作为接下来模型训练的输入数据集。
    传入参数有batch_size:batch的大小;poems_vec:前面生成的诗文中字ID;word_to_int:前面生成的每一个字ID。

    def generate_batch(batch_size, poems_vec, word_to_int):
        # 每次取batch_size首诗进行训练
        n_chunk = len(poems_vec) // batch_size
        x_batches = []
        y_batches = []
        #使用for循环,生成n_chunk个batch。
        for i in range(n_chunk):
            #每一个batch开始和结束的index
            start_index = i * batch_size
            end_index = start_index + batch_size
            batches = poems_vec[start_index:end_index]
            
            # 找到这个batch中所有poem最长的poem的长度,以这个长度为最大值生成batch中每一行的长度。
            length = max(map(len, batches))
            # 填充一个空batch,空的地方放空格对应的index标号
            x_data = np.full((batch_size, length), word_to_int[' '], np.int32)
    
            for row, batch in enumerate(batches):
                # 每一行就是一首诗,在原本的长度上把诗还原上去
                x_data[row, :len(batch)] = batch
            y_data = np.copy(x_data)
            # y就是x向左边移动一个,最后一位使用倒数第二位的数值填充
            y_data[:, :-1] = x_data[:, 1:]
            x_batches.append(x_data)
            y_batches.append(y_data)
        return x_batches, y_batches
    

    模型的训练代码位于train.py的run_training方法

    # -*- coding: utf-8 -*-
    import tensorflow as tf
    import os
    import poems
    import models
    
    tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size')
    tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate')
    tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path')
    tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems')
    tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix')
    tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs')
    
    FLAGS = tf.app.flags.FLAGS
    
    
    def run_training():
        if not os.path.exists(FLAGS.model_dir):
            os.makedirs(FLAGS.model_dir)
        # 读取诗集文件
        # 依次得到数字ID表示的诗句、汉字-ID的映射map、所有的汉字的列表
        poems_vector, word_to_int, vocabularies = poems.process_poems(FLAGS.file_path)
        batches_inputs, batches_outputs = poems.generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
    
        input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
        output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    
        # 通过rnn模型得到结果状态集
        end_points = models.rnn_model(model='lstm', input_data=input_data, output_data=output_targets,
                                      vocab_size=len(vocabularies), rnn_size=128, num_layers=2, batch_size=64,
                                      learning_rate=FLAGS.learning_rate)
    
        # 初始化saver和session
        saver = tf.train.Saver(tf.global_variables())
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        with tf.Session() as sess:
            sess.run(init_op)
    
            start_epoch = 0
            checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
            if checkpoint:
                saver.restore(sess, checkpoint)
                print('## restore from the checkpointt {0}'.format(checkpoint))
                start_epoch += int(checkpoint.split('-')[-1])
            print('## strat training...')
    
            try:
                n_chunk = len(poems_vector) // FLAGS.batch_size
                for epoch in range(start_epoch, FLAGS.epoches):
                    n = 0
                    for batch in range(n_chunk):
                        # 训练并计算loss
                        # batches_inputs[n]: 第n个batch的输入数据
                        # batches_outputs[n]: 第n个batch的输出数据
                        loss, _, _ = sess.run([
                            end_points['total_loss'],
                            end_points['last_state'],
                            end_points['train_op']],
                            feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]}
                        )
                        n += 1
                        print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
                        # 每训练6个epoch进行一次模型保存
                        if epoch % 6 == 0:
                            saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
            except KeyboardInterrupt:
                print('## Interrupt manually, try saving checkpoint for now...')
                saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
                print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
    
    
    def main(_):
        run_training()
    
    
    if __name__ == '__main__':
        tf.app.run()
    
    

    相关文章

      网友评论

        本文标题:自然语言处理N天-Day1103从0搭建一个RNN神经网络作诗(

        本文链接:https://www.haomeiwen.com/subject/bmtxyqtx.html