美文网首页
用 go 实现一个分布式限流器

用 go 实现一个分布式限流器

作者: CocoAdapter | 来源:发表于2019-09-25 15:59 被阅读0次

    项目中需要对 api 的接口进行限流,但是麻烦的是,api 可能有多个节点,传统的本地限流无法处理这个问题。限流的算法有很多,比如计数器法,漏斗法,令牌桶法,等等。各有利弊,相关博文网上很多,这里不再赘述。

    项目的要求主要有以下几点:

    1. 支持本地/分布式限流,接口统一
    2. 支持多种限流算法的切换
    3. 方便配置,配置方式不确定

    go 语言不是很支持 OOP,我在实现的时候是按 Java 的思路走的,所以看起来有点不伦不类,希望能抛砖引玉。

    1. 接口定义

    package ratelimit
    
    import "time"
    
    // 限流器接口
    type Limiter interface {
        Acquire() error
        TryAcquire() bool
    }
    
    // 限流定义接口
    type Limit interface {
        Name() string
        Key() string
        Period() time.Duration
        Count() int32
        LimitType() LimitType
    }
    
    // 支持 burst
    type BurstLimit interface {
        Limit
        BurstCount() int32
    }
    
    // 分布式定义的 burst
    type DistLimit interface {
        Limit
        ClusterNum() int32
    }
    
    type LimitType int32
    const (
        CUSTOM LimitType = iota
        IP
    )
    

    Limiter 接口参考了 Google 的 guava 包里的 Limiter 实现。Acquire 接口是阻塞接口,其实还需要加上 context 来保证调用链安全,因为实际项目中并没有用到 Acquire 接口,所以没有实现完善;同理,超时时间的支持也可以通过添加新接口继承自 Limiter 接口来实现。TryAcquire 会立即返回。

    Limit 抽象了一个限流定义,Key() 方法返回这个 Limit 的唯一标识,Name() 仅作辅助,Period() 表示周期,单位是秒,Count() 表示周期内的最大次数,LimitType()表示根据什么来做区分,如 IP,默认是 CUSTOM.
    BurstLimit 提供突发的能力,一般是配合令牌桶算法。DistLimit 新增 ClusterNum() 方法,因为 mentor 要求分布式遇到错误的时候,需要退化为单机版本,退化的策略即是:2 节点总共 100QPS,如果出现分区,每个节点需要调整为各 50QPS

    2. LocalCounterLimiter

    package ratelimit
    
    import (
        "errors"
        "fmt"
        "math"
        "sync"
        "sync/atomic"
        "time"
    )
    
    // todo timer 需要 stop
    type localCounterLimiter struct {
        limit Limit
    
        limitCount int32 // 内部使用,对 limit.count 做了 <0 时的转换
    
        ticker *time.Ticker
        quit chan bool
    
        lock sync.Mutex
        newTerm *sync.Cond
        count int32
    }
    
    func (lim *localCounterLimiter) init() {
        lim.newTerm = sync.NewCond(&lim.lock)
        lim.limitCount = lim.limit.Count()
    
        if lim.limitCount < 0 {
            lim.limitCount = math.MaxInt32 // count 永远不会大于 limitCount,后面的写法保证溢出也没问题
        } else if lim.limitCount == 0  {
            // 禁止访问, 会无限阻塞
        } else {
            lim.ticker = time.NewTicker(lim.limit.Period())
            lim.quit = make(chan bool, 1)
    
            go func() {
                for {
                    select {
                    case <- lim.ticker.C:
                        fmt.Println("ticker .")
                        atomic.StoreInt32(&lim.count, 0)
                        lim.newTerm.Broadcast()
    
                        //lim.newTerm.L.Unlock()
                    case <- lim.quit:
                        fmt.Println("work well .")
                        lim.ticker.Stop()
                        return
                    }
                }
            }()
        }
    }
    
    // todo 需要机制来防止无限阻塞, 不超时也应该有个极限时间
    func (lim *localCounterLimiter) Acquire() error {
        if lim.limitCount == 0 {
            return errors.New("rate limit is 0, infinity wait")
        }
    
        lim.newTerm.L.Lock()
        for lim.count >= lim.limitCount {
            // block instead of spinning
            lim.newTerm.Wait()
            //fmt.Println(count, lim.limitCount)
        }
        lim.count++
        lim.newTerm.L.Unlock()
    
        return nil
    }
    
    func (lim *localCounterLimiter) TryAcquire() bool {
        count := atomic.AddInt32(&lim.count, 1)
        if count > lim.limitCount {
            return false
        } else {
            return true
        }
    }
    

    代码很简单,就不多说了

    3. LocalTokenBucketLimiter

    golang 的官方库里提供了一个 ratelimiter,就是采用令牌桶的算法。所以这里并没有重复造轮子,直接代理了 ratelimiter。

    package ratelimit
    
    import (
        "context"
        "golang.org/x/time/rate"
        "math"
    )
    
    type localTokenBucketLimiter struct {
        limit Limit
    
        limiter *rate.Limiter // 直接复用令牌桶的
    }
    
    func (lim *localTokenBucketLimiter) init() {
        burstCount := lim.limit.Count()
        if burstLimit, ok := lim.limit.(BurstLimit); ok {
            burstCount = burstLimit.BurstCount()
        }
    
        count := lim.limit.Count()
        if count < 0 {
            count = math.MaxInt32
        }
    
        f := float64(count) / lim.limit.Period().Seconds()
        if f < 0 {
            f = float64(rate.Inf) // 无限
        } else if f == 0 {
            panic("为 0 的时候,底层实现有问题")
        }
    
        lim.limiter = rate.NewLimiter(rate.Limit(f), int(burstCount))
    }
    
    func (lim *localTokenBucketLimiter) Acquire() error {
        err := lim.limiter.Wait(context.TODO())
        return err
    }
    
    func (lim *localTokenBucketLimiter) TryAcquire() bool {
        return lim.limiter.Allow()
    }
    

    4. RedisCounterLimiter

    package ratelimit
    
    import (
        "math"
        "sync"
        "xg-go/log"
        "xg-go/xg/common"
    )
    
    type redisCounterLimiter struct {
        limit      DistLimit
        limitCount int32 // 内部使用,对 limit.count 做了 <0 时的转换
    
        redisClient *common.RedisClient
    
        once sync.Once // 退化为本地计数器的时候使用
        localLim Limiter
    
        //script string
    }
    
    func (lim *redisCounterLimiter) init() {
        lim.limitCount = lim.limit.Count()
        if lim.limitCount < 0 {
            lim.limitCount = math.MaxInt32
        }
    
        //lim.script = buildScript()
    }
    
    //func buildScript() string {
    //  sb := strings.Builder{}
    //
    //  sb.WriteString("local c")
    //  sb.WriteString("\nc = redis.call('get',KEYS[1])")
    //  // 调用不超过最大值,则直接返回
    //  sb.WriteString("\nif c and tonumber(c) > tonumber(ARGV[1]) then")
    //  sb.WriteString("\nreturn c;")
    //  sb.WriteString("\nend")
    //  // 执行计算器自加
    //  sb.WriteString("\nc = redis.call('incr',KEYS[1])")
    //  sb.WriteString("\nif tonumber(c) == 1 then")
    //  sb.WriteString("\nredis.call('expire',KEYS[1],ARGV[2])")
    //  sb.WriteString("\nend")
    //  sb.WriteString("\nif tonumber(c) == 1 then")
    //  sb.WriteString("\nreturn c;")
    //
    //  return sb.String()
    //}
    
    func (lim *redisCounterLimiter) Acquire() error {
        panic("implement me")
    }
    
    func (lim *redisCounterLimiter) TryAcquire() (success bool) {
        defer func() {
            // 一般是 redis 连接断了,会触发空指针
            if err := recover(); err != nil {
                //log.Errorw("TryAcquire err", common.ERR, err)
                //success = lim.degradeTryAcquire()
                //return
                success = true
            }
    
            // 没有错误,判断是否开启了 local 如果开启了,把它停掉
            //if lim.localLim != nil {
            //  // stop 线程安全
            //  lim.localLim.Stop()
            //}
        }()
    
        count, err := lim.redisClient.IncrBy(lim.limit.Key(), 1)
        //panic("模拟 redis 出错")
        if err != nil {
            log.Errorw("TryAcquire err", common.ERR, err)
            panic(err)
        }
    
        // *2 是为了保留久一点,便于观察
        err = lim.redisClient.Expire(lim.limit.Key(), int(2 * lim.limit.Period().Seconds()))
        if err != nil {
            log.Errorw("TryAcquire error", common.ERR, err)
            panic(err)
        }
    
        // 业务正确的情况下 确认超限
        if int32(count) > lim.limitCount {
            return false
        }
    
        return true
    
        //keys := []string{lim.limit.Key()}
        //
        //log.Errorw("TryAcquire ", keys, lim.limit.Count(), lim.limit.Period().Seconds())
        //count, err := lim.redisClient.Eval(lim.script, keys, lim.limit.Count(), lim.limit.Period().Seconds())
        //if err != nil {
        //  log.Errorw("TryAcquire error", common.ERR, err)
        //  return false
        //}
        //
        //
        //typeName := reflect.TypeOf(count).Name()
        //log.Errorw(typeName)
        //
        //if count != nil && count.(int32) <= lim.limitCount {
        //
        //  return true
        //}
        //return false
    }
    
    func (lim *redisCounterLimiter) Stop() {
        // 判断是否开启了 local 如果开启了,把它停掉
        if lim.localLim != nil {
            // stop 线程安全
            lim.localLim.Stop()
        }
    }
    
    func (lim *redisCounterLimiter) degradeTryAcquire() bool {
        lim.once.Do(func() {
            count := lim.limit.Count() / lim.limit.ClusterNum()
            limit := LocalLimit {
                name: lim.limit.Name(),
                key: lim.limit.Key(),
                count: count,
                period: lim.limit.Period(),
                limitType: lim.limit.LimitType(),
            }
    
            lim.localLim = NewLimiter(&limit)
        })
    
        return lim.localLim.TryAcquire()
    }
    

    代码里回退的部分注释了,因为线上为了稳定,实习生的代码毕竟,所以先不跑。
    本来原有的思路是直接用 lua 脚本在 redis 上保证原子操作,但是底层封装的库对于直接调 eval 跑的时候,会抛错,而且 source 是 go-redis 里面,赶 ddl 没有时间去 debug,所以只能用 incrBy + expire 分开来。

    5. RedisTokenBucketLimiter

    令牌桶的状态变量得放在一个 线程安全/一致 的地方,redis 是不二人选。但是令牌桶的算法核心是个延迟计算得到令牌数量,这个是一个很长的临界区,所以要么用分布式锁,要么直接利用 redis 的单线程以原子方式跑。一般业界是后者,即 lua 脚本维护令牌桶的状态变量、计算令牌。代码类似这种

    local tokens_key = KEYS[1]
    local timestamp_key = KEYS[2]
    --redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)
    
    local rate = tonumber(ARGV[1])
    local capacity = tonumber(ARGV[2])
    local now = tonumber(ARGV[3])
    local requested = tonumber(ARGV[4])
    local intval = tonumber(ARGV[5])
    
    local fill_time = capacity/rate
    local ttl = math.floor(fill_time*2) * intval
    
    local last_tokens = tonumber(redis.call("get", tokens_key))
    if last_tokens == nil then
      last_tokens = capacity
    end
    
    local last_refreshed = tonumber(redis.call("get", timestamp_key))
    if last_refreshed == nil then
      last_refreshed = 0
    end
    
    local delta = math.max(0, now-last_refreshed)
    local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
    local allowed = filled_tokens >= requested
    local new_tokens = filled_tokens
    if allowed then
      new_tokens = filled_tokens - requested
    end
    
    redis.call("setex", tokens_key, ttl, new_tokens)
    redis.call("setex", timestamp_key, ttl, now)
    
    return { allowed, new_tokens }
    

    相关文章

      网友评论

          本文标题:用 go 实现一个分布式限流器

          本文链接:https://www.haomeiwen.com/subject/iiziuctx.html