美文网首页FastAPI 解读 by Gascognya
FastAPI 源码阅读 (四) Endpoint解析

FastAPI 源码阅读 (四) Endpoint解析

作者: Gascognya | 来源:发表于2020-09-02 13:29 被阅读0次

当路由找到了匹配的APIRoute,会调用其中的appapp中包含了对model和参数的验证。这对应我们上一章的内容。

我们配置了一个endpoint的参数验证,以及其response model。我们要将endpoint执行前处理参数认证,在其之后处理response model

get_request_handler

这里面定义了一层对endpoint的闭包app,负责在endpoint前后处理认证依赖和模型

def get_request_handler(
        dependant: Dependant,
        body_field: Optional[ModelField] = None,
        status_code: int = 200,
        response_class: Type[Response] = JSONResponse,
        response_field: Optional[ModelField] = None,
        response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
        response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
        response_model_by_alias: bool = True,
        response_model_exclude_unset: bool = False,
        response_model_exclude_defaults: bool = False,
        response_model_exclude_none: bool = False,
        dependency_overrides_provider: Optional[Any] = None,
) -> Callable:
    assert dependant.call is not None, "dependant.call must be a function"
    is_coroutine = asyncio.iscoroutinefunction(dependant.call)
    is_body_form = body_field and isinstance(body_field.field_info, params.Form)

    # route中调用endpoint的app,相比starlette的内容多了许多
    # 主要是对执行前的依赖,和执行后的response封装。进行了补充。

    async def app(request: Request) -> Response:
        try:
            body = None
            if body_field:
                # 把body抓过来
                if is_body_form:
                    body = await request.form()
                else:
                    body_bytes = await request.body()
                    if body_bytes:
                        body = await request.json()
        except json.JSONDecodeError as e:
            raise RequestValidationError([ErrorWrapper(e, ("body", e.pos))], body=e.doc)
        except Exception as e:
            raise HTTPException(
                status_code=400, detail="There was an error parsing the body"
            ) from e
        # 解决依赖
        solved_result = await solve_dependencies(
            request=request,
            dependant=dependant,
            body=body,
            dependency_overrides_provider=dependency_overrides_provider,
        )
        # 拿到依赖结果
        values, errors, background_tasks, sub_response, _ = solved_result
        # errors存在说明依赖项没有解决
        if errors:
            raise RequestValidationError(errors, body=body)
        else:
            # 进入endpoint,拿到response
            raw_response = await run_endpoint_function(
                dependant=dependant, values=values, is_coroutine=is_coroutine
            )

            if isinstance(raw_response, Response):
                # 如果是 return Response()
                if raw_response.background is None:
                    raw_response.background = background_tasks
                return raw_response
                # 直接返回
            # response为其他形式(如dict),进行序列化
            # 设置了response_model才生效
            response_data = await serialize_response(
                field=response_field,
                response_content=raw_response,
                include=response_model_include,
                exclude=response_model_exclude,
                by_alias=response_model_by_alias,
                exclude_unset=response_model_exclude_unset,
                exclude_defaults=response_model_exclude_defaults,
                exclude_none=response_model_exclude_none,
                is_coroutine=is_coroutine,
            )
            # 创建model实例,填充数据
            response = response_class(
                content=response_data,
                status_code=status_code,
                background=background_tasks,
            )
            # 合成response
            response.headers.raw.extend(sub_response.headers.raw)
            if sub_response.status_code:
                response.status_code = sub_response.status_code
            return response

    return app
solve_dependencies() & serialize_response()

solve_dependencies()包含了对Depends()的认证,也包含了对参数类型的认证。只有通过认证,才能进入到endpoint中
serialize_response()负责对endpoint的返回结果进行序列化(根据response model),如果直接返回了response衍生类,则跳过这步

async def solve_dependencies(
    *,
    request: Union[Request, WebSocket],
    dependant: Dependant,
    body: Optional[Union[Dict[str, Any], FormData]] = None,
    background_tasks: Optional[BackgroundTasks] = None,
    response: Optional[Response] = None,
    dependency_overrides_provider: Optional[Any] = None,
    dependency_cache: Optional[Dict[Tuple[Callable, Tuple[str]], Any]] = None,
) -> Tuple[
    Dict[str, Any],
    List[ErrorWrapper],
    Optional[BackgroundTasks],
    Response,
    Dict[Tuple[Callable, Tuple[str]], Any],
]:
    """
    处理endpoint的依赖函数,进行类型验证
    """
    values: Dict[str, Any] = {}
    errors: List[ErrorWrapper] = []
    response = response or Response(
        content=None,
        status_code=None,  # type: ignore
        headers=None,
        media_type=None,
        background=None,
    )
    # 新建个response
    dependency_cache = dependency_cache or {}
    sub_dependant: Dependant
    # 对每个依赖进行处理
    for sub_dependant in dependant.dependencies:
        sub_dependant.call = cast(Callable, sub_dependant.call)
        sub_dependant.cache_key = cast(
            Tuple[Callable, Tuple[str]], sub_dependant.cache_key
        )
        # 进行下类型标注
        call = sub_dependant.call
        # 获取到依赖函数
        use_sub_dependant = sub_dependant
        if (
            dependency_overrides_provider
            and dependency_overrides_provider.dependency_overrides
        ):
            # dependency_overrides_provider为app实例本身
            # 而dependency_overrides定义时默认值为空
            original_call = sub_dependant.call
            call = getattr(
                dependency_overrides_provider, "dependency_overrides", {}
            ).get(original_call, original_call)
            use_path: str = sub_dependant.path  # type: ignore
            use_sub_dependant = get_dependant(
                path=use_path,
                call=call,
                name=sub_dependant.name,
                security_scopes=sub_dependant.security_scopes,
            )
            use_sub_dependant.security_scopes = sub_dependant.security_scopes

        # 依赖递归,直到子依赖的dependencies为空
        # 实际上是对dependant树的遍历
        # 以依赖项作为根传入,直到没有依赖项了为止
        solved_result = await solve_dependencies(
            request=request,
            dependant=use_sub_dependant,
            body=body,
            background_tasks=background_tasks,
            response=response,
            dependency_overrides_provider=dependency_overrides_provider,
            dependency_cache=dependency_cache,
        )
        (
            sub_values,
            sub_errors,
            background_tasks,
            _,  # 子依赖项返回与我们相同的响应
            sub_dependency_cache,
        ) = solved_result
        # 子依赖的处理结果在这里返回

        # 整合子依赖的结果
        dependency_cache.update(sub_dependency_cache)
        if sub_errors:
            errors.extend(sub_errors)
            continue

        # 如果是已执行过的依赖,获取结果
        if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
            solved = dependency_cache[sub_dependant.cache_key]
        # 否则执行
        elif is_gen_callable(call) or is_async_gen_callable(call):
            stack = request.scope.get("fastapi_astack")
            if stack is None:
                raise RuntimeError(
                    async_contextmanager_dependencies_error
                )  # pragma: no cover
            solved = await solve_generator(
                call=call, stack=stack, sub_values=sub_values
                # 将子依赖的结果作为参数传入到自身
            )
        elif is_coroutine_callable(call):
            solved = await call(**sub_values)
        else:
            solved = await run_in_threadpool(call, **sub_values)
        if sub_dependant.name is not None:
            values[sub_dependant.name] = solved
        if sub_dependant.cache_key not in dependency_cache:
            dependency_cache[sub_dependant.cache_key] = solved
            # 将结果添加到字典 {function: result}

    # 子依赖项没有依赖时,会跳过循环,直接到达这里

    # 将request的参数,与依赖中的参数要求做对比。
    # 这样能够确定,在接受到的一大堆参数中,哪个参数是哪个依赖所需要的
    # 可以理解为各取所需
    path_values, path_errors = request_params_to_args(
        dependant.path_params, request.path_params
    )
    query_values, query_errors = request_params_to_args(
        dependant.query_params, request.query_params
    )
    header_values, header_errors = request_params_to_args(
        dependant.header_params, request.headers
    )
    cookie_values, cookie_errors = request_params_to_args(
        dependant.cookie_params, request.cookies
    )
    values.update(path_values)
    values.update(query_values)
    values.update(header_values)
    values.update(cookie_values)
    errors += path_errors + query_errors + header_errors + cookie_errors
    # 拼接出该call最终需要的参数

    # 这些是在get_dependant中配置好了的一些项
    # 比如是否要传入request等......
    if dependant.body_params:
        # 在参数中配置了model,将body中的参数注入到model中
        (
            body_values,
            body_errors,
        ) = await request_body_to_args(  # body_params checked above
            required_params=dependant.body_params, received_body=body
        )
        values.update(body_values)
        errors.extend(body_errors)
    if dependant.http_connection_param_name:
        values[dependant.http_connection_param_name] = request
    if dependant.request_param_name and isinstance(request, Request):
        values[dependant.request_param_name] = request
    elif dependant.websocket_param_name and isinstance(request, WebSocket):
        values[dependant.websocket_param_name] = request
    if dependant.background_tasks_param_name:
        if background_tasks is None:
            background_tasks = BackgroundTasks()
        values[dependant.background_tasks_param_name] = background_tasks
    if dependant.response_param_name:
        values[dependant.response_param_name] = response
    if dependant.security_scopes_param_name:
        values[dependant.security_scopes_param_name] = SecurityScopes(
            scopes=dependant.security_scopes
        )
    return values, errors, background_tasks, response, dependency_cache

之前提到过dependant将依赖以树状呈现,那么如果想解决依赖,也要从树根开始遍历。
这个函数的逻辑,就好比拿着个小篮子到树上摘苹果。将苹果(某树枝的结果)放在小篮子(values, errors)里。

async def serialize_response(
        *,
        field: Optional[ModelField] = None,
        response_content: Any,
        include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
        exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
        by_alias: bool = True,
        exclude_unset: bool = False,
        exclude_defaults: bool = False,
        exclude_none: bool = False,
        is_coroutine: bool = True,
) -> Any:
    if field:
        errors = []
        # 整理格式
        response_content = _prepare_response_content(
            response_content,
            exclude_unset=exclude_unset,
            exclude_defaults=exclude_defaults,
            exclude_none=exclude_none,
        )
        # 验证字段
        if is_coroutine:
            value, errors_ = field.validate(response_content, {}, loc=("response",))
        else:
            value, errors_ = await run_in_threadpool(
                field.validate, response_content, {}, loc=("response",)
            )
        if isinstance(errors_, ErrorWrapper):
            errors.append(errors_)
        elif isinstance(errors_, list):
            errors.extend(errors_)
        if errors:
            raise ValidationError(errors, field.type_)

        # 返回序列化结果
        return jsonable_encoder(
            value,
            include=include,
            exclude=exclude,
            by_alias=by_alias,
            exclude_unset=exclude_unset,
            exclude_defaults=exclude_defaults,
            exclude_none=exclude_none,
        )
    else:
        return jsonable_encoder(response_content)

最简单情况下,一个request经过 fastapi 的流程大致如此,接下来会进行细节分析和扩展内容的了解。

相关文章

网友评论

    本文标题:FastAPI 源码阅读 (四) Endpoint解析

    本文链接:https://www.haomeiwen.com/subject/pcfjsktx.html