Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix type infer for tidb's builtin compare(least and great… #22562

Merged
merged 5 commits into from
Jan 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 107 additions & 71 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package expression

import (
"math"
"strings"

"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
Expand Down Expand Up @@ -367,53 +368,67 @@ func (b *builtinCoalesceJSONSig) evalJSON(row chunk.Row) (res json.BinaryJSON, i
return res, isNull, err
}

// temporalWithDateAsNumEvalType makes DATE, DATETIME, TIMESTAMP pretend to be numbers rather than strings.
func temporalWithDateAsNumEvalType(argTp *types.FieldType) (argEvalType types.EvalType, isStr bool, isTemporalWithDate bool) {
argEvalType = argTp.EvalType()
isStr, isTemporalWithDate = argEvalType.IsStringKind(), types.IsTemporalWithDate(argTp.Tp)
if !isTemporalWithDate {
return
func aggregateType(args []Expression) *types.FieldType {
fieldTypes := make([]*types.FieldType, len(args))
for i := range fieldTypes {
fieldTypes[i] = args[i].GetType()
}
if argTp.Decimal > 0 {
argEvalType = types.ETDecimal
} else {
argEvalType = types.ETInt
}
return
return types.AggFieldType(fieldTypes)
}

// GetCmpTp4MinMax gets compare type for GREATEST and LEAST and BETWEEN
func GetCmpTp4MinMax(args []Expression) (argTp types.EvalType) {
datetimeFound, isAllStr := false, true
cmpEvalType, isStr, isTemporalWithDate := temporalWithDateAsNumEvalType(args[0].GetType())
if !isStr {
isAllStr = false
// ResolveType4Between resolves eval type for between expression.
func ResolveType4Between(args [3]Expression) types.EvalType {
cmpTp := args[0].GetType().EvalType()
for i := 1; i < 3; i++ {
cmpTp = getBaseCmpType(cmpTp, args[i].GetType().EvalType(), nil, nil)
}
if isTemporalWithDate {
datetimeFound = true
}
lft := args[0].GetType()
for i := range args {
rft := args[i].GetType()
var tp types.EvalType
tp, isStr, isTemporalWithDate = temporalWithDateAsNumEvalType(rft)
if isTemporalWithDate {
datetimeFound = true

hasTemporal := false
if cmpTp == types.ETString {
for _, arg := range args {
if types.IsTypeTemporal(arg.GetType().Tp) {
hasTemporal = true
break
}
}
if !isStr {
isAllStr = false
if hasTemporal {
cmpTp = types.ETDatetime
}
cmpEvalType = getBaseCmpType(cmpEvalType, tp, lft, rft)
lft = rft
}
argTp = cmpEvalType
if cmpEvalType.IsStringKind() {
argTp = types.ETString

return cmpTp
}

// resolveType4Extremum gets compare type for GREATEST and LEAST and BETWEEN (mainly for datetime).
func resolveType4Extremum(args []Expression) types.EvalType {
aggType := aggregateType(args)

var temporalItem *types.FieldType
if aggType.EvalType().IsStringKind() {
for i := range args {
item := args[i].GetType()
if types.IsTemporalWithDate(item.Tp) {
temporalItem = item
}
}

if !types.IsTemporalWithDate(aggType.Tp) && temporalItem != nil {
aggType.Tp = temporalItem.Tp
}
// TODO: String charset, collation checking are needed.
}
if isAllStr && datetimeFound {
argTp = types.ETDatetime
return aggType.EvalType()
}

// unsupportedJSONComparison reports warnings while there is a JSON type in least/greatest function's arguments
func unsupportedJSONComparison(ctx sessionctx.Context, args []Expression) {
for _, arg := range args {
tp := arg.GetType().Tp
if tp == mysql.TypeJSON {
ctx.GetSessionVars().StmtCtx.AppendWarning(errUnsupportedJSONComparison)
break
}
}
return argTp
}

type greatestFunctionClass struct {
Expand All @@ -424,10 +439,14 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err = c.verifyArgs(args); err != nil {
return nil, err
}
tp, cmpAsDatetime := GetCmpTp4MinMax(args), false
if tp == types.ETDatetime {
tp := resolveType4Extremum(args)
cmpAsDatetime := false
if tp == types.ETDatetime || tp == types.ETTimestamp {
cmpAsDatetime = true
tp = types.ETString
} else if tp == types.ETJson {
unsupportedJSONComparison(ctx, args)
tp = types.ETString
}
argTps := make([]types.EvalType, len(args))
for i := range args {
Expand All @@ -453,7 +472,7 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
case types.ETString:
sig = &builtinGreatestStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestString)
case types.ETDatetime:
case types.ETDatetime, types.ETTimestamp:
sig = &builtinGreatestTimeSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestTime)
}
Expand Down Expand Up @@ -592,30 +611,39 @@ func (b *builtinGreatestTimeSig) Clone() builtinFunc {

// evalString evals a builtinGreatestTimeSig.
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_greatest
func (b *builtinGreatestTimeSig) evalString(row chunk.Row) (_ string, isNull bool, err error) {
func (b *builtinGreatestTimeSig) evalString(row chunk.Row) (res string, isNull bool, err error) {
var (
v string
t types.Time
strRes string
timeRes types.Time
)
max := types.ZeroDatetime
sc := b.ctx.GetSessionVars().StmtCtx
for i := 0; i < len(b.args); i++ {
v, isNull, err = b.args[i].EvalString(b.ctx, row)
v, isNull, err := b.args[i].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, err
}
t, err = types.ParseDatetime(sc, v)
t, err := types.ParseDatetime(sc, v)
if err != nil {
if err = handleInvalidTimeError(b.ctx, err); err != nil {
return v, true, err
}
continue
} else {
v = t.String()
}
// In MySQL, if the compare result is zero, than we will try to use the string comparison result
if i == 0 || strings.Compare(v, strRes) > 0 {
strRes = v
}
if t.Compare(max) > 0 {
max = t
if i == 0 || t.Compare(timeRes) > 0 {
timeRes = t
}
}
return max.String(), false, nil
if timeRes.IsZero() {
res = strRes
} else {
res = timeRes.String()
}
return res, false, nil
}

type leastFunctionClass struct {
Expand All @@ -626,10 +654,14 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
if err = c.verifyArgs(args); err != nil {
return nil, err
}
tp, cmpAsDatetime := GetCmpTp4MinMax(args), false
tp := resolveType4Extremum(args)
cmpAsDatetime := false
if tp == types.ETDatetime {
cmpAsDatetime = true
tp = types.ETString
} else if tp == types.ETJson {
unsupportedJSONComparison(ctx, args)
tp = types.ETString
}
argTps := make([]types.EvalType, len(args))
for i := range args {
Expand Down Expand Up @@ -796,32 +828,36 @@ func (b *builtinLeastTimeSig) Clone() builtinFunc {
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#functionleast
func (b *builtinLeastTimeSig) evalString(row chunk.Row) (res string, isNull bool, err error) {
var (
v string
t types.Time
// timeRes will be converted to a strRes only when the arguments is a valid datetime value.
strRes string // Record the strRes of each arguments.
timeRes types.Time // Record the time representation of a valid arguments.
)
min := types.NewTime(types.MaxDatetime, mysql.TypeDatetime, types.MaxFsp)
findInvalidTime := false
sc := b.ctx.GetSessionVars().StmtCtx
for i := 0; i < len(b.args); i++ {
v, isNull, err = b.args[i].EvalString(b.ctx, row)
v, isNull, err := b.args[i].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, err
}
t, err = types.ParseDatetime(sc, v)
t, err := types.ParseDatetime(sc, v)
if err != nil {
if err = handleInvalidTimeError(b.ctx, err); err != nil {
return v, true, err
} else if !findInvalidTime {
res = v
findInvalidTime = true
}
} else {
v = t.String()
}
if i == 0 || strings.Compare(v, strRes) < 0 {
strRes = v
}
if t.Compare(min) < 0 {
min = t
if i == 0 || t.Compare(timeRes) < 0 {
timeRes = t
}
}
if !findInvalidTime {
res = min.String()

if timeRes.IsZero() {
res = strRes
} else {
res = timeRes.String()
}
return res, false, nil
}
Expand Down Expand Up @@ -1042,7 +1078,7 @@ type compareFunctionClass struct {

// getBaseCmpType gets the EvalType that the two args will be treated as when comparing.
func getBaseCmpType(lhs, rhs types.EvalType, lft, rft *types.FieldType) types.EvalType {
if lft.Tp == mysql.TypeUnspecified || rft.Tp == mysql.TypeUnspecified {
if lft != nil && rft != nil && (lft.Tp == mysql.TypeUnspecified || rft.Tp == mysql.TypeUnspecified) {
if lft.Tp == rft.Tp {
return types.ETString
}
Expand All @@ -1054,13 +1090,13 @@ func getBaseCmpType(lhs, rhs types.EvalType, lft, rft *types.FieldType) types.Ev
}
if lhs.IsStringKind() && rhs.IsStringKind() {
return types.ETString
} else if (lhs == types.ETInt || lft.Hybrid()) && (rhs == types.ETInt || rft.Hybrid()) {
} else if (lhs == types.ETInt || (lft != nil && lft.Hybrid())) && (rhs == types.ETInt || (rft != nil && rft.Hybrid())) {
return types.ETInt
} else if ((lhs == types.ETInt || lft.Hybrid()) || lhs == types.ETDecimal) &&
((rhs == types.ETInt || rft.Hybrid()) || rhs == types.ETDecimal) {
} else if ((lhs == types.ETInt || (lft != nil && lft.Hybrid())) || lhs == types.ETDecimal) &&
((rhs == types.ETInt || (rft != nil && rft.Hybrid())) || rhs == types.ETDecimal) {
return types.ETDecimal
} else if types.IsTemporalWithDate(lft.Tp) && rft.Tp == mysql.TypeYear ||
lft.Tp == mysql.TypeYear && types.IsTemporalWithDate(rft.Tp) {
} else if lft != nil && rft != nil && (types.IsTemporalWithDate(lft.Tp) && rft.Tp == mysql.TypeYear ||
lft.Tp == mysql.TypeYear && types.IsTemporalWithDate(rft.Tp)) {
return types.ETDatetime
}
return types.ETReal
Expand Down
19 changes: 14 additions & 5 deletions expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ func (s *testEvaluatorSuite) TestIntervalFunc(c *C) {
}
}

func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) {
// greatest/least function is compatible with MySQL 8.0
func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) {
sc := s.ctx.GetSessionVars().StmtCtx
originIgnoreTruncate := sc.IgnoreTruncate
sc.IgnoreTruncate = true
Expand All @@ -282,23 +283,23 @@ func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) {
},
{
[]interface{}{"123a", "b", "c", 12},
float64(123), float64(0), false, false,
"c", "12", false, false,
},
{
[]interface{}{tm, "123"},
curTimeString, "123", false, false,
},
{
[]interface{}{tm, 123},
curTimeInt, int64(123), false, false,
curTimeString, "123", false, false,
},
{
[]interface{}{tm, "invalid_time_1", "invalid_time_2", tmWithFsp},
curTimeWithFspString, "invalid_time_1", false, false,
curTimeWithFspString, curTimeString, false, false,
},
{
[]interface{}{tm, "invalid_time_2", "invalid_time_1", tmWithFsp},
curTimeWithFspString, "invalid_time_2", false, false,
curTimeWithFspString, curTimeString, false, false,
},
{
[]interface{}{tm, "invalid_time", nil, tmWithFsp},
Expand All @@ -316,6 +317,14 @@ func (s *testEvaluatorSuite) TestGreatestLeastFuncs(c *C) {
[]interface{}{errors.New("must error"), 123},
nil, nil, false, true,
},
{
[]interface{}{794755072.0, 4556, "2000-01-09"},
"794755072", "2000-01-09", false, false,
},
{
[]interface{}{905969664.0, 4556, "1990-06-16 17:22:56.005534"},
"905969664", "1990-06-16 17:22:56.005534", false, false,
},
} {
f0, err := newFunctionForTest(s.ctx, ast.Greatest, s.primitiveValsToConstants(t.args)...)
c.Assert(err, IsNil)
Expand Down
Loading