美文网首页rasa
rasa_core: nlg模块源码解读

rasa_core: nlg模块源码解读

作者: 是风车大渣渣啊 | 来源:发表于2020-02-07 00:00 被阅读0次

    最近在学习使用rasa构建聊天机器人,为了实现一个比较特别的功能,需要搞懂源码。rasa 的代码质量相当高,注释完整,函数定义包含 type hint 读起来非常舒服。
    rasa_core.nlg模块包含5个py脚本:

    • __init__.py
    • callback.py
    • generator.py
    • interpolator.py
    • template.py

    首先看 __init__.py

    from rasa.core.nlg.generator import NaturalLanguageGenerator
    from rasa.core.nlg.template import TemplatedNaturalLanguageGenerator
    from rasa.core.nlg.callback import CallbackNaturalLanguageGenerator
    

    可以看到,nlg模块主要有三个类,

    • NaturalLanguageGenerator(NLG)
    • TemplatedNaturalLanguageGenerator(TNLG)
    • CallbackNaturalLanguageGenerator(CNLG)

    TNLGCNLG都继承自NLG,所以从NLG开始。

    NaturalLanguageGenerator

    NLG类包含两个成员函数:

    • generate
    • create
      generate是抽象函数,没有具体实现,create是静态函数。
    generate:
    async def generate(
        self,
        template_name: Text,
        tracker: "DialogueStateTracker",
        output_channel: Text,
        **kwargs: Any,
    ) -> Optional[Dict[Text, Any]]
    

    异步抽象函数,用于对用户输入产生回复。

    create
    @staticmethod
    def create(
        obj: Union["NaturalLanguageGenerator", EndpointConfig, None],
        domain: Optional[Domain],
    ) -> "NaturalLanguageGenerator":
        """Factory to create a generator."""
    
        if isinstance(obj, NaturalLanguageGenerator):
            return obj
        else:
            return _create_from_endpoint_config(obj, domain)
    

    静态函数,用于产生一个NLG实例。建议的输入obj是NLG实例或者EndpointConfig对象,domain是Domain对象,如果obj是NLG实例,直接返回obj,否则根据EndpointConfig和Domain的配置,借助了_create_from_endpoint_config函数,实例化一个NLG。

    _create_from_endpoint_config

    接下来,我们来看_create_from_endpoint_config这个函数。

    def _create_from_endpoint_config(
        endpoint_config: Optional[EndpointConfig] = None, domain: Optional[Domain] = None,
    ) -> "NaturalLanguageGenerator":
        """Given an endpoint configuration, create a proper NLG object."""
    
        domain = domain or Domain.empty()
    
        if endpoint_config is None:
            from rasa.core.nlg import (  # pytype: disable=pyi-error
                TemplatedNaturalLanguageGenerator,
            )
    
            # this is the default type if no endpoint config is set
            nlg = TemplatedNaturalLanguageGenerator(domain.templates)
        elif endpoint_config.type is None or endpoint_config.type.lower() == "callback":
            from rasa.core.nlg import (  # pytype: disable=pyi-error
                CallbackNaturalLanguageGenerator,
            )
    
            # this is the default type if no nlg type is set
            nlg = CallbackNaturalLanguageGenerator(endpoint_config=endpoint_config)
        elif endpoint_config.type.lower() == "template":
            from rasa.core.nlg import (  # pytype: disable=pyi-error
                TemplatedNaturalLanguageGenerator,
            )
    
            nlg = TemplatedNaturalLanguageGenerator(domain.templates)
        else:
            nlg = _load_from_module_string(endpoint_config, domain)
    
        logger.debug(f"Instantiated NLG to '{nlg.__class__.__name__}'.")
        return nlg
    

    _create_from_endpoint_config的输入同样是EndpointConfig对象和Domain对象。函数主体是if-else的结构,根据EndpointConfig的状况决定构建怎样的NLG实例。

    _load_from_module_string
    def _load_from_module_string(
        endpoint_config: EndpointConfig, domain: Domain
    ) -> "NaturalLanguageGenerator":
        """Initializes a custom natural language generator.
    
        Args:
            domain: defines the universe in which the assistant operates
            endpoint_config: the specific natural language generator
        """
    
        try:
            nlg_class = common.class_from_module_path(endpoint_config.type)
            return nlg_class(endpoint_config=endpoint_config, domain=domain)
        except (AttributeError, ImportError) as e:
            raise Exception(
                f"Could not find a class based on the module path "
                f"'{endpoint_config.type}'. Failed to create a "
                f"`NaturalLanguageGenerator` instance. Error: {e}"
            )
    

    TemplatedNaturalLanguageGenerator

    TNLG继承自NLG,除了NLG的成员函数之外,还有以下新成员:

    • _templates_for_utter_action
    • _random_template_for
    • generate
    • generate_from_slots
    • _fill_template
    • _template_variables
      首先来看最重要的generate
    generate
    async def generate(
        self,
        template_name: Text,
        tracker: DialogueStateTracker,
        output_channel: Text,
        **kwargs: Any,
    ) -> Optional[Dict[Text, Any]]:
        """Generate a response for the requested template."""
    
        filled_slots = tracker.current_slot_values()
        return self.generate_from_slots(
            template_name, filled_slots, output_channel, **kwargs
        )
    

    输入是模板名和tracker对象,在模板中填充tracker记录的槽位生成回复语句。生成语句这里调用的是generate_from_slots函数。

    generate_from_slots
    def generate_from_slots(
        self,
        template_name: Text,
        filled_slots: Dict[Text, Any],
        output_channel: Text,
        **kwargs: Any,
    ) -> Optional[Dict[Text, Any]]:
        """Generate a response for the requested template."""
    
        # Fetching a random template for the passed template name
        r = copy.deepcopy(self._random_template_for(template_name, output_channel))
        # Filling the slots in the template and returning the template
        if r is not None:
            return self._fill_template(r, filled_slots, **kwargs)
        else:
            return None
    

    这里调用_random_template_for随机选择模板(一个action可能对应多个回复模板),然后调用_fill_template填充模板中的槽位。
    先来看_random_template_for。

    _random_template_for
    def _random_template_for(
        self, utter_action: Text, output_channel: Text
    ) -> Optional[Dict[Text, Any]]:
        """Select random template for the utter action from available ones.
    
        If channel-specific templates for the current output channel are given,
        only choose from channel-specific ones.
        """
        import numpy as np
    
        if utter_action in self.templates:
            suitable_templates = self._templates_for_utter_action(
                utter_action, output_channel
            )
    
            if suitable_templates:
                return np.random.choice(suitable_templates)
            else:
                return None
        else:
            return None
    

    调用_templates_for_utter_action函数拿到当前action的所有模板,使用np.random.choice在模板列表中随机选择一个。可以看到,输入是action名,返回的template其实是一个 dict 对象。

    _fill_template

    _fill_template将对选择的模板进行槽位填充的工作。

    def _fill_template(
        self,
        template: Dict[Text, Any],
        filled_slots: Optional[Dict[Text, Any]] = None,
        **kwargs: Any,
    ) -> Dict[Text, Any]:
        """"Combine slot values and key word arguments to fill templates."""
    
        # Getting the slot values in the template variables
        template_vars = self._template_variables(filled_slots, kwargs)
    
        keys_to_interpolate = [
            "text",
            "image",
            "custom",
            "button",
            "attachment",
            "quick_replies",
        ]
        if template_vars:
            for key in keys_to_interpolate:
                if key in template:
                    template[key] = interpolate(template[key], template_vars)
        return template
    

    可以看到,输入的模板template和填充槽位filled_slots都是dict对象。暂时没有看到具体的例子,猜测:
    filled_slots中的所有key都是template中的槽位名,value是对槽位的填充值value,通过替换template中的槽位填充值,完成回复语句的生成。

    interpolate.py

    在实现TNLG的回复生成阶段,调用了interpolate.py下的两个模块 interpolate和interpolate_text。interpolate_text用于对text格式的template进行槽位填充,使用正则表达式替换和str.format()的形式:

    def interpolate_text(template: Text, values: Dict[Text, Text]) -> Text:
        # transforming template tags from
        # "{tag_name}" to "{0[tag_name]}"
        # as described here:
        # https://stackoverflow.com/questions/7934620/python-dots-in-the-name-of-variable-in-a-format-string#comment9695339_7934969
        # black list character and make sure to not to allow
        # (a) newline in slot name
        # (b) { or } in slot name
        try:
            text = re.sub(r"{([^\n{}]+?)}", r"{0[\1]}", template)
            text = text.format(values)
            if "0[" in text:
                # regex replaced tag but format did not replace
                # likely cause would be that tag name was enclosed
                # in double curly and format func simply escaped it.
                # we don't want to return {0[SLOTNAME]} thus
                # restoring original value with { being escaped.
                return template.format({})
    
            return text
        except KeyError as e:
            logger.exception(
                "Failed to fill utterance template '{}'. "
                "Tried to replace '{}' but could not find "
                "a value for it. There is no slot with this "
                "name nor did you pass the value explicitly "
                "when calling the template. Return template "
                "without filling the template. "
                "".format(template, e.args[0])
            )
            return template
    

    CallbackNaturalLanguageGenerator

    最后,来看CNLG。CNLG的结构要简单很多,仅包含两个成员函数,一个产生回复的generate,另一个用于检验回复格式是否合法的validate_response。

    generate
    async def generate(
        self,
        template_name: Text,
        tracker: DialogueStateTracker,
        output_channel: Text,
        **kwargs: Any,
    ) -> Dict[Text, Any]:
        """Retrieve a named template from the domain using an endpoint."""
    
        body = nlg_request_format(template_name, tracker, output_channel, **kwargs)
    
        logger.debug(
            "Requesting NLG for {} from {}."
            "".format(template_name, self.nlg_endpoint.url)
        )
    
        response = await self.nlg_endpoint.request(
            method="post", json=body, timeout=DEFAULT_REQUEST_TIMEOUT
        )
    
        if self.validate_response(response):
            return response
        else:
            raise Exception("NLG web endpoint returned an invalid response.")
    

    输入是action的名称,用于记录的tracker,以及output_channel。首先从nlg_request_format函数中得到request的body,之后向endpoint上的服务发出请求,调用定义在对应Action类中的run函数,得到response,验证response的合法性,并且返回。

    nlg_request_format
    def nlg_request_format(
        template_name: Text,
        tracker: DialogueStateTracker,
        output_channel: Text,
        **kwargs: Any,
    ) -> Dict[Text, Any]:
        """Create the json body for the NLG json body for the request."""
    
        tracker_state = tracker.current_state(EventVerbosity.ALL)
    
        return {
            "template": template_name,
            "arguments": kwargs,
            "tracker": tracker_state,
            "channel": {"name": output_channel},
        }
    

    这个函数处理产生request的主体,用于指定Action的调用。在写Action的时候就很好奇,Action类的run函数一般定义成这样:def run(self, dispatcher, tracker, domain),后来就很神奇的发现这里边的tracker并不是一个rasa_core.trackers,包含的信息比较少。果然,这里产生的tracker,仅仅保留了当前状态。

    相关文章

      网友评论

        本文标题:rasa_core: nlg模块源码解读

        本文链接:https://www.haomeiwen.com/subject/vwnsxhtx.html