Skip to content

Commit 064771b

Browse files
committed
create WaitGroupWithStopOnError option to cancel other tasks and NewWaitGroupWithContext function to create WaitGroup with custom context
1 parent 5ac9a31 commit 064771b

3 files changed

Lines changed: 116 additions & 9 deletions

File tree

README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,52 @@ if err := wg.Wait(); err != nil {
184184
// oh, something bad happened in one of routines above.
185185
}
186186
```
187+
## Stop all tasks on first error
188+
```go
189+
import (
190+
"github.com/mrsoftware/errors"
191+
)
192+
193+
// in this example we are using ants goroutine pool.
194+
wg := errors.NewWaitGroup(errors.WaitGroupWithStopOnError())
195+
196+
ctx := wg.Context()
197+
198+
wg.Do(func() error {
199+
return callingHttpClient(ctx)
200+
})
201+
202+
wg.Do(func() error {
203+
return callingHttpClient(ctx)
204+
})
205+
206+
// if one of above task failed, context will cancel and other task will stop (the task must ba aware of context cancellation like http pkg do)
187207

208+
if err := wg.Wait(); err != nil {
209+
// oh, something bad happened in one of routines above.
210+
}
211+
```
212+
**or you can use NewWaitGroupWithContext method:**
213+
```go
214+
import (
215+
"github.com/mrsoftware/errors"
216+
)
217+
218+
// in this example we are using ants goroutine pool.
219+
ctx, wg := errors.NewWaitGroupWithContext(context.Background(), errors.WaitGroupWithStopOnError())
220+
221+
wg.Do(func() error {
222+
return callingHttpClient(ctx)
223+
})
224+
225+
wg.Do(func() error {
226+
return callingHttpClient(ctx)
227+
})
228+
229+
if err := wg.Wait(); err != nil {
230+
// oh, something bad happened in one of routines above.
231+
}
232+
```
188233

189234
for mode details, check the [documentation](https://godoc.org/github.com/mrsoftware/errors)
190235

waitGroup.go

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,30 @@
11
package errors
22

33
import (
4+
"context"
45
"sync"
56
)
67

78
// WaitGroup is sync.WaitGroup with error support.
89
type WaitGroup struct {
9-
noCopy noCopy
10-
options *WaitGroupOptions
11-
errors MultiError
12-
gch chan struct{}
10+
noCopy noCopy
11+
options *WaitGroupOptions
12+
errors MultiError
13+
gch chan struct{}
14+
ctx context.Context
15+
cancel context.CancelCauseFunc
16+
cancelOnce sync.Once
1317
}
1418

1519
// NewWaitGroup create new WaitGroup.
1620
func NewWaitGroup(options ...WaitGroupOption) *WaitGroup {
21+
_, wg := NewWaitGroupWithContext(context.Background(), options...)
22+
23+
return wg
24+
}
25+
26+
// NewWaitGroupWithContext create new WaitGroup with custom context.
27+
func NewWaitGroupWithContext(ctx context.Context, options ...WaitGroupOption) (context.Context, *WaitGroup) {
1728
ops := &WaitGroupOptions{
1829
Wg: &sync.WaitGroup{},
1930
TaskRunner: func(task func()) { go task() },
@@ -28,13 +39,27 @@ func NewWaitGroup(options ...WaitGroupOption) *WaitGroup {
2839
gch = make(chan struct{}, ops.TaskLimit)
2940
}
3041

31-
return &WaitGroup{options: ops, gch: gch}
42+
ctx, cancel := context.WithCancelCause(context.Background())
43+
44+
return ctx, &WaitGroup{options: ops, gch: gch, ctx: ctx, cancel: cancel}
45+
}
46+
47+
// Context of current waitGroup.
48+
func (g *WaitGroup) Context() context.Context {
49+
return g.ctx
50+
}
51+
52+
// Stop send cancel signal to all tasks.
53+
func (g *WaitGroup) Stop(err error) {
54+
g.cancelOnce.Do(func() { g.cancel(err) })
3255
}
3356

3457
// Wait is sync.WaitGroup.Wait.
35-
func (g *WaitGroup) Wait() error {
58+
func (g *WaitGroup) Wait() (err error) {
3659
g.options.Wg.Wait()
3760

61+
defer func() { g.Stop(err) }()
62+
3863
if g.errors.Len() == 0 {
3964
return nil
4065
}
@@ -55,6 +80,10 @@ func (g *WaitGroup) Done(err error) {
5580
return
5681
}
5782

83+
if g.options.StopOnError {
84+
g.Stop(err)
85+
}
86+
5887
g.errors.Add(err)
5988
}
6089

@@ -84,9 +113,10 @@ type noCopy struct{}
84113

85114
// WaitGroupOptions for WaitGroup.
86115
type WaitGroupOptions struct {
87-
Wg *sync.WaitGroup
88-
TaskLimit int
89-
TaskRunner WaitGroupTaskRunner
116+
Wg *sync.WaitGroup
117+
TaskLimit int
118+
TaskRunner WaitGroupTaskRunner
119+
StopOnError bool
90120
}
91121

92122
type WaitGroupOption func(group *WaitGroupOptions)
@@ -114,3 +144,10 @@ func WaitGroupWithTaskRunner(runner WaitGroupTaskRunner) WaitGroupOption {
114144
g.TaskRunner = runner
115145
}
116146
}
147+
148+
// WaitGroupWithStopOnError used if you want to stop all tasks on first error.
149+
func WaitGroupWithStopOnError() WaitGroupOption {
150+
return func(g *WaitGroupOptions) {
151+
g.StopOnError = true
152+
}
153+
}

waitGroup_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package errors
22

33
import (
4+
"context"
45
"errors"
56
"sync"
67
"sync/atomic"
@@ -131,6 +132,30 @@ func TestGroup(t *testing.T) {
131132
assert.Nil(t, err)
132133
assert.True(t, isUsedCustomRunner)
133134
})
135+
136+
t.Run("set StopOnError options", func(t *testing.T) {
137+
error1 := errors.New("error 1")
138+
139+
ctx, wg := NewWaitGroupWithContext(context.Background(), WaitGroupWithStopOnError())
140+
141+
wg.Do(func() error { return error1 })
142+
143+
// sample long-running and context aware task.
144+
wg.Do(func() error {
145+
for {
146+
select {
147+
case <-ctx.Done():
148+
return ctx.Err()
149+
}
150+
}
151+
})
152+
153+
err := wg.Wait()
154+
155+
expected := NewMultiError(error1, context.Canceled)
156+
assert.ElementsMatch(t, expected.errors, err.(*MultiError).errors)
157+
assert.Equal(t, ctx, wg.Context())
158+
})
134159
}
135160

136161
// all below test cases are copied from sync/waitgroup_test.go and transformed to group.

0 commit comments

Comments
 (0)