From 5dd429688a8928fb7a475d14980613e333932f02 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 5 Sep 2023 15:47:27 +0800 Subject: [PATCH] *: use cmp.Compare to replace types.Compare (#46657) ref pingcap/tidb#45933 --- executor/aggfuncs/func_max_min.go | 9 +++-- executor/aggfuncs/func_max_min_test.go | 5 ++- expression/builtin_compare.go | 21 +++++----- expression/builtin_compare_vec_generated.go | 30 +++++++------- expression/builtin_other_vec_generated.go | 6 ++- expression/generator/compare_vec.go | 10 +++-- expression/generator/other_vec.go | 6 ++- store/mockstore/unistore/cophandler/topn.go | 3 +- types/compare.go | 45 --------------------- types/datum.go | 29 ++++++------- util/chunk/compare.go | 27 +++++++------ util/ranger/points.go | 3 +- 12 files changed, 83 insertions(+), 111 deletions(-) diff --git a/executor/aggfuncs/func_max_min.go b/executor/aggfuncs/func_max_min.go index 7f7f302bbeda2..a9bc2b6b69692 100644 --- a/executor/aggfuncs/func_max_min.go +++ b/executor/aggfuncs/func_max_min.go @@ -15,6 +15,7 @@ package aggfuncs import ( + "cmp" "unsafe" "github.com/pingcap/errors" @@ -313,7 +314,7 @@ func (e *maxMin4IntSliding) ResetPartialResult(pr PartialResult) { func (e *maxMin4IntSliding) AllocPartialResult() (pr PartialResult, memDelta int64) { p, memDelta := e.maxMin4Int.AllocPartialResult() (*partialResult4MaxMinInt)(p).deque = NewDeque(e.isMax, func(i, j interface{}) int { - return types.CompareInt64(i.(int64), j.(int64)) + return cmp.Compare(i.(int64), j.(int64)) }) return p, memDelta + DefMaxMinDequeSize } @@ -450,7 +451,7 @@ type maxMin4UintSliding struct { func (e *maxMin4UintSliding) AllocPartialResult() (pr PartialResult, memDelta int64) { p, memDelta := e.maxMin4Uint.AllocPartialResult() (*partialResult4MaxMinUint)(p).deque = NewDeque(e.isMax, func(i, j interface{}) int { - return types.CompareUint64(i.(uint64), j.(uint64)) + return cmp.Compare(i.(uint64), j.(uint64)) }) return p, memDelta + DefMaxMinDequeSize } @@ -589,7 +590,7 @@ type maxMin4Float32Sliding struct { func (e *maxMin4Float32Sliding) AllocPartialResult() (pr PartialResult, memDelta int64) { p, memDelta := e.maxMin4Float32.AllocPartialResult() (*partialResult4MaxMinFloat32)(p).deque = NewDeque(e.isMax, func(i, j interface{}) int { - return types.CompareFloat64(float64(i.(float32)), float64(j.(float32))) + return cmp.Compare(float64(i.(float32)), float64(j.(float32))) }) return p, memDelta + DefMaxMinDequeSize } @@ -726,7 +727,7 @@ type maxMin4Float64Sliding struct { func (e *maxMin4Float64Sliding) AllocPartialResult() (pr PartialResult, memDelta int64) { p, memDelta := e.maxMin4Float64.AllocPartialResult() (*partialResult4MaxMinFloat64)(p).deque = NewDeque(e.isMax, func(i, j interface{}) int { - return types.CompareFloat64(i.(float64), j.(float64)) + return cmp.Compare(i.(float64), j.(float64)) }) return p, memDelta + DefMaxMinDequeSize } diff --git a/executor/aggfuncs/func_max_min_test.go b/executor/aggfuncs/func_max_min_test.go index 2505b186fe61b..4eff13e3ceabf 100644 --- a/executor/aggfuncs/func_max_min_test.go +++ b/executor/aggfuncs/func_max_min_test.go @@ -15,6 +15,7 @@ package aggfuncs_test import ( + "cmp" "fmt" "testing" "time" @@ -336,7 +337,7 @@ func TestMaxSlidingWindow(t *testing.T) { func TestDequeReset(t *testing.T) { deque := aggfuncs.NewDeque(true, func(i, j interface{}) int { - return types.CompareInt64(i.(int64), j.(int64)) + return cmp.Compare(i.(int64), j.(int64)) }) deque.PushBack(0, 12) deque.Reset() @@ -346,7 +347,7 @@ func TestDequeReset(t *testing.T) { func TestDequePushPop(t *testing.T) { deque := aggfuncs.NewDeque(true, func(i, j interface{}) int { - return types.CompareInt64(i.(int64), j.(int64)) + return cmp.Compare(i.(int64), j.(int64)) }) times := 15 // pushes element from back of deque diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 0f894ee728f3d..52f845bc35370 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -15,6 +15,7 @@ package expression import ( + "cmp" "math" "strings" @@ -2627,22 +2628,22 @@ func (b *builtinNullEQIntSig) evalInt(row chunk.Row) (val int64, isNull bool, er res = 1 case isNull0 != isNull1: return res, false, nil - case isUnsigned0 && isUnsigned1 && types.CompareUint64(uint64(arg0), uint64(arg1)) == 0: + case isUnsigned0 && isUnsigned1 && cmp.Compare(uint64(arg0), uint64(arg1)) == 0: res = 1 - case !isUnsigned0 && !isUnsigned1 && types.CompareInt64(arg0, arg1) == 0: + case !isUnsigned0 && !isUnsigned1 && cmp.Compare(arg0, arg1) == 0: res = 1 case isUnsigned0 && !isUnsigned1: if arg1 < 0 { return res, false, nil } - if types.CompareInt64(arg0, arg1) == 0 { + if cmp.Compare(arg0, arg1) == 0 { res = 1 } case !isUnsigned0 && isUnsigned1: if arg0 < 0 { return res, false, nil } - if types.CompareInt64(arg0, arg1) == 0 { + if cmp.Compare(arg0, arg1) == 0 { res = 1 } } @@ -2674,7 +2675,7 @@ func (b *builtinNullEQRealSig) evalInt(row chunk.Row) (val int64, isNull bool, e res = 1 case isNull0 != isNull1: return res, false, nil - case types.CompareFloat64(arg0, arg1) == 0: + case cmp.Compare(arg0, arg1) == 0: res = 1 } return res, false, nil @@ -2948,21 +2949,21 @@ func CompareInt(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsR var res int switch { case isUnsigned0 && isUnsigned1: - res = types.CompareUint64(uint64(arg0), uint64(arg1)) + res = cmp.Compare(uint64(arg0), uint64(arg1)) case isUnsigned0 && !isUnsigned1: if arg1 < 0 || uint64(arg0) > math.MaxInt64 { res = 1 } else { - res = types.CompareInt64(arg0, arg1) + res = cmp.Compare(arg0, arg1) } case !isUnsigned0 && isUnsigned1: if arg0 < 0 || uint64(arg1) > math.MaxInt64 { res = -1 } else { - res = types.CompareInt64(arg0, arg1) + res = cmp.Compare(arg0, arg1) } case !isUnsigned0 && !isUnsigned1: - res = types.CompareInt64(arg0, arg1) + res = cmp.Compare(arg0, arg1) } return int64(res), false, nil } @@ -3006,7 +3007,7 @@ func CompareReal(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhs if isNull0 || isNull1 { return compareNull(isNull0, isNull1), true, nil } - return int64(types.CompareFloat64(arg0, arg1)), false, nil + return int64(cmp.Compare(arg0, arg1)), false, nil } // CompareDecimal compares two decimals. diff --git a/expression/builtin_compare_vec_generated.go b/expression/builtin_compare_vec_generated.go index e4dca22a50acc..5b53f65229c31 100644 --- a/expression/builtin_compare_vec_generated.go +++ b/expression/builtin_compare_vec_generated.go @@ -17,6 +17,8 @@ package expression import ( + "cmp" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" ) @@ -49,7 +51,7 @@ func (b *builtinLTRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) if result.IsNull(i) { continue } - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val < 0) } return nil @@ -199,7 +201,7 @@ func (b *builtinLTDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu if result.IsNull(i) { continue } - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val < 0) } return nil @@ -273,7 +275,7 @@ func (b *builtinLERealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) if result.IsNull(i) { continue } - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val <= 0) } return nil @@ -423,7 +425,7 @@ func (b *builtinLEDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu if result.IsNull(i) { continue } - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val <= 0) } return nil @@ -497,7 +499,7 @@ func (b *builtinGTRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) if result.IsNull(i) { continue } - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val > 0) } return nil @@ -647,7 +649,7 @@ func (b *builtinGTDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu if result.IsNull(i) { continue } - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val > 0) } return nil @@ -721,7 +723,7 @@ func (b *builtinGERealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) if result.IsNull(i) { continue } - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val >= 0) } return nil @@ -871,7 +873,7 @@ func (b *builtinGEDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu if result.IsNull(i) { continue } - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val >= 0) } return nil @@ -945,7 +947,7 @@ func (b *builtinEQRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) if result.IsNull(i) { continue } - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val == 0) } return nil @@ -1095,7 +1097,7 @@ func (b *builtinEQDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu if result.IsNull(i) { continue } - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val == 0) } return nil @@ -1169,7 +1171,7 @@ func (b *builtinNERealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) if result.IsNull(i) { continue } - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val != 0) } return nil @@ -1319,7 +1321,7 @@ func (b *builtinNEDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu if result.IsNull(i) { continue } - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val != 0) } return nil @@ -1396,7 +1398,7 @@ func (b *builtinNullEQRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu i64s[i] = 1 case isNull0 != isNull1: i64s[i] = 0 - case types.CompareFloat64(arg0[i], arg1[i]) == 0: + case cmp.Compare(arg0[i], arg1[i]) == 0: i64s[i] = 1 } } @@ -1562,7 +1564,7 @@ func (b *builtinNullEQDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk. i64s[i] = 1 case isNull0 != isNull1: i64s[i] = 0 - case types.CompareDuration(arg0[i], arg1[i]) == 0: + case cmp.Compare(arg0[i], arg1[i]) == 0: i64s[i] = 1 } } diff --git a/expression/builtin_other_vec_generated.go b/expression/builtin_other_vec_generated.go index 0033e4f4baeb5..25f872b26e40e 100644 --- a/expression/builtin_other_vec_generated.go +++ b/expression/builtin_other_vec_generated.go @@ -17,6 +17,8 @@ package expression import ( + "cmp" + "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -363,7 +365,7 @@ func (b *builtinInRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) } arg0 := args0[i] arg1 := args1[i] - compareResult = types.CompareFloat64(arg0, arg1) + compareResult = cmp.Compare(arg0, arg1) if compareResult == 0 { result.SetNull(i, false) r64s[i] = 1 @@ -529,7 +531,7 @@ func (b *builtinInDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu } arg0 := args0[i] arg1 := args1[i] - compareResult = types.CompareDuration(arg0, arg1) + compareResult = cmp.Compare(arg0, arg1) if compareResult == 0 { result.SetNull(i, false) r64s[i] = 1 diff --git a/expression/generator/compare_vec.go b/expression/generator/compare_vec.go index 423a0949deea5..4f433330a12a5 100644 --- a/expression/generator/compare_vec.go +++ b/expression/generator/compare_vec.go @@ -51,6 +51,8 @@ package expression const newLine = "\n" const builtinCompareImports = `import ( + "cmp" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" ) @@ -90,11 +92,11 @@ func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEvalInt(in {{- if eq .type.ETName "Json" }} val := types.CompareBinaryJSON(buf0.GetJSON(i), buf1.GetJSON(i)) {{- else if eq .type.ETName "Real" }} - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) {{- else if eq .type.ETName "String" }} val := types.CompareString(buf0.GetString(i), buf1.GetString(i), b.collation) {{- else if eq .type.ETName "Duration" }} - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) {{- else if eq .type.ETName "Datetime" }} val := arg0[i].Compare(arg1[i]) {{- else if eq .type.ETName "Decimal" }} @@ -147,11 +149,11 @@ func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEvalInt(in {{- if eq .type.ETName "Json" }} case types.CompareBinaryJSON(buf0.GetJSON(i), buf1.GetJSON(i)) == 0: {{- else if eq .type.ETName "Real" }} - case types.CompareFloat64(arg0[i], arg1[i]) == 0: + case cmp.Compare(arg0[i], arg1[i]) == 0: {{- else if eq .type.ETName "String" }} case types.CompareString(buf0.GetString(i), buf1.GetString(i), b.collation) == 0: {{- else if eq .type.ETName "Duration" }} - case types.CompareDuration(arg0[i], arg1[i]) == 0: + case cmp.Compare(arg0[i], arg1[i]) == 0: {{- else if eq .type.ETName "Datetime" }} case arg0[i].Compare(arg1[i]) == 0: {{- else if eq .type.ETName "Decimal" }} diff --git a/expression/generator/other_vec.go b/expression/generator/other_vec.go index 6a120b4c810ea..6c7943682b56c 100644 --- a/expression/generator/other_vec.go +++ b/expression/generator/other_vec.go @@ -50,6 +50,8 @@ package expression const newLine = "\n" const builtinOtherImports = `import ( + "cmp" + "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -106,11 +108,13 @@ var builtinInTmpl = template.Must(template.New("builtinInTmpl").Parse(` {{- else if eq .Input.TypeName "Time" -}} compareResult = arg0.Compare(arg1) {{- else if eq .Input.TypeName "Duration" -}} - compareResult = types.CompareDuration(arg0, arg1) + compareResult = cmp.Compare(arg0, arg1) {{- else if eq .Input.TypeName "JSON" -}} compareResult = types.CompareBinaryJSON(arg0, arg1) {{- else if eq .Input.TypeName "String" -}} compareResult = types.CompareString(arg0, arg1, b.collation) + {{- else if eq .Input.TypeNameInColumn "Float64" -}} + compareResult = cmp.Compare(arg0, arg1) {{- else -}} compareResult = types.Compare{{ .Input.TypeNameInColumn }}(arg0, arg1) {{- end -}} diff --git a/store/mockstore/unistore/cophandler/topn.go b/store/mockstore/unistore/cophandler/topn.go index bffa889ad4743..5992af0b270ec 100644 --- a/store/mockstore/unistore/cophandler/topn.go +++ b/store/mockstore/unistore/cophandler/topn.go @@ -15,6 +15,7 @@ package cophandler import ( + "cmp" "container/heap" "github.com/pingcap/errors" @@ -104,7 +105,7 @@ func (t *topNHeap) Less(i, j int) bool { var ret int var err error if expression.FieldTypeFromPB(by.GetExpr().GetFieldType()).GetType() == mysql.TypeEnum { - ret = types.CompareUint64(v1.GetUint64(), v2.GetUint64()) + ret = cmp.Compare(v1.GetUint64(), v2.GetUint64()) } else { ret, err = v1.Compare(t.sc, &v2, collate.GetCollator(collate.ProtoToCollation(by.Expr.FieldType.Collate))) if err != nil { diff --git a/types/compare.go b/types/compare.go index f43314c9185d3..c7a097241cf5e 100644 --- a/types/compare.go +++ b/types/compare.go @@ -16,33 +16,10 @@ package types import ( "math" - "time" "github.com/pingcap/tidb/util/collate" ) -// CompareInt64 returns an integer comparing the int64 x to y. -func CompareInt64(x, y int64) int { - if x < y { - return -1 - } else if x == y { - return 0 - } - - return 1 -} - -// CompareUint64 returns an integer comparing the uint64 x to y. -func CompareUint64(x, y uint64) int { - if x < y { - return -1 - } else if x == y { - return 0 - } - - return 1 -} - // VecCompareUU returns []int64 comparing the []uint64 x to []uint64 y func VecCompareUU(x, y []uint64, res []int64) { n := len(x) @@ -103,29 +80,7 @@ func VecCompareIU(x []int64, y []uint64, res []int64) { } } -// CompareFloat64 returns an integer comparing the float64 x to y. -func CompareFloat64(x, y float64) int { - if x < y { - return -1 - } else if x == y { - return 0 - } - - return 1 -} - // CompareString returns an integer comparing the string x to y with the specified collation and length. func CompareString(x, y, collation string) int { return collate.GetCollator(collation).Compare(x, y) } - -// CompareDuration returns an integer comparing the duration x to y. -func CompareDuration(x, y time.Duration) int { - if x < y { - return -1 - } else if x == y { - return 0 - } - - return 1 -} diff --git a/types/datum.go b/types/datum.go index 3c33f7aa24fdb..e4978293f79d2 100644 --- a/types/datum.go +++ b/types/datum.go @@ -15,6 +15,7 @@ package types import ( + "cmp" gjson "encoding/json" "fmt" "math" @@ -689,12 +690,12 @@ func (d *Datum) compareInt64(sc *stmtctx.StatementContext, i int64) (int, error) case KindMaxValue: return 1, nil case KindInt64: - return CompareInt64(d.i, i), nil + return cmp.Compare(d.i, i), nil case KindUint64: if i < 0 || d.GetUint64() > math.MaxInt64 { return 1, nil } - return CompareInt64(d.i, i), nil + return cmp.Compare(d.i, i), nil default: return d.compareFloat64(sc, float64(i)) } @@ -708,9 +709,9 @@ func (d *Datum) compareUint64(sc *stmtctx.StatementContext, u uint64) (int, erro if d.i < 0 || u > math.MaxInt64 { return -1, nil } - return CompareInt64(d.i, int64(u)), nil + return cmp.Compare(d.i, int64(u)), nil case KindUint64: - return CompareUint64(d.GetUint64(), u), nil + return cmp.Compare(d.GetUint64(), u), nil default: return d.compareFloat64(sc, float64(u)) } @@ -723,33 +724,33 @@ func (d *Datum) compareFloat64(sc *stmtctx.StatementContext, f float64) (int, er case KindMaxValue: return 1, nil case KindInt64: - return CompareFloat64(float64(d.i), f), nil + return cmp.Compare(float64(d.i), f), nil case KindUint64: - return CompareFloat64(float64(d.GetUint64()), f), nil + return cmp.Compare(float64(d.GetUint64()), f), nil case KindFloat32, KindFloat64: - return CompareFloat64(d.GetFloat64(), f), nil + return cmp.Compare(d.GetFloat64(), f), nil case KindString, KindBytes: fVal, err := StrToFloat(sc, d.GetString(), false) - return CompareFloat64(fVal, f), errors.Trace(err) + return cmp.Compare(fVal, f), errors.Trace(err) case KindMysqlDecimal: fVal, err := d.GetMysqlDecimal().ToFloat64() - return CompareFloat64(fVal, f), errors.Trace(err) + return cmp.Compare(fVal, f), errors.Trace(err) case KindMysqlDuration: fVal := d.GetMysqlDuration().Seconds() - return CompareFloat64(fVal, f), nil + return cmp.Compare(fVal, f), nil case KindMysqlEnum: fVal := d.GetMysqlEnum().ToNumber() - return CompareFloat64(fVal, f), nil + return cmp.Compare(fVal, f), nil case KindBinaryLiteral, KindMysqlBit: val, err := d.GetBinaryLiteral4Cmp().ToInt(sc) fVal := float64(val) - return CompareFloat64(fVal, f), errors.Trace(err) + return cmp.Compare(fVal, f), errors.Trace(err) case KindMysqlSet: fVal := d.GetMysqlSet().ToNumber() - return CompareFloat64(fVal, f), nil + return cmp.Compare(fVal, f), nil case KindMysqlTime: fVal, err := d.GetMysqlTime().ToNumber().ToFloat64() - return CompareFloat64(fVal, f), errors.Trace(err) + return cmp.Compare(fVal, f), errors.Trace(err) default: return -1, nil } diff --git a/util/chunk/compare.go b/util/chunk/compare.go index 546f73f8f1a1b..734953071ffbe 100644 --- a/util/chunk/compare.go +++ b/util/chunk/compare.go @@ -16,6 +16,7 @@ package chunk import ( "bytes" + "cmp" "sort" "github.com/pingcap/tidb/parser/mysql" @@ -71,7 +72,7 @@ func cmpInt64(l Row, lCol int, r Row, rCol int) int { if lNull || rNull { return cmpNull(lNull, rNull) } - return types.CompareInt64(l.GetInt64(lCol), r.GetInt64(rCol)) + return cmp.Compare(l.GetInt64(lCol), r.GetInt64(rCol)) } func cmpUint64(l Row, lCol int, r Row, rCol int) int { @@ -79,7 +80,7 @@ func cmpUint64(l Row, lCol int, r Row, rCol int) int { if lNull || rNull { return cmpNull(lNull, rNull) } - return types.CompareUint64(l.GetUint64(lCol), r.GetUint64(rCol)) + return cmp.Compare(l.GetUint64(lCol), r.GetUint64(rCol)) } func genCmpStringFunc(collation string) func(l Row, lCol int, r Row, rCol int) int { @@ -101,7 +102,7 @@ func cmpFloat32(l Row, lCol int, r Row, rCol int) int { if lNull || rNull { return cmpNull(lNull, rNull) } - return types.CompareFloat64(float64(l.GetFloat32(lCol)), float64(r.GetFloat32(rCol))) + return cmp.Compare(float64(l.GetFloat32(lCol)), float64(r.GetFloat32(rCol))) } func cmpFloat64(l Row, lCol int, r Row, rCol int) int { @@ -109,7 +110,7 @@ func cmpFloat64(l Row, lCol int, r Row, rCol int) int { if lNull || rNull { return cmpNull(lNull, rNull) } - return types.CompareFloat64(l.GetFloat64(lCol), r.GetFloat64(rCol)) + return cmp.Compare(l.GetFloat64(lCol), r.GetFloat64(rCol)) } func cmpMyDecimal(l Row, lCol int, r Row, rCol int) int { @@ -136,7 +137,7 @@ func cmpDuration(l Row, lCol int, r Row, rCol int) int { return cmpNull(lNull, rNull) } lDur, rDur := l.GetDuration(lCol, 0).Duration, r.GetDuration(rCol, 0).Duration - return types.CompareInt64(int64(lDur), int64(rDur)) + return cmp.Compare(int64(lDur), int64(rDur)) } func cmpNameValue(l Row, lCol int, r Row, rCol int) int { @@ -146,7 +147,7 @@ func cmpNameValue(l Row, lCol int, r Row, rCol int) int { } _, lVal := l.getNameValue(lCol) _, rVal := r.getNameValue(rCol) - return types.CompareUint64(lVal, rVal) + return cmp.Compare(lVal, rVal) } func cmpBit(l Row, lCol int, r Row, rCol int) int { @@ -185,13 +186,13 @@ func Compare(row Row, colIdx int, ad *types.Datum) int { case types.KindMaxValue: return -1 case types.KindInt64: - return types.CompareInt64(row.GetInt64(colIdx), ad.GetInt64()) + return cmp.Compare(row.GetInt64(colIdx), ad.GetInt64()) case types.KindUint64: - return types.CompareUint64(row.GetUint64(colIdx), ad.GetUint64()) + return cmp.Compare(row.GetUint64(colIdx), ad.GetUint64()) case types.KindFloat32: - return types.CompareFloat64(float64(row.GetFloat32(colIdx)), float64(ad.GetFloat32())) + return cmp.Compare(float64(row.GetFloat32(colIdx)), float64(ad.GetFloat32())) case types.KindFloat64: - return types.CompareFloat64(row.GetFloat64(colIdx), ad.GetFloat64()) + return cmp.Compare(row.GetFloat64(colIdx), ad.GetFloat64()) case types.KindString: return types.CompareString(row.GetString(colIdx), ad.GetString(), ad.Collation()) case types.KindBytes, types.KindBinaryLiteral, types.KindMysqlBit: @@ -201,13 +202,13 @@ func Compare(row Row, colIdx int, ad *types.Datum) int { return l.Compare(r) case types.KindMysqlDuration: l, r := row.GetDuration(colIdx, 0).Duration, ad.GetMysqlDuration().Duration - return types.CompareInt64(int64(l), int64(r)) + return cmp.Compare(int64(l), int64(r)) case types.KindMysqlEnum: l, r := row.GetEnum(colIdx).Value, ad.GetMysqlEnum().Value - return types.CompareUint64(l, r) + return cmp.Compare(l, r) case types.KindMysqlSet: l, r := row.GetSet(colIdx).Value, ad.GetMysqlSet().Value - return types.CompareUint64(l, r) + return cmp.Compare(l, r) case types.KindMysqlJSON: l, r := row.GetJSON(colIdx), ad.GetMysqlJSON() return types.CompareBinaryJSON(l, r) diff --git a/util/ranger/points.go b/util/ranger/points.go index e5061caa43446..75c5bfa363c8c 100644 --- a/util/ranger/points.go +++ b/util/ranger/points.go @@ -15,6 +15,7 @@ package ranger import ( + "cmp" "fmt" "math" "sort" @@ -115,7 +116,7 @@ func rangePointLess(sc *stmtctx.StatementContext, a, b *point, collator collate. } func rangePointEnumLess(_ *stmtctx.StatementContext, a, b *point) (bool, error) { - cmp := types.CompareInt64(a.value.GetInt64(), b.value.GetInt64()) + cmp := cmp.Compare(a.value.GetInt64(), b.value.GetInt64()) if cmp != 0 { return cmp < 0, nil }