美文网首页Spring Cloud 程序员分布式系列
基于redis和lua的分布式限流器设计与实现

基于redis和lua的分布式限流器设计与实现

作者: ro9er | 来源:发表于2018-10-23 17:02 被阅读113次

    前言

    之前这篇文章中,我大致介绍了一下google guava库中的RateLimiter的实现以及它背后的令牌桶算法原理。但是也有新的问题,在分布式的环境中,我们如何针对多机环境做限流呢?在查阅了一些资料和其他人的博客之后,我采用了redis来作为限流器的实现基础。
    原因主要有以下几点:

    • redis作为高性能缓存系统,性能上能够满足多机之间高并发访问的要求
    • redis有比较好的api来支持限流器令牌桶算法的实现
    • 对于我们的系统来说,通过spring data redis来操作比较简单和常见,避免了引入新的中间件带来的风险

    但是我们也知道,限流器在每次请求令牌和放入令牌操作中,存在一个协同的问题,即获取令牌操作要尽可能保证原子性,否则无法保证限流器是否能正常工作。在RateLimiter的实现中使用了mutex作为互斥锁来保证操作的原子性,那么在redis中就需要一个类似于事务的机制来保证获取令牌中多重操作的原子性。
    面对这样的需求,我们有几个选择:

    • 用redis实现分布式锁来保证操作的原子性,这个方案实现起来应该比较简单,分布式锁有现成的例子,然后就是把Rate Limiter的代码套用分布式锁就行了,但是这样的话效率会显得不太高,特别是在大量访问的情况下。
    • 用redis的transaction,在我查阅redis官方文档和stackoverflow之后发现redis的transaction官方并不推荐,并且有可能在未来取消事务,因此不可取。
    • 通过redis分布式锁和本地锁组成一个双层结构,每次分布式获取锁之后可以预支一部分令牌量,然后放到本地通过本地的锁来分配这些令牌,消耗完之后再到请求redis。这样的好处是相比第一个方案,网络访问延迟开销会比较好,但是实现难度和复杂程度比较难估量,而且这样的做法如果在多机不能保证均匀分配流量的情况下并不理想
    • 通过将获取锁封装到lua脚本中,提交给redis进行eval和evalsha操作来完成lua脚本的执行,由于lua脚本在redis中天然的原子性,我们的需求能够比较好的满足,问题是将业务逻辑封装在lua中,对于开发人员自身的能力和调试存在一定的问题。

    经过权衡,我采用了第四种方式,通过redis和lua来编写令牌桶算法来完成分布式限流的需求。

    lua脚本

    话不多说,先贴出lua代码

    
    -- 返回码 1:操作成功 0:未配置 -1: 获取失败 -2:修改错误,建议重新初始化 -500:不支持的操作
    -- redis hashmap 中存放的内容:
    -- last_mill_second 上次放入令牌或者初始化的时间
    -- stored_permits 目前令牌桶中的令牌数量
    -- max_permits 令牌桶容量
    -- interval 放令牌间隔
    -- app 一个标志位,表示对于当前key有没有限流存在
    
    local SUCCESS = 1
    local NO_LIMIT = 0
    local ACQUIRE_FAIL = -1
    local MODIFY_ERROR = -2
    local UNSUPPORT_METHOD = -500
    
    local ratelimit_info = redis.pcall("HMGET",KEYS[1], "last_mill_second", "stored_permits", "max_permits", "interval", "app")
    local last_mill_second = ratelimit_info[1]
    local stored_permits = tonumber(ratelimit_info[2])
    local max_permits = tonumber(ratelimit_info[3])
    local interval = tonumber(ratelimit_info[4])
    local app = ratelimit_info[5]
    
    local method = ARGV[1]
    
    --获取当前毫秒
    --考虑主从策略和脚本回放机制,这个time由客户端获取传入
    --local curr_time_arr = redis.call('TIME')
    --local curr_timestamp = curr_time_arr[1] * 1000 + curr_time_arr[2]/1000
    local curr_timestamp = tonumber(ARGV[2])
    
    
    -- 当前方法为初始化
    if method == 'init' then
        --如果app不为null说明已经初始化过,不要重复初始化
        if(type(app) ~='boolean' and app ~=nil) then
            return SUCCESS
        end
    
        redis.pcall("HMSET", KEYS[1],
            "last_mill_second", curr_timestamp,
            "stored_permits", ARGV[3],
            "max_permits", ARGV[4],
            "interval", ARGV[5],
            "app", ARGV[6])
        --始终返回成功
        return SUCCESS
    end
    
    -- 当前方法为修改配置
    if method == "modify" then
        if(type(app) =='boolean' or app ==nil) then
            return MODIFY_ERROR
        end
        --只能修改max_permits和interval
        redis.pcall("HMSET", KEYS[1],
            "max_permits", ARGV[3],
            "interval", ARGV[4])
    
        return SUCCESS
    
    end
    
    -- 当前方法为删除
    if method == "delete" then
        --已经清除完毕
        if(type(app) =='boolean' or app ==nil) then
            return SUCCESS
        end
        redis.pcall("DEL", KEYS[1])
        return SUCCESS
    end
    
    -- 尝试获取permits
    if method == "acquire" then
        -- 如果app为null说明没有对这个进行任何配置,返回0代表不限流
        if(type(app) =='boolean' or app ==nil) then
            return NO_LIMIT
        end
        --需要获取令牌数量
        local acquire_permits = tonumber(ARGV[3])
        --计算上一次放令牌到现在的时间间隔中,一共应该放入多少令牌
        local reserve_permits = math.max(0, math.floor((curr_timestamp - last_mill_second) / interval))
        
        local new_permits = math.min(max_permits, stored_permits + reserve_permits)
        local result = ACQUIRE_FAIL
        --如果桶中令牌数量够则放行
        if new_permits >= acquire_permits then
            result = SUCCESS
            new_permits = new_permits - acquire_permits
        end
        --更新当前桶中的令牌数量 
        redis.pcall("HSET", KEYS[1], "stored_permits", new_permits)
        --如果这次有放入令牌,则更新时间
        if reserve_permits > 0 then
            redis.pcall("HSET", KEYS[1], "last_mill_second", curr_timestamp)
        end
        return result
    end
    
    
    return UNSUPPORT_METHOD
    

    绝大部分逻辑在注释里面都已经写清楚了(我java客户端用的代码删掉了所有的注释,因为提交上去报编译错误,但是redis-cli调试就没问题,我也没太关注原因)。
    大致上,我在这个脚本中编写了4种函数:

    • init 初始化限流器
    • modify 修改限流器配置(主要针对限流器的桶大小和放令牌间隔,即1/QPS)
    • delete 删除限流器配置
    • acquire 尝试获取制定数目的令牌

    代码基本上仿照了Guava RateLimiter的逻辑,实现了触发式的放令牌策略。
    由于我的需求中不需要像guava RateLimiter那样的预支令牌的逻辑,因此如果当前没有令牌可供服务,我就直接返回获取失败了。
    还有一点需要注意的是,我本来在脚本中写了获取redis服务器当前时间的代码,但是我通过redis-cli执行的时候报错了:

    Write commands not allowed after non deterministic commands.
    

    这个错误的原因大家可以参见这篇文章,大致原因跟redis集群的重放和备份策略有关,相当于我调用TIME操作,会在主从各执行一次,得到的结果肯定会存在差异,这个差异就给最终逻辑正确性带来了不确定性。在redis 4.0之后引入了redis.replicate_commands()来放开限制。但我考虑了几个因素之后,还是采用网上大部分人的做法,在执行前先行获取到redis的时间戳,然后当做参数传上去。

    lua调试

    对lua调试最开始花掉了我不少时间,主要对于redis-cli命令不太熟悉。大家有一样问题的可以参见这篇文章。大致来说就是将写好的脚本放到redis所在文件夹下(我是windows环境),然后在cmd下执行 redis-cli.exe --eval rate_limit.lua test2(key,可重复) , (逗号分隔) init 10101 100 100 10 test2 (后跟参数,空格隔开)。

    java集成

    在完成了lua的调试工作之后,我们就开始java部分的集成代码编写,我们使用的是spring boot来完成开发。
    第一部分是redis配置:

        @Bean("rateLimitLua")
        public DefaultRedisScript<Long> getRateLimitScript() {
            DefaultRedisScript<Long> rateLimitLua = new DefaultRedisScript<>();
            rateLimitLua.setLocation(new ClassPathResource("rate_limit.lua"));
            rateLimitLua.setResultType(Long.class);
            return rateLimitLua;
        }
    

    然后是一些与lua适配的枚举和一些bean:

    /**
     * @author: Yuanqing Luo
     * @date: 2018/10/22
     *
     * 限流的具体方法
     */
    public enum RateLimitMethod {
    
        //initialize rate limiter
        init,
    
        //modify rate limiter parameter
        modify,
    
        //delete rate limiter
        delete,
    
        //acquire permits
        acquire;
    }
    
    /**
     * @author: Yuanqing Luo
     * @date: 2018/10/22
     * rate limite result
     **/
    public enum RateLimitResult {
    
        SUCCESS(1L),
        NO_LIMIT(0L),
        ACQUIRE_FAIL(-1L),
        MODIFY_ERROR(-2L),
        UNSUPPORT_METHOD(-500L),
        ERROR(-505L);
    
        private Long code;
    
        RateLimitResult(Long code){
            this.code = code;
        }
    
        public static RateLimitResult getResult(Long code){
            for(RateLimitResult enums: RateLimitResult.values()){
                if(enums.code.equals(code)){
                    return enums;
                }
            }
            throw new IllegalArgumentException("unknown rate limit return code:" + code);
        }
    }
    
    /**
     * @author: Yuanqing Luo
     * @date: 2018/10/22
     **/
    @Getter
    @Setter
    public class RateLimitVo {
    
        private String url;
    
        private boolean isLimit;
    
        private Double interval;
    
        private Integer maxPermits;
    
        private Integer initialPermits;
    
    }
    

    第三部分就是限流器的调用组装部分:

    /**
     * @author: Yuanqing Luo
     * @date: 2018/10/22
     **/
    @Service
    @Slf4j
    public class RateLimitClient {
    
        private static final String RATE_LIMIT_PREFIX = "ratelimit:";
    
        @Autowired
        StringRedisTemplate redisTemplate;
    
        @Resource
        @Qualifier("rateLimitLua")
        RedisScript<Long> rateLimitScript;
    
        public RateLimitResult init(String key, RateLimitVo rateLimitInfo){
            return exec(key, RateLimitMethod.init,
                    rateLimitInfo.getInitialPermits(),
                    rateLimitInfo.getMaxPermits(),
                    rateLimitInfo.getInterval(),
                    key);
        }
    
        public RateLimitResult modify(String key, RateLimitVo rateLimitInfo){
            return exec(key, RateLimitMethod.modify, key,
                    rateLimitInfo.getMaxPermits(),
                    rateLimitInfo.getInterval());
        }
    
        public RateLimitResult delete(String key){
            return exec(key, RateLimitMethod.delete);
        }
    
        public RateLimitResult acquire(String key){
            return acquire(key, 1);
        }
    
        public RateLimitResult acquire(String key, Integer permits){
            return exec(key, RateLimitMethod.acquire, permits);
        }
    
        /**
         * 执行redis的具体方法,限制method,保证没有其他的东西进来
         * @param key
         * @param method
         * @param params
         * @return
         */
        private RateLimitResult exec(String key, RateLimitMethod method, Object... params){
            try {
                Long timestamp = getRedisTimestamp();
                String[] allParams = new String[params.length + 2];
                allParams[0] = method.name();
                allParams[1] = timestamp.toString();
                for(int index = 0;index < params.length; index++){
                    allParams[2 + index] = params[index].toString();
                }
                Long result = redisTemplate.execute(rateLimitScript,
                        Collections.singletonList(getKey(key)),
                        allParams);
                return RateLimitResult.getResult(result);
            } catch (Exception e){
                log.error("execute redis script fail, key:{}, method:{}",
                        key, method.name(), e);
                return RateLimitResult.ERROR;
            }
        }
    
        private Long getRedisTimestamp(){
            Long currMillSecond = redisTemplate.execute(
                    (RedisCallback<Long>) redisConnection -> redisConnection.time()
            );
            return currMillSecond;
        }
        private String getKey(String key){
            return RATE_LIMIT_PREFIX + key;
        }
    }
    

    java代码这块比较简单了,基本就是封装了之前lua脚本中的4项操作。

    第四部分就是测试代码:

    /**
     * @author: Yuanqing Luo
     * @date: 2018/10/22
     **/
    @RunWith(SpringRunner.class)
    @SpringBootTest(classes = OpenApiGatewayApplication.class)
    public class RateLimitTest {
    
        @Autowired
        private RateLimitClient rateLimitClient;
    
        @Test
        public void testInit(){
            RateLimitVo vo = new RateLimitVo();
            vo.setInitialPermits(500);
            vo.setMaxPermits(500);
            vo.setInterval(2.0);
            rateLimitClient.init("test", vo);
        }
    
        @Test
        public void testAcquire() throws InterruptedException {
            //10个线程
            ExecutorService executorService = Executors.newFixedThreadPool(20);
    
            Subject<RateLimitSummary, RateLimitSummary> writeSubject = new SerializedSubject<RateLimitSummary, RateLimitSummary>(PublishSubject.<RateLimitSummary>create());
            Observable<RateLimitSummary> readSubject = writeSubject.share();
            Observable<RateLimitSummary> bucketStream = Observable.defer(()->{
                return readSubject.window(200, TimeUnit.MILLISECONDS)
                        .flatMap(
                                observable->
                                        observable.reduce(new RateLimitSummary(0,0,0),
                                                (a, b)-> a.reduce(b))
                        );
            });
            Observable<RateLimitSummary> rollingBucketStream = bucketStream.window(5, 1)
                    .flatMap(observable->observable.reduce(new RateLimitSummary(0, 0, 0),
                            (a, b)-> a.reduce(b)));
    
            Runnable acquire = () -> {
                Random random = new Random();
                while(true){
                    try {
                        Thread.sleep(30);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                    RateLimitResult result = rateLimitClient.acquire("test");
                    writeSubject.onNext(new RateLimitSummary(result));
                }
            };
            //初始时间
            final long currentMillis = System.currentTimeMillis();
            rollingBucketStream.subscribe(summary->{
                double timestamp = (System.currentTimeMillis() - currentMillis)/1000.0;
                System.out.println("time:"+ timestamp + ", acquired:" + summary.acquire +
                        ", reject " + summary.reject + ", error: " + summary.error);
            });
            for(int i=0;i<20;i++){
                executorService.submit(acquire);
            }
            while(true){
                Thread.sleep(5000);
            }
        }
    
        private static class RateLimitSummary{
            public int acquire;
            public int reject;
            public int error;
    
            public RateLimitSummary(RateLimitResult result){
                this.acquire = result == RateLimitResult.SUCCESS?1:0;
                this.reject = result == RateLimitResult.ACQUIRE_FAIL?1:0;
                this.error = result == RateLimitResult.ERROR?1:0;
            }
    
            public RateLimitSummary(int acquire, int reject, int error){
                this.acquire = acquire;
                this.reject = reject;
                this.error = error;
            }
    
            public RateLimitSummary reduce(RateLimitSummary toAdd){
                return new RateLimitSummary(this.acquire + toAdd.acquire,
                        this.reject + toAdd.reject,
                        this.error + toAdd.error);
    
            }
        }
    
    }
    

    这一段代码我仿照了Hystrix中的熔断统计的代码,通过一个subject来存放获取令牌结果,然后通过第一层bucketStream来将令牌结果按照200ms来分组并且reduce成一个结果。接着通过rollingBucketStream来将200ms的分组组合成一个一秒的时间窗(即5个为一组),并且以200ms为步长滚动。最后统计出来的结果通过subscribe来打印结果。之前的init代码我们看已经初始化了一个大小为500的令牌桶,存放令牌的时间间隔为2.0ms,所以支持的QPS为500。接着我们执行这段代码,并截取一部分输出:

    time:75.857, acquired:460, reject 8, error: 0
    time:76.056, acquired:483, reject 36, error: 0
    time:76.268, acquired:506, reject 52, error: 0
    time:76.454, acquired:503, reject 59, error: 0
    time:76.707, acquired:457, reject 69, error: 0
    time:76.854, acquired:417, reject 66, error: 0
    time:77.054, acquired:454, reject 36, error: 0
    time:77.255, acquired:459, reject 54, error: 0
    time:77.453, acquired:458, reject 77, error: 0
    time:77.658, acquired:474, reject 103, error: 0
    time:77.858, acquired:490, reject 132, error: 0
    

    可以看到,这个结果基本每200ms输出一次,然后一秒钟内的获取了令牌数目最大值跟500接近,并且能够很好地处理reject。有一部分结果一秒钟获取的令牌数与500差距较大,我分析的原因是因为请求重复时间段比较多,很多请求发生在前一个获取了令牌之后的2ms内,产生了reject。

    结语

    通过redis和lua,我实现了一个简单的分布式限流器。通过上述代码,大家能看到一个大致的实现框架,并且通过测试代码完成了验证。如果各位看官有什么问题欢迎留言,希望能跟大家共同学习。

    相关文章

      网友评论

        本文标题:基于redis和lua的分布式限流器设计与实现

        本文链接:https://www.haomeiwen.com/subject/sbhazftx.html