Skip to content

Commit

Permalink
The status of the AWS batch job should become failed once the retry l…
Browse files Browse the repository at this point in the history
…imit exceeded (flyteorg#291)

* Turn PhaseRetryableFailure into PhaseRetryLimitExceededFailure

Signed-off-by: Kevin Su <pingsutw@apache.org>

* nit

Signed-off-by: Kevin Su <pingsutw@apache.org>

* update

Signed-off-by: Kevin Su <pingsutw@apache.org>

* update test

Signed-off-by: Kevin Su <pingsutw@apache.org>

* lint

Signed-off-by: Kevin Su <pingsutw@apache.org>

* update

Signed-off-by: Kevin Su <pingsutw@apache.org>

* update tests

Signed-off-by: Kevin Su <pingsutw@apache.org>

* lint

Signed-off-by: Kevin Su <pingsutw@apache.org>

* wip

Signed-off-by: Kevin Su <pingsutw@apache.org>

* udpate

Signed-off-by: Kevin Su <pingsutw@apache.org>

* address comment

Signed-off-by: Kevin Su <pingsutw@apache.org>

* nit

Signed-off-by: Kevin Su <pingsutw@apache.org>

* fix tests

Signed-off-by: Kevin Su <pingsutw@apache.org>

* nit

Signed-off-by: Kevin Su <pingsutw@apache.org>

Signed-off-by: Kevin Su <pingsutw@apache.org>
  • Loading branch information
pingsutw authored Dec 1, 2022
1 parent b0f20e8 commit d5295d2
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 19 deletions.
4 changes: 1 addition & 3 deletions go/tasks/plugins/array/awsbatch/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
pluginState, err = LaunchSubTasks(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics)

case arrayCore.PhaseCheckingSubTaskExecutions:
pluginState, err = CheckSubTasksState(ctx, tCtx.TaskExecutionMetadata(),
tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(),
e.jobStore, tCtx.DataStore(), pluginConfig, pluginState, e.metrics)
pluginState, err = CheckSubTasksState(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics)

case arrayCore.PhaseAssembleFinalOutput:
pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version, pluginState.State)
Expand Down
27 changes: 22 additions & 5 deletions go/tasks/plugins/array/awsbatch/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"context"

core2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flytestdlib/storage"

"github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"
Expand Down Expand Up @@ -34,19 +34,32 @@ func createSubJobList(count int) []*Job {
return res
}

func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata, outputPrefix, baseOutputSandbox storage.DataReference, jobStore *JobStore,
dataStore *storage.DataStore, cfg *config.Config, currentState *State, metrics ExecutorMetrics) (newState *State, err error) {
func CheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, jobStore *JobStore,
cfg *config.Config, currentState *State, metrics ExecutorMetrics) (newState *State, err error) {
newState = currentState
parentState := currentState.State
jobName := taskMeta.GetTaskExecutionID().GetGeneratedName()
jobName := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
job := jobStore.Get(jobName)
outputPrefix := tCtx.OutputWriter().GetOutputPrefixPath()
baseOutputSandbox := tCtx.OutputWriter().GetRawOutputPrefix()
dataStore := tCtx.DataStore()
// Check that the taskTemplate is valid
var taskTemplate *core2.TaskTemplate
taskTemplate, err = tCtx.TaskReader().Read(ctx)
if err != nil {
return nil, errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read task template")
} else if taskTemplate == nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil")
}
retry := toRetryStrategy(ctx, toBackoffLimit(taskTemplate.Metadata), cfg.MinRetries, cfg.MaxRetries)

// If job isn't currently being monitored (recovering from a restart?), add it to the sync-cache and return
if job == nil {
logger.Info(ctx, "Job not found in cache, adding it. [%v]", jobName)

_, err = jobStore.GetOrCreate(jobName, &Job{
ID: *currentState.ExternalJobID,
OwnerReference: taskMeta.GetOwnerID(),
OwnerReference: tCtx.TaskExecutionMetadata().GetOwnerID(),
SubJobs: createSubJobList(currentState.GetExecutionArraySize()),
})

Expand Down Expand Up @@ -108,6 +121,10 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata
} else {
msg.Collect(childIdx, "Job failed")
}

if subJob.Status.Phase == core.PhaseRetryableFailure && *retry.Attempts == int64(len(subJob.Attempts)) {
actualPhase = core.PhasePermanentFailure
}
} else if subJob.Status.Phase.IsSuccess() {
actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputSandbox, childIdx, originalIdx)
if err != nil {
Expand Down
87 changes: 76 additions & 11 deletions go/tasks/plugins/array/awsbatch/monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package awsbatch
import (
"testing"

"github.com/stretchr/testify/mock"

"github.com/flyteorg/flytestdlib/contextutils"
"github.com/flyteorg/flytestdlib/promutils/labeled"

Expand All @@ -11,6 +13,7 @@ import (

"github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus"

flyteIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"

"github.com/aws/aws-sdk-go/aws/request"
Expand All @@ -19,6 +22,7 @@ import (
arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
ioMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config"
batchMocks "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/mocks"
"github.com/flyteorg/flytestdlib/utils"
Expand All @@ -35,15 +39,39 @@ func init() {

func TestCheckSubTasksState(t *testing.T) {
ctx := context.Background()
tCtx := &mocks.TaskExecutionContext{}
tID := &mocks.TaskExecutionID{}
tID.OnGetGeneratedName().Return("generated-name")

tMeta := &mocks.TaskExecutionMetadata{}
tMeta.OnGetOwnerID().Return(types.NamespacedName{
Namespace: "domain",
Name: "name",
})
tMeta.OnGetTaskExecutionID().Return(tID)
inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

outputWriter := &ioMocks.OutputWriter{}
outputWriter.OnGetOutputPrefixPath().Return("")
outputWriter.OnGetRawOutputPrefix().Return("")

taskReader := &mocks.TaskReader{}
task := &flyteIdl.TaskTemplate{
Type: "test",
Target: &flyteIdl.TaskTemplate_Container{
Container: &flyteIdl.Container{
Command: []string{"command"},
Args: []string{"{{.Input}}"},
},
},
Metadata: &flyteIdl.TaskMetadata{Retries: &flyteIdl.RetryStrategy{Retries: 3}},
}
taskReader.On("Read", mock.Anything).Return(task, nil)

tCtx.OnOutputWriter().Return(outputWriter)
tCtx.OnTaskReader().Return(taskReader)
tCtx.OnDataStore().Return(inMemDatastore)
tCtx.OnTaskExecutionMetadata().Return(tMeta)

t.Run("Not in cache", func(t *testing.T) {
mBatchClient := batchMocks.NewMockAwsBatchClient()
Expand All @@ -52,7 +80,7 @@ func TestCheckSubTasksState(t *testing.T) {
utils.NewRateLimiter("", 10, 20))

jobStore := newJobsStore(t, batchClient)
newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, nil, &config.Config{}, &State{
newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: 5,
Expand Down Expand Up @@ -98,7 +126,7 @@ func TestCheckSubTasksState(t *testing.T) {

assert.NoError(t, err)

newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, nil, &config.Config{}, &State{
newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: 5,
Expand Down Expand Up @@ -133,13 +161,10 @@ func TestCheckSubTasksState(t *testing.T) {

assert.NoError(t, err)

inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

retryAttemptsArray, err := bitarray.NewCompactArray(1, bitarray.Item(1))
assert.NoError(t, err)

newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{
newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: 1,
Expand Down Expand Up @@ -181,13 +206,10 @@ func TestCheckSubTasksState(t *testing.T) {

assert.NoError(t, err)

inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1))
assert.NoError(t, err)

newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{
newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: 2,
Expand All @@ -206,6 +228,49 @@ func TestCheckSubTasksState(t *testing.T) {
assert.NoError(t, err)
p, _ := newState.GetPhase()
assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String())
})

t.Run("retry limit exceeded", func(t *testing.T) {
mBatchClient := batchMocks.NewMockAwsBatchClient()
batchClient := NewCustomBatchClient(mBatchClient, "", "",
utils.NewRateLimiter("", 10, 20),
utils.NewRateLimiter("", 10, 20))

jobStore := newJobsStore(t, batchClient)
_, err := jobStore.GetOrCreate(tID.GetGeneratedName(), &Job{
ID: "job-id",
Status: JobStatus{
Phase: core.PhaseRunning,
},
SubJobs: []*Job{
{Status: JobStatus{Phase: core.PhaseRetryableFailure}, Attempts: []Attempt{{LogStream: "failed"}}},
{Status: JobStatus{Phase: core.PhaseSuccess}},
},
})

assert.NoError(t, err)

retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1))
assert.NoError(t, err)

newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseWriteToDiscoveryThenFail,
ExecutionArraySize: 2,
OriginalArraySize: 2,
OriginalMinSuccesses: 2,
ArrayStatus: arraystatus.ArrayStatus{
Detailed: arrayCore.NewPhasesCompactArray(2),
},
IndexesToCache: bitarray.NewBitSet(2),
RetryAttempts: retryAttemptsArray,
},
ExternalJobID: refStr("job-id"),
JobDefinitionArn: "",
}, getAwsBatchExecutorMetrics(promutils.NewTestScope()))

assert.NoError(t, err)
p, _ := newState.GetPhase()
assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, p)
})
}
16 changes: 16 additions & 0 deletions go/tasks/plugins/array/core/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,22 @@ func TestSummaryToPhase(t *testing.T) {
core.PhaseSuccess: 10,
},
},
{
"FailedToRetry",
PhaseWriteToDiscoveryThenFail,
map[core.Phase]int64{
core.PhaseSuccess: 5,
core.PhasePermanentFailure: 5,
},
},
{
"Retrying",
PhaseCheckingSubTaskExecutions,
map[core.Phase]int64{
core.PhaseSuccess: 5,
core.PhaseRetryableFailure: 5,
},
},
}

for _, tt := range tests {
Expand Down

0 comments on commit d5295d2

Please sign in to comment.