Skip to content

Commit

Permalink
expression: add expression.BuildContext to build expressions (#50662)
Browse files Browse the repository at this point in the history
close #50661
  • Loading branch information
lcwangchao authored Jan 29, 2024
1 parent afeabbb commit ff050bb
Show file tree
Hide file tree
Showing 34 changed files with 415 additions and 421 deletions.
4 changes: 2 additions & 2 deletions pkg/expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
if _, ok := noNeedCastAggFuncs[a.Name]; ok {
return
}
var castFunc func(ctx sessionctx.Context, expr expression.Expression) expression.Expression
var castFunc func(ctx expression.BuildContext, expr expression.Expression) expression.Expression
switch retTp := a.RetTp; retTp.EvalType() {
case types.ETInt:
castFunc = expression.WrapWithCastAsInt
Expand All @@ -416,7 +416,7 @@ func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
case types.ETDecimal:
castFunc = expression.WrapWithCastAsDecimal
case types.ETDatetime, types.ETTimestamp:
castFunc = func(ctx sessionctx.Context, expr expression.Expression) expression.Expression {
castFunc = func(ctx expression.BuildContext, expr expression.Expression) expression.Expression {
return expression.WrapWithCastAsTime(ctx, expr, retTp)
}
case types.ETDuration:
Expand Down
9 changes: 4 additions & 5 deletions pkg/expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import (
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/opcode"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/collate"
Expand Down Expand Up @@ -106,7 +105,7 @@ func adjustNullFlagForReturnType(funcName string, args []Expression, bf baseBuil
}
}

func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expression, tp *types.FieldType) (baseBuiltinFunc, error) {
func newBaseBuiltinFunc(ctx BuildContext, funcName string, args []Expression, tp *types.FieldType) (baseBuiltinFunc, error) {
if ctx == nil {
return baseBuiltinFunc{}, errors.New("unexpected nil session ctx")
}
Expand Down Expand Up @@ -164,7 +163,7 @@ func newReturnFieldTypeForBaseBuiltinFunc(funcName string, retType types.EvalTyp
// newBaseBuiltinFuncWithTp creates a built-in function signature with specified types of arguments and the return type of the function.
// argTps indicates the types of the args, retType indicates the return type of the built-in function.
// Every built-in function needs to be determined argTps and retType when we create it.
func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Expression, retType types.EvalType, argTps ...types.EvalType) (bf baseBuiltinFunc, err error) {
func newBaseBuiltinFuncWithTp(ctx BuildContext, funcName string, args []Expression, retType types.EvalType, argTps ...types.EvalType) (bf baseBuiltinFunc, err error) {
if len(args) != len(argTps) {
panic("unexpected length of args and argTps")
}
Expand Down Expand Up @@ -222,7 +221,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex
// argTps indicates the field types of the args, retType indicates the return type of the built-in function.
// newBaseBuiltinFuncWithTp and newBaseBuiltinFuncWithFieldTypes are essentially the same, but newBaseBuiltinFuncWithFieldTypes uses FieldType to cast args.
// If there are specific requirements for decimal/datetime/timestamp, newBaseBuiltinFuncWithFieldTypes should be used, such as if,ifnull and casewhen.
func newBaseBuiltinFuncWithFieldTypes(ctx sessionctx.Context, funcName string, args []Expression, retType types.EvalType, argTps ...*types.FieldType) (bf baseBuiltinFunc, err error) {
func newBaseBuiltinFuncWithFieldTypes(ctx BuildContext, funcName string, args []Expression, retType types.EvalType, argTps ...*types.FieldType) (bf baseBuiltinFunc, err error) {
if len(args) != len(argTps) {
panic("unexpected length of args and argTps")
}
Expand Down Expand Up @@ -549,7 +548,7 @@ func VerifyArgsWrapper(name string, l int) error {
// functionClass is the interface for a function which may contains multiple functions.
type functionClass interface {
// getFunction gets a function signature by the types and the counts of given arguments.
getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error)
getFunction(ctx BuildContext, args []Expression) (builtinFunc, error)
// verifyArgsByCount verifies the count of parameters.
verifyArgsByCount(l int) error
}
Expand Down
27 changes: 13 additions & 14 deletions pkg/expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (

"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/mathutil"
Expand Down Expand Up @@ -103,7 +102,7 @@ func numericContextResultType(expr Expression) types.EvalType {

// setFlenDecimal4RealOrDecimal is called to set proper `flen` and `decimal` of return
// type according to the two input parameter's types.
func setFlenDecimal4RealOrDecimal(ctx sessionctx.Context, retTp *types.FieldType, arg0, arg1 Expression, isReal bool, isMultiply bool) {
func setFlenDecimal4RealOrDecimal(retTp *types.FieldType, arg0, arg1 Expression, isReal, isMultiply bool) {
a, b := arg0.GetType(), arg1.GetType()
if a.GetDecimal() != types.UnspecifiedLength && b.GetDecimal() != types.UnspecifiedLength {
retTp.SetDecimalUnderLimit(a.GetDecimal() + b.GetDecimal())
Expand Down Expand Up @@ -167,7 +166,7 @@ type arithmeticPlusFunctionClass struct {
baseFunctionClass
}

func (c *arithmeticPlusFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
func (c *arithmeticPlusFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
Expand All @@ -177,7 +176,7 @@ func (c *arithmeticPlusFunctionClass) getFunction(ctx sessionctx.Context, args [
if err != nil {
return nil, err
}
setFlenDecimal4RealOrDecimal(ctx, bf.tp, args[0], args[1], true, false)
setFlenDecimal4RealOrDecimal(bf.tp, args[0], args[1], true, false)
sig := &builtinArithmeticPlusRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_PlusReal)
return sig, nil
Expand All @@ -186,7 +185,7 @@ func (c *arithmeticPlusFunctionClass) getFunction(ctx sessionctx.Context, args [
if err != nil {
return nil, err
}
setFlenDecimal4RealOrDecimal(ctx, bf.tp, args[0], args[1], false, false)
setFlenDecimal4RealOrDecimal(bf.tp, args[0], args[1], false, false)
sig := &builtinArithmeticPlusDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_PlusDecimal)
return sig, nil
Expand Down Expand Up @@ -317,7 +316,7 @@ type arithmeticMinusFunctionClass struct {
baseFunctionClass
}

func (c *arithmeticMinusFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
func (c *arithmeticMinusFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
Expand All @@ -327,7 +326,7 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx sessionctx.Context, args
if err != nil {
return nil, err
}
setFlenDecimal4RealOrDecimal(ctx, bf.tp, args[0], args[1], true, false)
setFlenDecimal4RealOrDecimal(bf.tp, args[0], args[1], true, false)
sig := &builtinArithmeticMinusRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MinusReal)
return sig, nil
Expand All @@ -336,7 +335,7 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx sessionctx.Context, args
if err != nil {
return nil, err
}
setFlenDecimal4RealOrDecimal(ctx, bf.tp, args[0], args[1], false, false)
setFlenDecimal4RealOrDecimal(bf.tp, args[0], args[1], false, false)
sig := &builtinArithmeticMinusDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MinusDecimal)
return sig, nil
Expand Down Expand Up @@ -500,7 +499,7 @@ type arithmeticMultiplyFunctionClass struct {
baseFunctionClass
}

func (c *arithmeticMultiplyFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
func (c *arithmeticMultiplyFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
Expand All @@ -511,7 +510,7 @@ func (c *arithmeticMultiplyFunctionClass) getFunction(ctx sessionctx.Context, ar
if err != nil {
return nil, err
}
setFlenDecimal4RealOrDecimal(ctx, bf.tp, args[0], args[1], true, true)
setFlenDecimal4RealOrDecimal(bf.tp, args[0], args[1], true, true)
sig := &builtinArithmeticMultiplyRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyReal)
return sig, nil
Expand All @@ -520,7 +519,7 @@ func (c *arithmeticMultiplyFunctionClass) getFunction(ctx sessionctx.Context, ar
if err != nil {
return nil, err
}
setFlenDecimal4RealOrDecimal(ctx, bf.tp, args[0], args[1], false, true)
setFlenDecimal4RealOrDecimal(bf.tp, args[0], args[1], false, true)
sig := &builtinArithmeticMultiplyDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_MultiplyDecimal)
return sig, nil
Expand Down Expand Up @@ -646,7 +645,7 @@ type arithmeticDivideFunctionClass struct {
baseFunctionClass
}

func (c *arithmeticDivideFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
func (c *arithmeticDivideFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
Expand Down Expand Up @@ -740,7 +739,7 @@ type arithmeticIntDivideFunctionClass struct {
baseFunctionClass
}

func (c *arithmeticIntDivideFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
func (c *arithmeticIntDivideFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
Expand Down Expand Up @@ -899,7 +898,7 @@ func (c *arithmeticModFunctionClass) setType4ModRealOrDecimal(retTp, a, b *types
}
}

func (c *arithmeticModFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
func (c *arithmeticModFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
Expand Down
21 changes: 10 additions & 11 deletions pkg/expression/builtin_arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"github.com/pingcap/tidb/pkg/testkit/testutil"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/mock"
"github.com/pingcap/tipb/go-tipb"
"github.com/stretchr/testify/require"
)
Expand All @@ -38,25 +37,25 @@ func TestSetFlenDecimal4RealOrDecimal(t *testing.T) {
b := &types.FieldType{}
b.SetDecimal(0)
b.SetFlag(2)
setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
setFlenDecimal4RealOrDecimal(ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, 4, ret.GetFlen())

b.SetFlen(65)
setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
setFlenDecimal4RealOrDecimal(ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, mysql.MaxRealWidth, ret.GetFlen())
setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, false, false)
setFlenDecimal4RealOrDecimal(ret, &Constant{RetType: a}, &Constant{RetType: b}, false, false)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, mysql.MaxDecimalWidth, ret.GetFlen())

b.SetFlen(types.UnspecifiedLength)
setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
setFlenDecimal4RealOrDecimal(ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, types.UnspecifiedLength, ret.GetFlen())

b.SetDecimal(types.UnspecifiedLength)
setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
setFlenDecimal4RealOrDecimal(ret, &Constant{RetType: a}, &Constant{RetType: b}, true, false)
require.Equal(t, types.UnspecifiedLength, ret.GetDecimal())
require.Equal(t, types.UnspecifiedLength, ret.GetFlen())

Expand All @@ -69,25 +68,25 @@ func TestSetFlenDecimal4RealOrDecimal(t *testing.T) {
b.SetDecimal(0)
b.SetFlen(2)

setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
setFlenDecimal4RealOrDecimal(ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, 5, ret.GetFlen())

b.SetFlen(65)
setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
setFlenDecimal4RealOrDecimal(ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, mysql.MaxRealWidth, ret.GetFlen())
setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, false, true)
setFlenDecimal4RealOrDecimal(ret, &Constant{RetType: a}, &Constant{RetType: b}, false, true)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, mysql.MaxDecimalWidth, ret.GetFlen())

b.SetFlen(types.UnspecifiedLength)
setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
setFlenDecimal4RealOrDecimal(ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
require.Equal(t, 1, ret.GetDecimal())
require.Equal(t, types.UnspecifiedLength, ret.GetFlen())

b.SetDecimal(types.UnspecifiedLength)
setFlenDecimal4RealOrDecimal(mock.NewContext(), ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
setFlenDecimal4RealOrDecimal(ret, &Constant{RetType: a}, &Constant{RetType: b}, true, true)
require.Equal(t, types.UnspecifiedLength, ret.GetDecimal())
require.Equal(t, types.UnspecifiedLength, ret.GetFlen())
}
Expand Down
Loading

0 comments on commit ff050bb

Please sign in to comment.