Featured image of post Java 并发 - CountDownLatch 实现原理

Java 并发 - CountDownLatch 实现原理

基本实现思路:

  • 利用 共享锁 实现

  • 初始化时,state=count 即已经上了 count 次共享锁

  • await() 即加共享锁,必须 state=0 时才能加锁成功,否则按照 AQS 机制,进入等待队列阻塞

  • countDown() 解一次锁,直到为 0

源码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
public class CountDownLatch {

    // 内部类实现 AQS
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {  // 使用 AQS 的 state 作为计数器
            setState(count);
        }

        int getCount() {
            return getState();
        }

        // 采用 **共享锁** 机制,因为可以被不同的线程 countdown
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // 每次执行将 state-1,直到 state=0
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))  // 通过 CAS 改变 state 的值,失败会进行下一轮循环
                    return nextc == 0;
            }
        }
    }

    private final Sync sync;

    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    // 通过 acquireSharedInterruptibly 获取共享锁
    // 若 state!=0 会被持续阻塞
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);  // 1 这里是随意的参数,在 coutdownlatch 中无意义
    }

    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    // 解锁一次
    public void countDown() {
        sync.releaseShared(1);
    }

    // 获取当前计数
    public long getCount() {
        return sync.getCount();
    }

    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }
}