美文网首页
Golang源码分析之sort

Golang源码分析之sort

作者: vouv | 来源:发表于2020-02-19 13:11 被阅读0次

    排序是工程中必不可少的功能,很多编程语言SDK都提供了排序相关的实现。作为软件工程师,我们在学习各类排序算法的同时,是否有思考过,如何去实现一个工业级的排序算法?如果你是Go语言的作者之一,该如何去实现一种能适应多种情况的排序算法?

    Go SDK中排序相关的实现主要在sort/sort.go中,本文主要基于该文件进行相关实现的分析。

    首先来看看Go对排序接口的定义,利用Go的interface特性可以轻松实现多种数据类型的排序功能。想要调用sort包的排序功能我们需要实现这个排序接口,排序接口主要定义了三个方法:

    • Len() int: 返回传入数据的总数
    • Less(i, j int) bool: 返回数组中下标为i的数据是否小于下标为j的数据
    • Swap(i, j int): 表示执行交换数组中下标为i的数据和下标为j的数据
    // A type, typically a collection, that satisfies sort.Interface can be
    // sorted by the routines in this package. The methods require that the
    // elements of the collection be enumerated by an integer index.
    type Interface interface {
        // Len is the number of elements in the collection.
        Len() int
        // Less reports whether the element with
        // index i should sort before the element with index j.
        Less(i, j int) bool
        // Swap swaps the elements with indexes i and j.
        Swap(i, j int)
    }
    

    了解了包中对sort接口的定义后,再来看看sort包对外提供的主要接口Sort,源码如下:

    // Sort sorts data.
    // It makes one call to data.Len to determine n, and O(n*log(n)) calls to
    // data.Less and data.Swap. The sort is not guaranteed to be stable.
    func Sort(data Interface) {
        n := data.Len()
        quickSort(data, 0, n, maxDepth(n))
    }
    

    如注释所说,当我们调用Sort方法时,该方法会调用一次data.Len(),之后会以O(n*log(n))的时间复杂度调用data.Lessdata.Swap。我们可以看到,Sort内部调用了包私有的quickSort方法,也就是我们熟悉的快排,同时传了4个参数,学过快排的同学都能理解前三个参数的含义,但是我们还看到了一个陌生的函数调用maxDepth(n),这里的depth究竟代表什么呢?所以先探究一下这个函数,代码如下:

    // maxDepth returns a threshold at which quicksort should switch
    // to heapsort. It returns 2*ceil(lg(n+1)).
    func maxDepth(n int) int {
        var depth int
        for i := n; i > 0; i >>= 1 {
            depth++
        }
        return depth * 2
    }
    

    简单来说,maxDepth方法返回的深度表示了数据的量级,qiuckSort方法会根据这个量级选择使用快排还是堆排序,学过堆排序的同学都知道,堆排序的时间复杂度稳定在O(nlogn),有时候比快排还稳定,但是堆排序对数据是跳着访问的,对CPU缓存不友好。

    了解了maxDepth方法以后就可以来看看quickSort的源码了

    func quickSort(data Interface, a, b, maxDepth int) {
        for b-a > 12 { // Use ShellSort for slices <= 12 elements
            if maxDepth == 0 {
                heapSort(data, a, b)
                return
            }
            maxDepth--
            mlo, mhi := doPivot(data, a, b)
            // Avoiding recursion on the larger subproblem guarantees
            // a stack depth of at most lg(b-a).
            if mlo-a < b-mhi {
                quickSort(data, a, mlo, maxDepth)
                a = mhi // i.e., quickSort(data, mhi, b)
            } else {
                quickSort(data, mhi, b, maxDepth)
                b = mlo // i.e., quickSort(data, a, mlo)
            }
        }
        if b-a > 1 {
            // Do ShellSort pass with gap 6
            // It could be written in this simplified form cause b-a <= 12
            for i := a + 6; i < b; i++ {
                if data.Less(i, i-6) {
                    data.Swap(i, i-6)
                }
            }
            insertionSort(data, a, b)
        }
    }
    

    这里代码的实现方式比较好理解,首先对于数组元素大于12个的情况会在快排和堆排之间选择,除此之外的情况会使用希尔排序(间隔为6)和插入排序进行排序。

    包中对于heapSort的实现中规中矩,使用从上往下堆化的方式建堆。这里就不详细介绍,对于快排的实现方式,有的同学就发现不同了,这里调用了一个寻找分区点的函数doPivot,但是doPivot返回了两个值(这里就利用了Go中函数可以有多个返回值的特性)。同时这里可以看到返回mlo,mhi以后并没有继续递归地在左右分区查找,而是做了一个比较,原因也正如注释所说,由于使用了递归的方式实现排序,就必须要考虑到栈溢出的问题,所以对分区的两半,把数量多的放到下一次循环继续切分循环,小的直接递归。这里也表明了调用quickSort的最高栈深度为log(b-a),也就是log(n)。

    接下来可以看看doPivot函数,为什么会返回两个分区点呢?因为mlo到mhi之间的数已经被确定了位置,这里考虑到取中位数的时候数组出现大量重复的数会影响到排序性能的问题,可以发现Go作者对这种情况的解决方式充满着智慧。具体代码如下:

    func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
        m := int(uint(lo+hi) >> 1) // 首先用位运算的方式求中间点,防止溢出
        if hi-lo > 40 {
                    //  多数取中
            // Tukey's ``Ninther,'' median of three medians of three.
            s := (hi - lo) / 8
            medianOfThree(data, lo, lo+s, lo+2*s)
            medianOfThree(data, m, m-s, m+s)
            medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
        }
        medianOfThree(data, lo, m, hi-1)
    
        // 接下来要对数据达成以下划分结果
        //  data[lo] = pivot (set up by ChoosePivot)
        //  data[lo < i < a] < pivot
        //  data[a <= i < b] <= pivot
        //  data[b <= i < c] unexamined
        //  data[c <= i < hi-1] > pivot
        //  data[hi-1] >= pivot
        pivot := lo
        a, c := lo+1, hi-1
    
        for ; a < c && data.Less(a, pivot); a++ {
        }
        b := a
        for {
            for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
            }
            for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
            }
            if b >= c {
                break
            }
            // data[b] > pivot; data[c-1] <= pivot
            data.Swap(b, c-1)
            b++
            c--
        }
            // 如果data[c <= i < hi-1] > pivot,hi-c<3 这表明数据中有重复的数,
            // 这里保守一些,认为hi-c<5 为边界,如果重复的数较多,
            // 会以直接扫描跳过的方式把pivot左右两边的区间缩小
        // If hi-c<3 then there are duplicates (by property of median of nine).
        // Let's be a bit more conservative, and set border to 5.
        protect := hi-c < 5
        if !protect && hi-c < (hi-lo)/4 {
            // Lets test some points for equality to pivot
            dups := 0
            if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
                data.Swap(c, hi-1)
                c++
                dups++
            }
            if !data.Less(b-1, pivot) { // data[b-1] = pivot
                b--
                dups++
            }
            // m-lo = (hi-lo)/2 > 6
            // b-lo > (hi-lo)*3/4-1 > 8
            // ==> m < b ==> data[m] <= pivot
            if !data.Less(m, pivot) { // data[m] = pivot
                data.Swap(m, b-1)
                b--
                dups++
            }
            // if at least 2 points are equal to pivot, assume skewed distribution
            protect = dups > 1
        }
        if protect {
            // Protect against a lot of duplicates
            // Add invariant:
            //  data[a <= i < b] unexamined
            //  data[b <= i < c] = pivot
            for {
                for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
                }
                for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
                }
                if a >= b {
                    break
                }
                // data[a] == pivot; data[b-1] < pivot
                data.Swap(a, b-1)
                a++
                b--
            }
        }
        // Swap pivot into middle
        data.Swap(pivot, b-1)
        return b - 1, c
    }
    
    

    相关文章

      网友评论

          本文标题:Golang源码分析之sort

          本文链接:https://www.haomeiwen.com/subject/szrnfhtx.html