美文网首页
ConcurrentLinkedQueue源码分析(Java8)

ConcurrentLinkedQueue源码分析(Java8)

作者: 超有为青年 | 来源:发表于2017-06-22 11:27 被阅读0次

    最近一段时间在看一本书《Java并发编程的艺术》,在P164讲到了关于ConcurrentLinkedQueue的源码分析,但是这部分源码非常复杂,于是我又顺手看了一下IDEA的Java源码,发现在Java8中,该部分的源码已经被更新过了,正好读一读顺带做个笔记。

    基本介绍

    ConcurrentLinkedQueue是一个列表实现,包括一个head和tail引用,该类的初始化过程中,头尾引用都被初始化成一个空的Node,下面我们可以看到相关代码:

    public class ConcurrentLinkedQueue<E> extends AbstractQueue<E>
            implements Queue<E>, java.io.Serializable {
    
            private static class Node<E> {
                volatile E item;
                volatile Node<E> next;   
            }
    
            private transient volatile Node<E> head;
            
            private transient volatile Node<E> tail;
            
            public ConcurrentLinkedQueue() {
                head = tail = new Node<E>(null);
            }
    }
    

    入队流程

    单线程下的入队流程为:

    1. 将新节点加入到tail引用的next中
    2. 将新节点赋值给tail引用

    但是在多线程环境中,需要保障其他线程入队和出队不受影响,ConcurrentLinkedQueue由CAS算法实现了无锁入队,下面是加入节点的关键代码:

    public boolean offer(E e) {
        checkNotNull(e);
        final Node<E> newNode = new Node<E>(e);
    
        // 循环开始,p和t都指向tail,q指向tail的next
        for (Node<E> t = tail, p = t;;) {
            Node<E> q = p.next;
            if (q == null) {
                // q为null代表目前tail后面没有其他线程插入的节点,即p确实是最后的节点
                if (p.casNext(null, newNode)) {
                    // 这里casNext函数的作用是当p的next节点为null时,用newNode更新p的next节点,更新成功返回true
    
                    // 如果casNext更新成功,证明newNode已经成功插入到队尾
                    if (p != t)
                        // 这一步判断表明,t即tail已经不是真正的队尾引用,这是减少cas操作的一步优化
    
                        // 这里casTail函数的作用是当tail与t相等时,用newNode更新tail,在这里CAS失败也没有关系
                        casTail(t, newNode);
                    return true;
                }
                // 如果casNext更新失败,则重新将p的next赋值给q
            }
            else if (p == q)
                // 当p==q只有一种情况,即p==p.next,在这种情况下就表明当前节点已经离队,因为在出队操作之后,ConcurrentLinkedQueue会将出队节点的next设为它本身
    
                // 在遇到当前节点已经是出队节点的情况下,表明当前节点已经在head之前,因此根据如下逻辑进行更新当前节点:1、如果tail已经更新,那么将当前节点设为tail;2、否则,将当前节点设为head,因为不能保证tail指向的节点是否已经离队
                p = (t != (t = tail)) ? t : head;
            else
                // 当tail更新且p不在tail时,用tail更新p,否则用q更新p
                p = (p != t && t != (t = tail)) ? t : q;
        }
    }
    

    如果觉得上述方法过于复杂,我们可以用一种更简单的方案来进行结果相同的操作:

    public boolean offer(E e) {
        checkNotNull(e);
        final Node<E> newNode = new Node<E>(e);
    
        for (; ; ) {
            Node<E> t = tail;
            if (t.casNext(null, newNode)) {
                // 参照单线程的入队流程,casNext成功表明newNode已经成功插入到了队列里
    
                // 如果casTail失败了也没有关系,失败了证明有其他的线程在进行casTail,至少有一根线程可以成功
                casTail(t, newNode);
                return true;
            }
        }
    }
    

    而在JDK源码中,加入了一步优化,这步优化是:在插入一个新节点时,不着急将tail指向这个新节点,而是在插入第二个新节点的时候,才对tail进行cas操作。
    这样做会导致两个问题:

    1. tail并不在保持原有的一定指向队尾的性质;
    2. 从tail开始需要进过几步查找next才能寻找到真正的队尾;

    但是这样做有一个好处:减少了至少一半的cas操作,虽然增加了普通的赋值操作,但是在多线程情况下cas操作的耗时要远远大于一般赋值操作的耗时,因此这部分优化可以增大该容器类的并发量。而剩下部分的判断都是为了在进行这一步优化的情况下,保证程序的正确性所做的。

    出队流程

    单线程情况下的出队流程为:

    1. 如果head==tail,证明队列为空,返回null
    2. 将队首元素的值取出,作为返回值
    3. 将head指向head.next

    如果按照这种思路,我们可以直接写出一个简单写法的无锁出队方案:

    public E poll() {
        for (; ; ) {
            Node<E> h = head;
            if (h.next == null) {
                return null;
            } else {
                if (casHead(h, h.next)) {
                    if (h.next != null)
                        return h.next.item;
                }
            }
        }
    }
    

    我们再来看JDK源码中的poll函数实现,在这个poll函数中,使用了和offer函数中类似的优化方式,在出队的时候并不着急更新head的值,而是缓慢更新,然后用一部分操作来保证出队的正确性:

    public E poll() {
        restartFromHead:
        for (; ; ) {
            for (Node<E> h = head, p = h, q; ; ) {
                E item = p.item;
    
                if (item != null && p.casItem(item, null)) {
                    if (p != h)
                        updateHead(h, ((q = p.next) != null) ? q : p);
                    return item;
                } else if ((q = p.next) == null) {
                    updateHead(h, p);
                    return null;
                } else if (p == q)
                    continue restartFromHead;
                else
                    p = q;
            }
        }
    }
    

    性能测试

    这里不光是性能测试,同样有针对上述两种简单的无锁入队和出队的正确性测试。我分别开了2根入队线程和2根出队线程,每根入队线程循环入队1000W的数据,下面展示了测试结果(因为我的电脑是4核i5,比较弱鸡,如果线程开多了那么大量的时间都在线程切换上,测试结果就不准确了):

    使用JDK源码

    Test Started: 11:15 25:839
    Get thread finished, Total: 10809006
    Get thread finished, Total: 9190994
    Test Finished: 11:15 31:487
    Total Time Cost: 5s 648ms

    使用自定义的offer函数

    Test Started: 11:17 36:963
    Get thread finished, Total: 9335745
    Get thread finished, Total: 10664255
    Test Finished: 11:17 41:627
    Total Time Cost: 4s 664ms

    使用自定义的poll函数

    Test Started: 11:18 17:412
    Get thread finished, Total: 9714954
    Get thread finished, Total: 10285046
    Test Finished: 11:18 21:669
    Total Time Cost: 4s 257ms

    同时使用自定义的offer和poll函数

    Test Started: 11:18 51:663
    Get thread finished, Total: 10219132
    Get thread finished, Total: 9780868
    Test Finished: 11:18 56:602
    Total Time Cost: 4s 939ms

    有点尴尬的是好像优化过的源码是跑的最慢的,应该和我只有2根读写线程有关,争抢的情况比较少,争抢情况越严重,线程越多,源码的速度应该是更快的。如果谁有更好的机器可以拿代码试一下,下面是我的测试代码:

    public class TestQueue {
    
        private static int TOTAL_COUNT = 10000000;
        private static int TOTAL_WRITE = 2;
        private static int TOTAL_READ = 2;
        private static SimpleDateFormat DATE_FORMAT = new SimpleDateFormat("HH:mm ss:SSS");
    
        public static void main(String[] args) {
            AtomicInteger flag = new AtomicInteger(0);
            ConcurrentHashMap<Integer, AtomicInteger> total = new ConcurrentHashMap<>(TOTAL_COUNT);
            for (int i = 0; i != TOTAL_COUNT; i++) {
                total.put(i, new AtomicInteger(0));
            }
            CustomQueue<Integer> customQueue = new CustomQueue<>();
            ExecutorService executor = Executors.newCachedThreadPool();
    
            Date startTime = new Date();
            System.out.println("Test  Started: " + DATE_FORMAT.format(startTime));
    
            for (int i = 0; i != TOTAL_WRITE; i++) {
                executor.execute(new Runnable() {
                    @Override
                    public void run() {
                        for (int i = 0; i != TOTAL_COUNT; i++) {
                            customQueue.add(i);
                        }
                    }
                });
            }
    
            for (int i = 0; i != TOTAL_READ; i++) {
                executor.execute(new Runnable() {
                    @Override
                    public void run() {
    
                        int sum = 0;
    
                        while (flag.get() != TOTAL_WRITE * TOTAL_COUNT) {
                            Integer num = customQueue.poll();
                            if (num != null) {
                                sum++;
                                flag.incrementAndGet();
                                total.get(num).incrementAndGet();
                            }
                        }
    
                        System.out.println("Get thread finished, Total: " + sum);
                    }
                });
            }
    
            executor.shutdown();
    
            try {
                executor.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS);
                Date endTime = new Date();
                long totalTime = endTime.getTime() - startTime.getTime();
                for (int i = 0; i != TOTAL_COUNT; i++) {
                    if (total.get(i).get() != TOTAL_WRITE) {
                        System.out.println("Test Failed: " + i + " " + total.get(i));
                        break;
                    }
                }
                System.out.println("Test Finished: " + DATE_FORMAT.format(endTime));
                System.out.printf("Total Time Cost: %ds %dms", totalTime / 1000, totalTime % 1000);
            } catch (InterruptedException e) {
                System.out.println("Failure: " + flag.get());
                e.printStackTrace();
            }
    
        }
    
    }
    

    相关文章

      网友评论

          本文标题:ConcurrentLinkedQueue源码分析(Java8)

          本文链接:https://www.haomeiwen.com/subject/quzxcxtx.html