fix(GODT-3124): Handling of sync child jobs

Improve the handling of sync child jobs to ensure it behaves correctly
in all scenarios.

The sync service now uses a isolated context to avoid all the pipeline
stages shutting down before all the sync tasks have had the opportunity
to run their course.

The job waiter now immediately starts with a counter of 1 and waits
until all the child and the parent job finish before considering the
work to be finished.

Finally, we also handle the case where a sync job can't be queued
because the calling context has been cancelled.
This commit is contained in:
Leander Beernaert 2023-11-29 13:52:12 +01:00
parent 9449177553
commit 7a1c7e8743
15 changed files with 100 additions and 78 deletions

View File

@ -307,7 +307,7 @@ func newBridge(
bridge.heartbeat.init(bridge, heartbeatManager)
}
bridge.syncService.Run(bridge.tasks)
bridge.syncService.Run()
return bridge, nil
}
@ -451,6 +451,8 @@ func (bridge *Bridge) Close(ctx context.Context) {
logrus.WithError(err).Error("Failed to close servers")
}
bridge.syncService.Close()
// Stop all ongoing tasks.
bridge.tasks.CancelAndWait()

View File

@ -210,7 +210,11 @@ func (t *Handler) run(ctx context.Context,
stageContext.metadataFetched = syncStatus.NumSyncedMessages
stageContext.totalMessageCount = syncStatus.TotalMessageCount
t.regulator.Sync(ctx, stageContext)
if err := t.regulator.Sync(ctx, stageContext); err != nil {
stageContext.onError(err)
_ = stageContext.waitAndClose(ctx)
return fmt.Errorf("failed to start sync job: %w", err)
}
// Wait on reply
if err := stageContext.waitAndClose(ctx); err != nil {

View File

@ -64,7 +64,7 @@ func (s Status) InProgress() bool {
// Regulator is an abstraction for the sync service, since it regulates the number of concurrent sync activities.
type Regulator interface {
Sync(ctx context.Context, stage *Job)
Sync(ctx context.Context, stage *Job) error
}
type BuildResult struct {

View File

@ -119,7 +119,6 @@ func (j *Job) onJobFinished(ctx context.Context, lastMessageID string, count int
// begin is expected to be called once the job enters the pipeline.
func (j *Job) begin() {
j.log.Info("Job started")
j.jw.onTaskCreated()
}
// end is expected to be called once the job has no further work left.
@ -133,7 +132,6 @@ func (j *Job) waitAndClose(ctx context.Context) error {
defer j.close()
select {
case <-ctx.Done():
j.jw.onContextCancelled()
<-j.jw.doneCh
return ctx.Err()
case e := <-j.jw.doneCh:
@ -227,7 +225,6 @@ type JobWaiterMessage int
const (
JobWaiterMessageCreated JobWaiterMessage = iota
JobWaiterMessageFinished
JobWaiterMessageCtxErr
)
type jobWaiterMessagePair struct {
@ -248,7 +245,7 @@ type jobWaiter struct {
func newJobWaiter(log *logrus.Entry, panicHandler async.PanicHandler) *jobWaiter {
return &jobWaiter{
ch: make(chan jobWaiterMessagePair),
doneCh: make(chan error),
doneCh: make(chan error, 2),
log: log,
panicHandler: panicHandler,
}
@ -273,15 +270,11 @@ func (j *jobWaiter) onTaskCreated() {
j.sendMessage(JobWaiterMessageCreated, nil)
}
func (j *jobWaiter) onContextCancelled() {
j.sendMessage(JobWaiterMessageCtxErr, nil)
}
func (j *jobWaiter) begin() {
go func() {
defer async.HandlePanic(j.panicHandler)
total := 0
total := 1
var err error
defer func() {
@ -296,8 +289,6 @@ func (j *jobWaiter) begin() {
}
switch m.m {
case JobWaiterMessageCtxErr:
// DO nothing
case JobWaiterMessageCreated:
total++
case JobWaiterMessageFinished:

View File

@ -83,6 +83,7 @@ func TestJob_WaitsOnAllChildrenOnError(t *testing.T) {
job1.onFinished(context.Background())
job2.onError(jobErr)
tj.job.end()
}()
close(startCh)
@ -115,6 +116,7 @@ func TestJob_MultipleChildrenReportError(t *testing.T) {
}
wg.Wait()
tj.job.end()
close(startCh)
err := tj.job.waitAndClose(context.Background())
require.Error(t, err)
@ -179,6 +181,7 @@ func TestJob_CtxCancelCancelsAllChildren(t *testing.T) {
go func() {
wg.Wait()
tj.job.end()
cancel()
}()
@ -201,6 +204,7 @@ func TestJob_CtxCancelBeforeBegin(t *testing.T) {
go func() {
wg.Wait()
cancel()
tj.job.end()
}()
wg.Done()

View File

@ -127,9 +127,11 @@ func (mr *MockBuildStageOutputMockRecorder) Close() *gomock.Call {
}
// Produce mocks base method.
func (m *MockBuildStageOutput) Produce(arg0 context.Context, arg1 ApplyRequest) {
func (m *MockBuildStageOutput) Produce(arg0 context.Context, arg1 ApplyRequest) error {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Produce", arg0, arg1)
ret := m.ctrl.Call(m, "Produce", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Produce indicates an expected call of Produce.
@ -212,9 +214,11 @@ func (mr *MockDownloadStageOutputMockRecorder) Close() *gomock.Call {
}
// Produce mocks base method.
func (m *MockDownloadStageOutput) Produce(arg0 context.Context, arg1 BuildRequest) {
func (m *MockDownloadStageOutput) Produce(arg0 context.Context, arg1 BuildRequest) error {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Produce", arg0, arg1)
ret := m.ctrl.Call(m, "Produce", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Produce indicates an expected call of Produce.
@ -297,9 +301,11 @@ func (mr *MockMetadataStageOutputMockRecorder) Close() *gomock.Call {
}
// Produce mocks base method.
func (m *MockMetadataStageOutput) Produce(arg0 context.Context, arg1 DownloadRequest) {
func (m *MockMetadataStageOutput) Produce(arg0 context.Context, arg1 DownloadRequest) error {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Produce", arg0, arg1)
ret := m.ctrl.Call(m, "Produce", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Produce indicates an expected call of Produce.
@ -478,9 +484,11 @@ func (m *MockRegulator) EXPECT() *MockRegulatorMockRecorder {
}
// Sync mocks base method.
func (m *MockRegulator) Sync(arg0 context.Context, arg1 *Job) {
func (m *MockRegulator) Sync(arg0 context.Context, arg1 *Job) error {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Sync", arg0, arg1)
ret := m.ctrl.Call(m, "Sync", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Sync indicates an expected call of Sync.

View File

@ -33,7 +33,7 @@ type Service struct {
applyStage *ApplyStage
limits syncLimits
metaCh *ChannelConsumerProducer[*Job]
panicHandler async.PanicHandler
group *async.Group
}
func NewService(reporter reporter.Reporter,
@ -53,26 +53,22 @@ func NewService(reporter reporter.Reporter,
buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter),
applyStage: NewApplyStage(applyCh),
metaCh: metaCh,
panicHandler: panicHandler,
group: async.NewGroup(context.Background(), panicHandler),
}
}
func (s *Service) Run(group *async.Group) {
group.Once(func(ctx context.Context) {
syncGroup := async.NewGroup(ctx, s.panicHandler)
s.metadataStage.Run(syncGroup)
s.downloadStage.Run(syncGroup)
s.buildStage.Run(syncGroup)
s.applyStage.Run(syncGroup)
defer s.metaCh.Close()
defer syncGroup.CancelAndWait()
<-ctx.Done()
})
func (s *Service) Run() {
s.metadataStage.Run(s.group)
s.downloadStage.Run(s.group)
s.buildStage.Run(s.group)
s.applyStage.Run(s.group)
}
func (s *Service) Sync(ctx context.Context, stage *Job) {
s.metaCh.Produce(ctx, stage)
func (s *Service) Sync(ctx context.Context, stage *Job) error {
return s.metaCh.Produce(ctx, stage)
}
func (s *Service) Close() {
s.group.CancelAndWait()
s.metaCh.Close()
}

View File

@ -50,10 +50,10 @@ func TestApplyStage_CancelledJobIsDiscarded(t *testing.T) {
}()
jobCancel()
input.Produce(ctx, ApplyRequest{
require.NoError(t, input.Produce(ctx, ApplyRequest{
childJob: childJob,
messages: nil,
})
}))
err := tj.job.waitAndClose(ctx)
require.ErrorIs(t, err, context.Canceled)
@ -84,10 +84,10 @@ func TestApplyStage_JobWithNoMessagesIsFinalized(t *testing.T) {
stage.run(ctx)
}()
input.Produce(ctx, ApplyRequest{
require.NoError(t, input.Produce(ctx, ApplyRequest{
childJob: childJob,
messages: nil,
})
}))
err := tj.job.waitAndClose(ctx)
cancel()
@ -127,10 +127,10 @@ func TestApplyStage_ErrorOnApplyIsReportedAndJobFails(t *testing.T) {
stage.run(ctx)
}()
input.Produce(ctx, ApplyRequest{
require.NoError(t, input.Produce(ctx, ApplyRequest{
childJob: childJob,
messages: buildResults,
})
}))
err := tj.job.waitAndClose(ctx)
cancel()

View File

@ -21,6 +21,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"runtime"
"github.com/ProtonMail/gluon/async"
@ -182,10 +183,12 @@ func (b *BuildStage) run(ctx context.Context) {
outJob.onStageCompleted(ctx)
b.output.Produce(ctx, ApplyRequest{
if err := b.output.Produce(ctx, ApplyRequest{
childJob: outJob,
messages: success,
})
}); err != nil {
return fmt.Errorf("failed to produce output for next stage: %w", err)
}
}
return nil

View File

@ -111,7 +111,7 @@ func TestBuildStage_SuccessRemovesFailedMessage(t *testing.T) {
stage.run(ctx)
}()
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}))
req, err := output.Consume(ctx)
cancel()
@ -170,7 +170,7 @@ func TestBuildStage_BuildFailureIsReportedButDoesNotCancelJob(t *testing.T) {
stage.run(ctx)
}()
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}))
req, err := output.Consume(ctx)
cancel()
@ -222,7 +222,7 @@ func TestBuildStage_FailedToLocateKeyRingIsReportedButDoesNotFailBuild(t *testin
stage.run(ctx)
}()
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}))
req, err := output.Consume(ctx)
cancel()
@ -267,7 +267,7 @@ func TestBuildStage_OtherErrorsFailJob(t *testing.T) {
stage.run(ctx)
}()
input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}})
require.NoError(t, input.Produce(ctx, BuildRequest{childJob: childJob, batch: []proton.FullMessage{msg}}))
err := tj.job.waitAndClose(ctx)
require.Equal(t, expectedErr, err)
@ -311,10 +311,10 @@ func TestBuildStage_CancelledJobIsDiscarded(t *testing.T) {
}()
jobCancel()
input.Produce(ctx, BuildRequest{
require.NoError(t, input.Produce(ctx, BuildRequest{
childJob: childJob,
batch: []proton.FullMessage{msg},
})
}))
go func() { cancel() }()

View File

@ -21,6 +21,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"sync/atomic"
"github.com/ProtonMail/gluon/async"
@ -183,10 +184,12 @@ func (d *DownloadStage) run(ctx context.Context) {
// Step 5: Publish result.
request.onStageCompleted(ctx)
d.output.Produce(ctx, BuildRequest{
if err := d.output.Produce(ctx, BuildRequest{
batch: result,
childJob: request.childJob,
})
}); err != nil {
request.job.onError(fmt.Errorf("failed to produce output for next stage: %w", err))
}
}
}

View File

@ -189,10 +189,10 @@ func TestDownloadStage_Run(t *testing.T) {
stage.run(ctx)
}()
input.Produce(ctx, DownloadRequest{
require.NoError(t, input.Produce(ctx, DownloadRequest{
childJob: childJob,
ids: msgIDs,
})
}))
out, err := output.Consume(ctx)
require.NoError(t, err)
@ -232,10 +232,10 @@ func TestDownloadStage_RunWith422(t *testing.T) {
stage.run(ctx)
}()
input.Produce(ctx, DownloadRequest{
require.NoError(t, input.Produce(ctx, DownloadRequest{
childJob: childJob,
ids: msgIDs,
})
}))
out, err := output.Consume(ctx)
require.NoError(t, err)
@ -271,10 +271,11 @@ func TestDownloadStage_CancelledJobIsDiscarded(t *testing.T) {
}()
jobCancel()
input.Produce(ctx, DownloadRequest{
require.NoError(t, input.Produce(ctx, DownloadRequest{
childJob: childJob,
ids: nil,
})
}))
go func() { cancel() }()
@ -308,10 +309,10 @@ func TestDownloadStage_JobAbortsOnMessageDownloadError(t *testing.T) {
stage.run(ctx)
}()
input.Produce(ctx, DownloadRequest{
require.NoError(t, input.Produce(ctx, DownloadRequest{
childJob: childJob,
ids: []string{"foo"},
})
}))
err := tj.job.waitAndClose(ctx)
require.Equal(t, expectedErr, err)
@ -359,10 +360,10 @@ func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) {
stage.run(ctx)
}()
input.Produce(ctx, DownloadRequest{
require.NoError(t, input.Produce(ctx, DownloadRequest{
childJob: childJob,
ids: []string{"foo"},
})
}))
err := tj.job.waitAndClose(ctx)
require.Equal(t, expectedErr, err)

View File

@ -20,6 +20,7 @@ package syncservice
import (
"context"
"errors"
"fmt"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/logging"
@ -87,10 +88,6 @@ func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessag
return
}
if job.ctx.Err() != nil {
continue
}
job.begin()
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
if err != nil {
@ -119,7 +116,10 @@ func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessag
output.onStageCompleted(ctx)
m.output.Produce(ctx, output)
if err := m.output.Produce(ctx, output); err != nil {
job.onError(fmt.Errorf("failed to produce output for next stage: %w", err))
return
}
}
// If this job has no more work left, signal completion.

View File

@ -21,6 +21,7 @@ import (
"context"
"fmt"
"io"
"sync"
"testing"
"github.com/ProtonMail/gluon/async"
@ -59,7 +60,7 @@ func TestMetadataStage_RunFinishesWith429(t *testing.T) {
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
}()
input.Produce(ctx, tj.job)
require.NoError(t, input.Produce(ctx, tj.job))
for _, chunk := range xslices.Chunk(msgs, TestMaxMessages) {
tj.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(int64(len(chunk))))
@ -93,7 +94,10 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
}()
input.Produce(ctx, tj.job)
{
err := input.Produce(ctx, tj.job)
require.NoError(t, err)
}
// read one output then cancel
request, err := output.Consume(ctx)
@ -102,8 +106,11 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
// cancel job context
jobCancel()
wg := sync.WaitGroup{}
wg.Add(1)
// The next stages should check whether the job has been cancelled or not. Here we need to do it manually.
go func() {
wg.Done()
for {
req, err := output.Consume(ctx)
if err != nil {
@ -113,8 +120,9 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
req.checkCancelled()
}
}()
wg.Wait()
err = tj.job.waitAndClose(ctx)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
cancel()
}
@ -149,8 +157,8 @@ func TestMetadataStage_RunInterleaved(t *testing.T) {
}()
go func() {
input.Produce(ctx, tj1.job)
input.Produce(ctx, tj2.job)
require.NoError(t, input.Produce(ctx, tj1.job))
require.NoError(t, input.Produce(ctx, tj2.job))
}()
go func() {

View File

@ -23,7 +23,7 @@ import (
)
type StageOutputProducer[T any] interface {
Produce(ctx context.Context, value T)
Produce(ctx context.Context, value T) error
Close()
}
@ -41,10 +41,12 @@ func NewChannelConsumerProducer[T any]() *ChannelConsumerProducer[T] {
return &ChannelConsumerProducer[T]{ch: make(chan T)}
}
func (c ChannelConsumerProducer[T]) Produce(ctx context.Context, value T) {
func (c ChannelConsumerProducer[T]) Produce(ctx context.Context, value T) error {
select {
case <-ctx.Done():
return ctx.Err()
case c.ch <- value:
return nil
}
}