From a9d7577eccd310f148ae602b278af160b843b468 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Tue, 4 Apr 2023 18:47:20 +0800 Subject: [PATCH] disttask: merge `GlobalTaskManager` and `SubTaskManager` (#42786) --- disttask/framework/dispatcher/dispatcher.go | 36 +++--- .../framework/dispatcher/dispatcher_test.go | 37 +++---- disttask/framework/framework_test.go | 10 +- disttask/framework/scheduler/interface.go | 9 +- .../framework/scheduler/interface_mock.go | 25 ++--- disttask/framework/scheduler/manager.go | 28 +++-- disttask/framework/scheduler/manager_test.go | 60 +++++----- disttask/framework/scheduler/scheduler.go | 34 +++--- .../framework/scheduler/scheduler_test.go | 6 +- disttask/framework/storage/table_test.go | 42 +++---- disttask/framework/storage/task_table.go | 103 ++++++------------ domain/domain.go | 31 ++---- 12 files changed, 180 insertions(+), 241 deletions(-) diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go index 55c61881ee1e5..9e1d0453c4928 100644 --- a/disttask/framework/dispatcher/dispatcher.go +++ b/disttask/framework/dispatcher/dispatcher.go @@ -92,12 +92,11 @@ func (d *dispatcher) delRunningGTask(globalTaskID int64) { } type dispatcher struct { - ctx context.Context - cancel context.CancelFunc - gTaskMgr *storage.GlobalTaskManager - subTaskMgr *storage.SubTaskManager - wg tidbutil.WaitGroupWrapper - gPool *spool.Pool + ctx context.Context + cancel context.CancelFunc + taskMgr *storage.TaskManager + wg tidbutil.WaitGroupWrapper + gPool *spool.Pool runningGTasks struct { syncutil.RWMutex @@ -107,10 +106,9 @@ type dispatcher struct { } // NewDispatcher creates a dispatcher struct. -func NewDispatcher(ctx context.Context, globalTaskTable *storage.GlobalTaskManager, subtaskTable *storage.SubTaskManager) (Dispatch, error) { +func NewDispatcher(ctx context.Context, taskTable *storage.TaskManager) (Dispatch, error) { dispatcher := &dispatcher{ - gTaskMgr: globalTaskTable, - subTaskMgr: subtaskTable, + taskMgr: taskTable, detectPendingGTaskCh: make(chan *proto.Task, DefaultDispatchConcurrency), } pool, err := spool.NewPool("dispatch_pool", int32(DefaultDispatchConcurrency), util.DistTask, spool.WithBlocking(true)) @@ -156,7 +154,7 @@ func (d *dispatcher) DispatchTaskLoop() { } // TODO: Consider getting these tasks, in addition to the task being worked on.. - gTasks, err := d.gTaskMgr.GetTasksInStates(proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateReverting) + gTasks, err := d.taskMgr.GetGlobalTasksInStates(proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateReverting) if err != nil { logutil.BgLogger().Warn("get unfinished(pending, running or reverting) tasks failed", zap.Error(err)) break @@ -199,7 +197,7 @@ func (d *dispatcher) probeTask(gTask *proto.Task) (isFinished bool, subTaskErr s // TODO: Consider putting the following operations into a transaction. // TODO: Consider collect some information about the tasks. if gTask.State != proto.TaskStateReverting { - cnt, err := d.subTaskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStateFailed) + cnt, err := d.taskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStateFailed) if err != nil { logutil.BgLogger().Warn("check task failed", zap.Int64("task ID", gTask.ID), zap.Error(err)) return false, "" @@ -208,7 +206,7 @@ func (d *dispatcher) probeTask(gTask *proto.Task) (isFinished bool, subTaskErr s return false, proto.TaskStateFailed } - cnt, err = d.subTaskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStatePending, proto.TaskStateRunning) + cnt, err = d.taskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStatePending, proto.TaskStateRunning) if err != nil { logutil.BgLogger().Warn("check task failed", zap.Int64("task ID", gTask.ID), zap.Error(err)) return false, "" @@ -220,7 +218,7 @@ func (d *dispatcher) probeTask(gTask *proto.Task) (isFinished bool, subTaskErr s return true, "" } - cnt, err := d.subTaskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStateRevertPending, proto.TaskStateReverting) + cnt, err := d.taskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStateRevertPending, proto.TaskStateReverting) if err != nil { logutil.BgLogger().Warn("check task failed", zap.Int64("task ID", gTask.ID), zap.Error(err)) return false, "" @@ -310,7 +308,7 @@ func (d *dispatcher) updateTask(gTask *proto.Task, gTaskState string, retryTimes for i := 0; i < retryTimes; i++ { gTask.State = gTaskState // Write the global task meta into the storage. - err = d.gTaskMgr.UpdateTask(gTask) + err = d.taskMgr.UpdateGlobalTask(gTask) if err == nil { break } @@ -347,7 +345,7 @@ func (d *dispatcher) processErrFlow(gTask *proto.Task, receiveErr string) error return d.updateTask(gTask, proto.TaskStateReverted, retrySQLTimes) } - // TODO: UpdateTask and AddNewTask in a txn. + // TODO: UpdateGlobalTask and AddNewSubTask in a txn. // Write the global task meta into the storage. err = d.updateTask(gTask, proto.TaskStateReverting, retrySQLTimes) if err != nil { @@ -357,7 +355,7 @@ func (d *dispatcher) processErrFlow(gTask *proto.Task, receiveErr string) error // New rollback subtasks and write into the storage. for _, id := range instanceIDs { subtask := proto.NewSubtask(gTask.ID, gTask.Type, id, meta) - err = d.subTaskMgr.AddNewTask(gTask.ID, subtask.SchedulerID, subtask.Meta, subtask.Type, true) + err = d.taskMgr.AddNewSubTask(gTask.ID, subtask.SchedulerID, subtask.Meta, subtask.Type, true) if err != nil { logutil.BgLogger().Warn("add subtask failed", zap.Int64("gTask ID", gTask.ID), zap.Error(err)) return err @@ -410,7 +408,7 @@ func (d *dispatcher) processNormalFlow(gTask *proto.Task) (err error) { return nil } - // TODO: UpdateTask and AddNewTask in a txn. + // TODO: UpdateGlobalTask and AddNewSubTask in a txn. // Write the global task meta into the storage. err = d.updateTask(gTask, gTask.State, retryTimes) if err != nil { @@ -428,7 +426,7 @@ func (d *dispatcher) processNormalFlow(gTask *proto.Task) (err error) { // TODO: Consider batch insert. // TODO: Synchronization interruption problem, e.g. AddNewTask failed. - err = d.subTaskMgr.AddNewTask(gTask.ID, subtask.SchedulerID, subtask.Meta, subtask.Type, false) + err = d.taskMgr.AddNewSubTask(gTask.ID, subtask.SchedulerID, subtask.Meta, subtask.Type, false) if err != nil { logutil.BgLogger().Warn("add subtask failed", zap.Int64("gTask ID", gTask.ID), zap.Error(err)) return err @@ -468,7 +466,7 @@ func (d *dispatcher) GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]s return nil, nil } - schedulerIDs, err := d.subTaskMgr.GetSchedulerIDs(gTaskID) + schedulerIDs, err := d.taskMgr.GetSchedulerIDsByTaskID(gTaskID) if err != nil { return nil, err } diff --git a/disttask/framework/dispatcher/dispatcher_test.go b/disttask/framework/dispatcher/dispatcher_test.go index db3c1abca155b..7c37b33a86ead 100644 --- a/disttask/framework/dispatcher/dispatcher_test.go +++ b/disttask/framework/dispatcher/dispatcher_test.go @@ -34,18 +34,15 @@ import ( "github.com/tikv/client-go/v2/util" ) -func MockDispatcher(t *testing.T) (dispatcher.Dispatch, *storage.GlobalTaskManager, *storage.SubTaskManager, kv.Storage) { +func MockDispatcher(t *testing.T) (dispatcher.Dispatch, *storage.TaskManager, kv.Storage) { store := testkit.CreateMockStore(t) gtk := testkit.NewTestKit(t, store) - stk := testkit.NewTestKit(t, store) ctx := context.Background() - gm := storage.NewGlobalTaskManager(util.WithInternalSourceType(ctx, "globalTaskManager"), gtk.Session()) - storage.SetGlobalTaskManager(gm) - sm := storage.NewSubTaskManager(util.WithInternalSourceType(ctx, "subTaskManager"), stk.Session()) - storage.SetSubTaskManager(sm) - dsp, err := dispatcher.NewDispatcher(util.WithInternalSourceType(ctx, "dispatcher"), gm, sm) + mgr := storage.NewTaskManager(util.WithInternalSourceType(ctx, "taskManager"), gtk.Session()) + storage.SetTaskManager(mgr) + dsp, err := dispatcher.NewDispatcher(util.WithInternalSourceType(ctx, "dispatcher"), mgr) require.NoError(t, err) - return dsp, gm, sm, store + return dsp, mgr, store } func deleteTasks(t *testing.T, store kv.Storage, taskID int64) { @@ -55,7 +52,7 @@ func deleteTasks(t *testing.T, store kv.Storage, taskID int64) { func TestGetInstance(t *testing.T) { ctx := context.Background() - dsp, _, subTaskMgr, _ := MockDispatcher(t) + dsp, mgr, _ := MockDispatcher(t) makeFailpointRes := func(v interface{}) string { bytes, err := json.Marshal(v) @@ -103,7 +100,7 @@ func TestGetInstance(t *testing.T) { TaskID: gTaskID, SchedulerID: uuids[1], } - err = subTaskMgr.AddNewTask(gTaskID, subtask.SchedulerID, nil, subtask.Type, true) + err = mgr.AddNewSubTask(gTaskID, subtask.SchedulerID, nil, subtask.Type, true) require.NoError(t, err) instanceIDs, err = dsp.GetAllSchedulerIDs(ctx, gTaskID) require.NoError(t, err) @@ -115,7 +112,7 @@ func TestGetInstance(t *testing.T) { TaskID: gTaskID, SchedulerID: uuids[0], } - err = subTaskMgr.AddNewTask(gTaskID, subtask.SchedulerID, nil, subtask.Type, true) + err = mgr.AddNewSubTask(gTaskID, subtask.SchedulerID, nil, subtask.Type, true) require.NoError(t, err) instanceIDs, err = dsp.GetAllSchedulerIDs(ctx, gTaskID) require.NoError(t, err) @@ -142,7 +139,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool) { dispatcher.DefaultDispatchConcurrency = 1 } - dsp, gTaskMgr, subTaskMgr, store := MockDispatcher(t) + dsp, mgr, store := MockDispatcher(t) dsp.Start() defer func() { dsp.Stop() @@ -171,24 +168,24 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool) { // Mock add tasks. taskIDs := make([]int64, 0, taskCnt) for i := 0; i < taskCnt; i++ { - taskID, err := gTaskMgr.AddNewTask(fmt.Sprintf("%d", i), taskTypeExample, 0, nil) + taskID, err := mgr.AddNewGlobalTask(fmt.Sprintf("%d", i), taskTypeExample, 0, nil) require.NoError(t, err) taskIDs = append(taskIDs, taskID) } // test normal flow checkGetRunningGTaskCnt() - tasks, err := gTaskMgr.GetTasksInStates(proto.TaskStateRunning) + tasks, err := mgr.GetGlobalTasksInStates(proto.TaskStateRunning) require.NoError(t, err) require.Len(t, tasks, taskCnt) for i, taskID := range taskIDs { require.Equal(t, int64(i+1), tasks[i].ID) - subtasks, err := subTaskMgr.GetSubtaskInStatesCnt(taskID, proto.TaskStatePending) + subtasks, err := mgr.GetSubtaskInStatesCnt(taskID, proto.TaskStatePending) require.NoError(t, err) require.Equal(t, int64(subtaskCnt), subtasks, fmt.Sprintf("num:%d", i)) } // test parallelism control if taskCnt == 1 { - taskID, err := gTaskMgr.AddNewTask(fmt.Sprintf("%d", taskCnt), taskTypeExample, 0, nil) + taskID, err := mgr.AddNewGlobalTask(fmt.Sprintf("%d", taskCnt), taskTypeExample, 0, nil) require.NoError(t, err) checkGetRunningGTaskCnt() // Clean the task. @@ -199,7 +196,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool) { // test DetectTaskLoop checkGetGTaskState := func(expectedState string) { for i := 0; i < cnt; i++ { - tasks, err = gTaskMgr.GetTasksInStates(expectedState) + tasks, err = mgr.GetGlobalTasksInStates(expectedState) require.NoError(t, err) if len(tasks) == taskCnt { break @@ -211,7 +208,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool) { if isSucc { // Mock subtasks succeed. for i := 1; i <= subtaskCnt*taskCnt; i++ { - err = subTaskMgr.UpdateSubtaskState(int64(i), proto.TaskStateSucceed) + err = mgr.UpdateSubtaskState(int64(i), proto.TaskStateSucceed) require.NoError(t, err) } checkGetGTaskState(proto.TaskStateSucceed) @@ -227,7 +224,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool) { }() // Mock a subtask fails. for i := 1; i <= subtaskCnt*taskCnt; i += subtaskCnt { - err = subTaskMgr.UpdateSubtaskState(int64(i), proto.TaskStateFailed) + err = mgr.UpdateSubtaskState(int64(i), proto.TaskStateFailed) require.NoError(t, err) } checkGetGTaskState(proto.TaskStateReverting) @@ -235,7 +232,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool) { // Mock all subtask reverted. start := subtaskCnt * taskCnt for i := start; i <= start+subtaskCnt*taskCnt; i++ { - err = subTaskMgr.UpdateSubtaskState(int64(i), proto.TaskStateReverted) + err = mgr.UpdateSubtaskState(int64(i), proto.TaskStateReverted) require.NoError(t, err) } checkGetGTaskState(proto.TaskStateReverted) diff --git a/disttask/framework/framework_test.go b/disttask/framework/framework_test.go index 96beda32d7464..ec51ed8cece3a 100644 --- a/disttask/framework/framework_test.go +++ b/disttask/framework/framework_test.go @@ -90,10 +90,10 @@ func TestFrameworkStartUp(t *testing.T) { return &testSubtaskExecutor{v: &v}, nil }) - store := testkit.CreateMockStore(t) - tk := testkit.NewTestKit(t, store) - gm := storage.NewGlobalTaskManager(context.TODO(), tk.Session()) - taskID, err := gm.AddNewTask("key1", "type1", 8, nil) + _ = testkit.CreateMockStore(t) + mgr, err := storage.GetTaskManager() + require.NoError(t, err) + taskID, err := mgr.AddNewGlobalTask("key1", "type1", 8, nil) require.NoError(t, err) start := time.Now() @@ -104,7 +104,7 @@ func TestFrameworkStartUp(t *testing.T) { } time.Sleep(time.Second) - task, err = gm.GetTaskByID(taskID) + task, err = mgr.GetGlobalTaskByID(taskID) require.NoError(t, err) require.NotNil(t, task) if task.State != proto.TaskStatePending && task.State != proto.TaskStateRunning { diff --git a/disttask/framework/scheduler/interface.go b/disttask/framework/scheduler/interface.go index 6d1bdcb1df093..81db9823b787e 100644 --- a/disttask/framework/scheduler/interface.go +++ b/disttask/framework/scheduler/interface.go @@ -22,16 +22,11 @@ import ( // TaskTable defines the interface to access task table. type TaskTable interface { - GetTasksInStates(states ...interface{}) (task []*proto.Task, err error) - GetTaskByID(taskID int64) (task *proto.Task, err error) -} - -// SubtaskTable defines the interface to access subtask table. -type SubtaskTable interface { + GetGlobalTasksInStates(states ...interface{}) (task []*proto.Task, err error) + GetGlobalTaskByID(taskID int64) (task *proto.Task, err error) GetSubtaskInStates(instanceID string, taskID int64, states ...interface{}) (*proto.Subtask, error) UpdateSubtaskState(id int64, state string) error HasSubtasksInStates(instanceID string, taskID int64, states ...interface{}) (bool, error) - // UpdateHeartbeat(TiDB string, taskID int64, heartbeat time.Time) error } // Pool defines the interface of a pool. diff --git a/disttask/framework/scheduler/interface_mock.go b/disttask/framework/scheduler/interface_mock.go index b618aebb72f69..07e54687c3ab7 100644 --- a/disttask/framework/scheduler/interface_mock.go +++ b/disttask/framework/scheduler/interface_mock.go @@ -27,8 +27,8 @@ type MockTaskTable struct { mock.Mock } -// GetTasksInStates implements TaskTable.GetTasksInStates. -func (t *MockTaskTable) GetTasksInStates(states ...interface{}) ([]*proto.Task, error) { +// GetGlobalTasksInStates implements TaskTable.GetTasksInStates. +func (t *MockTaskTable) GetGlobalTasksInStates(states ...interface{}) ([]*proto.Task, error) { args := t.Called(states...) if args.Error(1) != nil { return nil, args.Error(1) @@ -39,8 +39,8 @@ func (t *MockTaskTable) GetTasksInStates(states ...interface{}) ([]*proto.Task, } } -// GetTaskByID implements TaskTable.GetTaskByID. -func (t *MockTaskTable) GetTaskByID(id int64) (*proto.Task, error) { +// GetGlobalTaskByID implements TaskTable.GetTaskByID. +func (t *MockTaskTable) GetGlobalTaskByID(id int64) (*proto.Task, error) { args := t.Called(id) if args.Error(1) != nil { return nil, args.Error(1) @@ -51,14 +51,9 @@ func (t *MockTaskTable) GetTaskByID(id int64) (*proto.Task, error) { } } -// MockSubtaskTable is a mock of SubtaskTable. -type MockSubtaskTable struct { - mock.Mock -} - // GetSubtaskInStates implements SubtaskTable.GetSubtaskInStates. -func (m *MockSubtaskTable) GetSubtaskInStates(instanceID string, taskID int64, states ...interface{}) (*proto.Subtask, error) { - args := m.Called(instanceID, taskID, states) +func (t *MockTaskTable) GetSubtaskInStates(instanceID string, taskID int64, states ...interface{}) (*proto.Subtask, error) { + args := t.Called(instanceID, taskID, states) if args.Error(1) != nil { return nil, args.Error(1) } else if args.Get(0) == nil { @@ -69,14 +64,14 @@ func (m *MockSubtaskTable) GetSubtaskInStates(instanceID string, taskID int64, s } // UpdateSubtaskState implements SubtaskTable.UpdateSubtaskState. -func (m *MockSubtaskTable) UpdateSubtaskState(id int64, state string) error { - args := m.Called(id, state) +func (t *MockTaskTable) UpdateSubtaskState(id int64, state string) error { + args := t.Called(id, state) return args.Error(0) } // HasSubtasksInStates implements SubtaskTable.HasSubtasksInStates. -func (m *MockSubtaskTable) HasSubtasksInStates(instanceID string, taskID int64, states ...interface{}) (bool, error) { - args := m.Called(instanceID, taskID, states) +func (t *MockTaskTable) HasSubtasksInStates(instanceID string, taskID int64, states ...interface{}) (bool, error) { + args := t.Called(instanceID, taskID, states) return args.Bool(0), args.Error(1) } diff --git a/disttask/framework/scheduler/manager.go b/disttask/framework/scheduler/manager.go index 7cdb6414937b8..fd5723d2684d7 100644 --- a/disttask/framework/scheduler/manager.go +++ b/disttask/framework/scheduler/manager.go @@ -36,7 +36,7 @@ var ( // ManagerBuilder is used to build a Manager. type ManagerBuilder struct { newPool func(name string, size int32, component util.Component, options ...spool.Option) (Pool, error) - newScheduler func(ctx context.Context, id string, taskID int64, subtaskTable SubtaskTable, pool Pool) InternalScheduler + newScheduler func(ctx context.Context, id string, taskID int64, taskTable TaskTable, pool Pool) InternalScheduler } // NewManagerBuilder creates a new ManagerBuilder. @@ -55,15 +55,14 @@ func (b *ManagerBuilder) setPoolFactory(poolFactory func(name string, size int32 } // setSchedulerFactory sets the schedulerFactory to mock the InternalScheduler in unit test. -func (b *ManagerBuilder) setSchedulerFactory(schedulerFactory func(ctx context.Context, id string, taskID int64, subtaskTable SubtaskTable, pool Pool) InternalScheduler) { +func (b *ManagerBuilder) setSchedulerFactory(schedulerFactory func(ctx context.Context, id string, taskID int64, taskTable TaskTable, pool Pool) InternalScheduler) { b.newScheduler = schedulerFactory } // Manager monitors the global task table and manages the schedulers. type Manager struct { - globalTaskTable TaskTable - subtaskTable SubtaskTable - schedulerPool Pool + taskTable TaskTable + schedulerPool Pool // taskType -> subtaskExecutorPool subtaskExecutorPools map[string]Pool mu struct { @@ -78,15 +77,14 @@ type Manager struct { cancel context.CancelFunc logCtx context.Context newPool func(name string, size int32, component util.Component, options ...spool.Option) (Pool, error) - newScheduler func(ctx context.Context, id string, taskID int64, subtaskTable SubtaskTable, pool Pool) InternalScheduler + newScheduler func(ctx context.Context, id string, taskID int64, taskTable TaskTable, pool Pool) InternalScheduler } // BuildManager builds a Manager. -func (b *ManagerBuilder) BuildManager(ctx context.Context, id string, globalTaskTable TaskTable, subtaskTable SubtaskTable) (*Manager, error) { +func (b *ManagerBuilder) BuildManager(ctx context.Context, id string, taskTable TaskTable) (*Manager, error) { m := &Manager{ id: id, - globalTaskTable: globalTaskTable, - subtaskTable: subtaskTable, + taskTable: taskTable, subtaskExecutorPools: make(map[string]Pool), logCtx: logutil.WithKeyValue(context.Background(), "dist_task_manager", id), newPool: b.newPool, @@ -149,7 +147,7 @@ func (m *Manager) fetchAndHandleRunnableTasks(ctx context.Context) { logutil.Logger(m.logCtx).Info("fetchAndHandleRunnableTasks done") return case <-ticker.C: - tasks, err := m.globalTaskTable.GetTasksInStates(proto.TaskStateRunning, proto.TaskStateReverting) + tasks, err := m.taskTable.GetGlobalTasksInStates(proto.TaskStateRunning, proto.TaskStateReverting) if err != nil { m.onError(err) continue @@ -169,7 +167,7 @@ func (m *Manager) fetchAndFastCancelTasks(ctx context.Context) { logutil.Logger(m.logCtx).Info("fetchAndFastCancelTasks done") return case <-ticker.C: - tasks, err := m.globalTaskTable.GetTasksInStates(proto.TaskStateReverting) + tasks, err := m.taskTable.GetGlobalTasksInStates(proto.TaskStateReverting) if err != nil { m.onError(err) continue @@ -188,7 +186,7 @@ func (m *Manager) onRunnableTasks(ctx context.Context, tasks []*proto.Task) { logutil.Logger(m.logCtx).Error("unknown task type", zap.String("type", task.Type)) continue } - exist, err := m.subtaskTable.HasSubtasksInStates(m.id, task.ID, proto.TaskStatePending, proto.TaskStateRevertPending) + exist, err := m.taskTable.HasSubtasksInStates(m.id, task.ID, proto.TaskStatePending, proto.TaskStateRevertPending) if err != nil { m.onError(err) continue @@ -259,7 +257,7 @@ func (m *Manager) onRunnableTask(ctx context.Context, taskID int64, taskType str return } // runCtx only used in scheduler.Run, cancel in m.fetchAndFastCancelTasks - scheduler := m.newScheduler(ctx, m.id, taskID, m.subtaskTable, m.subtaskExecutorPools[taskType]) + scheduler := m.newScheduler(ctx, m.id, taskID, m.taskTable, m.subtaskExecutorPools[taskType]) scheduler.Start() defer scheduler.Stop() for { @@ -268,7 +266,7 @@ func (m *Manager) onRunnableTask(ctx context.Context, taskID int64, taskType str return case <-time.After(checkTime): } - task, err := m.globalTaskTable.GetTaskByID(taskID) + task, err := m.taskTable.GetGlobalTaskByID(taskID) if err != nil { m.onError(err) return @@ -278,7 +276,7 @@ func (m *Manager) onRunnableTask(ctx context.Context, taskID int64, taskType str return } // TODO: intergrate with heartbeat mechanism - if exist, err := m.subtaskTable.HasSubtasksInStates(m.id, task.ID, proto.TaskStatePending, proto.TaskStateRevertPending); err != nil { + if exist, err := m.taskTable.HasSubtasksInStates(m.id, task.ID, proto.TaskStatePending, proto.TaskStateRevertPending); err != nil { m.onError(err) return } else if !exist { diff --git a/disttask/framework/scheduler/manager_test.go b/disttask/framework/scheduler/manager_test.go index a6f1e1836807a..583b8ae2bbd44 100644 --- a/disttask/framework/scheduler/manager_test.go +++ b/disttask/framework/scheduler/manager_test.go @@ -29,7 +29,7 @@ import ( func TestManageTask(t *testing.T) { b := NewManagerBuilder() - m, err := b.BuildManager(context.Background(), "test", nil, nil) + m, err := b.BuildManager(context.Background(), "test", nil) require.NoError(t, err) tasks := []*proto.Task{{ID: 1}, {ID: 2}} newTasks := m.filterAlreadyHandlingTasks(tasks) @@ -67,12 +67,11 @@ func TestManageTask(t *testing.T) { func TestOnRunnableTasks(t *testing.T) { mockTaskTable := &MockTaskTable{} - mockSubtaskTable := &MockSubtaskTable{} mockInternalScheduler := &MockInternalScheduler{} mockPool := &MockPool{} b := NewManagerBuilder() - b.setSchedulerFactory(func(ctx context.Context, id string, taskID int64, subtaskTable SubtaskTable, pool Pool) InternalScheduler { + b.setSchedulerFactory(func(ctx context.Context, id string, taskID int64, taskTable TaskTable, pool Pool) InternalScheduler { return mockInternalScheduler }) b.setPoolFactory(func(name string, size int32, component util.Component, options ...spool.Option) (Pool, error) { @@ -82,7 +81,7 @@ func TestOnRunnableTasks(t *testing.T) { taskID := int64(1) task := &proto.Task{ID: taskID, State: proto.TaskStateRunning, Step: 0, Type: "type"} - m, err := b.BuildManager(context.Background(), id, mockTaskTable, mockSubtaskTable) + m, err := b.BuildManager(context.Background(), id, mockTaskTable) require.NoError(t, err) // no task @@ -94,57 +93,56 @@ func TestOnRunnableTasks(t *testing.T) { m.subtaskExecutorPools["type"] = mockPool // get subtask failed - mockSubtaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, errors.New("get subtask failed")).Once() + mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, errors.New("get subtask failed")).Once() m.onRunnableTasks(context.Background(), []*proto.Task{task}) // no subtask - mockSubtaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, nil).Once() + mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, nil).Once() m.onRunnableTasks(context.Background(), []*proto.Task{task}) // pool error - mockSubtaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() + mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() mockPool.On("Run", mock.Anything).Return(errors.New("pool error")).Once() m.onRunnableTasks(context.Background(), []*proto.Task{task}) // step 0 succeed - mockSubtaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() + mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() mockPool.On("Run", mock.Anything).Return(nil).Once() mockInternalScheduler.On("Start").Once() - mockTaskTable.On("GetTaskByID", taskID).Return(task, nil).Once() - mockSubtaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() + mockTaskTable.On("GetGlobalTaskByID", taskID).Return(task, nil).Once() + mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() mockInternalScheduler.On("Run", mock.Anything, task).Return(nil).Once() m.onRunnableTasks(context.Background(), []*proto.Task{task}) // step 1 canceled task1 := &proto.Task{ID: taskID, State: proto.TaskStateRunning, Step: 1} - mockTaskTable.On("GetTaskByID", taskID).Return(task1, nil).Once() - mockSubtaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() + mockTaskTable.On("GetGlobalTaskByID", taskID).Return(task1, nil).Once() + mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() mockInternalScheduler.On("Run", mock.Anything, task1).Return(errors.New("run errr")).Once() task2 := &proto.Task{ID: taskID, State: proto.TaskStateReverting, Step: 1} - mockTaskTable.On("GetTaskByID", taskID).Return(task2, nil).Once() - mockSubtaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() + mockTaskTable.On("GetGlobalTaskByID", taskID).Return(task2, nil).Once() + mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() mockInternalScheduler.On("Rollback", mock.Anything, task2).Return(nil).Once() task3 := &proto.Task{ID: taskID, State: proto.TaskStateReverted, Step: 1} - mockTaskTable.On("GetTaskByID", taskID).Return(task3, nil).Once() + mockTaskTable.On("GetGlobalTaskByID", taskID).Return(task3, nil).Once() mockInternalScheduler.On("Stop").Return(nil).Once() time.Sleep(5 * time.Second) mockTaskTable.AssertExpectations(t) - mockSubtaskTable.AssertExpectations(t) + mockTaskTable.AssertExpectations(t) mockInternalScheduler.AssertExpectations(t) mockPool.AssertExpectations(t) } func TestManager(t *testing.T) { - mockTaskTable := &MockTaskTable{} // TODO(gmhdbjd): use real subtask table instead of mock - mockSubtaskTable := &MockSubtaskTable{} + mockTaskTable := &MockTaskTable{} mockInternalScheduler := &MockInternalScheduler{} mockPool := &MockPool{} b := NewManagerBuilder() - b.setSchedulerFactory(func(ctx context.Context, id string, taskID int64, subtaskTable SubtaskTable, pool Pool) InternalScheduler { + b.setSchedulerFactory(func(ctx context.Context, id string, taskID int64, taskTable TaskTable, pool Pool) InternalScheduler { return mockInternalScheduler }) b.setPoolFactory(func(name string, size int32, component util.Component, options ...spool.Option) (Pool, error) { @@ -161,25 +159,25 @@ func TestManager(t *testing.T) { task1 := &proto.Task{ID: taskID1, State: proto.TaskStateRunning, Step: 0, Type: "type"} task2 := &proto.Task{ID: taskID2, State: proto.TaskStateReverting, Step: 0, Type: "type"} - mockTaskTable.On("GetTasksInStates", proto.TaskStateRunning, proto.TaskStateReverting).Return([]*proto.Task{task1, task2}, nil) - mockTaskTable.On("GetTasksInStates", proto.TaskStateReverting).Return([]*proto.Task{task2}, nil) + mockTaskTable.On("GetGlobalTasksInStates", proto.TaskStateRunning, proto.TaskStateReverting).Return([]*proto.Task{task1, task2}, nil) + mockTaskTable.On("GetGlobalTasksInStates", proto.TaskStateReverting).Return([]*proto.Task{task2}, nil) // task1 - mockSubtaskTable.On("HasSubtasksInStates", id, taskID1, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() + mockTaskTable.On("HasSubtasksInStates", id, taskID1, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() mockPool.On("Run", mock.Anything).Return(nil).Once() mockInternalScheduler.On("Start").Once() - mockTaskTable.On("GetTaskByID", taskID1).Return(task1, nil) - mockSubtaskTable.On("HasSubtasksInStates", id, taskID1, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() + mockTaskTable.On("GetGlobalTaskByID", taskID1).Return(task1, nil) + mockTaskTable.On("HasSubtasksInStates", id, taskID1, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() mockInternalScheduler.On("Run", mock.Anything, task1).Return(nil).Once() - mockSubtaskTable.On("HasSubtasksInStates", id, taskID1, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, nil) + mockTaskTable.On("HasSubtasksInStates", id, taskID1, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, nil) mockInternalScheduler.On("Stop").Once() // task2 - mockSubtaskTable.On("HasSubtasksInStates", id, taskID2, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() + mockTaskTable.On("HasSubtasksInStates", id, taskID2, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() mockPool.On("Run", mock.Anything).Return(nil).Once() mockInternalScheduler.On("Start").Once() - mockTaskTable.On("GetTaskByID", taskID2).Return(task2, nil) - mockSubtaskTable.On("HasSubtasksInStates", id, taskID2, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() + mockTaskTable.On("GetGlobalTaskByID", taskID2).Return(task2, nil) + mockTaskTable.On("HasSubtasksInStates", id, taskID2, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once() mockInternalScheduler.On("Rollback", mock.Anything, task2).Return(nil).Once() - mockSubtaskTable.On("HasSubtasksInStates", id, taskID2, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, nil) + mockTaskTable.On("HasSubtasksInStates", id, taskID2, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, nil) mockInternalScheduler.On("Stop").Once() mockPool.On("ReleaseAndWait").Twice() RegisterSubtaskExectorConstructor("type", func(minimalTask proto.MinimalTask, step int64) (SubtaskExecutor, error) { @@ -187,14 +185,14 @@ func TestManager(t *testing.T) { }, func(opts *subtaskExecutorRegisterOptions) { opts.PoolSize = 1 }) - m, err := b.BuildManager(context.Background(), id, mockTaskTable, mockSubtaskTable) + m, err := b.BuildManager(context.Background(), id, mockTaskTable) require.NoError(t, err) m.Start() time.Sleep(5 * time.Second) m.Stop() time.Sleep(5 * time.Second) mockTaskTable.AssertExpectations(t) - mockSubtaskTable.AssertExpectations(t) + mockTaskTable.AssertExpectations(t) mockInternalScheduler.AssertExpectations(t) mockPool.AssertExpectations(t) } diff --git a/disttask/framework/scheduler/scheduler.go b/disttask/framework/scheduler/scheduler.go index 230a9e2bbd370..b479beaed42e1 100644 --- a/disttask/framework/scheduler/scheduler.go +++ b/disttask/framework/scheduler/scheduler.go @@ -27,14 +27,14 @@ import ( // InternalSchedulerImpl is the implementation of InternalScheduler. type InternalSchedulerImpl struct { - ctx context.Context - cancel context.CancelFunc - id string - taskID int64 - subtaskTable SubtaskTable - pool Pool - wg sync.WaitGroup - logCtx context.Context + ctx context.Context + cancel context.CancelFunc + id string + taskID int64 + taskTable TaskTable + pool Pool + wg sync.WaitGroup + logCtx context.Context mu struct { sync.RWMutex @@ -45,14 +45,14 @@ type InternalSchedulerImpl struct { } // NewInternalScheduler creates a new InternalScheduler. -func NewInternalScheduler(ctx context.Context, id string, taskID int64, subtaskTable SubtaskTable, pool Pool) InternalScheduler { +func NewInternalScheduler(ctx context.Context, id string, taskID int64, taskTable TaskTable, pool Pool) InternalScheduler { logPrefix := fmt.Sprintf("id: %s, task_id: %d", id, taskID) schedulerImpl := &InternalSchedulerImpl{ - id: id, - taskID: taskID, - subtaskTable: subtaskTable, - pool: pool, - logCtx: logutil.WithKeyValue(context.Background(), "scheduler", logPrefix), + id: id, + taskID: taskID, + taskTable: taskTable, + pool: pool, + logCtx: logutil.WithKeyValue(context.Background(), "scheduler", logPrefix), } schedulerImpl.ctx, schedulerImpl.cancel = context.WithCancel(ctx) @@ -125,7 +125,7 @@ func (s *InternalSchedulerImpl) Run(ctx context.Context, task *proto.Task) error } for { - subtask, err := s.subtaskTable.GetSubtaskInStates(s.id, task.ID, proto.TaskStatePending) + subtask, err := s.taskTable.GetSubtaskInStates(s.id, task.ID, proto.TaskStatePending) if err != nil { s.onError(err) break @@ -200,7 +200,7 @@ func (s *InternalSchedulerImpl) Rollback(ctx context.Context, task *proto.Task) s.onError(err) return s.getError() } - subtask, err := s.subtaskTable.GetSubtaskInStates(s.id, task.ID, proto.TaskStateRevertPending) + subtask, err := s.taskTable.GetSubtaskInStates(s.id, task.ID, proto.TaskStateRevertPending) if err != nil { s.onError(err) return s.getError() @@ -277,7 +277,7 @@ func (s *InternalSchedulerImpl) resetError() { } func (s *InternalSchedulerImpl) updateSubtaskState(id int64, state string) { - err := s.subtaskTable.UpdateSubtaskState(id, state) + err := s.taskTable.UpdateSubtaskState(id, state) if err != nil { s.onError(err) } diff --git a/disttask/framework/scheduler/scheduler_test.go b/disttask/framework/scheduler/scheduler_test.go index 311b78de09b4f..7ed4c1a44a315 100644 --- a/disttask/framework/scheduler/scheduler_test.go +++ b/disttask/framework/scheduler/scheduler_test.go @@ -32,7 +32,7 @@ func TestSchedulerRun(t *testing.T) { defer cancel() runCtx, runCancel := context.WithCancel(ctx) defer runCancel() - mockSubtaskTable := &MockSubtaskTable{} + mockSubtaskTable := &MockTaskTable{} mockPool := &MockPool{} mockScheduler := &MockScheduler{} mockSubtaskExecutor := &MockSubtaskExecutor{} @@ -159,7 +159,7 @@ func TestSchedulerRollback(t *testing.T) { defer cancel() runCtx, runCancel := context.WithCancel(ctx) defer runCancel() - mockSubtaskTable := &MockSubtaskTable{} + mockSubtaskTable := &MockTaskTable{} mockPool := &MockPool{} mockScheduler := &MockScheduler{} @@ -222,7 +222,7 @@ func TestScheduler(t *testing.T) { defer cancel() runCtx, runCancel := context.WithCancel(ctx) defer runCancel() - mockSubtaskTable := &MockSubtaskTable{} + mockSubtaskTable := &MockTaskTable{} mockPool := &MockPool{} mockScheduler := &MockScheduler{} mockSubtaskExecutor := &MockSubtaskExecutor{} diff --git a/disttask/framework/storage/table_test.go b/disttask/framework/storage/table_test.go index 91d1bd03ca442..ca3058fb0ba89 100644 --- a/disttask/framework/storage/table_test.go +++ b/disttask/framework/storage/table_test.go @@ -42,17 +42,17 @@ func TestGlobalTaskTable(t *testing.T) { tk := testkit.NewTestKit(t, store) - gm := storage.NewGlobalTaskManager(context.Background(), tk.Session()) + gm := storage.NewTaskManager(context.Background(), tk.Session()) - storage.SetGlobalTaskManager(gm) - gm, err := storage.GetGlobalTaskManager() + storage.SetTaskManager(gm) + gm, err := storage.GetTaskManager() require.NoError(t, err) - id, err := gm.AddNewTask("key1", "test", 4, []byte("test")) + id, err := gm.AddNewGlobalTask("key1", "test", 4, []byte("test")) require.NoError(t, err) require.Equal(t, int64(1), id) - task, err := gm.GetNewTask() + task, err := gm.GetNewGlobalTask() require.NoError(t, err) require.Equal(t, int64(1), task.ID) require.Equal(t, "key1", task.Key) @@ -61,35 +61,35 @@ func TestGlobalTaskTable(t *testing.T) { require.Equal(t, uint64(4), task.Concurrency) require.Equal(t, []byte("test"), task.Meta) - task2, err := gm.GetTaskByID(1) + task2, err := gm.GetGlobalTaskByID(1) require.NoError(t, err) require.Equal(t, task, task2) - task3, err := gm.GetTasksInStates(proto.TaskStatePending) + task3, err := gm.GetGlobalTasksInStates(proto.TaskStatePending) require.NoError(t, err) require.Len(t, task3, 1) require.Equal(t, task, task3[0]) - task4, err := gm.GetTasksInStates(proto.TaskStatePending, proto.TaskStateRunning) + task4, err := gm.GetGlobalTasksInStates(proto.TaskStatePending, proto.TaskStateRunning) require.NoError(t, err) require.Len(t, task4, 1) require.Equal(t, task, task4[0]) task.State = proto.TaskStateRunning - err = gm.UpdateTask(task) + err = gm.UpdateGlobalTask(task) require.NoError(t, err) - task5, err := gm.GetTasksInStates(proto.TaskStateRunning) + task5, err := gm.GetGlobalTasksInStates(proto.TaskStateRunning) require.NoError(t, err) require.Len(t, task5, 1) require.Equal(t, task, task5[0]) - task6, err := gm.GetTaskByKey("key1") + task6, err := gm.GetGlobalTaskByKey("key1") require.NoError(t, err) require.Equal(t, task, task6) // test cannot insert task with dup key - _, err = gm.AddNewTask("key1", "test2", 4, []byte("test2")) + _, err = gm.AddNewGlobalTask("key1", "test2", 4, []byte("test2")) require.EqualError(t, err, "[kv:1062]Duplicate entry 'key1' for key 'tidb_global_task.task_key'") } @@ -98,13 +98,13 @@ func TestSubTaskTable(t *testing.T) { tk := testkit.NewTestKit(t, store) - sm := storage.NewSubTaskManager(context.Background(), tk.Session()) + sm := storage.NewTaskManager(context.Background(), tk.Session()) - storage.SetSubTaskManager(sm) - sm, err := storage.GetSubTaskManager() + storage.SetTaskManager(sm) + sm, err := storage.GetTaskManager() require.NoError(t, err) - err = sm.AddNewTask(1, "tidb1", []byte("test"), proto.TaskTypeExample, false) + err = sm.AddNewSubTask(1, "tidb1", []byte("test"), proto.TaskTypeExample, false) require.NoError(t, err) nilTask, err := sm.GetSubtaskInStates("tidb2", 1, proto.TaskStatePending) @@ -123,12 +123,12 @@ func TestSubTaskTable(t *testing.T) { require.NoError(t, err) require.Equal(t, task, task2) - ids, err := sm.GetSchedulerIDs(1) + ids, err := sm.GetSchedulerIDsByTaskID(1) require.NoError(t, err) require.Len(t, ids, 1) require.Equal(t, "tidb1", ids[0]) - ids, err = sm.GetSchedulerIDs(3) + ids, err = sm.GetSchedulerIDsByTaskID(3) require.NoError(t, err) require.Len(t, ids, 0) @@ -144,7 +144,7 @@ func TestSubTaskTable(t *testing.T) { require.NoError(t, err) require.True(t, ok) - err = sm.UpdateHeartbeat("tidb1", 1, time.Now()) + err = sm.UpdateSubtaskHeartbeat("tidb1", 1, time.Now()) require.NoError(t, err) err = sm.UpdateSubtaskState(1, proto.TaskStateRunning) @@ -170,14 +170,14 @@ func TestSubTaskTable(t *testing.T) { require.NoError(t, err) require.False(t, ok) - err = sm.DeleteTasks(1) + err = sm.DeleteSubtasksByTaskID(1) require.NoError(t, err) ok, err = sm.HasSubtasksInStates("tidb1", 1, proto.TaskStatePending, proto.TaskStateRunning) require.NoError(t, err) require.False(t, ok) - err = sm.AddNewTask(2, "tidb1", []byte("test"), proto.TaskTypeExample, true) + err = sm.AddNewSubTask(2, "tidb1", []byte("test"), proto.TaskTypeExample, true) require.NoError(t, err) cnt, err = sm.GetSubtaskInStatesCnt(2, proto.TaskStateRevertPending) diff --git a/disttask/framework/storage/task_table.go b/disttask/framework/storage/task_table.go index 1e2329f87a034..42208a9972b48 100644 --- a/disttask/framework/storage/task_table.go +++ b/disttask/framework/storage/task_table.go @@ -34,60 +34,36 @@ import ( "go.uber.org/zap" ) -// GlobalTaskManager is the manager of global task. -type GlobalTaskManager struct { +// TaskManager is the manager of global/sub task. +type TaskManager struct { ctx context.Context se sessionctx.Context mu sync.Mutex } -var globalTaskManagerInstance atomic.Pointer[GlobalTaskManager] -var subTaskManagerInstance atomic.Pointer[SubTaskManager] +var taskManagerInstance atomic.Pointer[TaskManager] -// NewGlobalTaskManager creates a new global task manager. -func NewGlobalTaskManager(ctx context.Context, se sessionctx.Context) *GlobalTaskManager { +// NewTaskManager creates a new task manager. +func NewTaskManager(ctx context.Context, se sessionctx.Context) *TaskManager { ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) - return &GlobalTaskManager{ + return &TaskManager{ ctx: ctx, se: se, } } -// NewSubTaskManager creates a new sub task manager. -func NewSubTaskManager(ctx context.Context, se sessionctx.Context) *SubTaskManager { - ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) - return &SubTaskManager{ - ctx: ctx, - se: se, - } -} - -// GetGlobalTaskManager gets the global task manager. -func GetGlobalTaskManager() (*GlobalTaskManager, error) { - v := globalTaskManagerInstance.Load() +// GetTaskManager gets the task manager. +func GetTaskManager() (*TaskManager, error) { + v := taskManagerInstance.Load() if v == nil { return nil, errors.New("global task manager is not initialized") } return v, nil } -// SetGlobalTaskManager sets the global task manager. -func SetGlobalTaskManager(is *GlobalTaskManager) { - globalTaskManagerInstance.Store(is) -} - -// GetSubTaskManager gets the sub task manager. -func GetSubTaskManager() (*SubTaskManager, error) { - v := subTaskManagerInstance.Load() - if v == nil { - return nil, errors.New("subTask manager is not initialized") - } - return v, nil -} - -// SetSubTaskManager sets the sub task manager. -func SetSubTaskManager(is *SubTaskManager) { - subTaskManagerInstance.Store(is) +// SetTaskManager sets the task manager. +func SetTaskManager(is *TaskManager) { + taskManagerInstance.Store(is) } // execSQL executes the sql and returns the result. @@ -129,8 +105,8 @@ func row2GlobeTask(r chunk.Row) *proto.Task { return task } -// AddNewTask adds a new task to global task table. -func (stm *GlobalTaskManager) AddNewTask(key, tp string, concurrency int, meta []byte) (int64, error) { +// AddNewGlobalTask adds a new task to global task table. +func (stm *TaskManager) AddNewGlobalTask(key, tp string, concurrency int, meta []byte) (int64, error) { stm.mu.Lock() defer stm.mu.Unlock() @@ -147,8 +123,8 @@ func (stm *GlobalTaskManager) AddNewTask(key, tp string, concurrency int, meta [ return strconv.ParseInt(rs[0].GetString(0), 10, 64) } -// GetNewTask get a new task from global task table, it's used by dispatcher only. -func (stm *GlobalTaskManager) GetNewTask() (task *proto.Task, err error) { +// GetNewGlobalTask get a new task from global task table, it's used by dispatcher only. +func (stm *TaskManager) GetNewGlobalTask() (task *proto.Task, err error) { stm.mu.Lock() defer stm.mu.Unlock() @@ -164,8 +140,8 @@ func (stm *GlobalTaskManager) GetNewTask() (task *proto.Task, err error) { return row2GlobeTask(rs[0]), nil } -// UpdateTask updates the global task. -func (stm *GlobalTaskManager) UpdateTask(task *proto.Task) error { +// UpdateGlobalTask updates the global task. +func (stm *TaskManager) UpdateGlobalTask(task *proto.Task) error { failpoint.Inject("MockUpdateTaskErr", func(val failpoint.Value) { if val.(bool) { failpoint.Return(errors.New("updateTaskErr")) @@ -183,8 +159,8 @@ func (stm *GlobalTaskManager) UpdateTask(task *proto.Task) error { return nil } -// GetTasksInStates gets the tasks in the states. -func (stm *GlobalTaskManager) GetTasksInStates(states ...interface{}) (task []*proto.Task, err error) { +// GetGlobalTasksInStates gets the tasks in the states. +func (stm *TaskManager) GetGlobalTasksInStates(states ...interface{}) (task []*proto.Task, err error) { stm.mu.Lock() defer stm.mu.Unlock() @@ -203,8 +179,8 @@ func (stm *GlobalTaskManager) GetTasksInStates(states ...interface{}) (task []*p return task, nil } -// GetTaskByID gets the task by the global task ID. -func (stm *GlobalTaskManager) GetTaskByID(taskID int64) (task *proto.Task, err error) { +// GetGlobalTaskByID gets the task by the global task ID. +func (stm *TaskManager) GetGlobalTaskByID(taskID int64) (task *proto.Task, err error) { stm.mu.Lock() defer stm.mu.Unlock() @@ -219,8 +195,8 @@ func (stm *GlobalTaskManager) GetTaskByID(taskID int64) (task *proto.Task, err e return row2GlobeTask(rs[0]), nil } -// GetTaskByKey gets the task by the task key -func (stm *GlobalTaskManager) GetTaskByKey(key string) (task *proto.Task, err error) { +// GetGlobalTaskByKey gets the task by the task key +func (stm *TaskManager) GetGlobalTaskByKey(key string) (task *proto.Task, err error) { stm.mu.Lock() defer stm.mu.Unlock() @@ -235,13 +211,6 @@ func (stm *GlobalTaskManager) GetTaskByKey(key string) (task *proto.Task, err er return row2GlobeTask(rs[0]), nil } -// SubTaskManager is the manager of subtask. -type SubTaskManager struct { - ctx context.Context - se sessionctx.Context - mu sync.Mutex -} - // row2SubTask converts a row to a subtask. func row2SubTask(r chunk.Row) *proto.Subtask { task := &proto.Subtask{ @@ -260,8 +229,8 @@ func row2SubTask(r chunk.Row) *proto.Subtask { return task } -// AddNewTask adds a new task to subtask table. -func (stm *SubTaskManager) AddNewTask(globalTaskID int64, designatedTiDBID string, meta []byte, tp string, isRevert bool) error { +// AddNewSubTask adds a new task to subtask table. +func (stm *TaskManager) AddNewSubTask(globalTaskID int64, designatedTiDBID string, meta []byte, tp string, isRevert bool) error { stm.mu.Lock() defer stm.mu.Unlock() @@ -279,7 +248,7 @@ func (stm *SubTaskManager) AddNewTask(globalTaskID int64, designatedTiDBID strin } // GetSubtaskInStates gets the subtask in the states. -func (stm *SubTaskManager) GetSubtaskInStates(tidbID string, taskID int64, states ...interface{}) (*proto.Subtask, error) { +func (stm *TaskManager) GetSubtaskInStates(tidbID string, taskID int64, states ...interface{}) (*proto.Subtask, error) { stm.mu.Lock() defer stm.mu.Unlock() @@ -297,7 +266,7 @@ func (stm *SubTaskManager) GetSubtaskInStates(tidbID string, taskID int64, state } // GetSubtaskInStatesCnt gets the subtask count in the states. -func (stm *SubTaskManager) GetSubtaskInStatesCnt(taskID int64, states ...interface{}) (int64, error) { +func (stm *TaskManager) GetSubtaskInStatesCnt(taskID int64, states ...interface{}) (int64, error) { stm.mu.Lock() defer stm.mu.Unlock() @@ -312,7 +281,7 @@ func (stm *SubTaskManager) GetSubtaskInStatesCnt(taskID int64, states ...interfa } // HasSubtasksInStates checks if there are subtasks in the states. -func (stm *SubTaskManager) HasSubtasksInStates(tidbID string, taskID int64, states ...interface{}) (bool, error) { +func (stm *TaskManager) HasSubtasksInStates(tidbID string, taskID int64, states ...interface{}) (bool, error) { stm.mu.Lock() defer stm.mu.Unlock() @@ -327,7 +296,7 @@ func (stm *SubTaskManager) HasSubtasksInStates(tidbID string, taskID int64, stat } // UpdateSubtaskState updates the subtask state. -func (stm *SubTaskManager) UpdateSubtaskState(id int64, state string) error { +func (stm *TaskManager) UpdateSubtaskState(id int64, state string) error { stm.mu.Lock() defer stm.mu.Unlock() @@ -335,8 +304,8 @@ func (stm *SubTaskManager) UpdateSubtaskState(id int64, state string) error { return err } -// UpdateHeartbeat updates the heartbeat of the subtask. -func (stm *SubTaskManager) UpdateHeartbeat(instanceID string, taskID int64, heartbeat time.Time) error { +// UpdateSubtaskHeartbeat updates the heartbeat of the subtask. +func (stm *TaskManager) UpdateSubtaskHeartbeat(instanceID string, taskID int64, heartbeat time.Time) error { stm.mu.Lock() defer stm.mu.Unlock() @@ -344,8 +313,8 @@ func (stm *SubTaskManager) UpdateHeartbeat(instanceID string, taskID int64, hear return err } -// DeleteTasks deletes the subtask of the given global task ID. -func (stm *SubTaskManager) DeleteTasks(taskID int64) error { +// DeleteSubtasksByTaskID deletes the subtask of the given global task ID. +func (stm *TaskManager) DeleteSubtasksByTaskID(taskID int64) error { stm.mu.Lock() defer stm.mu.Unlock() @@ -357,8 +326,8 @@ func (stm *SubTaskManager) DeleteTasks(taskID int64) error { return nil } -// GetSchedulerIDs gets the scheduler IDs of the given global task ID. -func (stm *SubTaskManager) GetSchedulerIDs(taskID int64) ([]string, error) { +// GetSchedulerIDsByTaskID gets the scheduler IDs of the given global task ID. +func (stm *TaskManager) GetSchedulerIDsByTaskID(taskID int64) ([]string, error) { stm.mu.Lock() defer stm.mu.Unlock() diff --git a/domain/domain.go b/domain/domain.go index 37a1b5ca951f8..360f55d9763ec 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -1351,41 +1351,30 @@ func (do *Domain) initDistTaskLoop(ctx context.Context) error { failpoint.Return(nil) } }) - se1, err := do.sysExecutorFactory(do) + se, err := do.sysExecutorFactory(do) if err != nil { return err } - se2, err := do.sysExecutorFactory(do) + taskManager := storage.NewTaskManager(kv.WithInternalSourceType(ctx, kv.InternalDistTask), se.(sessionctx.Context)) + schedulerManager, err := scheduler.NewManagerBuilder().BuildManager(ctx, do.ddl.GetID(), taskManager) if err != nil { - se1.Close() + se.Close() return err } - gm := storage.NewGlobalTaskManager(kv.WithInternalSourceType(ctx, kv.InternalDistTask), se1.(sessionctx.Context)) - sm := storage.NewSubTaskManager(kv.WithInternalSourceType(ctx, kv.InternalDistTask), se2.(sessionctx.Context)) - schedulerManager, err := scheduler.NewManagerBuilder().BuildManager(ctx, do.ddl.GetID(), gm, sm) - if err != nil { - se1.Close() - se2.Close() - return err - } - - storage.SetGlobalTaskManager(gm) - storage.SetSubTaskManager(sm) + storage.SetTaskManager(taskManager) do.wg.Run(func() { defer func() { - storage.SetGlobalTaskManager(nil) - storage.SetSubTaskManager(nil) - se1.Close() - se2.Close() + storage.SetTaskManager(nil) + se.Close() }() - do.distTaskFrameworkLoop(ctx, gm, sm, schedulerManager) + do.distTaskFrameworkLoop(ctx, taskManager, schedulerManager) }, "distTaskFrameworkLoop") return nil } -func (do *Domain) distTaskFrameworkLoop(ctx context.Context, globalTaskManager *storage.GlobalTaskManager, subtaskManager *storage.SubTaskManager, schedulerManager *scheduler.Manager) { +func (do *Domain) distTaskFrameworkLoop(ctx context.Context, taskManager *storage.TaskManager, schedulerManager *scheduler.Manager) { schedulerManager.Start() logutil.BgLogger().Info("dist task scheduler started") defer func() { @@ -1400,7 +1389,7 @@ func (do *Domain) distTaskFrameworkLoop(ctx context.Context, globalTaskManager * return } - newDispatch, err := dispatcher.NewDispatcher(ctx, globalTaskManager, subtaskManager) + newDispatch, err := dispatcher.NewDispatcher(ctx, taskManager) if err != nil { logutil.BgLogger().Error("failed to create a disttask dispatcher", zap.Error(err)) return