Rasa Core源码之Policy训练

作者: zqh_zy | 来源:发表于2017-12-23 17:15 被阅读784次
context.png

上下文的联系与理解是对话系统中重要的一块,直接影响与机器人对话的体验。最近接触了RASA系列,包括自然语言理解的rasa nlu和对话管理的rasa core。简单方便的实现一个任务对话系统的同时,也好奇其内部实现使用的技术。花时间读了Rasa Core关于上下文理解部分的源码,后面有机会再把rasa对话系统的其他模块的实现也做一个源码的分析。
文章分为以下几部分:

  • Rasa Core的主要模块概念
  • 训练数据准备
  • 对话Policy模型训练和实现方法

主要概念

arch.png

与对话系统的主要模块对应,如图 Rasa Core的实现也有相应的几个模块。从接受用户消息到机器人做出决策的流程大致如下:

  • 接受用户消息,送入Interpreter模块,识别并生成包含消息文本(text)、用户意图(intent)、和实体(entities)的字典。这里Interpreter对意图和实体的识别由上面提到的Rasa NLU实现,不是文章的主题,只要知道其功能即可。
  • Tracker 是对对话状态进行追踪(state tracker)的对象,它接受并记录Interpreter识别的新消息。
  • Policy接受当前的对话状态,选择响应哪一个Action。
  • 被选择的Action别记录在Tracker中,并返回响应给用户。

上述流程是在Interpreter和Policy模型训练好的基础上对话系统的运行流程,下面主要针对Policy选择Action的模型训练部分的源码进行分析。该部分模型需要考虑历史对话对下一步响应进行决策,是整个对话系统的核心。

训练数据

训练Policy之前需要准备两个数据文件:

  • domain.yml : 包括对话系统所适用的领域,其中包括intents(意图集合)、slots(实体槽集合)、actions (机器人相应方式的集合)。
  • story.md:训练数据集合,这里的训练数据比不是原始的对话数据,而是原始的对话在domain中的映射。
    以官方的订餐馆的数据集为例:

restaurant_domain.yml:

slots:
  cuisine:
    type: text
  people:
    type: text
  location:
    type: text
  price:
    type: text
  info:
    type: text
  matches:
    type: list

intents:
 - greet
 - affirm
 - deny
 - inform
 - thankyou
 - request_info

entities:
 - location
 - info
 - people
 - price
 - cuisine

templates:
  utter_greet:
    - "hey there!"
  utter_goodbye:
    - "goodbye :("
    - "Bye-bye"
  utter_default:
    - "default message"
  utter_ack_dosearch: 
    - "ok let me see what I can find"
...
...

actions:
  - utter_greet
  - utter_goodbye
  - utter_default
  - utter_ack_dosearch
  - utter_ack_findalternatives
...
...

babi_stories.md:

## story_03812903
* greet # 用户打招呼
 - utter_ask_howcanhelp # 机器人响应需要什么帮助
* inform{"location": "paris", "people": "six", "price": "cheap"} # 用户回复想订一下Paris便宜的六人桌
 - utter_on_it # 机器人回复好的
 - utter_ask_cuisine # 机器人继续询问要什么菜系
* inform{"cuisine": "indian"} # 用户说印度菜
 - utter_ack_dosearch # 机器人回复稍等帮您查找
 - action_search_restaurants # 机器人查库返回结果
...
... # 省略
* affirm # 用户确认 
 - utter_ack_makereservation # 机器人回复完成订单,询问手机号
* request_info{"info": "phone"} # 用户告知手机号
 - action_suggest # 机器人其他推荐
* thankyou # 用户感谢
 - utter_ask_helpmore # 机器人询问其他帮助

story中对样例对话进行了简单的注释。

模型训练

准备好训练数据,下面是模型训练。拿官方的一个经典的KerasPolicy模型为例,该模型用Keras实现了一个简单的LSTM作为Policy模型:

def model_architecture(self, num_features, num_actions, max_history_len):
        """Build a keras model and return a compiled model.

        :param max_history_len: The maximum number of historical
                                turns used to decide on next action
        """
        from keras.layers import LSTM, Activation, Masking, Dense
        from keras.models import Sequential

        n_hidden = 32  # Neural Net and training params
        batch_shape = (None, max_history_len, num_features)
        # Build Model
        model = Sequential()
        model.add(Masking(-1, batch_input_shape=batch_shape))
        model.add(LSTM(n_hidden, batch_input_shape=batch_shape, dropout=0.2))
        model.add(Dense(input_dim=n_hidden, units=num_actions))
        model.add(Activation('softmax'))

        model.compile(loss='categorical_crossentropy',
                      optimizer='rmsprop',
                      metrics=['accuracy'])

        logger.debug(model.summary())
        return model

模型通过历史对话记录作为输入训练数据,下一个决策Action作为label,进行模型训练。三个参数:

  • max_history_len: 记录的最大历史长度。
  • num_features: 每个记录的特征维度(intent、slot、action等的数目),包括了该记录的状态。
  • num_actions:候选响应数。

模型本质上是num_actions个类别的多分类。下面详细分析对story.md的编码,生成可以直接输入到模型的训练数据(X,y)。

状态追踪(state track)

在搞清楚模型输入的训练数据是什么之前需要了解Rasa Core是如何实现状态追踪。训练阶段,rasa core读入Story,用track记录:在设置最大长度上下文为2时,一条训练数据的会有如下字典的表示:

[
    {'entity_location': 1.0, 'entity_people': 1.0, 'entity_price': 1.0, 'slot_cuisine_0': 0.0, 'slot_info_0': 0.0, 'slot_location_0': 1.0, 'slot_matches_0': 0.0, 'slot_people_0': 1.0, 'slot_price_0': 1.0, 'intent_inform': 1.0, 'prev_utter_on_it': 1}, 
    {'entity_location': 1.0, 'entity_people': 1.0, 'entity_price': 1.0, 'slot_cuisine_0': 0.0, 'slot_info_0': 0.0, 'slot_location_0': 1.0, 'slot_matches_0': 0.0, 'slot_people_0': 1.0, 'slot_price_0': 1.0, 'intent_inform': 1.0, 'prev_utter_ask_cuisine': 1}, 
    {'entity_cuisine': 1.0, 'slot_cuisine_0': 1.0, 'slot_info_0': 0.0, 'slot_location_0': 1.0, 'slot_matches_0': 0.0, 'slot_people_0': 1.0, 'slot_price_0': 1.0, 'intent_inform': 1.0, 'prev_action_listen': 1}
]

该部分状态对应上面训练数据的

 - utter_on_it # 机器人回复好的
 - utter_ask_cuisine # 机器人继续询问要什么菜系
* inform{"cuisine": "indian"} # 用户说印度菜
状态编码

track列表中第一个字典表示utter_on_it后的状态,此时slot_location、slot_people、slot_price等的均已收集到在之前的对话中,对应value为1。第二个字典表示在utter_ask_cuisine后的状态,此时并没有获取到新的信息,而只是记录上一个机器响应prev_utter_ask_cuisine的value为1,表示该阶段状态;第三个字典表示当前状态,在获取新的cuisine信息后对应key的value置为1,同时上一个action为prev_action_listen表示监听。

相应的,根据训练数据下一个机器应该采取的action为:

- utter_ack_dosearch # 机器人回复稍等帮您查找

如此得到一条训练数据(x,y), x经过编码,单条记录为一个二值向量,如果特征出现为1,否则为0,对应上面的第三个字典:

{'entity_cuisine': 1.0, 'slot_cuisine_0': 1.0, 'slot_info_0': 0.0, 'slot_location_0': 1.0, 'slot_matches_0': 0.0, 'slot_people_0': 1.0, 'slot_price_0': 1.0, 'intent_inform': 1.0, 'prev_action_listen': 1}
[0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0,0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],

而对于最大历史信息记录为2的对应单条训练数据:

[array([[0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0,0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)]

对应的y为5(utter_ack_dosearch的编号)。相同方法从Story中读取所有可能的数据对,去重和数据增强(打乱拼接),最终生成训练数据X,y。

  • X的维度为:(num_states, max_history, num_features)
  • y的维度为:num_states
模型训练

在准备好训练数据之后就可以对LSTM进行训练:

  def train(self, X, y, domain, **kwargs):
      self.model = self.model_architecture(domain.num_features,
                                           domain.num_actions,
                                           X.shape[1])
      y_one_hot = np.zeros((len(y), domain.num_actions))
      y_one_hot[np.arange(len(y)), y] = 1

      number_of_samples = X.shape[0]
      idx = np.arange(number_of_samples)
      np.random.shuffle(idx)
      shuffled_X = X[idx, :, :]
      shuffled_y = y_one_hot[idx, :]

      validation_split = kwargs.get("validation_split", 0.0)
      logger.info("Fitting model with {} total samples and a validation "
                  "split of {}".format(number_of_samples, validation_split))
      self.model.fit(shuffled_X, shuffled_y, **kwargs)
      self.current_epoch = kwargs.get("epochs", 10)
      logger.info("Done fitting keras policy model")

和一般LSTM网络的训练方法一样,这里先对y进行one hot编码,shuffle训练集,之后进行训练。对于单个训练数据,对比文本的训练,一个状态相当于一个词,而最大上下文长度为2的单条训练数据可类比为2个词的句子。

而在模型实用的预测阶段,一开始流程也有涉及,显然只要Tracker记录之前的聊天记录,每次拿当前决策的前两个消息作为模型输入,输出即为每个action的概率值,选择最大的响应即可。

小结

到此分析了Rasa Core的Policy训练方式,虽然Rasa Core的代码量并不算大,但这里并没有根据源码细节来看,而只是理清其训练方法。通过一个不错的对话系统的源码阅读,可以对对话管理的几个关键技术有进一步的理解,比如状态追踪、上下文理解以及没有讲的意图识别和实体识别。
相比于高大上的论文的解决方案(如端到端、Memory Network进行上下文理解),Rasa Core显得更加简单可用,同样Rasa Core支持online learning还有点增强学习的意思,感兴趣的可以关注其github。

原创文章转载注明出处

相关文章

网友评论

  • 鱼头三:感谢,终于理清楚core的逻辑了
  • c51b2c80fd64:zqh_zy你好,想借你这号召一下一块研究rasa的同志们,能不能一块加个群,群号758380083。加进来啊朋友们,共同学习,共同进步!!
  • c51b2c80fd64:你好,请问你现在还在跟进最新版本进行研究吗?想请教个问题,可不可以加个QQ(1073521013),希望可以共同学习共同进步。谢谢
  • zhangp365:请教下,这个RASA-NLU系统的潜力如何,如果用它做中文意图分类,最终能达到或接近商业系统的水准不?比如准确率和召回率能达到双95%以上不?
    zhangp365:@zqh_zy 好的,感谢回复。
    zqh_zy:@东海火鸡 指标应该还是看具体问题和数据,rasa只是一个算法模块的pipeline,可以更换任意环节的组件,比如训练算法模型
  • Jonathan丶Wei:你好,rasa core支持在线学习的连接能发我下么?最近也在看这个!
    Jonathan丶Wei:@zqh_zy 3Q,我去看看
    zqh_zy:https://core.rasa.ai/tutorial_interactive_learning.html

本文标题:Rasa Core源码之Policy训练

本文链接:https://www.haomeiwen.com/subject/gpyqgxtx.html