diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go index 71970323f9490..cd5086aa5b53c 100644 --- a/disttask/framework/dispatcher/dispatcher.go +++ b/disttask/framework/dispatcher/dispatcher.go @@ -16,7 +16,7 @@ package dispatcher import ( "context" - "math/rand" + "fmt" "time" "github.com/pingcap/errors" @@ -401,9 +401,15 @@ func (d *dispatcher) processNormalFlow(gTask *proto.Task) (err error) { return nil } + // Generate all available TiDB nodes for this global tasks. + serverNodes, err1 := GenerateSchedulerNodes(d.ctx) + if err1 != nil { + return err + } subTasks := make([]*proto.Subtask, 0, len(metas)) - for _, meta := range metas { - instanceID, err := GetEligibleInstance(d.ctx) + for i, meta := range metas { + pos := i % len(serverNodes) + instanceID, err := GetEligibleInstance(serverNodes, pos) if err != nil { logutil.BgLogger().Warn("get a eligible instance failed", zap.Int64("gTask ID", gTask.ID), zap.Error(err)) return err @@ -414,24 +420,29 @@ func (d *dispatcher) processNormalFlow(gTask *proto.Task) (err error) { } // GetEligibleInstance gets an eligible instance. -func GetEligibleInstance(ctx context.Context) (string, error) { +func GetEligibleInstance(severNodes []*infosync.ServerInfo, pos int) (string, error) { + if pos >= len(severNodes) { + errMsg := fmt.Sprintf("available TiDB nodes range is 0 to %d, but request position: %d", len(severNodes)-1, pos) + return "", errors.New(errMsg) + } + return severNodes[pos].ID, nil +} + +// GenerateSchedulerNodes generate a eligible TiDB nodes. +func GenerateSchedulerNodes(ctx context.Context) ([]*infosync.ServerInfo, error) { serverInfos, err := infosync.GetAllServerInfo(ctx) if err != nil { - return "", err + return nil, err } if len(serverInfos) == 0 { - return "", errors.New("not found instance") + return nil, errors.New("not found instance") } - // TODO: Consider valid instances, and then consider scheduling strategies. - num := rand.Intn(len(serverInfos)) - for _, info := range serverInfos { - if num == 0 { - return info.ID, nil - } - num-- + serverNodes := make([]*infosync.ServerInfo, 0, len(serverInfos)) + for _, serverInfo := range serverInfos { + serverNodes = append(serverNodes, serverInfo) } - return "", errors.New("not found instance") + return serverNodes, nil } // GetAllSchedulerIDs gets all the scheduler IDs. diff --git a/disttask/framework/dispatcher/dispatcher_test.go b/disttask/framework/dispatcher/dispatcher_test.go index ff78691a40af5..6e6e6253cf327 100644 --- a/disttask/framework/dispatcher/dispatcher_test.go +++ b/disttask/framework/dispatcher/dispatcher_test.go @@ -69,7 +69,8 @@ func TestGetInstance(t *testing.T) { // test no server mockedAllServerInfos := map[string]*infosync.ServerInfo{} require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/infosync/mockGetAllServerInfo", makeFailpointRes(mockedAllServerInfos))) - instanceID, err := dispatcher.GetEligibleInstance(ctx) + serverNodes, err := dispatcher.GenerateSchedulerNodes(ctx) + instanceID, _ := dispatcher.GetEligibleInstance(serverNodes, 0) require.Lenf(t, instanceID, 0, "instanceID:%d", instanceID) require.EqualError(t, err, "not found instance") instanceIDs, err := dsp.GetAllSchedulerIDs(ctx, 1) @@ -89,7 +90,9 @@ func TestGetInstance(t *testing.T) { }, } require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/infosync/mockGetAllServerInfo", makeFailpointRes(mockedAllServerInfos))) - instanceID, err = dispatcher.GetEligibleInstance(ctx) + serverNodes, err = dispatcher.GenerateSchedulerNodes(ctx) + require.NoError(t, err) + instanceID, err = dispatcher.GetEligibleInstance(serverNodes, 0) require.NoError(t, err) if instanceID != uuids[0] && instanceID != uuids[1] { require.FailNowf(t, "expected uuids:%d,%d, actual uuid:%d", uuids[0], uuids[1], instanceID)