Skip to content

Commit dd64d25

Browse files
authored
lightining,expression: support user variable for BuildSimpleExpr and remove PlanContext dependency in lightning context (#55617)
ref #53388
1 parent 055716a commit dd64d25

File tree

9 files changed

+153
-64
lines changed

9 files changed

+153
-64
lines changed

pkg/executor/importer/import.go

+37
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,43 @@ func (e *LoadDataController) CreateColAssignExprs(planCtx planctx.PlanContext) (
13241324
return res, allWarnings, nil
13251325
}
13261326

1327+
// CreateColAssignSimpleExprs creates the column assignment expressions using `expression.BuildContext`.
1328+
// This method does not support:
1329+
// - Subquery
1330+
// - System Variables (e.g. `@@tidb_enable_async_commit`)
1331+
// - Window functions
1332+
// - Aggregate functions
1333+
// - Other special functions used in some specified queries such as `GROUPING`, `VALUES` ...
1334+
func (e *LoadDataController) CreateColAssignSimpleExprs(ctx expression.BuildContext) (_ []expression.Expression, _ []contextutil.SQLWarn, retErr error) {
1335+
var (
1336+
i int
1337+
assign *ast.Assignment
1338+
)
1339+
1340+
// TODO(lance6716): indeterministic function should also return error
1341+
defer tidbutil.Recover("load-data/import-into", "CreateColAssignExprs", func() {
1342+
retErr = errors.Errorf("can't use function at SET index %d", i)
1343+
}, false)
1344+
1345+
e.colAssignMu.Lock()
1346+
defer e.colAssignMu.Unlock()
1347+
res := make([]expression.Expression, 0, len(e.ColumnAssignments))
1348+
var allWarnings []contextutil.SQLWarn
1349+
for i, assign = range e.ColumnAssignments {
1350+
newExpr, err := expression.BuildSimpleExpr(ctx, assign.Expr)
1351+
// col assign expr warnings is static, we should generate it for each row processed.
1352+
// so we save it and clear it here.
1353+
if ctx.GetEvalCtx().WarningCount() > 0 {
1354+
allWarnings = append(allWarnings, ctx.GetEvalCtx().TruncateWarnings(0)...)
1355+
}
1356+
if err != nil {
1357+
return nil, nil, err
1358+
}
1359+
res = append(res, newExpr)
1360+
}
1361+
return res, allWarnings, nil
1362+
}
1363+
13271364
func (e *LoadDataController) getBackendWorkerConcurrency() int {
13281365
// suppose cpu:mem ratio 1:2(true in most case), and we assign 1G per concurrency,
13291366
// so we can use 2 * threadCnt as concurrency. write&ingest step is mostly

pkg/executor/importer/kv_encode.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func NewTableKVEncoder(
6464
}
6565
// we need a non-nil TxnCtx to avoid panic when evaluating set clause
6666
baseKVEncoder.SessionCtx.SetTxnCtxNotNil()
67-
colAssignExprs, _, err := ti.CreateColAssignExprs(baseKVEncoder.SessionCtx.GetPlanCtx())
67+
colAssignExprs, _, err := ti.CreateColAssignSimpleExprs(baseKVEncoder.SessionCtx.GetExprCtx())
6868
if err != nil {
6969
return nil, err
7070
}

pkg/expression/expression.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,10 @@ func WithCastExprTo(targetFt *types.FieldType) BuildOption {
103103
// This function is used to build some "simple" expressions with limited context.
104104
// The below expressions are not supported:
105105
// - Subquery
106-
// - Param marker (e.g. `?`)
107-
// - Variable (e.g. `@a`)
106+
// - System Variables (e.g. `@tidb_enable_async_commit`)
108107
// - Window functions
109108
// - Aggregate functions
110-
// - Other special functions such as `GROUPING`
109+
// - Other special functions used in some specified queries such as `GROUPING`, `VALUES` ...
111110
//
112111
// If you want to build a more complex expression, you should use `EvalAstExprWithPlanCtx` or `RewriteAstExprWithPlanCtx`
113112
// in `github.com/pingcap/tidb/pkg/planner/util`. They are more powerful but need planner context to build expressions.

pkg/lightning/backend/kv/BUILD.bazel

-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ go_library(
2929
"//pkg/parser/model",
3030
"//pkg/parser/mysql",
3131
"//pkg/planner/context",
32-
"//pkg/planner/contextimpl",
3332
"//pkg/sessionctx",
3433
"//pkg/sessionctx/variable",
3534
"//pkg/table",

pkg/lightning/backend/kv/session.go

-21
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ import (
3737
"github.com/pingcap/tidb/pkg/lightning/manual"
3838
"github.com/pingcap/tidb/pkg/parser/model"
3939
planctx "github.com/pingcap/tidb/pkg/planner/context"
40-
planctximpl "github.com/pingcap/tidb/pkg/planner/contextimpl"
4140
"github.com/pingcap/tidb/pkg/sessionctx"
4241
"github.com/pingcap/tidb/pkg/sessionctx/variable"
4342
tbctx "github.com/pingcap/tidb/pkg/table/context"
@@ -285,11 +284,6 @@ func (*transaction) MayFlush() error {
285284
return nil
286285
}
287286

288-
type planCtxImpl struct {
289-
*session
290-
*planctximpl.PlanCtxExtendedImpl
291-
}
292-
293287
// session is a trimmed down Session type which only wraps our own trimmed-down
294288
// transaction type and provides the session variables to the TiDB library
295289
// optimized for Lightning.
@@ -301,7 +295,6 @@ type session struct {
301295
txn transaction
302296
Vars *variable.SessionVars
303297
exprCtx *exprctximpl.SessionExprContext
304-
planctx *planCtxImpl
305298
tblctx *tbctximpl.TableContextImpl
306299
// currently, we only set `CommonAddRecordCtx`
307300
values map[fmt.Stringer]any
@@ -358,10 +351,6 @@ func newSession(options *encode.SessionOptions, logger log.Logger) *session {
358351
vars.TxnCtx = nil
359352
s.Vars = vars
360353
s.exprCtx = exprctximpl.NewSessionExprContext(s)
361-
s.planctx = &planCtxImpl{
362-
session: s,
363-
PlanCtxExtendedImpl: planctximpl.NewPlanCtxExtendedImpl(s),
364-
}
365354
s.tblctx = tbctximpl.NewTableContextImpl(s)
366355
s.txn.kvPairs = &Pairs{}
367356

@@ -378,11 +367,6 @@ func (se *session) GetSessionVars() *variable.SessionVars {
378367
return se.Vars
379368
}
380369

381-
// GetPlanCtx returns the PlanContext.
382-
func (se *session) GetPlanCtx() planctx.PlanContext {
383-
return se.planctx
384-
}
385-
386370
// GetExprCtx returns the expression context of the session.
387371
func (se *session) GetExprCtx() exprctx.ExprContext {
388372
return se.exprCtx
@@ -443,11 +427,6 @@ func (s *Session) GetTableCtx() tbctx.MutateContext {
443427
return s.sctx.tblctx
444428
}
445429

446-
// GetPlanCtx returns the context for planner.
447-
func (s *Session) GetPlanCtx() planctx.PlanContext {
448-
return s.sctx.planctx
449-
}
450-
451430
// TakeKvPairs returns the current Pairs and resets the buffer.
452431
func (s *Session) TakeKvPairs() *Pairs {
453432
memBuf := &s.sctx.txn.MemBuf

pkg/planner/core/BUILD.bazel

+2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ go_library(
105105
"//pkg/expression",
106106
"//pkg/expression/aggregation",
107107
"//pkg/expression/context",
108+
"//pkg/expression/contextopt",
108109
"//pkg/infoschema",
109110
"//pkg/infoschema/context",
110111
"//pkg/kv",
@@ -272,6 +273,7 @@ go_test(
272273
"//pkg/expression",
273274
"//pkg/expression/aggregation",
274275
"//pkg/expression/context",
276+
"//pkg/expression/contextopt",
275277
"//pkg/expression/contextstatic",
276278
"//pkg/infoschema",
277279
"//pkg/kv",

pkg/planner/core/expression_rewriter.go

+61-37
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import (
2626
"github.com/pingcap/errors"
2727
"github.com/pingcap/tidb/pkg/expression"
2828
"github.com/pingcap/tidb/pkg/expression/aggregation"
29+
exprctx "github.com/pingcap/tidb/pkg/expression/context"
30+
"github.com/pingcap/tidb/pkg/expression/contextopt"
2931
"github.com/pingcap/tidb/pkg/infoschema"
3032
infoschemactx "github.com/pingcap/tidb/pkg/infoschema/context"
3133
"github.com/pingcap/tidb/pkg/parser/ast"
@@ -505,17 +507,20 @@ func (er *expressionRewriter) buildSubquery(ctx context.Context, planCtx *exprRe
505507
return np, hintFlags, nil
506508
}
507509

508-
func (er *expressionRewriter) requirePlanCtx(inNode ast.Node) (ctx *exprRewriterPlanCtx, err error) {
510+
func (er *expressionRewriter) requirePlanCtx(inNode ast.Node, detail string) (ctx *exprRewriterPlanCtx, err error) {
509511
if ctx = er.planCtx; ctx == nil {
510-
err = errors.Errorf("node '%T' is not allowed when building an expression without planner", inNode)
512+
if detail != "" {
513+
detail = ", " + detail
514+
}
515+
err = errors.Errorf("planCtx is required when rewriting node: '%T'%s", inNode, detail)
511516
}
512517
return
513518
}
514519

515520
// Enter implements Visitor interface.
516521
func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
517522
enterWithPlanCtx := func(fn func(*exprRewriterPlanCtx) (ast.Node, bool)) (ast.Node, bool) {
518-
planCtx, err := er.requirePlanCtx(inNode)
523+
planCtx, err := er.requirePlanCtx(inNode, "")
519524
if err != nil {
520525
er.err = err
521526
return inNode, true
@@ -1416,8 +1421,8 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
14161421
inNode = er.preprocess(inNode)
14171422
}
14181423

1419-
withPlanCtx := func(fn func(*exprRewriterPlanCtx)) {
1420-
planCtx, err := er.requirePlanCtx(inNode)
1424+
withPlanCtx := func(fn func(*exprRewriterPlanCtx), detail string) {
1425+
planCtx, err := er.requirePlanCtx(inNode, detail)
14211426
if err != nil {
14221427
er.err = err
14231428
return
@@ -1448,15 +1453,19 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
14481453
case *driver.ParamMarkerExpr:
14491454
er.toParamMarker(v)
14501455
case *ast.VariableExpr:
1451-
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
1452-
er.rewriteVariable(planCtx, v)
1453-
})
1456+
if v.IsSystem {
1457+
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
1458+
er.rewriteSystemVariable(planCtx, v)
1459+
}, "accessing system variable requires plan context")
1460+
} else {
1461+
er.rewriteUserVariable(v)
1462+
}
14541463
case *ast.FuncCallExpr:
14551464
switch v.FnName.L {
14561465
case ast.Grouping:
14571466
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
14581467
er.funcCallToExpressionWithPlanCtx(planCtx, v)
1459-
})
1468+
}, "grouping function requires plan context")
14601469
default:
14611470
if _, ok := expression.TryFoldFunctions[v.FnName.L]; ok {
14621471
er.tryFoldCounter--
@@ -1536,7 +1545,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
15361545
case *ast.PositionExpr:
15371546
withPlanCtx(func(planCtx *exprRewriterPlanCtx) {
15381547
er.positionToScalarFunc(planCtx, v)
1539-
})
1548+
}, "")
15401549
case *ast.IsNullExpr:
15411550
er.isNullToExpression(v)
15421551
case *ast.IsTruthExpr:
@@ -1651,37 +1660,52 @@ func (er *expressionRewriter) useCache() bool {
16511660
return er.sctx.IsUseCache()
16521661
}
16531662

1654-
func (er *expressionRewriter) rewriteVariable(planCtx *exprRewriterPlanCtx, v *ast.VariableExpr) {
1663+
func (er *expressionRewriter) rewriteUserVariable(v *ast.VariableExpr) {
16551664
stkLen := len(er.ctxStack)
16561665
name := strings.ToLower(v.Name)
1657-
sessionVars := planCtx.builder.ctx.GetSessionVars()
1658-
if !v.IsSystem {
1659-
if v.Value != nil {
1660-
tp := er.ctxStack[stkLen-1].GetType(er.sctx.GetEvalCtx())
1661-
er.ctxStack[stkLen-1], er.err = er.newFunction(ast.SetVar, tp,
1662-
expression.DatumToConstant(types.NewDatum(name), mysql.TypeString, 0),
1663-
er.ctxStack[stkLen-1])
1664-
er.ctxNameStk[stkLen-1] = types.EmptyName
1665-
// Store the field type of the variable into SessionVars.UserVarTypes.
1666-
// Normally we can infer the type from SessionVars.User, but we need SessionVars.UserVarTypes when
1667-
// GetVar has not been executed to fill the SessionVars.Users.
1668-
sessionVars.SetUserVarType(name, tp)
1669-
return
1670-
}
1671-
tp, ok := sessionVars.GetUserVarType(name)
1672-
if !ok {
1673-
tp = types.NewFieldType(mysql.TypeVarString)
1674-
tp.SetFlen(mysql.MaxFieldVarCharLength)
1675-
}
1676-
f, err := er.newFunction(ast.GetVar, tp, expression.DatumToConstant(types.NewStringDatum(name), mysql.TypeString, 0))
1677-
if err != nil {
1678-
er.err = err
1679-
return
1680-
}
1681-
f.SetCoercibility(expression.CoercibilityImplicit)
1682-
er.ctxStackAppend(f, types.EmptyName)
1666+
evalCtx := er.sctx.GetEvalCtx()
1667+
if !evalCtx.GetOptionalPropSet().Contains(exprctx.OptPropSessionVars) {
1668+
er.err = errors.Errorf("rewriting user variable requires '%s' in evalCtx", exprctx.OptPropSessionVars.String())
1669+
return
1670+
}
1671+
1672+
sessionVars, err := contextopt.SessionVarsPropReader{}.GetSessionVars(evalCtx)
1673+
if err != nil {
1674+
er.err = err
1675+
return
1676+
}
1677+
1678+
intest.Assert(er.planCtx == nil || sessionVars == er.planCtx.builder.ctx.GetSessionVars())
1679+
1680+
if v.Value != nil {
1681+
tp := er.ctxStack[stkLen-1].GetType(er.sctx.GetEvalCtx())
1682+
er.ctxStack[stkLen-1], er.err = er.newFunction(ast.SetVar, tp,
1683+
expression.DatumToConstant(types.NewDatum(name), mysql.TypeString, 0),
1684+
er.ctxStack[stkLen-1])
1685+
er.ctxNameStk[stkLen-1] = types.EmptyName
1686+
// Store the field type of the variable into SessionVars.UserVarTypes.
1687+
// Normally we can infer the type from SessionVars.User, but we need SessionVars.UserVarTypes when
1688+
// GetVar has not been executed to fill the SessionVars.Users.
1689+
sessionVars.SetUserVarType(name, tp)
1690+
return
1691+
}
1692+
tp, ok := sessionVars.GetUserVarType(name)
1693+
if !ok {
1694+
tp = types.NewFieldType(mysql.TypeVarString)
1695+
tp.SetFlen(mysql.MaxFieldVarCharLength)
1696+
}
1697+
f, err := er.newFunction(ast.GetVar, tp, expression.DatumToConstant(types.NewStringDatum(name), mysql.TypeString, 0))
1698+
if err != nil {
1699+
er.err = err
16831700
return
16841701
}
1702+
f.SetCoercibility(expression.CoercibilityImplicit)
1703+
er.ctxStackAppend(f, types.EmptyName)
1704+
}
1705+
1706+
func (er *expressionRewriter) rewriteSystemVariable(planCtx *exprRewriterPlanCtx, v *ast.VariableExpr) {
1707+
name := strings.ToLower(v.Name)
1708+
sessionVars := planCtx.builder.ctx.GetSessionVars()
16851709
sysVar := variable.GetSysVar(name)
16861710
if sysVar == nil {
16871711
er.err = variable.ErrUnknownSystemVar.FastGenByArgs(name)

pkg/planner/core/expression_test.go

+42-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/pingcap/tidb/pkg/domain"
2222
"github.com/pingcap/tidb/pkg/expression"
2323
"github.com/pingcap/tidb/pkg/expression/context"
24+
"github.com/pingcap/tidb/pkg/expression/contextopt"
2425
"github.com/pingcap/tidb/pkg/expression/contextstatic"
2526
"github.com/pingcap/tidb/pkg/parser"
2627
"github.com/pingcap/tidb/pkg/parser/ast"
@@ -502,11 +503,51 @@ func TestBuildExpression(t *testing.T) {
502503
require.Equal(t, types.KindInt64, v.Kind())
503504
require.Equal(t, int64(7), v.GetInt64())
504505

506+
// user variable needs required option
507+
_, err = buildExpr(t, ctx, "@a")
508+
require.EqualError(t, err, "rewriting user variable requires 'OptPropSessionVars' in evalCtx")
509+
_, err = buildExpr(t, ctx, "@a := 1")
510+
require.EqualError(t, err, "rewriting user variable requires 'OptPropSessionVars' in evalCtx")
511+
512+
// reading user var
513+
vars := variable.NewSessionVars(nil)
514+
vars.TimeZone = evalCtx.Location()
515+
vars.StmtCtx.SetTimeZone(vars.Location())
516+
evalCtx = evalCtx.Apply(contextstatic.WithOptionalProperty(
517+
contextopt.NewSessionVarsProvider(vars),
518+
))
519+
ctx = ctx.Apply(contextstatic.WithEvalCtx(evalCtx))
520+
vars.SetUserVarVal("a", types.NewStringDatum("abc"))
521+
getVarExpr, err := buildExpr(t, ctx, "@a")
522+
require.NoError(t, err)
523+
v, err = getVarExpr.Eval(evalCtx, chunk.Row{})
524+
require.NoError(t, err)
525+
require.Equal(t, types.KindString, v.Kind())
526+
require.Equal(t, "abc", v.GetString())
527+
528+
// writing user var
529+
expr, err = buildExpr(t, ctx, "@a := 'def'")
530+
require.NoError(t, err)
531+
v, err = expr.Eval(evalCtx, chunk.Row{})
532+
require.NoError(t, err)
533+
require.Equal(t, types.KindString, v.Kind())
534+
require.Equal(t, "def", v.GetString())
535+
v, err = getVarExpr.Eval(evalCtx, chunk.Row{})
536+
require.NoError(t, err)
537+
require.Equal(t, types.KindString, v.Kind())
538+
require.Equal(t, "def", v.GetString())
539+
505540
// should report error for default expr when source table not provided
506541
_, err = buildExpr(t, ctx, "default(b)", expression.WithInputSchemaAndNames(schema, names, nil))
507542
require.EqualError(t, err, "Unsupported expr *ast.DefaultExpr when source table not provided")
508543

509544
// subquery not supported
510545
_, err = buildExpr(t, ctx, "a + (select b from t)", expression.WithTableInfo("", tbl))
511-
require.EqualError(t, err, "node '*ast.SubqueryExpr' is not allowed when building an expression without planner")
546+
require.EqualError(t, err, "planCtx is required when rewriting node: '*ast.SubqueryExpr'")
547+
548+
// system variables are not supported
549+
_, err = buildExpr(t, ctx, "@@tidb_enable_async_commit")
550+
require.EqualError(t, err, "planCtx is required when rewriting node: '*ast.VariableExpr', accessing system variable requires plan context")
551+
_, err = buildExpr(t, ctx, "@@global.tidb_enable_async_commit")
552+
require.EqualError(t, err, "planCtx is required when rewriting node: '*ast.VariableExpr', accessing system variable requires plan context")
512553
}

pkg/sessionctx/variable/session.go

+8
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,9 @@ type SessionVarsProvider interface {
727727
GetSessionVars() *SessionVars
728728
}
729729

730+
// SessionVars should implement `SessionVarsProvider`
731+
var _ SessionVarsProvider = &SessionVars{}
732+
730733
// SessionVars is to handle user-defined or global variables in the current session.
731734
type SessionVars struct {
732735
Concurrency
@@ -1657,6 +1660,11 @@ type SessionVars struct {
16571660
SharedLockPromotion bool
16581661
}
16591662

1663+
// GetSessionVars implements the `SessionVarsProvider` interface.
1664+
func (s *SessionVars) GetSessionVars() *SessionVars {
1665+
return s
1666+
}
1667+
16601668
// GetOptimizerFixControlMap returns the specified value of the optimizer fix control.
16611669
func (s *SessionVars) GetOptimizerFixControlMap() map[uint64]string {
16621670
return s.OptimizerFixControl

0 commit comments

Comments
 (0)