美文网首页Java服务器端编程IT技术篇Java
阿里的面试题带你认识ForkJoinPool

阿里的面试题带你认识ForkJoinPool

作者: 狼王编程 | 来源:发表于2020-11-24 14:27 被阅读0次

    我相信大家都用过线程池,比如ExcutorService,比如ThreadPoolExcutor

    今天来讲讲ForkJoinPool,它实现于ExcutorService,但又和我们常用的

    ThreadPoolExcutor原理不同

    前言

    随着在硬件上多核处理器的发展和广泛使用,并发编程成为程序员必须掌握的一门技术,在面试中也经常考查面试者并发相关的知识。

    今天,我们就从一道阿里的面试题来开始

    题目:如何充分利用多核CPU,计算超大数组中所有整数的和?

    解析开始

  1. 1.单线程相加?

  2. 我们最容易想到就是单线程相加,一个for循环搞定。

  3. 2.线程池相加?

  4. 如果进一步优化,我们会自然而然地想到使用线程池来分段相加,最后再把每个段的结果相加。

  5. 3.其它?

  6. Yes,就是我们今天的主角——ForkJoinPool,但是它要怎么实现呢?似乎没怎么用过哈^^

    让我们看看上面是那种方法都如何实现

    /**
     * 计算1亿个整数的和 */
    public class ForkJoinPoolTest01 {
        public static void main(String[] args) throws ExecutionException, InterruptedException {
            // 构造数据
            int length = 100000000;
            long[] arr = new long[length];
            for (int i = 0; i < length; i++) {
                arr[i] = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE);
            }
            // 单线程
            singleThreadSum(arr);
            // ThreadPoolExecutor线程池
            multiThreadSum(arr);
            // ForkJoinPool线程池
            forkJoinSum(arr);

        }

        private static void singleThreadSum(long[] arr) {
            long start = System.currentTimeMillis();

            long sum = 0;
            for (int i = 0; i < arr.length; i++) {
                // 模拟耗时,本文由公从号“彤哥读源码”原创
                sum += (arr[i]/5*5/5*5/5*5/5*5/5*5);
            }

            System.out.println("sum: " + sum);
            System.out.println("single thread elapse: " + (System.currentTimeMillis() - start));

        }

        private static void multiThreadSum(long[] arr) throws ExecutionException, InterruptedException {
            long start = System.currentTimeMillis();

            int count = 8;
            ExecutorService threadPool = Executors.newFixedThreadPool(count);
            List<Future<Long>> list = new ArrayList<>();
            for (int i = 0; i < count; i++) {
                int num = i;
                // 分段提交任务
                Future<Long> future = threadPool.submit(() -> {
                    long sum = 0;
                    for (int j = arr.length / count * num; j < (arr.length / count * (num + 1)); j++) {
                        try {
                            // 模拟耗时
                            sum += (arr[j]/5*5/5*5/5*5/5*5/5*5);
                        } catch (Exception e) {
                            e.printStackTrace();
                        }
                    }
                    return sum;
                });
                list.add(future);
            }

            // 每个段结果相加
            long sum = 0;
            for (Future<Long> future : list) {
                sum += future.get();
            }

            System.out.println("sum: " + sum);
            System.out.println("multi thread elapse: " + (System.currentTimeMillis() - start));
        }

        private static void forkJoinSum(long[] arr) throws ExecutionException, InterruptedException {
            long start = System.currentTimeMillis();

            ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
            // 提交任务
            ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(new SumTask(arr, 0, arr.length));
            // 获取结果
            Long sum = forkJoinTask.get();

            forkJoinPool.shutdown();

            System.out.println("sum: " + sum);
            System.out.println("fork join elapse: " + (System.currentTimeMillis() - start));
        }

        private static class SumTask extends RecursiveTask<Long> {
            private long[] arr;
            private int from;
            private int to;

            public SumTask(long[] arr, int from, int to) {
                this.arr = arr;
                this.from = from;
                this.to = to;
            }

            @Override
            protected Long compute() {
                // 小于1000的时候直接相加,可灵活调整
                if (to - from <= 1000) {
                    long sum = 0;
                    for (int i = from; i < to; i++) {
                        // 模拟耗时
                        sum += (arr[i]/5*5/5*5/5*5/5*5/5*5);
                    }
                    return sum;
                }

                // 分成两段任务,本文由公从号“彤哥读源码”原创
                int middle = (from + to) / 2;
                SumTask left = new SumTask(arr, from, middle);
                SumTask right = new SumTask(arr, middle, to);

                // 提交左边的任务
                left.fork();
                // 右边的任务直接利用当前线程计算,节约开销
                Long rightResult = right.compute();
                // 等待左边计算完毕
                Long leftResult = left.join();
                // 返回结果
                return leftResult + rightResult;
            }
        }
    }

    ~~Garnett偷偷地告诉你,实际上计算1亿个整数相加,单线程是最快的,我的电脑大概是100ms左右,使用线程池反而会变慢。~~

    ~~所以,为了演示ForkJoinPool的牛逼之处,我把每个数都/5*5/5*5/5*5/5*5/5*5了一顿操作,用来模拟计算耗时。~~

    来看结果:

    sum: 107352457433800662
    single thread elapse: 789
    sum: 107352457433800662
    multi thread elapse: 228
    sum: 107352457433800662fork join elapse: 189

    可以看到,ForkJoinPool相对普通线程池还是有很大提升的。

    什么是ForkJoinPool?

    谈到线程池,很多人会想到Executors提供的一些预设的线程池,比如单线程线程池SingleThreadExecutor,固定大小的线程池FixedThreadPool,但是很少有人会注意到其中还提供了一种特殊的线程池:WorkStealingPool,我们点进这个方法,会看到和其他方法不同的是,这种线程池并不是通过ThreadPoolExecutor来创建的,而是ForkJoinPool来创建的

    public static ExecutorService newWorkStealingPool() {
            return new ForkJoinPool
                (Runtime.getRuntime().availableProcessors(),
                 ForkJoinPool.defaultForkJoinWorkerThreadFactory,
                 null, true);
        }

    ThreadPoolExecutor应该都很了解了,就是一个基本的存储线程的线程池,需要执行任务的时候就从线程池中拿一个线程来执行。而ForkJoinPool则不仅仅是这么简单,同样也不是ThreadPoolExecutor的代替品,这种线程池是为了实现“分治法”这一思想而创建的,通过把大任务拆分成小任务,然后再把小任务的结果汇总起来就是最终的结果,和MapReduce的思想很类似

    最核心的思想可以这样描述:

    if(任务很小){
        直接计算得到结果
    }else{
        分拆成N个子任务
        调用子任务的fork()进行计算
        调用子任务的join()合并计算结果
    }
  7. 1.fork()

  8. fork()方法类似于线程的Thread.start()方法,但是它不是真的启动一个线程,而是将任务放入到工作队列中。

  9. 2.join()

  10. join()方法类似于线程的Thread.join()方法,但是它不是简单地阻塞线程,而是利用工作线程运行其它任务。当一个工作线程中调用了join()方法,它将处理其它任务,直到注意到目标子任务已经完成了。

    ForkJoinPool内部原理-工作窃取

    work-stealing(工作窃取)算法

    ForkJoinPool 的另一个特性是它使用了work-stealing(工作窃取)算法

    线程池内的所有工作线程都尝试找到并执行已经提交的任务,或者是被其他活动任务创建的子任务(如果不存在就阻塞等待)。这种特性使得 ForkJoinPool 在运行多个可以产生子任务的任务,或者是提交的许多小任务时效率更高。尤其是构建异步模型的 ForkJoinPool 时,对不需要合并(join)的事件类型任务也非常适用

    在 ForkJoinPool 中,线程池中每个工作线程(ForkJoinWorkerThread)都对应一个任务队列(WorkQueue),工作线程优先处理来自自身队列的任务(LIFO或FIFO顺序,参数 mode 决定),然后以FIFO的顺序随机窃取其他队列中的任务。

    ForkJoinPool中的任务

    ForkJoinPool 中的任务分为两种:

    一种是本地提交的任务(Submission task,如 execute、submit 提交的任务);

    另外一种是 fork 出的子任务(Worker task)。

    两种任务都会存放在 WorkQueue 数组中,但是这两种任务并不会混合在同一个队列里,ForkJoinPool 内部使用了一种随机哈希算法(有点类似 ConcurrentHashMap 的桶随机算法)将工作队列与对应的工作线程关联起来,Submission 任务存放在 WorkQueue 数组的偶数索引位置,Worker 任务存放在奇数索引位。

    实质上,Submission 与 Worker 一样,只不过它被限制只能执行它们提交的本地任务,在后面的源码解析中,我们统一称之为“Worker”。

    任务的分布情况如下图:

    ForkJoinPool原理

    初始化ForkJoinPool

    ForkJoinPool pool = ForkJoinPool.commonPool()

    public static ForkJoinPool commonPool() {
        // assert common != null : "static init error";
        return common;
    }

    获取ForkJoinPool很简单,直接调用commonPool()。注意,这个方法是jdk1.8才加的,也是推荐的方法,满足大部分场景。

    static{
        //...
        common = java.security.AccessController.doPrivileged
                (new java.security.PrivilegedAction<ForkJoinPool>() {
                    public ForkJoinPool run() { return makeCommonPool(); }});
        //...
    }

    private static ForkJoinPool makeCommonPool() {
        //...
        return new ForkJoinPool(parallelism, factory, handler, LIFO_QUEUE,"ForkJoinPool.commonPool-worker-");
    }

    common在static{}里创建,调用的是makeCommonPool(),最终调用ForkJoinPool的构造函数。

    private ForkJoinPool(int parallelism,
                         ForkJoinWorkerThreadFactory factory,
                         UncaughtExceptionHandler handler,
                         int mode,                     String workerNamePrefix) {
        this.workerNamePrefix = workerNamePrefix;
        this.factory = factory;
        this.ueh = handler;
        this.config = (parallelism & SMASK) | mode;
        long np = (long)(-parallelism); // offset ctl counts
        this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
    }

    parallelism默认是cpu核心数,ForkJoinPool里线程数量依据于它,但不表示最大线程数,不要等同于ThreadPoolExecutor里的corePoolSize或者maximumPoolSize。

    factory是线程工程,不是新东西了,默认实现是

    DefaultForkJoinWorkerThreadFactory。

    workerNamePrefix是其中线程名称的前缀,默认使用“ForkJoinPool-*”

    config保存不变的参数,包括了parallelism和mode,供后续读取。mode可选FIFO_QUEUELIFO_QUEUE,默认是LIFO_QUEUE,具体用哪种,就要看业务。

    ctl是ForkJoinPool中最重要的控制字段,将下面信息按16bit为一组封装在一个long中。

  11. AC: 活动的worker数量;

  12. TC: 总共的worker数量;

  13. SS: WorkQueue状态,第一位表示active的还是inactive,其余十五位表示版本号(对付ABA);

  14. ID:  这里保存了一个WorkQueue在WorkQueue[]的下标,和其他worker通过字段stackPred组成一个TreiberStack。后文讲的栈顶,指这里下标所在的WorkQueue。

  15. TreiberStack:这个栈的pull和pop使用了CAS,所以支持并发下的无锁操作。

    AC和TC初始化时取的是parallelism负数,后续代码可以直接判断正负,为负代表还没有达到目标数量。另外ctl低32位有个技巧可以直接用sp=(int)ctl取得,为负代表存在空闲worker。

    线程池缺不了状态的变化,记录字段是runState,具体介绍在后面的“ForkJoinPool状态修改”。

    任务ForkJoinTask

    ForkJoinPool执行任务的对象是ForkJoinTask,它是一个抽象类,有两个具体实现类RecursiveAction和RecursiveTask。

    public abstract class RecursiveAction extends ForkJoinTask<Void> {
        protected abstract void compute();

        public final Void getRawResult() { return null; }

        protected final void setRawResult(Void mustBeNull) { }

        protected final boolean exec() {
            compute();
            return true;
        }
    }

    public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
        V result;

        protected abstract V compute();

        public final V getRawResult() {
            return result;
        }

        protected final void setRawResult(V value) {
            result = value;
        }

        protected final boolean exec() {
            result = compute();
            return true;
        }
    }

    ForkJoinTask的抽象方法exec由RecursiveAction和RecursiveTask实现,它被定义为final,具体的执行步骤compute延迟到子类实现。很容易看出RecursiveAction和RecursiveTask的区别,前者没有result,getRawResult返回空,它们对应不需要返回结果和需要返回结果两种场景。

    ForkJoinTask里很重要的字段是它的状态status,默认是0,当得出结果时变更为负数,有三种结果:

  16. NORMAL

  17. CANCELLED

  18. EXCEPTIONAL

  19. 除此之外,在得出结果之前,任务状态能够被设置为SIGNAL,表示有线程等待这个任务的结果,执行完成后需要notify通知,具体看后文的join。

    ForkJoinTask在触发执行后,并不支持其他什么特别操作,只能等待任务执行完成。CountedCompleter是ForkJoinTask的子类,它在子任务协作方面扩展了更多操作。我们聚焦ForkJoinPool主线流程,CountedCompleter相关内容另文再介绍。

    WorkQueue

    WorkQueue是一个双端队列,它定义在ForkJoinPool类里。

    scanState描述WorkQueue当前状态:

  20. 偶数表示RUNNING

  21. 奇数表示SCANNING

  22. 负数表示inactive

  23. stackPred是WorkQueue组成TreiberStack时,保存前者的字段。

    ForkJoinPool状态修改

  24. STARTED

  25. STOP

  26. TERMINATED

  27. SHUTDOWN

  28. RSLOCK‍‍‍‍

  29. RSIGNAL

  30. runState记录了ForkJoinPool的运行状态,除了SHUTDOWN是负数,其他都是正数。前面四种不用说了,线程池标准状态流转。在多线程环境修改runState,不能简单想改就改,需要先获取锁,RSLOCK和RSIGNAL就用在这里。

    private int lockRunState() {
        int rs;
        return ((((rs = runState) & RSLOCK) != 0 ||
                 !U.compareAndSwapInt(this, RUNSTATE, rs, rs |= RSLOCK)) ?
                awaitRunStateLock() : rs);
    }

    修改前调用lockRunState锁定,检查当前状态,尝试一次使用CAS修改runState为RSLOCK。需要状态变化的机会很少,大多数时间一次就能成功,但不能排除少几率的竞争,这时候进入awaitRunStateLock。

    private int awaitRunStateLock() {
        Object lock;
        boolean wasInterrupted = false;
        for (int spins = SPINS, r = 0, rs, ns;;) {
            //1
            if (((rs = runState) & RSLOCK) == 0) {
                if (U.compareAndSwapInt(this, RUNSTATE, rs, ns = rs | RSLOCK)) {
                    if (wasInterrupted) {
                        try {
                            Thread.currentThread().interrupt();
                        } catch (SecurityException ignore) {
                        }
                    }
                    return ns;
                }
            }
            else if (r == 0)
                r = ThreadLocalRandom.nextSecondarySeed();
            else if (spins > 0) {
                r ^= r << 6; r ^= r >>> 21; r ^= r << 7; // xorshift
                if (r >= 0)
                    --spins;
            }
            //2
            else if ((rs & STARTED) == 0 || (lock = stealCounter) == null)
                Thread.yield();   // initialization race
            //3
            else if (U.compareAndSwapInt(this, RUNSTATE, rs, rs | RSIGNAL)) {
                synchronized (lock) {
                    if ((runState & RSIGNAL) != 0) {
                        try {
                            lock.wait();
                        } catch (InterruptedException ie) {
                            if (!(Thread.currentThread() instanceof
                                  ForkJoinWorkerThread))
                                wasInterrupted = true;
                        }
                    }
                    else
                        lock.notifyAll();
                }
            }
        }
    }

    在自旋中,第一步,mark1再次尝试修改runState为RSLOCK,成功直接返回。

    mark2检查ForkJoinPool初始化情况,这里没有额外多写个变量做锁,直接利用了stealCounter这个原子变量。因为初始化时(后文的externalSubmit),才会对stealCounter赋值。所以当状态不是STARTED或者stealCounter为空时,让出线程等待。

    mark3处,线程不会无限制自旋尝试,会利用wait/notify进入阻塞等待。RSIGNAL代替原状态,表示有线程进入了等待,解锁时要处理。在高并发下,这不是一个好的设计,但进入这里的几率很低,作为兜底还是可以的。

    private void unlockRunState(int oldRunState, int newRunState) {
        if (!U.compareAndSwapInt(this, RUNSTATE, oldRunState, newRunState)) {
            Object lock = stealCounter;
            runState = newRunState;              // clears RSIGNAL bit
            if (lock != null)
                synchronized (lock) { lock.notifyAll(); }
        }
    }

    解锁的逻辑就比较简单,如果顺利将状态修改为目标状态,成功大吉。否则表示有别的线程进入了wait,需要调用notifyAll唤醒,重新尝试竞争。

    ForkJoinPool最佳实践

    (1)最适合的是计算密集型任务

    (2)在需要阻塞工作线程时,可以使用ManagedBlocker;

    (3)不应该在RecursiveTask的内部使用ForkJoinPool.invoke()/invokeAll();

    总结

    (1)ForkJoinPool特别适合于“分而治之”算法的实现;

    (2)ForkJoinPool和ThreadPoolExecutor是互补的,不是谁替代谁的关系,二者适用的场景不同;

    (3)ForkJoinTask有两个核心方法——fork()和join(),有三个重要子类——RecursiveAction、RecursiveTask和CountedCompleter;

    (4)ForkjoinPool内部基于“工作窃取”算法实现;

    (5)每个线程有自己的工作队列,它是一个双端队列,自己从队列头存取任务,其它线程从尾部窃取任务;

    (6)ForkJoinPool最适合于计算密集型任务,但也可以使用ManagedBlocker以便用于阻塞型任务;

    (7)RecursiveTask内部可以少调用一次fork(),利用当前线程处理,这是一种技巧;

    Garnett还会不断的分享技术干货的,希望你们是我最好的观众!

    乐于输出干货的Java技术公众号:Garnett的Java之路。公众号内有大量的技术文章、海量视频资源、精美脑图,不妨来关注一下!回复【资料】领取大量学习资源和免费书籍!

    相关文章

      网友评论

        本文标题:阿里的面试题带你认识ForkJoinPool

        本文链接:https://www.haomeiwen.com/subject/gqviiktx.html