diff --git a/bindinfo/session_handle_test.go b/bindinfo/session_handle_test.go index 1da00a925247c..1ed0ce07e975d 100644 --- a/bindinfo/session_handle_test.go +++ b/bindinfo/session_handle_test.go @@ -16,6 +16,7 @@ package bindinfo_test import ( "context" + "fmt" "strconv" "testing" "time" @@ -415,6 +416,29 @@ func TestLocalTemporaryTable(t *testing.T) { tk.MustGetErrCode("create binding for delete from tmp2 where b = 1 and c > 1 using delete /*+ use_index(t, c) */ from t where b = 1 and c > 1", errno.ErrOptOnTemporaryTable) } +func TestIssue53834(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t (a varchar(1024))`) + tk.MustExec(`insert into t values (space(1024))`) + for i := 0; i < 12; i++ { + tk.MustExec(`insert into t select * from t`) + } + oomAction := tk.MustQuery(`select @@tidb_mem_oom_action`).Rows()[0][0].(string) + defer func() { + tk.MustExec(fmt.Sprintf(`set global tidb_mem_oom_action='%v'`, oomAction)) + }() + + tk.MustExec(`set global tidb_mem_oom_action='cancel'`) + err := tk.ExecToErr(`replace into t select /*+ memory_quota(1 mb) */ * from t`) + require.ErrorContains(t, err, "cancelled due to exceeding the allowed memory limit") + + tk.MustExec(`create binding for replace into t select * from t using replace into t select /*+ memory_quota(1 mb) */ * from t`) + err = tk.ExecToErr(`replace into t select * from t`) + require.ErrorContains(t, err, "cancelled due to exceeding the allowed memory limit") +} + func TestDropSingleBindings(t *testing.T) { store := testkit.CreateMockStore(t) diff --git a/planner/optimize.go b/planner/optimize.go index 9b78c703ecef8..0c7295ea4e731 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -282,7 +282,7 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in } metrics.BindUsageCounter.WithLabelValues(scope).Inc() hint.BindHint(stmtNode, binding.Hint) - curStmtHints, _, curWarns := handleStmtHints(binding.Hint.GetFirstTableHints()) + curStmtHints, _, curWarns := handleStmtHints(binding.Hint.GetStmtHints()) sessVars.StmtCtx.StmtHints = curStmtHints // update session var by hint /set_var/ for name, val := range sessVars.StmtCtx.StmtHints.SetVars { diff --git a/session/sessiontest/session_test.go b/session/sessiontest/session_test.go index b66fe9fe198ad..743393b8b5f83 100644 --- a/session/sessiontest/session_test.go +++ b/session/sessiontest/session_test.go @@ -1636,11 +1636,11 @@ func TestStmtHints(t *testing.T) { val = int64(1) * 1024 * 1024 require.True(t, tk.Session().GetSessionVars().MemTracker.CheckBytesLimit(val)) - tk.MustExec("insert /*+ MEMORY_QUOTA(1 MB) */ into t1 select /*+ MEMORY_QUOTA(3 MB) */ * from t1;") + tk.MustExec("insert /*+ MEMORY_QUOTA(1 MB) */ into t1 select /*+ MEMORY_QUOTA(1 MB) */ * from t1;") val = int64(1) * 1024 * 1024 require.True(t, tk.Session().GetSessionVars().MemTracker.CheckBytesLimit(val)) - require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 1) - require.EqualError(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings()[0].Err, "[util:3126]Hint MEMORY_QUOTA(`3145728`) is ignored as conflicting/duplicated.") + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 2) + require.EqualError(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings()[0].Err, "[util:3126]Hint MEMORY_QUOTA(`1048576`) is ignored as conflicting/duplicated.") // Test NO_INDEX_MERGE hint tk.Session().GetSessionVars().SetEnableIndexMerge(true) diff --git a/util/hint/hint_processor.go b/util/hint/hint_processor.go index 566d0ec7bc076..b838908a45f3c 100644 --- a/util/hint/hint_processor.go +++ b/util/hint/hint_processor.go @@ -45,12 +45,20 @@ type HintsSet struct { indexHints [][]*ast.IndexHint // Slice offset is the traversal order of `TableName` in the ast. } -// GetFirstTableHints gets the first table hints. -func (hs *HintsSet) GetFirstTableHints() []*ast.TableOptimizerHint { +// GetStmtHints gets all statement-level hints. +func (hs *HintsSet) GetStmtHints() []*ast.TableOptimizerHint { + var result []*ast.TableOptimizerHint if len(hs.tableHints) > 0 { - return hs.tableHints[0] + result = append(result, hs.tableHints[0]...) // keep the same behavior with prior implementation } - return nil + for _, tHints := range hs.tableHints[1:] { + for _, h := range tHints { + if isStmtHint(h) { + result = append(result, h) + } + } + } + return result } // ContainTableHint checks whether the table hint set contains a hint. @@ -89,7 +97,18 @@ func ExtractTableHintsFromStmtNode(node ast.Node, sctx sessionctx.Context) []*as case *ast.InsertStmt: // check duplicated hints checkInsertStmtHintDuplicated(node, sctx) - return x.TableHints + result := make([]*ast.TableOptimizerHint, 0, len(x.TableHints)) + result = append(result, x.TableHints...) + if x.Select != nil { + // support statement-level hint in sub-select: "insert into t select /* ... */ ..." + // TODO: support this for Update and Delete as well + for _, h := range ExtractTableHintsFromStmtNode(x.Select, sctx) { + if isStmtHint(h) { + result = append(result, h) + } + } + } + return result case *ast.ExplainStmt: return ExtractTableHintsFromStmtNode(x.Stmt, sctx) case *ast.SetOprStmt: @@ -688,3 +707,13 @@ func (checker *bindableChecker) Enter(in ast.Node) (out ast.Node, skipChildren b func (checker *bindableChecker) Leave(in ast.Node) (out ast.Node, ok bool) { return in, checker.bindable } + +// isStmtHint checks whether this hint is a statement-level hint. +func isStmtHint(h *ast.TableOptimizerHint) bool { + switch h.HintName.L { + case "max_execution_time", "memory_quota", "resource_group": + return true + default: + return false + } +}