利用RNN和LSTM生成小说题记

作者: _两只橙_ | 来源:发表于2018-03-16 19:41 被阅读212次

    一、选取素材

    • 语料格式
    • 题记:此情可待成追忆,只是当时已惘然。

    二、开发环境

    三、实战代码

    #!/bash/bin
    # -*-coding:utf-8-*-
    
    import sys
    import os
    import numpy as np
    import collections
    import tensorflow as tf
    import tensorflow.contrib.rnn as rnn
    import tensorflow.contrib.legacy_seq2seq as seq2seq
    
    BEGIN_CHAR = '^'
    END_CHAR = '$'
    UNKNOWN_CHAR = '*'
    MAX_LENGTH = 100
    MIN_LENGTH = 10
    max_words = 3000
    epochs = 50
    # 语料
    poetry_file = 'story.txt'
    # 模型文件存放位置
    save_dir = 'model'
    
    
    class Data:
        def __init__(self):
            self.batch_size = 64
            self.poetry_file = poetry_file
            self.load()
            self.create_batches()
    
        def load(self):
            def handle(line):
                if len(line) > MAX_LENGTH:
                    index_end = line.rfind('。', 0, MAX_LENGTH)
                    index_end = index_end if index_end > 0 else MAX_LENGTH
                    line = line[:index_end + 1]
                return BEGIN_CHAR + line + END_CHAR
    
            self.poetrys = [line.strip().replace(' ', '').split(':')[1] for line in
                            open(self.poetry_file, encoding='utf-8')]
            self.poetrys = [handle(line) for line in self.poetrys if len(line) > MIN_LENGTH]
            # 所有字
            words = []
            for poetry in self.poetrys:
                words += [word for word in poetry]
            counter = collections.Counter(words)
            count_pairs = sorted(counter.items(), key=lambda x: -x[1])
            words, _ = zip(*count_pairs)
    
            # 取出现频率最高的词的数量组成字典,不在字典中的字用'*'代替
            words_size = min(max_words, len(words))
            self.words = words[:words_size] + (UNKNOWN_CHAR,)
            self.words_size = len(self.words)
    
            # 字映射成id
            self.char2id_dict = {w: i for i, w in enumerate(self.words)}
            self.id2char_dict = {i: w for i, w in enumerate(self.words)}
            self.unknow_char = self.char2id_dict.get(UNKNOWN_CHAR)
            self.char2id = lambda char: self.char2id_dict.get(char, self.unknow_char)
            self.id2char = lambda num: self.id2char_dict.get(num)
            self.poetrys = sorted(self.poetrys, key=lambda line: len(line))
            self.poetrys_vector = [list(map(self.char2id, poetry)) for poetry in self.poetrys]
    
        def create_batches(self):
            self.n_size = len(self.poetrys_vector) // self.batch_size
            self.poetrys_vector = self.poetrys_vector[:self.n_size * self.batch_size]
            self.x_batches = []
            self.y_batches = []
            for i in range(self.n_size):
                batches = self.poetrys_vector[i * self.batch_size: (i + 1) * self.batch_size]
                length = max(map(len, batches))
                for row in range(self.batch_size):
                    if len(batches[row]) < length:
                        r = length - len(batches[row])
                        batches[row][len(batches[row]): length] = [self.unknow_char] * r
                xdata = np.array(batches)
                ydata = np.copy(xdata)
                ydata[:, :-1] = xdata[:, 1:]
                self.x_batches.append(xdata)
                self.y_batches.append(ydata)
    
    
    class Model:
        def __init__(self, data, model='lstm', infer=False):
            self.rnn_size = 128
            self.n_layers = 2
    
            if infer:
                self.batch_size = 1
            else:
                self.batch_size = data.batch_size
    
            if model == 'rnn':
                cell_rnn = rnn.BasicRNNCell
            elif model == 'gru':
                cell_rnn = rnn.GRUCell
            elif model == 'lstm':
                cell_rnn = rnn.BasicLSTMCell
    
            cell = cell_rnn(self.rnn_size, state_is_tuple=False)
            self.cell = rnn.MultiRNNCell([cell] * self.n_layers, state_is_tuple=False)
    
            self.x_tf = tf.placeholder(tf.int32, [self.batch_size, None])
            self.y_tf = tf.placeholder(tf.int32, [self.batch_size, None])
    
            self.initial_state = self.cell.zero_state(self.batch_size, tf.float32)
    
            with tf.variable_scope('rnnlm'):
                softmax_w = tf.get_variable("softmax_w", [self.rnn_size, data.words_size])
                softmax_b = tf.get_variable("softmax_b", [data.words_size])
                with tf.device("/cpu:0"):
                    embedding = tf.get_variable(
                        "embedding", [data.words_size, self.rnn_size])
                    inputs = tf.nn.embedding_lookup(embedding, self.x_tf)
    
            outputs, final_state = tf.nn.dynamic_rnn(
                self.cell, inputs, initial_state=self.initial_state, scope='rnnlm')
    
            self.output = tf.reshape(outputs, [-1, self.rnn_size])
            self.logits = tf.matmul(self.output, softmax_w) + softmax_b
            self.probs = tf.nn.softmax(self.logits)
            self.final_state = final_state
            pred = tf.reshape(self.y_tf, [-1])
            # seq2seq
            loss = seq2seq.sequence_loss_by_example([self.logits],
                                                    [pred],
                                                    [tf.ones_like(pred, dtype=tf.float32)], )
    
            self.cost = tf.reduce_mean(loss)
            self.learning_rate = tf.Variable(0.0, trainable=False)
            tvars = tf.trainable_variables()
            grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), 5)
    
            optimizer = tf.train.AdamOptimizer(self.learning_rate)
            self.train_op = optimizer.apply_gradients(zip(grads, tvars))
    
    
    def train(data, model):
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(tf.global_variables())
            n = 0
            for epoch in range(epochs):
                sess.run(tf.assign(model.learning_rate, 0.002 * (0.97 ** epoch)))
                pointer = 0
                for batche in range(data.n_size):
                    n += 1
                    feed_dict = {model.x_tf: data.x_batches[pointer], model.y_tf: data.y_batches[pointer]}
                    pointer += 1
                    train_loss, _, _ = sess.run([model.cost, model.final_state, model.train_op], feed_dict=feed_dict)
                    sys.stdout.write('\r')
                    info = "{}/{} (epoch {}) | train_loss {:.3f}" \
                        .format(epoch * data.n_size + batche,
                                epochs * data.n_size, epoch, train_loss)
                    sys.stdout.write(info)
                    sys.stdout.flush()
                    # save
                    if (epoch * data.n_size + batche) % 1000 == 0 \
                            or (epoch == epochs - 1 and batche == data.n_size - 1):
                        checkpoint_path = os.path.join(save_dir, 'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=n)
                        sys.stdout.write('\n')
                        print("model saved to {}".format(checkpoint_path))
                sys.stdout.write('\n')
    
    
    def sample(data, model, head=u''):
        def to_word(weights):
            t = np.cumsum(weights)
            s = np.sum(weights)
            sa = int(np.searchsorted(t, np.random.rand(1) * s))
            return data.id2char(sa)
    
        for word in head:
            if word not in data.words:
                return u'{} 不在字典中'.format(word)
    
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
    
            saver = tf.train.Saver(tf.global_variables())
            model_file = tf.train.latest_checkpoint(save_dir)
            saver.restore(sess, model_file)
    
            if head:
                print('生成题记 ---> ', head)
                poem = BEGIN_CHAR
                for head_word in head:
                    poem += head_word
                    x = np.array([list(map(data.char2id, poem))])
                    state = sess.run(model.cell.zero_state(1, tf.float32))
                    feed_dict = {model.x_tf: x, model.initial_state: state}
                    [probs, state] = sess.run([model.probs, model.final_state], feed_dict)
                    word = to_word(probs[-1])
                    while word != u',' and word != u'。':
                        poem += word
                        x = np.zeros((1, 1))
                        x[0, 0] = data.char2id(word)
                        [probs, state] = sess.run([model.probs, model.final_state],
                                                  {model.x_tf: x, model.initial_state: state})
                        word = to_word(probs[-1])
                    poem += word
                return poem[1:]
            else:
                poem = ''
                head = BEGIN_CHAR
                x = np.array([list(map(data.char2id, head))])
                state = sess.run(model.cell.zero_state(1, tf.float32))
                feed_dict = {model.x_tf: x, model.initial_state: state}
                [probs, state] = sess.run([model.probs, model.final_state], feed_dict)
                word = to_word(probs[-1])
                while word != END_CHAR:
                    poem += word
                    x = np.zeros((1, 1))
                    x[0, 0] = data.char2id(word)
                    [probs, state] = sess.run([model.probs, model.final_state],
                                              {model.x_tf: x, model.initial_state: state})
                    word = to_word(probs[-1])
                return poem
    
    
    if __name__ == '__main__':
    
        # 训练模型
        data = Data()
        model = Model(data=data, infer=False)
        print(train(data, model))
    
        # 生成题记
        # data = Data()
        # model = Model(data=data, infer=True)
        # print(sample(data, model, head='我为秋香'))
    
    
    输出
    生成题记 --->  我为秋香
    我罢性不行,为德劝仙兴。秋风暝冰始,香巢深器酒。
    
    输出

    相关文章

      网友评论

        本文标题:利用RNN和LSTM生成小说题记

        本文链接:https://www.haomeiwen.com/subject/rdvrqftx.html