美文网首页
机器学习RRN训练聊天机器人

机器学习RRN训练聊天机器人

作者: 西5d | 来源:发表于2019-03-02 15:01 被阅读22次

    前言

    上篇写过一个机器学习写唐诗的实验,这次我们搞个稍微复杂些的,实现一个聊天机器人,也是基于腾讯云实验室的一篇教程,有些部分做了改动,大部分时间都用在了环境的适配上面。开始本地是在Mac环境,单独依靠CPU训练,比较慢。后来找了个配置比较好的机器, 6核心12线程,效果好一些。总结来说,机器学习相关有两个重点,一个是基础的训练资源,包括对原始数据的清洗处理和规范化,训练中其实模型是没有很大区别的。其次,是好的机器配置,资源有限,没有上GPU。这次实验,本地训练大概半天到4000步的时候,还只是个复读机,换了高配机器1天左右就可以到30万左右,两天到70万,基本达到损失率稳定(30万就可以)。
    以下是本地机器的配置,奈何效果不行

    MacBook Pro (13-inch, 2017, Four Thunderbolt 3 Ports)
    10.13.6 (17G65)16 GB 2133 MHz LPDDR3
    3.1 GHz Intel Core i5

    注意事项
    强烈建议使用virtualenv配置python,简单而且不会对本地运行环境造成影响。
    同时需要安装好TensorFlow环境

    过程步骤

    实验内容

    1. 首先进行数据的清洗,处理。提取ask和answer数据,并得到字典,以及做向量化处理。训练数据可以使用本次实验链接里的,也可以使用网上的小黄鸡等等语料。注意这里的字典之前查的资料是满足3000左右的常用汉字就可以,是在语料中找到常用字。

    2. 模型学习部分。
      这里引用了seq2seq的部分,单独有一些修改。之前下载实验中提供的训练了30万次左右的模型直接进行对话,但是本地一直提示错误。最终选择了自己训练,保存了完整的checkpoint文件,可以启动程序。如图最终训练在71万次左右,其实30万左右损失率基本就已经不变了,如果能提供更优化的语料应该效果会更好。后续有链接提供所有资料,可以直接下载。


      训练完毕的模型
    3. 模拟对话,这部分是最终的成果,启动本地依赖,加载训练模型之后就可以对话了,效果看图,可以看到有些句子还是可以对上的,一问一答,有些幼稚。


      模拟对话

    代码部分

    1. 数据整理和向量化 generate.py
    # -*- coding:utf-8 -*-
    from io import open
    import random
    import tensorflow as tf
    
    # version tf 1.12 2018-12-08 22:22:08
    PAD = "PAD"
    GO = "GO"
    EOS = "EOS"
    UNK = "UNK"
    START_VOCAB = [PAD, GO, EOS, UNK]
    
    PAD_ID = 0  # 填充
    GO_ID = 1  # 开始标志
    EOS_ID = 2  # 结束标志
    UNK_ID = 3  # 未知字符
    _buckets = [(10, 15), (20, 25), (40, 50), (80, 100)]
    units_num = 256
    num_layers = 3
    max_gradient_norm = 5.0
    batch_size = 50
    learning_rate = 0.5
    learning_rate_decay_factor = 0.97
    
    train_encode_file = "data/train_encode"
    train_decode_file = "data/train_decode"
    test_encode_file = "data/test_encode"
    test_decode_file = "data/test_decode"
    vocab_encode_file = "data/vocab_encode"
    vocab_decode_file = "data/vocab_decode"
    train_encode_vec_file = "data/train_encode_vec"
    train_decode_vec_file = "data/train_decode_vec"
    test_encode_vec_file = "data/test_encode_vec"
    test_decode_vec_file = "data/test_decode_vec"
    
    
    def is_chinese(sentence):
        flag = True
        if len(sentence) < 2:
            flag = False
            return flag
        for uchar in sentence:
            if (uchar == ',' or uchar == '。' or
                    uchar == '~' or uchar == '?' or
                    uchar == '!'):
                flag = True
            elif '一' <= uchar <= '鿿':
                flag = True
            else:
                flag = False
                break
        return flag
    
    
    def get_chatbot():
        f = open("data/chat.conv", "r", encoding="utf-8")
        train_encode = open(train_encode_file, "w", encoding="utf-8")
        train_decode = open(train_decode_file, "w", encoding="utf-8")
        test_encode = open(test_encode_file, "w", encoding="utf-8")
        test_decode = open(test_decode_file, "w", encoding="utf-8")
        vocab_encode = open(vocab_encode_file, "w", encoding="utf-8")
        vocab_decode = open(vocab_decode_file, "w", encoding="utf-8")
        encode = list()
        decode = list()
    
        chat = list()
        print("start load source data...")
        step = 0
        for line in f.readlines():
            line = line.strip('\n').strip()
            if not line:
                continue
            if line[0] == "E":
                if step % 1000 == 0:
                    print("step:%d" % step)
                step += 1
                if (len(chat) == 2 and is_chinese(chat[0]) and is_chinese(chat[1]) and
                        not chat[0] in encode and not chat[1] in decode):
                    encode.append(chat[0])
                    decode.append(chat[1])
                chat = list()
            elif line[0] == "M":
                L = line.split(' ')
                if len(L) > 1:
                    chat.append(L[1])
        encode_size = len(encode)
        if encode_size != len(decode):
            raise ValueError("encode size not equal to decode size")
        test_index = random.sample([i for i in range(encode_size)], int(encode_size * 0.2))
        print("divide source into two...")
        step = 0
        for i in range(encode_size):
            if step % 1000 == 0:
                print("%d" % step)
            step += 1
            if i in test_index:
                test_encode.write(encode[i] + "\n")
                test_decode.write(decode[i] + "\n")
            else:
                train_encode.write(encode[i] + "\n")
                train_decode.write(decode[i] + "\n")
    
        vocab_encode_set = set(''.join(encode))
        vocab_decode_set = set(''.join(decode))
        print("get vocab_encode...")
        step = 0
        for word in vocab_encode_set:
            if step % 1000 == 0:
                print("%d" % step)
            step += 1
            vocab_encode.write(word + "\n")
        print("get vocab_decode...")
        step = 0
        for word in vocab_decode_set:
            print("%d" % step)
            step += 1
            vocab_decode.write(word + "\n")
    
    
    def gen_chatbot_vectors(input_file, vocab_file, output_file):
        vocab_f = open(vocab_file, "r", encoding="utf-8")
        output_f = open(output_file, "w")
        input_f = open(input_file, "r", encoding="utf-8")
        words = list()
        for word in vocab_f.readlines():
            word = word.strip('\n').strip()
            words.append(word)
        word_to_id = {word: i for i, word in enumerate(words)}
        to_id = lambda word: word_to_id.get(word, UNK_ID)
        print("get %s vectors" % input_file)
        step = 0
        for line in input_f.readlines():
            if step % 1000 == 0:
                print("step:%d" % step)
            step += 1
            line = line.strip('\n').strip()
            vec = map(to_id, line)
            output_f.write(' '.join([str(n) for n in vec]) + "\n")
    
    
    def get_vectors():
        gen_chatbot_vectors(train_encode_file, vocab_encode_file, train_encode_vec_file)
        gen_chatbot_vectors(train_decode_file, vocab_decode_file, train_decode_vec_file)
        gen_chatbot_vectors(test_encode_file, vocab_encode_file, test_encode_vec_file)
        gen_chatbot_vectors(test_decode_file, vocab_decode_file, test_decode_vec_file)
    
    
    def get_vocabs(vocab_file):
        words = list()
        with open(vocab_file, "r", encoding="utf-8") as vocab_f:
            for word in vocab_f:
                words.append(word.strip('\n').strip())
        id_to_word = {i: word for i, word in enumerate(words)}
        word_to_id = {v: k for k, v in id_to_word.items()}
        vocab_size = len(id_to_word)
        return id_to_word, word_to_id, vocab_size
    
    
    def read_data(source_path, target_path, max_size=None):
        data_set = [[] for _ in _buckets]
        with tf.gfile.GFile(source_path, mode="r") as source_file:
            with tf.gfile.GFile(target_path, mode="r") as target_file:
                source, target = source_file.readline(), target_file.readline()
                counter = 0
                while source and target and (not max_size or counter < max_size):
                    counter += 1
                    source_ids = [int(x) for x in source.split()]
                    target_ids = [int(x) for x in target.split()]
                    target_ids.append(EOS_ID)
                    for bucket_id, (source_size, target_size) in enumerate(_buckets):
                        if len(source_ids) < source_size and len(target_ids) < target_size:
                            data_set[bucket_id].append([source_ids, target_ids])
                            break
                    source, target = source_file.readline(), target_file.readline()
        return data_set
    
    
    # run
    #获取 ask、answer 数据并生成字典
    # get_chatbot()
    #训练数据转化为数字表示
    # get_vectors()
    
    1. 学习模型

    简书限制太长无法发布,只能在最后的链接获取了
    seq2seq.py
    seq2seq_model.py

    1. 训练模块
      train_chat.py
    # -*- coding:utf-8 -*-
    import generate as generate_chat
    import seq2seq_model as seq2seq_model
    import tensorflow as tf
    import numpy as np
    import logging
    import logging.handlers
    
    if __name__ == '__main__':
    
        _, _, source_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_encode_file)
        _, _, target_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_decode_file)
        train_set = generate_chat.read_data(generate_chat.train_encode_vec_file, generate_chat.train_decode_vec_file)
        test_set = generate_chat.read_data(generate_chat.test_encode_vec_file, generate_chat.test_decode_vec_file)
        train_bucket_sizes = [len(train_set[i]) for i in range(len(generate_chat._buckets))]
        train_total_size = float(sum(train_bucket_sizes))
        train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes))]
        cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
        with tf.Session(config=cpu_config) as sess:
            model = seq2seq_model.Seq2SeqModel(source_vocab_size,
                                               target_vocab_size,
                                               generate_chat._buckets,
                                               generate_chat.units_num,
                                               generate_chat.num_layers,
                                               generate_chat.max_gradient_norm,
                                               generate_chat.batch_size,
                                               generate_chat.learning_rate,
                                               generate_chat.learning_rate_decay_factor,
                                               use_lstm=True)
    
            ckpt = tf.train.get_checkpoint_state('./mytrain')
    
            if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
                print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
                model.saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print("Created model with fresh parameters.")
                sess.run(tf.global_variables_initializer())
            loss = 0.0
            step = 0
            previous_losses = []
            while True:
                random_number_01 = np.random.random_sample()
                bucket_id = min([i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01])
                encoder_inputs, decoder_inputs, target_weights = model.get_batch(train_set, bucket_id)
                _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, False)
                print("step:%d,loss:%f" % (step, step_loss))
                loss += step_loss / 2000
                step += 1
                if step % 1000 == 0:
                    print("step:%d,per_loss:%f" % (step, loss))
                    if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
                        sess.run(model.learning_rate_decay_op)
                    previous_losses.append(loss)
                    model.saver.save(sess, "mytrain/chatbot.ckpt", global_step=model.global_step)
                    loss = 0.0
                if step % 5000 == 0:
                    for bucket_id in range(len(generate_chat._buckets)):
                        if len(test_set[bucket_id]) == 0:
                            continue
                            encoder_inputs, decoder_inputs, target_weights = model.get_batch(test_set, bucket_id)
                            _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id,
                                                         True)
                            print("bucket_id:%d,eval_loss:%f" % (bucket_id, eval_loss))
    
    
    1. 对话模块
      chat.py
    # -*- coding:utf-8 -*-
    import generate as generate_chat
    import seq2seq_model as seq2seq_model
    import tensorflow as tf
    import numpy as np
    import sys
    
    if __name__ == '__main__':
        source_id_to_word, source_word_to_id, source_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_encode_file)
        target_id_to_word, target_word_to_id, target_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_decode_file)
        to_id = lambda word: source_word_to_id.get(word, generate_chat.UNK_ID)
        cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
        with tf.Session(config=cpu_config) as sess:
            model = seq2seq_model.Seq2SeqModel(source_vocab_size,
                                               target_vocab_size,
                                               generate_chat._buckets,
                                               generate_chat.units_num,
                                               generate_chat.num_layers,
                                               generate_chat.max_gradient_norm,
                                               1,
                                               generate_chat.learning_rate,
                                               generate_chat.learning_rate_decay_factor,
                                               forward_only=True,
                                               use_lstm=True)
            #model.saver.restore(sess, "model/chatbot.ckpt-317000")
            model.saver.restore(sess, "mytrain/chatbot.ckpt-717000")
            while True:
                sys.stdout.write("ask > ")
                sys.stdout.flush()
                sentence = sys.stdin.readline().strip('\n')
                flag = generate_chat.is_chinese(sentence)
                if not sentence or not flag:
                    print("请输入纯中文")
                    continue
                sentence_vec = list(map(to_id, sentence))
                bucket_id = len(generate_chat._buckets) - 1
                if len(sentence_vec) > generate_chat._buckets[bucket_id][0]:
                    print("sentence too long max:%d" % generate_chat._buckets[bucket_id][0])
                    exit(0)
                for i, bucket in enumerate(generate_chat._buckets):
                    if bucket[0] >= len(sentence_vec):
                        bucket_id = i
                        break
                encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(sentence_vec, [])]},
                                                                                 bucket_id)
                _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
                outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
                if generate_chat.EOS_ID in outputs:
                    outputs = outputs[:outputs.index(generate_chat.EOS_ID)]
                answer = "".join([tf.compat.as_str(target_id_to_word[output]) for output in outputs])
                print("answer > " + answer)
    
    

    注意
    这里在train_chat.py 和 chat.py中,tf.session有个配置改动,限制了使用的CPU数,在Ubuntu下如果没有限制,会造成TF占用所有的CPU资源,导致系统卡死,具体数值根据CPU核心数设置。
    代码如下:

    cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
        with tf.Session(config=cpu_config) as sess:
    

    结语

    感谢阅读,最后放上实验的实际地址和我自己训练的所有资源,本地实验在mac tf 1.12.0 和 python3.6.7,以及Ubuntu tf.1.12.0 和 python3.5环境下都正常,再次建议在virtualenv环境下。
    实验链接(时间过久可能失效):https://cloud.tencent.com/developer/labs/lab/10406
    本地实验资源:https://iss.igosh.com/share/201903/tencent-me.tar.gz

    相关文章

      网友评论

          本文标题:机器学习RRN训练聊天机器人

          本文链接:https://www.haomeiwen.com/subject/wqveuqtx.html