fix(GODT-3124): Race condition in sync task waiter
Fix incorrect use of `sync.WaitGroup` use to wait on sync jobs that fail. After calling `WaitGroup.Wait()` it is advised to call `WaitGroup.Add` until the existing counter has reached 0. The code has been updated with a different mechanism that achieves the same behavior which was previously available.
This commit is contained in:
parent
6d7c21b2c9
commit
7d13c99710
|
@ -210,12 +210,10 @@ func (t *Handler) run(ctx context.Context,
|
|||
stageContext.metadataFetched = syncStatus.NumSyncedMessages
|
||||
stageContext.totalMessageCount = syncStatus.TotalMessageCount
|
||||
|
||||
defer stageContext.Close()
|
||||
|
||||
t.regulator.Sync(ctx, stageContext)
|
||||
|
||||
// Wait on reply
|
||||
if err := stageContext.wait(ctx); err != nil {
|
||||
if err := stageContext.waitAndClose(ctx); err != nil {
|
||||
return fmt.Errorf("failed sync messages: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -19,9 +19,6 @@ package syncservice
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
|
@ -48,10 +45,8 @@ type Job struct {
|
|||
updateApplier UpdateApplier
|
||||
syncReporter Reporter
|
||||
|
||||
log *logrus.Entry
|
||||
errorCh *async.QueuedChannel[error]
|
||||
wg sync.WaitGroup
|
||||
once sync.Once
|
||||
log *logrus.Entry
|
||||
jw *jobWaiter
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
downloadCache *DownloadCache
|
||||
|
@ -74,7 +69,7 @@ func NewJob(ctx context.Context,
|
|||
) *Job {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &Job{
|
||||
j := &Job{
|
||||
ctx: ctx,
|
||||
client: client,
|
||||
userID: userID,
|
||||
|
@ -85,26 +80,23 @@ func NewJob(ctx context.Context,
|
|||
messageBuilder: messageBuilder,
|
||||
updateApplier: updateApplier,
|
||||
syncReporter: syncReporter,
|
||||
errorCh: async.NewQueuedChannel[error](4, 8, panicHandler, fmt.Sprintf("sync-job-error-%v", userID)),
|
||||
panicHandler: panicHandler,
|
||||
downloadCache: cache,
|
||||
jw: newJobWaiter(log.WithField("sync-job", "waiter"), panicHandler),
|
||||
}
|
||||
|
||||
j.jw.begin()
|
||||
|
||||
return j
|
||||
}
|
||||
|
||||
func (j *Job) Close() {
|
||||
j.errorCh.CloseAndDiscardQueued()
|
||||
j.wg.Wait()
|
||||
func (j *Job) close() {
|
||||
j.jw.close()
|
||||
}
|
||||
|
||||
func (j *Job) onError(err error) {
|
||||
defer j.wg.Done()
|
||||
defer j.jw.onTaskFinished(err)
|
||||
|
||||
// context cancelled is caught & handled in a different location.
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
|
||||
j.errorCh.Enqueue(err)
|
||||
j.cancel()
|
||||
}
|
||||
|
||||
|
@ -119,55 +111,42 @@ func (j *Job) onJobFinished(ctx context.Context, lastMessageID string, count int
|
|||
return
|
||||
}
|
||||
|
||||
// j.onError() also calls j.wg.Done().
|
||||
j.wg.Done()
|
||||
// j.onError() also calls j.jw.onTaskFinished().
|
||||
defer j.jw.onTaskFinished(nil)
|
||||
j.syncReporter.OnProgress(ctx, count)
|
||||
}
|
||||
|
||||
// begin is expected to be called once the job enters the pipeline.
|
||||
func (j *Job) begin() {
|
||||
j.log.Info("Job started")
|
||||
j.wg.Add(1)
|
||||
j.startChildWaiter()
|
||||
j.jw.onTaskCreated()
|
||||
}
|
||||
|
||||
// end is expected to be called once the job has no further work left.
|
||||
func (j *Job) end() {
|
||||
j.log.Info("Job finished")
|
||||
j.wg.Done()
|
||||
j.jw.onTaskFinished(nil)
|
||||
}
|
||||
|
||||
// wait waits until the job has finished, the context got cancelled or an error occurred.
|
||||
func (j *Job) wait(ctx context.Context) error {
|
||||
defer j.wg.Wait()
|
||||
|
||||
// waitAndClose waits until the job has finished, the context got cancelled or an error occurred.
|
||||
func (j *Job) waitAndClose(ctx context.Context) error {
|
||||
defer j.close()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
j.cancel()
|
||||
j.jw.onContextCancelled()
|
||||
<-j.jw.doneCh
|
||||
return ctx.Err()
|
||||
case err := <-j.errorCh.GetChannel():
|
||||
return err
|
||||
case e := <-j.jw.doneCh:
|
||||
return e
|
||||
}
|
||||
}
|
||||
|
||||
func (j *Job) newChildJob(messageID string, messageCount int64) childJob {
|
||||
j.log.Infof("Creating new child job")
|
||||
j.wg.Add(1)
|
||||
j.jw.onTaskCreated()
|
||||
return childJob{job: j, lastMessageID: messageID, messageCount: messageCount}
|
||||
}
|
||||
|
||||
func (j *Job) startChildWaiter() {
|
||||
j.once.Do(func() {
|
||||
go func() {
|
||||
defer async.HandlePanic(j.panicHandler)
|
||||
|
||||
j.wg.Wait()
|
||||
j.log.Info("All child jobs succeeded")
|
||||
j.errorCh.Enqueue(j.ctx.Err())
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// childJob represents a batch of work that goes down the pipeline. It keeps track of the message ID that is in the
|
||||
// batch and the number of messages in the batch.
|
||||
type childJob struct {
|
||||
|
@ -232,7 +211,7 @@ func (s *childJob) checkCancelled() bool {
|
|||
err := s.job.ctx.Err()
|
||||
if err != nil {
|
||||
s.job.log.Infof("Child job exit due to context cancelled")
|
||||
s.job.wg.Done()
|
||||
s.job.jw.onTaskFinished(err)
|
||||
return true
|
||||
}
|
||||
|
||||
|
@ -242,3 +221,102 @@ func (s *childJob) checkCancelled() bool {
|
|||
func (s *childJob) getContext() context.Context {
|
||||
return s.job.ctx
|
||||
}
|
||||
|
||||
type JobWaiterMessage int
|
||||
|
||||
const (
|
||||
JobWaiterMessageCreated JobWaiterMessage = iota
|
||||
JobWaiterMessageFinished
|
||||
JobWaiterMessageCtxErr
|
||||
)
|
||||
|
||||
type jobWaiterMessagePair struct {
|
||||
m JobWaiterMessage
|
||||
err error
|
||||
}
|
||||
|
||||
// jobWaiter is meant to be used to track ongoing sync batches. Once all the child jobs
|
||||
// have completed, the first recorded error (if any) will be written to doneCh and then this
|
||||
// channel will be closed.
|
||||
type jobWaiter struct {
|
||||
ch chan jobWaiterMessagePair
|
||||
doneCh chan error
|
||||
log *logrus.Entry
|
||||
panicHandler async.PanicHandler
|
||||
}
|
||||
|
||||
func newJobWaiter(log *logrus.Entry, panicHandler async.PanicHandler) *jobWaiter {
|
||||
return &jobWaiter{
|
||||
ch: make(chan jobWaiterMessagePair),
|
||||
doneCh: make(chan error),
|
||||
log: log,
|
||||
panicHandler: panicHandler,
|
||||
}
|
||||
}
|
||||
|
||||
func (j *jobWaiter) close() {
|
||||
close(j.ch)
|
||||
}
|
||||
|
||||
func (j *jobWaiter) sendMessage(m JobWaiterMessage, err error) {
|
||||
j.ch <- jobWaiterMessagePair{
|
||||
m: m,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (j *jobWaiter) onTaskFinished(err error) {
|
||||
j.sendMessage(JobWaiterMessageFinished, err)
|
||||
}
|
||||
|
||||
func (j *jobWaiter) onTaskCreated() {
|
||||
j.sendMessage(JobWaiterMessageCreated, nil)
|
||||
}
|
||||
|
||||
func (j *jobWaiter) onContextCancelled() {
|
||||
j.sendMessage(JobWaiterMessageCtxErr, nil)
|
||||
}
|
||||
|
||||
func (j *jobWaiter) begin() {
|
||||
go func() {
|
||||
defer async.HandlePanic(j.panicHandler)
|
||||
|
||||
total := 0
|
||||
var err error
|
||||
|
||||
defer func() {
|
||||
j.doneCh <- err
|
||||
close(j.doneCh)
|
||||
}()
|
||||
|
||||
for {
|
||||
m, ok := <-j.ch
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
switch m.m {
|
||||
case JobWaiterMessageCtxErr:
|
||||
// DO nothing
|
||||
case JobWaiterMessageCreated:
|
||||
total++
|
||||
case JobWaiterMessageFinished:
|
||||
total--
|
||||
if m.err != nil && err == nil {
|
||||
err = m.err
|
||||
}
|
||||
default:
|
||||
j.log.Errorf("Unknown message type: %v", m.m)
|
||||
continue
|
||||
}
|
||||
|
||||
if total <= 0 {
|
||||
if total < 0 {
|
||||
logrus.Errorf("Child count less than 0, shouldn't happen...")
|
||||
}
|
||||
j.log.Info("All child jobs completed")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ package syncservice
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
|
@ -56,8 +57,7 @@ func TestJob_WaitsOnChildren(t *testing.T) {
|
|||
tj.job.end()
|
||||
}()
|
||||
|
||||
require.NoError(t, tj.job.wait(context.Background()))
|
||||
tj.job.Close()
|
||||
require.NoError(t, tj.job.waitAndClose(context.Background()))
|
||||
}
|
||||
|
||||
func TestJob_WaitsOnAllChildrenOnError(t *testing.T) {
|
||||
|
@ -73,18 +73,22 @@ func TestJob_WaitsOnAllChildrenOnError(t *testing.T) {
|
|||
|
||||
jobErr := errors.New("failed")
|
||||
|
||||
startCh := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
job1 := tj.job.newChildJob("1", 0)
|
||||
job2 := tj.job.newChildJob("2", 1)
|
||||
|
||||
<-startCh
|
||||
|
||||
job1.onFinished(context.Background())
|
||||
job2.onError(jobErr)
|
||||
}()
|
||||
|
||||
err := tj.job.wait(context.Background())
|
||||
close(startCh)
|
||||
err := tj.job.waitAndClose(context.Background())
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, jobErr)
|
||||
tj.job.Close()
|
||||
}
|
||||
|
||||
func TestJob_MultipleChildrenReportError(t *testing.T) {
|
||||
|
@ -99,20 +103,22 @@ func TestJob_MultipleChildrenReportError(t *testing.T) {
|
|||
|
||||
startCh := make(chan struct{})
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
job := tj.job.newChildJob("1", 0)
|
||||
wg.Done()
|
||||
<-startCh
|
||||
|
||||
job.onError(jobErr)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(startCh)
|
||||
err := tj.job.wait(context.Background())
|
||||
err := tj.job.waitAndClose(context.Background())
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, jobErr)
|
||||
tj.job.Close()
|
||||
}
|
||||
|
||||
func TestJob_ChildFailureCancelsAllOtherChildJobs(t *testing.T) {
|
||||
|
@ -127,8 +133,12 @@ func TestJob_ChildFailureCancelsAllOtherChildJobs(t *testing.T) {
|
|||
|
||||
failJob := tj.job.newChildJob("0", 1)
|
||||
|
||||
tj.job.begin()
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
job := tj.job.newChildJob("1", 0)
|
||||
<-job.getContext().Done()
|
||||
require.ErrorIs(t, job.getContext().Err(), context.Canceled)
|
||||
|
@ -137,12 +147,13 @@ func TestJob_ChildFailureCancelsAllOtherChildJobs(t *testing.T) {
|
|||
}
|
||||
go func() {
|
||||
failJob.onError(jobErr)
|
||||
wg.Wait()
|
||||
tj.job.end()
|
||||
}()
|
||||
|
||||
err := tj.job.wait(context.Background())
|
||||
err := tj.job.waitAndClose(context.Background())
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, jobErr)
|
||||
tj.job.Close()
|
||||
}
|
||||
|
||||
func TestJob_CtxCancelCancelsAllChildren(t *testing.T) {
|
||||
|
@ -154,9 +165,12 @@ func TestJob_CtxCancelCancelsAllChildren(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
job := tj.job.newChildJob("1", 0)
|
||||
wg.Done()
|
||||
<-job.getContext().Done()
|
||||
require.ErrorIs(t, job.getContext().Err(), context.Canceled)
|
||||
require.True(t, job.checkCancelled())
|
||||
|
@ -164,13 +178,35 @@ func TestJob_CtxCancelCancelsAllChildren(t *testing.T) {
|
|||
}
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
err := tj.job.waitAndClose(ctx)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
}
|
||||
|
||||
func TestJob_CtxCancelBeforeBegin(t *testing.T) {
|
||||
options := setupGoLeak()
|
||||
defer goleak.VerifyNone(t, options)
|
||||
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
tj := newTestJob(ctx, mockCtrl, "u", getTestLabels())
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
wg.Done()
|
||||
err := tj.job.waitAndClose(ctx)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
tj.job.Close()
|
||||
}
|
||||
|
||||
func TestJob_WithoutChildJobsCanBeTerminated(t *testing.T) {
|
||||
|
@ -186,9 +222,8 @@ func TestJob_WithoutChildJobsCanBeTerminated(t *testing.T) {
|
|||
tj.job.begin()
|
||||
tj.job.end()
|
||||
}()
|
||||
err := tj.job.wait(ctx)
|
||||
err := tj.job.waitAndClose(context.Background())
|
||||
require.NoError(t, err)
|
||||
tj.job.Close()
|
||||
}
|
||||
|
||||
type tjob struct {
|
||||
|
|
|
@ -55,7 +55,7 @@ func TestApplyStage_CancelledJobIsDiscarded(t *testing.T) {
|
|||
messages: nil,
|
||||
})
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
err := tj.job.waitAndClose(ctx)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
cancel()
|
||||
}
|
||||
|
@ -89,7 +89,7 @@ func TestApplyStage_JobWithNoMessagesIsFinalized(t *testing.T) {
|
|||
messages: nil,
|
||||
})
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
err := tj.job.waitAndClose(ctx)
|
||||
cancel()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
@ -132,7 +132,7 @@ func TestApplyStage_ErrorOnApplyIsReportedAndJobFails(t *testing.T) {
|
|||
messages: buildResults,
|
||||
})
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
err := tj.job.waitAndClose(ctx)
|
||||
cancel()
|
||||
require.ErrorIs(t, err, applyErr)
|
||||
}
|
||||
|
|
|
@ -269,7 +269,7 @@ func TestBuildStage_OtherErrorsFailJob(t *testing.T) {
|
|||
|
||||
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
err := tj.job.waitAndClose(ctx)
|
||||
require.Equal(t, expectedErr, err)
|
||||
|
||||
cancel()
|
||||
|
|
|
@ -313,7 +313,7 @@ func TestDownloadStage_JobAbortsOnMessageDownloadError(t *testing.T) {
|
|||
ids: []string{"foo"},
|
||||
})
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
err := tj.job.waitAndClose(ctx)
|
||||
require.Equal(t, expectedErr, err)
|
||||
|
||||
cancel()
|
||||
|
@ -364,7 +364,7 @@ func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) {
|
|||
ids: []string{"foo"},
|
||||
})
|
||||
|
||||
err := tj.job.wait(ctx)
|
||||
err := tj.job.waitAndClose(ctx)
|
||||
require.Equal(t, expectedErr, err)
|
||||
|
||||
cancel()
|
||||
|
|
|
@ -87,6 +87,10 @@ func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessag
|
|||
return
|
||||
}
|
||||
|
||||
if job.ctx.Err() != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
job.begin()
|
||||
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
|
||||
if err != nil {
|
||||
|
|
|
@ -114,7 +114,7 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
err = tj.job.wait(context.Background())
|
||||
err = tj.job.waitAndClose(ctx)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
cancel()
|
||||
}
|
||||
|
@ -165,8 +165,8 @@ func TestMetadataStage_RunInterleaved(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
require.NoError(t, tj1.job.wait(ctx))
|
||||
require.NoError(t, tj2.job.wait(ctx))
|
||||
require.NoError(t, tj1.job.waitAndClose(ctx))
|
||||
require.NoError(t, tj2.job.waitAndClose(ctx))
|
||||
cancel()
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue