Skip to content

Commit

Permalink
disttask: merge GlobalTaskManager and SubTaskManager (#42786)
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao authored Apr 4, 2023
1 parent 8f64f36 commit a9d7577
Show file tree
Hide file tree
Showing 12 changed files with 180 additions and 241 deletions.
36 changes: 17 additions & 19 deletions disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,11 @@ func (d *dispatcher) delRunningGTask(globalTaskID int64) {
}

type dispatcher struct {
ctx context.Context
cancel context.CancelFunc
gTaskMgr *storage.GlobalTaskManager
subTaskMgr *storage.SubTaskManager
wg tidbutil.WaitGroupWrapper
gPool *spool.Pool
ctx context.Context
cancel context.CancelFunc
taskMgr *storage.TaskManager
wg tidbutil.WaitGroupWrapper
gPool *spool.Pool

runningGTasks struct {
syncutil.RWMutex
Expand All @@ -107,10 +106,9 @@ type dispatcher struct {
}

// NewDispatcher creates a dispatcher struct.
func NewDispatcher(ctx context.Context, globalTaskTable *storage.GlobalTaskManager, subtaskTable *storage.SubTaskManager) (Dispatch, error) {
func NewDispatcher(ctx context.Context, taskTable *storage.TaskManager) (Dispatch, error) {
dispatcher := &dispatcher{
gTaskMgr: globalTaskTable,
subTaskMgr: subtaskTable,
taskMgr: taskTable,
detectPendingGTaskCh: make(chan *proto.Task, DefaultDispatchConcurrency),
}
pool, err := spool.NewPool("dispatch_pool", int32(DefaultDispatchConcurrency), util.DistTask, spool.WithBlocking(true))
Expand Down Expand Up @@ -156,7 +154,7 @@ func (d *dispatcher) DispatchTaskLoop() {
}

// TODO: Consider getting these tasks, in addition to the task being worked on..
gTasks, err := d.gTaskMgr.GetTasksInStates(proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateReverting)
gTasks, err := d.taskMgr.GetGlobalTasksInStates(proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateReverting)
if err != nil {
logutil.BgLogger().Warn("get unfinished(pending, running or reverting) tasks failed", zap.Error(err))
break
Expand Down Expand Up @@ -199,7 +197,7 @@ func (d *dispatcher) probeTask(gTask *proto.Task) (isFinished bool, subTaskErr s
// TODO: Consider putting the following operations into a transaction.
// TODO: Consider collect some information about the tasks.
if gTask.State != proto.TaskStateReverting {
cnt, err := d.subTaskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStateFailed)
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStateFailed)
if err != nil {
logutil.BgLogger().Warn("check task failed", zap.Int64("task ID", gTask.ID), zap.Error(err))
return false, ""
Expand All @@ -208,7 +206,7 @@ func (d *dispatcher) probeTask(gTask *proto.Task) (isFinished bool, subTaskErr s
return false, proto.TaskStateFailed
}

cnt, err = d.subTaskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStatePending, proto.TaskStateRunning)
cnt, err = d.taskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStatePending, proto.TaskStateRunning)
if err != nil {
logutil.BgLogger().Warn("check task failed", zap.Int64("task ID", gTask.ID), zap.Error(err))
return false, ""
Expand All @@ -220,7 +218,7 @@ func (d *dispatcher) probeTask(gTask *proto.Task) (isFinished bool, subTaskErr s
return true, ""
}

cnt, err := d.subTaskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStateRevertPending, proto.TaskStateReverting)
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStateRevertPending, proto.TaskStateReverting)
if err != nil {
logutil.BgLogger().Warn("check task failed", zap.Int64("task ID", gTask.ID), zap.Error(err))
return false, ""
Expand Down Expand Up @@ -310,7 +308,7 @@ func (d *dispatcher) updateTask(gTask *proto.Task, gTaskState string, retryTimes
for i := 0; i < retryTimes; i++ {
gTask.State = gTaskState
// Write the global task meta into the storage.
err = d.gTaskMgr.UpdateTask(gTask)
err = d.taskMgr.UpdateGlobalTask(gTask)
if err == nil {
break
}
Expand Down Expand Up @@ -347,7 +345,7 @@ func (d *dispatcher) processErrFlow(gTask *proto.Task, receiveErr string) error
return d.updateTask(gTask, proto.TaskStateReverted, retrySQLTimes)
}

// TODO: UpdateTask and AddNewTask in a txn.
// TODO: UpdateGlobalTask and AddNewSubTask in a txn.
// Write the global task meta into the storage.
err = d.updateTask(gTask, proto.TaskStateReverting, retrySQLTimes)
if err != nil {
Expand All @@ -357,7 +355,7 @@ func (d *dispatcher) processErrFlow(gTask *proto.Task, receiveErr string) error
// New rollback subtasks and write into the storage.
for _, id := range instanceIDs {
subtask := proto.NewSubtask(gTask.ID, gTask.Type, id, meta)
err = d.subTaskMgr.AddNewTask(gTask.ID, subtask.SchedulerID, subtask.Meta, subtask.Type, true)
err = d.taskMgr.AddNewSubTask(gTask.ID, subtask.SchedulerID, subtask.Meta, subtask.Type, true)
if err != nil {
logutil.BgLogger().Warn("add subtask failed", zap.Int64("gTask ID", gTask.ID), zap.Error(err))
return err
Expand Down Expand Up @@ -410,7 +408,7 @@ func (d *dispatcher) processNormalFlow(gTask *proto.Task) (err error) {
return nil
}

// TODO: UpdateTask and AddNewTask in a txn.
// TODO: UpdateGlobalTask and AddNewSubTask in a txn.
// Write the global task meta into the storage.
err = d.updateTask(gTask, gTask.State, retryTimes)
if err != nil {
Expand All @@ -428,7 +426,7 @@ func (d *dispatcher) processNormalFlow(gTask *proto.Task) (err error) {

// TODO: Consider batch insert.
// TODO: Synchronization interruption problem, e.g. AddNewTask failed.
err = d.subTaskMgr.AddNewTask(gTask.ID, subtask.SchedulerID, subtask.Meta, subtask.Type, false)
err = d.taskMgr.AddNewSubTask(gTask.ID, subtask.SchedulerID, subtask.Meta, subtask.Type, false)
if err != nil {
logutil.BgLogger().Warn("add subtask failed", zap.Int64("gTask ID", gTask.ID), zap.Error(err))
return err
Expand Down Expand Up @@ -468,7 +466,7 @@ func (d *dispatcher) GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]s
return nil, nil
}

schedulerIDs, err := d.subTaskMgr.GetSchedulerIDs(gTaskID)
schedulerIDs, err := d.taskMgr.GetSchedulerIDsByTaskID(gTaskID)
if err != nil {
return nil, err
}
Expand Down
37 changes: 17 additions & 20 deletions disttask/framework/dispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,15 @@ import (
"github.com/tikv/client-go/v2/util"
)

func MockDispatcher(t *testing.T) (dispatcher.Dispatch, *storage.GlobalTaskManager, *storage.SubTaskManager, kv.Storage) {
func MockDispatcher(t *testing.T) (dispatcher.Dispatch, *storage.TaskManager, kv.Storage) {
store := testkit.CreateMockStore(t)
gtk := testkit.NewTestKit(t, store)
stk := testkit.NewTestKit(t, store)
ctx := context.Background()
gm := storage.NewGlobalTaskManager(util.WithInternalSourceType(ctx, "globalTaskManager"), gtk.Session())
storage.SetGlobalTaskManager(gm)
sm := storage.NewSubTaskManager(util.WithInternalSourceType(ctx, "subTaskManager"), stk.Session())
storage.SetSubTaskManager(sm)
dsp, err := dispatcher.NewDispatcher(util.WithInternalSourceType(ctx, "dispatcher"), gm, sm)
mgr := storage.NewTaskManager(util.WithInternalSourceType(ctx, "taskManager"), gtk.Session())
storage.SetTaskManager(mgr)
dsp, err := dispatcher.NewDispatcher(util.WithInternalSourceType(ctx, "dispatcher"), mgr)
require.NoError(t, err)
return dsp, gm, sm, store
return dsp, mgr, store
}

func deleteTasks(t *testing.T, store kv.Storage, taskID int64) {
Expand All @@ -55,7 +52,7 @@ func deleteTasks(t *testing.T, store kv.Storage, taskID int64) {

func TestGetInstance(t *testing.T) {
ctx := context.Background()
dsp, _, subTaskMgr, _ := MockDispatcher(t)
dsp, mgr, _ := MockDispatcher(t)

makeFailpointRes := func(v interface{}) string {
bytes, err := json.Marshal(v)
Expand Down Expand Up @@ -103,7 +100,7 @@ func TestGetInstance(t *testing.T) {
TaskID: gTaskID,
SchedulerID: uuids[1],
}
err = subTaskMgr.AddNewTask(gTaskID, subtask.SchedulerID, nil, subtask.Type, true)
err = mgr.AddNewSubTask(gTaskID, subtask.SchedulerID, nil, subtask.Type, true)
require.NoError(t, err)
instanceIDs, err = dsp.GetAllSchedulerIDs(ctx, gTaskID)
require.NoError(t, err)
Expand All @@ -115,7 +112,7 @@ func TestGetInstance(t *testing.T) {
TaskID: gTaskID,
SchedulerID: uuids[0],
}
err = subTaskMgr.AddNewTask(gTaskID, subtask.SchedulerID, nil, subtask.Type, true)
err = mgr.AddNewSubTask(gTaskID, subtask.SchedulerID, nil, subtask.Type, true)
require.NoError(t, err)
instanceIDs, err = dsp.GetAllSchedulerIDs(ctx, gTaskID)
require.NoError(t, err)
Expand All @@ -142,7 +139,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool) {
dispatcher.DefaultDispatchConcurrency = 1
}

dsp, gTaskMgr, subTaskMgr, store := MockDispatcher(t)
dsp, mgr, store := MockDispatcher(t)
dsp.Start()
defer func() {
dsp.Stop()
Expand Down Expand Up @@ -171,24 +168,24 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool) {
// Mock add tasks.
taskIDs := make([]int64, 0, taskCnt)
for i := 0; i < taskCnt; i++ {
taskID, err := gTaskMgr.AddNewTask(fmt.Sprintf("%d", i), taskTypeExample, 0, nil)
taskID, err := mgr.AddNewGlobalTask(fmt.Sprintf("%d", i), taskTypeExample, 0, nil)
require.NoError(t, err)
taskIDs = append(taskIDs, taskID)
}
// test normal flow
checkGetRunningGTaskCnt()
tasks, err := gTaskMgr.GetTasksInStates(proto.TaskStateRunning)
tasks, err := mgr.GetGlobalTasksInStates(proto.TaskStateRunning)
require.NoError(t, err)
require.Len(t, tasks, taskCnt)
for i, taskID := range taskIDs {
require.Equal(t, int64(i+1), tasks[i].ID)
subtasks, err := subTaskMgr.GetSubtaskInStatesCnt(taskID, proto.TaskStatePending)
subtasks, err := mgr.GetSubtaskInStatesCnt(taskID, proto.TaskStatePending)
require.NoError(t, err)
require.Equal(t, int64(subtaskCnt), subtasks, fmt.Sprintf("num:%d", i))
}
// test parallelism control
if taskCnt == 1 {
taskID, err := gTaskMgr.AddNewTask(fmt.Sprintf("%d", taskCnt), taskTypeExample, 0, nil)
taskID, err := mgr.AddNewGlobalTask(fmt.Sprintf("%d", taskCnt), taskTypeExample, 0, nil)
require.NoError(t, err)
checkGetRunningGTaskCnt()
// Clean the task.
Expand All @@ -199,7 +196,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool) {
// test DetectTaskLoop
checkGetGTaskState := func(expectedState string) {
for i := 0; i < cnt; i++ {
tasks, err = gTaskMgr.GetTasksInStates(expectedState)
tasks, err = mgr.GetGlobalTasksInStates(expectedState)
require.NoError(t, err)
if len(tasks) == taskCnt {
break
Expand All @@ -211,7 +208,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool) {
if isSucc {
// Mock subtasks succeed.
for i := 1; i <= subtaskCnt*taskCnt; i++ {
err = subTaskMgr.UpdateSubtaskState(int64(i), proto.TaskStateSucceed)
err = mgr.UpdateSubtaskState(int64(i), proto.TaskStateSucceed)
require.NoError(t, err)
}
checkGetGTaskState(proto.TaskStateSucceed)
Expand All @@ -227,15 +224,15 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool) {
}()
// Mock a subtask fails.
for i := 1; i <= subtaskCnt*taskCnt; i += subtaskCnt {
err = subTaskMgr.UpdateSubtaskState(int64(i), proto.TaskStateFailed)
err = mgr.UpdateSubtaskState(int64(i), proto.TaskStateFailed)
require.NoError(t, err)
}
checkGetGTaskState(proto.TaskStateReverting)
require.Len(t, tasks, taskCnt)
// Mock all subtask reverted.
start := subtaskCnt * taskCnt
for i := start; i <= start+subtaskCnt*taskCnt; i++ {
err = subTaskMgr.UpdateSubtaskState(int64(i), proto.TaskStateReverted)
err = mgr.UpdateSubtaskState(int64(i), proto.TaskStateReverted)
require.NoError(t, err)
}
checkGetGTaskState(proto.TaskStateReverted)
Expand Down
10 changes: 5 additions & 5 deletions disttask/framework/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ func TestFrameworkStartUp(t *testing.T) {
return &testSubtaskExecutor{v: &v}, nil
})

store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
gm := storage.NewGlobalTaskManager(context.TODO(), tk.Session())
taskID, err := gm.AddNewTask("key1", "type1", 8, nil)
_ = testkit.CreateMockStore(t)
mgr, err := storage.GetTaskManager()
require.NoError(t, err)
taskID, err := mgr.AddNewGlobalTask("key1", "type1", 8, nil)
require.NoError(t, err)
start := time.Now()

Expand All @@ -104,7 +104,7 @@ func TestFrameworkStartUp(t *testing.T) {
}

time.Sleep(time.Second)
task, err = gm.GetTaskByID(taskID)
task, err = mgr.GetGlobalTaskByID(taskID)
require.NoError(t, err)
require.NotNil(t, task)
if task.State != proto.TaskStatePending && task.State != proto.TaskStateRunning {
Expand Down
9 changes: 2 additions & 7 deletions disttask/framework/scheduler/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,11 @@ import (

// TaskTable defines the interface to access task table.
type TaskTable interface {
GetTasksInStates(states ...interface{}) (task []*proto.Task, err error)
GetTaskByID(taskID int64) (task *proto.Task, err error)
}

// SubtaskTable defines the interface to access subtask table.
type SubtaskTable interface {
GetGlobalTasksInStates(states ...interface{}) (task []*proto.Task, err error)
GetGlobalTaskByID(taskID int64) (task *proto.Task, err error)
GetSubtaskInStates(instanceID string, taskID int64, states ...interface{}) (*proto.Subtask, error)
UpdateSubtaskState(id int64, state string) error
HasSubtasksInStates(instanceID string, taskID int64, states ...interface{}) (bool, error)
// UpdateHeartbeat(TiDB string, taskID int64, heartbeat time.Time) error
}

// Pool defines the interface of a pool.
Expand Down
25 changes: 10 additions & 15 deletions disttask/framework/scheduler/interface_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ type MockTaskTable struct {
mock.Mock
}

// GetTasksInStates implements TaskTable.GetTasksInStates.
func (t *MockTaskTable) GetTasksInStates(states ...interface{}) ([]*proto.Task, error) {
// GetGlobalTasksInStates implements TaskTable.GetTasksInStates.
func (t *MockTaskTable) GetGlobalTasksInStates(states ...interface{}) ([]*proto.Task, error) {
args := t.Called(states...)
if args.Error(1) != nil {
return nil, args.Error(1)
Expand All @@ -39,8 +39,8 @@ func (t *MockTaskTable) GetTasksInStates(states ...interface{}) ([]*proto.Task,
}
}

// GetTaskByID implements TaskTable.GetTaskByID.
func (t *MockTaskTable) GetTaskByID(id int64) (*proto.Task, error) {
// GetGlobalTaskByID implements TaskTable.GetTaskByID.
func (t *MockTaskTable) GetGlobalTaskByID(id int64) (*proto.Task, error) {
args := t.Called(id)
if args.Error(1) != nil {
return nil, args.Error(1)
Expand All @@ -51,14 +51,9 @@ func (t *MockTaskTable) GetTaskByID(id int64) (*proto.Task, error) {
}
}

// MockSubtaskTable is a mock of SubtaskTable.
type MockSubtaskTable struct {
mock.Mock
}

// GetSubtaskInStates implements SubtaskTable.GetSubtaskInStates.
func (m *MockSubtaskTable) GetSubtaskInStates(instanceID string, taskID int64, states ...interface{}) (*proto.Subtask, error) {
args := m.Called(instanceID, taskID, states)
func (t *MockTaskTable) GetSubtaskInStates(instanceID string, taskID int64, states ...interface{}) (*proto.Subtask, error) {
args := t.Called(instanceID, taskID, states)
if args.Error(1) != nil {
return nil, args.Error(1)
} else if args.Get(0) == nil {
Expand All @@ -69,14 +64,14 @@ func (m *MockSubtaskTable) GetSubtaskInStates(instanceID string, taskID int64, s
}

// UpdateSubtaskState implements SubtaskTable.UpdateSubtaskState.
func (m *MockSubtaskTable) UpdateSubtaskState(id int64, state string) error {
args := m.Called(id, state)
func (t *MockTaskTable) UpdateSubtaskState(id int64, state string) error {
args := t.Called(id, state)
return args.Error(0)
}

// HasSubtasksInStates implements SubtaskTable.HasSubtasksInStates.
func (m *MockSubtaskTable) HasSubtasksInStates(instanceID string, taskID int64, states ...interface{}) (bool, error) {
args := m.Called(instanceID, taskID, states)
func (t *MockTaskTable) HasSubtasksInStates(instanceID string, taskID int64, states ...interface{}) (bool, error) {
args := t.Called(instanceID, taskID, states)
return args.Bool(0), args.Error(1)
}

Expand Down
Loading

0 comments on commit a9d7577

Please sign in to comment.