Skip to content

Commit

Permalink
*: use cmp.Compare to replace types.Compare (#46657)
Browse files Browse the repository at this point in the history
ref #45933
  • Loading branch information
hawkingrei authored Sep 5, 2023
1 parent 5d4cea5 commit 5dd4296
Show file tree
Hide file tree
Showing 12 changed files with 83 additions and 111 deletions.
9 changes: 5 additions & 4 deletions executor/aggfuncs/func_max_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package aggfuncs

import (
"cmp"
"unsafe"

"github.com/pingcap/errors"
Expand Down Expand Up @@ -313,7 +314,7 @@ func (e *maxMin4IntSliding) ResetPartialResult(pr PartialResult) {
func (e *maxMin4IntSliding) AllocPartialResult() (pr PartialResult, memDelta int64) {
p, memDelta := e.maxMin4Int.AllocPartialResult()
(*partialResult4MaxMinInt)(p).deque = NewDeque(e.isMax, func(i, j interface{}) int {
return types.CompareInt64(i.(int64), j.(int64))
return cmp.Compare(i.(int64), j.(int64))
})
return p, memDelta + DefMaxMinDequeSize
}
Expand Down Expand Up @@ -450,7 +451,7 @@ type maxMin4UintSliding struct {
func (e *maxMin4UintSliding) AllocPartialResult() (pr PartialResult, memDelta int64) {
p, memDelta := e.maxMin4Uint.AllocPartialResult()
(*partialResult4MaxMinUint)(p).deque = NewDeque(e.isMax, func(i, j interface{}) int {
return types.CompareUint64(i.(uint64), j.(uint64))
return cmp.Compare(i.(uint64), j.(uint64))
})
return p, memDelta + DefMaxMinDequeSize
}
Expand Down Expand Up @@ -589,7 +590,7 @@ type maxMin4Float32Sliding struct {
func (e *maxMin4Float32Sliding) AllocPartialResult() (pr PartialResult, memDelta int64) {
p, memDelta := e.maxMin4Float32.AllocPartialResult()
(*partialResult4MaxMinFloat32)(p).deque = NewDeque(e.isMax, func(i, j interface{}) int {
return types.CompareFloat64(float64(i.(float32)), float64(j.(float32)))
return cmp.Compare(float64(i.(float32)), float64(j.(float32)))
})
return p, memDelta + DefMaxMinDequeSize
}
Expand Down Expand Up @@ -726,7 +727,7 @@ type maxMin4Float64Sliding struct {
func (e *maxMin4Float64Sliding) AllocPartialResult() (pr PartialResult, memDelta int64) {
p, memDelta := e.maxMin4Float64.AllocPartialResult()
(*partialResult4MaxMinFloat64)(p).deque = NewDeque(e.isMax, func(i, j interface{}) int {
return types.CompareFloat64(i.(float64), j.(float64))
return cmp.Compare(i.(float64), j.(float64))
})
return p, memDelta + DefMaxMinDequeSize
}
Expand Down
5 changes: 3 additions & 2 deletions executor/aggfuncs/func_max_min_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package aggfuncs_test

import (
"cmp"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -336,7 +337,7 @@ func TestMaxSlidingWindow(t *testing.T) {

func TestDequeReset(t *testing.T) {
deque := aggfuncs.NewDeque(true, func(i, j interface{}) int {
return types.CompareInt64(i.(int64), j.(int64))
return cmp.Compare(i.(int64), j.(int64))
})
deque.PushBack(0, 12)
deque.Reset()
Expand All @@ -346,7 +347,7 @@ func TestDequeReset(t *testing.T) {

func TestDequePushPop(t *testing.T) {
deque := aggfuncs.NewDeque(true, func(i, j interface{}) int {
return types.CompareInt64(i.(int64), j.(int64))
return cmp.Compare(i.(int64), j.(int64))
})
times := 15
// pushes element from back of deque
Expand Down
21 changes: 11 additions & 10 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package expression

import (
"cmp"
"math"
"strings"

Expand Down Expand Up @@ -2627,22 +2628,22 @@ func (b *builtinNullEQIntSig) evalInt(row chunk.Row) (val int64, isNull bool, er
res = 1
case isNull0 != isNull1:
return res, false, nil
case isUnsigned0 && isUnsigned1 && types.CompareUint64(uint64(arg0), uint64(arg1)) == 0:
case isUnsigned0 && isUnsigned1 && cmp.Compare(uint64(arg0), uint64(arg1)) == 0:
res = 1
case !isUnsigned0 && !isUnsigned1 && types.CompareInt64(arg0, arg1) == 0:
case !isUnsigned0 && !isUnsigned1 && cmp.Compare(arg0, arg1) == 0:
res = 1
case isUnsigned0 && !isUnsigned1:
if arg1 < 0 {
return res, false, nil
}
if types.CompareInt64(arg0, arg1) == 0 {
if cmp.Compare(arg0, arg1) == 0 {
res = 1
}
case !isUnsigned0 && isUnsigned1:
if arg0 < 0 {
return res, false, nil
}
if types.CompareInt64(arg0, arg1) == 0 {
if cmp.Compare(arg0, arg1) == 0 {
res = 1
}
}
Expand Down Expand Up @@ -2674,7 +2675,7 @@ func (b *builtinNullEQRealSig) evalInt(row chunk.Row) (val int64, isNull bool, e
res = 1
case isNull0 != isNull1:
return res, false, nil
case types.CompareFloat64(arg0, arg1) == 0:
case cmp.Compare(arg0, arg1) == 0:
res = 1
}
return res, false, nil
Expand Down Expand Up @@ -2948,21 +2949,21 @@ func CompareInt(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhsR
var res int
switch {
case isUnsigned0 && isUnsigned1:
res = types.CompareUint64(uint64(arg0), uint64(arg1))
res = cmp.Compare(uint64(arg0), uint64(arg1))
case isUnsigned0 && !isUnsigned1:
if arg1 < 0 || uint64(arg0) > math.MaxInt64 {
res = 1
} else {
res = types.CompareInt64(arg0, arg1)
res = cmp.Compare(arg0, arg1)
}
case !isUnsigned0 && isUnsigned1:
if arg0 < 0 || uint64(arg1) > math.MaxInt64 {
res = -1
} else {
res = types.CompareInt64(arg0, arg1)
res = cmp.Compare(arg0, arg1)
}
case !isUnsigned0 && !isUnsigned1:
res = types.CompareInt64(arg0, arg1)
res = cmp.Compare(arg0, arg1)
}
return int64(res), false, nil
}
Expand Down Expand Up @@ -3006,7 +3007,7 @@ func CompareReal(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhs
if isNull0 || isNull1 {
return compareNull(isNull0, isNull1), true, nil
}
return int64(types.CompareFloat64(arg0, arg1)), false, nil
return int64(cmp.Compare(arg0, arg1)), false, nil
}

// CompareDecimal compares two decimals.
Expand Down
30 changes: 16 additions & 14 deletions expression/builtin_compare_vec_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions expression/builtin_other_vec_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions expression/generator/compare_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ package expression
const newLine = "\n"

const builtinCompareImports = `import (
"cmp"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
)
Expand Down Expand Up @@ -90,11 +92,11 @@ func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEvalInt(in
{{- if eq .type.ETName "Json" }}
val := types.CompareBinaryJSON(buf0.GetJSON(i), buf1.GetJSON(i))
{{- else if eq .type.ETName "Real" }}
val := types.CompareFloat64(arg0[i], arg1[i])
val := cmp.Compare(arg0[i], arg1[i])
{{- else if eq .type.ETName "String" }}
val := types.CompareString(buf0.GetString(i), buf1.GetString(i), b.collation)
{{- else if eq .type.ETName "Duration" }}
val := types.CompareDuration(arg0[i], arg1[i])
val := cmp.Compare(arg0[i], arg1[i])
{{- else if eq .type.ETName "Datetime" }}
val := arg0[i].Compare(arg1[i])
{{- else if eq .type.ETName "Decimal" }}
Expand Down Expand Up @@ -147,11 +149,11 @@ func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEvalInt(in
{{- if eq .type.ETName "Json" }}
case types.CompareBinaryJSON(buf0.GetJSON(i), buf1.GetJSON(i)) == 0:
{{- else if eq .type.ETName "Real" }}
case types.CompareFloat64(arg0[i], arg1[i]) == 0:
case cmp.Compare(arg0[i], arg1[i]) == 0:
{{- else if eq .type.ETName "String" }}
case types.CompareString(buf0.GetString(i), buf1.GetString(i), b.collation) == 0:
{{- else if eq .type.ETName "Duration" }}
case types.CompareDuration(arg0[i], arg1[i]) == 0:
case cmp.Compare(arg0[i], arg1[i]) == 0:
{{- else if eq .type.ETName "Datetime" }}
case arg0[i].Compare(arg1[i]) == 0:
{{- else if eq .type.ETName "Decimal" }}
Expand Down
6 changes: 5 additions & 1 deletion expression/generator/other_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ package expression
const newLine = "\n"

const builtinOtherImports = `import (
"cmp"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -106,11 +108,13 @@ var builtinInTmpl = template.Must(template.New("builtinInTmpl").Parse(`
{{- else if eq .Input.TypeName "Time" -}}
compareResult = arg0.Compare(arg1)
{{- else if eq .Input.TypeName "Duration" -}}
compareResult = types.CompareDuration(arg0, arg1)
compareResult = cmp.Compare(arg0, arg1)
{{- else if eq .Input.TypeName "JSON" -}}
compareResult = types.CompareBinaryJSON(arg0, arg1)
{{- else if eq .Input.TypeName "String" -}}
compareResult = types.CompareString(arg0, arg1, b.collation)
{{- else if eq .Input.TypeNameInColumn "Float64" -}}
compareResult = cmp.Compare(arg0, arg1)
{{- else -}}
compareResult = types.Compare{{ .Input.TypeNameInColumn }}(arg0, arg1)
{{- end -}}
Expand Down
3 changes: 2 additions & 1 deletion store/mockstore/unistore/cophandler/topn.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package cophandler

import (
"cmp"
"container/heap"

"github.com/pingcap/errors"
Expand Down Expand Up @@ -104,7 +105,7 @@ func (t *topNHeap) Less(i, j int) bool {
var ret int
var err error
if expression.FieldTypeFromPB(by.GetExpr().GetFieldType()).GetType() == mysql.TypeEnum {
ret = types.CompareUint64(v1.GetUint64(), v2.GetUint64())
ret = cmp.Compare(v1.GetUint64(), v2.GetUint64())
} else {
ret, err = v1.Compare(t.sc, &v2, collate.GetCollator(collate.ProtoToCollation(by.Expr.FieldType.Collate)))
if err != nil {
Expand Down
Loading

0 comments on commit 5dd4296

Please sign in to comment.