Skip to content

Commit

Permalink
disttask: refactor dispatcher so we have chance to add init logics pe…
Browse files Browse the repository at this point in the history
…r task type (#46603)

ref #46258
  • Loading branch information
D3Hunter committed Sep 5, 2023
1 parent db570ea commit 5cbcc86
Show file tree
Hide file tree
Showing 15 changed files with 280 additions and 201 deletions.
1 change: 1 addition & 0 deletions ddl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ go_library(
"//disttask/framework/handle",
"//disttask/framework/proto",
"//disttask/framework/scheduler",
"//disttask/framework/storage",
"//disttask/operator",
"//domain/infosync",
"//domain/resourcegroup",
Expand Down
34 changes: 24 additions & 10 deletions ddl/backfilling_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/disttask/framework/dispatcher"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/disttask/framework/storage"
"github.com/pingcap/tidb/domain/infosync"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta"
Expand All @@ -32,28 +33,28 @@ import (
"github.com/tikv/client-go/v2/tikv"
)

type backfillingDispatcher struct {
type backfillingDispatcherExt struct {
d *ddl
}

var _ dispatcher.Dispatcher = (*backfillingDispatcher)(nil)
var _ dispatcher.Extension = (*backfillingDispatcherExt)(nil)

// NewBackfillingDispatcher creates a new backfillingDispatcher.
func NewBackfillingDispatcher(d DDL) (dispatcher.Dispatcher, error) {
// NewBackfillingDispatcherExt creates a new backfillingDispatcherExt.
func NewBackfillingDispatcherExt(d DDL) (dispatcher.Extension, error) {
ddl, ok := d.(*ddl)
if !ok {
return nil, errors.New("The getDDL result should be the type of *ddl")
}
return &backfillingDispatcher{
return &backfillingDispatcherExt{
d: ddl,
}, nil
}

func (*backfillingDispatcher) OnTick(_ context.Context, _ *proto.Task) {
func (*backfillingDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}

// OnNextStage generate next stage's plan.
func (h *backfillingDispatcher) OnNextStage(ctx context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) ([][]byte, error) {
func (h *backfillingDispatcherExt) OnNextStage(ctx context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) ([][]byte, error) {
var globalTaskMeta BackfillGlobalMeta
if err := json.Unmarshal(gTask.Meta, &globalTaskMeta); err != nil {
return nil, err
Expand Down Expand Up @@ -112,23 +113,36 @@ func (h *backfillingDispatcher) OnNextStage(ctx context.Context, _ dispatcher.Ta
}

// OnErrStage generate error handling stage's plan.
func (*backfillingDispatcher) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task, receiveErr []error) (meta []byte, err error) {
func (*backfillingDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task, receiveErr []error) (meta []byte, err error) {
// We do not need extra meta info when rolling back
firstErr := receiveErr[0]
task.Error = firstErr

return nil, nil
}

func (*backfillingDispatcher) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
func (*backfillingDispatcherExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
return dispatcher.GenerateSchedulerNodes(ctx)
}

// IsRetryableErr implements TaskFlowHandle.IsRetryableErr interface.
func (*backfillingDispatcher) IsRetryableErr(error) bool {
func (*backfillingDispatcherExt) IsRetryableErr(error) bool {
return true
}

type litBackfillDispatcher struct {
*dispatcher.BaseDispatcher
}

func newLitBackfillDispatcher(ctx context.Context, taskMgr *storage.TaskManager,
serverID string, task *proto.Task, handle dispatcher.Extension) dispatcher.Dispatcher {
dis := litBackfillDispatcher{
BaseDispatcher: dispatcher.NewBaseDispatcher(ctx, taskMgr, serverID, task),
}
dis.BaseDispatcher.Extension = handle
return &dis
}

func getTblInfo(d *ddl, job *model.Job) (tblInfo *model.TableInfo, err error) {
err = kv.RunInNewTxn(d.ctx, d.store, true, func(ctx context.Context, txn kv.Transaction) error {
tblInfo, err = meta.NewMeta(txn).GetTable(job.SchemaID, job.TableID)
Expand Down
2 changes: 1 addition & 1 deletion ddl/backfilling_dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (

func TestBackfillingDispatcher(t *testing.T) {
store, dom := testkit.CreateMockStoreAndDomain(t)
dsp, err := ddl.NewBackfillingDispatcher(dom.DDL())
dsp, err := ddl.NewBackfillingDispatcherExt(dom.DDL())
require.NoError(t, err)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
Expand Down
10 changes: 7 additions & 3 deletions ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"github.com/pingcap/tidb/disttask/framework/dispatcher"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/disttask/framework/scheduler"
"github.com/pingcap/tidb/disttask/framework/storage"
"github.com/pingcap/tidb/domain/infosync"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
Expand Down Expand Up @@ -690,11 +691,14 @@ func newDDL(ctx context.Context, options ...Option) *ddl {
return NewBackfillSchedulerHandle(ctx, taskMeta, d, step == proto.StepTwo)
})

backFillDsp, err := NewBackfillingDispatcher(d)
backFillDsp, err := NewBackfillingDispatcherExt(d)
if err != nil {
logutil.BgLogger().Warn("NewBackfillingDispatcher failed", zap.String("category", "ddl"), zap.Error(err))
logutil.BgLogger().Warn("NewBackfillingDispatcherExt failed", zap.String("category", "ddl"), zap.Error(err))
} else {
dispatcher.RegisterTaskDispatcher(BackfillTaskType, backFillDsp)
dispatcher.RegisterDispatcherFactory(BackfillTaskType,
func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher {
return newLitBackfillDispatcher(ctx, taskMgr, serverID, task, backFillDsp)
})
scheduler.RegisterSubtaskExectorConstructor(BackfillTaskType, proto.StepOne,
func(proto.MinimalTask, int64) (scheduler.SubtaskExecutor, error) {
return &scheduler.EmptyExecutor{}, nil
Expand Down
1 change: 0 additions & 1 deletion disttask/framework/dispatcher/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ go_library(
"//util/syncutil",
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_failpoint//:failpoint",
"@org_golang_x_exp//maps",
"@org_uber_go_zap//:zap",
],
)
Expand Down
79 changes: 39 additions & 40 deletions disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,22 @@ type TaskHandle interface {
storage.SessionExecutor
}

// Manage the lifetime of a task.
// Dispatcher manages the lifetime of a task
// including submitting subtasks and updating the status of a task.
type dispatcher struct {
type Dispatcher interface {
ExecuteTask()
}

// BaseDispatcher is the base struct for Dispatcher.
// each task type embed this struct and implement the Extension interface.
type BaseDispatcher struct {
ctx context.Context
taskMgr *storage.TaskManager
task *proto.Task
logCtx context.Context
serverID string
impl Dispatcher
// when RegisterDispatcherFactory, the factory MUST initialize this field.
Extension

// for HA
// liveNodes will fetch and store all live nodes every liveNodeInterval ticks.
Expand All @@ -85,41 +92,33 @@ type dispatcher struct {
// MockOwnerChange mock owner change in tests.
var MockOwnerChange func()

func newDispatcher(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) (*dispatcher, error) {
// NewBaseDispatcher creates a new BaseDispatcher.
func NewBaseDispatcher(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) *BaseDispatcher {
logPrefix := fmt.Sprintf("task_id: %d, task_type: %s, server_id: %s", task.ID, task.Type, serverID)
impl := GetTaskDispatcher(task.Type)
dsp := &dispatcher{
return &BaseDispatcher{
ctx: ctx,
taskMgr: taskMgr,
task: task,
logCtx: logutil.WithKeyValue(context.Background(), "dispatcher", logPrefix),
serverID: serverID,
impl: impl,
liveNodes: nil,
liveNodeFetchInterval: DefaultLiveNodesCheckInterval,
liveNodeFetchTick: 0,
taskNodes: nil,
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
}
if dsp.impl == nil {
logutil.BgLogger().Warn("gen dispatcher impl failed, this type impl doesn't register")
dsp.task.Error = errors.New("unsupported task type")
// state transform: pending -> failed.
return nil, dsp.updateTask(proto.TaskStateFailed, nil, retrySQLTimes)
}
return dsp, nil
}

// ExecuteTask start to schedule a task.
func (d *dispatcher) executeTask() {
func (d *BaseDispatcher) ExecuteTask() {
logutil.Logger(d.logCtx).Info("execute one task",
zap.String("state", d.task.State), zap.Uint64("concurrency", d.task.Concurrency))
d.scheduleTask()
// TODO: manage history task table.
}

// refreshTask fetch task state from tidb_global_task table.
func (d *dispatcher) refreshTask() (err error) {
func (d *BaseDispatcher) refreshTask() (err error) {
d.task, err = d.taskMgr.GetGlobalTaskByID(d.task.ID)
if err != nil {
logutil.Logger(d.logCtx).Error("refresh task failed", zap.Error(err))
Expand All @@ -128,7 +127,7 @@ func (d *dispatcher) refreshTask() (err error) {
}

// scheduleTask schedule the task execution step by step.
func (d *dispatcher) scheduleTask() {
func (d *BaseDispatcher) scheduleTask() {
ticker := time.NewTicker(checkTaskFinishedInterval)
defer ticker.Stop()
for {
Expand Down Expand Up @@ -178,14 +177,14 @@ func (d *dispatcher) scheduleTask() {
}

// handle task in cancelling state, dispatch revert subtasks.
func (d *dispatcher) onCancelling() error {
func (d *BaseDispatcher) onCancelling() error {
logutil.Logger(d.logCtx).Debug("on cancelling state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step))
errs := []error{errors.New("cancel")}
return d.onErrHandlingStage(errs)
}

// handle task in reverting state, check all revert subtasks finished.
func (d *dispatcher) onReverting() error {
func (d *BaseDispatcher) onReverting() error {
logutil.Logger(d.logCtx).Debug("on reverting state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.task.ID, proto.TaskStateRevertPending, proto.TaskStateReverting)
if err != nil {
Expand All @@ -199,20 +198,20 @@ func (d *dispatcher) onReverting() error {
return d.updateTask(proto.TaskStateReverted, nil, retrySQLTimes)
}
// Wait all subtasks in this stage finished.
d.impl.OnTick(d.ctx, d.task)
d.OnTick(d.ctx, d.task)
logutil.Logger(d.logCtx).Debug("on reverting state, this task keeps current state", zap.String("state", d.task.State))
return nil
}

// handle task in pending state, dispatch subtasks.
func (d *dispatcher) onPending() error {
func (d *BaseDispatcher) onPending() error {
logutil.Logger(d.logCtx).Debug("on pending state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step))
return d.onNextStage()
}

// handle task in running state, check all running subtasks finished.
// If subtasks finished, run into the next stage.
func (d *dispatcher) onRunning() error {
func (d *BaseDispatcher) onRunning() error {
logutil.Logger(d.logCtx).Debug("on running state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step))
subTaskErrs, err := d.taskMgr.CollectSubTaskError(d.task.ID)
if err != nil {
Expand Down Expand Up @@ -240,12 +239,12 @@ func (d *dispatcher) onRunning() error {
return err
}
// Wait all subtasks in this stage finished.
d.impl.OnTick(d.ctx, d.task)
d.OnTick(d.ctx, d.task)
logutil.Logger(d.logCtx).Debug("on running state, this task keeps current state", zap.String("state", d.task.State))
return nil
}

func (d *dispatcher) replaceDeadNodesIfAny() error {
func (d *BaseDispatcher) replaceDeadNodesIfAny() error {
if len(d.taskNodes) == 0 {
return errors.Errorf("len(d.taskNodes) == 0, onNextStage is not invoked before onRunning")
}
Expand All @@ -256,7 +255,7 @@ func (d *dispatcher) replaceDeadNodesIfAny() error {
if err != nil {
return err
}
eligibleServerInfos, err := d.impl.GetEligibleInstances(d.ctx, d.task)
eligibleServerInfos, err := d.GetEligibleInstances(d.ctx, d.task)
if err != nil {
return err
}
Expand Down Expand Up @@ -299,7 +298,7 @@ func (d *dispatcher) replaceDeadNodesIfAny() error {
return nil
}

func (d *dispatcher) updateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) {
func (d *BaseDispatcher) updateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) {
prevState := d.task.State
d.task.State = taskState
if !VerifyTaskStateTransform(prevState, taskState) {
Expand Down Expand Up @@ -331,9 +330,9 @@ func (d *dispatcher) updateTask(taskState string, newSubTasks []*proto.Subtask,
return err
}

func (d *dispatcher) onErrHandlingStage(receiveErr []error) error {
func (d *BaseDispatcher) onErrHandlingStage(receiveErr []error) error {
// 1. generate the needed task meta and subTask meta (dist-plan).
meta, err := d.impl.OnErrStage(d.ctx, d, d.task, receiveErr)
meta, err := d.OnErrStage(d.ctx, d, d.task, receiveErr)
if err != nil {
// OnErrStage must be retryable, if not, there will have resource leak for tasks.
logutil.Logger(d.logCtx).Warn("handle error failed", zap.Error(err))
Expand All @@ -344,7 +343,7 @@ func (d *dispatcher) onErrHandlingStage(receiveErr []error) error {
return d.dispatchSubTask4Revert(d.task, meta)
}

func (d *dispatcher) dispatchSubTask4Revert(task *proto.Task, meta []byte) error {
func (d *BaseDispatcher) dispatchSubTask4Revert(task *proto.Task, meta []byte) error {
instanceIDs, err := d.GetAllSchedulerIDs(d.ctx, task)
if err != nil {
logutil.Logger(d.logCtx).Warn("get task's all instances failed", zap.Error(err))
Expand All @@ -358,17 +357,17 @@ func (d *dispatcher) dispatchSubTask4Revert(task *proto.Task, meta []byte) error
return d.updateTask(proto.TaskStateReverting, subTasks, retrySQLTimes)
}

func (d *dispatcher) onNextStage() error {
func (d *BaseDispatcher) onNextStage() error {
// 1. generate the needed global task meta and subTask meta (dist-plan).
metas, err := d.impl.OnNextStage(d.ctx, d, d.task)
metas, err := d.OnNextStage(d.ctx, d, d.task)
if err != nil {
return d.handlePlanErr(err)
}
// 2. dispatch dist-plan to EligibleInstances.
return d.dispatchSubTask(d.task, metas)
}

func (d *dispatcher) dispatchSubTask(task *proto.Task, metas [][]byte) error {
func (d *BaseDispatcher) dispatchSubTask(task *proto.Task, metas [][]byte) error {
logutil.Logger(d.logCtx).Info("dispatch subtasks", zap.String("state", d.task.State), zap.Uint64("concurrency", d.task.Concurrency), zap.Int("subtasks", len(metas)))
// 1. Adjust the global task's concurrency.
if task.Concurrency == 0 {
Expand Down Expand Up @@ -400,7 +399,7 @@ func (d *dispatcher) dispatchSubTask(task *proto.Task, metas [][]byte) error {
}

// 3. select all available TiDB nodes for task.
serverNodes, err := d.impl.GetEligibleInstances(d.ctx, task)
serverNodes, err := d.GetEligibleInstances(d.ctx, task)
logutil.Logger(d.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes)))

if err != nil {
Expand All @@ -424,9 +423,9 @@ func (d *dispatcher) dispatchSubTask(task *proto.Task, metas [][]byte) error {
return d.updateTask(proto.TaskStateRunning, subTasks, retrySQLTimes)
}

func (d *dispatcher) handlePlanErr(err error) error {
func (d *BaseDispatcher) handlePlanErr(err error) error {
logutil.Logger(d.logCtx).Warn("generate plan failed", zap.Error(err), zap.String("state", d.task.State))
if d.impl.IsRetryableErr(err) {
if d.IsRetryableErr(err) {
return err
}
d.task.Error = err
Expand Down Expand Up @@ -458,8 +457,8 @@ func GenerateSchedulerNodes(ctx context.Context) (serverNodes []*infosync.Server
}

// GetAllSchedulerIDs gets all the scheduler IDs.
func (d *dispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Task) ([]string, error) {
serverInfos, err := d.impl.GetEligibleInstances(ctx, task)
func (d *BaseDispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Task) ([]string, error) {
serverInfos, err := d.GetEligibleInstances(ctx, task)
if err != nil {
return nil, err
}
Expand All @@ -481,7 +480,7 @@ func (d *dispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Task) (
}

// GetPreviousSubtaskMetas get subtask metas from specific step.
func (d *dispatcher) GetPreviousSubtaskMetas(taskID int64, step int64) ([][]byte, error) {
func (d *BaseDispatcher) GetPreviousSubtaskMetas(taskID int64, step int64) ([][]byte, error) {
previousSubtasks, err := d.taskMgr.GetSucceedSubtasksByStep(taskID, step)
if err != nil {
logutil.Logger(d.logCtx).Warn("get previous succeed subtask failed", zap.Int64("step", step))
Expand All @@ -495,12 +494,12 @@ func (d *dispatcher) GetPreviousSubtaskMetas(taskID int64, step int64) ([][]byte
}

// WithNewSession executes the function with a new session.
func (d *dispatcher) WithNewSession(fn func(se sessionctx.Context) error) error {
func (d *BaseDispatcher) WithNewSession(fn func(se sessionctx.Context) error) error {
return d.taskMgr.WithNewSession(fn)
}

// WithNewTxn executes the fn in a new transaction.
func (d *dispatcher) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error {
func (d *BaseDispatcher) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error {
return d.taskMgr.WithNewTxn(ctx, fn)
}

Expand Down
Loading

0 comments on commit 5cbcc86

Please sign in to comment.