美文网首页FastAPI 解读 by Gascognya
FastAPI 依赖注入详解:生成依赖树

FastAPI 依赖注入详解:生成依赖树

作者: Gascognya | 来源:发表于2020-10-25 16:37 被阅读0次
    class APIRoute(routing.Route):
        def __init__(...):
    
            ......
    
            self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
            for depends in self.dependencies[::-1]:
                self.dependant.dependencies.insert(
                    0,
                    get_parameterless_sub_dependant(depends=depends, path=self.path_format),
                )
    
            ......
    
    

    在添加APIRoute节点时,会对endpoint进行解析,生成 依赖树get_dependant便是解析出endpoint的依赖树的函数。

    这部分在之前源码解析中讲过,但是当时的理解并不深刻。这次让我们来认真剖析这部分

    def get_dependant(
        *,
        path: str,
        call: Callable,
        name: Optional[str] = None,
        security_scopes: Optional[List[str]] = None,
        use_cache: bool = True,
    ) -> Dependant:
        """
        * 该函数为递归函数, 不止会被endpoint调用, 也会被其依赖项调用。
    
        :param path: 路径
        :param call: endpoint/依赖项
        :param name: 被依赖项使用, 为参数名
        :param security_scopes: 被依赖项使用, 为积攒的安全域
        :param use_cache: 缓存
        :return: Dependant对象
        """
        path_param_names = get_path_param_names(path)
        # 捕捉路径参数 e.g. "/user/{id}"
    
        endpoint_signature = get_typed_signature(call)
        signature_params = endpoint_signature.parameters
        # 解析endpoint/依赖项的参数, 通过inspect
    
        if is_gen_callable(call) or is_async_gen_callable(call):
            check_dependency_contextmanagers()
        # 确保异步上下文管理器import成功
    
        dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
        # 依赖对象
    
        # 生成依赖树
        for param_name, param in signature_params.items():
    
            if isinstance(param.default, params.Depends):
                # 如果该参数是Depends()时 (因为其写在默认值位置)
                sub_dependant = get_param_sub_dependant(
                    param=param, path=path, security_scopes=security_scopes
                )
                # 生成一个子依赖项
                dependant.dependencies.append(sub_dependant)
                # 加入到父依赖项的节点中
                continue
    
            if add_non_field_param_to_dependency(param=param, dependant=dependant):
                continue
            # 找出Request, WebSocket, HTTPConnection, Response, BackgroundTasks, SecurityScopes等参数。
            # 将其参数名, 在dependant中标注出来
    
            # 既不是Depends依赖项, 也不是特殊参数
            # 就当做普通参数来看待
            param_field = get_param_field(
                param=param, default_field_info=params.Query, param_name=param_name
            )
            # 参数默认当做Query, 获取其ModelField
    
            if param_name in path_param_names:
                # 如果这个参数名在上文解析路径得到的路径参数集合中
                # e.g. "/user/{id}" -> {id, ...} -> param_name = "id"
    
                assert is_scalar_field(
                    field=param_field
                ), "Path params must be of one of the supported types"
                # 判断是否为标准的field类型
    
                if isinstance(param.default, params.Path):
                    ignore_default = False
                else:
                    ignore_default = True
                # path_param = Path(), 设置为不忽略默认, 确保有效
    
                param_field = get_param_field(
                    param=param,
                    param_name=param_name,
                    default_field_info=params.Path,
                    force_type=params.ParamTypes.path,
                    ignore_default=ignore_default,
                )
                # 重新按Path生成参数字段, 获得ModelField
                add_param_to_fields(field=param_field, dependant=dependant)
                # 整合到dependant的path参数列表
    
            elif is_scalar_field(field=param_field):
                # 如果并非path参数, 即默认query参数, 但属于标准field类型
                # 注: cookie属于这类
                add_param_to_fields(field=param_field, dependant=dependant)
                # 整合到dependant的query参数列表
    
            elif isinstance(
                param.default, (params.Query, params.Header)
            ) and is_scalar_sequence_field(param_field):
                # 如果不是path, 也不是标准的query, 但属于包含有Query()或Header()
                # 且为标准序列参数时
                add_param_to_fields(field=param_field, dependant=dependant)
                # 整合到dependant的query或header参数列表
    
            else:
                field_info = param_field.field_info
                assert isinstance(
                    field_info, params.Body
                ), f"Param: {param_field.name} can only be a request body, using Body(...)"
                # 上述条件都不满足, 即不是路径参数、标准查询参数、Query查询参数、Header参数中任何一个
                # 则断言一定是Body参数
                dependant.body_params.append(param_field)
                # 将其整合到Body参数列表
        return dependant
    

    分步解读

    def get_dependant(
        *,
        path: str,
        call: Callable,
        name: Optional[str] = None,
        security_scopes: Optional[List[str]] = None,
        use_cache: bool = True,
    ) -> Dependant:
        path_param_names = get_path_param_names(path)
        # 捕捉路径参数 e.g. "/user/{id}"
    
        endpoint_signature = get_typed_signature(call)
        signature_params = endpoint_signature.parameters
        # 解析endpoint/依赖项的参数, 通过inspect
    
        if is_gen_callable(call) or is_async_gen_callable(call):
            check_dependency_contextmanagers()
        # 确保异步上下文管理器import成功
    
        dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
        # 依赖对象
    

    get_dependant不止被endpoint使用,其依赖项和子依赖都会使用,其为递归函数。
    开头生成一个Dependant节点对象,等待下面加工,最终被返回。其形成的是一个树状结构

        for param_name, param in signature_params.items():
    

    接下来把该节点的参数都抓出来,逐个分析。

            if isinstance(param.default, params.Depends):
                # 如果该参数是Depends()时 (因为其写在默认值位置)
                sub_dependant = get_param_sub_dependant(
                    param=param, path=path, security_scopes=security_scopes
                )
                # 生成一个子依赖项
                dependant.dependencies.append(sub_dependant)
                # 加入到父依赖项的节点中
                continue
    

    首先判断是否为Depends()项,如果是,则生成子依赖。下面是生成子依赖的流程。

    def get_param_sub_dependant(
        *, param: inspect.Parameter, path: str, security_scopes: Optional[List[str]] = None
    ) -> Dependant:
        depends: params.Depends = param.default
        # 拿到Depends对象
    
        if depends.dependency:
            dependency = depends.dependency
        else:
            dependency = param.annotation
        # 拿到函数/类, 没有则默认为注解类。
        # 这代表着user: User = Depends() 是被允许的
    
        return get_sub_dependant(
            depends=depends,
            dependency=dependency,
            path=path,
            name=param.name,
            security_scopes=security_scopes,
        )
    

    拿出Depends中的依赖内容,如果没有就用注解来充当。即user: User = Depends()这种形式可以被允许。

    def get_sub_dependant(
        *,
        depends: params.Depends,
        dependency: Callable,
        path: str,
        name: Optional[str] = None,
        security_scopes: Optional[List[str]] = None,
    ) -> Dependant:
        """
        :param depends: 依赖项对象
        :param dependency: 具体依赖内容
        :param path: 路径
        :param name: 参数名
        :param security_scopes: 安全域
        :return:
        """
        security_requirement = None
        # 安全性要求, 先置为None
        security_scopes = security_scopes or []
        # 安全域
    
        if isinstance(depends, params.Security):
            # 判断是否为"安全依赖"
            # 注: Security是Depends的子类
            dependency_scopes = depends.scopes
            security_scopes.extend(dependency_scopes)
            # 将依赖项的安全域整合进来
    
        if isinstance(dependency, SecurityBase):
            # 如果依赖内容是安全认证 e.g. Depends(oauth2_scheme)
            # 注: OAuth2是SecurityBase的子类
    
            use_scopes: List[str] = []
            if isinstance(dependency, (OAuth2, OpenIdConnect)):
                # 注: OAuth2PasswordBearer, OAuth2AuthorizationCodeBearer
                # 两者为OAuth2子类
                use_scopes = security_scopes
                # 如果其为上述两者实例, 则将积攒的安全域, 传入其中。
    
            security_requirement = SecurityRequirement(
                security_scheme=dependency, scopes=use_scopes
            )
            # 安全性需求置为, SecurityRequirement(SecurityBase, [])
            # 或者 SecurityRequirement(OAuth2, security_scopes)
    
        # 上文两个判断组合起来的逻辑是
        # 1. 第一个判断, 将后置依赖中的安全域需求整合起来
        # 2. 当扫描到了前置的OAuth2时, 将这些积攒的安全域需求传入其中
    
        sub_dependant = get_dependant(
            path=path,
            call=dependency,
            name=name,
            security_scopes=security_scopes,
            use_cache=depends.use_cache,
        )
        # 以这个依赖项作为根节点, 继续生产依赖树
    
        if security_requirement:
            sub_dependant.security_requirements.append(security_requirement)
        # 将SecurityRequirement放进这个依赖项中
        # 注意SecurityRequirement存在条件是本依赖项为SecurityBase相关
    
        sub_dependant.security_scopes = security_scopes
        # 将现有的安全域需求放进这个依赖项中
    
        return sub_dependant
    

    接下来是对安全相关的处理。我们可以看到,中间又调用了get_dependant,参数包含了namesecurity_scopes。endpoint的根节点传参不包含这两项。

    回到get_dependant
            if add_non_field_param_to_dependency(param=param, dependant=dependant):
                continue
            # 找出Request, WebSocket, HTTPConnection, Response, BackgroundTasks, SecurityScopes等参数。
            # 将其参数名, 在dependant中标注出来
    
            # 既不是Depends依赖项, 也不是特殊参数
            # 就当做普通参数来看待
            param_field = get_param_field(
                param=param, default_field_info=params.Query, param_name=param_name
            )
            # 参数默认当做Query, 获取其ModelField
    

    如果不是Depends参数,则首先默认当成查询参数query,并生成ModelField字段。

            if param_name in path_param_names:
                # 如果这个参数名在上文解析路径得到的路径参数集合中
                # e.g. "/user/{id}" -> {id, ...} -> param_name = "id"
    
                assert is_scalar_field(
                    field=param_field
                ), "Path params must be of one of the supported types"
                # 判断是否为标准的field类型
    
                if isinstance(param.default, params.Path):
                    ignore_default = False
                else:
                    ignore_default = True
                # path_param = Path(), 设置为不忽略默认, 确保有效
    
                param_field = get_param_field(
                    param=param,
                    param_name=param_name,
                    default_field_info=params.Path,
                    force_type=params.ParamTypes.path,
                    ignore_default=ignore_default,
                )
                # 重新按Path生成参数字段, 获得ModelField
                add_param_to_fields(field=param_field, dependant=dependant)
                # 整合到dependant的path参数列表
    

    如果其为路径参数,则重新生成ModelField字段。再整合到dependant的参数列表中

            elif is_scalar_field(field=param_field):
                # 如果并非path参数, 即默认query参数, 但属于标准field类型
                # 注: cookie属于这类
                add_param_to_fields(field=param_field, dependant=dependant)
                # 整合到dependant的query参数列表
    

    不是路径参数,但是标准的查询参数

            elif isinstance(
                param.default, (params.Query, params.Header)
            ) and is_scalar_sequence_field(param_field):
                # 如果不是path, 也不是标准的query, 但属于包含有Query()或Header()
                # 且为标准序列参数时
                add_param_to_fields(field=param_field, dependant=dependant)
                # 整合到dependant的query或header参数列表
    

    Query()和Header()两种情况

            else:
                field_info = param_field.field_info
                assert isinstance(
                    field_info, params.Body
                ), f"Param: {param_field.name} can only be a request body, using Body(...)"
                # 上述条件都不满足, 即不是路径参数、标准查询参数、Query查询参数、Header参数中任何一个
                # 则断言一定是Body参数
                dependant.body_params.append(param_field)
                # 将其整合到Body参数列表
        return dependant
    

    当上述条件都不满足,则可以断言为Body()字段。

    就此,一个APIRoute的依赖树便生成了
    下章说说如何使用依赖树

    相关文章

      网友评论

        本文标题:FastAPI 依赖注入详解:生成依赖树

        本文链接:https://www.haomeiwen.com/subject/ozqgmktx.html