美文网首页
Golang自定义基于gin框架的Session中间件

Golang自定义基于gin框架的Session中间件

作者: FredricZhu | 来源:发表于2020-11-05 13:34 被阅读0次

    工程结构如下


    image.png

    原理主要是利用了cookie来保存sessionID。使用sessionID来获取每个用户对应的Session。
    main.go测试代码

    package main
    
    import (
        "fmt"
        "log"
        "net/http"
    
        "github.com/gin-gonic/gin"
        "github.com/zhuge20100104/gin_session/gsession"
    )
    
    func main() {
        r := gin.Default()
        mgrObj, err := gsession.CreateSessionMgr(gsession.Redis, "localhost:6379")
        if err != nil {
            log.Fatalf("Create manager obj failed, err: %v\n", err)
            return
        }
        sm := gsession.SessionMiddleware(mgrObj, gsession.Options{
            Path:     "/",
            Domain:   "127.0.0.1",
            MaxAge:   120,
            Secure:   false,
            HttpOnly: true,
        })
        r.Use(sm)
        r.GET("/incr", func(c *gin.Context) {
            session := c.MustGet("session").(gsession.Session)
            fmt.Printf("%#v\n", session)
            var count int
            v, err := session.Get("count")
            if err != nil {
                log.Printf("get count from session failed, err: %v\n", err)
                count = 0
            } else {
                count = v.(int)
                count++
            }
            session.Set("count", count)
            session.Save()
            c.String(http.StatusOK, "count:%v", count)
        })
        r.Run()
    }
    

    session.go

    package gsession
    
    import (
        "fmt"
        "log"
    
        "github.com/gin-gonic/gin"
    )
    
    type SessionMgrType string
    
    const (
        // SessionID在cookie里面的名字
        SessionCookieName = "session_id"
        // Session对象在Context里面的名字
        SessionContextName                = "session"
        Memory             SessionMgrType = "memory"
        Redis              SessionMgrType = "redis"
    )
    
    // Session 接口
    type Session interface {
        // 获取Session对象的ID
        ID() string
        // 加载redis数据到 session data
        Load() error
        // 获取key对应的value值
        Get(string) (interface{}, error)
        // 设置key对应的value值
        Set(string, interface{})
        // 删除key对应的value值
        Del(string)
        // 落盘数据到redis
        Save()
        // 设置Redis数据过期时间,内存版本无效
        SetExpired(int)
    }
    
    // SessionMgr Session管理器对象
    type SessionMgr interface {
        // 初始化Redis数据库连接
        Init(addr string, options ...string) error
        // 通过SessionID获取已经初始化的Session对象
        GetSession(string) (Session, error)
        // 创建一个新的Session对象
        CreateSession() Session
        // 使用SessionID清空一个Session对象
        Clear(string)
    }
    
    // Options Cookie对应的相关选项
    type Options struct {
        Path   string
        Domain string
        // Cookie中的SessionID存活时间
        // MaxAge=0 means no 'Max-Age' attribute specified.
        // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'.
        // MaxAge>0 means Max-Age attribute present and given in seconds.
        MaxAge   int
        Secure   bool
        HttpOnly bool
    }
    
    func CreateSessionMgr(name SessionMgrType, addr string, options ...string) (sm SessionMgr, err error) {
        switch name {
        case Memory:
            sm = NewMemSessionMgr()
        case Redis:
            sm = NewRedisSessionMgr()
        default:
            err = fmt.Errorf("unsupported %v\n", name)
            return
        }
        err = sm.Init(addr, options...)
        return
    }
    
    func SessionMiddleware(sm SessionMgr, options Options) gin.HandlerFunc {
        return func(c *gin.Context) {
            var session Session
            // 尝试从cookie获取session ID
            sessionID, err := c.Cookie(SessionCookieName)
            if err != nil {
                log.Printf("get session_id from cookie failed, err:%v\n", err)
                session = sm.CreateSession()
                sessionID = session.ID()
            } else {
                log.Printf("SessionId: %v\n", sessionID)
                session, err = sm.GetSession(sessionID)
                if err != nil {
                    log.Printf("Get session by %s failed, err: %v\n", sessionID, err)
                    session = sm.CreateSession()
                    sessionID = session.ID()
                }
            }
    
            session.SetExpired(options.MaxAge)
            c.Set(SessionContextName, session)
            c.SetCookie(SessionCookieName, sessionID, options.MaxAge, options.Path, options.Domain, options.Secure, options.HttpOnly)
            defer sm.Clear(sessionID)
            c.Next()
        }
    }
    
    

    memory.go

    package gsession
    
    import (
        "fmt"
        "sync"
    
        uuid "github.com/satori/go.uuid"
    )
    
    // memSession 内存对应的Session对象
    type memSession struct {
        // 全局唯一标识的session id对象
        id string
        // session数据
        data map[string]interface{}
        // session过期时间
        expired int
        // 读写锁,支持多线程
        rwLock sync.RWMutex
    }
    
    func NewMemSession(id string) *memSession {
        return &memSession{
            id:   id,
            data: make(map[string]interface{}, 8),
        }
    }
    
    func (m *memSession) ID() string {
        return m.id
    }
    
    func (m *memSession) Load() (err error) {
        return
    }
    
    func (m *memSession) Get(key string) (value interface{}, err error) {
        m.rwLock.RLock()
        defer m.rwLock.RUnlock()
        value, ok := m.data[key]
        if !ok {
            err = fmt.Errorf("Invalid key")
            return
        }
        return
    }
    
    func (m *memSession) Set(key string, value interface{}) {
        m.rwLock.Lock()
        defer m.rwLock.Unlock()
        m.data[key] = value
    }
    
    func (m *memSession) Del(key string) {
        m.rwLock.Lock()
        defer m.rwLock.Unlock()
        delete(m.data, key)
    }
    
    func (m *memSession) Save() {
        return
    }
    
    func (m *memSession) SetExpired(expired int) {
        m.expired = expired
    }
    
    // MemSessionMgr 内存Session管理器
    type MemSessionMgr struct {
        session map[string]Session
        rwLock  sync.RWMutex
    }
    
    // NewMemSessionMgr MemSessionMgr类构造函数
    func NewMemSessionMgr() *MemSessionMgr {
        return &MemSessionMgr{
            session: make(map[string]Session, 1024),
        }
    }
    
    func (m *MemSessionMgr) Init(addr string, options ...string) (err error) {
        return
    }
    
    // GetSession get the session by session id
    func (m *MemSessionMgr) GetSession(sessionID string) (sd Session, err error) {
        m.rwLock.RLock()
        defer m.rwLock.RUnlock()
        sd, ok := m.session[sessionID]
        if !ok {
            err = fmt.Errorf("Invalid session id")
            return
        }
        return
    }
    
    func (m *MemSessionMgr) CreateSession() (sd Session) {
        sessionID := uuid.NewV4().String()
        sd = NewMemSession(sessionID)
        m.session[sd.ID()] = sd
        return
    }
    
    func (m *MemSessionMgr) Clear(sessionID string) {
        m.rwLock.Lock()
        defer m.rwLock.Unlock()
        delete(m.session, sessionID)
    }
    
    

    redis.go

    package gsession
    
    import (
        "bytes"
        "encoding/gob"
        "fmt"
        "log"
        "strconv"
        "sync"
        "time"
    
        "github.com/go-redis/redis"
        uuid "github.com/satori/go.uuid"
    )
    
    // redisSession redis session对象
    type redisSession struct {
        // redis session id 对象
        id string
        // session 数据对象
        data map[string]interface{}
        // session 数据是否有更新
        modifyFlag bool
        // 过期时间
        expired int
        rwLock  sync.RWMutex
        client  *redis.Client
    }
    
    func NewRedisSession(id string, client *redis.Client) (session Session) {
        session = &redisSession{
            id:     id,
            data:   make(map[string]interface{}, 8),
            client: client,
        }
        return
    }
    
    func (r *redisSession) ID() string {
        return r.id
    }
    
    func (r *redisSession) Load() (err error) {
        data, err := r.client.Get(r.id).Bytes()
        if err != nil {
            log.Printf("get session data from redis by %s failed, err: %v\n", r.id, err)
            return
        }
    
        dec := gob.NewDecoder(bytes.NewBuffer(data))
        err = dec.Decode(&r.data)
        if err != nil {
            log.Printf("gob decode session data failed, err: %v\n", err)
            return
        }
        return
    }
    
    func (r *redisSession) Get(key string) (value interface{}, err error) {
        r.rwLock.RLock()
        defer r.rwLock.RUnlock()
        value, ok := r.data[key]
        if !ok {
            err = fmt.Errorf("invalid key")
            return
        }
        return
    }
    
    func (r *redisSession) Set(key string, value interface{}) {
        r.rwLock.Lock()
        defer r.rwLock.Unlock()
        r.data[key] = value
        r.modifyFlag = true
    }
    
    func (r *redisSession) Del(key string) {
        r.rwLock.Lock()
        defer r.rwLock.Unlock()
        delete(r.data, key)
        r.modifyFlag = true
    }
    
    func (r *redisSession) SetExpired(expired int) {
        r.expired = expired
    }
    
    func (r *redisSession) Save() {
        r.rwLock.Lock()
        defer r.rwLock.Unlock()
        if !r.modifyFlag {
            return
        }
        buf := new(bytes.Buffer)
        enc := gob.NewEncoder(buf)
        err := enc.Encode(r.data)
        if err != nil {
            log.Fatalf("gob encode r.data failed, err: %v\n", err)
            return
        }
    
        r.client.Set(r.id, buf.Bytes(), time.Second*time.Duration(r.expired))
        log.Printf("set data %v to redis.\n", buf.Bytes())
        r.modifyFlag = false
    }
    
    // redisSessionMgr redis Session管理器对象
    type redisSessionMgr struct {
        session map[string]Session
        rwLock  sync.RWMutex
        client  *redis.Client
    }
    
    // NewRedisSessionMgr Redis SessionMgr类构造函数
    func NewRedisSessionMgr() *redisSessionMgr {
        return &redisSessionMgr{
            session: make(map[string]Session, 1024),
        }
    }
    
    func (r *redisSessionMgr) Init(addr string, options ...string) (err error) {
        var (
            password string
            db       int
        )
        if len(options) == 1 {
            password = options[0]
        }
    
        if len(options) == 2 {
            password = options[0]
            db, err = strconv.Atoi(options[1])
            if err != nil {
                log.Fatalln("invalid redis DB param")
            }
        }
    
        r.client = redis.NewClient(&redis.Options{
            Addr:     addr,
            Password: password,
            DB:       db,
        })
    
        _, err = r.client.Ping().Result()
        if err != nil {
            return
        }
        return nil
    }
    
    func (r *redisSessionMgr) GetSession(sessionID string) (sd Session, err error) {
        sd = NewRedisSession(sessionID, r.client)
        err = sd.Load()
    
        if err != nil {
            return
        }
    
        r.rwLock.RLock()
        r.session[sessionID] = sd
        r.rwLock.RUnlock()
        return
    }
    
    func (r *redisSessionMgr) CreateSession() (sd Session) {
        sessionID := uuid.NewV4().String()
        sd = NewRedisSession(sessionID, r.client)
        r.session[sd.ID()] = sd
        return
    }
    
    func (r *redisSessionMgr) Clear(sessionID string) {
        r.rwLock.Lock()
        defer r.rwLock.Unlock()
        delete(r.session, sessionID)
    }
    
    

    程序输出如下,


    image.png

    相关文章

      网友评论

          本文标题:Golang自定义基于gin框架的Session中间件

          本文链接:https://www.haomeiwen.com/subject/fbwovktx.html