美文网首页
tensorflow生成——lstm生成诗

tensorflow生成——lstm生成诗

作者: small瓜瓜 | 来源:发表于2021-08-10 11:58 被阅读0次

    学习tensorflow的生成,在网上查到生成诗的例子,改了一下,将所有的代码全部放在一个文件中,进行运行。

    #!-*- coding: utf-8 -*-
    
    import math
    import os
    
    import numpy as np
    import tensorflow as tf
    from tensorflow.keras.models import load_model
    from tensorflow.keras.preprocessing.sequence import pad_sequences
    from tensorflow.keras.preprocessing.text import Tokenizer
    
    # 禁用词,包含如下字符的唐诗将被忽略
    disallowed_words = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']']
    max_len = 64
    # 最小词频
    min_word_frequency = 6
    
    # 共训练多少个epoch
    epochs = 20
    # 训练的batch size
    batch_size = 128
    # 数据集路径
    dataset_path = './poetry.txt'
    random = True
    
    # 加载数据集
    lines = []
    with open(dataset_path, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            # 将冒号统一成相同格式
            lines.append(line.replace(':', ':'))
    
    # 数据集列表
    poetries = []
    # 逐行处理读取到的数据
    for line in lines:
        # 有且只能有一个冒号用来分割标题
        if line.count(':') != 1:
            continue
        # 后半部分不能包含禁止词
        __, last_part = line.split(':')
        ignore_flag = False
        for dis_word in disallowed_words:
            if dis_word in last_part:
                ignore_flag = True
                break
        if ignore_flag:
            continue
        # 长度不能超过最大长度
        if len(last_part) > max_len - 2:
            continue
        # 为了使用tensorflow的text的分词器,这里将分割使用" "
        poetries.append(last_part.replace('\n', '').replace("", ' ').strip())
    
    # 分词器
    tokenize = Tokenizer()
    tokenize.fit_on_texts(poetries)
    
    # 字汇
    words = [word for word, count in tokenize.word_counts.items() if count > min_word_frequency]
    words = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + words
    vocab_size = len(words)
    # 创建字典
    word_dict = dict(zip(words, range(vocab_size)))
    index_word = dict((value, key) for key, value in word_dict.items())
    
    # 获取总的训练数据
    total = len(poetries)
    best_model_path = 'my_bast_model.h5'
    
    
    def __epoch_iter__():
        # 是否随机混洗
        if random:
            np.random.shuffle(poetries)
        # 迭代一个epoch,每次yield一个batch
        for start in range(0, total, batch_size):
            end = min(start + batch_size, total)
            batch_data = []
            # 逐一对古诗进行编码
            for single_data in poetries[start:end]:
                single_seq = [word_dict.get(item, word_dict['[UNK]']) for item in single_data if item != ' ']
                single_seq = [word_dict['[CLS]']] + single_seq + [
                    word_dict['[SEP]']]
                batch_data.append(single_seq)
    
            # 填充为相同长度
            batch_data = pad_sequences(batch_data, value=word_dict['[PAD]'])
            # yield x,y
            yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], vocab_size)
            del batch_data
    
    
    def get_train():
        while True:
            yield from __epoch_iter__()
    
    
    if os.path.exists(best_model_path):
        # 加载模型
        print(f"===== 加载模型 {best_model_path} ======")
        model = load_model(best_model_path)
    else:
        # 构建模型
        model = tf.keras.Sequential([
            # 不定长度的输入
            tf.keras.layers.Input((None,)),
            # 词嵌入层
            tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=128),
            # 第一个LSTM层,返回序列作为下一层的输入
            tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
            # 第二个LSTM层,返回序列作为下一层的输入
            tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
            # 对每一个时间点的输出都做softmax,预测下一个词的概率
            tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(vocab_size, activation='softmax')),
        ])
    
    # 查看模型结构
    model.summary()
    # 配置优化器和损失函数
    model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy)
    
    
    class Evaluate(tf.keras.callbacks.Callback):
        """
        在每个epoch训练完成后,保留最优权重,并随机生成settings.SHOW_NUM首古诗展示
        """
    
        def __init__(self):
            super().__init__()
            # 给loss赋一个较大的初始值
            self.lowest = 1e10
    
        def on_epoch_end(self, epoch, logs=None):
            # 在每个epoch训练完成后调用
            # 如果当前loss更低,就保存当前模型参数
            if logs['loss'] <= self.lowest:
                self.lowest = logs['loss']
                model.save(best_model_path)
            # 随机生成几首古体诗测试,查看训练效果
            token_ids = [word_dict['[CLS]']]
            while len(token_ids) < max_len:
                output = model(np.reshape(token_ids, (1, -1)))
                probas = output.numpy()[0, -1, 3:]
                del output
                if np.argmax(probas) == 0:
                    break
    
                probas = probas[1:]
                # 按照出现概率,对所有token倒序排列
                p_args = probas.argsort()[::-1][:100]
                # 排列后的概率顺序
                p = probas[p_args]
                # 先对概率归一
                p = p / sum(p)
                # 再按照预测出的概率,随机选择一个词作为预测结果
                target_index = np.random.choice(len(p), p=p)
                target = p_args[target_index] + 4
                # 保存
                token_ids.append(target)
    
            out_text = [index_word[item] for item in token_ids]
            print(''.join(out_text[1:]).replace("。", "。\n"))
    
    
    # 开始训练
    model.fit(get_train(), epochs=epochs,
              steps_per_epoch=int(math.floor(total / batch_size)),
              callbacks=[Evaluate()])
    
    运行结果如下:
    ===== 加载模型 my_bast_model.h5 ======
    Model: "sequential"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    embedding (Embedding)        (None, None, 128)         456064    
    _________________________________________________________________
    lstm (LSTM)                  (None, None, 128)         131584    
    _________________________________________________________________
    lstm_1 (LSTM)                (None, None, 128)         131584    
    _________________________________________________________________
    time_distributed (TimeDistri (None, None, 3563)        459627    
    =================================================================
    Total params: 1,178,859
    Trainable params: 1,178,859
    Non-trainable params: 0
    _________________________________________________________________
    
    Epoch 1/20
    191/191 [==============================] - 63s 287ms/step - loss: 3.4426
    朝朝不识老,与有君时便。
    此年千地行,长与故阳客。
    清上碧庭色,清溪已故夕。
    无此不同心,时间泪有游。
    
    Epoch 2/20
    191/191 [==============================] - 55s 289ms/step - loss: 3.4255
    山门与古远,尽夜生幽城。
    世坐知书梦,相知道自闲。
    人朝孤户下,红井草还间。
    欲得君亲至,同归独不迷。
    
    Epoch 3/20
    191/191 [==============================] - 55s 287ms/step - loss: 3.4091
    金公八子至,八国未难逢。
    大火从相息,千门不见兵。
    玉山流岸雪,疏响月鸣回。
    却恨不伤到,无来一后游。
    
    Epoch 4/20
    191/191 [==============================] - 55s 286ms/step - loss: 3.4174
    幽来大灵名,所不无同归。
    风前月下水,水发自飞云。
    路起心无住,归林一自斜。
    烟尘千万落枝长,乱鸟初含玉辇烟。
    
    Epoch 5/20
    191/191 [==============================] - 55s 288ms/step - loss: 3.3755
    何闻有所及,不是日光迟。
    吴首向关路,又然清易伤。
    雪连空北尽,孤月待天稀。
    遥似西台去,多年谢病期。
    
    Epoch 6/20
    191/191 [==============================] - 55s 287ms/step - loss: 3.3861
    霜声入风水,月日欲长秋。
    天石烟云落,青窗月满扉。
    不勤经我远,高节向云山。
    远里浮山意,还今应奈亲。
    
    Epoch 7/20
    191/191 [==============================] - 55s 286ms/step - loss: 3.3795
    妾处青云寺,空行满梦情。
    闲风秋雨起,江雪海云寒。
    花色前峰日,春前见马流。
    无乡如思意,一日话吾亲。
    
    Epoch 8/20
    191/191 [==============================] - 59s 310ms/step - loss: 3.3618
    雨出见南园,暮愁江雨稀。
    一高人与客,千古在山时。
    不觉随湖发,谁逢月下云。
    长家无共过,吟去寄乡名。
    
    Epoch 9/20
    191/191 [==============================] - 56s 293ms/step - loss: 3.3708
    江阳三下楚,风雨动征年。
    草雨山阴雪,云霜出岸灯。
    归思无客贵,从见泪离身。
    
    Epoch 10/20
    191/191 [==============================] - 57s 297ms/step - loss: 3.3599
    荷霏香翠遍,风彩郁成尘。
    御玉天霞入,残轩积雨林。
    沙浓声落暗,海水滴生流。
    不要休何处,空明梦复多。
    

    古诗文件是直接使用网上的,也可以将github中的古诗集项目拿来用,这里不做演示。
    项目源码https://gitee.com/MIEAPP/deep-learning/tree/master/example10

    相关文章

      网友评论

          本文标题:tensorflow生成——lstm生成诗

          本文链接:https://www.haomeiwen.com/subject/ulmxbltx.html