From 1bca15efde32b2af586b4e1d22acca246805c47e Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 5 Sep 2023 11:23:10 +0800 Subject: [PATCH 1/6] *: use cmp.Compare to replace CompareDuration Signed-off-by: Weizhen Wang --- expression/builtin_compare_vec_generated.go | 16 +++++++++------- expression/builtin_other_vec_generated.go | 4 +++- expression/generator/compare_vec.go | 6 ++++-- expression/generator/other_vec.go | 4 +++- types/compare.go | 12 ------------ 5 files changed, 19 insertions(+), 23 deletions(-) diff --git a/expression/builtin_compare_vec_generated.go b/expression/builtin_compare_vec_generated.go index e4dca22a50acc..6a9c3a325764e 100644 --- a/expression/builtin_compare_vec_generated.go +++ b/expression/builtin_compare_vec_generated.go @@ -17,6 +17,8 @@ package expression import ( + "cmp" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" ) @@ -199,7 +201,7 @@ func (b *builtinLTDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu if result.IsNull(i) { continue } - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val < 0) } return nil @@ -423,7 +425,7 @@ func (b *builtinLEDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu if result.IsNull(i) { continue } - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val <= 0) } return nil @@ -647,7 +649,7 @@ func (b *builtinGTDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu if result.IsNull(i) { continue } - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val > 0) } return nil @@ -871,7 +873,7 @@ func (b *builtinGEDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu if result.IsNull(i) { continue } - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val >= 0) } return nil @@ -1095,7 +1097,7 @@ func (b *builtinEQDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu if result.IsNull(i) { continue } - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val == 0) } return nil @@ -1319,7 +1321,7 @@ func (b *builtinNEDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu if result.IsNull(i) { continue } - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val != 0) } return nil @@ -1562,7 +1564,7 @@ func (b *builtinNullEQDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk. i64s[i] = 1 case isNull0 != isNull1: i64s[i] = 0 - case types.CompareDuration(arg0[i], arg1[i]) == 0: + case cmp.Compare(arg0[i], arg1[i]) == 0: i64s[i] = 1 } } diff --git a/expression/builtin_other_vec_generated.go b/expression/builtin_other_vec_generated.go index 0033e4f4baeb5..a4aa8ccd38c55 100644 --- a/expression/builtin_other_vec_generated.go +++ b/expression/builtin_other_vec_generated.go @@ -17,6 +17,8 @@ package expression import ( + "cmp" + "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -529,7 +531,7 @@ func (b *builtinInDurationSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu } arg0 := args0[i] arg1 := args1[i] - compareResult = types.CompareDuration(arg0, arg1) + compareResult = cmp.Compare(arg0, arg1) if compareResult == 0 { result.SetNull(i, false) r64s[i] = 1 diff --git a/expression/generator/compare_vec.go b/expression/generator/compare_vec.go index 423a0949deea5..6ee607e6a0f97 100644 --- a/expression/generator/compare_vec.go +++ b/expression/generator/compare_vec.go @@ -51,6 +51,8 @@ package expression const newLine = "\n" const builtinCompareImports = `import ( + "cmp" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" ) @@ -94,7 +96,7 @@ func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEvalInt(in {{- else if eq .type.ETName "String" }} val := types.CompareString(buf0.GetString(i), buf1.GetString(i), b.collation) {{- else if eq .type.ETName "Duration" }} - val := types.CompareDuration(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) {{- else if eq .type.ETName "Datetime" }} val := arg0[i].Compare(arg1[i]) {{- else if eq .type.ETName "Decimal" }} @@ -151,7 +153,7 @@ func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEvalInt(in {{- else if eq .type.ETName "String" }} case types.CompareString(buf0.GetString(i), buf1.GetString(i), b.collation) == 0: {{- else if eq .type.ETName "Duration" }} - case types.CompareDuration(arg0[i], arg1[i]) == 0: + case cmp.Compare(arg0[i], arg1[i]) == 0: {{- else if eq .type.ETName "Datetime" }} case arg0[i].Compare(arg1[i]) == 0: {{- else if eq .type.ETName "Decimal" }} diff --git a/expression/generator/other_vec.go b/expression/generator/other_vec.go index 6a120b4c810ea..6f8c1679feb10 100644 --- a/expression/generator/other_vec.go +++ b/expression/generator/other_vec.go @@ -50,6 +50,8 @@ package expression const newLine = "\n" const builtinOtherImports = `import ( + "cmp" + "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -106,7 +108,7 @@ var builtinInTmpl = template.Must(template.New("builtinInTmpl").Parse(` {{- else if eq .Input.TypeName "Time" -}} compareResult = arg0.Compare(arg1) {{- else if eq .Input.TypeName "Duration" -}} - compareResult = types.CompareDuration(arg0, arg1) + compareResult = cmp.Compare(arg0, arg1) {{- else if eq .Input.TypeName "JSON" -}} compareResult = types.CompareBinaryJSON(arg0, arg1) {{- else if eq .Input.TypeName "String" -}} diff --git a/types/compare.go b/types/compare.go index f43314c9185d3..4fcbfc767026b 100644 --- a/types/compare.go +++ b/types/compare.go @@ -16,7 +16,6 @@ package types import ( "math" - "time" "github.com/pingcap/tidb/util/collate" ) @@ -118,14 +117,3 @@ func CompareFloat64(x, y float64) int { func CompareString(x, y, collation string) int { return collate.GetCollator(collation).Compare(x, y) } - -// CompareDuration returns an integer comparing the duration x to y. -func CompareDuration(x, y time.Duration) int { - if x < y { - return -1 - } else if x == y { - return 0 - } - - return 1 -} From d56f20c6812e4c9de80e234a7b0c482d157da403 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 5 Sep 2023 11:29:41 +0800 Subject: [PATCH 2/6] *: use cmp.Compare to replace CompareDuration Signed-off-by: Weizhen Wang --- executor/aggfuncs/func_max_min.go | 5 +++-- expression/builtin_compare.go | 4 ++-- expression/builtin_compare_vec_generated.go | 14 +++++++------- expression/builtin_other_vec_generated.go | 2 +- expression/generator/compare_vec.go | 4 ++-- types/compare.go | 11 ----------- util/chunk/compare.go | 9 +++++---- 7 files changed, 20 insertions(+), 29 deletions(-) diff --git a/executor/aggfuncs/func_max_min.go b/executor/aggfuncs/func_max_min.go index 7f7f302bbeda2..20f5a8970a8cb 100644 --- a/executor/aggfuncs/func_max_min.go +++ b/executor/aggfuncs/func_max_min.go @@ -15,6 +15,7 @@ package aggfuncs import ( + "cmp" "unsafe" "github.com/pingcap/errors" @@ -589,7 +590,7 @@ type maxMin4Float32Sliding struct { func (e *maxMin4Float32Sliding) AllocPartialResult() (pr PartialResult, memDelta int64) { p, memDelta := e.maxMin4Float32.AllocPartialResult() (*partialResult4MaxMinFloat32)(p).deque = NewDeque(e.isMax, func(i, j interface{}) int { - return types.CompareFloat64(float64(i.(float32)), float64(j.(float32))) + return cmp.Compare(float64(i.(float32)), float64(j.(float32))) }) return p, memDelta + DefMaxMinDequeSize } @@ -726,7 +727,7 @@ type maxMin4Float64Sliding struct { func (e *maxMin4Float64Sliding) AllocPartialResult() (pr PartialResult, memDelta int64) { p, memDelta := e.maxMin4Float64.AllocPartialResult() (*partialResult4MaxMinFloat64)(p).deque = NewDeque(e.isMax, func(i, j interface{}) int { - return types.CompareFloat64(i.(float64), j.(float64)) + return cmp.Compare(i.(float64), j.(float64)) }) return p, memDelta + DefMaxMinDequeSize } diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 0f894ee728f3d..8af25b064f50c 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -2674,7 +2674,7 @@ func (b *builtinNullEQRealSig) evalInt(row chunk.Row) (val int64, isNull bool, e res = 1 case isNull0 != isNull1: return res, false, nil - case types.CompareFloat64(arg0, arg1) == 0: + case cmp.Compare(arg0, arg1) == 0: res = 1 } return res, false, nil @@ -3006,7 +3006,7 @@ func CompareReal(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhs if isNull0 || isNull1 { return compareNull(isNull0, isNull1), true, nil } - return int64(types.CompareFloat64(arg0, arg1)), false, nil + return int64(cmp.Compare(arg0, arg1)), false, nil } // CompareDecimal compares two decimals. diff --git a/expression/builtin_compare_vec_generated.go b/expression/builtin_compare_vec_generated.go index 6a9c3a325764e..5b53f65229c31 100644 --- a/expression/builtin_compare_vec_generated.go +++ b/expression/builtin_compare_vec_generated.go @@ -51,7 +51,7 @@ func (b *builtinLTRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) if result.IsNull(i) { continue } - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val < 0) } return nil @@ -275,7 +275,7 @@ func (b *builtinLERealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) if result.IsNull(i) { continue } - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val <= 0) } return nil @@ -499,7 +499,7 @@ func (b *builtinGTRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) if result.IsNull(i) { continue } - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val > 0) } return nil @@ -723,7 +723,7 @@ func (b *builtinGERealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) if result.IsNull(i) { continue } - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val >= 0) } return nil @@ -947,7 +947,7 @@ func (b *builtinEQRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) if result.IsNull(i) { continue } - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val == 0) } return nil @@ -1171,7 +1171,7 @@ func (b *builtinNERealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) if result.IsNull(i) { continue } - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) i64s[i] = boolToInt64(val != 0) } return nil @@ -1398,7 +1398,7 @@ func (b *builtinNullEQRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colu i64s[i] = 1 case isNull0 != isNull1: i64s[i] = 0 - case types.CompareFloat64(arg0[i], arg1[i]) == 0: + case cmp.Compare(arg0[i], arg1[i]) == 0: i64s[i] = 1 } } diff --git a/expression/builtin_other_vec_generated.go b/expression/builtin_other_vec_generated.go index a4aa8ccd38c55..25f872b26e40e 100644 --- a/expression/builtin_other_vec_generated.go +++ b/expression/builtin_other_vec_generated.go @@ -365,7 +365,7 @@ func (b *builtinInRealSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) } arg0 := args0[i] arg1 := args1[i] - compareResult = types.CompareFloat64(arg0, arg1) + compareResult = cmp.Compare(arg0, arg1) if compareResult == 0 { result.SetNull(i, false) r64s[i] = 1 diff --git a/expression/generator/compare_vec.go b/expression/generator/compare_vec.go index 6ee607e6a0f97..4f433330a12a5 100644 --- a/expression/generator/compare_vec.go +++ b/expression/generator/compare_vec.go @@ -92,7 +92,7 @@ func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEvalInt(in {{- if eq .type.ETName "Json" }} val := types.CompareBinaryJSON(buf0.GetJSON(i), buf1.GetJSON(i)) {{- else if eq .type.ETName "Real" }} - val := types.CompareFloat64(arg0[i], arg1[i]) + val := cmp.Compare(arg0[i], arg1[i]) {{- else if eq .type.ETName "String" }} val := types.CompareString(buf0.GetString(i), buf1.GetString(i), b.collation) {{- else if eq .type.ETName "Duration" }} @@ -149,7 +149,7 @@ func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEvalInt(in {{- if eq .type.ETName "Json" }} case types.CompareBinaryJSON(buf0.GetJSON(i), buf1.GetJSON(i)) == 0: {{- else if eq .type.ETName "Real" }} - case types.CompareFloat64(arg0[i], arg1[i]) == 0: + case cmp.Compare(arg0[i], arg1[i]) == 0: {{- else if eq .type.ETName "String" }} case types.CompareString(buf0.GetString(i), buf1.GetString(i), b.collation) == 0: {{- else if eq .type.ETName "Duration" }} diff --git a/types/compare.go b/types/compare.go index 4fcbfc767026b..3760f7c6ef074 100644 --- a/types/compare.go +++ b/types/compare.go @@ -102,17 +102,6 @@ func VecCompareIU(x []int64, y []uint64, res []int64) { } } -// CompareFloat64 returns an integer comparing the float64 x to y. -func CompareFloat64(x, y float64) int { - if x < y { - return -1 - } else if x == y { - return 0 - } - - return 1 -} - // CompareString returns an integer comparing the string x to y with the specified collation and length. func CompareString(x, y, collation string) int { return collate.GetCollator(collation).Compare(x, y) diff --git a/util/chunk/compare.go b/util/chunk/compare.go index 546f73f8f1a1b..9e255a14e080a 100644 --- a/util/chunk/compare.go +++ b/util/chunk/compare.go @@ -16,6 +16,7 @@ package chunk import ( "bytes" + "cmp" "sort" "github.com/pingcap/tidb/parser/mysql" @@ -101,7 +102,7 @@ func cmpFloat32(l Row, lCol int, r Row, rCol int) int { if lNull || rNull { return cmpNull(lNull, rNull) } - return types.CompareFloat64(float64(l.GetFloat32(lCol)), float64(r.GetFloat32(rCol))) + return cmp.Compare(float64(l.GetFloat32(lCol)), float64(r.GetFloat32(rCol))) } func cmpFloat64(l Row, lCol int, r Row, rCol int) int { @@ -109,7 +110,7 @@ func cmpFloat64(l Row, lCol int, r Row, rCol int) int { if lNull || rNull { return cmpNull(lNull, rNull) } - return types.CompareFloat64(l.GetFloat64(lCol), r.GetFloat64(rCol)) + return cmp.Compare(l.GetFloat64(lCol), r.GetFloat64(rCol)) } func cmpMyDecimal(l Row, lCol int, r Row, rCol int) int { @@ -189,9 +190,9 @@ func Compare(row Row, colIdx int, ad *types.Datum) int { case types.KindUint64: return types.CompareUint64(row.GetUint64(colIdx), ad.GetUint64()) case types.KindFloat32: - return types.CompareFloat64(float64(row.GetFloat32(colIdx)), float64(ad.GetFloat32())) + return cmp.Compare(float64(row.GetFloat32(colIdx)), float64(ad.GetFloat32())) case types.KindFloat64: - return types.CompareFloat64(row.GetFloat64(colIdx), ad.GetFloat64()) + return cmp.Compare(row.GetFloat64(colIdx), ad.GetFloat64()) case types.KindString: return types.CompareString(row.GetString(colIdx), ad.GetString(), ad.Collation()) case types.KindBytes, types.KindBinaryLiteral, types.KindMysqlBit: From f3bd4b4cbd6183a042557bba8b29afc198fd873d Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 5 Sep 2023 11:41:42 +0800 Subject: [PATCH 3/6] *: use cmp.Compare to replace CompareDuration Signed-off-by: Weizhen Wang --- executor/aggfuncs/func_max_min.go | 4 +-- executor/aggfuncs/func_max_min_test.go | 4 +-- expression/builtin_compare.go | 16 ++++++------ store/mockstore/unistore/cophandler/topn.go | 2 +- types/compare.go | 22 ---------------- types/datum.go | 29 +++++++++++---------- util/chunk/compare.go | 18 ++++++------- util/ranger/points.go | 2 +- 8 files changed, 38 insertions(+), 59 deletions(-) diff --git a/executor/aggfuncs/func_max_min.go b/executor/aggfuncs/func_max_min.go index 20f5a8970a8cb..a9bc2b6b69692 100644 --- a/executor/aggfuncs/func_max_min.go +++ b/executor/aggfuncs/func_max_min.go @@ -314,7 +314,7 @@ func (e *maxMin4IntSliding) ResetPartialResult(pr PartialResult) { func (e *maxMin4IntSliding) AllocPartialResult() (pr PartialResult, memDelta int64) { p, memDelta := e.maxMin4Int.AllocPartialResult() (*partialResult4MaxMinInt)(p).deque = NewDeque(e.isMax, func(i, j interface{}) int { - return types.CompareInt64(i.(int64), j.(int64)) + return cmp.Compare(i.(int64), j.(int64)) }) return p, memDelta + DefMaxMinDequeSize } @@ -451,7 +451,7 @@ type maxMin4UintSliding struct { func (e *maxMin4UintSliding) AllocPartialResult() (pr PartialResult, memDelta int64) { p, memDelta := e.maxMin4Uint.AllocPartialResult() (*partialResult4MaxMinUint)(p).deque = NewDeque(e.isMax, func(i, j interface{}) int { - return types.CompareUint64(i.(uint64), j.(uint64)) + return cmp.Compare(i.(uint64), j.(uint64)) }) return p, memDelta + DefMaxMinDequeSize } diff --git a/executor/aggfuncs/func_max_min_test.go b/executor/aggfuncs/func_max_min_test.go index 2505b186fe61b..319c69517d4f2 100644 --- a/executor/aggfuncs/func_max_min_test.go +++ b/executor/aggfuncs/func_max_min_test.go @@ -336,7 +336,7 @@ func TestMaxSlidingWindow(t *testing.T) { func TestDequeReset(t *testing.T) { deque := aggfuncs.NewDeque(true, func(i, j interface{}) int { - return types.CompareInt64(i.(int64), j.(int64)) + return cmp.Compare(i.(int64), j.(int64)) }) deque.PushBack(0, 12) deque.Reset() @@ -346,7 +346,7 @@ func TestDequeReset(t *testing.T) { func TestDequePushPop(t *testing.T) { deque := aggfuncs.NewDeque(true, func(i, j interface{}) int { - return types.CompareInt64(i.(int64), j.(int64)) + return cmp.Compare(i.(int64), j.(int64)) }) times := 15 // pushes element from back of deque diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 8af25b064f50c..3a072af14ee2d 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -2627,22 +2627,22 @@ func (b *builtinNullEQIntSig) evalInt(row chunk.Row) (val int64, isNull bool, er res = 1 case isNull0 != isNull1: return res, false, nil - case isUnsigned0 && isUnsigned1 && types.CompareUint64(uint64(arg0), uint64(arg1)) == 0: + case isUnsigned0 && isUnsigned1 && cmp.Compare(uint64(arg0), uint64(arg1)) == 0: res = 1 - case !isUnsigned0 && !isUnsigned1 && types.CompareInt64(arg0, arg1) == 0: + case !isUnsigned0 && !isUnsigned1 && cmp.Compare(arg0, arg1) == 0: res = 1 case isUnsigned0 && !isUnsigned1: if arg1 < 0 { return res, false, nil } - if types.CompareInt64(arg0, arg1) == 0 { + if cmp.Compare(arg0, arg1) == 0 { res = 1 } case !isUnsigned0 && isUnsigned1: if arg0 < 0 { return res, false, nil } - if types.CompareInt64(arg0, arg1) == 0 { + if cmp.Compare(arg0, arg1) == 0 { res = 1 } } @@ -2948,21 +2948,21 @@ func CompareInt(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsR var res int switch { case isUnsigned0 && isUnsigned1: - res = types.CompareUint64(uint64(arg0), uint64(arg1)) + res = cmp.Compare(uint64(arg0), uint64(arg1)) case isUnsigned0 && !isUnsigned1: if arg1 < 0 || uint64(arg0) > math.MaxInt64 { res = 1 } else { - res = types.CompareInt64(arg0, arg1) + res = cmp.Compare(arg0, arg1) } case !isUnsigned0 && isUnsigned1: if arg0 < 0 || uint64(arg1) > math.MaxInt64 { res = -1 } else { - res = types.CompareInt64(arg0, arg1) + res = cmp.Compare(arg0, arg1) } case !isUnsigned0 && !isUnsigned1: - res = types.CompareInt64(arg0, arg1) + res = cmp.Compare(arg0, arg1) } return int64(res), false, nil } diff --git a/store/mockstore/unistore/cophandler/topn.go b/store/mockstore/unistore/cophandler/topn.go index bffa889ad4743..aee5e40f1a8c5 100644 --- a/store/mockstore/unistore/cophandler/topn.go +++ b/store/mockstore/unistore/cophandler/topn.go @@ -104,7 +104,7 @@ func (t *topNHeap) Less(i, j int) bool { var ret int var err error if expression.FieldTypeFromPB(by.GetExpr().GetFieldType()).GetType() == mysql.TypeEnum { - ret = types.CompareUint64(v1.GetUint64(), v2.GetUint64()) + ret = cmp.Compare(v1.GetUint64(), v2.GetUint64()) } else { ret, err = v1.Compare(t.sc, &v2, collate.GetCollator(collate.ProtoToCollation(by.Expr.FieldType.Collate))) if err != nil { diff --git a/types/compare.go b/types/compare.go index 3760f7c6ef074..c7a097241cf5e 100644 --- a/types/compare.go +++ b/types/compare.go @@ -20,28 +20,6 @@ import ( "github.com/pingcap/tidb/util/collate" ) -// CompareInt64 returns an integer comparing the int64 x to y. -func CompareInt64(x, y int64) int { - if x < y { - return -1 - } else if x == y { - return 0 - } - - return 1 -} - -// CompareUint64 returns an integer comparing the uint64 x to y. -func CompareUint64(x, y uint64) int { - if x < y { - return -1 - } else if x == y { - return 0 - } - - return 1 -} - // VecCompareUU returns []int64 comparing the []uint64 x to []uint64 y func VecCompareUU(x, y []uint64, res []int64) { n := len(x) diff --git a/types/datum.go b/types/datum.go index 3c33f7aa24fdb..e4978293f79d2 100644 --- a/types/datum.go +++ b/types/datum.go @@ -15,6 +15,7 @@ package types import ( + "cmp" gjson "encoding/json" "fmt" "math" @@ -689,12 +690,12 @@ func (d *Datum) compareInt64(sc *stmtctx.StatementContext, i int64) (int, error) case KindMaxValue: return 1, nil case KindInt64: - return CompareInt64(d.i, i), nil + return cmp.Compare(d.i, i), nil case KindUint64: if i < 0 || d.GetUint64() > math.MaxInt64 { return 1, nil } - return CompareInt64(d.i, i), nil + return cmp.Compare(d.i, i), nil default: return d.compareFloat64(sc, float64(i)) } @@ -708,9 +709,9 @@ func (d *Datum) compareUint64(sc *stmtctx.StatementContext, u uint64) (int, erro if d.i < 0 || u > math.MaxInt64 { return -1, nil } - return CompareInt64(d.i, int64(u)), nil + return cmp.Compare(d.i, int64(u)), nil case KindUint64: - return CompareUint64(d.GetUint64(), u), nil + return cmp.Compare(d.GetUint64(), u), nil default: return d.compareFloat64(sc, float64(u)) } @@ -723,33 +724,33 @@ func (d *Datum) compareFloat64(sc *stmtctx.StatementContext, f float64) (int, er case KindMaxValue: return 1, nil case KindInt64: - return CompareFloat64(float64(d.i), f), nil + return cmp.Compare(float64(d.i), f), nil case KindUint64: - return CompareFloat64(float64(d.GetUint64()), f), nil + return cmp.Compare(float64(d.GetUint64()), f), nil case KindFloat32, KindFloat64: - return CompareFloat64(d.GetFloat64(), f), nil + return cmp.Compare(d.GetFloat64(), f), nil case KindString, KindBytes: fVal, err := StrToFloat(sc, d.GetString(), false) - return CompareFloat64(fVal, f), errors.Trace(err) + return cmp.Compare(fVal, f), errors.Trace(err) case KindMysqlDecimal: fVal, err := d.GetMysqlDecimal().ToFloat64() - return CompareFloat64(fVal, f), errors.Trace(err) + return cmp.Compare(fVal, f), errors.Trace(err) case KindMysqlDuration: fVal := d.GetMysqlDuration().Seconds() - return CompareFloat64(fVal, f), nil + return cmp.Compare(fVal, f), nil case KindMysqlEnum: fVal := d.GetMysqlEnum().ToNumber() - return CompareFloat64(fVal, f), nil + return cmp.Compare(fVal, f), nil case KindBinaryLiteral, KindMysqlBit: val, err := d.GetBinaryLiteral4Cmp().ToInt(sc) fVal := float64(val) - return CompareFloat64(fVal, f), errors.Trace(err) + return cmp.Compare(fVal, f), errors.Trace(err) case KindMysqlSet: fVal := d.GetMysqlSet().ToNumber() - return CompareFloat64(fVal, f), nil + return cmp.Compare(fVal, f), nil case KindMysqlTime: fVal, err := d.GetMysqlTime().ToNumber().ToFloat64() - return CompareFloat64(fVal, f), errors.Trace(err) + return cmp.Compare(fVal, f), errors.Trace(err) default: return -1, nil } diff --git a/util/chunk/compare.go b/util/chunk/compare.go index 9e255a14e080a..734953071ffbe 100644 --- a/util/chunk/compare.go +++ b/util/chunk/compare.go @@ -72,7 +72,7 @@ func cmpInt64(l Row, lCol int, r Row, rCol int) int { if lNull || rNull { return cmpNull(lNull, rNull) } - return types.CompareInt64(l.GetInt64(lCol), r.GetInt64(rCol)) + return cmp.Compare(l.GetInt64(lCol), r.GetInt64(rCol)) } func cmpUint64(l Row, lCol int, r Row, rCol int) int { @@ -80,7 +80,7 @@ func cmpUint64(l Row, lCol int, r Row, rCol int) int { if lNull || rNull { return cmpNull(lNull, rNull) } - return types.CompareUint64(l.GetUint64(lCol), r.GetUint64(rCol)) + return cmp.Compare(l.GetUint64(lCol), r.GetUint64(rCol)) } func genCmpStringFunc(collation string) func(l Row, lCol int, r Row, rCol int) int { @@ -137,7 +137,7 @@ func cmpDuration(l Row, lCol int, r Row, rCol int) int { return cmpNull(lNull, rNull) } lDur, rDur := l.GetDuration(lCol, 0).Duration, r.GetDuration(rCol, 0).Duration - return types.CompareInt64(int64(lDur), int64(rDur)) + return cmp.Compare(int64(lDur), int64(rDur)) } func cmpNameValue(l Row, lCol int, r Row, rCol int) int { @@ -147,7 +147,7 @@ func cmpNameValue(l Row, lCol int, r Row, rCol int) int { } _, lVal := l.getNameValue(lCol) _, rVal := r.getNameValue(rCol) - return types.CompareUint64(lVal, rVal) + return cmp.Compare(lVal, rVal) } func cmpBit(l Row, lCol int, r Row, rCol int) int { @@ -186,9 +186,9 @@ func Compare(row Row, colIdx int, ad *types.Datum) int { case types.KindMaxValue: return -1 case types.KindInt64: - return types.CompareInt64(row.GetInt64(colIdx), ad.GetInt64()) + return cmp.Compare(row.GetInt64(colIdx), ad.GetInt64()) case types.KindUint64: - return types.CompareUint64(row.GetUint64(colIdx), ad.GetUint64()) + return cmp.Compare(row.GetUint64(colIdx), ad.GetUint64()) case types.KindFloat32: return cmp.Compare(float64(row.GetFloat32(colIdx)), float64(ad.GetFloat32())) case types.KindFloat64: @@ -202,13 +202,13 @@ func Compare(row Row, colIdx int, ad *types.Datum) int { return l.Compare(r) case types.KindMysqlDuration: l, r := row.GetDuration(colIdx, 0).Duration, ad.GetMysqlDuration().Duration - return types.CompareInt64(int64(l), int64(r)) + return cmp.Compare(int64(l), int64(r)) case types.KindMysqlEnum: l, r := row.GetEnum(colIdx).Value, ad.GetMysqlEnum().Value - return types.CompareUint64(l, r) + return cmp.Compare(l, r) case types.KindMysqlSet: l, r := row.GetSet(colIdx).Value, ad.GetMysqlSet().Value - return types.CompareUint64(l, r) + return cmp.Compare(l, r) case types.KindMysqlJSON: l, r := row.GetJSON(colIdx), ad.GetMysqlJSON() return types.CompareBinaryJSON(l, r) diff --git a/util/ranger/points.go b/util/ranger/points.go index e5061caa43446..0f163008b04c5 100644 --- a/util/ranger/points.go +++ b/util/ranger/points.go @@ -115,7 +115,7 @@ func rangePointLess(sc *stmtctx.StatementContext, a, b *point, collator collate. } func rangePointEnumLess(_ *stmtctx.StatementContext, a, b *point) (bool, error) { - cmp := types.CompareInt64(a.value.GetInt64(), b.value.GetInt64()) + cmp := cmp.Compare(a.value.GetInt64(), b.value.GetInt64()) if cmp != 0 { return cmp < 0, nil } From f48395b426a36a95e63d1508738ab5ffe725c2eb Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 5 Sep 2023 11:52:07 +0800 Subject: [PATCH 4/6] update Signed-off-by: Weizhen Wang --- expression/builtin_compare.go | 1 + expression/generator/other_vec.go | 2 ++ store/mockstore/unistore/cophandler/topn.go | 1 + util/ranger/points.go | 1 + 4 files changed, 5 insertions(+) diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 3a072af14ee2d..52f845bc35370 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -15,6 +15,7 @@ package expression import ( + "cmp" "math" "strings" diff --git a/expression/generator/other_vec.go b/expression/generator/other_vec.go index 6f8c1679feb10..6c7943682b56c 100644 --- a/expression/generator/other_vec.go +++ b/expression/generator/other_vec.go @@ -113,6 +113,8 @@ var builtinInTmpl = template.Must(template.New("builtinInTmpl").Parse(` compareResult = types.CompareBinaryJSON(arg0, arg1) {{- else if eq .Input.TypeName "String" -}} compareResult = types.CompareString(arg0, arg1, b.collation) + {{- else if eq .Input.TypeNameInColumn "Float64" -}} + compareResult = cmp.Compare(arg0, arg1) {{- else -}} compareResult = types.Compare{{ .Input.TypeNameInColumn }}(arg0, arg1) {{- end -}} diff --git a/store/mockstore/unistore/cophandler/topn.go b/store/mockstore/unistore/cophandler/topn.go index aee5e40f1a8c5..5992af0b270ec 100644 --- a/store/mockstore/unistore/cophandler/topn.go +++ b/store/mockstore/unistore/cophandler/topn.go @@ -15,6 +15,7 @@ package cophandler import ( + "cmp" "container/heap" "github.com/pingcap/errors" diff --git a/util/ranger/points.go b/util/ranger/points.go index 0f163008b04c5..75c5bfa363c8c 100644 --- a/util/ranger/points.go +++ b/util/ranger/points.go @@ -15,6 +15,7 @@ package ranger import ( + "cmp" "fmt" "math" "sort" From 54bd645a81d363ffa43a2c59ceb596f5beb717b7 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 5 Sep 2023 12:10:35 +0800 Subject: [PATCH 5/6] update Signed-off-by: Weizhen Wang --- executor/aggfuncs/func_max_min_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/executor/aggfuncs/func_max_min_test.go b/executor/aggfuncs/func_max_min_test.go index 319c69517d4f2..4eff13e3ceabf 100644 --- a/executor/aggfuncs/func_max_min_test.go +++ b/executor/aggfuncs/func_max_min_test.go @@ -15,6 +15,7 @@ package aggfuncs_test import ( + "cmp" "fmt" "testing" "time" From 7e2f78141f91c5525e8f6e723f520cf41cf8cf09 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 5 Sep 2023 12:26:10 +0800 Subject: [PATCH 6/6] update Signed-off-by: Weizhen Wang --- WORKSPACE | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/WORKSPACE b/WORKSPACE index a83727b55825b..a360d9d997bc6 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -110,3 +110,14 @@ http_archive( "https://github.com/bazelbuild/java_tools/releases/download/java_v12.6/java_tools_linux-v12.6.zip", ], ) + +http_archive( + name = "rules_proto", + strip_prefix = "rules_proto-40298556293ae502c66579620a7ce867d5f57311", + urls = [ + "http://bazel-cache.pingcap.net:8080/bazelbuild/rules_proto/archive/40298556293ae502c66579620a7ce867d5f57311.tar.gz", + "http://ats.apps.svc/bazelbuild/rules_proto/archive/40298556293ae502c66579620a7ce867d5f57311.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/rules_proto/archive/40298556293ae502c66579620a7ce867d5f57311.tar.gz", + "https://github.com/bazelbuild/rules_proto/archive/40298556293ae502c66579620a7ce867d5f57311.tar.gz", + ], +)