go 版本基于1.18
结构体
结构体定义如下:
type WaitGroup struct {
noCopy noCopy
// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
// 64-bit atomic operations require 64-bit alignment, but 32-bit
// compilers only guarantee that 64-bit fields are 32-bit aligned.
// For this reason on 32 bit architectures we need to check in state()
// if state1 is aligned or not, and dynamically "swap" the field order if
// needed.
state1 uint64
state2 uint32
}
当我们初始化一个WaitGroup对象时,其counter值、waiter值、semap值均为0
noCopy :
空结构体,它并不会占用内存,编译器也不会对其进行字节填充。它主要是为了通过go vet
工具来做静态编译检查,主要作用是防止开发者在使用WaitGroup过程中对其进行了复制,从而导致的安全隐患-
state1, state2:
主要代表三部分内容:- 通过Add()设置的子goroutine的计数值counter
- 通过Wait()陷入阻塞的waiter数
- 信号量semap
其中在64位 的操作系统中(对齐系数为8), 此时state1 的的高32 位代表计数器counter, 低32位代表waiter 数, state2 代表信号量
在32 位的操作系统中(对齐系数为4), 此时将state1 和state2 unsafe.Pointer() 转化为[3]uint32的state数组,其中state[0] 代表信号量semap, state[0]作为uint64的高32位,即counter, state[1] 作为uint64的低32位, 即waiter。 具体实现的代码如下state方法就是返回对应的计数(counter,waiter)和信号量(semap)
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) { if unsafe.Alignof(wg.state1) == 8 || uintptr(unsafe.Pointer(&wg.state1))%8 == 0 { // state1 is 64-bit aligned: nothing to do. return &wg.state1, &wg.state2 } else { // state1 is 32-bit aligned but not 64-bit aligned: this means that // (&state1)+4 is 64-bit aligned. state := (*[3]uint32)(unsafe.Pointer(&wg.state1)) return (*uint64)(unsafe.Pointer(&state[1])), &state[0] } }
方法
1. Add
- 源码实现
func (wg *WaitGroup) Add(delta int) {
// 获取计数器和信号量
statep, semap := wg.state()
// 竞争检测相关,与功能无关,忽略
if race.Enabled {
_ = *statep // trigger nil deref early
if delta < 0 {
// Synchronize decrements with Wait.
race.ReleaseMerge(unsafe.Pointer(wg))
}
race.Disable()
defer race.Enable()
}
// 计数值加上 delta: statep 的前四个字节是计数值,因此将 delta 前移 32位
state := atomic.AddUint64(statep, uint64(delta)<<32)
// 当前的counter计数值
v := int32(state >> 32)
// 当前的waiter 计数值
w := uint32(state)
// 竞争检测,忽略
if race.Enabled && delta > 0 && v == int32(delta) {
// The first increment must be synchronized with Wait.
// Need to model this as a read, because there can be
// several concurrent wg.counter transitions from 0.
race.Read(unsafe.Pointer(semap))
}
// counter 计数值<0 , 曝panic 异常
if v < 0 {
panic("sync: negative WaitGroup counter")
}
// delta > 0 && v == int32(delta) : 表示从 0 开始添加计数值
// w!=0 :表示已经有了等待者
// 说明在添加counter计数值的时候,同时添加了等待者,非法操作。添加等待者需要在添加计数值之后
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// v>0 : 计数值不等于0,不需要唤醒等待者,直接返回
// w==0: 没有等待者,不需要唤醒,直接返回
if v > 0 || w == 0 {
return
}
// This goroutine has set counter to 0 when waiters > 0.
// Now there can't be concurrent mutations of state:
// - Adds must not happen concurrently with Wait,
// - Wait does not increment waiters if it sees counter == 0.
// Still do a cheap sanity check to detect WaitGroup misuse.
// 再次检查数据是否一致
if *statep != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// 到这里说明计数值为0,且等待者大于0,需要唤醒所有的等待者,并把系统置为初始状态(0状态)
// 将计数值和等待者数量都置为0
*statep = 0
// 唤醒等待者
for ; w != 0; w-- {
runtime_Semrelease(semap, false, 0)
}
}
2. Done
- 源码
func (wg *WaitGroup) Done() {
wg.Add(-1)
}
完成一个任务,将计数值减一,当计数值减为0时,需要唤醒所有的等待者
3.Wait
- 源码
func (wg *WaitGroup) Wait() {
// 获取计数器和信号量
statep, semap := wg.state()
// 竞争检测,忽略
if race.Enabled {
_ = *statep // trigger nil deref early
race.Disable()
}
for {
// 原子操作,获取计数器值
state := atomic.LoadUint64(statep)
v := int32(state >> 32)
w := uint32(state)
// 所有任务都完成了,counter =0,此时直接退出,即不阻塞
if v == 0 {
// Counter is 0, no need to wait.
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
// waiter 计数器加一
// 这里会有竞争,比如多个 Wait 调用,或者在同时调用 Add 方法,增加不成功会继续 for 循环
if atomic.CompareAndSwapUint64(statep, state, state+1) {
if race.Enabled && w == 0 {
// Wait must be synchronized with the first Add.
// Need to model this is as a write to race with the read in Add.
// As a consequence, can do the write only for the first waiter,
// otherwise concurrent Waits will race with each other.
race.Write(unsafe.Pointer(semap))
}
// // 增加成功后,阻塞在信号量这里,等待被唤醒
runtime_Semacquire(semap)
// 被唤醒的时候,计数器应该是0状态。如果重用 WaitGroup,需要等 Wait 返回
if *statep != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
}
}
注意事项
- 保证 Add 在 Wait 前调用: 确保在子go 程中不使用Add 方法, 又可能导致和wait 造成竞争冲突,最后导致panic
- Add 函数不要传入负值,有可能导致panic 或者导致 wait 函数中 信号量P 操作死锁等待
- 不要复制使用 WaitGroup,函数传递时使用指针传递, WaitGroup 不支持复制操作, 可用go tool vet 检查是否对WaitGroup 复制使用
- 尽量不复用 WaigGroup,减少出问题的风险, 复用的前提要在wait 函数返回之后
使用示例
package main
import (
"sync"
)
type httpPkg struct{}
func (httpPkg) Get(url string) {}
var http httpPkg
func main() {
var wg sync.WaitGroup
var urls = []string{
"http://www.golang.org/",
"http://www.google.com/",
"http://www.example.com/",
}
for _, url := range urls {
// Increment the WaitGroup counter.
wg.Add(1)
// Launch a goroutine to fetch the URL.
go func(url string) {
// Decrement the counter when the goroutine completes.
defer wg.Done()
// Fetch the URL.
http.Get(url)
}(url)
}
// Wait for all HTTP fetches to complete.
wg.Wait()
}