美文网首页
FastAPI内存泄露定位之FastAPI uvicorn代码走

FastAPI内存泄露定位之FastAPI uvicorn代码走

作者: ZackJiang | 来源:发表于2020-06-22 22:49 被阅读0次

    背景

    发现fastAPI和pytorch一起使用时,如果不使用async定义接口则会产生内存泄露,走读一下fastAPI代码看下区别到底在哪,相关git issue为https://github.com/tiangolo/fastapi/issues/596

    fastAPI uvicorn代码走读

    调用rest接口时,会走到starlette.routing.pyclass Routercall()方法,进行url匹配,如果走的是默认url群匹配,看这几行代码就足够了,下面不重要。

    starlette.routing.py class Router

        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
            """
            The main entry point to the Router class.
            """
            assert scope["type"] in ("http", "websocket", "lifespan")
    
            if "router" not in scope:
                scope["router"] = self
    
            # life span是控制服务器的起停的,这里不用关注
            if scope["type"] == "lifespan":
                await self.lifespan(scope, receive, send)
                return
    
            partial = None
    
            for route in self.routes:
                # Determine if any route matches the incoming scope,
                # and hand over to the matching route if found.
                match, child_scope = route.matches(scope)
                if match == Match.FULL:
                    scope.update(child_scope)
                    # 全匹配走到这里去调用实现并封装http请求
                    await route.handle(scope, receive, send)
                    return
                elif match == Match.PARTIAL and partial is None:
                    partial = route
                    partial_scope = child_scope
    

    此处的routing实例应该是类fastapi.routing.py中的class APIRoute的实例,但是那块没覆写__call__()方法,所以此处的self.routes属性就是ASGI初始化的时候通过装饰器放入的starlette.Route对象的实例。对应的handle实现如下

    starlette.routing.py class Route

        async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
            if self.methods and scope["method"] not in self.methods:
                if "app" in scope:
                    raise HTTPException(status_code=405)
                else:
                    response = PlainTextResponse("Method Not Allowed", status_code=405)
                await response(scope, receive, send)
            else:
                await self.app(scope, receive, send)
    

    其中,FastAPI中的route对象的实现为fastapi.routing.pyclass APIRoute(routing.Route)为starlette Route对象的子类,app属性的初始化方法如下。

    fastapi.routing.py class APIRoute

    class APIRoute(routing.Route):
      def __init__:
        #其他属性初始化省略了
        self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
        self.app = request_response(self.get_route_handler())
      def get_route_handler(self) -> Callable:
        return get_request_handler(
          dependant=self.dependant,
          body_field=self.body_field,
          status_code=self.status_code,
          response_class=self.response_class or JSONResponse,
          response_field=self.secure_cloned_response_field,
          response_model_include=self.response_model_include,
          response_model_exclude=self.response_model_exclude,
          response_model_by_alias=self.response_model_by_alias,
          response_model_exclude_unset=self.response_model_exclude_unset,
          response_model_exclude_defaults=self.response_model_exclude_defaults,
          response_model_exclude_none=self.response_model_exclude_none,
          dependency_overrides_provider=self.dependency_overrides_provider,
        )
      
      
    

    下面都是对http请求处理的实现:

    fastapi.routing.py

    def get_request_handler(
        dependant: Dependant,
        body_field: ModelField = None,
        status_code: int = 200,
        response_class: Type[Response] = JSONResponse,
        response_field: ModelField = None,
        response_model_include: Union[SetIntStr, DictIntStrAny] = None,
        response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
        response_model_by_alias: bool = True,
        response_model_exclude_unset: bool = False,
        response_model_exclude_defaults: bool = False,
        response_model_exclude_none: bool = False,
        dependency_overrides_provider: Any = None,
    ) -> Callable:
        assert dependant.call is not None, "dependant.call must be a function"
        is_coroutine = asyncio.iscoroutinefunction(dependant.call)
        is_body_form = body_field and isinstance(get_field_info(body_field), params.Form)
    
        async def app(request: Request) -> Response:
            try:
                body = None
                if body_field:
                    if is_body_form:
                        body = await request.form()
                    else:
                        body_bytes = await request.body()
                        if body_bytes:
                            body = await request.json()
            except Exception as e:
                logger.error(f"Error getting request body: {e}")
                raise HTTPException(
                    status_code=400, detail="There was an error parsing the body"
                ) from e
            solved_result = await solve_dependencies(
                request=request,
                dependant=dependant,
                body=body,
                dependency_overrides_provider=dependency_overrides_provider,
            )
            values, errors, background_tasks, sub_response, _ = solved_result
            if errors:
                raise RequestValidationError(errors, body=body)
            else:
              # 在这里调用你的rest接口实现
                raw_response = await run_endpoint_function(
                    dependant=dependant, values=values, is_coroutine=is_coroutine
                )
    
                if isinstance(raw_response, Response):
                    if raw_response.background is None:
                        raw_response.background = background_tasks
                    return raw_response
                response_data = await serialize_response(
                    field=response_field,
                    response_content=raw_response,
                    include=response_model_include,
                    exclude=response_model_exclude,
                    by_alias=response_model_by_alias,
                    exclude_unset=response_model_exclude_unset,
                    exclude_defaults=response_model_exclude_defaults,
                    exclude_none=response_model_exclude_none,
                    is_coroutine=is_coroutine,
                )
                response = response_class(
                    content=response_data,
                    status_code=status_code,
                    background=background_tasks,
                )
                response.headers.raw.extend(sub_response.headers.raw)
                if sub_response.status_code:
                    response.status_code = sub_response.status_code
                return response
    
        return app
    

    starlette.routing.py

    def request_response(func: typing.Callable) -> ASGIApp:
        """
        Takes a function or coroutine `func(request) -> response`,
        and returns an ASGI application.
        """
        is_coroutine = asyncio.iscoroutinefunction(func)
    
        async def app(scope: Scope, receive: Receive, send: Send) -> None:
            request = Request(scope, receive=receive, send=send)
            # 在fastAPI中 func就是get_request_handler返回的协程对象,is_corutine总是true。
            if is_coroutine:
                response = await func(request)
            else:
                response = await run_in_threadpool(func, request)
            await response(scope, receive, send)
    
        return app
    

    上面我们已经看到了,fastAPI在是通过dependant对象来驱动接口实现的,下面进去看下dependant对象的初始化。

    fastapi.dependencies.utils.py

    def get_dependant(
        *,
        path: str,
        call: Callable,
        name: str = None,
        security_scopes: List[str] = None,
        use_cache: bool = True,
    ) -> Dependant:
        path_param_names = get_path_param_names(path)
        endpoint_signature = get_typed_signature(call)
        signature_params = endpoint_signature.parameters
        if inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call):
            check_dependency_contextmanagers()
        dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
        for param_name, param in signature_params.items():
            if isinstance(param.default, params.Depends):
                sub_dependant = get_param_sub_dependant(
                    param=param, path=path, security_scopes=security_scopes
                )
                dependant.dependencies.append(sub_dependant)
        for param_name, param in signature_params.items():
            if isinstance(param.default, params.Depends):
                continue
            if add_non_field_param_to_dependency(param=param, dependant=dependant):
                continue
            param_field = get_param_field(
                param=param, default_field_info=params.Query, param_name=param_name
            )
            if param_name in path_param_names:
                assert is_scalar_field(
                    field=param_field
                ), f"Path params must be of one of the supported types"
                if isinstance(param.default, params.Path):
                    ignore_default = False
                else:
                    ignore_default = True
                param_field = get_param_field(
                    param=param,
                    param_name=param_name,
                    default_field_info=params.Path,
                    force_type=params.ParamTypes.path,
                    ignore_default=ignore_default,
                )
                add_param_to_fields(field=param_field, dependant=dependant)
            elif is_scalar_field(field=param_field):
                add_param_to_fields(field=param_field, dependant=dependant)
            elif isinstance(
                param.default, (params.Query, params.Header)
            ) and is_scalar_sequence_field(param_field):
                add_param_to_fields(field=param_field, dependant=dependant)
            else:
                field_info = get_field_info(param_field)
                assert isinstance(
                    field_info, params.Body
                ), f"Param: {param_field.name} can only be a request body, using Body(...)"
                dependant.body_params.append(param_field)
    
    

    这里看到也就是对一下路径参数啥的初始化也校验啥的,没啥了,直接往下看调用逻辑吧

    async def run_endpoint_function(
        *, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool
    ) -> Any:
        # Only called by get_request_handler. Has been split into its own function to
        # facilitate profiling endpoints, since inner functions are harder to profile.
        assert dependant.call is not None, "dependant.call must be a function"
    
        if is_coroutine:
            return await dependant.call(**values)
        else:
            return await run_in_threadpool(dependant.call, **values)
    

    OK,这里就可以知道fastAPI定义rest接口加不加async有什么区别了,一个是直接协程调用,不加async走了run_in_threadpool

    async def run_in_threadpool(
        func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
    ) -> T:
        loop = asyncio.get_event_loop()
        if contextvars is not None:  # pragma: no cover
            # Ensure we run in the same context
            child = functools.partial(func, *args, **kwargs)
            context = contextvars.copy_context()
            func = context.run
            args = (child,)
        elif kwargs:  # pragma: no cover
            # loop.run_in_executor doesn't accept 'kwargs', so bind them in here
            func = functools.partial(func, **kwargs)
        return await loop.run_in_executor(None, func, *args)
    

    这里已经看到实际执行时仍然使用的uvloop的事件循环loop.run_in_executor(None, func, *args),下面就可以通过这一步入手来看是不是pytorch于uvloop跑在一起就存在内存泄露了。

    当前结论:如果使用事件循环的run_in_executor并不指定executor时,默认executor的worker数量为cpu数量x5,线程在执行完后不会释放资源,但是当线程池已经满了以后理论上内存不应继续上涨

    接下来贴下我的测试代码:

    import asyncio
    
    import cv2 as cv
    import gc
    from pympler import tracker
    from concurrent import futures
    
    executor = futures.ThreadPoolExecutor(max_workers=1)
    
    memory_tracker = tracker.SummaryTracker()
    
    def mm():
        img = cv.imread("cap.jpg", 0)
        detector = cv.AKAZE_create()
        kpts, desc = detector.detectAndCompute(img, None)
        gc.collect()
        memory_tracker.print_diff()
        return None
    
    async def main():
        while True:
            loop = asyncio.get_event_loop()
            await loop.run_in_executor(executor, mm)
    
    
    if __name__=='__main__':
        loop = asyncio.get_event_loop()
        loop.run_until_complete(main())
    
    

    我的测试机上有40个cpu,所以理论上线程池的worker上线为200,如果指定executor最大数量的话测试(如以上代码),会发现内存稳定没有泄露,但是如果跟fastAPI一样的话会发现内存在前200次循环会一直上涨,之后稳定,但是如果你再thread_pool里执行的是特别大的模型的话,这里200这个数量级就太大了,有可能会吃掉非常多的内存。

    结论:如果用fastAPI跑非常大的深度学习模型,且部署的机器CPU数量较多的话,的确会吃掉很多内存,但是这里不是内存泄露,还是有上限的,但是还是建议starlette可以修改可以配置线程池大小,否则吃掉的内存太多了。当前建议容器化封装的时候只给对应服务分配少量的cpu资源,可以解决这个问题。

    另外,python 3.8已经限制了线程池的最大数量如下,如果你用的python 3.8也不用操心这个问题了。

            if max_workers is None:
                # ThreadPoolExecutor is often used to:
                # * CPU bound task which releases GIL
                # * I/O bound task (which releases GIL, of course)
                #
                # We use cpu_count + 4 for both types of tasks.
                # But we limit it to 32 to avoid consuming surprisingly large resource
                # on many core machine.
                max_workers = min(32, (os.cpu_count() or 1) + 4)
            if max_workers <= 0:
                raise ValueError("max_workers must be greater than 0")
    

    相关文章

      网友评论

          本文标题:FastAPI内存泄露定位之FastAPI uvicorn代码走

          本文链接:https://www.haomeiwen.com/subject/vzkqfktx.html