美文网首页
springboot+redis+lua实现简单的限流、黑名单

springboot+redis+lua实现简单的限流、黑名单

作者: BlackR_ | 来源:发表于2021-07-20 16:03 被阅读0次

    简单了解下redis嵌入lua脚本(随便百度扒的):
    Redis支持的LUA脚本与其优势
    redis嵌入lua官方文档
    Redis悲观锁、乐观锁和调用Lua脚本的优缺点

    序言

    本教程本着人和代码其中一个能跑就行的原则。人菜勿喷,只接受技术性建议。
    近期看到了一些关于redis的文章说:哇redis牛皮、redis怎么这么快、哦哟!redis6竟然支持多线程了。。。

    然后问题来了:多线程?线程安全问题?哦哟,百度百度了解了下,原来如此,此多线程非彼多线程(自己百度去,这里不说这个)。然后就看到了嵌入lua脚本这个骚操作,在继续百度了下,哎哟,不错哟!拿来吧你。

    蹭个热度:吴签和某时间管理大师参与多人运动(多线程)时已出现一个很细小(线程安全)的针眼,我的有点大,请你忍耐一下,一夜之间让花季程序猿痛苦流泪。

    正文

    通过redis嵌入lua脚本,实现简单的限流、黑名单功能。别说这也没有那也没有,个性化功能自行开发。
    其中有点小坑,最后再说。
    说那么多,不如直接丢代码。Talk is cheap. Show me the code.

    测试环境

    win11
    jdk8
    Redis server v=5.0.9
    springboot 2.4.7

    Show me the code.

    先来看看lua脚本
    lua脚本存放在项目的resource目录下的lua文件夹下面(路径可以自己改,下面SelfRedisScript .java里面改成对应的就行)

    --- lua脚本:限流、黑名单专用,慎改
    --- 用于高并发情况下保证redis线程安全
    --- 注意:
    --- 1、redis反序列化问题
    --- 2、完成lua脚本后,请在本地测试无误后再提交代码
    --- 3、若lua脚本执行报错,redis不会回滚已经执行的命令
    
    -- 获取传递进来的参数
    local countKey = KEYS[1]
    if countKey == nil then
        return true
    end
    -- 获取传递进来的阈值
    local requestCount = KEYS[2]
    -- 获取传递进来的过期时间ttl
    local requestTtl = KEYS[3]
    -- 获取redis参数
    local countVal = redis.call('GET', countKey)
    -- 如果不是第一次请求
    if countVal then
        -- 由于lua脚本接收到参数都会转为String,所以要转成数字类型才能比较
        local numCountVal = tonumber(countVal)
        -- 如果超过指定阈值,则返回true
        if numCountVal >= tonumber(requestCount) then
            return true
        else
            numCountVal = numCountVal + 1
            redis.call('SETEX', countKey, requestTtl, numCountVal)
        end
    else
        redis.call('SETEX', countKey, requestTtl, 1)
    end
    return false
    

    Java代码(SelfRedisScript .java)注入RedisScript

    @Component
    public class SelfRedisScript {
    
        @Bean("redisScriptBoolean")
        public DefaultRedisScript<Boolean> redisScriptBoolean() {
            DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
            redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit_blacklisted.lua")));
            redisScript.setResultType(Boolean.class);
            return redisScript;
        }
    }
    

    Java代码(RedisTemplateConfig.java)简单配置RedisTemplate

    @EnableCaching
    @Configuration
    @AutoConfigureBefore(RedisAutoConfiguration.class)
    public class RedisTemplateConfig {
    
        @Bean
        @Primary
        public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
            RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
            redisTemplate.setKeySerializer(new StringRedisSerializer());
            redisTemplate.setHashKeySerializer(new StringRedisSerializer());
            redisTemplate.setValueSerializer(new StringRedisSerializer());
            redisTemplate.setHashValueSerializer(new StringRedisSerializer());
            redisTemplate.setConnectionFactory(redisConnectionFactory);
            return redisTemplate;
        }
    }
    

    准备工作做完后,开始实现简单的限流、黑名单

    在过滤器里面实现功能
    CommonConstants类中的常量、ApplicationConfig从application.yml从获取值的代码就不贴出来了,
    WebUtils.returnResponse单独列出来了,自行修改补充。
    RedisConstants常量:

       /**
        * 限流机制key前缀 REQUEST_LIMIT:127.0.0.1:/api/test
        */
       String REQUEST_LIMIT = "REQUEST_LIMIT:%s:%s";
    
       /**
        * 黑名单机制key前缀 REQUEST_LIMIT:127.0.0.1:
        */
       String REQUEST_BLACKLISTED = "REQUEST_BLACKLISTED:%s";
    
    @Slf4j
    public class RequestLimitFilter implements Filter {
    
        private final ApplicationConfig applicationConfig;
        private Long limitTimeSeconds;
        private Integer limitCount;
        private Long blacklistedTimeSeconds;
        private Integer blacklistedCount;
        private List<String> limitIgnores;
        private RedisTemplate<String, Object> redisTemplate;
        private DefaultRedisScript<Boolean> script;
    
        @Override
        public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
            HttpServletRequest request = (HttpServletRequest) servletRequest;
            HttpServletResponse response = (HttpServletResponse) servletResponse;
            String requestURI = request.getRequestURI();
            if (!WebUtils.uriMatch(this.limitIgnores, requestURI)) {
                // 获取ip
                String realIp = WebUtils.getIP(request);
                // 黑名单限制
                String blacklistedKey = String.format(RedisConstants.REQUEST_BLACKLISTED, realIp);
                // key为空返回true,超过指定阈值返回true,其他返回false
                Boolean blackPass = getPass(blacklistedKey, blacklistedCount, blacklistedTimeSeconds);
                if (blackPass) {
                    WebUtils.returnResponse(response, JSONUtil.toJsonStr(R.failed(StatusCode.BLACKLISTED)));
                    return;
                }
                // 限流限制
                String limitKey = String.format(RedisConstants.REQUEST_LIMIT, realIp, requestURI);
                Boolean limitPass = getPass(limitKey, limitCount, limitTimeSeconds);
                if (limitPass) {
                    WebUtils.returnResponse(response, JSONUtil.toJsonStr(R.failed(StatusCode.LIMITED)));
                    return;
                }
            }
            filterChain.doFilter(servletRequest, servletResponse);
        }
    
        @Override
        public void destroy() {
            Filter.super.destroy();
        }
    
        @Override
        public void init(FilterConfig filterConfig) throws ServletException {
            // 限流
            this.limitCount = ObjectUtil.isNull(applicationConfig.getLimitCount())
                    ? CommonConstants.REQUEST_LIMIT_COUNT : applicationConfig.getLimitCount();
            this.limitTimeSeconds = ObjectUtil.isNull(applicationConfig.getLimitTimeSeconds())
                    ? CommonConstants.REQUEST_LIMIT_TIME_SECONDS : applicationConfig.getLimitTimeSeconds();
            // 黑名单
            this.blacklistedCount = ObjectUtil.isNull(applicationConfig.getBlacklistedCount())
                    ? CommonConstants.REQUEST_BLACKLISTED_COUNT : applicationConfig.getBlacklistedCount();
            this.blacklistedTimeSeconds = ObjectUtil.isNull(applicationConfig.getBlacklistedTimeSeconds())
                    ? CommonConstants.REQUEST_BLACKLISTED_TIME_SECONDS : applicationConfig.getBlacklistedTimeSeconds();
            // 过滤请求,从application.yml从获取值
            this.limitIgnores = IterUtil.isEmpty(applicationConfig.getLimitIgnores())
                    ? Collections.emptyList() : applicationConfig.getLimitIgnores();
            // lua
            this.redisTemplate = SpringContextHolder.getBean(RedisTemplate.class);
            this.script = SpringContextHolder.getBean("redisScriptBoolean");
            Filter.super.init(filterConfig);
        }
    
        public RequestLimitFilter(ApplicationConfig applicationConfig) {
            this.applicationConfig = applicationConfig;
        }
    
        /**
         *  调用lua脚本,获取执行结果
         * @param key 缓存key
         * @param count 请求阈值
         * @param timeSeconds  拦截时间
         * @return 执行结果
         */
        private Boolean getPass(String key, Integer count, Long timeSeconds) {
            Boolean execute = redisTemplate.execute(script, Arrays.asList(key, String.valueOf(count), String.valueOf(timeSeconds)));
            return execute == null ? true : execute;
        }
    }
    
    // -------------------------------------------WebUtils工具类---------------------------------------------
        public void returnResponse(HttpServletResponse response, String data) {
            response.setCharacterEncoding("UTF-8");
            response.setContentType("text/html; charset=utf-8");
            try (PrintWriter writer = response.getWriter()) {
                // 通过 PrintWriter 将 data 数据直接 print 回去      
                writer.print(data);
            } catch (IOException ignored) {
            }
        }
    
        public String getIP(HttpServletRequest request) {
            Assert.notNull(request, "HttpServletRequest is null");
            String ip = request.getHeader(HEADER_X_REQUESTED_FOR);
            if (StrUtil.isBlank(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader(HEADER_X_FORWARDED_FOR);
            }
            if (StrUtil.isBlank(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader(HEADER_PROXY_CLIENT_IP);
            }
            if (StrUtil.isBlank(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader(HEADER_WL_PROXY_CLIENT_IP);
            }
            if (StrUtil.isBlank(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader(HEADER_HTTP_CLIENT_IP);
            }
            if (StrUtil.isBlank(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getHeader(HEADER_HTTP_X_FORWARDED_FOR);
            }
            if (StrUtil.isBlank(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                ip = request.getRemoteAddr();
            }
            return StrUtil.isBlank(ip) ? null : ip.split(",")[0];
        }
    

    最后注册下RequestLimitFilter.java这个过滤器

    @Component
    @AllArgsConstructor
    public class FilterRegistration {
    
        private final ApplicationConfig applicationConfig;
    
        @Bean
        public FilterRegistrationBean<RequestLimitFilter> requestLimitFilter() {
            FilterRegistrationBean<RequestLimitFilter> registration = new FilterRegistrationBean<>();
            registration.setFilter(new RequestLimitFilter(applicationConfig));
            registration.addUrlPatterns("/*");
            registration.setName("RequestLimitFilter");
            registration.setOrder(1);
            return registration;
        }
    }
    

    展示下成果(计算规则自行调整

    请求即记录


    请求即记录

    时间段内多次请求达到限流指定的请求阈值


    达到限流指定的请求阈值

    时间段内多次请求已被限流后,继续请求达到黑名单指定的请求阈值


    达到黑名单指定的请求阈值

    注意事项

    • RedisTemplate配置的序列化问题

    如果配置的是JdkSerializationRedisSerializer,就需要改成StringRedisSerializer,如果需要两者兼容,那
    就再给spring丢一个名为jdkRedisSerializer的Bean,然后在 @Autowired时,添加@Qualifier("jdkRedisSerializer")指定注入Bean

      @Bean("jdkRedisSerializer")
      public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
          RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
          redisTemplate.setKeySerializer(new StringRedisSerializer());
          redisTemplate.setHashKeySerializer(new StringRedisSerializer());
          redisTemplate.setValueSerializer(new JdkSerializationRedisSerializer());
          redisTemplate.setHashValueSerializer(new JdkSerializationRedisSerializer());
          redisTemplate.setConnectionFactory(redisConnectionFactory);
          return redisTemplate;
      }
    
    • lua脚本执行报错问题

    若lua脚本执行报错,redis不会回滚已经执行的命令,所以在完成lua脚本后,请在本地测试无误后再提交代码

    相关文章

      网友评论

          本文标题:springboot+redis+lua实现简单的限流、黑名单

          本文链接:https://www.haomeiwen.com/subject/fpvlmltx.html