美文网首页
基于Redis + Lua的令牌桶限流器的实现

基于Redis + Lua的令牌桶限流器的实现

作者: 桃子是水果 | 来源:发表于2022-03-10 14:51 被阅读0次

开发环境

  • jdk 11.0.10
  • SpringBoot 2.6.2
  • Idea

主要依赖

<dependency>
      <groupId>redis.clients</groupId>
       <artifactId>jedis</artifactId>
</dependency>
<dependency>
      <groupId>org.apache.commons</groupId>
      <artifactId>commons-pool2</artifactId>
</dependency>
 <dependency>
      <groupId>org.springframework.boot</groupId>
      <artifactId>spring-boot-starter-aop</artifactId>
</dependency>
 <dependency>
      <groupId>org.springframework.boot</groupId>
      <artifactId>spring-boot-devtools</artifactId>
      <scope>runtime</scope>
      <optional>true</optional>
  </dependency>
  <dependency>
      <groupId>org.springframework.boot</groupId>
      <artifactId>spring-boot-configuration-processor</artifactId>
      <optional>true</optional>
  </dependency>

核心代码

自定义注解
@Documented
@Retention(RUNTIME)
@Target(value = {ElementType.METHOD})
public @interface RateLimit {

    /**
     * 限流接口名称
     * @return 限流接口名称
     */
    String interfaceName();

    /**
     * 最大令牌数
     * @return 最大令牌数
     */
    long maxPermits();

    /**
     * 每秒生成的令牌数
     * @return
     */
    long tokensPerSeconds();
}
限流器抽象类
public abstract class RateLimiter {

    private static final Logger logger = LoggerFactory.getLogger(RateLimiter.class);

    /**
     * 是否开启限流
     */
    private boolean limited = true;

    /**
     * 开启限流功能
     */
    public void open() {
        if (!this.limited) {
            this.limited = true;
        } else {
            logger.info("the limiter has started...");
        }
    }

    /**
     * 关闭限流功能
     */
    public void close() {
        if (this.limited) {
            this.limited = false;
        } else {
            logger.info("the limiter has stopped...");
        }
    }

    /**
     * 获取令牌(指定接口限流)
     * @param interfaceName 需要限流的接口名
     * @param maxPermits 最大令牌数
     * @param tokensPerSeconds 每秒生成的令牌数
     * @return boolean 是否通过限流(获取到令牌)
     */
    protected abstract boolean acquire(String interfaceName, long maxPermits, long tokensPerSeconds);

    /**
     * 获取令牌(指定接口)
     * @param interfaceName 需要限流的接口名
     * @return boolean 是否通过限流(获取到令牌)
     */
    public boolean tryAcquire(String interfaceName, long maxPermits, long tokensPerSeconds) {
        if (this.limited) {
            return this.acquire(interfaceName, maxPermits, tokensPerSeconds);
        } else {
            return true;
        }
    }
}
令牌桶实现类
public class TokenBucketRateLimiter extends RateLimiter {

    private static final Logger logger = LoggerFactory.getLogger(TokenBucketRateLimiter.class);

    /**
     * redis的lua脚本
     */
    private DefaultRedisScript<Boolean> script;

    /**
     * redisTemplate
     */
    private RedisTemplate<String, Object> redisTemplate;

    public TokenBucketRateLimiter(DefaultRedisScript<Boolean> script, RedisTemplate<String, Object> redisTemplate) {
        this.script = script;
        this.redisTemplate = redisTemplate;
    }

    /**
     * 限流检测(单个接口)
     * @param interfaceName 需要限流的接口名
     * @param maxPermits 最大令牌数
     * @param tokensPerSeconds 每秒生成的令牌数
     * @return 是否通过限流 true: 通过
     */
    @Override
    protected boolean acquire(String interfaceName, long maxPermits, long tokensPerSeconds) {
        // 错误的参数将不起作用
        if (maxPermits <= 0 || tokensPerSeconds <= 0) {
            logger.warn("maxPermits and tokensPerSeconds can not be less than zero...");
            return true;
        }

        // 参数结构: KEYS = [限流的key]   ARGV = [最大令牌数, 每秒生成的令牌数, 本次请求的毫秒数]
        Boolean result = this.redisTemplate.execute(this.script, Collections.singletonList(interfaceName), maxPermits, tokensPerSeconds, System.currentTimeMillis());
        return result!=null && result;
    }
}
具体实现令牌桶的Lua脚本
-- LUA脚本会以单线程执行,不会有并发问题,一个脚本中的执行过程中如果报错,那么已执行的操作不会回滚
-- KEYS和ARGV是外部传入进来需要操作的redis数据库中的key,下标从1开始
-- 参数结构: KEYS = [限流的key]   ARGV = [最大令牌数, 每秒生成的令牌数, 本次请求的毫秒数]
local info = redis.pcall('HMGET', KEYS[1], 'last_time', 'stored_token_nums')
local last_time = info[1] --最后一次通过限流的时间
local stored_token_nums = tonumber(info[2]) -- 剩余的令牌数量
local max_token = tonumber(ARGV[1])
local token_rate = tonumber(ARGV[2])
local current_time = tonumber(ARGV[3])
local past_time = 0
local rateOfperMills = token_rate/1000 -- 每毫秒生产令牌速率

if stored_token_nums == nil then
    -- 第一次请求或者键已经过期
    stored_token_nums = max_token --令牌恢复至最大数量
    last_time = current_time --记录请求时间
else
    -- 处于流量中
    past_time = current_time - last_time --经过了多少时间

    if past_time <= 0 then
        --高并发下每个服务的时间可能不一致
        past_time = 0 -- 强制变成0 此处可能会出现少量误差
    end
    -- 两次请求期间内应该生成多少个token
    local generated_nums = math.floor(past_time * rateOfperMills)  -- 向下取整,多余的认为还没生成完
    stored_token_nums = math.min((stored_token_nums + generated_nums), max_token) -- 合并所有的令牌后不能超过设定的最大令牌数
end

local returnVal = 0 -- 返回值

if stored_token_nums > 0 then
    returnVal = 1 -- 通过限流
    stored_token_nums = stored_token_nums - 1 -- 减少令牌
    -- 必须要在获得令牌后才能重新记录时间。举例: 当每隔2ms请求一次时,只要第一次没有获取到token,那么后续会无法生产token,永远只过去了2ms
    last_time = last_time + past_time
end

-- 更新缓存
redis.call('HMSET', KEYS[1], 'last_time', last_time, 'stored_token_nums', stored_token_nums)
-- 设置超时时间
-- 令牌桶满额的时间(超时时间)(ms) = 空缺的令牌数 * 生成一枚令牌所需要的毫秒数(1 / 每毫秒生产令牌速率)
redis.call('PEXPIRE', KEYS[1], math.ceil((1/rateOfperMills) * (max_token - stored_token_nums)))

return returnVal
切面类
@Aspect
public class RateLimitAspect {

    private static final Logger logger = LoggerFactory.getLogger(RateLimitAspect.class);

    private RateLimiter rateLimiter;

    public RateLimitAspect(RateLimiter rateLimiter) {
        this.rateLimiter = rateLimiter;
    }

    /**
     * 标注切点-所有标识了RateLimit注解的方法
     */
    @Pointcut("@annotation(cn.t.redis.limiter.annotations.RateLimit)")
    public void pointCut(){};

    @Before("pointCut()")
    public void before(JoinPoint joinPoint) {
        Method method = ((MethodSignature)joinPoint.getSignature()).getMethod();
        RateLimit a = method.getAnnotation(RateLimit.class);
        if (a != null) {
            String name = a.interfaceName();
            long maxPermits = a.maxPermits();
            long tokensPerSeconds = a.tokensPerSeconds();
            // 执行限流判断
            var ret = this.rateLimiter.tryAcquire(name, maxPermits, tokensPerSeconds);
            if (!ret) {
                throw new RateLimitException("the interface can not be accessed in the meantime...");
            }
        }
    }
}
自定义异常
public class RateLimitException extends RuntimeException {

    public RateLimitException() {}

    public RateLimitException(String message) {
        super(message);
    }
}
自动配置类
@Configuration
@AutoConfigureBefore(RedisAutoConfiguration.class) // 高优先级,先于自动默认的自动配置生成RedisTemplate
public class LimiterAutoConfiguration {

    @Autowired
    private RedisConnectionFactory connectionFactory;

    /**
     * 配置redisTemplate
     * @return redisTemplate
     */
    @Bean
    @ConditionalOnMissingBean(RedisTemplate.class)
    public RedisTemplate<String, Object> redisTemplate() {

        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
        redisTemplate.setConnectionFactory(this.connectionFactory);

        // 定义Jackson2JsonRedisSerializer序列化对象
        Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);

        ObjectMapper objectMapper = new ObjectMapper();
        // 指定要序列化的域,ALL:field,get和set等,ANY: 可见性,会将有private修饰符的字段也序列化
        objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance,ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(objectMapper);

        // 使用StringRedisSerializer来序列化和反序列化redis的key值
        redisTemplate.setKeySerializer(new StringRedisSerializer());
        redisTemplate.setHashKeySerializer(new StringRedisSerializer());
        // 使用jackson2JsonRedisSerializer序列化和反序列化value
        redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
        redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
        // 属性设置完成afterPropertiesSet就会被调用,可以对设置不成功的做一些默认处理
        redisTemplate.afterPropertiesSet();
        return redisTemplate;
    }

    /**
     * redis的lua脚本对象
     * @return lua脚本对象
     */
    @Bean
    public DefaultRedisScript<Boolean> redisScript() {
        DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("RateLimiter.lua")));
        redisScript.setResultType(Boolean.class);
        return redisScript;
    }

    /**
     * 默认限流器的实现-令牌桶
     * @return 默认限流器
     */
    @Bean
    @ConditionalOnMissingBean(RateLimiter.class)
    public RateLimiter rateLimiter() {
        return new TokenBucketRateLimiter(this.redisScript(), this.redisTemplate());
    }

    /**
     * 限流切面
     * @param rateLimiter
     * @return
     */
    @Bean
    @ConditionalOnBean(RateLimiter.class)
    public RateLimitAspect rateLimitAspect(RateLimiter rateLimiter) {
        return new RateLimitAspect(rateLimiter);
    }
}
自动配置类指示文件(src/main/resources/META-INF/spring.factories)
org.springframework.boot.autoconfigure.EnableAutoConfiguration=\
cn.t.redis.limiter.configuration.LimiterAutoConfiguration

打包后在需要使用限流功能的模块中引入即可

使用方法

  1. 引入本jar包

    <dependency>
        <groupId>cn.t.redis.limiter</groupId>
        <artifactId>limiter-spring-boot-starter</artifactId>
        <version>1.0.0</version>
     </dependency>
    
  2. 配置Redis连接信息

spring:
  redis:
    #host: localhost  # 单点连接ip
    #port: 18379 # # 单点连接端口
    timeout: 6000 # 连接超时时间
    password: your password
    client-type: lettuce #指定连接工厂类型
    cluster:
      max-redirects: 3  # 获取失败 最大重定向次数
      nodes: # 集群节点
        - 127.0.0.1:7001
        - 127.0.0.1:7002
        - 127.0.0.1:7003
        - 127.0.0.1:7004
        - 127.0.0.1:7005
        - 127.0.0.1:7006
    lettuce: # lettuce连接池
      pool:
        max-active: 100  # 连接池最大连接数(使用负值表示没有限制)
        max-idle: 20 # 最大空闲连接数
        min-idle: 10  # 最小空闲连接数
        max-wait: 1500 # 连接池最大阻塞等待时间(ms)(使用负值表示没有限制)
  1. 在需要限流的接口处使用注解
@RequestMapping("/index")
@RateLimit(interfaceName = "limit", maxPermits = 5, tokensPerSeconds = 1)
public String ratelimit() {
    return "hello world";
}
  1. 未通过限流的访问会抛出异常,建议在全局异常处理器中捕获处理。

例如:

@RestControllerAdvice
public class GlobalErrorController {
    @ExceptionHandler(RateLimitException.class)
    public String ratelimiteHanler(RateLimitException e) {
        return e.getMessage();
    }
}

代码地址: 基于Redis + Lua的令牌桶限流器的实现

相关文章

网友评论

      本文标题:基于Redis + Lua的令牌桶限流器的实现

      本文链接:https://www.haomeiwen.com/subject/yuifdrtx.html