美文网首页程序员
django rest farmework 限制接口请求速率

django rest farmework 限制接口请求速率

作者: 天空蓝雨 | 来源:发表于2020-02-19 16:21 被阅读0次

    参考官方的文档:
    https://www.django-rest-framework.org/api-guide/throttling/

    30568817.jpg
    • 全局配置:

    在settings 里面全局配置

    REST_FRAMEWORK = {
        'DEFAULT_THROTTLE_CLASSES': [
            'rest_framework.throttling.AnonRateThrottle',
            'rest_framework.throttling.UserRateThrottle'
        ],
        'DEFAULT_THROTTLE_RATES': {
            'anon': '100/day',
            'user': '1000/day'
        }
    }
    

    anon 和 user分别对应默认的AnonRateThrottle 和UserRateThrottle 类
    每个类都有一个 scope 的属性

    **rest 内置了,这三个限速类 **
    AnonRateThrottle scope 属性的值是 anon
    UserRateThrottle scope 属性的值是user
    每个用户进行简单的全局速率限制,那么 UserRateThrottle 是合适的
    上面两个一个用户的所有请求都是累加的 (一个用户的所有视图请求次数累加到一起,判断是不是超速了)
    同一 scope和 token 组成 唯一的key,对应请求次数
    参考:

    django rest framework throttling per user and per view

    **ScopedRateThrottle 针对 不同api进行统计 视图必须包含 throttle_scope 属性 **
    (其实如果你要对每个视图的每个用户进行速率限制,那么只需要自定义即可 比如 你把自定义的 scope 设置为 请求的 api 路径,文末会讲一下)

    • DEFAULT_THROTTLE_CLASSES :

    指定全局使用的限速类

    • DEFAULT_THROTTLE_RATES :

    指定全局的scope对应的限速字符串参数

    • DEFAULT_THROTTLE_CLASSES:上面的官方文档的配置其实是全局使用了这两个限速类,并且配置了对应的scope (因为这连个内置类只能从 settings 的 user 和anon 读取 对应的rate 参数 )

    如果只想全局配置速率参数,那把限速类 :DEFAULT_THROTTLE_CLASSES 配置去掉即可。

    • 单独对函数配置

    • 不用全局配置的情况下,可以使用 throttle_classes装饰器 单独对某个 views 函数进行作用:
    xxx.views.py 
    @api_view()
    @throttle_classes([UserRateThrottle])
    def xxx(request):
          pass
    

    throttle_classes 装饰器源码

    def throttle_classes(throttle_classes):
        def decorator(func):
            func.throttle_classes = throttle_classes
            return func
        return decorator
    
    • 可以用类的方式 APIView

    因为 APIView 可以把属性throttle_classes 直接影响到类下面的每个视图函数。其实和 单独的装饰器原理是一样的

    from rest_framework.response import Response
    from rest_framework.throttling import UserRateThrottle
    from rest_framework.views import APIView
    
    class ExampleView(APIView):
        throttle_classes = [UserRateThrottle]
    
        def get(self, request, format=None):
            content = {
                'status': 'request was permitted'
            }
            return Response(content)
    

    APIView 部分源码:

    class APIView(View):
    
        # The following policies may be set at either globally, or per-view.
        renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
        parser_classes = api_settings.DEFAULT_PARSER_CLASSES
        authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
        throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
        permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
        content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
        metadata_class = api_settings.DEFAULT_METADATA_CLASS
        versioning_class = api_settings.DEFAULT_VERSIONING_CLASS
    
        # Allow dependency injection of other settings to make testing easier.
        settings = api_settings
    #在 APIViews 里面怎么调用的限速类呢,看下面
    获取 throttle 实例
        def get_throttles(self):
            """
            Instantiates and returns the list of throttles that this view uses.
            """
            return [throttle() for throttle in self.throttle_classes]
    # 执行检查 throttle 实例
        def check_throttles(self, request):
            """
            Check if request should be throttled.
            Raises an appropriate exception if the request is throttled.
            """
            for throttle in self.get_throttles():
                if not throttle.allow_request(request, self):
                    self.throttled(request, throttle.wait())
    # throttle 不通过速率校验返回这个函数,
        def throttled(self, request, wait):
            """
            If request is throttled, determine what kind of exception to raise.
            """
            raise exceptions.Throttled(wait)
    
    

    最后返回错误消息是在

    class Throttled(APIException) 这个类里面定义的。
    属于 rest_framework/exceptions.py  异常包
    
    *******************************************************

    **可以看出 APIView 的默认属性都是通过api_settings 来获取的 **
    看 api_settings 部分源码:

    class APISettings(object):
        """
        A settings object, that allows API settings to be accessed as properties.
        For example:
    
            from rest_framework.settings import api_settings
            print(api_settings.DEFAULT_RENDERER_CLASSES)
    
        Any setting with string import paths will be automatically resolved
        and return the class, rather than the string literal.
        """
        def __init__(self, user_settings=None, defaults=None, import_strings=None):
            if user_settings:
                self._user_settings = self.__check_user_settings(user_settings)
            self.defaults = defaults or DEFAULTS
            self.import_strings = import_strings or IMPORT_STRINGS
            self._cached_attrs = set()
    
        @property
        def user_settings(self):
            if not hasattr(self, '_user_settings'):
                self._user_settings = getattr(settings, 'REST_FRAMEWORK', {})
            return self._user_settings
    
    

    看一看出,在user_settings 里面通过读取 settings 里面的 名为 REST_FRAMEWORK 配置并且设置到当前 APISettings 实例。
    这就解释了 为什么限速类 最终默认值 都是连接到 settings 里面的
    读取顺序 当前函数属性(包括装饰器),然后是 APIVIew 属性 最后是settings REST_FRAMEWORK 默认配置

    可以看出无论是 throttle_classes 装饰器,还是 使用类视图 继承 APIView
    都是通过指定 throttle_classes 来获取用那个限速类 来限速

    *******************************************************
    • 自定义限速类

    直接继承 SimpleRateThrottle 就可以了,然后指定至少
    scope 或者 rate 属性(当然也可以都指定,那只会生效 rate)
    scope 或rate 格式:

    'nymber/time'      
    time 可以是 second,minute,hour或day
    

    当然偷懒的做法就是直接继承 rest 里面写好限速类

    class BurstRateThrottle(UserRateThrottle):
        scope = 'burst'
    
    class SustainedRateThrottle(UserRateThrottle):
        scope = 'sustained'
    

    除了 scope 和rate 属性,还有其他的一些属性,可以参考 官方 api 手册

    30568818.jpg

    那么限速类内部是怎么工作的呢? 看下面的源码分析一下就知道了

    • SimpleRateThrottle 部分源码

    • 初始化部分
    class SimpleRateThrottle(BaseThrottle):
        """
        A simple cache implementation, that only requires `.get_cache_key()`
        to be overridden.
    
        The rate (requests / seconds) is set by a `rate` attribute on the View
        class.  The attribute is a string of the form 'number_of_requests/period'.
    
        Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
    
        Previous request information used for throttling is stored in the cache.
        """
        cache = default_cache
        timer = time.time
        cache_format = 'throttle_%(scope)s_%(ident)s'
        scope = None
        THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
    
        def __init__(self):
            if not getattr(self, 'rate', None):
                self.rate = self.get_rate()
            self.num_requests, self.duration = self.parse_rate(self.rate)
    
    • 函数 get_rate() 部分(获取频率的字符串)
    
        def get_rate(self):
            """
            Determine the string representation of the allowed request rate.
            """
            if not getattr(self, 'scope', None):
                msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                       self.__class__.__name__)
                raise ImproperlyConfigured(msg)
    
            try:
                return self.THROTTLE_RATES[self.scope]
            except KeyError:
                msg = "No default throttle rate set for '%s' scope" % self.scope
                raise ImproperlyConfigured(msg)
    
    
    • parse_rate 函数部分 (解析频率字符串的参数)
        def parse_rate(self, rate):
            """
            Given the request rate string, return a two tuple of:
            <allowed number of requests>, <period of time in seconds>
            """
            if rate is None:
                return (None, None)
            num, period = rate.split('/')
            num_requests = int(num)
            duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
            return (num_requests, duration)
    

    从上面可以看出来,如果你的类有rate 属性,则直接使用,如果没有,就从settings 里面那个全局的 DEFAULT_THROTTLE_RATES 字典里面用scope 的key取出(这个和django 默认的日志系统,的用法有相似之处)

    • 最后 allow_request() 函数
      这个函数就是判断请求是不是符合频率的:
        def allow_request(self, request, view):
            """
            Implement the check to see if the request should be throttled.
    
            On success calls `throttle_success`.
            On failure calls `throttle_failure`.
            """
            if self.rate is None:
                return True
    
            self.key = self.get_cache_key(request, view)
            if self.key is None:
                return True
    
            self.history = self.cache.get(self.key, [])
            self.now = self.timer()
    
            # Drop any requests from the history which have now passed the
            # throttle duration
            while self.history and self.history[-1] <= self.now - self.duration:
                self.history.pop()
            if len(self.history) >= self.num_requests:
                return self.throttle_failure()
            return self.throttle_success()
    
    当然还有 get_cache_key wait 这两个主要的函数,这里就不在细讲了
    30568819.jpg

    然后在看一下 内置的 两个限速类的源码怎么写的:
    AnonRateThrottle

    class AnonRateThrottle(SimpleRateThrottle):
        """
        Limits the rate of API calls that may be made by a anonymous users.
    
        The IP address of the request will be used as the unique cache key.
        """
        scope = 'anon'
    
        def get_cache_key(self, request, view):
            if request.user.is_authenticated:
                return None  # Only throttle unauthenticated requests.
    
            return self.cache_format % {
                'scope': self.scope,
                'ident': self.get_ident(request)
            }
    
    

    UserRateThrottle

    class UserRateThrottle(SimpleRateThrottle):
        """
        Limits the rate of API calls that may be made by a given user.
    
        The user id will be used as a unique cache key if the user is
        authenticated.  For anonymous requests, the IP address of the request will
        be used.
        """
        scope = 'user'
    
        def get_cache_key(self, request, view):
            if request.user.is_authenticated:
                ident = request.user.pk   # 登录用户使用密码作为 计数key
            else:
                ident = self.get_ident(request)  # 没有登录用户使用 ip 作为  计数 key 
    
            return self.cache_format % {
                'scope': self.scope,
                'ident': ident
            }
    
    

    从上面可以看出,主要是 get_cache_key 函数要返回 scope 和 ident 属性,因为他们在父类中会有方法使用到。ident 是给请求计算次数用的 肯可能是匿名用户的ip 或者已认证用户的token

    好了,总结这么多就差不多了 累死宝宝了

    后续理解:
    api_view() 装饰器的工作流程:

    def api_view(http_method_names=None):
        """
        Decorator that converts a function-based view into an APIView subclass.
        Takes a list of allowed methods for the view as an argument.
        """
        def decorator(func):
            # 变成子类
            WrappedAPIView = type(
                six.PY3 and 'WrappedAPIView' or b'WrappedAPIView',
                (APIView,),
                {'__doc__': func.__doc__}
            )
    上面相当于 :
            #     class WrappedAPIView(APIView):
            #         pass
            #     WrappedAPIView.__doc__ = func.doc    <--- Not possible to do this
    
            def handler(self, *args, **kwargs):
                return func(*args, **kwargs)
    
            for method in http_method_names:
                setattr(WrappedAPIView, method.lower(), handler)
    
            WrappedAPIView.__name__ = func.__name__
            WrappedAPIView.__module__ = func.__module__
    
            WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes',
                                                      APIView.renderer_classes)
    
            WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
                                                    APIView.parser_classes)
    
            WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
                                                            APIView.authentication_classes)
    
            WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes',
                                                      APIView.throttle_classes)
    
            WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
                                                        APIView.permission_classes)
    
            WrappedAPIView.schema = getattr(func, 'schema',
                                            APIView.schema)
    
            return WrappedAPIView.as_view()
    
        return decorator
    

    从上面的解释可以看出,api_view 函数,其实是把装饰的函数变成了 APIView 的一个子类。然后把被装饰的函数里面的属性进行覆盖掉 APIView 里面的默认属性
    主要有一下属性:
    函数普通的 比如 namedoc
    还有APIView 的内置属性,比如 renderer_classes 、parser_classes 、authentication_classes 、throttle_classes 、 permission_classes 、schema

    注意最后调用 了 as_view( ) 函数,看的出来 rest_framework 的api_view 函数,其实是整合了其他装饰器。并最后返回一个 APIView.as_view( ) 函数(这就和 django里面使用 类视图,包括自动匹配 get post 方法 练联系起来了。)
    还有一点
    rest_farmwork 的视图使用 好像最后都要APIView 那个返回,不然报错?至少我工作项目里面会报错。可能有什么配置,以后再说吧
    参考 rest_farmwork views 的写法(他里面说了 有类视图继承(APIView),和 函数视图(@api_view()) 这样就会保证 对接了 rest 的东西 比如 request 。
    https://www.django-rest-framework.org/api-guide/views/

    最后举个例子,对每个用户的每个接口进行速率的限制:

    
    

    对特定用户,进行设置速率,比如,管理员不限次数,普通用户 五次每分钟:

    class StrictRate(UserRateThrottle):
        """普通用户一分钟五次, 管理员不限次数"""
    
        def allow_request(self, request, view):
            # 管理员不限次数,普通用户 5次 一分钟
            self.rate = None if request.user.is_superuser else "5/m"
            super(StrictRate, self).allow_request(request, view)
    
    

    就是继承,在allow_request 里面 根据user 属性设置 rate 就行了

    相关文章

      网友评论

        本文标题:django rest farmework 限制接口请求速率

        本文链接:https://www.haomeiwen.com/subject/pyeafhtx.html