美文网首页
协程的信号量:源码浅析

协程的信号量:源码浅析

作者: 小城哇哇 | 来源:发表于2023-05-24 15:01 被阅读0次

在《协程的信号量》中,我们简单介绍了 Semophore 及其使用,当时提到过它的默认实现:SemaphoreImpl,但没细讲。今天就来着重讨论一下这货吧 —— 正所谓「阅读源码乃学习的良药」,您说是不?

SemaphoreImpl 源码分析

首先,再回顾一个 Semaphore 接口吧:

public interface Semaphore {
    public val availablePermits: Int
    public suspend fun acquire()
    public fun tryAcquire(): Boolean
    public fun release()
}

总共四个方法,其中一个还是域的 getter,要分析实现,自然得从这四个方法入手。

SemaphoreImpl 的源码如下:

private class SemaphoreImpl(private val permits: Int, acquiredPermits: Int) : Semaphore {
    private val head: AtomicRef<SemaphoreSegment>
    private val deqIdx = atomic(0L)
    private val tail: AtomicRef<SemaphoreSegment>
    private val enqIdx = atomic(0L)

    init {
        // 要求「允许数」一定大于0
        require(permits > 0) { "Semaphore should have at least 1 permit, but had $permits" }
        // 要求「已请求数」一定大于0且小于总允许数
        require(acquiredPermits in 0..permits) { "The number of acquired permits should be in 0..$permits" }
        val s = SemaphoreSegment(0, null, 2)
        head = atomic(s)
        tail = atomic(s)
    }

    /**
     * This counter indicates a number of available permits if it is non-negative,
     * or the size with minus sign otherwise. Note, that 32-bit counter is enough here
     * since the maximal number of available permits is [permits] which is [Int],
     * and the maximum number of waiting acquirers cannot be greater than 2^31 in any
     * real application.
     */
    private val _availablePermits = atomic(permits - acquiredPermits)
    override val availablePermits: Int get() = max(_availablePermits.value, 0)

    private val onCancellationRelease = { _: Throwable -> release() }

    override fun tryAcquire(): Boolean {
        _availablePermits.loop { p ->
            if (p <= 0) return false
            if (_availablePermits.compareAndSet(p, p - 1)) return true
        }
    }

    override suspend fun acquire() {
        val p = _availablePermits.getAndDecrement()
        if (p > 0) return // permit acquired
        // While it looks better when the following function is inlined,
        // it is important to make `suspend` function invocations in a way
        // so that the tail-call optimization can be applied.
        acquireSlowPath()
    }

    private suspend fun acquireSlowPath() = suspendCancellableCoroutineReusable<Unit> sc@ { cont ->
        while (true) {
            if (addAcquireToQueue(cont)) return@sc
            val p = _availablePermits.getAndDecrement()
            if (p > 0) { // permit acquired
                cont.resume(Unit, onCancellationRelease)
                return@sc
            }
        }
    }

    override fun release() {
        while (true) {
            val p = _availablePermits.getAndUpdate { cur ->
                check(cur < permits) { "The number of released permits cannot be greater than $permits" }
                cur + 1
            }
            if (p >= 0) return
            if (tryResumeNextFromQueue()) return
        }
    }

    /**
     * Returns `false` if the received permit cannot be used and the calling operation should restart.
     */
    private fun addAcquireToQueue(cont: CancellableContinuation<Unit>): Boolean {
        val curTail = this.tail.value
        val enqIdx = enqIdx.getAndIncrement()
        val segment = this.tail.findSegmentAndMoveForward(id = enqIdx / SEGMENT_SIZE, startFrom = curTail,
            createNewSegment = ::createSegment).segment // cannot be closed
        val i = (enqIdx % SEGMENT_SIZE).toInt()
        // the regular (fast) path -- if the cell is empty, try to install continuation
        if (segment.cas(i, null, cont)) { // installed continuation successfully
            cont.invokeOnCancellation(CancelSemaphoreAcquisitionHandler(segment, i).asHandler)
            return true
        }
        // On CAS failure -- the cell must be either PERMIT or BROKEN
        // If the cell already has PERMIT from tryResumeNextFromQueue, try to grab it
        if (segment.cas(i, PERMIT, TAKEN)) { // took permit thus eliminating acquire/release pair
            /// This continuation is not yet published, but still can be cancelled via outer job
            cont.resume(Unit, onCancellationRelease)
            return true
        }
        assert { segment.get(i) === BROKEN } // it must be broken in this case, no other way around it
        return false // broken cell, need to retry on a different cell
    }

    @Suppress("UNCHECKED_CAST")
    private fun tryResumeNextFromQueue(): Boolean {
        val curHead = this.head.value
        val deqIdx = deqIdx.getAndIncrement()
        val id = deqIdx / SEGMENT_SIZE
        val segment = this.head.findSegmentAndMoveForward(id, startFrom = curHead,
            createNewSegment = ::createSegment).segment // cannot be closed
        segment.cleanPrev()
        if (segment.id > id) return false
        val i = (deqIdx % SEGMENT_SIZE).toInt()
        val cellState = segment.getAndSet(i, PERMIT) // set PERMIT and retrieve the prev cell state
        when {
            cellState === null -> {
                // Acquire has not touched this cell yet, wait until it comes for a bounded time
                // The cell state can only transition from PERMIT to TAKEN by addAcquireToQueue
                repeat(MAX_SPIN_CYCLES) {
                    if (segment.get(i) === TAKEN) return true
                }
                // Try to break the slot in order not to wait
                return !segment.cas(i, PERMIT, BROKEN)
            }
            cellState === CANCELLED -> return false // the acquire was already cancelled
            else -> return (cellState as CancellableContinuation<Unit>).tryResumeAcquire()
        }
    }

    private fun CancellableContinuation<Unit>.tryResumeAcquire(): Boolean {
        val token = tryResume(Unit, null, onCancellationRelease) ?: return false
        completeResume(token)
        return true
    }
}

那就从最简单的 getter 讲起吧

1. availablePermits

val availablePermits: Int,就是前面所说的 getter 方法。

private val _availablePermits = atomic(permits - acquiredPermits)
override val availablePermits: Int get() = max(_availablePermits.value, 0)

私有变量 _availablePermits 存放实际的可用信号量数,初始值为:允许数和已请求数的差 —— 就是剩余的量嘛;而对外的实现接口,取该值,并通过 max 作了「非负保护」。

值得注意的是,_availablePermits 的类型为 AtomicInt,用以保证原子操作。

2. tryAcquire()

tryAcquire 实现「尝试性请求」,当无可用信号量时,不挂起而抛弃请求。

override fun tryAcquire(): Boolean {
    _availablePermits.loop { p ->
        if (p <= 0) return false
        if (_availablePermits.compareAndSet(p, p - 1)) return true
    }
}

其中,loopAtomicInt 的扩展方法。保证其值可用的逻辑,是一个lambda。原型如下:

public class AtomicInt internal constructor(
    /** Get/set of this property maps to read/write of volatile variable */
    @Volatile var value: Int
) {
// ...
}

inline fun AtomicInt.loop(block: (Int) -> Unit): Nothing {
    while (true) { // 死循环
        block(value)
    }
}

如果值 _availablePermits 小于1,即无可用信号量,返回 false,尝试失败;否则,成功,而后_availablePermits 值减 1,并返回 true。

这里的「尝试获取」的获取,其实就是调用 AtomicInt.compareAndSet() 方法。因为原子操作可能失败,就会导致整个 lambda 无 return,于是,loop 再来一次 —— 这也就是为啥 loop 方法内为死循环的原因。

3. acquire()

acquire() 就是普通的信号量请求了,如果当前无可用的,将会挂起等待,直到有可用信号量时,恢复执行。

override suspend fun acquire() {
    val p = _availablePermits.getAndDecrement()
    if (p > 0) return // permit acquired
    // While it looks better when the following function is inlined,
    // it is important to make `suspend` function invocations in a way
    // so that the tail-call optimization can be applied.
    acquireSlowPath()
}

直接调用了「可用数量」的自减函数 getAndDecrement():拿到当前的信号量可用数,消耗而减掉一个;如果当前可用数大于 0,就请求成功,否则进入等待。

这个等待,就是后面的 acquireSlowPath() 做到的。

SemaphoreSegment

在继续分析 acquireSlowPath() 之前,先来看看 SemaphoreSegment

private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?, pointers: Int) : Segment<SemaphoreSegment>(id, prev, pointers) {
    val acquirers = atomicArrayOfNulls<Any?>(SEGMENT_SIZE)
    override val maxSlots: Int get() = SEGMENT_SIZE // SEGMENT_SIZE 值默认为 16,即 16 个槽

    @Suppress("NOTHING_TO_INLINE")
    inline fun get(index: Int): Any? = acquirers[index].value

    @Suppress("NOTHING_TO_INLINE")
    inline fun set(index: Int, value: Any?) {
        acquirers[index].value = value
    }

    @Suppress("NOTHING_TO_INLINE")
    inline fun cas(index: Int, expected: Any?, value: Any?): Boolean = acquirers[index].compareAndSet(expected, value)

    @Suppress("NOTHING_TO_INLINE")
    inline fun getAndSet(index: Int, value: Any?) = acquirers[index].getAndSet(value)

    // Cleans the acquirer slot located by the specified index
    // and removes this segment physically if all slots are cleaned.
    fun cancel(index: Int) {
        // Clean the slot
        set(index, CANCELLED)
        // Remove this segment if needed
        onSlotCleaned()
    }

    override fun toString() = "SemaphoreSegment[id=$id, hashCode=${hashCode()}]"
}

SemaphoreSegment 是一个 Segment 的实现,继承关系为:SemaphoreSegment --> Segment --> ConcurrentLinkedListNode<S>,说白了,其实就是一个链表结点类型。

结点本身,还实现了类似原子操作的方法,比如 getAndSet,因为内部就是调用的 getAndSet

内部维护一个数组 acquirers,长度默认为 16, set, get 这些操作,最终都是对应索引下的对象。怎么用呢?继续看 acquireSlowPath 干了什么。

acquireSlowPath

private suspend fun acquireSlowPath() = suspendCancellableCoroutineReusable<Unit> sc@ { cont ->
        // 死循环判断逻辑
        while (true) {
            if (addAcquireToQueue(cont)) return@sc
            val p = _availablePermits.getAndDecrement() // 自减
            if (p > 0) { // permit acquired
                cont.resume(Unit, onCancellationRelease)
                return@sc
            }
        }
    }

acquireSlowPath 是一个 suspend 函数,由 suspendCancellableCoroutineReusable 构造协程。

第一个 if, 先不管它。之后,自减拿到当前的允许数 p,如果其值大于 0,即表示当前可用,恢复当前的挂起,协程完成。如果值为 0,或者甚至是负的,那死循环继续。

接下来,回头来看第一个 if 的 addAcquireToQueue

addAcquireToQueue

这个函数,顾名思义:添加请求至队列,那又是什么意思呢?我们直接从源码上梳理一下。

private val tail: AtomicRef<SemaphoreSegment> // 当前结点原子引用
private val enqIdx = atomic(0L) // 当前结点索引

private fun addAcquireToQueue(cont: CancellableContinuation<Unit>): Boolean {
        val curTail = this.tail.value
        val enqIdx = enqIdx.getAndIncrement() // 自增获取下一个 index
        // 找到下一个结点,如果没有,就创建
        val segment = this.tail.findSegmentAndMoveForward(id = enqIdx / SEGMENT_SIZE, startFrom = curTail,
            createNewSegment = ::createSegment).segment // cannot be closed
        val i = (enqIdx % SEGMENT_SIZE).toInt()
        // 插入结点,当前为 null,则成功插入,结点值设为 continuation,然后返回true
        // 在调用方 acquireSlowPath 处,则 return 了 —— 但是不能 resume,只能等着被cancel
        if (segment.cas(i, null, cont)) {
            cont.invokeOnCancellation(CancelSemaphoreAcquisitionHandler(segment, i).asHandler)
            return true
        }
        // 当前为允许(PERMIT),设为获取(TAKEN)成功,则返回true
        if (segment.cas(i, PERMIT, TAKEN)) {
            // 这里成功获取,然后再标记为「已用」,TAKEN,所以可以 resume 了,
            cont.resume(Unit, onCancellationRelease)
            return true
        }
        assert { segment.get(i) === BROKEN } // it must be broken in this case, no other way around it
        return false // broken cell, need to retry on a different cell
    }

// 创建信息量结点 
private fun createSegment(id: Long, prev: SemaphoreSegment?) = SemaphoreSegment(id, prev, 0)

简单来说:通过一个链表,维护了一个信号量结点族(segment),每个结点有不同状态(PERMIT, TAKEN等),用以实现信号量的回收和重复利用。

acquireSlowPath 内部一触发 resume,挂起则恢复,信号量成功获取,即 acquire 完成,可以执行任务了。

release()

第四个方法,信号量的释放。同样,直接在源码上梳理:

private val head: AtomicRef<SemaphoreSegment>
private val deqIdx = atomic(0L)

override fun release() {
    while (true) {
        val p = _availablePermits.getAndUpdate { cur ->
        // release 要求,当前的信号量不能满
            check(cur < permits) { "The number of released permits cannot be greater than $permits" }
            cur + 1 // 可用数加1
        }
        if (p >= 0) return // p是当前值,非负,证明release后就大于0,成功release
        if (tryResumeNextFromQueue()) return
    }
}

如果当前可用数为负,将进入最后一个 if:tryResumeNextFromQueue

tryResumeNextFromQueue

tryResumeNextFromQueue 从功能上来讲,和前面 acquire 的 addAcquireToQueue 相对应,前者「加」,后者「减」—— 从队列中「减」出一个可用结点

private fun tryResumeNextFromQueue(): Boolean {
        val curHead = this.head.value
        val deqIdx = deqIdx.getAndIncrement()
        val id = deqIdx / SEGMENT_SIZE
        val segment = this.head.findSegmentAndMoveForward(id, startFrom = curHead,
            createNewSegment = ::createSegment).segment // cannot be closed
        segment.cleanPrev()
        // 找出的 segment 的 id 不对,直接返回 false —— 失败
        if (segment.id > id) return false
        val i = (deqIdx % SEGMENT_SIZE).toInt()
        val cellState = segment.getAndSet(i, PERMIT) // 获取当前状态,同时更新为 PERMIT
        when {
            cellState === null -> {
                // 之前为 null,相当于从未使用过,重复一个判断,直到它变为 TAKEN,然后返回 true —— 成功
                repeat(MAX_SPIN_CYCLES) {
                    if (segment.get(i) === TAKEN) return true
                }
                // 如果前面的循环判断失效,那就设为 BROKEN 吧
                return !segment.cas(i, PERMIT, BROKEN)
            }
            cellState === CANCELLED -> return false // 当前其他对于此结点的请求,已经取消了,返回false,重新进入 release 逻辑循环
            else -> return (cellState as CancellableContinuation<Unit>).tryResumeAcquire() // 其他情况,成功 resume,则视为成功 release
        }
    }
    
private fun CancellableContinuation<Unit>.tryResumeAcquire(): Boolean {
        val token = tryResume(Unit, null, onCancellationRelease) ?: return false
        completeResume(token)
        return true
    }

小结

这一路分析下来我们发现,信号量数量处理,不单是一个数字就可以做到的。「当前可用信号量」确实是用一个数字来控制(availablePermits),但是信号量总有用完的时候,这时候的请求都需要 suspend,然后等待可用信号到来的。这个 suspend 操作,就需要额外的逻辑来处理了(SemaphoreSegment),而这一块看来,并不简单,涉及很多原子操作的实现,以及协程的处理。

文章略长,难免有误或者疏漏,欢迎留言讨论,不吝赐教!

相关文章

网友评论

      本文标题:协程的信号量:源码浅析

      本文链接:https://www.haomeiwen.com/subject/nwafedtx.html