Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix wrong result type for greatest/least (#29408) #29914

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 141 additions & 32 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,16 @@ var (
_ builtinFunc = &builtinGreatestRealSig{}
_ builtinFunc = &builtinGreatestDecimalSig{}
_ builtinFunc = &builtinGreatestStringSig{}
_ builtinFunc = &builtinGreatestDurationSig{}
_ builtinFunc = &builtinGreatestTimeSig{}
_ builtinFunc = &builtinGreatestCmpStringAsTimeSig{}
_ builtinFunc = &builtinLeastIntSig{}
_ builtinFunc = &builtinLeastRealSig{}
_ builtinFunc = &builtinLeastDecimalSig{}
_ builtinFunc = &builtinLeastStringSig{}
_ builtinFunc = &builtinLeastTimeSig{}
_ builtinFunc = &builtinLeastDurationSig{}
_ builtinFunc = &builtinLeastCmpStringAsTimeSig{}
_ builtinFunc = &builtinIntervalIntSig{}
_ builtinFunc = &builtinIntervalRealSig{}

Expand Down Expand Up @@ -404,7 +408,7 @@ func ResolveType4Between(args [3]Expression) types.EvalType {
}

// resolveType4Extremum gets compare type for GREATEST and LEAST and BETWEEN (mainly for datetime).
func resolveType4Extremum(args []Expression) types.EvalType {
func resolveType4Extremum(args []Expression) (_ types.EvalType, cmpStringAsDatetime bool) {
aggType := aggregateType(args)

var temporalItem *types.FieldType
Expand All @@ -421,10 +425,11 @@ func resolveType4Extremum(args []Expression) types.EvalType {

if !types.IsTypeTemporal(aggType.Tp) && temporalItem != nil {
aggType.Tp = temporalItem.Tp
cmpStringAsDatetime = true
}
// TODO: String charset, collation checking are needed.
}
return aggType.EvalType()
return aggType.EvalType(), cmpStringAsDatetime
}

// unsupportedJSONComparison reports warnings while there is a JSON type in least/greatest function's arguments
Expand All @@ -446,12 +451,9 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err = c.verifyArgs(args); err != nil {
return nil, err
}
tp := resolveType4Extremum(args)
cmpAsDatetime := false
if tp == types.ETDatetime || tp == types.ETTimestamp {
cmpAsDatetime = true
tp = types.ETString
} else if tp == types.ETDuration {
tp, cmpStringAsDatetime := resolveType4Extremum(args)
if cmpStringAsDatetime {
// Args are temporal and string mixed, we cast all args as string and parse it to temporal mannualy to compare.
tp = types.ETString
} else if tp == types.ETJson {
unsupportedJSONComparison(ctx, args)
Expand All @@ -465,9 +467,6 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err != nil {
return nil, err
}
if cmpAsDatetime {
tp = types.ETDatetime
}
switch tp {
case types.ETInt:
sig = &builtinGreatestIntSig{bf}
Expand All @@ -479,8 +478,16 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
sig = &builtinGreatestDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestDecimal)
case types.ETString:
sig = &builtinGreatestStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestString)
if cmpStringAsDatetime {
sig = &builtinGreatestCmpStringAsTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestCmpStringAsTime)
} else {
sig = &builtinGreatestStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestString)
}
case types.ETDuration:
sig = &builtinGreatestDurationSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestDuration)
case types.ETDatetime, types.ETTimestamp:
sig = &builtinGreatestTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestTime)
Expand Down Expand Up @@ -622,23 +629,27 @@ func (b *builtinGreatestStringSig) evalString(row chunk.Row) (max string, isNull
return
}

type builtinGreatestTimeSig struct {
type builtinGreatestCmpStringAsTimeSig struct {
baseBuiltinFunc
}

func (b *builtinGreatestTimeSig) Clone() builtinFunc {
newSig := &builtinGreatestTimeSig{}
func (b *builtinGreatestCmpStringAsTimeSig) Clone() builtinFunc {
newSig := &builtinGreatestCmpStringAsTimeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

// evalString evals a builtinGreatestTimeSig.
// evalString evals a builtinGreatestCmpStringAsTimeSig.
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_greatest
<<<<<<< HEAD
func (b *builtinGreatestTimeSig) evalString(row chunk.Row) (res string, isNull bool, err error) {
var (
strRes string
timeRes types.Time
)
=======
func (b *builtinGreatestCmpStringAsTimeSig) evalString(row chunk.Row) (strRes string, isNull bool, err error) {
>>>>>>> 4b110036e... expression: fix wrong result type for greatest/least (#29408)
sc := b.ctx.GetSessionVars().StmtCtx
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalString(b.ctx, row)
Expand Down Expand Up @@ -669,6 +680,52 @@ func (b *builtinGreatestTimeSig) evalString(row chunk.Row) (res string, isNull b
return res, false, nil
}

type builtinGreatestTimeSig struct {
baseBuiltinFunc
}

func (b *builtinGreatestTimeSig) Clone() builtinFunc {
newSig := &builtinGreatestTimeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinGreatestTimeSig) evalTime(row chunk.Row) (res types.Time, isNull bool, err error) {
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalTime(b.ctx, row)
if isNull || err != nil {
return types.ZeroTime, true, err
}
if i == 0 || v.Compare(res) > 0 {
res = v
}
}
return res, false, nil
}

type builtinGreatestDurationSig struct {
baseBuiltinFunc
}

func (b *builtinGreatestDurationSig) Clone() builtinFunc {
newSig := &builtinGreatestDurationSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinGreatestDurationSig) evalDuration(row chunk.Row) (res types.Duration, isNull bool, err error) {
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalDuration(b.ctx, row)
if isNull || err != nil {
return types.Duration{}, true, err
}
if i == 0 || v.Compare(res) > 0 {
res = v
}
}
return res, false, nil
}

type leastFunctionClass struct {
baseFunctionClass
}
Expand All @@ -677,12 +734,9 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
if err = c.verifyArgs(args); err != nil {
return nil, err
}
tp := resolveType4Extremum(args)
cmpAsDatetime := false
if tp == types.ETDatetime || tp == types.ETTimestamp {
cmpAsDatetime = true
tp = types.ETString
} else if tp == types.ETDuration {
tp, cmpStringAsDatetime := resolveType4Extremum(args)
if cmpStringAsDatetime {
// Args are temporal and string mixed, we cast all args as string and parse it to temporal mannualy to compare.
tp = types.ETString
} else if tp == types.ETJson {
unsupportedJSONComparison(ctx, args)
Expand All @@ -696,9 +750,6 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
if err != nil {
return nil, err
}
if cmpAsDatetime {
tp = types.ETDatetime
}
switch tp {
case types.ETInt:
sig = &builtinLeastIntSig{bf}
Expand All @@ -710,8 +761,16 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
sig = &builtinLeastDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LeastDecimal)
case types.ETString:
sig = &builtinLeastStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LeastString)
if cmpStringAsDatetime {
sig = &builtinLeastCmpStringAsTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LeastCmpStringAsTime)
} else {
sig = &builtinLeastStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LeastString)
}
case types.ETDuration:
sig = &builtinLeastDurationSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LeastDuration)
case types.ETDatetime, types.ETTimestamp:
sig = &builtinLeastTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LeastTime)
Expand Down Expand Up @@ -840,24 +899,28 @@ func (b *builtinLeastStringSig) evalString(row chunk.Row) (min string, isNull bo
return
}

type builtinLeastTimeSig struct {
type builtinLeastCmpStringAsTimeSig struct {
baseBuiltinFunc
}

func (b *builtinLeastTimeSig) Clone() builtinFunc {
newSig := &builtinLeastTimeSig{}
func (b *builtinLeastCmpStringAsTimeSig) Clone() builtinFunc {
newSig := &builtinLeastCmpStringAsTimeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

// evalString evals a builtinLeastTimeSig.
// evalString evals a builtinLeastCmpStringAsTimeSig.
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#functionleast
<<<<<<< HEAD
func (b *builtinLeastTimeSig) evalString(row chunk.Row) (res string, isNull bool, err error) {
var (
// timeRes will be converted to a strRes only when the arguments is a valid datetime value.
strRes string // Record the strRes of each arguments.
timeRes types.Time // Record the time representation of a valid arguments.
)
=======
func (b *builtinLeastCmpStringAsTimeSig) evalString(row chunk.Row) (strRes string, isNull bool, err error) {
>>>>>>> 4b110036e... expression: fix wrong result type for greatest/least (#29408)
sc := b.ctx.GetSessionVars().StmtCtx
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalString(b.ctx, row)
Expand Down Expand Up @@ -888,6 +951,52 @@ func (b *builtinLeastTimeSig) evalString(row chunk.Row) (res string, isNull bool
return res, false, nil
}

type builtinLeastTimeSig struct {
baseBuiltinFunc
}

func (b *builtinLeastTimeSig) Clone() builtinFunc {
newSig := &builtinLeastTimeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinLeastTimeSig) evalTime(row chunk.Row) (res types.Time, isNull bool, err error) {
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalTime(b.ctx, row)
if isNull || err != nil {
return types.ZeroTime, true, err
}
if i == 0 || v.Compare(res) < 0 {
res = v
}
}
return res, false, nil
}

type builtinLeastDurationSig struct {
baseBuiltinFunc
}

func (b *builtinLeastDurationSig) Clone() builtinFunc {
newSig := &builtinLeastDurationSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinLeastDurationSig) evalDuration(row chunk.Row) (res types.Duration, isNull bool, err error) {
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalDuration(b.ctx, row)
if isNull || err != nil {
return types.Duration{}, true, err
}
if i == 0 || v.Compare(res) < 0 {
res = v
}
}
return res, false, nil
}

type intervalFunctionClass struct {
baseFunctionClass
}
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) {
},
{
[]interface{}{duration, duration},
"12:59:59", "12:59:59", false, false,
duration, duration, false, false,
},
{
[]interface{}{"123", nil, "123"},
Expand Down
Loading