Skip to content

Commit

Permalink
types, util: clean up compareDatum (#30815)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjhuang2016 authored Dec 17, 2021
1 parent 8e11e03 commit 321d307
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 152 deletions.
158 changes: 9 additions & 149 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,6 @@ func (d *Datum) SetValue(val interface{}, tp *types.FieldType) {

// Compare compares datum to another datum.
// Notes: don't rely on datum.collation to get the collator, it's tend to buggy.
// TODO: use this function to replace CompareDatum. After we remove all of usage of CompareDatum, we can rename this function back to CompareDatum.
func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collate.Collator) (int, error) {
if d.k == KindMysqlJSON && ad.k != KindMysqlJSON {
cmp, err := ad.Compare(sc, d, comparer)
Expand Down Expand Up @@ -583,74 +582,19 @@ func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collat
case KindFloat32, KindFloat64:
return d.compareFloat64(sc, ad.GetFloat64())
case KindString:
return d.compareStringNew(sc, ad.GetString(), comparer)
return d.compareString(sc, ad.GetString(), comparer)
case KindBytes:
return d.compareStringNew(sc, ad.GetString(), comparer)
return d.compareString(sc, ad.GetString(), comparer)
case KindMysqlDecimal:
return d.compareMysqlDecimal(sc, ad.GetMysqlDecimal())
case KindMysqlDuration:
return d.compareMysqlDuration(sc, ad.GetMysqlDuration())
case KindMysqlEnum:
return d.compareMysqlEnumNew(sc, ad.GetMysqlEnum(), comparer)
return d.compareMysqlEnum(sc, ad.GetMysqlEnum(), comparer)
case KindBinaryLiteral, KindMysqlBit:
return d.compareBinaryLiteralNew(sc, ad.GetBinaryLiteral4Cmp(), comparer)
return d.compareBinaryLiteral(sc, ad.GetBinaryLiteral4Cmp(), comparer)
case KindMysqlSet:
return d.compareMysqlSetNew(sc, ad.GetMysqlSet(), comparer)
case KindMysqlJSON:
return d.compareMysqlJSON(sc, ad.GetMysqlJSON())
case KindMysqlTime:
return d.compareMysqlTime(sc, ad.GetMysqlTime())
default:
return 0, nil
}
}

// CompareDatum compares datum to another datum.
// Deprecated: will be replaced with Compare.
// TODO: return error properly.
func (d *Datum) CompareDatum(sc *stmtctx.StatementContext, ad *Datum) (int, error) {
if d.k == KindMysqlJSON && ad.k != KindMysqlJSON {
cmp, err := ad.CompareDatum(sc, d)
return cmp * -1, errors.Trace(err)
}
switch ad.k {
case KindNull:
if d.k == KindNull {
return 0, nil
}
return 1, nil
case KindMinNotNull:
if d.k == KindNull {
return -1, nil
} else if d.k == KindMinNotNull {
return 0, nil
}
return 1, nil
case KindMaxValue:
if d.k == KindMaxValue {
return 0, nil
}
return -1, nil
case KindInt64:
return d.compareInt64(sc, ad.GetInt64())
case KindUint64:
return d.compareUint64(sc, ad.GetUint64())
case KindFloat32, KindFloat64:
return d.compareFloat64(sc, ad.GetFloat64())
case KindString:
return d.compareString(sc, ad.GetString(), d.collation)
case KindBytes:
return d.compareBytes(sc, ad.GetBytes())
case KindMysqlDecimal:
return d.compareMysqlDecimal(sc, ad.GetMysqlDecimal())
case KindMysqlDuration:
return d.compareMysqlDuration(sc, ad.GetMysqlDuration())
case KindMysqlEnum:
return d.compareMysqlEnum(sc, ad.GetMysqlEnum())
case KindBinaryLiteral, KindMysqlBit:
return d.compareBinaryLiteral(sc, ad.GetBinaryLiteral4Cmp())
case KindMysqlSet:
return d.compareMysqlSet(sc, ad.GetMysqlSet())
return d.compareMysqlSet(sc, ad.GetMysqlSet(), comparer)
case KindMysqlJSON:
return d.compareMysqlJSON(sc, ad.GetMysqlJSON())
case KindMysqlTime:
Expand Down Expand Up @@ -731,7 +675,7 @@ func (d *Datum) compareFloat64(sc *stmtctx.StatementContext, f float64) (int, er
}
}

func (d *Datum) compareStringNew(sc *stmtctx.StatementContext, s string, comparer collate.Collator) (int, error) {
func (d *Datum) compareString(sc *stmtctx.StatementContext, s string, comparer collate.Collator) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
Expand Down Expand Up @@ -764,44 +708,6 @@ func (d *Datum) compareStringNew(sc *stmtctx.StatementContext, s string, compare
}
}

func (d *Datum) compareString(sc *stmtctx.StatementContext, s string, retCollation string) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
case KindMaxValue:
return 1, nil
case KindString, KindBytes:
return CompareString(d.GetString(), s, d.collation), nil
case KindMysqlDecimal:
dec := new(MyDecimal)
err := sc.HandleTruncate(dec.FromString(hack.Slice(s)))
return d.GetMysqlDecimal().Compare(dec), errors.Trace(err)
case KindMysqlTime:
dt, err := ParseDatetime(sc, s)
return d.GetMysqlTime().Compare(dt), errors.Trace(err)
case KindMysqlDuration:
dur, err := ParseDuration(sc, s, MaxFsp)
return d.GetMysqlDuration().Compare(dur), errors.Trace(err)
case KindMysqlSet:
return CompareString(d.GetMysqlSet().String(), s, d.collation), nil
case KindMysqlEnum:
return CompareString(d.GetMysqlEnum().String(), s, d.collation), nil
case KindBinaryLiteral, KindMysqlBit:
return CompareString(d.GetBinaryLiteral4Cmp().ToString(), s, d.collation), nil
default:
fVal, err := StrToFloat(sc, s, false)
if err != nil {
return 0, errors.Trace(err)
}
return d.compareFloat64(sc, fVal)
}
}

func (d *Datum) compareBytes(sc *stmtctx.StatementContext, b []byte) (int, error) {
str := string(hack.String(b))
return d.compareString(sc, str, d.collation)
}

func (d *Datum) compareMysqlDecimal(sc *stmtctx.StatementContext, dec *MyDecimal) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
Expand Down Expand Up @@ -839,7 +745,7 @@ func (d *Datum) compareMysqlDuration(sc *stmtctx.StatementContext, dur Duration)
}
}

func (d *Datum) compareMysqlEnumNew(sc *stmtctx.StatementContext, enum Enum, comparer collate.Collator) (int, error) {
func (d *Datum) compareMysqlEnum(sc *stmtctx.StatementContext, enum Enum, comparer collate.Collator) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
Expand All @@ -852,7 +758,7 @@ func (d *Datum) compareMysqlEnumNew(sc *stmtctx.StatementContext, enum Enum, com
}
}

func (d *Datum) compareBinaryLiteralNew(sc *stmtctx.StatementContext, b BinaryLiteral, comparer collate.Collator) (int, error) {
func (d *Datum) compareBinaryLiteral(sc *stmtctx.StatementContext, b BinaryLiteral, comparer collate.Collator) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
Expand All @@ -872,7 +778,7 @@ func (d *Datum) compareBinaryLiteralNew(sc *stmtctx.StatementContext, b BinaryLi
}
}

func (d *Datum) compareMysqlSetNew(sc *stmtctx.StatementContext, set Set, comparer collate.Collator) (int, error) {
func (d *Datum) compareMysqlSet(sc *stmtctx.StatementContext, set Set, comparer collate.Collator) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
Expand All @@ -885,52 +791,6 @@ func (d *Datum) compareMysqlSetNew(sc *stmtctx.StatementContext, set Set, compar
}
}

func (d *Datum) compareMysqlEnum(sc *stmtctx.StatementContext, enum Enum) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
case KindMaxValue:
return 1, nil
case KindString, KindBytes, KindMysqlEnum, KindMysqlSet:
return CompareString(d.GetString(), enum.String(), d.collation), nil
default:
return d.compareFloat64(sc, enum.ToNumber())
}
}

func (d *Datum) compareBinaryLiteral(sc *stmtctx.StatementContext, b BinaryLiteral) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
case KindMaxValue:
return 1, nil
case KindString, KindBytes:
fallthrough // in this case, d is converted to Binary and then compared with b
case KindBinaryLiteral, KindMysqlBit:
return CompareString(d.GetBinaryLiteral4Cmp().ToString(), b.ToString(), d.collation), nil
default:
val, err := b.ToInt(sc)
if err != nil {
return 0, errors.Trace(err)
}
result, err := d.compareFloat64(sc, float64(val))
return result, errors.Trace(err)
}
}

func (d *Datum) compareMysqlSet(sc *stmtctx.StatementContext, set Set) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
case KindMaxValue:
return 1, nil
case KindString, KindBytes, KindMysqlEnum, KindMysqlSet:
return CompareString(d.GetString(), set.String(), d.collation), nil
default:
return d.compareFloat64(sc, set.ToNumber())
}
}

func (d *Datum) compareMysqlJSON(sc *stmtctx.StatementContext, target json.BinaryJSON) (int, error) {
origin, err := d.ToMysqlJSON()
if err != nil {
Expand Down
3 changes: 0 additions & 3 deletions util/ranger/points.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,6 @@ func (r *builder) buildFromIn(expr *expression.ScalarFunction) ([]*point, bool)
hasNull = true
continue
}
if dt.Kind() == types.KindString || dt.Kind() == types.KindBinaryLiteral {
dt.SetString(dt.GetString(), colCollate)
}
if expr.GetArgs()[0].GetType().Tp == mysql.TypeEnum {
switch dt.Kind() {
case types.KindString, types.KindBytes, types.KindBinaryLiteral:
Expand Down

0 comments on commit 321d307

Please sign in to comment.