1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
| public class CountDownLatch {
// 内部类实现 AQS
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) { // 使用 AQS 的 state 作为计数器
setState(count);
}
int getCount() {
return getState();
}
// 采用 **共享锁** 机制,因为可以被不同的线程 countdown
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// 每次执行将 state-1,直到 state=0
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc)) // 通过 CAS 改变 state 的值,失败会进行下一轮循环
return nextc == 0;
}
}
}
private final Sync sync;
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
// 通过 acquireSharedInterruptibly 获取共享锁
// 若 state!=0 会被持续阻塞
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1); // 1 这里是随意的参数,在 coutdownlatch 中无意义
}
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
// 解锁一次
public void countDown() {
sync.releaseShared(1);
}
// 获取当前计数
public long getCount() {
return sync.getCount();
}
public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}
}
|