diff --git a/distsql/distsql.go b/distsql/distsql.go index cc5d33a117037..33b075bb2f34b 100644 --- a/distsql/distsql.go +++ b/distsql/distsql.go @@ -34,7 +34,8 @@ import ( // DispatchMPPTasks dispatches all tasks and returns an iterator. func DispatchMPPTasks(ctx context.Context, sctx sessionctx.Context, tasks []*kv.MPPDispatchRequest, fieldTypes []*types.FieldType, planIDs []int, rootID int) (SelectResult, error) { - resp := sctx.GetMPPClient().DispatchMPPTasks(ctx, sctx.GetSessionVars().KVVars, tasks) + _, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] + resp := sctx.GetMPPClient().DispatchMPPTasks(ctx, sctx.GetSessionVars().KVVars, tasks, allowTiFlashFallback) if resp == nil { err := errors.New("client returns nil response") return nil, err diff --git a/kv/mpp.go b/kv/mpp.go index 4e3f70532c9d2..8d2754eb4f239 100644 --- a/kv/mpp.go +++ b/kv/mpp.go @@ -81,7 +81,7 @@ type MPPClient interface { ConstructMPPTasks(context.Context, *MPPBuildTasksRequest, map[string]time.Time, time.Duration) ([]MPPTaskMeta, error) // DispatchMPPTasks dispatches ALL mpp requests at once, and returns an iterator that transfers the data. - DispatchMPPTasks(ctx context.Context, vars interface{}, reqs []*MPPDispatchRequest) Response + DispatchMPPTasks(ctx context.Context, vars interface{}, reqs []*MPPDispatchRequest, needTriggerFallback bool) Response } // MPPBuildTasksRequest request the stores allocation for a mpp plan fragment. diff --git a/server/conn_test.go b/server/conn_test.go index 8c22567b6d229..65ae656773a56 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -39,6 +39,7 @@ import ( "github.com/pingcap/tidb/util/arena" "github.com/pingcap/tidb/util/chunk" "github.com/stretchr/testify/require" + tikverr "github.com/tikv/client-go/v2/error" "github.com/tikv/client-go/v2/testutils" ) @@ -858,6 +859,16 @@ func TestTiFlashFallback(t *testing.T) { require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/establishMppConnectionErr", "return(true)")) testFallbackWork(t, tk, cc, "select * from t t1 join t t2 on t1.a = t2.a") require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/store/mockstore/unistore/establishMppConnectionErr")) + + // When fallback is not set, TiFlash mpp will return the original error message + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/store/mockstore/unistore/mppDispatchTimeout", "return(true)")) + tk.MustExec("set @@tidb_allow_fallback_to_tikv=''") + tk.MustExec("set @@tidb_allow_mpp=ON") + tk.MustExec("set @@tidb_enforce_mpp=ON") + tk.MustExec("set @@tidb_isolation_read_engines='tiflash,tidb'") + err = cc.handleQuery(ctx, "select count(*) from t") + require.Error(t, err) + require.NotEqual(t, err.Error(), tikverr.ErrTiFlashServerTimeout.Error()) } func testFallbackWork(t *testing.T, tk *testkit.TestKit, cc *clientConn, sql string) { diff --git a/store/copr/mpp.go b/store/copr/mpp.go index db637f0438b24..f8970f57cccc4 100644 --- a/store/copr/mpp.go +++ b/store/copr/mpp.go @@ -141,6 +141,8 @@ type mppIterator struct { vars *tikv.Variables + needTriggerFallback bool + mu sync.Mutex } @@ -236,8 +238,12 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *Backoffer, req // That's a hard job but we can try it in the future. if sender.GetRPCError() != nil { logutil.BgLogger().Warn("mpp dispatch meet io error", zap.String("error", sender.GetRPCError().Error()), zap.Uint64("timestamp", taskMeta.StartTs), zap.Int64("task", taskMeta.TaskId)) - // we return timeout to trigger tikv's fallback - err = derr.ErrTiFlashServerTimeout + // if needTriggerFallback is true, we return timeout to trigger tikv's fallback + if m.needTriggerFallback { + err = derr.ErrTiFlashServerTimeout + } else { + err = sender.GetRPCError() + } } } else { rpcResp, err = m.store.GetTiKVClient().SendRequest(ctx, req.Meta.GetAddress(), wrappedReq, tikv.ReadTimeoutMedium) @@ -258,8 +264,11 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *Backoffer, req if err != nil { logutil.BgLogger().Error("mpp dispatch meet error", zap.String("error", err.Error()), zap.Uint64("timestamp", taskMeta.StartTs), zap.Int64("task", taskMeta.TaskId)) - // we return timeout to trigger tikv's fallback - m.sendError(derr.ErrTiFlashServerTimeout) + // if needTriggerFallback is true, we return timeout to trigger tikv's fallback + if m.needTriggerFallback { + err = derr.ErrTiFlashServerTimeout + } + m.sendError(err) return } @@ -345,8 +354,12 @@ func (m *mppIterator) establishMPPConns(bo *Backoffer, req *kv.MPPDispatchReques if err != nil { logutil.BgLogger().Warn("establish mpp connection meet error and cannot retry", zap.String("error", err.Error()), zap.Uint64("timestamp", taskMeta.StartTs), zap.Int64("task", taskMeta.TaskId)) - // we return timeout to trigger tikv's fallback - m.sendError(derr.ErrTiFlashServerTimeout) + // if needTriggerFallback is true, we return timeout to trigger tikv's fallback + if m.needTriggerFallback { + m.sendError(derr.ErrTiFlashServerTimeout) + } else { + m.sendError(err) + } return } @@ -378,7 +391,12 @@ func (m *mppIterator) establishMPPConns(bo *Backoffer, req *kv.MPPDispatchReques logutil.BgLogger().Info("stream unknown error", zap.Error(err), zap.Uint64("timestamp", taskMeta.StartTs), zap.Int64("task", taskMeta.TaskId)) } } - m.sendError(derr.ErrTiFlashServerTimeout) + // if needTriggerFallback is true, we return timeout to trigger tikv's fallback + if m.needTriggerFallback { + m.sendError(derr.ErrTiFlashServerTimeout) + } else { + m.sendError(err) + } return } } @@ -470,17 +488,18 @@ func (m *mppIterator) Next(ctx context.Context) (kv.ResultSubset, error) { } // DispatchMPPTasks dispatches all the mpp task and waits for the responses. -func (c *MPPClient) DispatchMPPTasks(ctx context.Context, variables interface{}, dispatchReqs []*kv.MPPDispatchRequest) kv.Response { +func (c *MPPClient) DispatchMPPTasks(ctx context.Context, variables interface{}, dispatchReqs []*kv.MPPDispatchRequest, needTriggerFallback bool) kv.Response { vars := variables.(*tikv.Variables) ctxChild, cancelFunc := context.WithCancel(ctx) iter := &mppIterator{ - store: c.store, - tasks: dispatchReqs, - finishCh: make(chan struct{}), - cancelFunc: cancelFunc, - respChan: make(chan *mppResponse, 4096), - startTs: dispatchReqs[0].StartTs, - vars: vars, + store: c.store, + tasks: dispatchReqs, + finishCh: make(chan struct{}), + cancelFunc: cancelFunc, + respChan: make(chan *mppResponse, 4096), + startTs: dispatchReqs[0].StartTs, + vars: vars, + needTriggerFallback: needTriggerFallback, } go iter.run(ctxChild) return iter