美文网首页
ForkJoin源码解析

ForkJoin源码解析

作者: 海涛_meteor | 来源:发表于2019-07-28 12:11 被阅读0次

    前言

    本文通过Forkjoin实现数据累加的demo来进行源码分析,并且基于jdk8环境,因此与jdk7的情况会略有不同。其具体代码实现如下。

    任务类
    public class ForkJoinSumCalculator extends RecursiveTask<Long> {
    
        private final long[] numbers;
        private final int start;
        private final int end;
        public static final long THRESHOLD = 10000;
    
        public ForkJoinSumCalculator(long[] numbers) {
            this(numbers, 0, numbers.length);
        }
    
        private ForkJoinSumCalculator(long[] numbers, int start, int end) {
            this.numbers = numbers;
            this.start = start;
            this.end = end;
        }
    
        @Override
        protected Long compute() {
            int length = end - start;
            if (length <= THRESHOLD) {
                return computeSequentially();
            }
            ForkJoinSumCalculator leftTask = new ForkJoinSumCalculator(numbers, start, start + length/2);
            leftTask.fork();
            ForkJoinSumCalculator rightTask = new ForkJoinSumCalculator(numbers, start + length/2, end);
            Long rightResult = rightTask.compute();
            Long leftResult = leftTask.join();
            return leftResult + rightResult;
        }
        private long computeSequentially() {
            long sum = 0;
            for (int i = start; i < end; i++) {
                sum += numbers[i];
            }
                return sum;
        }
    }
    

    定义了ForkJoinSumCalculator来实现任务分解和子任务的累加计算。

    测试类
    public class ForkJoinTest {
    
        public static void main(String[] args) {
            long[] numbers = LongStream.rangeClosed(1, 1000000).toArray();
            ForkJoinTask<Long> task = new ForkJoinSumCalculator(numbers);//1
            long result = new ForkJoinPool().invoke(task);//2
            System.out.println("result:"+result);
        }
    }
    

    通过测试类ForkJoinTest启动了ForkJoinPool并计算得到结果,从这里的main方法可以看出实现主要依赖1和2两行,1中ForkJoinSumCalculator类的初始化先不做过多说明,从2开始进入分析。

    源码解析

    首先来看一下new ForkJoinPool()这个线程池初始化操作到底做了什么,源码如下

        public ForkJoinPool() {
            this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
                 defaultForkJoinWorkerThreadFactory, null, false);
        }
        
        public ForkJoinPool(int parallelism,
                            ForkJoinWorkerThreadFactory factory,
                            UncaughtExceptionHandler handler,
                            boolean asyncMode) {
            this(checkParallelism(parallelism),
                 checkFactory(factory),
                 handler,
                 asyncMode ? FIFO_QUEUE : LIFO_QUEUE,
                 "ForkJoinPool-" + nextPoolId() + "-worker-");
            checkPermission();
        }
        
        private ForkJoinPool(int parallelism,
                             ForkJoinWorkerThreadFactory factory,
                             UncaughtExceptionHandler handler,
                             int mode,
                             String workerNamePrefix) {
            this.workerNamePrefix = workerNamePrefix;
            this.factory = factory;
            this.ueh = handler;
            this.config = (parallelism & SMASK) | mode;
            long np = (long)(-parallelism); // offset ctl counts
            this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
        }
    

    这里就是三个ForkJoinPool的连续调用,最后的作用仅是给workerNamePrefixfactoryuehconfigctl几个属性赋值。顺带提一下ForkJoinPool上包含注解sun.misc.Contended,这个注解jdk8中才引入,是java中避免缓存行伪共享的一种方案,能在并发情况下更好提升性能,此处不展开。
    接着来看一下checkParallelism方法

        private static int checkParallelism(int parallelism) {
            if (parallelism <= 0 || parallelism > MAX_CAP)
                throw new IllegalArgumentException();
            return parallelism;
        }
    

    这里传入的parallelismMath.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),该值取0x7fff和当前核心数的最小值,结合checkParallelism方法可以看出,parallelism值一般就是CPU核数了。由于SMASK = = 0xffffmodeLIFO_QUEUE = 0(从名字可以很明显看出这是个后入先出的队列),因此根据表达式config的值就是核心数。

    然后来看一下ctl这个值,这是一个64位的long变量,根据注释说明,ctl的64位被分成4个16位标识,依次称为ACTCSSID

    • AC: 运行中线程数与目标值checkParallelism的差值,如果ac是负的说明没有足够的活动线程
    • TC: 总线程数与目标值checkParallelism的差值,如果tc是负的说明没有足够的总线程
    • SS: 版本计数和最顶端等待线程的状态
    • ID: 栈中最顶端等待线程的索引

    ctl的低32位称为sp,当sp非0时说明有等待线程。

    然后需要注意的是factory属性传入的值为defaultForkJoinWorkerThreadFactory,该值的初始化在ForkJoinPool类的静态代码块中,源码如下

        static {
            // initialize field offsets for CAS etc
            try {
                U = sun.misc.Unsafe.getUnsafe();
                Class<?> k = ForkJoinPool.class;
                CTL = U.objectFieldOffset
                    (k.getDeclaredField("ctl"));
                RUNSTATE = U.objectFieldOffset
                    (k.getDeclaredField("runState"));
                STEALCOUNTER = U.objectFieldOffset
                    (k.getDeclaredField("stealCounter"));
                Class<?> tk = Thread.class;
                PARKBLOCKER = U.objectFieldOffset
                    (tk.getDeclaredField("parkBlocker"));
                Class<?> wk = WorkQueue.class;
                QTOP = U.objectFieldOffset
                    (wk.getDeclaredField("top"));
                QLOCK = U.objectFieldOffset
                    (wk.getDeclaredField("qlock"));
                QSCANSTATE = U.objectFieldOffset
                    (wk.getDeclaredField("scanState"));
                QPARKER = U.objectFieldOffset
                    (wk.getDeclaredField("parker"));
                QCURRENTSTEAL = U.objectFieldOffset
                    (wk.getDeclaredField("currentSteal"));
                QCURRENTJOIN = U.objectFieldOffset
                    (wk.getDeclaredField("currentJoin"));
                Class<?> ak = ForkJoinTask[].class;
                ABASE = U.arrayBaseOffset(ak);
                int scale = U.arrayIndexScale(ak);
                if ((scale & (scale - 1)) != 0)   //判断scale是否为2的幂次方
                    throw new Error("data type scale not a power of two");
                ASHIFT = 31 - Integer.numberOfLeadingZeros(scale);
            } catch (Exception e) {
                throw new Error(e);
            }
    
            commonMaxSpares = DEFAULT_COMMON_MAX_SPARES;
            defaultForkJoinWorkerThreadFactory =
                new DefaultForkJoinWorkerThreadFactory();
            modifyThreadPermission = new RuntimePermission("modifyThread");
    
            common = java.security.AccessController.doPrivileged
                (new java.security.PrivilegedAction<ForkJoinPool>() {
                    public ForkJoinPool run() { return makeCommonPool(); }});
            int par = common.config & SMASK; // report 1 even if threads disabled
            commonParallelism = par > 0 ? par : 1;
        }
    

    通过defaultForkJoinWorkerThreadFactory = new DefaultForkJoinWorkerThreadFactory();对该常量进行了初始化,DefaultForkJoinWorkerThreadFactoryForkJoinPool的静态内部类,其具体实现为

        static final class DefaultForkJoinWorkerThreadFactory
            implements ForkJoinWorkerThreadFactory {
            public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
                return new ForkJoinWorkerThread(pool);
            }
        }
    

    到这里ForkJoinPool的初始化就算完成了,接着回到main方法来看一下invoke(task)方法的实现

        public <T> T invoke(ForkJoinTask<T> task) {
            if (task == null)
                throw new NullPointerException();
            externalPush(task);
            return task.join();
        }
    

    这里调用了externalPush(task)方法,接着来看一下

        final void externalPush(ForkJoinTask<?> task) {
            WorkQueue[] ws; WorkQueue q; int m;
            int r = ThreadLocalRandom.getProbe();
            int rs = runState;
            if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
                (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
                U.compareAndSwapInt(q, QLOCK, 0, 1)) {
                ForkJoinTask<?>[] a; int am, n, s;
                if ((a = q.array) != null &&
                    (am = a.length - 1) > (n = (s = q.top) - q.base)) {
                    int j = ((am & s) << ASHIFT) + ABASE;
                    U.putOrderedObject(a, j, task);  //task加入q的任务队列中
                    U.putOrderedInt(q, QTOP, s + 1);  //修改top的位置
                    U.putIntVolatile(q, QLOCK, 0);
                    if (n <= 1)
                        signalWork(ws, q);
                    return;
                }
                U.compareAndSwapInt(q, QLOCK, 1, 0);
            }
            externalSubmit(task);
        }
    

    首先看到ThreadLocalRandom.getProbe()可以生成一个随机数,ThreadLocalRandom类解决了Random种子竞争的问题,在并发情况下性能更好,这里不做过多分析。

    runState标识pool的运行状态,具体表示如下

        // runState bits: SHUTDOWN must be negative, others arbitrary powers of two
        private static final int  RSLOCK     = 1;
        private static final int  RSIGNAL    = 1 << 1;
        private static final int  STARTED    = 1 << 2;
        private static final int  STOP       = 1 << 29;
        private static final int  TERMINATED = 1 << 30;
        private static final int  SHUTDOWN   = 1 << 31;
    

    看第一个if,需要同时满足5个条件才进入分支,来看一下

    1. (ws = workQueues) != null //workQueues数组非空
    2. (m = (ws.length - 1)) >= 0 //workQueues中至少有一个WorkQueue对象,并赋值m
    3. (q = ws[m & r & SQMASK]) != null //m & r & SQMASK保证随机数为偶数且不大于m,这么做是由于这里有一个隐含的约定,只有线程为空的WorkQueue对象才能出现在ws的偶数位
    4. r != 0 //随机数非0
    5. rs > 0 //runState非0表示线程池没有被关闭
    6. U.compareAndSwapInt(q, QLOCK, 0, 1) //能够成功将对象q的qlock属性从0置为1,这里的qlock=1说明被锁定, < 0表示终止,所以这里显然是一个加锁操作

    上述条件不能全部满足则会跳出if执行externalSubmit(task)方法,否则就接着进入下一个if语句,又需要满足两个条件

    1. (a = q.array) != null //q的队列不为空,这里的array类型为ForkJoinTask<?>[]
    2. (am = a.length - 1) > (n = (s = q.top) - q.base) //top是当前线程即将处理的队列偏移量,base是可以被其他线程“窃取”的队列偏移量,base是被volatile修饰的,所以这个值显然是会存在并发情况的

    可以看到当条件不满足时会通过U.compareAndSwapInt(q, QLOCK, 1, 0)直接释放锁。
    而当满足上述条件时,开始执行int j = ((am & s) << ASHIFT) + ABASE,这其中ASHIFT是数组array中每个元素所占字节长度的二进制位数(去除高位所有0后的位数),ABASE是第一个元素地址相对于数组起始地址的偏移值,根据计算出的偏移量jtask放入array中,利用QTOP偏移量将top值进行+1操作,置qlock为0以释放锁,并执行以下代码块:

        if (n <= 1)
            signalWork(ws, q);
        return;
    

    这里跟进看一下signalWork(ws, q)方法

        final void signalWork(WorkQueue[] ws, WorkQueue q) {
            long c; int sp, i; WorkQueue v; Thread p;
            while ((c = ctl) < 0L) {                       // active线程过少
                if ((sp = (int)c) == 0) {                  // 没有空闲线程
                    if ((c & ADD_WORKER) != 0L)            // 工作线程太少
                        tryAddWorker(c);
                    break;
                }
                if (ws == null)                            // unstarted/terminated
                    break;
                if (ws.length <= (i = sp & SMASK))         // 已终止
                    break;
                if ((v = ws[i]) == null)                   // 正在终止
                    break;
                int vs = (sp + SS_SEQ) & ~INACTIVE;        // next scanState
                int d = sp - v.scanState;                  // screen CAS
                long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
                if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
                    v.scanState = vs;                      
                    if ((p = v.parker) != null)           // 唤醒v的owner
                        U.unpark(p);
                    break;
                }
                if (q != null && q.base == q.top)          // no more work
                    break;
            }
        }
    

    (c = ctl) < 0L判断active线程过少时,会执行while循环,当满足工作线程太少的判断条件时,会执行tryAddWorker(c)方法增加工作线程,来看看具体代码

        private void tryAddWorker(long c) {
            boolean add = false;
            do {
                long nc = ((AC_MASK & (c + AC_UNIT)) |
                           (TC_MASK & (c + TC_UNIT)));//AC、TC分别进行加1操作,表示增加了worker线程
                if (ctl == c) {
                    int rs, stop;                 // check if terminating
                    if ((stop = (rs = lockRunState()) & STOP) == 0)
                        add = U.compareAndSwapLong(this, CTL, c, nc);
                    unlockRunState(rs, rs & ~RSLOCK);
                    if (stop != 0)
                        break;
                    if (add) {
                        createWorker();
                        break;
                    }
                }
            } while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
        }
    

    这里用了do-while循环尝试创建worker线程,当CAS地修改ctl成功时才会执行createWorker()方法并推出,createWorker()方法实现如下

        private boolean createWorker() {
            ForkJoinWorkerThreadFactory fac = factory;
            Throwable ex = null;
            ForkJoinWorkerThread wt = null;
            try {
                if (fac != null && (wt = fac.newThread(this)) != null) {
                    wt.start();
                    return true;
                }
            } catch (Throwable rex) {
                ex = rex;
            }
            deregisterWorker(wt, ex);
            return false;
        }
    

    根据之前的静态代码块可以知道,这里传入的factory是一个DefaultForkJoinWorkerThreadFactory类型对象,

        static final class DefaultForkJoinWorkerThreadFactory
            implements ForkJoinWorkerThreadFactory {
            public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
                return new ForkJoinWorkerThread(pool);
            }
        }
        
            protected ForkJoinWorkerThread(ForkJoinPool pool) {
            // Use a placeholder until a useful name can be set in registerWorker
            super("aForkJoinWorkerThread");
            this.pool = pool;
            this.workQueue = pool.registerWorker(this);//将当前ForkJoinWorkerThread线程注册到ForkJoinPool中
        }
    

    由此可知这里的createWorker()方法会创建一个ForkJoinWorkerThread线程并启动它。pool.registerWorker(this)会将当前线程注册到pool中,这也就意味着当前线程会成为这个workQueueowner,这里就要说到worker steal算法,大意就是一个线程从自己任务队列的头部取出任务执行,而其他空闲线程可以从其队列的尾部“偷”任务执行,以充分利用空闲的线程资源。这里当线程成为owner之后,才可以从top位置取任务,因此WorkQueue中的top是非volatile类型,base却是volatile的。

    由于在createWorker()中,创建的线程被启动了,那么我们有必要来看看ForkJoinWorkerThreadrun方法里都做了些什么。

        public void run() {
            if (workQueue.array == null) { // only run once
                Throwable exception = null;
                try {
                    onStart(); //没有任何操作
                    pool.runWorker(workQueue);
                } catch (Throwable ex) {
                    exception = ex;
                } finally {
                    try {
                        onTermination(exception);
                    } catch (Throwable ex) {
                        if (exception == null)
                            exception = ex;
                    } finally {
                        pool.deregisterWorker(this, exception);
                    }
                }
            }
        }
    

    可以看到业务都交由了pool.runWorker(workQueue)运行,源码如下

        final void runWorker(WorkQueue w) {
            w.growArray();                   // allocate queue
            int seed = w.hint;               // initially holds randomization hint
            int r = (seed == 0) ? 1 : seed;  // avoid 0 for xorShift
            for (ForkJoinTask<?> t;;) {
                if ((t = scan(w, r)) != null)
                    w.runTask(t);
                else if (!awaitWork(w, r))
                    break;
                r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
            }
        }
    

    这里用了一个死循环来执行task,具体涉及scanrunTaskawaitWork几个方法,逐一来看一下。

    首先是scan

        private ForkJoinTask<?> scan(WorkQueue w, int r) {
            WorkQueue[] ws; int m;
            if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
                int ss = w.scanState;                     // initially non-negative
                for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
                    WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
                    int b, n; long c;
                    if ((q = ws[k]) != null) {
                        if ((n = (b = q.base) - q.top) < 0 &&
                            (a = q.array) != null) {      // non-empty
                            long i = (((a.length - 1) & b) << ASHIFT) + ABASE; //base对应的地址
                            if ((t = ((ForkJoinTask<?>)
                                      U.getObjectVolatile(a, i))) != null &&
                                q.base == b) {
                                if (ss >= 0) {
                                    if (U.compareAndSwapObject(a, i, t, null)) {
                                        q.base = b + 1;
                                        if (n < -1)       // signal others
                                            signalWork(ws, q);
                                        return t;
                                    }
                                }
                                else if (oldSum == 0 &&   // try to activate
                                         w.scanState < 0)
                                    tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
                            }
                            if (ss < 0)                   // refresh
                                ss = w.scanState;
                            r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
                            origin = k = r & m;           // move and rescan
                            oldSum = checkSum = 0;
                            continue;
                        }
                        checkSum += b;
                    }
                    if ((k = (k + 1) & m) == origin) {    // 直到遍历完所有队列才停止
                        if ((ss >= 0 || (ss == (ss = w.scanState))) &&
                            oldSum == (oldSum = checkSum)) {
                            if (ss < 0 || w.qlock < 0)    // already inactive
                                break;
                            int ns = ss | INACTIVE;       // try to inactivate
                            long nc = ((SP_MASK & ns) |
                                       (UC_MASK & ((c = ctl) - AC_UNIT)));
                            w.stackPred = (int)c;         // 记录前一个top值
                            U.putInt(w, QSCANSTATE, ns);
                            if (U.compareAndSwapLong(this, CTL, c, nc))
                                ss = ns;
                            else
                                w.scanState = ss;         // back out
                        }
                        checkSum = 0;
                    }
                }
            }
            return null;
        }
    

    这个方法主要做的一件事就是遍历workQueue,并窃取一个尾部任务,窃取到则立即返回,并执行w.runTask(t),那么接着来看一下runTask方法

            final void runTask(ForkJoinTask<?> task) {
                if (task != null) {
                    scanState &= ~SCANNING; // mark as busy
                    (currentSteal = task).doExec();
                    U.putOrderedObject(this, QCURRENTSTEAL, null); // release for GC
                    execLocalTasks();
                    ForkJoinWorkerThread thread = owner;
                    if (++nsteals < 0)      // collect on overflow
                        transferStealCount(pool);
                    scanState |= SCANNING;
                    if (thread != null)
                        thread.afterTopLevelExec();
                }
            }
    

    可以看到task被提交给(currentSteal = task).doExec()进行处理

        final int doExec() {
            int s; boolean completed;
            if ((s = status) >= 0) {
                try {
                    completed = exec();
                } catch (Throwable rex) {
                    return setExceptionalCompletion(rex);
                }
                if (completed)
                    s = setCompletion(NORMAL);
            }
            return s;
        }
    

    之后又被交由ForkJoinTask<V>的子类RecursiveTask<V>实现的exec()方法进行处理

        protected final boolean exec() {
            result = compute();
            return true;
        }
    

    到这里应该就很清楚了,这个compute()方法就是我们自定义的任务类ForkJoinSumCalculator中实现的方法。也就是说一旦窃取到任务就直接执行了,那么execLocalTasks()方法又是在做什么呢,来看一下

            final void execLocalTasks() {
                int b = base, m, s;
                ForkJoinTask<?>[] a = array;
                if (b - (s = top - 1) <= 0 && a != null &&
                    (m = a.length - 1) >= 0) {
                    if ((config & FIFO_QUEUE) == 0) {
                        for (ForkJoinTask<?> t;;) {
                            if ((t = (ForkJoinTask<?>)U.getAndSetObject
                                 (a, ((m & s) << ASHIFT) + ABASE, null)) == null) //从top位置取出任务 
                                break;
                            U.putOrderedInt(this, QTOP, s);
                            t.doExec();
                            if (base - (s = top - 1) > 0)
                                break;
                        }
                    }
                    else
                        pollAndExecAll();
                }
            }
    

    LIFO模式会执行pollAndExecAll(),否则执行另一个分支。两个分支做的事情其实一样,都是循环执行array中的任务,不同的是一个从top取,一个从base取。

    最后来看一下awaitWork方法

        private boolean awaitWork(WorkQueue w, int r) {
            if (w == null || w.qlock < 0)                 // w is terminating
                return false;
            for (int pred = w.stackPred, spins = SPINS, ss;;) {
                if ((ss = w.scanState) >= 0)
                    break;
                else if (spins > 0) {
                    r ^= r << 6; r ^= r >>> 21; r ^= r << 7;
                    if (r >= 0 && --spins == 0) {         // 进行随机自旋
                        WorkQueue v; WorkQueue[] ws; int s, j; AtomicLong sc;
                        if (pred != 0 && (ws = workQueues) != null &&
                            (j = pred & SMASK) < ws.length &&
                            (v = ws[j]) != null &&        // see if pred parking
                            (v.parker == null || v.scanState >= 0))
                            spins = SPINS;                // continue spinning
                    }
                }
                else if (w.qlock < 0)                     // recheck after spins
                    return false;
                else if (!Thread.interrupted()) {
                    long c, prevctl, parkTime, deadline;
                    int ac = (int)((c = ctl) >> AC_SHIFT) + (config & SMASK);
                    if ((ac <= 0 && tryTerminate(false, false)) ||
                        (runState & STOP) != 0)           // pool terminating
                        return false;
                    if (ac <= 0 && ss == (int)c) {        // is last waiter
                        prevctl = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & pred);
                        int t = (short)(c >>> TC_SHIFT);  // 收缩过剩的线程
                        if (t > 2 && U.compareAndSwapLong(this, CTL, c, prevctl))
                            return false;                 // else use timed wait
                        parkTime = IDLE_TIMEOUT * ((t >= 0) ? 1 : 1 - t);
                        deadline = System.nanoTime() + parkTime - TIMEOUT_SLOP;
                    }
                    else
                        prevctl = parkTime = deadline = 0L;
                    Thread wt = Thread.currentThread();
                    U.putObject(wt, PARKBLOCKER, this);   // emulate LockSupport
                    w.parker = wt;
                    if (w.scanState < 0 && ctl == c)      // recheck before park
                        U.park(false, parkTime);
                    U.putOrderedObject(w, QPARKER, null);
                    U.putObject(wt, PARKBLOCKER, null);
                    if (w.scanState >= 0)
                        break;
                    if (parkTime != 0L && ctl == c &&
                        deadline - System.nanoTime() <= 0L &&
                        U.compareAndSwapLong(this, CTL, c, prevctl))
                        return false;                     // shrink pool
                }
            }
            return true;
        }
    

    scan方法没有窃取到任务时,会进入到这个方法,根据这个方法的返回值判断是继续去执行scan还是退出当前线程。同时判断当前线程是否是过剩线程,如果是的话将退出当前线程以收缩线程池。

    到这里pool.runWorker(workQueue)做的事基本了解了,也知道了最后执行任务的步骤调用的都是我们自定义的compute()方法,那么还是很有必要来具体了解一下这个方法的内容。

        @Override
        protected Long compute() {
            int length = end - start;
            if (length <= THRESHOLD) {//小于阈值开始进行累加
                return computeSequentially();
            }
            ForkJoinSumCalculator leftTask = new ForkJoinSumCalculator(numbers, start, start + length/2);
            leftTask.fork();
            ForkJoinSumCalculator rightTask = new ForkJoinSumCalculator(numbers, start + length/2, end);
            Long rightResult = rightTask.compute();
            Long leftResult = leftTask.join();
            return leftResult + rightResult;
        }
    

    这个方法虽然是自定义的,但其实必须遵守一个大概的实现模板,模板里必定有fork()join()方法,我们依次来看一下它们做了什么。

        public final ForkJoinTask<V> fork() {
            Thread t;
            if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
                ((ForkJoinWorkerThread)t).workQueue.push(this);
            else
                ForkJoinPool.common.externalPush(this);
            return this;
        }
    

    fork()做的事情比较简单,当前线程如果是ForkJoinWorkerThread线程就通过push方法将当前任务加入队列top端,否则执行externalPush方法,这个方法之前已经出现过了,这里就不重复介绍了。

    接着来看join()方法

        public final V join() {
            int s;
            if ((s = doJoin() & DONE_MASK) != NORMAL)
                reportException(s);
            return getRawResult();
        }
    

    这里会根据doJoin()方法的返回值来判断是否抛出异常,那么来看一下doJoin()方法。

        private int doJoin() {
            int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
            return (s = status) < 0 ? s :
                ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
                (w = (wt = (ForkJoinWorkerThread)t).workQueue).
                tryUnpush(this) && (s = doExec()) < 0 ? s :
                wt.pool.awaitJoin(w, this, 0L) :
                externalAwaitDone();
        }
    

    当前线程如果不是ForkJoinWorkerThread线程,则执行externalAwaitDone()方法阻塞当前线程,否则执行另一个判断(w = (wt = (ForkJoinWorkerThread)t).workQueue). tryUnpush(this) && (s = doExec()) < 0 ? s : wt.pool.awaitJoin(w, this, 0L)

            final boolean tryUnpush(ForkJoinTask<?> t) {
                ForkJoinTask<?>[] a; int s;
                if ((a = array) != null && (s = top) != base &&
                    U.compareAndSwapObject
                    (a, (((a.length - 1) & --s) << ASHIFT) + ABASE, t, null)) {
                    U.putOrderedInt(this, QTOP, s);
                    return true;
                }
                return false;
            }
    

    tryUnpush方法判断top端的任务取出是否成功,并且调用doExec()执行,成功则返回状态s,否则执行wt.pool.awaitJoin(w, this, 0L)awaitJoin方法会在指定任务完成或者超时前尝试帮助或阻塞自身,来具体看一下,

        final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) {
            int s = 0;
            if (task != null && w != null) {
                ForkJoinTask<?> prevJoin = w.currentJoin;
                U.putOrderedObject(w, QCURRENTJOIN, task);//记录当前等待的任务
                CountedCompleter<?> cc = (task instanceof CountedCompleter) ?
                    (CountedCompleter<?>)task : null;
                for (;;) {
                    if ((s = task.status) < 0) //任务完成则直接退出
                        break;
                    if (cc != null)
                        helpComplete(w, cc, 0);
                    else if (w.base == w.top || w.tryRemoveAndExec(task))
                        helpStealer(w, task);
                    if ((s = task.status) < 0)
                        break;
                    long ms, ns;
                    if (deadline == 0L)
                        ms = 0L;
                    else if ((ns = deadline - System.nanoTime()) <= 0L)
                        break;
                    else if ((ms = TimeUnit.NANOSECONDS.toMillis(ns)) <= 0L)
                        ms = 1L;
                    if (tryCompensate(w)) {
                        task.internalWait(ms);
                        U.getAndAddLong(this, CTL, AC_UNIT);
                    }
                }
                U.putOrderedObject(w, QCURRENTJOIN, prevJoin);
            }
            return s;
        }
    

    这里比较重要的是tryRemoveAndExechelpStealertryCompensate几个方法。
    首先来看tryRemoveAndExec方法,

            final boolean tryRemoveAndExec(ForkJoinTask<?> task) {
                ForkJoinTask<?>[] a; int m, s, b, n;
                if ((a = array) != null && (m = a.length - 1) >= 0 &&
                    task != null) {
                    while ((n = (s = top) - (b = base)) > 0) {
                        for (ForkJoinTask<?> t;;) {      // 从s遍历到b
                            long j = ((--s & m) << ASHIFT) + ABASE;
                            if ((t = (ForkJoinTask<?>)U.getObject(a, j)) == null)
                                return s + 1 == top;     // shorter than expected
                            else if (t == task) {
                                boolean removed = false;
                                if (s + 1 == top) {      // pop
                                    if (U.compareAndSwapObject(a, j, task, null)) {
                                        U.putOrderedInt(this, QTOP, s);
                                        removed = true;
                                    }
                                }
                                else if (base == b)      // replace with proxy
                                    removed = U.compareAndSwapObject(
                                        a, j, task, new EmptyTask());
                                if (removed)
                                    task.doExec();
                                break;
                            }
                            else if (t.status < 0 && s + 1 == top) {
                                if (U.compareAndSwapObject(a, j, t, null))
                                    U.putOrderedInt(this, QTOP, s);
                                break;                  // was cancelled
                            }
                            if (--n == 0)
                                return false;
                        }
                        if (task.status < 0)
                            return false;
                    }
                }
                return true;
            }
    

    该方法主要做的是去自己的队列中进行遍历,看看任务是否在top位置,在的话直接取出执行,若在队列中间,则用new EmptyTask()替换之,并取出任务执行。方法返回时若任务未执行完,则不进行后续的help动作。
    接着来看一下helpStealer方法,

        private void helpStealer(WorkQueue w, ForkJoinTask<?> task) {
            WorkQueue[] ws = workQueues;
            int oldSum = 0, checkSum, m;
            if (ws != null && (m = ws.length - 1) >= 0 && w != null &&
                task != null) {
                do {                                       // restart point
                    checkSum = 0;                          // for stability check
                    ForkJoinTask<?> subtask;
                    WorkQueue j = w, v;                    // v是子任务的stealer
                    descent: for (subtask = task; subtask.status >= 0; ) {
                        for (int h = j.hint | 1, k = 0, i; ; k += 2) {
                            if (k > m)                     // can't find stealer
                                break descent;
                            if ((v = ws[i = (h + k) & m]) != null) {
                                if (v.currentSteal == subtask) {
                                    j.hint = i;
                                    break;
                                }
                                checkSum += v.base;
                            }
                        }
                        for (;;) {                         // 帮助v执行任务
                            ForkJoinTask<?>[] a; int b;
                            checkSum += (b = v.base);
                            ForkJoinTask<?> next = v.currentJoin;
                            if (subtask.status < 0 || j.currentJoin != subtask ||
                                v.currentSteal != subtask) // stale
                                break descent;
                            if (b - v.top >= 0 || (a = v.array) == null) {
                                if ((subtask = next) == null)
                                    break descent;
                                j = v;
                                break;
                            }
                            int i = (((a.length - 1) & b) << ASHIFT) + ABASE;
                            ForkJoinTask<?> t = ((ForkJoinTask<?>)
                                                 U.getObjectVolatile(a, i));
                            if (v.base == b) {
                                if (t == null)             // stale
                                    break descent;
                                if (U.compareAndSwapObject(a, i, t, null)) {
                                    v.base = b + 1;
                                    ForkJoinTask<?> ps = w.currentSteal;
                                    int top = w.top;
                                    do {
                                        U.putOrderedObject(w, QCURRENTSTEAL, t);
                                        t.doExec();        // 清空本地任务
                                    } while (task.status >= 0 &&
                                             w.top != top &&
                                             (t = w.pop()) != null);
                                    U.putOrderedObject(w, QCURRENTSTEAL, ps);
                                    if (w.base != w.top)
                                        return;            // 自己的队列不为空了不再进行help操作
                                }
                            }
                        }
                    }
                } while (task.status >= 0 && oldSum != (oldSum = checkSum));
            }
        }
    

    该方法很长,做的事情主要是找到偷取自己任务的WorkQueue,去偷取它的任务执行。直到自己的队列不为空了,则不再进行help操作。
    最后来看一下tryCompensate方法

        private boolean tryCompensate(WorkQueue w) {
            boolean canBlock;
            WorkQueue[] ws; long c; int m, pc, sp;
            if (w == null || w.qlock < 0 ||           // caller terminating
                (ws = workQueues) == null || (m = ws.length - 1) <= 0 ||
                (pc = config & SMASK) == 0)           // parallelism disabled
                canBlock = false;
            else if ((sp = (int)(c = ctl)) != 0)      // 释放空闲线程
                canBlock = tryRelease(c, ws[sp & m], 0L);
            else {
                int ac = (int)(c >> AC_SHIFT) + pc;
                int tc = (short)(c >> TC_SHIFT) + pc;
                int nbusy = 0;                        // validate saturation
                for (int i = 0; i <= m; ++i) {        // two passes of odd indices
                    WorkQueue v;
                    if ((v = ws[((i << 1) | 1) & m]) != null) {
                        if ((v.scanState & SCANNING) != 0)
                            break;
                        ++nbusy;
                    }
                }
                if (nbusy != (tc << 1) || ctl != c)
                    canBlock = false;                 // unstable or stale
                else if (tc >= pc && ac > 1 && w.isEmpty()) {
                    long nc = ((AC_MASK & (c - AC_UNIT)) |
                               (~AC_MASK & c));       // uncompensated
                    canBlock = U.compareAndSwapLong(this, CTL, c, nc);
                }
                else if (tc >= MAX_CAP ||
                         (this == common && tc >= pc + commonMaxSpares))
                    throw new RejectedExecutionException(
                        "Thread limit exceeded replacing blocked worker");
                else {                                // similar to tryAddWorker
                    boolean add = false; int rs;      // CAS within lock
                    long nc = ((AC_MASK & c) |
                               (TC_MASK & (c + TC_UNIT)));
                    if (((rs = lockRunState()) & STOP) == 0)
                        add = U.compareAndSwapLong(this, CTL, c, nc);
                    unlockRunState(rs, rs & ~RSLOCK);
                    canBlock = add && createWorker(); // throws on exception
                }
            }
            return canBlock;
        }
    

    该方法尝试减少活跃线程,也会由于任务阻塞释放或者创建补偿线程。到此整个流程基本完整了。

    总结

    本文从数据累加的demo开始,将整个执行流程在源码层面进行了一个大概的串联,由于本人能力有限在许多标志位的使用及位运算的细节方面并没有了解的很深入,仔细去推敲的话其实里面还有很多东西可以去挖,看完源码之后不得不感叹Doug Lea大神的厉害。

    相关文章

      网友评论

          本文标题:ForkJoin源码解析

          本文链接:https://www.haomeiwen.com/subject/ayburctx.html