概述
ForkJoinPool是Doug Lea 在JDK 1.7后加入的,为了充分利用多核CPU的计算能力,采用分治算法,创建多个线程、多个队列,使用不同线程处理不同的队列,且处理完自己的任务后,还会窃取其他线程的任务,达到充分使用CPU的目的。ForkJoinPool有很多使用场景,特别是JDK1.8中添加的parallel流处理和异步处理类CompletableFuture等中都有用到。而且该类比较复杂,我们要战术上重视它,耐下心看且放弃一些细枝末节,先通览整个流程。战略上小看它,前面介绍类普通线程池和定时调度线程池,我们已经知道套路了(最简单的一个流程:任务提交线程池->线程池创建线程->启动线程->线程run方法中又调用任务的run方法),它也属于线程池也是大概的逻辑。
看一下ForkJoinTask流程图
ForkJoinPool使用例子
例子依然可以在github中找到
public class ForkJoinPoolTest {
public static void main(String[] args) throws ExecutionException, InterruptedException {
ForkJoinPool forkJoinPool = new ForkJoinPool();
//显然使用IntStream.parallel().sum()可以方便得到结果
// 且parallel也是使用的ForkJoinPool,这是后话,我们本例就是测试ForkJoinTask的分解
int[] numbers = IntStream.rangeClosed(0, 1_000_000).toArray();
long begin = System.currentTimeMillis();
ForkJoinTask<Integer> submit = forkJoinPool.submit(new SumTask(numbers, 0, numbers.length - 1));
System.out.println("累加结果为:" + submit.get());
System.out.println("运算耗时:" + (System.currentTimeMillis() - begin));
}
private static class SumTask extends RecursiveTask<Integer> {
private int[] numbers;
private int from;
private int to;
public SumTask(int[] numbers, int from, int to) {
this.numbers = numbers;
this.from = from;
this.to = to;
}
@Override
protected Integer compute() {
if (to - from <= 2) {
int total = 0;
for (int i = from; i <= to; i++) {
total += numbers[i];
}
return total;
} else {
int middle = (from + to) / 2;
SumTask taskLeft = new SumTask(numbers, from, middle);
SumTask taskRight = new SumTask(numbers, middle + 1, to);
taskLeft.fork();
taskRight.fork();
return taskRight.join() + taskLeft.join();
//return taskLeft.join() + taskRight.join() ;
}
}
}
}
结果就不展示了,就是计算累加的和,这里有个注意点,可以看到compute方法中,对子任务taskLeft.fork()、taskRight.fork()后,先执行taskRight.join()再加上taskLeft.join(),如果反过来写,会发现慢将近一倍的时间,为什么是这样?我们先留个疑问在这,后面揭晓。
提交任务
submit方法
public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
//任务不允许为空
if (task == null)
throw new NullPointerException();
externalPush(task);
return task;
}
final void externalPush(ForkJoinTask<?> task) {
WorkQueue[] ws; WorkQueue q; int m;
//probe是和线程相关的一个值,线程私有
int r = ThreadLocalRandom.getProbe();
int rs = runState;
//相当于进行一次快速入队,成功则返回,不成功externalSubmit执行完整的入队
//当队列数组不为空且线程入队的队列不为空时,加锁入队
if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
(q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 && //随机到某个偶数队列中
U.compareAndSwapInt(q, QLOCK, 0, 1)) {//加锁操作,锁定workQueue
ForkJoinTask<?>[] a; int am, n, s;
if ((a = q.array) != null &&
(am = a.length - 1) > (n = (s = q.top) - q.base)) {
int j = ((am & s) << ASHIFT) + ABASE;
U.putOrderedObject(a, j, task); //任务入队
U.putOrderedInt(q, QTOP, s + 1);
U.putIntVolatile(q, QLOCK, 0); //解锁操作
if (n <= 1) //当任务数小于等于1时执行唤醒空闲线程或者创建新线程执行任务
signalWork(ws, q);
return;
}
U.compareAndSwapInt(q, QLOCK, 1, 0);
}
//完整版入队操作,可以看到如果某个外部线程第一次submit,肯定是到这里的(因为它得到的r是0)
externalSubmit(task);
}
externalSubmit方法
private void externalSubmit(ForkJoinTask<?> task) {
int r; // initialize callers probe
//如果线程的probe没有初始化,进行初始化
if ((r = ThreadLocalRandom.getProbe()) == 0) {
ThreadLocalRandom.localInit();
r = ThreadLocalRandom.getProbe();
}
//这是一个死循环,所以可以保证WorkQueue[]数组的创建, 队列的创建, 任务入队
for (;;) {
WorkQueue[] ws; WorkQueue q; int rs, m, k;
boolean move = false;
if ((rs = runState) < 0) {
tryTerminate(false, false); // help terminate
throw new RejectedExecutionException();
}
else if ((rs & STARTED) == 0 || // initialize WorkQueue[]数组的创建
((ws = workQueues) == null || (m = ws.length - 1) < 0)) {
int ns = 0;
rs = lockRunState();
try {
if ((rs & STARTED) == 0) {
U.compareAndSwapObject(this, STEALCOUNTER, null,
new AtomicLong());
// create workQueues array with size a power of two
int p = config & SMASK; // ensure at least 2 slots
int n = (p > 1) ? p - 1 : 1;
n |= n >>> 1; n |= n >>> 2; n |= n >>> 4;
n |= n >>> 8; n |= n >>> 16; n = (n + 1) << 1;
workQueues = new WorkQueue[n];
ns = STARTED;
}
} finally {
unlockRunState(rs, (rs & ~RSLOCK) | ns);
}
}
else if ((q = ws[k = r & m & SQMASK]) != null) { //任务入队
if (q.qlock == 0 && U.compareAndSwapInt(q, QLOCK, 0, 1)) {
ForkJoinTask<?>[] a = q.array;
int s = q.top;
boolean submitted = false; // initial submission or resizing
try { // locked version of push
if ((a != null && a.length > s + 1 - q.base) ||
(a = q.growArray()) != null) {
int j = (((a.length - 1) & s) << ASHIFT) + ABASE;
U.putOrderedObject(a, j, task);
U.putOrderedInt(q, QTOP, s + 1);
submitted = true;
}
} finally {
U.compareAndSwapInt(q, QLOCK, 1, 0);
}
if (submitted) { //入队成功后,唤醒或者新建一个线程,处理任务
signalWork(ws, q);
return;
}
}
move = true; // move on failure
}
else if (((rs = runState) & RSLOCK) == 0) { // create new queue 队列的创建
q = new WorkQueue(this, null);
q.hint = r;
q.config = k | SHARED_QUEUE;
q.scanState = INACTIVE;
rs = lockRunState(); // publish index
if (rs > 0 && (ws = workQueues) != null &&
k < ws.length && ws[k] == null)
ws[k] = q; // else terminated
unlockRunState(rs, rs & ~RSLOCK);
}
else
move = true; // move if busy
//如果队列加锁失败,说明被别的线程处理了,重新计算probe的值
if (move)
r = ThreadLocalRandom.advanceProbe(r);
}
}
可以看到不管是快速入队方法,还是完整入队方法,入队成功后都会调用signalWork方法。
signalWork方法
final void signalWork(WorkQueue[] ws, WorkQueue q) {
long c; int sp, i; WorkQueue v; Thread p;
//循环检查:有空闲线程唤醒空闲线程,工作线程数太少,则新建空闲线程
while ((c = ctl) < 0L) { // too few active
if ((sp = (int)c) == 0) { // no idle workers
if ((c & ADD_WORKER) != 0L) // too few workers
tryAddWorker(c); //如果工作线程太小,创建新的工作线程处理
break;
}
if (ws == null) // unstarted/terminated
break;
if (ws.length <= (i = sp & SMASK)) // terminated
break;
if ((v = ws[i]) == null) // terminating
break;
int vs = (sp + SS_SEQ) & ~INACTIVE; // next scanState
int d = sp - v.scanState; // screen CAS
long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
v.scanState = vs; // activate v
if ((p = v.parker) != null)
U.unpark(p); //唤醒阻塞线程
break;
}
if (q != null && q.base == q.top) // no more work
break;
}
}
我们看看新建线程方法
private void tryAddWorker(long c) {
boolean add = false;
//也是同样的套路,先尝试CAS修改ctl值,增加工作线程数,增加成功,调用createWorker方法
do {
long nc = ((AC_MASK & (c + AC_UNIT)) |
(TC_MASK & (c + TC_UNIT)));
if (ctl == c) {
int rs, stop; // check if terminating
if ((stop = (rs = lockRunState()) & STOP) == 0)
add = U.compareAndSwapLong(this, CTL, c, nc);
unlockRunState(rs, rs & ~RSLOCK);
if (stop != 0)
break;
if (add) {
createWorker(); //创建新线程
break;
}
}
} while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
}
createWorker 方法
private boolean createWorker() {
ForkJoinWorkerThreadFactory fac = factory;
Throwable ex = null;
ForkJoinWorkerThread wt = null;
try {
//也是和ThreadPoolExecutor一样的套路
//创建线程成功,将线程start后方法返回, 否则执行deregisterWorker进行回退操作
if (fac != null && (wt = fac.newThread(this)) != null) {
wt.start();
return true;
}
} catch (Throwable rex) {
ex = rex;
}
//注销工作线程和fac.newThread方法中的registerWorker相对
//回退操作,会减少ctl值,移除工作线程的队列,另外如果工作线程数太少会再次调用tryAddWorker方法,尝试新建线程
deregisterWorker(wt, ex);
return false;
}
我们看看ForkJoinWorkerThreadFactory.newThread做了什么?
ForkJoinWorkerThreadFactory.newThread方法
public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
return new ForkJoinWorkerThread(pool);
}
//将自己的工作队列workQueue注册到ForkJoinPool的WorkQueue[] 数组中
protected ForkJoinWorkerThread(ForkJoinPool pool) {
// Use a placeholder until a useful name can be set in registerWorker
super("aForkJoinWorkerThread");
this.pool = pool;
this.workQueue = pool.registerWorker(this);
}
final WorkQueue registerWorker(ForkJoinWorkerThread wt) {
UncaughtExceptionHandler handler;
wt.setDaemon(true); // configure thread
if ((handler = ueh) != null)
wt.setUncaughtExceptionHandler(handler);
//新建一个WorkQueue对象,这个是工作线程的WorkQueue
WorkQueue w = new WorkQueue(this, wt);
int i = 0; // assign a pool index
int mode = config & MODE_MASK;
int rs = lockRunState();
try {
WorkQueue[] ws; int n; // skip if no array
if ((ws = workQueues) != null && (n = ws.length) > 0) {
int s = indexSeed += SEED_INCREMENT; // unlikely to collide
int m = n - 1;
//得到一个奇数下标
i = ((s << 1) | 1) & m; // odd-numbered indices
if (ws[i] != null) { // collision
int probes = 0; // step by approx half n
int step = (n <= 4) ? 2 : ((n >>> 1) & EVENMASK) + 2;
while (ws[i = (i + step) & m] != null) {
if (++probes >= n) {
workQueues = ws = Arrays.copyOf(ws, n <<= 1);
m = n - 1;
probes = 0;
}
}
}
w.hint = s; // use as random seed
w.config = i | mode;
w.scanState = i; // publication fence
//将工作线程的workWueue赋值给线程池的一个奇数下标
ws[i] = w;
}
} finally {
unlockRunState(rs, rs & ~RSLOCK);
}
wt.setName(workerNamePrefix.concat(Integer.toString(i >>> 1)));
return w;
}
上面我们看到createWorker方法中,线程创建成功后,会进行thread.start,我们照旧看ForkJoinWorkerThread类的run方法吧。
ForkJoinWorkerThread.run 方法
public void run() {
if (workQueue.array == null) { // only run once
Throwable exception = null;
try {
onStart();
pool.runWorker(workQueue);
} catch (Throwable ex) {
exception = ex;
} finally {
try {
onTermination(exception);
} catch (Throwable ex) {
if (exception == null)
exception = ex;
} finally {
pool.deregisterWorker(this, exception);
}
}
}
}
run方法又调用了ForkJoinPool的runWorker方法
final void runWorker(WorkQueue w) {
//分配内存
w.growArray(); // allocate queue
int seed = w.hint; // initially holds randomization hint
int r = (seed == 0) ? 1 : seed; // avoid 0 for xorShift
for (ForkJoinTask<?> t;;) {
//进行扫描,随机窃取一个顶级任务
if ((t = scan(w, r)) != null)
w.runTask(t); //运行任务
else if (!awaitWork(w, r)) //如果窃取不到任务,进行等待
break;
r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
}
}
private ForkJoinTask<?> scan(WorkQueue w, int r) {
WorkQueue[] ws; int m;
//当线程池不为空,进行扫描
if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
int ss = w.scanState; // initially non-negative
for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
int b, n; long c;
if ((q = ws[k]) != null) {//获取workQueue
if ((n = (b = q.base) - q.top) < 0 &&
(a = q.array) != null) { // non-empty
long i = (((a.length - 1) & b) << ASHIFT) + ABASE;
if ((t = ((ForkJoinTask<?>) //获取任务
U.getObjectVolatile(a, i))) != null &&
q.base == b) {
if (ss >= 0) {
if (U.compareAndSwapObject(a, i, t, null)) {
q.base = b + 1; //更新base位置
if (n < -1) // signal others
signalWork(ws, q); //唤醒空闲线程或新建线程,帮忙处理任务
return t;
}
}
else if (oldSum == 0 && // try to activate
w.scanState < 0)
tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
}
if (ss < 0) // refresh
ss = w.scanState;
//没扫描到,扫描其他位置
r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
origin = k = r & m; // move and rescan
oldSum = checkSum = 0;
continue;
}
checkSum += b;
}
//更新workQueue下标值k 继续查找
if ((k = (k + 1) & m) == origin) { // continue until stable
//运行到这里说明已经扫描了全部的 workQueues,但并未扫描到任务
if ((ss >= 0 || (ss == (ss = w.scanState))) &&
oldSum == (oldSum = checkSum)) {
if (ss < 0 || w.qlock < 0) // already inactive
break;
//对当前WorkQueue进行inactivate 处理
int ns = ss | INACTIVE; // try to inactivate
long nc = ((SP_MASK & ns) |
(UC_MASK & ((c = ctl) - AC_UNIT)));
w.stackPred = (int)c; // hold prev stack top
U.putInt(w, QSCANSTATE, ns);
if (U.compareAndSwapLong(this, CTL, c, nc))
ss = ns;
else
w.scanState = ss; // back out
}
checkSum = 0;
}
}
}
return null;
}
扫描到任务以后,会调用任务的runTask方法
final void runTask(ForkJoinTask<?> task) {
if (task != null) {
scanState &= ~SCANNING; // mark as busy
//调用任务的doExec方法
(currentSteal = task).doExec();
U.putOrderedObject(this, QCURRENTSTEAL, null); // release for GC
execLocalTasks();
ForkJoinWorkerThread thread = owner;
if (++nsteals < 0) // collect on overflow
transferStealCount(pool);
scanState |= SCANNING;
if (thread != null)
thread.afterTopLevelExec();
}
}
final int doExec() {
int s; boolean completed;
if ((s = status) >= 0) {
try {
//调用exec方法并将返回值赋值给completed
completed = exec();
} catch (Throwable rex) {
return setExceptionalCompletion(rex);
}
if (completed)
s = setCompletion(NORMAL);
}
return s;
}
到了这里,终于快看到我们测试例子了复写的compute方法了,我们看下例子中继承的RecursiveTask类
protected final boolean exec() {
result = compute();
return true;
}
小结
上面我们看到线程池提交任务,放到一个workQueue数组的一个偶数下标的队列中,然后新建一个工作线程,工作线程中初始化一个workQueue放入workQueue数组奇数下标中。\
fork方法
public final ForkJoinTask<V> fork() {
Thread t;
//如果是ForkJoinWorkerThread 线程fork出来的,push到自己的workQueue中
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
((ForkJoinWorkerThread)t).workQueue.push(this);
else
ForkJoinPool.common.externalPush(this); //否则push到common池中
return this;
}
push 方法
final void push(ForkJoinTask<?> task) {
ForkJoinTask<?>[] a; ForkJoinPool p;
int b = base, s = top, n;
if ((a = array) != null) { // ignore if queue removed
int m = a.length - 1; // fenced write for task visibility
U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task); //任务入队
U.putOrderedInt(this, QTOP, s + 1);
if ((n = s - b) <= 1) {
if ((p = pool) != null)
p.signalWork(p.workQueues, this);
}
else if (n >= m) //数组满了,进行扩容
growArray();
}
}
compute中调用子任务的fork后,就会将子任务入队了,然后taskRight.join等待子任务处理完成。我们看看join方法的逻辑。
//等待任务执行完成并返回结果
public final V join() {
int s;
if ((s = doJoin() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
private int doJoin() {
int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
//tryUnpush判断当前任务是栈顶任务,直接进行处理(即调子任务的compute方法),否则进入awaitJoin方法
return (s = status) < 0 ? s :
((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
(w = (wt = (ForkJoinWorkerThread)t).workQueue).
tryUnpush(this) && (s = doExec()) < 0 ? s :
wt.pool.awaitJoin(w, this, 0L) :
externalAwaitDone();
}
await方法
final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) {
int s = 0;
if (task != null && w != null) {
ForkJoinTask<?> prevJoin = w.currentJoin;
U.putOrderedObject(w, QCURRENTJOIN, task);
CountedCompleter<?> cc = (task instanceof CountedCompleter) ?
(CountedCompleter<?>)task : null;
for (;;) {
if ((s = task.status) < 0)
break;
//如果是CountedCompleter任务,执行helpComplete
if (cc != null)
helpComplete(w, cc, 0);
//这里比较关键,如果队列不为空,会再执行tryRemoveAndExec
else if (w.base == w.top || w.tryRemoveAndExec(task))
helpStealer(w, task);//如果队列是空或者遇到的任务都被别的线程执行过了,就偷个任务做
if ((s = task.status) < 0)
break;
long ms, ns;
if (deadline == 0L)
ms = 0L;
else if ((ns = deadline - System.nanoTime()) <= 0L)
break;
else if ((ms = TimeUnit.NANOSECONDS.toMillis(ns)) <= 0L)
ms = 1L;
//尝试释放一个线程或新建一个线程
if (tryCompensate(w)) {
//阻塞自己
task.internalWait(ms);
U.getAndAddLong(this, CTL, AC_UNIT);
}
}
U.putOrderedObject(w, QCURRENTJOIN, prevJoin);
}
return s;
}
tryRemoveAndExec方法
final boolean tryRemoveAndExec(ForkJoinTask<?> task) {
ForkJoinTask<?>[] a; int m, s, b, n;
if ((a = array) != null && (m = a.length - 1) >= 0 &&
task != null) {
while ((n = (s = top) - (b = base)) > 0) {
//遍历整个队列,如果队列中存在此子任务,进行调用doExec
for (ForkJoinTask<?> t;;) { // traverse from s to b
long j = ((--s & m) << ASHIFT) + ABASE;
if ((t = (ForkJoinTask<?>)U.getObject(a, j)) == null)
return s + 1 == top; // shorter than expected
else if (t == task) {
boolean removed = false;
if (s + 1 == top) { // pop
if (U.compareAndSwapObject(a, j, task, null)) {
U.putOrderedInt(this, QTOP, s);
removed = true;
}
}
else if (base == b) // replace with proxy
removed = U.compareAndSwapObject(
a, j, task, new EmptyTask());
if (removed)
task.doExec();
break;
}
else if (t.status < 0 && s + 1 == top) {
if (U.compareAndSwapObject(a, j, t, null))
U.putOrderedInt(this, QTOP, s);
break; // was cancelled
}
if (--n == 0)
return false;
}
if (task.status < 0)
return false;
}
}
return true;
}
至此整个流程就串起来了,例子中的SumTask类的compute方法执行后,会创建子任务,子任务.fork()会将任务入队,子任务.join()时,会执行子任务的compute方法。
join方法的分析完后,我们可以回答taskRight.join() + taskLeft.join()会更高效?
因为调用taskLeft.fork会将taskLeft入队,taskRight.fork会将taskRight入队,接下来如果执行taskRight.join(),taskRight这时候是栈顶任务,直接tryUnpush执行,不需要再遍历队列。