Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: Move more methods from SessionVars to BuildContext #52440

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (c *castAsIntFunctionClass) getFunction(ctx BuildContext, args []Expression
if err != nil {
return nil, err
}
bf := newBaseBuiltinCastFunc(b, ctx.Value(inUnionCastContext) != nil)
bf := newBaseBuiltinCastFunc(b, ctx.IsInUnionCast())
if args[0].GetType().Hybrid() || IsBinaryLiteral(args[0]) {
sig = &builtinCastIntAsIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastIntAsInt)
Expand Down Expand Up @@ -171,7 +171,7 @@ func (c *castAsRealFunctionClass) getFunction(ctx BuildContext, args []Expressio
if err != nil {
return nil, err
}
bf := newBaseBuiltinCastFunc(b, ctx.Value(inUnionCastContext) != nil)
bf := newBaseBuiltinCastFunc(b, ctx.IsInUnionCast())
if IsBinaryLiteral(args[0]) {
sig = &builtinCastRealAsRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastRealAsReal)
Expand Down Expand Up @@ -226,7 +226,7 @@ func (c *castAsDecimalFunctionClass) getFunction(ctx BuildContext, args []Expres
if err != nil {
return nil, err
}
bf := newBaseBuiltinCastFunc(b, ctx.Value(inUnionCastContext) != nil)
bf := newBaseBuiltinCastFunc(b, ctx.IsInUnionCast())
if IsBinaryLiteral(args[0]) {
sig = &builtinCastDecimalAsDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastDecimalAsDecimal)
Expand Down Expand Up @@ -2052,10 +2052,10 @@ func CanImplicitEvalReal(expr Expression) bool {
// BuildCastFunction4Union build a implicitly CAST ScalarFunction from the Union
// Expression.
func BuildCastFunction4Union(ctx BuildContext, expr Expression, tp *types.FieldType) (res Expression) {
ctx.SetValue(inUnionCastContext, struct{}{})
defer func() {
ctx.SetValue(inUnionCastContext, nil)
}()
if !ctx.IsInUnionCast() {
ctx.SetInUnionCast(true)
defer ctx.SetInUnionCast(false)
}
return BuildCastFunction(ctx, expr, tp)
}

Expand Down
3 changes: 1 addition & 2 deletions pkg/expression/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ func foldConstant(ctx BuildContext, expr Expression) (Expression, bool) {
}

args := x.GetArgs()
sc := ctx.GetSessionVars().StmtCtx
argIsConst := make([]bool, len(args))
hasNullArg := false
allConstArg := true
Expand All @@ -194,7 +193,7 @@ func foldConstant(ctx BuildContext, expr Expression) (Expression, bool) {
//
// NullEQ and ConcatWS are excluded, because they could have different value when the non-constant value is
// 1 or NULL. For example, concat_ws(NULL, NULL) gives NULL, but concat_ws(1, NULL) gives ''
if !hasNullArg || !sc.InNullRejectCheck || x.FuncName.L == ast.NullEQ || x.FuncName.L == ast.ConcatWS {
if !hasNullArg || !ctx.IsInNullRejectCheck() || x.FuncName.L == ast.NullEQ || x.FuncName.L == ast.ConcatWS {
return expr, isDeferredConst
}
constArgs := make([]Expression, len(args))
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/constant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func TestConstantFolding(t *testing.T) {
{
condition: func(ctx BuildContext) Expression {
expr := newFunction(ctx, ast.ConcatWS, newColumn(0), NewNull())
ctx.GetSessionVars().StmtCtx.InNullRejectCheck = true
ctx.SetInNullRejectCheck(true)
return expr
},
result: "concat_ws(cast(Column#0, var_string(20)), <nil>)",
Expand Down
15 changes: 10 additions & 5 deletions pkg/expression/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package context

import (
"fmt"
"time"

"github.com/pingcap/tidb/pkg/errctx"
Expand Down Expand Up @@ -86,12 +85,18 @@ type BuildContext interface {
IsUseCache() bool
// SetSkipPlanCache sets to skip the plan cache and records the reason.
SetSkipPlanCache(reason error)
// AllocPlanColumnID allocates column id for plan.
AllocPlanColumnID() int64
// SetInNullRejectCheck sets the flag to indicate whether the expression is in null reject check.
SetInNullRejectCheck(in bool)
// IsInNullRejectCheck returns the flag to indicate whether the expression is in null reject check.
IsInNullRejectCheck() bool
// SetInUnionCast sets the flag to indicate whether the expression is in union cast.
SetInUnionCast(in bool)
// IsInUnionCast indicates whether executing in special cast context that negative unsigned num will be zero.
IsInUnionCast() bool
// GetSessionVars gets the session variables.
GetSessionVars() *variable.SessionVars
// Value returns the value associated with this context for key.
Value(key fmt.Stringer) any
// SetValue saves a value associated with this context for key.
SetValue(key fmt.Stringer, value any)
}

// ExprContext contains full context for expression building and evaluating.
Expand Down
28 changes: 28 additions & 0 deletions pkg/expression/contextimpl/sessionctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package contextimpl
import (
"context"
"math"
"sync/atomic"
"time"

"github.com/pingcap/tidb/pkg/errctx"
Expand Down Expand Up @@ -51,6 +52,8 @@ var _ exprctx.ExprContext = struct {
type ExprCtxExtendedImpl struct {
sctx sessionctx.Context
*SessionEvalContext
inNullRejectCheck atomic.Bool
inUnionCast atomic.Bool
}

// NewExprExtendedImpl creates a new ExprCtxExtendedImpl.
Expand Down Expand Up @@ -109,6 +112,31 @@ func (ctx *ExprCtxExtendedImpl) SetSkipPlanCache(reason error) {
ctx.sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(reason)
}

// AllocPlanColumnID allocates column id for plan.
func (ctx *ExprCtxExtendedImpl) AllocPlanColumnID() int64 {
return ctx.sctx.GetSessionVars().AllocPlanColumnID()
}

// SetInNullRejectCheck sets whether the expression is in null reject check.
func (ctx *ExprCtxExtendedImpl) SetInNullRejectCheck(in bool) {
ctx.inNullRejectCheck.Store(in)
}

// IsInNullRejectCheck returns whether the expression is in null reject check.
func (ctx *ExprCtxExtendedImpl) IsInNullRejectCheck() bool {
return ctx.inNullRejectCheck.Load()
}

// SetInUnionCast sets the flag to indicate whether the expression is in union cast.
func (ctx *ExprCtxExtendedImpl) SetInUnionCast(in bool) {
ctx.inUnionCast.Store(in)
}

// IsInUnionCast indicates whether executing in special cast context that negative unsigned num will be zero.
func (ctx *ExprCtxExtendedImpl) IsInUnionCast() bool {
return ctx.inUnionCast.Load()
}

// GetWindowingUseHighPrecision determines whether to compute window operations without loss of precision.
// see https://dev.mysql.com/doc/refman/8.0/en/window-function-optimization.html for more details.
func (ctx *ExprCtxExtendedImpl) GetWindowingUseHighPrecision() bool {
Expand Down
29 changes: 29 additions & 0 deletions pkg/expression/contextimpl/sessionctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ func TestSessionBuildContext(t *testing.T) {
require.True(t, evalCtx.GetOptionalPropSet().IsFull())
require.Same(t, ctx, evalCtx.Sctx())

// charset and collation
vars := ctx.GetSessionVars()
err := vars.SetSystemVar("character_set_connection", "gbk")
require.NoError(t, err)
Expand All @@ -265,17 +266,45 @@ func TestSessionBuildContext(t *testing.T) {
require.Equal(t, "gbk_chinese_ci", collate)
require.Equal(t, "utf8mb4_0900_ai_ci", impl.GetDefaultCollationForUTF8MB4())

// SysdateIsNow
vars.SysdateIsNow = true
require.True(t, impl.GetSysdateIsNow())

// NoopFuncsMode
vars.NoopFuncsMode = 2
require.Equal(t, 2, impl.GetNoopFuncsMode())

// Rng
vars.Rng = mathutil.NewWithSeed(123)
require.Same(t, vars.Rng, impl.Rng())

// PlanCache
vars.StmtCtx.UseCache = true
require.True(t, impl.IsUseCache())
impl.SetSkipPlanCache(errors.New("mockReason"))
require.False(t, impl.IsUseCache())

// Alloc column id
prevID := vars.PlanColumnID.Load()
colID := impl.AllocPlanColumnID()
require.Equal(t, colID, prevID+1)
colID = impl.AllocPlanColumnID()
require.Equal(t, colID, prevID+2)
vars.AllocPlanColumnID()
colID = impl.AllocPlanColumnID()
require.Equal(t, colID, prevID+4)

// InNullRejectCheck
require.False(t, impl.IsInNullRejectCheck())
impl.SetInNullRejectCheck(true)
require.True(t, impl.IsInNullRejectCheck())
impl.SetInNullRejectCheck(false)
require.False(t, impl.IsInNullRejectCheck())

// InUnionCast
require.False(t, impl.IsInUnionCast())
impl.SetInUnionCast(true)
require.True(t, impl.IsInUnionCast())
impl.SetInUnionCast(false)
require.False(t, impl.IsInUnionCast())
}
4 changes: 2 additions & 2 deletions pkg/expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ func EvaluateExprWithNull(ctx BuildContext, schema *Schema, expr Expression) Exp
if MaybeOverOptimized4PlanCache(ctx, []Expression{expr}) {
ctx.SetSkipPlanCache(errors.NewNoStackError("%v affects null check"))
}
if ctx.GetSessionVars().StmtCtx.InNullRejectCheck {
if ctx.IsInNullRejectCheck() {
expr, _ = evaluateExprWithNullInNullRejectCheck(ctx, schema, expr)
return expr
}
Expand Down Expand Up @@ -1022,7 +1022,7 @@ func ColumnInfos2ColumnsAndNames(ctx BuildContext, dbName, tblName model.CIStr,
newCol := &Column{
RetType: col.FieldType.Clone(),
ID: col.ID,
UniqueID: ctx.GetSessionVars().AllocPlanColumnID(),
UniqueID: ctx.AllocPlanColumnID(),
Index: col.Offset,
OrigName: names[i].String(),
IsHidden: col.Hidden,
Expand Down
8 changes: 4 additions & 4 deletions pkg/planner/core/rule_predicate_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,10 @@ func isNullRejected(ctx PlanContext, schema *expression.Schema, expr expression.
return false
}
sc := ctx.GetSessionVars().StmtCtx
sc.InNullRejectCheck = true
defer func() {
sc.InNullRejectCheck = false
}()
if !exprCtx.IsInNullRejectCheck() {
exprCtx.SetInNullRejectCheck(true)
defer exprCtx.SetInNullRejectCheck(false)
}
for _, cond := range expression.SplitCNFItems(expr) {
if isNullRejectedSpecially(ctx, schema, expr) {
return true
Expand Down
1 change: 0 additions & 1 deletion pkg/sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ type StatementContext struct {
ForcePlanCache bool // force the optimizer to use plan cache even if there is risky optimization, see #49736.
CacheType PlanCacheType
BatchCheck bool
InNullRejectCheck bool
IgnoreExplainIDSuffix bool
MultiSchemaInfo *model.MultiSchemaInfo
// If the select statement was like 'select * from t as of timestamp ...' or in a stale read transaction
Expand Down