介绍CountDownLatch之前,我相信很多人在学习的时候是不清楚这个CountDownLatch的使用场景是啥。为了回答这个问题,简单说个小段子。
老李家有两个熊孩子小A和小B,小A和小B每天放学后自己回家,到家后都需要老李来开门,不要问我为啥不给小A和小B一把钥匙。由于不是一个年级的,放学的时间不同,每天都需要老李开两次门,有一天老李怒了,告诉两个熊孩子,以后到家了必须敲下门,在门口喊一声,老李听到两个孩子的敲门声再去敲门,不要问我小A和小B是亲生的不。
其实,上面这个例子就是CountDownLatch的使用场景,小A和小B到家时间不同相当于两个线程的执行时间不同,小A和小B每次回家必须喊一次相当于线程间的通信,老李只有听到两个孩子的敲门声才会去敲门相当于主线程不再阻塞,向下进行。
再举个最近项目中的使用场景。
最近在做图像识别的一个项目,需要上传图片到华为云的modelart服务来获取图片的识别信息,然后对返回信息进行处理,分析出想要的信息。
由于有些产品是需要同时上传两张图片,然后再根据返回的信息进行处理。上传一张图片等待返回信息这个过程的时间大概是3-5秒,上传两张图片,需要访问两次华为云modelart服务,如果使用串行方式的话,那么需要花费10s左右的时间,这里就想到了可以使用CountDownLatch,等待这两个上传操作的线程结束拿到返回信息后,再调用后面的接口来分析这两个图片的信息。
这里,就简单介绍完了CountDownLatch的使用场景,下面简单说下CountDownLatch的使用,直接给出CountDownLatch源码中的例子。
* class Driver2 { // ...
* void main() throws InterruptedException {
* CountDownLatch doneSignal = new CountDownLatch(N);
* Executor e = ...
*
* for (int i = 0; i < N; ++i) // create and start threads
* e.execute(new WorkerRunnable(doneSignal, i));
*
* doneSignal.await(); // wait for all to finish
* }
* }
*
* class WorkerRunnable implements Runnable {
* private final CountDownLatch doneSignal;
* private final int i;
* WorkerRunnable(CountDownLatch doneSignal, int i) {
* this.doneSignal = doneSignal;
* this.i = i;
* }
* public void run() {
* try {
* doWork(i);
* doneSignal.countDown();
* } catch (InterruptedException ex) {} // return;
* }
*
* void doWork() { ... }
* }}
首先看main方法,一开始根据需要等待的线程数,初始化CountDownLatch,然后启动线程,线程结束后调用CountDownLatch的countDown方法,当调用countDownLatch的counDown次数和初始化CountDownLatch的线程数相同时,主线程中的CountDownLatch的await方法不再阻塞,往下进行。
使用很简单,主要看源码实现。
CountDownLatch的底层实现是使用AQS队列实现,对AQS的不熟悉的同学可以看下方腾飞的《java并发编程的艺术》这本书或者看下这个AQS。
首先看下await方法。
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
sync这个实例是什么类型的呢
public class CountDownLatch {
/**
* Synchronization control For CountDownLatch.
* Uses AQS state to represent count.
*/
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
private final Sync sync;
}
Sync类继承了AbstractQueuedSynchronizer(AQS), 通过state值的大小来控制锁的获取。下面根据CountDownLatch的使用来分析下源码。
(1)创建CountDownLatch实例时。
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
这里就可以很清楚的看到,这里会初始化AQS队列的state值的大小,state值其实就是需要等待线程数的大小。
(2)主线程调用CountDownLatch的await方法,阻塞主线程,等待其他线程执行结束。
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
首先调用tryAcquireShared获取当前state的值,如果值为0返回1,说明其他线程执行结束,不再阻塞。如果值不为0,则返回-1,说明其他线程还未执行结束,需要调用doAcquireSharedInterruptibly方法阻塞等待。
下面看下这个方法的实现。
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED); ##队列中插入node节点,保存线程信息
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor(); ##获取node节点的前一个节点
if (p == head) { ## 判断p节点是否是头结点
int r = tryAcquireShared(arg); ##获取state值得大小
if (r >= 0) { ## r>=0 说明state值为0
setHeadAndPropagate(node, r); ##设置头结点并且触发队列中头结点的下一个节点是否是共享节点,如果是的话,下个节点对应的线程也不再阻塞,具有传播特性。
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt()) ## 阻塞调用此方法的线程
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
上面的注释已经说明上面方法中整个的处理过程,其中setHeadAndPropagate和shouldParkAfterFailedAcquire还需要详细分析一下,首先看下setHeadAndPropagate方法。
private void setHeadAndPropagate(Node node, int propagate) {
Node h = head; // Record old head for check below
setHead(node);
if (propagate > 0 || h == null || h.waitStatus < 0 ||
(h = head) == null || h.waitStatus < 0) {
Node s = node.next;
if (s == null || s.isShared())
doReleaseShared();
}
}
执行此方法的前提是node的前一个节点是head节点,并且state值为0。在这个方法里,首先将当前的node节点设置为head节点,然后根据propagate这个值的大小,判断是否获取node节点的下一个节点,然后根据下一个节点是否是共享式类型的节点,来释放下个节点对应的线程,使下个节点的线程也不再阻塞,propagate使线程的释放具有了传播性,从队列的头结点开始,只要头结点不再阻塞,也可以使队列中的其他共享节点也不再阻塞,具有了传播性。
然后看下shouldParkAfterFailedAcquire方法的实现。
private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
int ws = pred.waitStatus;
if (ws == Node.SIGNAL)
return true;
if (ws > 0) {
do {
node.prev = pred = pred.prev;
} while (pred.waitStatus > 0);
pred.next = node;
} else {
compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
}
return false;
}
这个方法的目的主要是获取state值不为0时,是否阻塞此线程。如果此方法返回true则会调用parkAndCheckInterrupt这个方法,在这个方法里调用LockSupport的park方法阻塞此线程。那么阻塞后,什么时候唤醒这个线程呢,想要解决这个疑问就需要看下CountDownLatch的countDown方法的处理逻辑了。
(3) 线程执行完,调用CountDownLatch的countDown方法。
public void countDown() {
sync.releaseShared(1);
}
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
首先,在tryReleaseShared方法中将state值的大小减一,然后执行doReleaseShared方法,
private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}
在doReleaseShared方法中通过unparkSuccessor获取head节点的下一个节点的thread信息,然后执行LockSupport的unpark方法,这样的话之前await方法中阻塞的线程就不再阻塞,继续往下执行。
通过研究CountDownLatch的这三个方法,基本理解了底层实现,另外,如果能看懂这几个方法的源码,其实对AQS的源码也已经了解的差不多了,后面可以去看下Lock的源码,也是基于AQS实现的。