diff --git a/ddl/BUILD.bazel b/ddl/BUILD.bazel index e7130f05fc33c..68044361552e8 100644 --- a/ddl/BUILD.bazel +++ b/ddl/BUILD.bazel @@ -74,6 +74,7 @@ go_library( "//disttask/framework/handle", "//disttask/framework/proto", "//disttask/framework/scheduler", + "//disttask/framework/storage", "//disttask/operator", "//domain/infosync", "//domain/resourcegroup", diff --git a/ddl/backfilling_dispatcher.go b/ddl/backfilling_dispatcher.go index 8e7bf825e9690..7e16672bde7c0 100644 --- a/ddl/backfilling_dispatcher.go +++ b/ddl/backfilling_dispatcher.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/disttask/framework/dispatcher" "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/disttask/framework/storage" "github.com/pingcap/tidb/domain/infosync" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" @@ -32,28 +33,28 @@ import ( "github.com/tikv/client-go/v2/tikv" ) -type backfillingDispatcher struct { +type backfillingDispatcherExt struct { d *ddl } -var _ dispatcher.Dispatcher = (*backfillingDispatcher)(nil) +var _ dispatcher.Extension = (*backfillingDispatcherExt)(nil) -// NewBackfillingDispatcher creates a new backfillingDispatcher. -func NewBackfillingDispatcher(d DDL) (dispatcher.Dispatcher, error) { +// NewBackfillingDispatcherExt creates a new backfillingDispatcherExt. +func NewBackfillingDispatcherExt(d DDL) (dispatcher.Extension, error) { ddl, ok := d.(*ddl) if !ok { return nil, errors.New("The getDDL result should be the type of *ddl") } - return &backfillingDispatcher{ + return &backfillingDispatcherExt{ d: ddl, }, nil } -func (*backfillingDispatcher) OnTick(_ context.Context, _ *proto.Task) { +func (*backfillingDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } // OnNextStage generate next stage's plan. -func (h *backfillingDispatcher) OnNextStage(ctx context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) ([][]byte, error) { +func (h *backfillingDispatcherExt) OnNextStage(ctx context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) ([][]byte, error) { var globalTaskMeta BackfillGlobalMeta if err := json.Unmarshal(gTask.Meta, &globalTaskMeta); err != nil { return nil, err @@ -112,7 +113,7 @@ func (h *backfillingDispatcher) OnNextStage(ctx context.Context, _ dispatcher.Ta } // OnErrStage generate error handling stage's plan. -func (*backfillingDispatcher) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task, receiveErr []error) (meta []byte, err error) { +func (*backfillingDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task, receiveErr []error) (meta []byte, err error) { // We do not need extra meta info when rolling back firstErr := receiveErr[0] task.Error = firstErr @@ -120,15 +121,28 @@ func (*backfillingDispatcher) OnErrStage(_ context.Context, _ dispatcher.TaskHan return nil, nil } -func (*backfillingDispatcher) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { +func (*backfillingDispatcherExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { return dispatcher.GenerateSchedulerNodes(ctx) } // IsRetryableErr implements TaskFlowHandle.IsRetryableErr interface. -func (*backfillingDispatcher) IsRetryableErr(error) bool { +func (*backfillingDispatcherExt) IsRetryableErr(error) bool { return true } +type litBackfillDispatcher struct { + *dispatcher.BaseDispatcher +} + +func newLitBackfillDispatcher(ctx context.Context, taskMgr *storage.TaskManager, + serverID string, task *proto.Task, handle dispatcher.Extension) dispatcher.Dispatcher { + dis := litBackfillDispatcher{ + BaseDispatcher: dispatcher.NewBaseDispatcher(ctx, taskMgr, serverID, task), + } + dis.BaseDispatcher.Extension = handle + return &dis +} + func getTblInfo(d *ddl, job *model.Job) (tblInfo *model.TableInfo, err error) { err = kv.RunInNewTxn(d.ctx, d.store, true, func(ctx context.Context, txn kv.Transaction) error { tblInfo, err = meta.NewMeta(txn).GetTable(job.SchemaID, job.TableID) diff --git a/ddl/backfilling_dispatcher_test.go b/ddl/backfilling_dispatcher_test.go index 72381dd64ca5a..94e4ebb487759 100644 --- a/ddl/backfilling_dispatcher_test.go +++ b/ddl/backfilling_dispatcher_test.go @@ -33,7 +33,7 @@ import ( func TestBackfillingDispatcher(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) - dsp, err := ddl.NewBackfillingDispatcher(dom.DDL()) + dsp, err := ddl.NewBackfillingDispatcherExt(dom.DDL()) require.NoError(t, err) tk := testkit.NewTestKit(t, store) tk.MustExec("use test") diff --git a/ddl/ddl.go b/ddl/ddl.go index ab6486e58439e..7e9a5689517e3 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -43,6 +43,7 @@ import ( "github.com/pingcap/tidb/disttask/framework/dispatcher" "github.com/pingcap/tidb/disttask/framework/proto" "github.com/pingcap/tidb/disttask/framework/scheduler" + "github.com/pingcap/tidb/disttask/framework/storage" "github.com/pingcap/tidb/domain/infosync" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" @@ -690,11 +691,14 @@ func newDDL(ctx context.Context, options ...Option) *ddl { return NewBackfillSchedulerHandle(ctx, taskMeta, d, step == proto.StepTwo) }) - backFillDsp, err := NewBackfillingDispatcher(d) + backFillDsp, err := NewBackfillingDispatcherExt(d) if err != nil { - logutil.BgLogger().Warn("NewBackfillingDispatcher failed", zap.String("category", "ddl"), zap.Error(err)) + logutil.BgLogger().Warn("NewBackfillingDispatcherExt failed", zap.String("category", "ddl"), zap.Error(err)) } else { - dispatcher.RegisterTaskDispatcher(BackfillTaskType, backFillDsp) + dispatcher.RegisterDispatcherFactory(BackfillTaskType, + func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher { + return newLitBackfillDispatcher(ctx, taskMgr, serverID, task, backFillDsp) + }) scheduler.RegisterSubtaskExectorConstructor(BackfillTaskType, proto.StepOne, func(proto.MinimalTask, int64) (scheduler.SubtaskExecutor, error) { return &scheduler.EmptyExecutor{}, nil diff --git a/disttask/framework/dispatcher/BUILD.bazel b/disttask/framework/dispatcher/BUILD.bazel index 4ea32cf32dea4..246b7e8568392 100644 --- a/disttask/framework/dispatcher/BUILD.bazel +++ b/disttask/framework/dispatcher/BUILD.bazel @@ -24,7 +24,6 @@ go_library( "//util/syncutil", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", - "@org_golang_x_exp//maps", "@org_uber_go_zap//:zap", ], ) diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go index 94844985a8e2b..da001bafdca60 100644 --- a/disttask/framework/dispatcher/dispatcher.go +++ b/disttask/framework/dispatcher/dispatcher.go @@ -60,15 +60,22 @@ type TaskHandle interface { storage.SessionExecutor } -// Manage the lifetime of a task. +// Dispatcher manages the lifetime of a task // including submitting subtasks and updating the status of a task. -type dispatcher struct { +type Dispatcher interface { + ExecuteTask() +} + +// BaseDispatcher is the base struct for Dispatcher. +// each task type embed this struct and implement the Extension interface. +type BaseDispatcher struct { ctx context.Context taskMgr *storage.TaskManager task *proto.Task logCtx context.Context serverID string - impl Dispatcher + // when RegisterDispatcherFactory, the factory MUST initialize this field. + Extension // for HA // liveNodes will fetch and store all live nodes every liveNodeInterval ticks. @@ -85,33 +92,25 @@ type dispatcher struct { // MockOwnerChange mock owner change in tests. var MockOwnerChange func() -func newDispatcher(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) (*dispatcher, error) { +// NewBaseDispatcher creates a new BaseDispatcher. +func NewBaseDispatcher(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) *BaseDispatcher { logPrefix := fmt.Sprintf("task_id: %d, task_type: %s, server_id: %s", task.ID, task.Type, serverID) - impl := GetTaskDispatcher(task.Type) - dsp := &dispatcher{ + return &BaseDispatcher{ ctx: ctx, taskMgr: taskMgr, task: task, logCtx: logutil.WithKeyValue(context.Background(), "dispatcher", logPrefix), serverID: serverID, - impl: impl, liveNodes: nil, liveNodeFetchInterval: DefaultLiveNodesCheckInterval, liveNodeFetchTick: 0, taskNodes: nil, rand: rand.New(rand.NewSource(time.Now().UnixNano())), } - if dsp.impl == nil { - logutil.BgLogger().Warn("gen dispatcher impl failed, this type impl doesn't register") - dsp.task.Error = errors.New("unsupported task type") - // state transform: pending -> failed. - return nil, dsp.updateTask(proto.TaskStateFailed, nil, retrySQLTimes) - } - return dsp, nil } // ExecuteTask start to schedule a task. -func (d *dispatcher) executeTask() { +func (d *BaseDispatcher) ExecuteTask() { logutil.Logger(d.logCtx).Info("execute one task", zap.String("state", d.task.State), zap.Uint64("concurrency", d.task.Concurrency)) d.scheduleTask() @@ -119,7 +118,7 @@ func (d *dispatcher) executeTask() { } // refreshTask fetch task state from tidb_global_task table. -func (d *dispatcher) refreshTask() (err error) { +func (d *BaseDispatcher) refreshTask() (err error) { d.task, err = d.taskMgr.GetGlobalTaskByID(d.task.ID) if err != nil { logutil.Logger(d.logCtx).Error("refresh task failed", zap.Error(err)) @@ -128,7 +127,7 @@ func (d *dispatcher) refreshTask() (err error) { } // scheduleTask schedule the task execution step by step. -func (d *dispatcher) scheduleTask() { +func (d *BaseDispatcher) scheduleTask() { ticker := time.NewTicker(checkTaskFinishedInterval) defer ticker.Stop() for { @@ -178,14 +177,14 @@ func (d *dispatcher) scheduleTask() { } // handle task in cancelling state, dispatch revert subtasks. -func (d *dispatcher) onCancelling() error { +func (d *BaseDispatcher) onCancelling() error { logutil.Logger(d.logCtx).Debug("on cancelling state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step)) errs := []error{errors.New("cancel")} return d.onErrHandlingStage(errs) } // handle task in reverting state, check all revert subtasks finished. -func (d *dispatcher) onReverting() error { +func (d *BaseDispatcher) onReverting() error { logutil.Logger(d.logCtx).Debug("on reverting state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step)) cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.task.ID, proto.TaskStateRevertPending, proto.TaskStateReverting) if err != nil { @@ -199,20 +198,20 @@ func (d *dispatcher) onReverting() error { return d.updateTask(proto.TaskStateReverted, nil, retrySQLTimes) } // Wait all subtasks in this stage finished. - d.impl.OnTick(d.ctx, d.task) + d.OnTick(d.ctx, d.task) logutil.Logger(d.logCtx).Debug("on reverting state, this task keeps current state", zap.String("state", d.task.State)) return nil } // handle task in pending state, dispatch subtasks. -func (d *dispatcher) onPending() error { +func (d *BaseDispatcher) onPending() error { logutil.Logger(d.logCtx).Debug("on pending state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step)) return d.onNextStage() } // handle task in running state, check all running subtasks finished. // If subtasks finished, run into the next stage. -func (d *dispatcher) onRunning() error { +func (d *BaseDispatcher) onRunning() error { logutil.Logger(d.logCtx).Debug("on running state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step)) subTaskErrs, err := d.taskMgr.CollectSubTaskError(d.task.ID) if err != nil { @@ -240,12 +239,12 @@ func (d *dispatcher) onRunning() error { return err } // Wait all subtasks in this stage finished. - d.impl.OnTick(d.ctx, d.task) + d.OnTick(d.ctx, d.task) logutil.Logger(d.logCtx).Debug("on running state, this task keeps current state", zap.String("state", d.task.State)) return nil } -func (d *dispatcher) replaceDeadNodesIfAny() error { +func (d *BaseDispatcher) replaceDeadNodesIfAny() error { if len(d.taskNodes) == 0 { return errors.Errorf("len(d.taskNodes) == 0, onNextStage is not invoked before onRunning") } @@ -256,7 +255,7 @@ func (d *dispatcher) replaceDeadNodesIfAny() error { if err != nil { return err } - eligibleServerInfos, err := d.impl.GetEligibleInstances(d.ctx, d.task) + eligibleServerInfos, err := d.GetEligibleInstances(d.ctx, d.task) if err != nil { return err } @@ -299,7 +298,7 @@ func (d *dispatcher) replaceDeadNodesIfAny() error { return nil } -func (d *dispatcher) updateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) { +func (d *BaseDispatcher) updateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) { prevState := d.task.State d.task.State = taskState if !VerifyTaskStateTransform(prevState, taskState) { @@ -331,9 +330,9 @@ func (d *dispatcher) updateTask(taskState string, newSubTasks []*proto.Subtask, return err } -func (d *dispatcher) onErrHandlingStage(receiveErr []error) error { +func (d *BaseDispatcher) onErrHandlingStage(receiveErr []error) error { // 1. generate the needed task meta and subTask meta (dist-plan). - meta, err := d.impl.OnErrStage(d.ctx, d, d.task, receiveErr) + meta, err := d.OnErrStage(d.ctx, d, d.task, receiveErr) if err != nil { // OnErrStage must be retryable, if not, there will have resource leak for tasks. logutil.Logger(d.logCtx).Warn("handle error failed", zap.Error(err)) @@ -344,7 +343,7 @@ func (d *dispatcher) onErrHandlingStage(receiveErr []error) error { return d.dispatchSubTask4Revert(d.task, meta) } -func (d *dispatcher) dispatchSubTask4Revert(task *proto.Task, meta []byte) error { +func (d *BaseDispatcher) dispatchSubTask4Revert(task *proto.Task, meta []byte) error { instanceIDs, err := d.GetAllSchedulerIDs(d.ctx, task) if err != nil { logutil.Logger(d.logCtx).Warn("get task's all instances failed", zap.Error(err)) @@ -358,9 +357,9 @@ func (d *dispatcher) dispatchSubTask4Revert(task *proto.Task, meta []byte) error return d.updateTask(proto.TaskStateReverting, subTasks, retrySQLTimes) } -func (d *dispatcher) onNextStage() error { +func (d *BaseDispatcher) onNextStage() error { // 1. generate the needed global task meta and subTask meta (dist-plan). - metas, err := d.impl.OnNextStage(d.ctx, d, d.task) + metas, err := d.OnNextStage(d.ctx, d, d.task) if err != nil { return d.handlePlanErr(err) } @@ -368,7 +367,7 @@ func (d *dispatcher) onNextStage() error { return d.dispatchSubTask(d.task, metas) } -func (d *dispatcher) dispatchSubTask(task *proto.Task, metas [][]byte) error { +func (d *BaseDispatcher) dispatchSubTask(task *proto.Task, metas [][]byte) error { logutil.Logger(d.logCtx).Info("dispatch subtasks", zap.String("state", d.task.State), zap.Uint64("concurrency", d.task.Concurrency), zap.Int("subtasks", len(metas))) // 1. Adjust the global task's concurrency. if task.Concurrency == 0 { @@ -400,7 +399,7 @@ func (d *dispatcher) dispatchSubTask(task *proto.Task, metas [][]byte) error { } // 3. select all available TiDB nodes for task. - serverNodes, err := d.impl.GetEligibleInstances(d.ctx, task) + serverNodes, err := d.GetEligibleInstances(d.ctx, task) logutil.Logger(d.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes))) if err != nil { @@ -424,9 +423,9 @@ func (d *dispatcher) dispatchSubTask(task *proto.Task, metas [][]byte) error { return d.updateTask(proto.TaskStateRunning, subTasks, retrySQLTimes) } -func (d *dispatcher) handlePlanErr(err error) error { +func (d *BaseDispatcher) handlePlanErr(err error) error { logutil.Logger(d.logCtx).Warn("generate plan failed", zap.Error(err), zap.String("state", d.task.State)) - if d.impl.IsRetryableErr(err) { + if d.IsRetryableErr(err) { return err } d.task.Error = err @@ -458,8 +457,8 @@ func GenerateSchedulerNodes(ctx context.Context) (serverNodes []*infosync.Server } // GetAllSchedulerIDs gets all the scheduler IDs. -func (d *dispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Task) ([]string, error) { - serverInfos, err := d.impl.GetEligibleInstances(ctx, task) +func (d *BaseDispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Task) ([]string, error) { + serverInfos, err := d.GetEligibleInstances(ctx, task) if err != nil { return nil, err } @@ -481,7 +480,7 @@ func (d *dispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Task) ( } // GetPreviousSubtaskMetas get subtask metas from specific step. -func (d *dispatcher) GetPreviousSubtaskMetas(taskID int64, step int64) ([][]byte, error) { +func (d *BaseDispatcher) GetPreviousSubtaskMetas(taskID int64, step int64) ([][]byte, error) { previousSubtasks, err := d.taskMgr.GetSucceedSubtasksByStep(taskID, step) if err != nil { logutil.Logger(d.logCtx).Warn("get previous succeed subtask failed", zap.Int64("step", step)) @@ -495,12 +494,12 @@ func (d *dispatcher) GetPreviousSubtaskMetas(taskID int64, step int64) ([][]byte } // WithNewSession executes the function with a new session. -func (d *dispatcher) WithNewSession(fn func(se sessionctx.Context) error) error { +func (d *BaseDispatcher) WithNewSession(fn func(se sessionctx.Context) error) error { return d.taskMgr.WithNewSession(fn) } // WithNewTxn executes the fn in a new transaction. -func (d *dispatcher) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { +func (d *BaseDispatcher) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { return d.taskMgr.WithNewTxn(ctx, fn) } diff --git a/disttask/framework/dispatcher/dispatcher_manager.go b/disttask/framework/dispatcher/dispatcher_manager.go index 5ab2ab6cf3a24..7a750a21e9adc 100644 --- a/disttask/framework/dispatcher/dispatcher_manager.go +++ b/disttask/framework/dispatcher/dispatcher_manager.go @@ -18,6 +18,7 @@ import ( "context" "time" + "github.com/pingcap/errors" "github.com/pingcap/tidb/disttask/framework/proto" "github.com/pingcap/tidb/disttask/framework/storage" "github.com/pingcap/tidb/resourcemanager/pool/spool" @@ -34,7 +35,7 @@ func (dm *Manager) getRunningTaskCnt() int { return len(dm.runningTasks.taskIDs) } -func (dm *Manager) setRunningTask(task *proto.Task, dispatcher *dispatcher) { +func (dm *Manager) setRunningTask(task *proto.Task, dispatcher Dispatcher) { dm.runningTasks.Lock() defer dm.runningTasks.Unlock() dm.runningTasks.taskIDs[task.ID] = struct{}{} @@ -81,7 +82,7 @@ type Manager struct { runningTasks struct { syncutil.RWMutex taskIDs map[int64]struct{} - dispatchers map[int64]*dispatcher + dispatchers map[int64]Dispatcher } finishedTaskCh chan *proto.Task } @@ -100,7 +101,7 @@ func NewManager(ctx context.Context, taskTable *storage.TaskManager, serverID st dispatcherManager.gPool = gPool dispatcherManager.ctx, dispatcherManager.cancel = context.WithCancel(ctx) dispatcherManager.runningTasks.taskIDs = make(map[int64]struct{}) - dispatcherManager.runningTasks.dispatchers = make(map[int64]*dispatcher) + dispatcherManager.runningTasks.dispatchers = make(map[int64]Dispatcher) return dispatcherManager, nil } @@ -157,6 +158,21 @@ func (dm *Manager) dispatchTaskLoop() { if dm.isRunningTask(task.ID) { continue } + // we check it before start dispatcher, so no need to check it again. + // see startDispatcher. + // this should not happen normally, unless user modify system table + // directly. + if GetDispatcherFactory(task.Type) == nil { + logutil.BgLogger().Warn("unknown task type", zap.Int64("task-id", task.ID), + zap.String("task-type", task.Type)) + prevState := task.State + task.State = proto.TaskStateFailed + task.Error = errors.New("unknown task type") + if _, err2 := dm.taskMgr.UpdateGlobalTaskAndAddSubTasks(task, nil, prevState); err2 != nil { + logutil.BgLogger().Warn("update task state of unknown type failed", zap.Error(err2)) + } + continue + } // the task is not in runningTasks set when: // owner changed or task is cancelled when status is pending. if task.State == proto.TaskStateRunning || task.State == proto.TaskStateReverting || task.State == proto.TaskStateCancelling { @@ -186,17 +202,15 @@ func (*Manager) checkConcurrencyOverflow(cnt int) bool { func (dm *Manager) startDispatcher(task *proto.Task) { // Using the pool with block, so it wouldn't return an error. _ = dm.gPool.Run(func() { - dispatcher, err := newDispatcher(dm.ctx, dm.taskMgr, dm.serverID, task) - if err != nil { - return - } + dispatcherFactory := GetDispatcherFactory(task.Type) + dispatcher := dispatcherFactory(dm.ctx, dm.taskMgr, dm.serverID, task) dm.setRunningTask(task, dispatcher) - dispatcher.executeTask() + dispatcher.ExecuteTask() dm.delRunningTask(task.ID) }) } // MockDispatcher mock one dispatcher for one task, only used for tests. -func (dm *Manager) MockDispatcher(task *proto.Task) (*dispatcher, error) { - return newDispatcher(dm.ctx, dm.taskMgr, dm.serverID, task) +func (dm *Manager) MockDispatcher(task *proto.Task) *BaseDispatcher { + return NewBaseDispatcher(dm.ctx, dm.taskMgr, dm.serverID, task) } diff --git a/disttask/framework/dispatcher/dispatcher_test.go b/disttask/framework/dispatcher/dispatcher_test.go index 0e8fff03d53db..15511cb919b17 100644 --- a/disttask/framework/dispatcher/dispatcher_test.go +++ b/disttask/framework/dispatcher/dispatcher_test.go @@ -35,44 +35,43 @@ import ( ) var ( - _ dispatcher.Dispatcher = (*testDispatcher)(nil) - _ dispatcher.Dispatcher = (*numberExampleDispatcher)(nil) + _ dispatcher.Extension = (*testDispatcherExt)(nil) + _ dispatcher.Extension = (*numberExampleDispatcherExt)(nil) ) const ( - taskTypeExample = "task_example" - subtaskCnt = 3 + subtaskCnt = 3 ) -type testDispatcher struct{} +type testDispatcherExt struct{} -func (*testDispatcher) OnTick(_ context.Context, _ *proto.Task) { +func (*testDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } -func (*testDispatcher) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) { +func (*testDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) { return nil, nil } -func (*testDispatcher) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) { +func (*testDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) { return nil, nil } var mockedAllServerInfos = []*infosync.ServerInfo{} -func (*testDispatcher) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { +func (*testDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { return mockedAllServerInfos, nil } -func (*testDispatcher) IsRetryableErr(error) bool { +func (*testDispatcherExt) IsRetryableErr(error) bool { return true } -type numberExampleDispatcher struct{} +type numberExampleDispatcherExt struct{} -func (*numberExampleDispatcher) OnTick(_ context.Context, _ *proto.Task) { +func (*numberExampleDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } -func (n *numberExampleDispatcher) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) { +func (n *numberExampleDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) { if task.State == proto.TaskStatePending { task.Step = proto.StepInit } @@ -92,16 +91,16 @@ func (n *numberExampleDispatcher) OnNextStage(_ context.Context, _ dispatcher.Ta return metas, nil } -func (n *numberExampleDispatcher) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) { +func (n *numberExampleDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) { // Don't handle not. return nil, nil } -func (*numberExampleDispatcher) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { +func (*numberExampleDispatcherExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { return dispatcher.GenerateSchedulerNodes(ctx) } -func (*numberExampleDispatcher) IsRetryableErr(error) bool { +func (*numberExampleDispatcherExt) IsRetryableErr(error) bool { return true } @@ -111,7 +110,12 @@ func MockDispatcherManager(t *testing.T, pool *pools.ResourcePool) (*dispatcher. storage.SetTaskManager(mgr) dsp, err := dispatcher.NewManager(util.WithInternalSourceType(ctx, "dispatcher"), mgr, "host:port") require.NoError(t, err) - dispatcher.RegisterTaskDispatcher(proto.TaskTypeExample, &testDispatcher{}) + dispatcher.RegisterDispatcherFactory(proto.TaskTypeExample, + func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher { + mockDispatcher := dsp.MockDispatcher(task) + mockDispatcher.Extension = &testDispatcherExt{} + return mockDispatcher + }) return dsp, mgr } @@ -132,8 +136,8 @@ func TestGetInstance(t *testing.T) { dspManager, mgr := MockDispatcherManager(t, pool) // test no server task := &proto.Task{ID: 1, Type: proto.TaskTypeExample} - dsp, err := dspManager.MockDispatcher(task) - require.NoError(t, err) + dsp := dspManager.MockDispatcher(task) + dsp.Extension = &testDispatcherExt{} instanceIDs, err := dsp.GetAllSchedulerIDs(ctx, task) require.Lenf(t, instanceIDs, 0, "GetAllSchedulerIDs when there's no subtask") require.NoError(t, err) @@ -188,7 +192,6 @@ func TestGetInstance(t *testing.T) { } func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) { - dispatcher.RegisterTaskDispatcher(taskTypeExample, &numberExampleDispatcher{}) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/MockDisableDistTask", "return(true)")) defer func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/domain/MockDisableDistTask")) @@ -209,6 +212,12 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) { defer pool.Close() dsp, mgr := MockDispatcherManager(t, pool) + dispatcher.RegisterDispatcherFactory(proto.TaskTypeExample, + func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher { + mockDispatcher := dsp.MockDispatcher(task) + mockDispatcher.Extension = &numberExampleDispatcherExt{} + return mockDispatcher + }) dsp.Start() defer func() { dsp.Stop() @@ -239,7 +248,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) { // Mock add tasks. taskIDs := make([]int64, 0, taskCnt) for i := 0; i < taskCnt; i++ { - taskID, err := mgr.AddNewGlobalTask(fmt.Sprintf("%d", i), taskTypeExample, 0, nil) + taskID, err := mgr.AddNewGlobalTask(fmt.Sprintf("%d", i), proto.TaskTypeExample, 0, nil) require.NoError(t, err) taskIDs = append(taskIDs, taskID) } @@ -254,7 +263,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) { } // test parallelism control if taskCnt == 1 { - taskID, err := mgr.AddNewGlobalTask(fmt.Sprintf("%d", taskCnt), taskTypeExample, 0, nil) + taskID, err := mgr.AddNewGlobalTask(fmt.Sprintf("%d", taskCnt), proto.TaskTypeExample, 0, nil) require.NoError(t, err) checkGetRunningTaskCnt(taskCnt) // Clean the task. diff --git a/disttask/framework/dispatcher/interface.go b/disttask/framework/dispatcher/interface.go index 95be7cfc4be6a..af021d0daff8d 100644 --- a/disttask/framework/dispatcher/interface.go +++ b/disttask/framework/dispatcher/interface.go @@ -18,13 +18,16 @@ import ( "context" "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/disttask/framework/storage" "github.com/pingcap/tidb/domain/infosync" "github.com/pingcap/tidb/util/syncutil" - "golang.org/x/exp/maps" ) -// Dispatcher is used to control the process operations for each task. -type Dispatcher interface { +// Extension is used to control the process operations for each task. +// it's used to extend functions of BaseDispatcher. +// as golang doesn't support inheritance, we embed this interface in Dispatcher +// to simulate abstract method as in other OO languages. +type Extension interface { // OnTick is used to handle the ticker event, if business impl need to do some periodical work, you can // do it here, but don't do too much work here, because the ticker interval is small, and it will block // the event is generated every checkTaskRunningInterval, and only when the task NOT FINISHED and NO ERROR. @@ -47,32 +50,38 @@ type Dispatcher interface { IsRetryableErr(err error) bool } -var taskDispatcherMap struct { - syncutil.RWMutex - dispatcherMap map[string]Dispatcher -} +// FactoryFn is used to create a dispatcher. +type FactoryFn func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) Dispatcher -// RegisterTaskDispatcher is used to register the task Dispatcher. -func RegisterTaskDispatcher(taskType string, dispatcherHandle Dispatcher) { - taskDispatcherMap.Lock() - taskDispatcherMap.dispatcherMap[taskType] = dispatcherHandle - taskDispatcherMap.Unlock() +var dispatcherFactoryMap = struct { + syncutil.RWMutex + m map[string]FactoryFn +}{ + m: make(map[string]FactoryFn), } -// ClearTaskDispatcher is only used in test. -func ClearTaskDispatcher() { - taskDispatcherMap.Lock() - maps.Clear(taskDispatcherMap.dispatcherMap) - taskDispatcherMap.Unlock() +// RegisterDispatcherFactory is used to register the dispatcher factory. +// normally dispatcher ctor should be registered before the server start. +// and should be called in a single routine, such as in init(). +// after the server start, there's should be no write to the map. +// but for index backfill, the register call stack is so deep, not sure +// if it's safe to do so, so we use a lock here. +func RegisterDispatcherFactory(taskType string, ctor FactoryFn) { + dispatcherFactoryMap.Lock() + defer dispatcherFactoryMap.Unlock() + dispatcherFactoryMap.m[taskType] = ctor } -// GetTaskDispatcher is used to get the task Dispatcher. -func GetTaskDispatcher(taskType string) Dispatcher { - taskDispatcherMap.Lock() - defer taskDispatcherMap.Unlock() - return taskDispatcherMap.dispatcherMap[taskType] +// GetDispatcherFactory is used to get the dispatcher factory. +func GetDispatcherFactory(taskType string) FactoryFn { + dispatcherFactoryMap.RLock() + defer dispatcherFactoryMap.RUnlock() + return dispatcherFactoryMap.m[taskType] } -func init() { - taskDispatcherMap.dispatcherMap = make(map[string]Dispatcher) +// ClearDispatcherFactory is only used in test +func ClearDispatcherFactory() { + dispatcherFactoryMap.Lock() + defer dispatcherFactoryMap.Unlock() + dispatcherFactoryMap.m = make(map[string]FactoryFn) } diff --git a/disttask/framework/framework_err_handling_test.go b/disttask/framework/framework_err_handling_test.go index 2661a1b56ac12..8738c2969cff4 100644 --- a/disttask/framework/framework_err_handling_test.go +++ b/disttask/framework/framework_err_handling_test.go @@ -27,19 +27,19 @@ import ( "github.com/pingcap/tidb/testkit" ) -type planErrDispatcher struct { +type planErrDispatcherExt struct { callTime int } var ( - _ dispatcher.Dispatcher = (*planErrDispatcher)(nil) - _ dispatcher.Dispatcher = (*planNotRetryableErrDispatcher)(nil) + _ dispatcher.Extension = (*planErrDispatcherExt)(nil) + _ dispatcher.Extension = (*planNotRetryableErrDispatcherExt)(nil) ) -func (*planErrDispatcher) OnTick(_ context.Context, _ *proto.Task) { +func (*planErrDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } -func (p *planErrDispatcher) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { +func (p *planErrDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { if gTask.State == proto.TaskStatePending { if p.callTime == 0 { p.callTime++ @@ -61,7 +61,7 @@ func (p *planErrDispatcher) OnNextStage(_ context.Context, _ dispatcher.TaskHand return nil, nil } -func (p *planErrDispatcher) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) { +func (p *planErrDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) { if p.callTime == 1 { p.callTime++ return nil, errors.New("not retryable err") @@ -69,64 +69,64 @@ func (p *planErrDispatcher) OnErrStage(_ context.Context, _ dispatcher.TaskHandl return []byte("planErrTask"), nil } -func (*planErrDispatcher) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { +func (*planErrDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { return generateSchedulerNodes4Test() } -func (*planErrDispatcher) IsRetryableErr(error) bool { +func (*planErrDispatcherExt) IsRetryableErr(error) bool { return true } -type planNotRetryableErrDispatcher struct { +type planNotRetryableErrDispatcherExt struct { } -func (*planNotRetryableErrDispatcher) OnTick(_ context.Context, _ *proto.Task) { +func (*planNotRetryableErrDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } -func (p *planNotRetryableErrDispatcher) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { +func (p *planNotRetryableErrDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { return nil, errors.New("not retryable err") } -func (*planNotRetryableErrDispatcher) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) { +func (*planNotRetryableErrDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) { return nil, errors.New("not retryable err") } -func (*planNotRetryableErrDispatcher) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { +func (*planNotRetryableErrDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { return generateSchedulerNodes4Test() } -func (*planNotRetryableErrDispatcher) IsRetryableErr(error) bool { +func (*planNotRetryableErrDispatcherExt) IsRetryableErr(error) bool { return false } func TestPlanErr(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() m := sync.Map{} - RegisterTaskMeta(&m, &planErrDispatcher{0}) + RegisterTaskMeta(&m, &planErrDispatcherExt{0}) distContext := testkit.NewDistExecutionContext(t, 2) DispatchTaskAndCheckSuccess("key1", t, &m) distContext.Close() } func TestRevertPlanErr(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() m := sync.Map{} - RegisterTaskMeta(&m, &planErrDispatcher{0}) + RegisterTaskMeta(&m, &planErrDispatcherExt{0}) distContext := testkit.NewDistExecutionContext(t, 2) DispatchTaskAndCheckSuccess("key1", t, &m) distContext.Close() } func TestPlanNotRetryableErr(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() m := sync.Map{} - RegisterTaskMeta(&m, &planNotRetryableErrDispatcher{}) + RegisterTaskMeta(&m, &planNotRetryableErrDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 2) DispatchTaskAndCheckState("key1", t, &m, proto.TaskStateFailed) distContext.Close() diff --git a/disttask/framework/framework_ha_test.go b/disttask/framework/framework_ha_test.go index fb655613203f8..d6bc28f2f45c3 100644 --- a/disttask/framework/framework_ha_test.go +++ b/disttask/framework/framework_ha_test.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/disttask/framework/dispatcher" "github.com/pingcap/tidb/disttask/framework/proto" "github.com/pingcap/tidb/disttask/framework/scheduler" + "github.com/pingcap/tidb/disttask/framework/storage" "github.com/pingcap/tidb/domain/infosync" "github.com/pingcap/tidb/testkit" "github.com/stretchr/testify/require" @@ -30,7 +31,7 @@ import ( type haTestFlowHandle struct{} -var _ dispatcher.Dispatcher = (*haTestFlowHandle)(nil) +var _ dispatcher.Extension = (*haTestFlowHandle)(nil) func (*haTestFlowHandle) OnTick(_ context.Context, _ *proto.Task) { } @@ -77,8 +78,13 @@ func (*haTestFlowHandle) IsRetryableErr(error) bool { } func RegisterHATaskMeta(m *sync.Map) { - dispatcher.ClearTaskDispatcher() - dispatcher.RegisterTaskDispatcher(proto.TaskTypeExample, &haTestFlowHandle{}) + dispatcher.ClearDispatcherFactory() + dispatcher.RegisterDispatcherFactory(proto.TaskTypeExample, + func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher { + baseDispatcher := dispatcher.NewBaseDispatcher(ctx, taskMgr, serverID, task) + baseDispatcher.Extension = &haTestFlowHandle{} + return baseDispatcher + }) scheduler.ClearSchedulers() scheduler.RegisterTaskType(proto.TaskTypeExample) scheduler.RegisterSchedulerConstructor(proto.TaskTypeExample, proto.StepOne, func(_ context.Context, _ int64, _ []byte, _ int64) (scheduler.Scheduler, error) { @@ -96,7 +102,7 @@ func RegisterHATaskMeta(m *sync.Map) { } func TestHABasic(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map RegisterHATaskMeta(&m) @@ -112,7 +118,7 @@ func TestHABasic(t *testing.T) { } func TestHAManyNodes(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map @@ -129,7 +135,7 @@ func TestHAManyNodes(t *testing.T) { } func TestHAFailInDifferentStage(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map @@ -151,7 +157,7 @@ func TestHAFailInDifferentStage(t *testing.T) { } func TestHAFailInDifferentStageManyNodes(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map @@ -173,7 +179,7 @@ func TestHAFailInDifferentStageManyNodes(t *testing.T) { } func TestHAReplacedButRunning(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map @@ -186,7 +192,7 @@ func TestHAReplacedButRunning(t *testing.T) { } func TestHAReplacedButRunningManyNodes(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map diff --git a/disttask/framework/framework_rollback_test.go b/disttask/framework/framework_rollback_test.go index 1537b377454cb..d7ad0fa5e45bf 100644 --- a/disttask/framework/framework_rollback_test.go +++ b/disttask/framework/framework_rollback_test.go @@ -24,20 +24,21 @@ import ( "github.com/pingcap/tidb/disttask/framework/dispatcher" "github.com/pingcap/tidb/disttask/framework/proto" "github.com/pingcap/tidb/disttask/framework/scheduler" + "github.com/pingcap/tidb/disttask/framework/storage" "github.com/pingcap/tidb/domain/infosync" "github.com/pingcap/tidb/testkit" "github.com/stretchr/testify/require" ) -type rollbackDispatcher struct{} +type rollbackDispatcherExt struct{} -var _ dispatcher.Dispatcher = (*rollbackDispatcher)(nil) +var _ dispatcher.Extension = (*rollbackDispatcherExt)(nil) var rollbackCnt atomic.Int32 -func (*rollbackDispatcher) OnTick(_ context.Context, _ *proto.Task) { +func (*rollbackDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } -func (*rollbackDispatcher) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { +func (*rollbackDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { if gTask.State == proto.TaskStatePending { gTask.Step = proto.StepOne return [][]byte{ @@ -49,15 +50,15 @@ func (*rollbackDispatcher) OnNextStage(_ context.Context, _ dispatcher.TaskHandl return nil, nil } -func (*rollbackDispatcher) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) { +func (*rollbackDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) { return []byte("rollbacktask1"), nil } -func (*rollbackDispatcher) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { +func (*rollbackDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { return generateSchedulerNodes4Test() } -func (*rollbackDispatcher) IsRetryableErr(error) bool { +func (*rollbackDispatcherExt) IsRetryableErr(error) bool { return true } @@ -105,8 +106,13 @@ func (e *rollbackSubtaskExecutor) Run(_ context.Context) error { } func RegisterRollbackTaskMeta(m *sync.Map) { - dispatcher.ClearTaskDispatcher() - dispatcher.RegisterTaskDispatcher(proto.TaskTypeExample, &rollbackDispatcher{}) + dispatcher.ClearDispatcherFactory() + dispatcher.RegisterDispatcherFactory(proto.TaskTypeExample, + func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher { + baseDispatcher := dispatcher.NewBaseDispatcher(ctx, taskMgr, serverID, task) + baseDispatcher.Extension = &rollbackDispatcherExt{} + return baseDispatcher + }) scheduler.ClearSchedulers() scheduler.RegisterTaskType(proto.TaskTypeExample) scheduler.RegisterSchedulerConstructor(proto.TaskTypeExample, proto.StepOne, func(_ context.Context, _ int64, _ []byte, _ int64) (scheduler.Scheduler, error) { @@ -119,7 +125,7 @@ func RegisterRollbackTaskMeta(m *sync.Map) { } func TestFrameworkRollback(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() m := sync.Map{} diff --git a/disttask/framework/framework_test.go b/disttask/framework/framework_test.go index ab53fd5beb4a4..e3fa7646bbac9 100644 --- a/disttask/framework/framework_test.go +++ b/disttask/framework/framework_test.go @@ -31,14 +31,14 @@ import ( "github.com/stretchr/testify/require" ) -type testDispatcher struct{} +type testDispatcherExt struct{} -var _ dispatcher.Dispatcher = (*testDispatcher)(nil) +var _ dispatcher.Extension = (*testDispatcherExt)(nil) -func (*testDispatcher) OnTick(_ context.Context, _ *proto.Task) { +func (*testDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } -func (*testDispatcher) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { +func (*testDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { if gTask.State == proto.TaskStatePending { gTask.Step = proto.StepOne return [][]byte{ @@ -56,7 +56,7 @@ func (*testDispatcher) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, g return nil, nil } -func (*testDispatcher) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) { +func (*testDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) { return nil, nil } @@ -73,11 +73,11 @@ func generateSchedulerNodes4Test() ([]*infosync.ServerInfo, error) { return serverNodes, nil } -func (*testDispatcher) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { +func (*testDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { return generateSchedulerNodes4Test() } -func (*testDispatcher) IsRetryableErr(error) bool { +func (*testDispatcherExt) IsRetryableErr(error) bool { return true } @@ -128,9 +128,14 @@ func (e *testSubtaskExecutor1) Run(_ context.Context) error { return nil } -func RegisterTaskMeta(m *sync.Map, dispatcherHandle dispatcher.Dispatcher) { - dispatcher.ClearTaskDispatcher() - dispatcher.RegisterTaskDispatcher(proto.TaskTypeExample, dispatcherHandle) +func RegisterTaskMeta(m *sync.Map, dispatcherHandle dispatcher.Extension) { + dispatcher.ClearDispatcherFactory() + dispatcher.RegisterDispatcherFactory(proto.TaskTypeExample, + func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher { + baseDispatcher := dispatcher.NewBaseDispatcher(ctx, taskMgr, serverID, task) + baseDispatcher.Extension = dispatcherHandle + return baseDispatcher + }) scheduler.ClearSchedulers() scheduler.RegisterTaskType(proto.TaskTypeExample) scheduler.RegisterSchedulerConstructor(proto.TaskTypeExample, proto.StepOne, func(_ context.Context, _ int64, _ []byte, _ int64) (scheduler.Scheduler, error) { @@ -207,10 +212,10 @@ func DispatchTaskAndCheckState(taskKey string, t *testing.T, m *sync.Map, state } func TestFrameworkBasic(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map - RegisterTaskMeta(&m, &testDispatcher{}) + RegisterTaskMeta(&m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 2) DispatchTaskAndCheckSuccess("key1", t, &m) DispatchTaskAndCheckSuccess("key2", t, &m) @@ -225,10 +230,10 @@ func TestFrameworkBasic(t *testing.T) { } func TestFramework3Server(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map - RegisterTaskMeta(&m, &testDispatcher{}) + RegisterTaskMeta(&m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 3) DispatchTaskAndCheckSuccess("key1", t, &m) DispatchTaskAndCheckSuccess("key2", t, &m) @@ -240,10 +245,10 @@ func TestFramework3Server(t *testing.T) { } func TestFrameworkAddDomain(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map - RegisterTaskMeta(&m, &testDispatcher{}) + RegisterTaskMeta(&m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 2) DispatchTaskAndCheckSuccess("key1", t, &m) distContext.AddDomain() @@ -257,10 +262,10 @@ func TestFrameworkAddDomain(t *testing.T) { } func TestFrameworkDeleteDomain(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map - RegisterTaskMeta(&m, &testDispatcher{}) + RegisterTaskMeta(&m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 2) DispatchTaskAndCheckSuccess("key1", t, &m) distContext.DeleteDomain(1) @@ -270,10 +275,10 @@ func TestFrameworkDeleteDomain(t *testing.T) { } func TestFrameworkWithQuery(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map - RegisterTaskMeta(&m, &testDispatcher{}) + RegisterTaskMeta(&m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 2) DispatchTaskAndCheckSuccess("key1", t, &m) @@ -292,21 +297,21 @@ func TestFrameworkWithQuery(t *testing.T) { } func TestFrameworkCancelGTask(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map - RegisterTaskMeta(&m, &testDispatcher{}) + RegisterTaskMeta(&m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 2) DispatchAndCancelTask("key1", t, &m) distContext.Close() } func TestFrameworkSubTaskFailed(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map - RegisterTaskMeta(&m, &testDispatcher{}) + RegisterTaskMeta(&m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 1) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/MockExecutorRunErr", "1*return(true)")) defer func() { @@ -317,11 +322,11 @@ func TestFrameworkSubTaskFailed(t *testing.T) { } func TestFrameworkSubTaskInitEnvFailed(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map - RegisterTaskMeta(&m, &testDispatcher{}) + RegisterTaskMeta(&m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 1) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockExecSubtaskInitEnvErr", "return()")) defer func() { @@ -332,10 +337,10 @@ func TestFrameworkSubTaskInitEnvFailed(t *testing.T) { } func TestOwnerChange(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map - RegisterTaskMeta(&m, &testDispatcher{}) + RegisterTaskMeta(&m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 3) dispatcher.MockOwnerChange = func() { @@ -348,10 +353,10 @@ func TestOwnerChange(t *testing.T) { } func TestFrameworkCancelThenSubmitSubTask(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map - RegisterTaskMeta(&m, &testDispatcher{}) + RegisterTaskMeta(&m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 3) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/cancelBeforeUpdate", "return()")) DispatchTaskAndCheckState("😊", t, &m, proto.TaskStateReverted) @@ -360,10 +365,10 @@ func TestFrameworkCancelThenSubmitSubTask(t *testing.T) { } func TestSchedulerDownBasic(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map - RegisterTaskMeta(&m, &testDispatcher{}) + RegisterTaskMeta(&m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 4) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler", "return()")) @@ -378,10 +383,10 @@ func TestSchedulerDownBasic(t *testing.T) { } func TestSchedulerDownManyNodes(t *testing.T) { - defer dispatcher.ClearTaskDispatcher() + defer dispatcher.ClearDispatcherFactory() defer scheduler.ClearSchedulers() var m sync.Map - RegisterTaskMeta(&m, &testDispatcher{}) + RegisterTaskMeta(&m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 30) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler", "return()")) diff --git a/disttask/importinto/dispatcher.go b/disttask/importinto/dispatcher.go index 50d963adab71c..ddffaebe8e2f0 100644 --- a/disttask/importinto/dispatcher.go +++ b/disttask/importinto/dispatcher.go @@ -117,7 +117,7 @@ func (t *taskInfo) close(ctx context.Context) { } } -type importDispatcher struct { +type importDispatcherExt struct { mu sync.RWMutex // NOTE: there's no need to sync for below 2 fields actually, since we add a restriction that only one // task can be running at a time. but we might support task queuing in the future, leave it for now. @@ -133,9 +133,9 @@ type importDispatcher struct { disableTiKVImportMode atomic.Bool } -var _ dispatcher.Dispatcher = (*importDispatcher)(nil) +var _ dispatcher.Extension = (*importDispatcherExt)(nil) -func (dsp *importDispatcher) OnTick(ctx context.Context, task *proto.Task) { +func (dsp *importDispatcherExt) OnTick(ctx context.Context, task *proto.Task) { // only switch TiKV mode or register task when task is running if task.State != proto.TaskStateRunning { return @@ -144,7 +144,7 @@ func (dsp *importDispatcher) OnTick(ctx context.Context, task *proto.Task) { dsp.registerTask(ctx, task) } -func (dsp *importDispatcher) switchTiKVMode(ctx context.Context, task *proto.Task) { +func (dsp *importDispatcherExt) switchTiKVMode(ctx context.Context, task *proto.Task) { dsp.updateCurrentTask(task) // only import step need to switch to IMPORT mode, // If TiKV is in IMPORT mode during checksum, coprocessor will time out. @@ -173,20 +173,20 @@ func (dsp *importDispatcher) switchTiKVMode(ctx context.Context, task *proto.Tas dsp.lastSwitchTime.Store(time.Now()) } -func (dsp *importDispatcher) registerTask(ctx context.Context, task *proto.Task) { +func (dsp *importDispatcherExt) registerTask(ctx context.Context, task *proto.Task) { val, _ := dsp.taskInfoMap.LoadOrStore(task.ID, &taskInfo{taskID: task.ID}) info := val.(*taskInfo) info.register(ctx) } -func (dsp *importDispatcher) unregisterTask(ctx context.Context, task *proto.Task) { +func (dsp *importDispatcherExt) unregisterTask(ctx context.Context, task *proto.Task) { if val, loaded := dsp.taskInfoMap.LoadAndDelete(task.ID); loaded { info := val.(*taskInfo) info.close(ctx) } } -func (dsp *importDispatcher) OnNextStage(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task) ( +func (dsp *importDispatcherExt) OnNextStage(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task) ( resSubtaskMeta [][]byte, err error) { logger := logutil.BgLogger().With( zap.String("type", gTask.Type), @@ -272,7 +272,7 @@ func (dsp *importDispatcher) OnNextStage(ctx context.Context, handle dispatcher. return metaBytes, nil } -func (dsp *importDispatcher) OnErrStage(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, receiveErrs []error) ([]byte, error) { +func (dsp *importDispatcherExt) OnErrStage(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, receiveErrs []error) ([]byte, error) { logger := logutil.BgLogger().With( zap.String("type", gTask.Type), zap.Int64("task-id", gTask.ID), @@ -310,7 +310,7 @@ func (dsp *importDispatcher) OnErrStage(ctx context.Context, handle dispatcher.T return nil, err } -func (*importDispatcher) GetEligibleInstances(ctx context.Context, gTask *proto.Task) ([]*infosync.ServerInfo, error) { +func (*importDispatcherExt) GetEligibleInstances(ctx context.Context, gTask *proto.Task) ([]*infosync.ServerInfo, error) { taskMeta := &TaskMeta{} err := json.Unmarshal(gTask.Meta, taskMeta) if err != nil { @@ -322,12 +322,12 @@ func (*importDispatcher) GetEligibleInstances(ctx context.Context, gTask *proto. return dispatcher.GenerateSchedulerNodes(ctx) } -func (*importDispatcher) IsRetryableErr(error) bool { +func (*importDispatcherExt) IsRetryableErr(error) bool { // TODO: check whether the error is retryable. return false } -func (dsp *importDispatcher) switchTiKV2NormalMode(ctx context.Context, task *proto.Task, logger *zap.Logger) { +func (dsp *importDispatcherExt) switchTiKV2NormalMode(ctx context.Context, task *proto.Task, logger *zap.Logger) { dsp.updateCurrentTask(task) if dsp.disableTiKVImportMode.Load() { return @@ -348,7 +348,7 @@ func (dsp *importDispatcher) switchTiKV2NormalMode(ctx context.Context, task *pr dsp.lastSwitchTime.Store(time.Time{}) } -func (dsp *importDispatcher) updateCurrentTask(task *proto.Task) { +func (dsp *importDispatcherExt) updateCurrentTask(task *proto.Task) { if dsp.currTaskID.Swap(task.ID) != task.ID { taskMeta := &TaskMeta{} if err := json.Unmarshal(task.Meta, taskMeta); err == nil { @@ -358,6 +358,19 @@ func (dsp *importDispatcher) updateCurrentTask(task *proto.Task) { } } +type importDispatcher struct { + *dispatcher.BaseDispatcher +} + +func newImportDispatcher(ctx context.Context, taskMgr *storage.TaskManager, + serverID string, task *proto.Task) dispatcher.Dispatcher { + dis := importDispatcher{ + BaseDispatcher: dispatcher.NewBaseDispatcher(ctx, taskMgr, serverID, task), + } + dis.BaseDispatcher.Extension = &importDispatcherExt{} + return &dis +} + // preProcess does the pre-processing for the task. func preProcess(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta, logger *zap.Logger) error { logger.Info("pre process") @@ -512,7 +525,7 @@ func job2Step(ctx context.Context, taskMeta *TaskMeta, step string) error { }) } -func (dsp *importDispatcher) finishJob(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) error { +func (dsp *importDispatcherExt) finishJob(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) error { dsp.unregisterTask(ctx, gTask) redactSensitiveInfo(gTask, taskMeta) summary := &importer.JobSummary{ImportedRows: taskMeta.Result.LoadedRowCnt} @@ -522,7 +535,7 @@ func (dsp *importDispatcher) finishJob(ctx context.Context, handle dispatcher.Ta }) } -func (dsp *importDispatcher) failJob(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, +func (dsp *importDispatcherExt) failJob(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta, logger *zap.Logger, errorMsg string) error { dsp.switchTiKV2NormalMode(ctx, gTask, logger) dsp.unregisterTask(ctx, gTask) @@ -583,5 +596,5 @@ func stepStr(step int64) string { } func init() { - dispatcher.RegisterTaskDispatcher(proto.ImportInto, &importDispatcher{}) + dispatcher.RegisterDispatcherFactory(proto.ImportInto, newImportDispatcher) } diff --git a/disttask/importinto/dispatcher_test.go b/disttask/importinto/dispatcher_test.go index 721025926d3dc..3a2fd666f4d92 100644 --- a/disttask/importinto/dispatcher_test.go +++ b/disttask/importinto/dispatcher_test.go @@ -60,7 +60,7 @@ func (s *importIntoSuite) TestDispatcherGetEligibleInstances() { } mockedAllServerInfos := makeFailpointRes(serverInfoMap) - dsp := importDispatcher{} + dsp := importDispatcherExt{} gTask := &proto.Task{Meta: []byte("{}")} ctx := context.WithValue(context.Background(), "etcd", true) s.enableFailPoint("github.com/pingcap/tidb/domain/infosync/mockGetAllServerInfo", mockedAllServerInfos) @@ -88,7 +88,7 @@ func (s *importIntoSuite) TestUpdateCurrentTask() { bs, err := json.Marshal(taskMeta) require.NoError(s.T(), err) - dsp := importDispatcher{} + dsp := importDispatcherExt{} require.Equal(s.T(), int64(0), dsp.currTaskID.Load()) require.False(s.T(), dsp.disableTiKVImportMode.Load())