Skip to content

Commit

Permalink
disttask: fix failed step is taken as success (#49971) (#49982)
Browse files Browse the repository at this point in the history
close #49950
  • Loading branch information
ti-chi-bot authored Feb 7, 2024
1 parent 75d0380 commit 523a313
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 15 deletions.
3 changes: 2 additions & 1 deletion pkg/disttask/framework/dispatcher/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ go_test(
timeout = "short",
srcs = [
"dispatcher_manager_test.go",
"dispatcher_nokit_test.go",
"dispatcher_test.go",
"main_test.go",
],
embed = [":dispatcher"],
flaky = True,
race = "off",
shard_count = 16,
shard_count = 17,
deps = [
"//pkg/disttask/framework/mock",
"//pkg/disttask/framework/proto",
Expand Down
30 changes: 17 additions & 13 deletions pkg/disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,23 +329,22 @@ func (d *BaseDispatcher) onPending() error {
// If subtasks finished, run into the next stage.
func (d *BaseDispatcher) onRunning() error {
logutil.Logger(d.logCtx).Debug("on running state", zap.Stringer("state", d.Task.State), zap.Int64("stage", int64(d.Task.Step)))
subTaskErrs, err := d.taskMgr.CollectSubTaskError(d.ctx, d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Warn("collect subtask error failed", zap.Error(err))
return err
}
if len(subTaskErrs) > 0 {
logutil.Logger(d.logCtx).Warn("subtasks encounter errors")
return d.onErrHandlingStage(subTaskErrs)
}
// check current stage finished.
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.ctx, d.Task.ID, proto.TaskStatePending, proto.TaskStateRunning)
cntByStates, err := d.taskMgr.GetSubtaskCntGroupByStates(d.ctx, d.Task.ID, d.Task.Step)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
}

if cnt == 0 {
if cntByStates[proto.TaskStateFailed] > 0 || cntByStates[proto.TaskStateCanceled] > 0 {
subTaskErrs, err := d.taskMgr.CollectSubTaskError(d.ctx, d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Warn("collect subtask error failed", zap.Error(err))
return err
}
if len(subTaskErrs) > 0 {
logutil.Logger(d.logCtx).Warn("subtasks encounter errors")
return d.onErrHandlingStage(subTaskErrs)
}
} else if d.isStepSucceed(cntByStates) {
return d.onNextStage()
}
// Check if any node are down.
Expand Down Expand Up @@ -745,6 +744,11 @@ func (d *BaseDispatcher) WithNewTxn(ctx context.Context, fn func(se sessionctx.C
return d.taskMgr.WithNewTxn(ctx, fn)
}

func (*BaseDispatcher) isStepSucceed(cntByStates map[proto.TaskState]int64) bool {
_, ok := cntByStates[proto.TaskStateSucceed]
return len(cntByStates) == 0 || (len(cntByStates) == 1 && ok)
}

// IsCancelledErr checks if the error is a cancelled error.
func IsCancelledErr(err error) bool {
return strings.Contains(err.Error(), taskCancelMsg)
Expand Down
40 changes: 40 additions & 0 deletions pkg/disttask/framework/dispatcher/dispatcher_nokit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package dispatcher

import (
"testing"

"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/stretchr/testify/require"
)

func TestSchedulerIsStepSucceed(t *testing.T) {
s := &BaseDispatcher{}
require.True(t, s.isStepSucceed(nil))
require.True(t, s.isStepSucceed(map[proto.TaskState]int64{}))
require.True(t, s.isStepSucceed(map[proto.TaskState]int64{
proto.TaskStateSucceed: 1,
}))
for _, state := range []proto.TaskState{
proto.TaskStateCanceled,
proto.TaskStateFailed,
proto.TaskStateReverting,
} {
require.False(t, s.isStepSucceed(map[proto.TaskState]int64{
state: 1,
}))
}
}
2 changes: 2 additions & 0 deletions pkg/disttask/framework/dispatcher/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ type TaskManager interface {
GetSchedulerIDsByTaskID(ctx context.Context, taskID int64) ([]string, error)
GetSucceedSubtasksByStep(ctx context.Context, taskID int64, step proto.Step) ([]*proto.Subtask, error)
GetSchedulerIDsByTaskIDAndStep(ctx context.Context, taskID int64, step proto.Step) ([]string, error)
// GetSubtaskCntGroupByStates returns the count of subtasks of some step group by state.
GetSubtaskCntGroupByStates(ctx context.Context, taskID int64, step proto.Step) (map[proto.TaskState]int64, error)

WithNewSession(fn func(se sessionctx.Context) error) error
WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error
Expand Down
15 changes: 15 additions & 0 deletions pkg/disttask/framework/mock/dispatcher_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pkg/disttask/framework/storage/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ go_test(
srcs = ["table_test.go"],
flaky = True,
race = "on",
shard_count = 8,
shard_count = 9,
deps = [
":storage",
"//pkg/disttask/framework/proto",
"//pkg/sessionctx",
"//pkg/testkit",
"//pkg/testkit/testsetup",
"@com_github_ngaut_pools//:pools",
Expand Down
40 changes: 40 additions & 0 deletions pkg/disttask/framework/storage/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/pingcap/tidb/pkg/disttask/framework/storage"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/testkit"
"github.com/pingcap/tidb/pkg/testkit/testsetup"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -420,6 +421,45 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) {
require.Equal(t, int64(0), cnt)
}

// InsertSubtask adds a new subtask of any state to subtask table.
func insertSubtask(t *testing.T, gm *storage.TaskManager, taskID int64, step proto.Step, execID string, meta []byte, state proto.TaskState, tp proto.TaskType) {
ctx := context.Background()
ctx = util.WithInternalSourceType(ctx, "table_test")
require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error {
_, err := storage.ExecSQL(ctx, se, `
insert into mysql.tidb_background_subtask(step, task_key, exec_id, meta, state, type, state_update_time, checkpoint, summary) values`+
`(%?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), '{}', '{}')`,
step, taskID, execID, meta, state, proto.Type2Int(tp))
return err
}))
}

func TestGetSubtaskCntByStates(t *testing.T) {
pool := GetResourcePool(t)
sm := GetTaskManager(t, pool)
defer pool.Close()
ctx := context.Background()
ctx = util.WithInternalSourceType(ctx, "table_test")

insertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.TaskStatePending, "test")
insertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.TaskStatePending, "test")
insertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.TaskStateRunning, "test")
insertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.TaskStateSucceed, "test")
insertSubtask(t, sm, 1, proto.StepOne, "tidb1", nil, proto.TaskStateFailed, "test")
insertSubtask(t, sm, 1, proto.StepTwo, "tidb1", nil, proto.TaskStateFailed, "test")
cntByStates, err := sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepOne)
require.NoError(t, err)
require.Len(t, cntByStates, 4)
require.Equal(t, int64(2), cntByStates[proto.TaskStatePending])
require.Equal(t, int64(1), cntByStates[proto.TaskStateRunning])
require.Equal(t, int64(1), cntByStates[proto.TaskStateSucceed])
require.Equal(t, int64(1), cntByStates[proto.TaskStateFailed])
cntByStates, err = sm.GetSubtaskCntGroupByStates(ctx, 1, proto.StepTwo)
require.NoError(t, err)
require.Len(t, cntByStates, 1)
require.Equal(t, int64(1), cntByStates[proto.TaskStateFailed])
}

func TestDistFrameworkMeta(t *testing.T) {
pool := GetResourcePool(t)
sm := GetTaskManager(t, pool)
Expand Down
21 changes: 21 additions & 0 deletions pkg/disttask/framework/storage/task_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,27 @@ func (stm *TaskManager) GetSubtaskInStatesCnt(ctx context.Context, taskID int64,
return rs[0].GetInt64(0), nil
}

// GetSubtaskCntGroupByStates gets the subtask count by states.
func (stm *TaskManager) GetSubtaskCntGroupByStates(ctx context.Context, taskID int64, step proto.Step) (map[proto.TaskState]int64, error) {
rs, err := stm.executeSQLWithNewSession(ctx, `
select state, count(*)
from mysql.tidb_background_subtask
where task_key = %? and step = %?
group by state`,
taskID, step)
if err != nil {
return nil, err
}

res := make(map[proto.TaskState]int64, len(rs))
for _, r := range rs {
state := proto.TaskState(r.GetString(0))
res[state] = r.GetInt64(1)
}

return res, nil
}

// CollectSubTaskError collects the subtask error.
func (stm *TaskManager) CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error) {
rs, err := stm.executeSQLWithNewSession(ctx,
Expand Down

0 comments on commit 523a313

Please sign in to comment.