From ad931db7882efdf22ce3d631aaef269abdc5ec5a Mon Sep 17 00:00:00 2001 From: Yang Keao Date: Mon, 24 Jun 2024 16:37:40 +0800 Subject: [PATCH 1/2] add a linter to check extraprivate field Signed-off-by: Yang Keao --- build/BUILD.bazel | 1 + build/linter/extraprivate/BUILD.bazel | 26 ++++ build/linter/extraprivate/analyzer.go | 138 ++++++++++++++++++ build/linter/extraprivate/analyzer_test.go | 33 +++++ build/linter/extraprivate/cmd/BUILD.bazel | 18 +++ build/linter/extraprivate/cmd/main.go | 24 +++ .../extraprivate/testdata/src/t/BUILD.bazel | 8 + .../testdata/src/t/extraprivate.go | 51 +++++++ build/nogo_config.json | 9 ++ 9 files changed, 308 insertions(+) create mode 100644 build/linter/extraprivate/BUILD.bazel create mode 100644 build/linter/extraprivate/analyzer.go create mode 100644 build/linter/extraprivate/analyzer_test.go create mode 100644 build/linter/extraprivate/cmd/BUILD.bazel create mode 100644 build/linter/extraprivate/cmd/main.go create mode 100644 build/linter/extraprivate/testdata/src/t/BUILD.bazel create mode 100644 build/linter/extraprivate/testdata/src/t/extraprivate.go diff --git a/build/BUILD.bazel b/build/BUILD.bazel index b4e6d1c408cbc..fbef11268317b 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -152,6 +152,7 @@ nogo( "//build/linter/durationcheck", "//build/linter/etcdconfig", "//build/linter/exportloopref", + "//build/linter/extraprivate", "//build/linter/forcetypeassert", "//build/linter/gofmt", "//build/linter/gci", diff --git a/build/linter/extraprivate/BUILD.bazel b/build/linter/extraprivate/BUILD.bazel new file mode 100644 index 0000000000000..87b640b6fe2d7 --- /dev/null +++ b/build/linter/extraprivate/BUILD.bazel @@ -0,0 +1,26 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "extraprivate", + srcs = ["analyzer.go"], + importpath = "github.com/pingcap/tidb/build/linter/extraprivate", + visibility = ["//visibility:public"], + deps = [ + "//build/linter/util", + "@com_github_fatih_structtag//:structtag", + "@org_golang_x_tools//go/analysis", + "@org_golang_x_tools//go/analysis/passes/inspect", + "@org_golang_x_tools//go/ast/inspector", + ], +) + +go_test( + name = "extraprivate_test", + timeout = "short", + srcs = ["analyzer_test.go"], + flaky = True, + deps = [ + ":extraprivate", + "@org_golang_x_tools//go/analysis/analysistest", + ], +) diff --git a/build/linter/extraprivate/analyzer.go b/build/linter/extraprivate/analyzer.go new file mode 100644 index 0000000000000..bc5a83c47a305 --- /dev/null +++ b/build/linter/extraprivate/analyzer.go @@ -0,0 +1,138 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extraprivate + +import ( + "go/ast" + "go/types" + + "github.com/fatih/structtag" + "github.com/pingcap/tidb/build/linter/util" + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" +) + +// Analyzer is the analyzer of extraprivate. +// Access to fields with `extraprivate` tag outside its struct's methods is not allowed. However, this +// analyzer allows to construct the struct manually with the field with `extraprivate` tag. +var Analyzer = &analysis.Analyzer{ + Name: "extraprivate", + Doc: "Check developers don't read or write fields with extraprivate tag outside the method", + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Run: run, +} + +func run(pass *analysis.Pass) (any, error) { + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + + nodeFilter := []ast.Node{ + (*ast.FuncDecl)(nil), + (*ast.SelectorExpr)(nil), + } + + inspect.WithStack(nodeFilter, func(n ast.Node, push bool, stack []ast.Node) (proceed bool) { + if push { + return true + } + + se, ok := n.(*ast.SelectorExpr) + if !ok { + return + } + + xType := pass.TypesInfo.Types[se.X].Type + if !isExtraPrivateField(xType, se.Sel) { + return + } + + if isAccessWithinStructMethod(pass, stack, xType) { + return + } + + pass.Reportf(se.Pos(), "access to extraprivate field outside its struct's methods") + return + }) + + return nil, nil +} + +func isExtraPrivateField(xAnyType types.Type, fieldName *ast.Ident) bool { + switch xType := xAnyType.(type) { + case *types.Named: + underlyingType := xType.Underlying() + structType, ok := underlyingType.(*types.Struct) + if !ok { + return false + } + for i := 0; i < structType.NumFields(); i++ { + field := structType.Field(i) + if field.Name() == fieldName.Name { + tags, err := structtag.Parse(structType.Tag(i)) + if err != nil { + continue + } + _, err = tags.Get("extraprivate") + if err != nil { + continue + } + return true + } + } + + return false + case *types.Pointer: + return isExtraPrivateField(xType.Elem(), fieldName) + default: + return false + } +} + +func isAccessWithinStructMethod(pass *analysis.Pass, stack []ast.Node, se types.Type) bool { + for i := len(stack) - 1; i >= 0; i-- { + funcDecl, ok := stack[i].(*ast.FuncDecl) + if !ok { + continue + } + + if funcDecl.Recv == nil { + continue + } + + recvType := pass.TypesInfo.TypeOf(funcDecl.Recv.List[0].Type) + if recvType == nil { + continue + } + + if resolvePointer(recvType) == resolvePointer(se) { + return true + } + } + + return false +} + +func resolvePointer(xType types.Type) types.Type { + switch xType := xType.(type) { + case *types.Pointer: + return xType.Elem() + default: + return xType + } +} + +func init() { + util.SkipAnalyzerByConfig(Analyzer) +} diff --git a/build/linter/extraprivate/analyzer_test.go b/build/linter/extraprivate/analyzer_test.go new file mode 100644 index 0000000000000..316801240f3e7 --- /dev/null +++ b/build/linter/extraprivate/analyzer_test.go @@ -0,0 +1,33 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !intest + +package extraprivate_test + +import ( + "testing" + + "github.com/pingcap/tidb/build/linter/extraprivate" + "golang.org/x/tools/go/analysis/analysistest" +) + +// TODO: investigate the CI environment and check how to run this test in CI. +// The CI environment doesn't have `go` executable in $PATH. + +func Test(t *testing.T) { + testdata := analysistest.TestData() + pkgs := []string{"t"} + analysistest.Run(t, testdata, extraprivate.Analyzer, pkgs...) +} diff --git a/build/linter/extraprivate/cmd/BUILD.bazel b/build/linter/extraprivate/cmd/BUILD.bazel new file mode 100644 index 0000000000000..c449b62ed24f8 --- /dev/null +++ b/build/linter/extraprivate/cmd/BUILD.bazel @@ -0,0 +1,18 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library") + +go_library( + name = "cmd_lib", + srcs = ["main.go"], + importpath = "github.com/pingcap/tidb/build/linter/extraprivate/cmd", + visibility = ["//visibility:private"], + deps = [ + "//build/linter/extraprivate", + "@org_golang_x_tools//go/analysis/singlechecker", + ], +) + +go_binary( + name = "cmd", + embed = [":cmd_lib"], + visibility = ["//visibility:public"], +) diff --git a/build/linter/extraprivate/cmd/main.go b/build/linter/extraprivate/cmd/main.go new file mode 100644 index 0000000000000..8b7c605917741 --- /dev/null +++ b/build/linter/extraprivate/cmd/main.go @@ -0,0 +1,24 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "github.com/pingcap/tidb/build/linter/extraprivate" + "golang.org/x/tools/go/analysis/singlechecker" +) + +func main() { + singlechecker.Main(extraprivate.Analyzer) +} diff --git a/build/linter/extraprivate/testdata/src/t/BUILD.bazel b/build/linter/extraprivate/testdata/src/t/BUILD.bazel new file mode 100644 index 0000000000000..705422dd74266 --- /dev/null +++ b/build/linter/extraprivate/testdata/src/t/BUILD.bazel @@ -0,0 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "t", + srcs = ["extraprivate.go"], + importpath = "github.com/pingcap/tidb/build/linter/extraprivate/testdata/src/t", + visibility = ["//visibility:public"], +) diff --git a/build/linter/extraprivate/testdata/src/t/extraprivate.go b/build/linter/extraprivate/testdata/src/t/extraprivate.go new file mode 100644 index 0000000000000..d5f02cfedc682 --- /dev/null +++ b/build/linter/extraprivate/testdata/src/t/extraprivate.go @@ -0,0 +1,51 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package t + +type structWithExtraPrivateField struct { + field1 int `extraprivate:""` + field2 int +} + +func readField1(s structWithExtraPrivateField) int { + return s.field1 // want `access to extraprivate field outside its struct's methods` +} + +func readField1Ptr(s *structWithExtraPrivateField) int { + return s.field1 // want `access to extraprivate field outside its struct's methods` +} + +func readField2(s structWithExtraPrivateField) int { + return s.field2 +} +func readField2Ptr(s *structWithExtraPrivateField) int { + return s.field2 +} + +func (s structWithExtraPrivateField) Field1() int { + return s.field1 +} + +func (s *structWithExtraPrivateField) Field1Ptr() int { + return s.field1 +} + +func (s structWithExtraPrivateField) Field2() int { + return s.field2 +} + +func (s *structWithExtraPrivateField) Field2Ptr() int { + return s.field2 +} diff --git a/build/nogo_config.json b/build/nogo_config.json index 58d37523fe211..7b8027001b829 100644 --- a/build/nogo_config.json +++ b/build/nogo_config.json @@ -1362,5 +1362,14 @@ "external/": "no need to vet third party code", ".*_generated\\.go$": "ignore generated code" } + }, + "extraprivate": { + "exclude_files": { + "pkg/parser/parser.go": "parser/parser.go code", + "external/": "no need to vet third party code", + ".*_generated\\.go$": "ignore generated code", + "build/linter/extraprivate/testdata/": "no need to vet the test inside the linter", + "_test.go": "ignore all test code" + } } } From 8a1cd5701bd21960ff7237558929150e3d109105 Mon Sep 17 00:00:00 2001 From: Yang Keao Date: Wed, 26 Jun 2024 19:13:27 +0800 Subject: [PATCH 2/2] add extraprivate tag for Constant.Value Signed-off-by: Yang Keao --- pkg/executor/aggfuncs/builder.go | 11 ++-- pkg/executor/set.go | 3 +- pkg/expression/aggregation/descriptor.go | 42 +++++++++++----- pkg/expression/builtin_arithmetic.go | 6 ++- pkg/expression/builtin_compare.go | 33 ++++++++---- pkg/expression/builtin_info.go | 9 ++-- pkg/expression/builtin_json.go | 30 ++++++----- pkg/expression/builtin_math.go | 6 ++- pkg/expression/builtin_op.go | 9 +++- pkg/expression/builtin_other.go | 14 +++--- pkg/expression/builtin_string.go | 22 ++++++-- pkg/expression/collation.go | 3 +- pkg/expression/constant.go | 36 ++++++++++++- pkg/expression/constant_fold.go | 11 ++-- pkg/expression/constant_propagation.go | 12 +++-- pkg/expression/explain.go | 10 ++-- pkg/expression/expression.go | 46 +++++++++++------ pkg/expression/scalar_function.go | 10 +++- pkg/expression/util.go | 36 +++++++++---- pkg/planner/cardinality/selectivity.go | 8 ++- pkg/planner/core/exhaust_physical_plans.go | 5 +- pkg/planner/core/expression_rewriter.go | 6 ++- pkg/planner/core/logical_aggregation.go | 10 +++- pkg/planner/core/logical_plans.go | 8 ++- .../core/memtable_predicate_extractor.go | 4 +- pkg/planner/core/optimizer.go | 6 ++- pkg/planner/core/rule_column_pruning.go | 3 +- .../core/rule_derive_topn_from_window.go | 2 +- pkg/planner/core/rule_partition_processor.go | 50 +++++++++++++------ pkg/planner/core/rule_predicate_push_down.go | 18 ++++--- .../core/rule_predicate_simplification.go | 10 ++-- pkg/planner/core/scalar_subq_expression.go | 5 +- pkg/planner/util/null_misc.go | 18 ++++--- pkg/util/ranger/checker.go | 14 ++++-- pkg/util/ranger/detacher.go | 8 +-- .../r/planner/core/integration.result | 19 ++++--- .../t/planner/core/integration.test | 11 +++- 37 files changed, 396 insertions(+), 158 deletions(-) diff --git a/pkg/executor/aggfuncs/builder.go b/pkg/executor/aggfuncs/builder.go index abbff30bae6fb..79124bd41d0ee 100644 --- a/pkg/executor/aggfuncs/builder.go +++ b/pkg/executor/aggfuncs/builder.go @@ -720,10 +720,13 @@ func buildLeadLag(ctx AggFuncBuildContext, aggFuncDesc *aggregation.AggFuncDesc, if len(aggFuncDesc.Args) == 3 { defaultExpr = aggFuncDesc.Args[2] if et, ok := defaultExpr.(*expression.Constant); ok { - evalCtx := ctx.GetEvalCtx() - res, err1 := et.Value.ConvertTo(evalCtx.TypeCtx(), aggFuncDesc.RetTp) - if err1 == nil { - defaultExpr = &expression.Constant{Value: res, RetType: aggFuncDesc.RetTp} + etVal, ok := et.GetValueWithoutOverOptimization(ctx) + if ok { + evalCtx := ctx.GetEvalCtx() + res, err1 := etVal.ConvertTo(evalCtx.TypeCtx(), aggFuncDesc.RetTp) + if err1 == nil { + defaultExpr = &expression.Constant{Value: res, RetType: aggFuncDesc.RetTp} + } } } } diff --git a/pkg/executor/set.go b/pkg/executor/set.go index 89a65bb77a2e4..bb648bdbf2755 100644 --- a/pkg/executor/set.go +++ b/pkg/executor/set.go @@ -78,7 +78,8 @@ func (e *SetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { cs := dt.GetString() var co string if v.ExtendValue != nil { - co = v.ExtendValue.Value.GetString() + val, _ := v.ExtendValue.GetValue() + co = val.GetString() } err = e.setCharset(cs, co, v.Name == ast.SetNames) if err != nil { diff --git a/pkg/expression/aggregation/descriptor.go b/pkg/expression/aggregation/descriptor.go index c56c95befc025..4a8ff37bfdc4a 100644 --- a/pkg/expression/aggregation/descriptor.go +++ b/pkg/expression/aggregation/descriptor.go @@ -16,7 +16,6 @@ package aggregation import ( "bytes" - "math" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/expression" @@ -246,8 +245,15 @@ func (a *AggFuncDesc) evalNullValueInOuterJoin4Count(ctx expression.BuildContext for _, arg := range a.Args { result := expression.EvaluateExprWithNull(ctx, schema, arg) con, ok := result.(*expression.Constant) - if !ok || con.Value.IsNull() { - return types.Datum{}, ok + if !ok { + return types.Datum{}, false + } + conVal, ok := con.GetValueWithoutOverOptimization(ctx) + if !ok { + return types.Datum{}, false + } + if conVal.IsNull() { + return types.Datum{}, true } } return types.NewDatum(1), true @@ -256,28 +262,40 @@ func (a *AggFuncDesc) evalNullValueInOuterJoin4Count(ctx expression.BuildContext func (a *AggFuncDesc) evalNullValueInOuterJoin4Sum(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) { result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0]) con, ok := result.(*expression.Constant) - if !ok || con.Value.IsNull() { - return types.Datum{}, ok + if !ok { + return types.Datum{}, false } - return con.Value, true + conVal, ok := con.GetValueWithoutOverOptimization(ctx) + if !ok { + return types.Datum{}, false + } + return conVal, true } func (a *AggFuncDesc) evalNullValueInOuterJoin4BitAnd(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) { result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0]) con, ok := result.(*expression.Constant) - if !ok || con.Value.IsNull() { - return types.NewDatum(uint64(math.MaxUint64)), true + if !ok { + return types.Datum{}, false } - return con.Value, true + conVal, ok := con.GetValueWithoutOverOptimization(ctx) + if !ok { + return types.Datum{}, false + } + return conVal, true } func (a *AggFuncDesc) evalNullValueInOuterJoin4BitOr(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) { result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0]) con, ok := result.(*expression.Constant) - if !ok || con.Value.IsNull() { - return types.NewDatum(0), true + if !ok { + return types.Datum{}, false + } + conVal, ok := con.GetValue() + if !ok { + return types.Datum{}, false } - return con.Value, true + return conVal, true } // UpdateNotNullFlag4RetType checks if we should remove the NotNull flag for the return type of the agg. diff --git a/pkg/expression/builtin_arithmetic.go b/pkg/expression/builtin_arithmetic.go index 3e9e14af32fb4..7d8a9c2620a47 100644 --- a/pkg/expression/builtin_arithmetic.go +++ b/pkg/expression/builtin_arithmetic.go @@ -62,7 +62,11 @@ var ( func isConstantBinaryLiteral(ctx EvalContext, expr Expression) bool { if types.IsBinaryStr(expr.GetType(ctx)) { if v, ok := expr.(*Constant); ok { - if k := v.Value.Kind(); k == types.KindBinaryLiteral { + val, ok := v.GetValue() + if !ok { + return false + } + if k := val.Kind(); k == types.KindBinaryLiteral { return true } } diff --git a/pkg/expression/builtin_compare.go b/pkg/expression/builtin_compare.go index 442ae50912883..f5eb96d1bcfa4 100644 --- a/pkg/expression/builtin_compare.go +++ b/pkg/expression/builtin_compare.go @@ -1443,7 +1443,7 @@ func RefineComparedConstant(ctx BuildContext, targetFieldType types.FieldType, c } return con, false } - c, err := intDatum.Compare(evalCtx.TypeCtx(), &con.Value, collate.GetBinaryCollator()) + c, err := intDatum.Compare(evalCtx.TypeCtx(), &dt, collate.GetBinaryCollator()) if err != nil { return con, false } @@ -1580,11 +1580,20 @@ func (c *compareFunctionClass) refineArgs(ctx BuildContext, args []Expression) ( if !allowCmpArgsRefining4PlanCache(ctx, args) { return args, nil } - // We should remove the mutable constant for correctness, because its value may be changed. + if err := RemoveMutableConst(ctx, args); err != nil { return nil, err } + // After `RemoveMutableConst`, params and deferred functions are all constant value now. + var arg0Val, arg1Val types.Datum + if arg0IsCon { + arg0Val, _ = arg0.GetValue() + } + if arg1IsCon { + arg1Val, _ = arg1.GetValue() + } + if arg0IsCon && !arg1IsCon && matchRefineRule3Pattern(arg0EvalType, arg1Type) { return c.refineNumericConstantCmpDatetime(ctx, args, arg0, 0), nil } @@ -1613,7 +1622,7 @@ func (c *compareFunctionClass) refineArgs(ctx BuildContext, args []Expression) ( // For uint: // inf: 11111111 & 1 == 1 // -inf: 00000000 & 1 == 0 - if arg1.Value.GetInt64()&1 == 1 { + if arg1Val.GetInt64()&1 == 1 { isPositiveInfinite = true } else { isNegativeInfinite = true @@ -1630,7 +1639,7 @@ func (c *compareFunctionClass) refineArgs(ctx BuildContext, args []Expression) ( // to check the NotNullFlag, then more optimizations can be enabled. isExceptional = isExceptional && mysql.HasNotNullFlag(arg1Type.GetFlag()) if isExceptional && arg0.GetType(ctx.GetEvalCtx()).EvalType() == types.ETInt { - if arg0.Value.GetInt64()&1 == 1 { + if arg0Val.GetInt64()&1 == 1 { isNegativeInfinite = true } else { isPositiveInfinite = true @@ -1639,18 +1648,22 @@ func (c *compareFunctionClass) refineArgs(ctx BuildContext, args []Expression) ( } // int constant [cmp] year type - if arg0IsCon && arg0IsInt && arg1Type.GetType() == mysql.TypeYear && !arg0.Value.IsNull() { - adjusted, failed := types.AdjustYear(arg0.Value.GetInt64(), false) + if arg0IsCon && arg0IsInt && arg1Type.GetType() == mysql.TypeYear && !arg0Val.IsNull() { + adjusted, failed := types.AdjustYear(arg0Val.GetInt64(), false) if failed == nil { - arg0.Value.SetInt64(adjusted) + newVal := types.Datum{} + newVal.SetInt64(adjusted) + *arg0 = Constant{Value: newVal, RetType: arg0.RetType} finalArg0 = arg0 } } // year type [cmp] int constant - if arg1IsCon && arg1IsInt && arg0Type.GetType() == mysql.TypeYear && !arg1.Value.IsNull() { - adjusted, failed := types.AdjustYear(arg1.Value.GetInt64(), false) + if arg1IsCon && arg1IsInt && arg0Type.GetType() == mysql.TypeYear && !arg1Val.IsNull() { + adjusted, failed := types.AdjustYear(arg1Val.GetInt64(), false) if failed == nil { - arg1.Value.SetInt64(adjusted) + newVal := types.Datum{} + newVal.SetInt64(adjusted) + *arg1 = Constant{Value: newVal, RetType: arg1.RetType} finalArg1 = arg1 } } diff --git a/pkg/expression/builtin_info.go b/pkg/expression/builtin_info.go index e0dc23eb125dc..abc3de8abf366 100644 --- a/pkg/expression/builtin_info.go +++ b/pkg/expression/builtin_info.go @@ -633,9 +633,12 @@ func (c *benchmarkFunctionClass) getFunction(ctx BuildContext, args []Expression // since non-constant loop count would be different between rows, and cannot be vectorized. var constLoopCount int64 con, ok := args[0].(*Constant) - if ok && con.Value.Kind() == types.KindInt64 { - if lc, isNull, err := con.EvalInt(ctx.GetEvalCtx(), chunk.Row{}); err == nil && !isNull { - constLoopCount = lc + if ok { + conVal, ok := con.GetValue() + if ok && conVal.Kind() == types.KindInt64 { + if lc, isNull, err := con.EvalInt(ctx.GetEvalCtx(), chunk.Row{}); err == nil && !isNull { + constLoopCount = lc + } } } bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETInt, sameEvalType) diff --git a/pkg/expression/builtin_json.go b/pkg/expression/builtin_json.go index 4cf7045904565..ba465618ed1d3 100644 --- a/pkg/expression/builtin_json.go +++ b/pkg/expression/builtin_json.go @@ -1822,19 +1822,25 @@ func (c *jsonSchemaValidFunctionClass) verifyArgs(ctx EvalContext, args []Expres return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "json_schema_valid") } if c, ok := args[0].(*Constant); ok { - // If args[0] is NULL, then don't check the length of *both* arguments. - // JSON_SCHEMA_VALID(NULL,NULL) -> NULL - // JSON_SCHEMA_VALID(NULL,'') -> NULL - // JSON_SCHEMA_VALID('',NULL) -> ErrInvalidJSONTextInParam - if !c.Value.IsNull() { - if len(c.Value.GetBytes()) == 0 { - return types.ErrInvalidJSONTextInParam.GenWithStackByArgs( - 1, "json_schema_valid", "The document is empty.", 0) - } - if c1, ok := args[1].(*Constant); ok { - if !c1.Value.IsNull() && len(c1.Value.GetBytes()) == 0 { + if c, ok := c.GetValue(); ok { + // If args[0] is NULL, then don't check the length of *both* arguments. + // JSON_SCHEMA_VALID(NULL,NULL) -> NULL + // JSON_SCHEMA_VALID(NULL,'') -> NULL + // JSON_SCHEMA_VALID('',NULL) -> ErrInvalidJSONTextInParam + if !c.IsNull() { + if len(c.GetBytes()) == 0 { return types.ErrInvalidJSONTextInParam.GenWithStackByArgs( - 2, "json_schema_valid", "The document is empty.", 0) + 1, "json_schema_valid", "The document is empty.", 0) + } + if c1, ok := args[1].(*Constant); ok { + c1Val, err := c1.Eval(ctx, chunk.Row{}) + if err != nil { + return err + } + if !c1Val.IsNull() && len(c1Val.GetBytes()) == 0 { + return types.ErrInvalidJSONTextInParam.GenWithStackByArgs( + 2, "json_schema_valid", "The document is empty.", 0) + } } } } diff --git a/pkg/expression/builtin_math.go b/pkg/expression/builtin_math.go index 9282b6bfcd53e..fcc727b904d2b 100644 --- a/pkg/expression/builtin_math.go +++ b/pkg/expression/builtin_math.go @@ -1188,7 +1188,11 @@ func (b *builtinConvSig) evalString(ctx EvalContext, row chunk.Row) (res string, var str string switch x := b.args[0].(type) { case *Constant: - if x.Value.Kind() == types.KindBinaryLiteral { + xVal, err := x.Eval(ctx, row) + if err != nil { + return res, isNull, err + } + if xVal.Kind() == types.KindBinaryLiteral { datum, err := x.Eval(ctx, row) if err != nil { return "", false, err diff --git a/pkg/expression/builtin_op.go b/pkg/expression/builtin_op.go index 8f09c0f89ac80..1f492b4c09019 100644 --- a/pkg/expression/builtin_op.go +++ b/pkg/expression/builtin_op.go @@ -838,15 +838,20 @@ type unaryMinusFunctionClass struct { } func (c *unaryMinusFunctionClass) handleIntOverflow(ctx EvalContext, arg *Constant) (overflow bool) { + argVal, ok := arg.GetValue() + if !ok { + // ignore parameter and deferred functions + return false + } if mysql.HasUnsignedFlag(arg.GetType(ctx).GetFlag()) { - uval := arg.Value.GetUint64() + uval := argVal.GetUint64() // -math.MinInt64 is 9223372036854775808, so if uval is more than 9223372036854775808, like // 9223372036854775809, -9223372036854775809 is less than math.MinInt64, overflow occurs. if uval > uint64(-math.MinInt64) { return true } } else { - val := arg.Value.GetInt64() + val := argVal.GetInt64() // The math.MinInt64 is -9223372036854775808, the math.MaxInt64 is 9223372036854775807, // which is less than abs(-9223372036854775808). When val == math.MinInt64, overflow occurs. if val == math.MinInt64 { diff --git a/pkg/expression/builtin_other.go b/pkg/expression/builtin_other.go index 14e0b7600ef19..07fd854a5027e 100644 --- a/pkg/expression/builtin_other.go +++ b/pkg/expression/builtin_other.go @@ -162,13 +162,15 @@ func (c *inFunctionClass) verifyArgs(ctx BuildContext, args []Expression) ([]Exp validatedArgs := make([]Expression, 0, len(args)) for _, arg := range args { if constant, ok := arg.(*Constant); ok { - switch { - case columnType.GetType() == mysql.TypeBit && constant.Value.Kind() == types.KindInt64: - if constant.Value.GetInt64() < 0 { - if MaybeOverOptimized4PlanCache(ctx, args) { - ctx.SetSkipPlanCache(fmt.Sprintf("Bit Column in (%v)", constant.Value.GetInt64())) + if constant, ok := constant.GetValue(); ok { + switch { + case columnType.GetType() == mysql.TypeBit && constant.Kind() == types.KindInt64: + if constant.GetInt64() < 0 { + if MaybeOverOptimized4PlanCache(ctx, args) { + ctx.SetSkipPlanCache(fmt.Sprintf("Bit Column in (%v)", constant.GetInt64())) + } + continue } - continue } } } diff --git a/pkg/expression/builtin_string.go b/pkg/expression/builtin_string.go index 678291b636af0..4185ae2f0d8fa 100644 --- a/pkg/expression/builtin_string.go +++ b/pkg/expression/builtin_string.go @@ -1055,13 +1055,19 @@ func (c *convertFunctionClass) getFunction(ctx BuildContext, args []Expression) return nil, err } - charsetArg, ok := args[1].(*Constant) + charsetArgCon, ok := args[1].(*Constant) if !ok { // `args[1]` is limited by parser to be a constant string, // should never go into here. return nil, errIncorrectArgs.GenWithStackByArgs("charset") } - transcodingName := charsetArg.Value.GetString() + charsetArg, ok := charsetArgCon.GetValue() + if !ok { + // `args[1]` is limited by parser to be a constant string, + // should never go into here. + return nil, errIncorrectArgs.GenWithStackByArgs("charset") + } + transcodingName := charsetArg.GetString() bf.tp.SetCharset(strings.ToLower(transcodingName)) // Quoted about the behavior of syntax CONVERT(expr, type) to CHAR(): // In all cases, the string has the default collation for the character set. @@ -3873,7 +3879,11 @@ func (c *weightStringFunctionClass) verifyArgs(ctx EvalContext, args []Expressio if !ok { return weightStringPaddingNone, 0, ErrIncorrectType.GenWithStackByArgs(args[1].String(), c.funcName) } - switch x := c1.Value.GetString(); x { + c1Val, ok := c1.GetValue() + if !ok { + return weightStringPaddingNone, 0, ErrIncorrectType.GenWithStackByArgs(args[1].String(), c.funcName) + } + switch x := c1Val.GetString(); x { case "CHAR": if padding == weightStringPaddingNone { padding = weightStringPaddingAsChar @@ -3890,7 +3900,11 @@ func (c *weightStringFunctionClass) verifyArgs(ctx EvalContext, args []Expressio if !ok { return weightStringPaddingNone, 0, ErrIncorrectType.GenWithStackByArgs(args[1].String(), c.funcName) } - length = int(c2.Value.GetInt64()) + c2Val, ok := c2.GetValue() + if !ok { + return weightStringPaddingNone, 0, ErrIncorrectType.GenWithStackByArgs(args[1].String(), c.funcName) + } + length = int(c2Val.GetInt64()) if length == 0 { return weightStringPaddingNone, 0, ErrIncorrectType.GenWithStackByArgs(args[2].String(), c.funcName) } diff --git a/pkg/expression/collation.go b/pkg/expression/collation.go index dfbf3aac695db..520e5f30bbb74 100644 --- a/pkg/expression/collation.go +++ b/pkg/expression/collation.go @@ -164,7 +164,8 @@ func deriveCoercibilityForScalarFunc(sf *ScalarFunction) Coercibility { } func deriveCoercibilityForConstant(c *Constant) Coercibility { - if c.Value.IsNull() { + // For a parameter or deferred constant, it's coercibility is always `CoercibilityCoercible`, even when it's NULL. + if cVal, ok := c.GetValue(); ok && cVal.IsNull() { return CoercibilityIgnorable } else if c.RetType.EvalType() != types.ETString { return CoercibilityNumeric diff --git a/pkg/expression/constant.go b/pkg/expression/constant.go index 1643956abce93..7347f794f7700 100644 --- a/pkg/expression/constant.go +++ b/pkg/expression/constant.go @@ -114,7 +114,9 @@ func NewNullWithFieldType(fieldType *types.FieldType) *Constant { // Constant stands for a constant value. type Constant struct { - Value types.Datum + // Value is the datum of the constant. Don't use `Constant.Value` directly, because it'll be empty if the constant + // is a param marker or a deferred expression. Use `Constant.Eval` instead. + Value types.Datum `extraprivate:""` RetType *types.FieldType // DeferredExpr holds deferred function in PlanCache cached plan. // it's only used to represent non-deterministic functions(see expression.DeferredFunctions) @@ -539,3 +541,35 @@ func (c *Constant) MemoryUsage() (sum int64) { } return } + +// GetValue returns the reference to the value of the constant if the constant is not a parameter or a deferred expression. +// Otherwise, it'll return nil and false. +// If you are using this function in planner, or `getFunction() / verifyArg()` in `expression`, please consider whether +// it's more appropriate to use `GetValueWithoutOverOptimization`. +func (c *Constant) GetValue() (types.Datum, bool) { + if c.DeferredExpr != nil || c.ParamMarker != nil { + return types.Datum{}, false + } + return c.Value, true +} + +// GetValueWithoutOverOptimization returns the value of the constant if: +// 1. It's not a parameter or a deferred expression. +// 2. The context is not in plan cache. +// This function avoids using `Constant.Value` unexpectedly in plan cache, which may cause wrong result. +func (c *Constant) GetValueWithoutOverOptimization(ctx BuildContext) (types.Datum, bool) { + val, ok := c.GetValue() + + if ctx.IsUseCache() { + if ok { + return val, true + } + return types.Datum{}, false + } + + datum, err := c.Eval(ctx.GetEvalCtx(), chunk.Row{}) + if err != nil { + return types.Datum{}, false + } + return datum, true +} diff --git a/pkg/expression/constant_fold.go b/pkg/expression/constant_fold.go index dd742477564c6..b82c8d494313e 100644 --- a/pkg/expression/constant_fold.go +++ b/pkg/expression/constant_fold.go @@ -99,7 +99,7 @@ func ifNullFoldHandler(ctx BuildContext, expr *ScalarFunction) (Expression, bool // Only check constArg.Value here. Because deferred expression is // evaluated to constArg.Value after foldConstant(args[0]), it's not // needed to be checked. - if constArg.Value.IsNull() { + if constArgVal, ok := constArg.GetValueWithoutOverOptimization(ctx); ok && constArgVal.IsNull() { foldedExpr, isConstant := foldConstant(ctx, args[1]) // See https://github.com/pingcap/tidb/issues/51765. If the first argument can @@ -180,9 +180,12 @@ func foldConstant(ctx BuildContext, expr Expression) (Expression, bool) { for i := 0; i < len(args); i++ { switch x := args[i].(type) { case *Constant: - isDeferredConst = isDeferredConst || x.DeferredExpr != nil || x.ParamMarker != nil - argIsConst[i] = true - hasNullArg = hasNullArg || x.Value.IsNull() + xVal, ok := x.GetValueWithoutOverOptimization(ctx) + if ok { + isDeferredConst = isDeferredConst || x.DeferredExpr != nil || x.ParamMarker != nil + argIsConst[i] = true + hasNullArg = hasNullArg || xVal.IsNull() + } default: allConstArg = false } diff --git a/pkg/expression/constant_propagation.go b/pkg/expression/constant_propagation.go index fca6c3f514705..c5216ab5a9a60 100644 --- a/pkg/expression/constant_propagation.go +++ b/pkg/expression/constant_propagation.go @@ -53,15 +53,19 @@ func (s *basePropConstSolver) insertCol(col *Column) { // tryToUpdateEQList tries to update the eqList. When the eqList has store this column with a different constant, like // a = 1 and a = 2, we set the second return value to false. func (s *basePropConstSolver) tryToUpdateEQList(col *Column, con *Constant) (bool, bool) { - if con.Value.IsNull() && ConstExprConsiderPlanCache(con, s.ctx.IsUseCache()) { + conVal, ok := con.GetValueWithoutOverOptimization(s.ctx) + if ok && conVal.IsNull() { return false, true } id := s.getColID(col) oldCon := s.eqList[id] if oldCon != nil { - evalCtx := s.ctx.GetEvalCtx() - res, err := oldCon.Value.Compare(evalCtx.TypeCtx(), &con.Value, collate.GetCollator(col.GetType(s.ctx.GetEvalCtx()).GetCollate())) - return false, res != 0 || err != nil + oldConVal, oldOK := oldCon.GetValueWithoutOverOptimization(s.ctx) + if oldOK && ok { + evalCtx := s.ctx.GetEvalCtx() + res, err := oldConVal.Compare(evalCtx.TypeCtx(), &conVal, collate.GetCollator(col.GetType(s.ctx.GetEvalCtx()).GetCollate())) + return false, res != 0 || err != nil + } } s.eqList[id] = con return true, false diff --git a/pkg/expression/explain.go b/pkg/expression/explain.go index 2162e84144448..2d5c96adf51d1 100644 --- a/pkg/expression/explain.go +++ b/pkg/expression/explain.go @@ -40,9 +40,13 @@ func (expr *ScalarFunction) explainInfo(ctx EvalContext, normalized bool) string // convert `in(_tidb_tid, -1)` to `in(_tidb_tid, dual)` whether normalized equals to true or false. if expr.FuncName.L == ast.In { args := expr.GetArgs() - if len(args) == 2 && strings.HasSuffix(args[0].ExplainNormalizedInfo(), model.ExtraPhysTblIdName.L) && args[1].(*Constant).Value.GetInt64() == -1 { - buffer.WriteString(args[0].ExplainNormalizedInfo() + ", dual)") - return buffer.String() + if len(args) == 2 && strings.HasSuffix(args[0].ExplainNormalizedInfo(), model.ExtraPhysTblIdName.L) { + if arg1Con, ok := args[1].(*Constant); ok { + if arg1Val, ok := arg1Con.GetValue(); ok && arg1Val.GetInt64() == -1 { + buffer.WriteString(args[0].ExplainNormalizedInfo() + ", dual)") + return buffer.String() + } + } } } switch expr.FuncName.L { diff --git a/pkg/expression/expression.go b/pkg/expression/expression.go index 7f97583365719..fb056ab485c89 100644 --- a/pkg/expression/expression.go +++ b/pkg/expression/expression.go @@ -291,7 +291,10 @@ func IsEQCondFromIn(expr Expression) bool { // ExprNotNull checks if an expression is possible to be null. func ExprNotNull(ctx EvalContext, expr Expression) bool { if c, ok := expr.(*Constant); ok { - return !c.Value.IsNull() + // For parameter / deferred function, we can't determine if it's null or not. + if cVal, ok := c.GetValue(); ok { + return !cVal.IsNull() + } } // For ScalarFunction, the result would not be correct until we support maintaining // NotNull flag for it. @@ -910,9 +913,11 @@ func evaluateExprWithNullInNullRejectCheck(ctx BuildContext, schema *Schema, exp } allArgsNullFromSet := true for i := range args { - if cons, ok := args[i].(*Constant); ok && cons.Value.IsNull() && !nullFromSets[i] { - allArgsNullFromSet = false - break + if cons, ok := args[i].(*Constant); ok { + if consVal, ok := cons.GetValue(); ok && consVal.IsNull() && !nullFromSets[i] { + allArgsNullFromSet = false + break + } } } @@ -929,14 +934,16 @@ func evaluateExprWithNullInNullRejectCheck(ctx BuildContext, schema *Schema, exp } if hasNonConstantArg { for i := range args { - if cons, ok := args[i].(*Constant); ok && cons.Value.IsNull() && nullFromSets[i] { - if x.FuncName.L == ast.LogicAnd { - args[i] = NewOne() - break - } - if x.FuncName.L == ast.LogicOr { - args[i] = NewZero() - break + if cons, ok := args[i].(*Constant); ok { + if consVal, ok := cons.GetValue(); ok && consVal.IsNull() && !nullFromSets[i] { + if x.FuncName.L == ast.LogicAnd { + args[i] = NewOne() + break + } + if x.FuncName.L == ast.LogicOr { + args[i] = NewZero() + break + } } } } @@ -945,9 +952,12 @@ func evaluateExprWithNullInNullRejectCheck(ctx BuildContext, schema *Schema, exp c := NewFunctionInternal(ctx, x.FuncName.L, x.RetType.Clone(), args...) cons, ok := c.(*Constant) - // If the return expr is Null Constant, and all the Null Constant arguments are affected by column schema, - // then we think the result Null Constant is also affected by the column schema - return c, ok && cons.Value.IsNull() && allArgsNullFromSet + if ok { + consVal, ok := cons.GetValue() + // If the return expr is Null Constant, and all the Null Constant arguments are affected by column schema, + // then we think the result Null Constant is also affected by the column schema + return c, ok && consVal.IsNull() && allArgsNullFromSet + } case *Column: if !schema.Contains(x) { return x, false @@ -1081,7 +1091,11 @@ func NewValuesFunc(ctx BuildContext, offset int, retTp *types.FieldType) *Scalar // IsBinaryLiteral checks whether an expression is a binary literal func IsBinaryLiteral(expr Expression) bool { con, ok := expr.(*Constant) - return ok && con.Value.Kind() == types.KindBinaryLiteral + if !ok { + return false + } + conVal, ok := con.GetValue() + return ok && conVal.Kind() == types.KindBinaryLiteral } // wrapWithIsTrue wraps `arg` with istrue function if the return type of expr is not diff --git a/pkg/expression/scalar_function.go b/pkg/expression/scalar_function.go index f007e43b6cc85..26ae78c2c895b 100644 --- a/pkg/expression/scalar_function.go +++ b/pkg/expression/scalar_function.go @@ -154,7 +154,15 @@ func typeInferForNull(ctx EvalContext, args []Expression) { } var isNull = func(expr Expression) bool { cons, ok := expr.(*Constant) - return ok && cons.RetType.GetType() == mysql.TypeNull && cons.Value.IsNull() + if !ok { + return false + } + consVal, ok := cons.GetValue() + if !ok { + return false + } + + return cons.RetType.GetType() == mysql.TypeNull && consVal.IsNull() } // Infer the actual field type of the NULL constant. var retFieldTp *types.FieldType diff --git a/pkg/expression/util.go b/pkg/expression/util.go index cfca9c7b6bedc..2473bc250d28d 100644 --- a/pkg/expression/util.go +++ b/pkg/expression/util.go @@ -181,15 +181,19 @@ func ExtractEquivalenceColumns(result [][]Expression, exprs []Expression) [][]Ex // and constant. It return nil, 0 if the expression is not of this form. // It is used by derived Top N pattern and it is put here since it looks like // a general purpose routine. Similar routines can be added to find lower bound as well. -func FindUpperBound(expr Expression) (*Column, int64) { +func FindUpperBound(ctx BuildContext, expr Expression) (*Column, int64) { scalarFunction, scalarFunctionOk := expr.(*ScalarFunction) if scalarFunctionOk { args := scalarFunction.GetArgs() if len(args) == 2 { col, colOk := args[0].(*Column) constant, constantOk := args[1].(*Constant) + if !constantOk { + return nil, 0 + } + constantVal, constantOk := constant.GetValueWithoutOverOptimization(ctx) if colOk && constantOk && (scalarFunction.FuncName.L == ast.LT || scalarFunction.FuncName.L == ast.LE) { - value, valueOk := constant.Value.GetValue().(int64) + value, valueOk := constantVal.GetValue().(int64) if valueOk { if scalarFunction.FuncName.L == ast.LT { return col, value - 1 @@ -628,8 +632,16 @@ func SubstituteCorCol2Constant(ctx BuildContext, expr Expression) (Expression, e return &Constant{Value: *x.Data, RetType: x.GetType(ctx.GetEvalCtx())}, nil case *Constant: if x.DeferredExpr != nil { - newExpr := FoldConstant(ctx, x) - return &Constant{Value: newExpr.(*Constant).Value, RetType: x.GetType(ctx.GetEvalCtx())}, nil + newExpr := FoldConstant(ctx, x).(*Constant) + // if the `DeferredExpr` evaluated failed, the `newExpr.GetValue()` may return a `nil`. In this + // case we'll return an empty datum. This behavior is weird, but will keep the behavior not changed. + val, ok := newExpr.GetValue() + if !ok { + logutil.BgLogger().Warn("substitute correlation column to constant, but get empty value", + zap.String("expression", x.ExplainInfo(ctx.GetEvalCtx()))) + return &Constant{RetType: x.GetType(ctx.GetEvalCtx())}, nil + } + return &Constant{Value: val, RetType: x.GetType(ctx.GetEvalCtx())}, nil } } return expr, nil @@ -1401,7 +1413,7 @@ func GetUint64FromConstant(ctx EvalContext, expr Expression) (uint64, bool, bool logutil.BgLogger().Warn("not a constant expression", zap.String("expression", expr.ExplainInfo(ctx))) return 0, false, false } - dt := con.Value + dt, _ := con.GetValue() if con.ParamMarker != nil { dt = con.ParamMarker.GetUserVar(ctx) } else if con.DeferredExpr != nil { @@ -1552,16 +1564,22 @@ func RemoveMutableConst(ctx BuildContext, exprs []Expression) (err error) { for _, expr := range exprs { switch v := expr.(type) { case *Constant: - v.ParamMarker = nil + val, ok := v.GetValue() + if ok { + return nil + } + if v.DeferredExpr != nil { // evaluate and update v.Value to convert v to a complete immutable constant. // TODO: remove or hide DeferredExpr since it's too dangerous (hard to be consistent with v.Value all the time). - v.Value, err = v.DeferredExpr.Eval(ctx.GetEvalCtx(), chunk.Row{}) + val, err = v.DeferredExpr.Eval(ctx.GetEvalCtx(), chunk.Row{}) if err != nil { return err } - v.DeferredExpr = nil } - v.DeferredExpr = nil // do nothing since v.Value has already been evaluated in this case. + *v = Constant{ + Value: val, + RetType: v.RetType, + } case *ScalarFunction: return RemoveMutableConst(ctx, v.GetArgs()) } diff --git a/pkg/planner/cardinality/selectivity.go b/pkg/planner/cardinality/selectivity.go index 789436451322b..b530140e1b231 100644 --- a/pkg/planner/cardinality/selectivity.go +++ b/pkg/planner/cardinality/selectivity.go @@ -297,12 +297,16 @@ func Selectivity( if expression.MaybeOverOptimized4PlanCache(ctx.GetExprCtx(), []expression.Expression{c}) { continue } - if c.Value.IsNull() { + val, ok := c.GetValueWithoutOverOptimization(ctx.GetExprCtx()) + if !ok { + continue + } + if val.IsNull() { // c is null ret *= 0 mask &^= 1 << uint64(i) delete(notCoveredConstants, i) - } else if isTrue, err := c.Value.ToBool(sc.TypeCtx()); err == nil { + } else if isTrue, err := val.ToBool(sc.TypeCtx()); err == nil { if isTrue == 0 { // c is false ret *= 0 diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go index 608a8707783b4..704bd5e8b5773 100644 --- a/pkg/planner/core/exhaust_physical_plans.go +++ b/pkg/planner/core/exhaust_physical_plans.go @@ -1562,7 +1562,10 @@ func (cwc *ColWithCmpFuncManager) BuildRangesByRow(ctx *rangerctx.RangerContext, if err != nil { return nil, err } - cwc.TmpConstant[i].Value = constantArg + *cwc.TmpConstant[i] = expression.Constant{ + Value: constantArg, + RetType: cwc.TmpConstant[i].RetType, + } newExpr, err := expression.NewFunction(exprCtx, opType, types.NewFieldType(mysql.TypeTiny), cwc.TargetCol, cwc.TmpConstant[i]) if err != nil { return nil, err diff --git a/pkg/planner/core/expression_rewriter.go b/pkg/planner/core/expression_rewriter.go index 7bff40645734a..7039b027367ea 100644 --- a/pkg/planner/core/expression_rewriter.go +++ b/pkg/planner/core/expression_rewriter.go @@ -1388,7 +1388,11 @@ func hasCTEConsumerInSubPlan(p base.LogicalPlan) bool { func initConstantRepertoire(ctx expression.EvalContext, c *expression.Constant) { c.SetRepertoire(expression.ASCII) if c.GetType(ctx).EvalType() == types.ETString { - for _, b := range c.Value.GetBytes() { + val, ok := c.GetValue() + if !ok { + return + } + for _, b := range val.GetBytes() { // if any character in constant is not ascii, set the repertoire to UNICODE. if b >= 0x80 { c.SetRepertoire(expression.UNICODE) diff --git a/pkg/planner/core/logical_aggregation.go b/pkg/planner/core/logical_aggregation.go index e548c536b09cb..3f49b8f6e0c7e 100644 --- a/pkg/planner/core/logical_aggregation.go +++ b/pkg/planner/core/logical_aggregation.go @@ -739,7 +739,15 @@ func (la *LogicalAggregation) canPullUp() bool { for _, f := range la.AggFuncs { for _, arg := range f.Args { expr := expression.EvaluateExprWithNull(la.SCtx().GetExprCtx(), la.Children()[0].Schema(), arg) - if con, ok := expr.(*expression.Constant); !ok || !con.Value.IsNull() { + con, ok := expr.(*expression.Constant) + if !ok { + return false + } + val, ok := con.GetValueWithoutOverOptimization(la.SCtx().GetExprCtx()) + if !ok { + return false + } + if !val.IsNull() { return false } } diff --git a/pkg/planner/core/logical_plans.go b/pkg/planner/core/logical_plans.go index 7fdfb0926371a..8ad7de925800e 100644 --- a/pkg/planner/core/logical_plans.go +++ b/pkg/planner/core/logical_plans.go @@ -337,8 +337,12 @@ func (p *LogicalJoin) extractFDForOuterJoin(filtersFromApply []expression.Expres if opt.OnlyInnerFilter { // if one of the inner condition is constant false, the inner side are all null, left make constant all of that. for _, one := range innerCondition { - if c, ok := one.(*expression.Constant); ok && c.DeferredExpr == nil && c.ParamMarker == nil { - if isTrue, err := c.Value.ToBool(p.SCtx().GetSessionVars().StmtCtx.TypeCtx()); err == nil { + if c, ok := one.(*expression.Constant); ok { + cVal, ok := c.GetValueWithoutOverOptimization(p.SCtx().GetExprCtx()) + if !ok { + continue + } + if isTrue, err := cVal.ToBool(p.SCtx().GetSessionVars().StmtCtx.TypeCtx()); err == nil { if isTrue == 0 { // c is false opt.InnerIsFalse = true diff --git a/pkg/planner/core/memtable_predicate_extractor.go b/pkg/planner/core/memtable_predicate_extractor.go index bdaa9f6b97912..2a87129519290 100644 --- a/pkg/planner/core/memtable_predicate_extractor.go +++ b/pkg/planner/core/memtable_predicate_extractor.go @@ -73,7 +73,7 @@ func (extractHelper) extractColInConsExpr(ctx base.PlanContext, extractCols map[ if !ok || constant.DeferredExpr != nil { return "", nil } - v := constant.Value + v, _ := constant.GetValue() if constant.ParamMarker != nil { v = constant.ParamMarker.GetUserVar(ctx.GetExprCtx().GetEvalCtx()) } @@ -155,7 +155,7 @@ func (helper *extractHelper) extractColBinaryOpConsExpr(ctx base.PlanContext, ex if !ok || constant.DeferredExpr != nil { return "", nil } - v := constant.Value + v, _ := constant.GetValue() if constant.ParamMarker != nil { v = constant.ParamMarker.GetUserVar(ctx.GetExprCtx().GetEvalCtx()) } diff --git a/pkg/planner/core/optimizer.go b/pkg/planner/core/optimizer.go index 3816116e9052f..0016be97c225a 100644 --- a/pkg/planner/core/optimizer.go +++ b/pkg/planner/core/optimizer.go @@ -576,8 +576,12 @@ func rewriteTableScanAndAggArgs(physicalTableScan *PhysicalTableScan, aggFuncs [ if !ok { return } + constExprVal, ok := constExpr.GetValueWithoutOverOptimization(physicalTableScan.SCtx().GetExprCtx()) + if !ok { + return + } // count(null) shouldn't be rewritten - if constExpr.Value.IsNull() { + if constExprVal.IsNull() { continue } aggFunc.Args[0] = arg diff --git a/pkg/planner/core/rule_column_pruning.go b/pkg/planner/core/rule_column_pruning.go index 1dfbd7cdc375c..c92aa2facc4bd 100644 --- a/pkg/planner/core/rule_column_pruning.go +++ b/pkg/planner/core/rule_column_pruning.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" "github.com/pingcap/tidb/pkg/planner/util/optimizetrace/logicaltrace" + "github.com/pingcap/tidb/pkg/types" ) type columnPruner struct { @@ -550,7 +551,7 @@ func addConstOneForEmptyProjection(p base.LogicalPlan) { RetType: constOne.GetType(p.SCtx().GetExprCtx().GetEvalCtx()), }) proj.Exprs = append(proj.Exprs, &expression.Constant{ - Value: constOne.Value, + Value: types.NewDatum(1), RetType: constOne.GetType(p.SCtx().GetExprCtx().GetEvalCtx()), }) } diff --git a/pkg/planner/core/rule_derive_topn_from_window.go b/pkg/planner/core/rule_derive_topn_from_window.go index b336dae1ad7a6..f824e0e750b01 100644 --- a/pkg/planner/core/rule_derive_topn_from_window.go +++ b/pkg/planner/core/rule_derive_topn_from_window.go @@ -86,7 +86,7 @@ func windowIsTopN(p *LogicalSelection) (bool, uint64) { } // Check if filter is column < constant or column <= constant. If it is in this form find column and constant. - column, limitValue := expression.FindUpperBound(p.Conditions[0]) + column, limitValue := expression.FindUpperBound(p.SCtx().GetExprCtx(), p.Conditions[0]) if column == nil || limitValue <= 0 { return false, 0 } diff --git a/pkg/planner/core/rule_partition_processor.go b/pkg/planner/core/rule_partition_processor.go index 0350591926d9f..bc0421f0fce69 100644 --- a/pkg/planner/core/rule_partition_processor.go +++ b/pkg/planner/core/rule_partition_processor.go @@ -558,10 +558,12 @@ func newListPartitionPruner(ctx base.PlanContext, tbl table.Table, partitionName func (l *listPartitionPruner) locatePartition(cond expression.Expression) (tables.ListPartitionLocation, bool, error) { switch sf := cond.(type) { case *expression.Constant: - b, err := sf.Value.ToBool(l.ctx.GetSessionVars().StmtCtx.TypeCtx()) - if err == nil && b == 0 { - // A constant false expression. - return nil, false, nil + if val, ok := sf.GetValueWithoutOverOptimization(l.ctx.GetExprCtx()); ok { + b, err := val.ToBool(l.ctx.GetSessionVars().StmtCtx.TypeCtx()) + if err == nil && b == 0 { + // A constant false expression. + return nil, false, nil + } } case *expression.ScalarFunction: switch sf.FuncName.L { @@ -1107,8 +1109,9 @@ func minCmp(ctx base.PlanContext, lowVal []types.Datum, columnsPruner *rangeColu // Not a constant, pruning not possible, so value is considered less than all partitions return true } + conVal, _ := con.GetValueWithoutOverOptimization(ctx.GetExprCtx()) // Add Null as point here? - cmp, err := con.Value.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &lowVal[j], comparer[j]) + cmp, err := conVal.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &lowVal[j], comparer[j]) if err != nil { *gotError = true } @@ -1145,23 +1148,25 @@ func minCmp(ctx base.PlanContext, lowVal []types.Datum, columnsPruner *rangeColu return true } if con, ok := (*conExpr).(*expression.Constant); ok && col != nil { + // `conExpr` comes from table definition, so it's always fine to get the value. + conVal, _ := con.GetValue() switch col.RetType.EvalType() { case types.ETInt: if mysql.HasUnsignedFlag(col.RetType.GetFlag()) { - if con.Value.GetUint64() == 0 { + if conVal.GetUint64() == 0 { return false } } else { - if con.Value.GetInt64() == types.IntergerSignedLowerBound(col.GetStaticType().GetType()) { + if conVal.GetInt64() == types.IntergerSignedLowerBound(col.GetStaticType().GetType()) { return false } } case types.ETDatetime: - if con.Value.GetMysqlTime().IsZero() { + if conVal.GetMysqlTime().IsZero() { return false } case types.ETString: - if len(con.Value.GetString()) == 0 { + if len(conVal.GetString()) == 0 { return false } } @@ -1186,8 +1191,13 @@ func maxCmp(ctx base.PlanContext, hiVal []types.Datum, columnsPruner *rangeColum // Not a constant, include every partition, i.e. value is not less than any partition return false } + conVal, ok := con.GetValueWithoutOverOptimization(ctx.GetExprCtx()) + if !ok { + // Not a constant. Maybe it's parameter or deferred function. + return false + } // Add Null as point here? - cmp, err := con.Value.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &hiVal[j], comparer[j]) + cmp, err := conVal.Compare(ctx.GetSessionVars().StmtCtx.TypeCtx(), &hiVal[j], comparer[j]) if err != nil { *gotError = true // error pushed, we will still use the cmp value @@ -1324,9 +1334,11 @@ type rangePruner struct { func (p *rangePruner) partitionRangeForExpr(sctx base.PlanContext, expr expression.Expression) (start int, end int, ok bool) { if constExpr, ok := expr.(*expression.Constant); ok { - if b, err := constExpr.Value.ToBool(sctx.GetSessionVars().StmtCtx.TypeCtx()); err == nil && b == 0 { - // A constant false expression. - return 0, 0, true + if constVal, ok := constExpr.GetValueWithoutOverOptimization(sctx.GetExprCtx()); ok { + if b, err := constVal.ToBool(sctx.GetSessionVars().StmtCtx.TypeCtx()); err == nil && b == 0 { + // A constant false expression. + return 0, 0, true + } } } @@ -1364,7 +1376,11 @@ func partitionRangeColumnForInExpr(sctx base.PlanContext, args []expression.Expr if !ok { return pruner.fullRange() } - switch constExpr.Value.Kind() { + constExprVal, ok := constExpr.GetValueWithoutOverOptimization(sctx.GetExprCtx()) + if !ok { + return pruner.fullRange() + } + switch constExprVal.Kind() { case types.KindInt64, types.KindUint64, types.KindMysqlTime, types.KindString: // for safety, only support string,int and datetime now case types.KindNull: result = append(result, partitionRange{0, 1}) @@ -1401,7 +1417,11 @@ func partitionRangeForInExpr(sctx base.PlanContext, args []expression.Expression if !ok { return pruner.fullRange() } - if constExpr.Value.Kind() == types.KindNull { + constExprVal, ok := constExpr.GetValueWithoutOverOptimization(sctx.GetExprCtx()) + if !ok { + return pruner.fullRange() + } + if constExprVal.Kind() == types.KindNull { result = append(result, partitionRange{0, 1}) continue } diff --git a/pkg/planner/core/rule_predicate_push_down.go b/pkg/planner/core/rule_predicate_push_down.go index 034e2bd91a591..1b4f24e2db599 100644 --- a/pkg/planner/core/rule_predicate_push_down.go +++ b/pkg/planner/core/rule_predicate_push_down.go @@ -525,10 +525,13 @@ func Conds2TableDual(p base.LogicalPlan, conds []expression.Expression) base.Log if expression.MaybeOverOptimized4PlanCache(p.SCtx().GetExprCtx(), []expression.Expression{con}) { return nil } - if isTrue, err := con.Value.ToBool(sc.TypeCtxOrDefault()); (err == nil && isTrue == 0) || con.Value.IsNull() { - dual := LogicalTableDual{}.Init(p.SCtx(), p.QueryBlockOffset()) - dual.SetSchema(p.Schema()) - return dual + conVal, ok := con.GetValueWithoutOverOptimization(p.SCtx().GetExprCtx()) + if ok { + if isTrue, err := conVal.ToBool(sc.TypeCtxOrDefault()); (err == nil && isTrue == 0) || conVal.IsNull() { + dual := LogicalTableDual{}.Init(p.SCtx(), p.QueryBlockOffset()) + dual.SetSchema(p.Schema()) + return dual + } } return nil } @@ -547,8 +550,11 @@ func DeleteTrueExprs(p base.LogicalPlan, conds []expression.Expression) []expres continue } sc := p.SCtx().GetSessionVars().StmtCtx - if isTrue, err := con.Value.ToBool(sc.TypeCtx()); err == nil && isTrue == 1 { - continue + conVal, ok := con.GetValueWithoutOverOptimization(p.SCtx().GetExprCtx()) + if ok { + if isTrue, err := conVal.ToBool(sc.TypeCtx()); err == nil && isTrue == 1 { + continue + } } newConds = append(newConds, cond) } diff --git a/pkg/planner/core/rule_predicate_simplification.go b/pkg/planner/core/rule_predicate_simplification.go index 3f4c948c034fc..625b5bd4dfce1 100644 --- a/pkg/planner/core/rule_predicate_simplification.go +++ b/pkg/planner/core/rule_predicate_simplification.go @@ -79,16 +79,20 @@ func updateInPredicate(ctx base.PlanContext, inPredicate expression.Expression, return inPredicate, true } v := inPredicate.(*expression.ScalarFunction) - notEQValue := notEQPredicate.(*expression.ScalarFunction).GetArgs()[1].(*expression.Constant) + notEQCon := notEQPredicate.(*expression.ScalarFunction).GetArgs()[1].(*expression.Constant) + notEQConVal, ok := notEQCon.GetValueWithoutOverOptimization(ctx.GetExprCtx()) // do not simplify != NULL since it is always false. - if notEQValue.Value.IsNull() { + if notEQConVal.IsNull() { + return inPredicate, true + } + if !ok { return inPredicate, true } newValues := make([]expression.Expression, 0, len(v.GetArgs())) var lastValue *expression.Constant for _, element := range v.GetArgs() { value, valueOK := element.(*expression.Constant) - redundantValue := valueOK && value.Equal(ctx.GetExprCtx().GetEvalCtx(), notEQValue) + redundantValue := valueOK && value.Equal(ctx.GetExprCtx().GetEvalCtx(), notEQCon) if !redundantValue { newValues = append(newValues, element) } diff --git a/pkg/planner/core/scalar_subq_expression.go b/pkg/planner/core/scalar_subq_expression.go index bf27085ced695..ce45d1acdc57a 100644 --- a/pkg/planner/core/scalar_subq_expression.go +++ b/pkg/planner/core/scalar_subq_expression.go @@ -98,7 +98,10 @@ func (s *ScalarSubQueryExpr) selfEvaluate() error { s.Constant = *expression.NewNull() return err } - s.Constant.Value = *colVal + s.Constant = expression.Constant{ + Value: *colVal, + RetType: s.Constant.RetType, + } s.evaled = true return nil } diff --git a/pkg/planner/util/null_misc.go b/pkg/planner/util/null_misc.go index 2722244826a82..4168df5996d1e 100644 --- a/pkg/planner/util/null_misc.go +++ b/pkg/planner/util/null_misc.go @@ -107,12 +107,18 @@ func isNullRejectedSimpleExpr(ctx context.PlanContext, schema *expression.Schema sc := ctx.GetSessionVars().StmtCtx result := expression.EvaluateExprWithNull(exprCtx, schema, expr) x, ok := result.(*expression.Constant) - if ok { - if x.Value.IsNull() { - return true - } else if isTrue, err := x.Value.ToBool(sc.TypeCtxOrDefault()); err == nil && isTrue == 0 { - return true - } + if !ok { + return false + } + xVal, ok := x.GetValueWithoutOverOptimization(ctx.GetExprCtx()) + if !ok { + return false + } + if xVal.IsNull() { + return true + } + if isTrue, err := xVal.ToBool(sc.TypeCtxOrDefault()); err == nil && isTrue == 0 { + return true } return false } diff --git a/pkg/util/ranger/checker.go b/pkg/util/ranger/checker.go index 5301ba18b3845..ec76de982ccbf 100644 --- a/pkg/util/ranger/checker.go +++ b/pkg/util/ranger/checker.go @@ -150,10 +150,14 @@ func (c *conditionChecker) checkLikeFunc(scalar *expression.ScalarFunction) (isA if !ok { return false, true } - if pattern.Value.IsNull() { + patternVal, ok := pattern.GetValue() + if !ok { + return false, true + } + if patternVal.IsNull() { return false, true } - patternStr, err := pattern.Value.ToString() + patternStr, err := patternVal.ToString() if err != nil { return false, true } @@ -170,7 +174,11 @@ func (c *conditionChecker) checkLikeFunc(scalar *expression.ScalarFunction) (isA if len(patternStr) == 0 { return true, likeFuncReserve } - escape := byte(scalar.GetArgs()[2].(*expression.Constant).Value.GetInt64()) + escapeVal, ok := scalar.GetArgs()[2].(*expression.Constant).GetValue() + if !ok { + return false, true + } + escape := byte(escapeVal.GetInt64()) for i := 0; i < len(patternStr); i++ { if patternStr[i] == escape { i++ diff --git a/pkg/util/ranger/detacher.go b/pkg/util/ranger/detacher.go index 2272612badf79..7bbd417dce9bb 100644 --- a/pkg/util/ranger/detacher.go +++ b/pkg/util/ranger/detacher.go @@ -599,12 +599,8 @@ func allEqOrIn(expr expression.Expression) bool { func extractValueInfo(expr expression.Expression) *valueInfo { if f, ok := expr.(*expression.ScalarFunction); ok && (f.FuncName.L == ast.EQ || f.FuncName.L == ast.NullEQ) { getValueInfo := func(c *expression.Constant) *valueInfo { - mutable := c.ParamMarker != nil || c.DeferredExpr != nil - var value *types.Datum - if !mutable { - value = &c.Value - } - return &valueInfo{value, mutable} + value, ok := c.GetValue() + return &valueInfo{&value, !ok} } if c, ok := f.GetArgs()[0].(*expression.Constant); ok { return getValueInfo(c) diff --git a/tests/integrationtest/r/planner/core/integration.result b/tests/integrationtest/r/planner/core/integration.result index c642d678c258f..0ca56b3a86f86 100644 --- a/tests/integrationtest/r/planner/core/integration.result +++ b/tests/integrationtest/r/planner/core/integration.result @@ -2758,20 +2758,16 @@ prepare stmt from 'select id from t1_no_idx where col_bit = ?'; set @a = 0x3135; execute stmt using @a; id -1 set @a = 0x0F; execute stmt using @a; id -2 prepare stmt from 'select id from t1_no_idx where col_bit in (?)'; set @a = 0x3135; execute stmt using @a; id -1 set @a = 0x0F; execute stmt using @a; id -2 drop table if exists t2_idx; create table t2_idx(id int, col_bit bit(16), key(col_bit)); insert into t2_idx values(1, 0x3135); @@ -2780,20 +2776,16 @@ prepare stmt from 'select id from t2_idx where col_bit = ?'; set @a = 0x3135; execute stmt using @a; id -1 set @a = 0x0F; execute stmt using @a; id -2 prepare stmt from 'select id from t2_idx where col_bit in (?)'; set @a = 0x3135; execute stmt using @a; id -1 set @a = 0x0F; execute stmt using @a; id -2 drop table if exists t_varchar; create table t_varchar(id int, col_varchar varchar(100), key(col_varchar)); insert into t_varchar values(1, '15'); @@ -2810,12 +2802,10 @@ prepare stmt from 'select id from t1_no_idx where col_bit = ?'; set @a = 0b11000100110101; execute stmt using @a; id -1 prepare stmt from 'select id from t1_no_idx where col_bit in (?)'; set @a = 0b11000100110101; execute stmt using @a; id -1 drop table if exists t; create table t (c1 float, c2 int, c3 int, primary key (c1) /*T![clustered_index] CLUSTERED */, key idx_1 (c2), key idx_2 (c3)); insert into t values(1.0,1,2),(2.0,2,1),(3.0,1,1),(4.0,2,2); @@ -4312,3 +4302,12 @@ case when ( t.c0 in (t.c0, cast((cast(1 as unsigned) - cast(t.c1 as signed)) as char)) ) then 1 else 2 end; 1 +drop table if exists t; +create table t (v bigint); +prepare stmt5 from 'select * from t where v = -?;'; +set @arg=1; +execute stmt5 using @arg; +v +set @arg=-9223372036854775808; +execute stmt5 using @arg; +v diff --git a/tests/integrationtest/t/planner/core/integration.test b/tests/integrationtest/t/planner/core/integration.test index 6283261333f44..18d125746c1f0 100644 --- a/tests/integrationtest/t/planner/core/integration.test +++ b/tests/integrationtest/t/planner/core/integration.test @@ -2376,4 +2376,13 @@ create table t (col TEXT); select 1 from (select t.col as c0, 46578369 as c1 from t) as t where case when ( t.c0 in (t.c0, cast((cast(1 as unsigned) - cast(t.c1 as signed)) as char)) - ) then 1 else 2 end; \ No newline at end of file + ) then 1 else 2 end; + +# TestIssue53504 +drop table if exists t; +create table t (v bigint); +prepare stmt5 from 'select * from t where v = -?;'; +set @arg=1; +execute stmt5 using @arg; +set @arg=-9223372036854775808; +execute stmt5 using @arg;