Rasa Core源码之Policy训练

作者: zqh_zy | 来源:发表于2017-12-23 17:15 被阅读784次
    context.png

    上下文的联系与理解是对话系统中重要的一块,直接影响与机器人对话的体验。最近接触了RASA系列,包括自然语言理解的rasa nlu和对话管理的rasa core。简单方便的实现一个任务对话系统的同时,也好奇其内部实现使用的技术。花时间读了Rasa Core关于上下文理解部分的源码,后面有机会再把rasa对话系统的其他模块的实现也做一个源码的分析。
    文章分为以下几部分:

    • Rasa Core的主要模块概念
    • 训练数据准备
    • 对话Policy模型训练和实现方法

    主要概念

    arch.png

    与对话系统的主要模块对应,如图 Rasa Core的实现也有相应的几个模块。从接受用户消息到机器人做出决策的流程大致如下:

    • 接受用户消息,送入Interpreter模块,识别并生成包含消息文本(text)、用户意图(intent)、和实体(entities)的字典。这里Interpreter对意图和实体的识别由上面提到的Rasa NLU实现,不是文章的主题,只要知道其功能即可。
    • Tracker 是对对话状态进行追踪(state tracker)的对象,它接受并记录Interpreter识别的新消息。
    • Policy接受当前的对话状态,选择响应哪一个Action。
    • 被选择的Action别记录在Tracker中,并返回响应给用户。

    上述流程是在Interpreter和Policy模型训练好的基础上对话系统的运行流程,下面主要针对Policy选择Action的模型训练部分的源码进行分析。该部分模型需要考虑历史对话对下一步响应进行决策,是整个对话系统的核心。

    训练数据

    训练Policy之前需要准备两个数据文件:

    • domain.yml : 包括对话系统所适用的领域,其中包括intents(意图集合)、slots(实体槽集合)、actions (机器人相应方式的集合)。
    • story.md:训练数据集合,这里的训练数据比不是原始的对话数据,而是原始的对话在domain中的映射。
      以官方的订餐馆的数据集为例:

    restaurant_domain.yml:

    slots:
      cuisine:
        type: text
      people:
        type: text
      location:
        type: text
      price:
        type: text
      info:
        type: text
      matches:
        type: list
    
    intents:
     - greet
     - affirm
     - deny
     - inform
     - thankyou
     - request_info
    
    entities:
     - location
     - info
     - people
     - price
     - cuisine
    
    templates:
      utter_greet:
        - "hey there!"
      utter_goodbye:
        - "goodbye :("
        - "Bye-bye"
      utter_default:
        - "default message"
      utter_ack_dosearch: 
        - "ok let me see what I can find"
    ...
    ...
    
    actions:
      - utter_greet
      - utter_goodbye
      - utter_default
      - utter_ack_dosearch
      - utter_ack_findalternatives
    ...
    ...
    
    

    babi_stories.md:

    ## story_03812903
    * greet # 用户打招呼
     - utter_ask_howcanhelp # 机器人响应需要什么帮助
    * inform{"location": "paris", "people": "six", "price": "cheap"} # 用户回复想订一下Paris便宜的六人桌
     - utter_on_it # 机器人回复好的
     - utter_ask_cuisine # 机器人继续询问要什么菜系
    * inform{"cuisine": "indian"} # 用户说印度菜
     - utter_ack_dosearch # 机器人回复稍等帮您查找
     - action_search_restaurants # 机器人查库返回结果
    ...
    ... # 省略
    * affirm # 用户确认 
     - utter_ack_makereservation # 机器人回复完成订单,询问手机号
    * request_info{"info": "phone"} # 用户告知手机号
     - action_suggest # 机器人其他推荐
    * thankyou # 用户感谢
     - utter_ask_helpmore # 机器人询问其他帮助
    

    story中对样例对话进行了简单的注释。

    模型训练

    准备好训练数据,下面是模型训练。拿官方的一个经典的KerasPolicy模型为例,该模型用Keras实现了一个简单的LSTM作为Policy模型:

    def model_architecture(self, num_features, num_actions, max_history_len):
            """Build a keras model and return a compiled model.
    
            :param max_history_len: The maximum number of historical
                                    turns used to decide on next action
            """
            from keras.layers import LSTM, Activation, Masking, Dense
            from keras.models import Sequential
    
            n_hidden = 32  # Neural Net and training params
            batch_shape = (None, max_history_len, num_features)
            # Build Model
            model = Sequential()
            model.add(Masking(-1, batch_input_shape=batch_shape))
            model.add(LSTM(n_hidden, batch_input_shape=batch_shape, dropout=0.2))
            model.add(Dense(input_dim=n_hidden, units=num_actions))
            model.add(Activation('softmax'))
    
            model.compile(loss='categorical_crossentropy',
                          optimizer='rmsprop',
                          metrics=['accuracy'])
    
            logger.debug(model.summary())
            return model
    

    模型通过历史对话记录作为输入训练数据,下一个决策Action作为label,进行模型训练。三个参数:

    • max_history_len: 记录的最大历史长度。
    • num_features: 每个记录的特征维度(intent、slot、action等的数目),包括了该记录的状态。
    • num_actions:候选响应数。

    模型本质上是num_actions个类别的多分类。下面详细分析对story.md的编码,生成可以直接输入到模型的训练数据(X,y)。

    状态追踪(state track)

    在搞清楚模型输入的训练数据是什么之前需要了解Rasa Core是如何实现状态追踪。训练阶段,rasa core读入Story,用track记录:在设置最大长度上下文为2时,一条训练数据的会有如下字典的表示:

    [
        {'entity_location': 1.0, 'entity_people': 1.0, 'entity_price': 1.0, 'slot_cuisine_0': 0.0, 'slot_info_0': 0.0, 'slot_location_0': 1.0, 'slot_matches_0': 0.0, 'slot_people_0': 1.0, 'slot_price_0': 1.0, 'intent_inform': 1.0, 'prev_utter_on_it': 1}, 
        {'entity_location': 1.0, 'entity_people': 1.0, 'entity_price': 1.0, 'slot_cuisine_0': 0.0, 'slot_info_0': 0.0, 'slot_location_0': 1.0, 'slot_matches_0': 0.0, 'slot_people_0': 1.0, 'slot_price_0': 1.0, 'intent_inform': 1.0, 'prev_utter_ask_cuisine': 1}, 
        {'entity_cuisine': 1.0, 'slot_cuisine_0': 1.0, 'slot_info_0': 0.0, 'slot_location_0': 1.0, 'slot_matches_0': 0.0, 'slot_people_0': 1.0, 'slot_price_0': 1.0, 'intent_inform': 1.0, 'prev_action_listen': 1}
    ]
    
    

    该部分状态对应上面训练数据的

     - utter_on_it # 机器人回复好的
     - utter_ask_cuisine # 机器人继续询问要什么菜系
    * inform{"cuisine": "indian"} # 用户说印度菜
    
    状态编码

    track列表中第一个字典表示utter_on_it后的状态,此时slot_location、slot_people、slot_price等的均已收集到在之前的对话中,对应value为1。第二个字典表示在utter_ask_cuisine后的状态,此时并没有获取到新的信息,而只是记录上一个机器响应prev_utter_ask_cuisine的value为1,表示该阶段状态;第三个字典表示当前状态,在获取新的cuisine信息后对应key的value置为1,同时上一个action为prev_action_listen表示监听。

    相应的,根据训练数据下一个机器应该采取的action为:

    - utter_ack_dosearch # 机器人回复稍等帮您查找
    

    如此得到一条训练数据(x,y), x经过编码,单条记录为一个二值向量,如果特征出现为1,否则为0,对应上面的第三个字典:

    {'entity_cuisine': 1.0, 'slot_cuisine_0': 1.0, 'slot_info_0': 0.0, 'slot_location_0': 1.0, 'slot_matches_0': 0.0, 'slot_people_0': 1.0, 'slot_price_0': 1.0, 'intent_inform': 1.0, 'prev_action_listen': 1}
    [0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0,0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    

    而对于最大历史信息记录为2的对应单条训练数据:

    [array([[0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0,0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)]
    

    对应的y为5(utter_ack_dosearch的编号)。相同方法从Story中读取所有可能的数据对,去重和数据增强(打乱拼接),最终生成训练数据X,y。

    • X的维度为:(num_states, max_history, num_features)
    • y的维度为:num_states
    模型训练

    在准备好训练数据之后就可以对LSTM进行训练:

      def train(self, X, y, domain, **kwargs):
          self.model = self.model_architecture(domain.num_features,
                                               domain.num_actions,
                                               X.shape[1])
          y_one_hot = np.zeros((len(y), domain.num_actions))
          y_one_hot[np.arange(len(y)), y] = 1
    
          number_of_samples = X.shape[0]
          idx = np.arange(number_of_samples)
          np.random.shuffle(idx)
          shuffled_X = X[idx, :, :]
          shuffled_y = y_one_hot[idx, :]
    
          validation_split = kwargs.get("validation_split", 0.0)
          logger.info("Fitting model with {} total samples and a validation "
                      "split of {}".format(number_of_samples, validation_split))
          self.model.fit(shuffled_X, shuffled_y, **kwargs)
          self.current_epoch = kwargs.get("epochs", 10)
          logger.info("Done fitting keras policy model")
    

    和一般LSTM网络的训练方法一样,这里先对y进行one hot编码,shuffle训练集,之后进行训练。对于单个训练数据,对比文本的训练,一个状态相当于一个词,而最大上下文长度为2的单条训练数据可类比为2个词的句子。

    而在模型实用的预测阶段,一开始流程也有涉及,显然只要Tracker记录之前的聊天记录,每次拿当前决策的前两个消息作为模型输入,输出即为每个action的概率值,选择最大的响应即可。

    小结

    到此分析了Rasa Core的Policy训练方式,虽然Rasa Core的代码量并不算大,但这里并没有根据源码细节来看,而只是理清其训练方法。通过一个不错的对话系统的源码阅读,可以对对话管理的几个关键技术有进一步的理解,比如状态追踪、上下文理解以及没有讲的意图识别和实体识别。
    相比于高大上的论文的解决方案(如端到端、Memory Network进行上下文理解),Rasa Core显得更加简单可用,同样Rasa Core支持online learning还有点增强学习的意思,感兴趣的可以关注其github。

    原创文章转载注明出处

    相关文章

      网友评论

      • 鱼头三:感谢,终于理清楚core的逻辑了
      • c51b2c80fd64:zqh_zy你好,想借你这号召一下一块研究rasa的同志们,能不能一块加个群,群号758380083。加进来啊朋友们,共同学习,共同进步!!
      • c51b2c80fd64:你好,请问你现在还在跟进最新版本进行研究吗?想请教个问题,可不可以加个QQ(1073521013),希望可以共同学习共同进步。谢谢
      • zhangp365:请教下,这个RASA-NLU系统的潜力如何,如果用它做中文意图分类,最终能达到或接近商业系统的水准不?比如准确率和召回率能达到双95%以上不?
        zhangp365:@zqh_zy 好的,感谢回复。
        zqh_zy:@东海火鸡 指标应该还是看具体问题和数据,rasa只是一个算法模块的pipeline,可以更换任意环节的组件,比如训练算法模型
      • Jonathan丶Wei:你好,rasa core支持在线学习的连接能发我下么?最近也在看这个!
        Jonathan丶Wei:@zqh_zy 3Q,我去看看
        zqh_zy:https://core.rasa.ai/tutorial_interactive_learning.html

      本文标题:Rasa Core源码之Policy训练

      本文链接:https://www.haomeiwen.com/subject/gpyqgxtx.html