美文网首页Go
动手实现 Redis 跳表(Go 语言)

动手实现 Redis 跳表(Go 语言)

作者: 0xE8551CCB | 来源:发表于2019-12-11 21:42 被阅读0次

    引言

    image
    读过 Redis 源码的童鞋,想必会知道 zset 实现时,使用了「跳表」(Skiplist)这种数据结构吧。它的原理非常容易理解,如果对链表比较熟悉,那么也会很容易理解「跳表」的工作原理(核心:有序链表 + 分层)。当然,本文并不会详细讲解「跳表」的工作原理,以及对于 Redis 跳表源码的详细分析。因为已经有前辈们产出了非常丰富的文章来讲解 Redis 跳表,需要的话,推荐阅读 这篇文章 了解更多细节。

    总的来说,Redis 的 zset 实现中,选用「跳表」的主要原因如下:

    1. 原理清晰易懂,且容易实现,方便维护:对比下平衡树或者红黑树(可能就像 Raft v.s. Paxos 的感觉一样),不管是原理还是实现都简单了很多。平衡树或者红黑树在实现时,还要时刻维护节点关系,必要时还需要执行树的左旋或者右旋来保持平衡;
    2. 拥有媲美平衡树或者红黑树的查询效率:插入、删除、查找的平均时间复杂度可以达到 O(logN)。

    当然,相对于 William Pugh 在他的论文中所描述的「跳表」算法而言,作者在实现 Redis 中的「跳表」时,给它加了点「料」:

    1. 允许重复的分数存在;
    2. 在进行比较时,不仅会比较 score,还会考虑关联的数据;
    3. 添加了一个回退指针,从而构成了一个双向链表(level[0]),便于倒序遍历链表(ZREVRANGE)使用。

    好了,废话完毕。接下来进入正题,看看如何使用 Go 语言来实现「跳表」吧(贴代码模式开启~)。

    跳表实现

    以下仅仅列出了几个比较有趣且关键的方法实现,即:插入、删除和更新分数。完整的实现源码可以参考 这里 或者 这里,包含了比较详细的单元测试。

    数据结构定义

    需要说明的是,为了简单起见,假设存储的元素是字符串类型(要是使用 interface{} 的话,又得加些代码支持元素之间的比较了)。但是在 Redis 中,实际的 element 类型是 sds

    const (
        MaxLevel = 64 // 足以容纳 2^64 个元素
        P = 0.25
    )
    
    type Node struct {
        elem string
        score float64
        backward *Node
        level []skipLevel
    }
    
    type skipLevel struct {
        // forward 每层都要有指向下一个节点的指针
        forward *Node
        // span 间隔定义为:从当前节点到 forward 指向的下个节点之间间隔的节点数
        span int
    }
    
    type Skiplist struct {
        header, tail *Node
        level int // 记录跳表的实际高度
        length int // 记录跳表的长度(不含头节点)
    }
    

    辅助方法

    考虑到在实现时,经常需要比较 score 和 element,所以这里直接给 Node 实现了一些比较方法,便于使用。

    func (node *Node) Compare(other *Node) int {
        if node.score < other.score || (node.score == other.score && node.elem < other.elem) {
            return -1
        } else if node.score > other.score || (node.score == other.score && node.elem > other.elem) {
            return 1
        } else {
            return 0
        }
    }
    
    func (node *Node) Lt(other *Node) bool {
        return node.Compare(other) < 0
    }
    
    func (node *Node) Lte(other *Node) bool {
        return node.Compare(other) <= 0
    }
    
    func (node *Node) Gt(other *Node) bool {
        return node.Compare(other) > 0
    }
    
    func (node *Node) Eq(other *Node) bool {
        return node.Compare(other) == 0
    }
    

    插入元素

    // Insert 向跳表中插入一个新的元素。
    // 步骤:
    // 1. 查找插入位置
    // 2. 创建新节点,并在目标位置插入节点
    // 3. 调整跳表 backward 指针等
    func (sl *Skiplist) Insert(score float64, elem string) *Node {
        var (
            // update 用于记录每层待更新的节点
            update [MaxLevel]*Node
            // rank 用来记录每层经过的节点记录(可以看成到头节点的距离)
            rank [MaxLevel]int
            // 构建一个新节点,用于下面的大小判断,其 level 在后面设置
            node = &Node{score: score, elem: elem}
        )
        cur := sl.header
        for i := sl.level - 1; i >= 0; i-- {
            if cur == sl.header {
                rank[i] = 0
            } else {
                rank[i] = rank[i+1]
            }
            // 与同层的后一个节点比较,如果后一个比目标值小,则可以继续向后
            // 否则下降到一层查找。注意这里的大小比较是按照 score 和
            // elem 综合计算得到的。
            for cur.level[i].forward != nil && cur.level[i].forward.Lt(node) {
                rank[i] += cur.level[i].span
                // 同层继续往后查找
                cur = cur.level[i].forward
            }
            update[i] = cur
        }
        // 调整跳表高度
        level := sl.randomLevel()
        if level > sl.level {
            // 初始化每层
            for i := level - 1; i >= sl.level; i-- {
                rank[i] = 0
                update[i] = sl.header
                update[i].level[i].span = sl.length
            }
            sl.level = level
        }
        // 更新节点 level,并插入新节点
        node.setLevel(level)
        for i := 0; i < level; i++ {
            // 更新每层的节点指向
            node.level[i].forward = update[i].level[i].forward
            update[i].level[i].forward = node
            // 更新 span 信息
            node.level[i].span = update[i].level[i].span - (rank[0] - rank[i])
            update[i].level[i].span = (rank[0] - rank[i]) + 1
        }
        // 针对新增节点 level < sl.level 的情况,需要更新上面没有扫到的层 span
        for i := level; i < sl.level; i++ {
            update[i].level[i].span++
        }
        // 调整 backward 指针
        // 如果前一个节点是头节点,则 backward 为 nil
        // 否则 backward 指向之前节点
        if update[0] != sl.header {
            // update[0] 就是和新增节点相邻的前一个节点
            node.backward = update[0]
        }
        // 如果新增节点是最后一个,则需要更新 tail 指针
        if node.level[0].forward == nil {
            sl.tail = node
        } else {
            // 中间节点,需要更新后一个节点的回退指针
            node.level[0].forward.backward = node
        }
        sl.length++
        return node
    }
    
    // randomLevel 对于新增节点,返回一个随机的 level
    // 返回的 level 范围为 [1, MaxLevel]。并且,采用的
    // 算法会保证,更大的 level 返回的概率越低。
    // 每个 level 出现的概率计算:(1-p) * p^(level-1)
    func (sl *Skiplist) randomLevel() int {
        level := 1
        for rand.Float64() < P && level < MaxLevel {
            level++
        }
        return level
    }
    

    删除元素

    // Delete 用于删除跳表中指定的节点。
    func (sl *Skiplist) Delete(score float64, elem string) *Node {
        // 第一步,找到需要删除节点
        var (
            update [MaxLevel]*Node
            targetNode = &Node{elem: elem, score: score}
        )
        cur := sl.header
        for i := sl.level - 1; i >= 0; i-- {
            for cur.level[i].forward != nil && cur.level[i].forward.Lt(targetNode) {
                cur = cur.level[i].forward
            }
            update[i] = cur
        }
        // 目标节点找到后,这里需要判断下 elem 是否相等
        // score 可以重复,所以必须要谨慎
        nodeToBeDeleted := update[0].level[0].forward
        if nodeToBeDeleted == nil || !nodeToBeDeleted.Eq(targetNode) {
            return nil
        }
        sl.deleteNode(update, nodeToBeDeleted)
        return nodeToBeDeleted
    }
    
    func (sl *Skiplist) deleteNode(update [64]*Node, nodeToBeDeleted *Node) {
        // 这时我们要删除的节点就是 nodeToBeDeleted
        // 调整每层待更新节点,修改 forward 指向
        for i := 0; i < sl.level; i++ {
            if update[i].level[i].forward == nodeToBeDeleted {
                update[i].level[i].forward = nodeToBeDeleted.level[i].forward
                update[i].level[i].span += nodeToBeDeleted.level[i].span - 1
            } else {
                update[i].level[i].span--
            }
        }
        // 调整回退指针:
        // 1. 如果被删除的节点是最后一个节点,需要更新 sl.tail
        // 2. 如果被删除的节点位于中间,则直接更新后一个节点 backward 即可
        if sl.tail == nodeToBeDeleted {
            sl.tail = nodeToBeDeleted.backward
        } else {
            nodeToBeDeleted.level[0].forward.backward = nodeToBeDeleted.backward
        }
        // 调整层数
        for sl.header.level[sl.level-1].forward == nil {
            sl.level--
        }
        // 减少节点计数
        sl.length--
        nodeToBeDeleted.backward = nil
        nodeToBeDeleted.level[0].forward = nil
    }
    

    更新分数

    // UpdateScore 用于更新节点的分数。该函数会保证更新分数后,
    // 节点的有序性依然可以维持。
    // 策略如下:
    // 1. 快速判断能否原节点修改,如果可以则直接修改并返回;
    // 2. 采用更加昂贵的操作:删除再添加。
    func (sl *Skiplist) UpdateScore(curScore float64, elem string, newScore float64) *Node {
        var (
            update [MaxLevel]*Node
            targetNode = &Node{elem: elem, score: curScore}
        )
        cur := sl.header
        // 第一步,找到符合条件的目标节点
        for i := sl.level - 1; i >= 0; i-- {
            for cur.level[i].forward != nil && cur.level[i].forward.Lt(targetNode) {
                cur = cur.level[i].forward
            }
            update[i] = cur
        }
        node := cur.level[0].forward
        if node == nil || !node.Eq(targetNode) {
            return nil
        }
        if sl.canUpdateScoreFor(node, newScore) {
            node.score = newScore
            return node
        } else {
            // 需要删除旧节点,增加新节点
            sl.deleteNode(update, node)
            return sl.Insert(newScore, node.elem)
        }
    }
    
    // canUpdateScoreFor 确定能否直接在原有的节点上进行修改
    // 什么条件才可以直接原地更新 score 呢?
    // 1. node 是唯一一个数据节点(node.backward == NULL && node->level[0].forward == NULL)
    // 2. node 是第一个数据节点,且新的分数要比 node 之后节点分数要小(这样才能保证有序)
    // 即:node.backward == NULL && node->level[0].forward->score > newScore)
    // 3. node 是最后一个数据节点,且 node 之前节点的分数要比新改的分数小
    // 即:node->backward->score < newScore && node->level[0].forward == NULL
    // 4. node 是修改的后的分数恰好还能保证位于前一个和后一个节点分数之间
    // 即:node->backward->score < newscore && node->level[0].forward->score > newscore
    func (sl *Skiplist) canUpdateScoreFor(node *Node, newScore float64) bool {
        if (node.backward == nil || node.backward.score < newScore) &&
            (node.level[0].forward == nil || node.level[0].forward.score > newScore) {
            return true
        }
    
        return false
    }
    

    总结

    俗话说,「说起来容易,做起来难」。在实现「跳表」的时候感受颇深,似乎看完 Redis 的「跳表」源码和网上诸多前辈编写的文章后,自以为懂得了原理(可能确实懂了),但是在具体实现的时候还是踩了不少坑。比如,空指针引起 panic;i-- 写成了 i++ 导致查找失败;一些边界情况的判断等。总之,细节决定成败,需要在保持思路清晰的同时,更加谨慎一些才能写出足够健壮的代码来。当然,这期间自然少不了单元测试的助攻,否则有很多问题可能都没法暴露出来~

    参考

    声明

    相关文章

      网友评论

        本文标题:动手实现 Redis 跳表(Go 语言)

        本文链接:https://www.haomeiwen.com/subject/mpwagctx.html