diff --git a/util/rowcodec/BUILD.bazel b/util/rowcodec/BUILD.bazel index cbb8ec5438bbf..41ea8eadfca89 100644 --- a/util/rowcodec/BUILD.bazel +++ b/util/rowcodec/BUILD.bazel @@ -36,6 +36,7 @@ go_test( flaky = True, deps = [ "//kv", + "//parser/model", "//parser/mysql", "//sessionctx/stmtctx", "//tablecodec", diff --git a/util/rowcodec/bench_test.go b/util/rowcodec/bench_test.go index 5c10fa8065e6c..c91916f3d0a40 100644 --- a/util/rowcodec/bench_test.go +++ b/util/rowcodec/bench_test.go @@ -19,6 +19,7 @@ import ( "time" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/tablecodec" @@ -28,6 +29,26 @@ import ( "github.com/pingcap/tidb/util/rowcodec" ) +func BenchmarkChecksum(b *testing.B) { + b.ReportAllocs() + datums := types.MakeDatums(1, "abc", 1.1) + tp1 := types.NewFieldType(mysql.TypeLong) + tp2 := types.NewFieldType(mysql.TypeVarchar) + tp3 := types.NewFieldType(mysql.TypeDouble) + cols := []rowcodec.ColData{ + {&model.ColumnInfo{ID: 1, FieldType: *tp1}, &datums[0]}, + {&model.ColumnInfo{ID: 2, FieldType: *tp2}, &datums[1]}, + {&model.ColumnInfo{ID: 3, FieldType: *tp3}, &datums[2]}, + } + row := rowcodec.RowData{Cols: cols} + for i := 0; i < b.N; i++ { + _, err := row.Checksum() + if err != nil { + b.Fatal(err) + } + } +} + func BenchmarkEncode(b *testing.B) { b.ReportAllocs() oldRow := types.MakeDatums(1, "abc", 1.1) diff --git a/util/rowcodec/common.go b/util/rowcodec/common.go index 16af35d91faa2..33cb2dee7760b 100644 --- a/util/rowcodec/common.go +++ b/util/rowcodec/common.go @@ -16,12 +16,16 @@ package rowcodec import ( "encoding/binary" + "hash/crc32" + "math" "reflect" "unsafe" "github.com/pingcap/errors" "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/types" + data "github.com/pingcap/tidb/types" ) // CodecVer is the constant number that represent the new row format. @@ -30,6 +34,7 @@ const CodecVer = 128 var ( errInvalidCodecVer = errors.New("invalid codec version") errInvalidChecksumVer = errors.New("invalid checksum version") + errInvalidChecksumTyp = errors.New("invalid type for checksum") ) // First byte in the encoded value which specifies the encoding type. @@ -240,3 +245,112 @@ func IsNewFormat(rowData []byte) bool { func FieldTypeFromModelColumn(col *model.ColumnInfo) *types.FieldType { return col.FieldType.Clone() } + +// ColData combines the column info as well as its datum. It's used to calculate checksum. +type ColData struct { + *model.ColumnInfo + Datum *data.Datum +} + +// Encode encodes the column datum into bytes for checksum. If buf provided, append encoded data to it. +func (c ColData) Encode(buf []byte) ([]byte, error) { + return appendDatumForChecksum(buf, c.Datum, c.GetType()) +} + +// RowData is a list of ColData for row checksum calculation. +type RowData struct { + // Cols is a list of ColData which is expected to be sorted by id before calling Encode/Checksum. + Cols []ColData + // Data stores the result of Encode. However, it mostly acts as a buffer for encoding columns on checksum + // calculation. + Data []byte +} + +// Len implements sort.Interface for RowData. +func (r RowData) Len() int { return len(r.Cols) } + +// Less implements sort.Interface for RowData. +func (r RowData) Less(i int, j int) bool { return r.Cols[i].ID < r.Cols[j].ID } + +// Swap implements sort.Interface for RowData. +func (r RowData) Swap(i int, j int) { r.Cols[i], r.Cols[j] = r.Cols[j], r.Cols[i] } + +// Encode encodes all columns into bytes (for test purpose). +func (r *RowData) Encode() ([]byte, error) { + var err error + if len(r.Data) > 0 { + r.Data = r.Data[:0] + } + for _, col := range r.Cols { + r.Data, err = col.Encode(r.Data) + if err != nil { + return nil, err + } + } + return r.Data, nil +} + +// Checksum calculates the checksum of columns. Callers should make sure columns are sorted by id. +func (r *RowData) Checksum() (checksum uint32, err error) { + for _, col := range r.Cols { + if len(r.Data) > 0 { + r.Data = r.Data[:0] + } + r.Data, err = col.Encode(r.Data) + if err != nil { + return 0, err + } + checksum = crc32.Update(checksum, crc32.IEEETable, r.Data) + } + return checksum, nil +} + +func appendDatumForChecksum(buf []byte, dat *data.Datum, typ byte) (out []byte, err error) { + defer func() { + if x := recover(); x != nil { + // catch panic when datum and type mismatch + err = errors.Annotate(x.(error), "encode datum for checksum") + } + }() + if dat.IsNull() { + return buf, nil + } + switch typ { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeInt24, mysql.TypeYear: + out = binary.LittleEndian.AppendUint64(buf, dat.GetUint64()) + case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob: + out = appendLengthValue(buf, dat.GetBytes()) + case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate, mysql.TypeNewDate: + out = appendLengthValue(buf, []byte(dat.GetMysqlTime().String())) + case mysql.TypeDuration: + out = appendLengthValue(buf, []byte(dat.GetMysqlDuration().String())) + case mysql.TypeFloat, mysql.TypeDouble: + v := dat.GetFloat64() + if math.IsInf(v, 0) || math.IsNaN(v) { + v = 0 // because ticdc has such a transform + } + out = binary.LittleEndian.AppendUint64(buf, math.Float64bits(v)) + case mysql.TypeNewDecimal: + out = appendLengthValue(buf, []byte(dat.GetMysqlDecimal().String())) + case mysql.TypeEnum: + out = binary.LittleEndian.AppendUint64(buf, dat.GetMysqlEnum().Value) + case mysql.TypeSet: + out = binary.LittleEndian.AppendUint64(buf, dat.GetMysqlSet().Value) + case mysql.TypeBit: + // ticdc transforms a bit value as the following way, no need to handle truncate error here. + v, _ := dat.GetBinaryLiteral().ToInt(nil) + out = binary.LittleEndian.AppendUint64(buf, v) + case mysql.TypeJSON: + out = appendLengthValue(buf, []byte(dat.GetMysqlJSON().String())) + case mysql.TypeNull, mysql.TypeGeometry: + out = buf + default: + return buf, errInvalidChecksumTyp + } + return +} + +func appendLengthValue(buf []byte, val []byte) []byte { + buf = binary.LittleEndian.AppendUint32(buf, uint32(len(val))) + return append(buf, val...) +} diff --git a/util/rowcodec/encoder.go b/util/rowcodec/encoder.go index f59df62ba0323..55fad578d612e 100644 --- a/util/rowcodec/encoder.go +++ b/util/rowcodec/encoder.go @@ -34,52 +34,19 @@ type Encoder struct { values []*types.Datum // Enable indicates whether this encoder should be use. Enable bool - // WithChecksum indicates whether to append checksum to the encoded row data. - WithChecksum bool } // Encode encodes a row from a datums slice. -func (encoder *Encoder) Encode(sc *stmtctx.StatementContext, colIDs []int64, values []types.Datum, buf []byte) ([]byte, error) { +func (encoder *Encoder) Encode(sc *stmtctx.StatementContext, colIDs []int64, values []types.Datum, buf []byte, checksums ...uint32) ([]byte, error) { encoder.reset() - err := encoder.encodeDatums(sc, colIDs, values) - if err != nil { - return nil, err - } - return encoder.row.toBytes(buf[:0]), nil -} - -// EncodeWithExtraChecksum likes Encode but also appends an extra checksum if checksum is enabled. -func (encoder *Encoder) EncodeWithExtraChecksum(sc *stmtctx.StatementContext, colIDs []int64, values []types.Datum, checksum uint32, buf []byte) ([]byte, error) { - encoder.reset() - if encoder.hasChecksum() { - encoder.setExtraChecksum(checksum) - } - err := encoder.encodeDatums(sc, colIDs, values) - if err != nil { - return nil, err - } - return encoder.row.toBytes(buf[:0]), nil -} - -// Checksum caclulates the checksum of datumns. -func (encoder *Encoder) Checksum(sc *stmtctx.StatementContext, colIDs []int64, values []types.Datum) (uint32, error) { - encoder.reset() - encoder.flags |= rowFlagChecksum - err := encoder.encodeDatums(sc, colIDs, values) - if err != nil { - return 0, err - } - return encoder.checksum1, nil -} - -func (encoder *Encoder) encodeDatums(sc *stmtctx.StatementContext, colIDs []int64, values []types.Datum) error { encoder.appendColVals(colIDs, values) numCols, notNullIdx := encoder.reformatCols() err := encoder.encodeRowCols(sc, numCols, notNullIdx) if err != nil { - return err + return nil, err } - return nil + encoder.setChecksums(checksums...) + return encoder.row.toBytes(buf[:0]), nil } func (encoder *Encoder) reset() { @@ -94,9 +61,6 @@ func (encoder *Encoder) reset() { encoder.checksumHeader = 0 encoder.checksum1 = 0 encoder.checksum2 = 0 - if encoder.WithChecksum { - encoder.flags |= rowFlagChecksum - } } func (encoder *Encoder) appendColVals(colIDs []int64, values []types.Datum) { @@ -193,9 +157,6 @@ func (encoder *Encoder) encodeRowCols(sc *stmtctx.StatementContext, numCols, not r.offsets[i] = uint16(len(r.data)) } } - if r.hasChecksum() { - return r.calcChecksum() - } return nil } diff --git a/util/rowcodec/row.go b/util/rowcodec/row.go index fe29f3dceb931..31ad03a9f819e 100644 --- a/util/rowcodec/row.go +++ b/util/rowcodec/row.go @@ -16,7 +16,6 @@ package rowcodec import ( "encoding/binary" - "hash/crc32" ) const ( @@ -90,19 +89,15 @@ func (r *row) hasChecksum() bool { return r.flags&rowFlagChecksum > 0 } func (r *row) hasExtraChecksum() bool { return r.checksumHeader&checksumFlagExtra > 0 } -func (r *row) checksumVersion() int { return int(r.checksumHeader & checksumMaskVersion) } - -func (r *row) calcChecksum() error { - if r.checksumVersion() != 0 { - return errInvalidChecksumVer +func (r *row) setChecksums(checksums ...uint32) { + if len(checksums) > 0 { + r.flags |= rowFlagChecksum + r.checksum1 = checksums[0] + if len(checksums) > 1 { + r.checksumHeader |= checksumFlagExtra + r.checksum2 = checksums[1] + } } - r.checksum1 = crc32.ChecksumIEEE(r.data) - return nil -} - -func (r *row) setExtraChecksum(v uint32) { - r.checksumHeader |= checksumFlagExtra - r.checksum2 = v } func (r *row) getData(i int) []byte { @@ -156,7 +151,7 @@ func (r *row) fromBytes(rowData []byte) error { if r.hasChecksum() { r.checksumHeader = rowData[cursor] - if r.checksumVersion() != 0 { + if r.ChecksumVersion() != 0 { return errInvalidChecksumVer } cursor++ @@ -244,6 +239,10 @@ func (r *row) findColID(colID int64) (idx int, isNil, notFound bool) { return } +// ChecksumVersion returns the version of checksum. Note that it's valid only if checksum has been encoded in the row +// value (callers can check it by `GetChecksum`). +func (r *row) ChecksumVersion() int { return int(r.checksumHeader & checksumMaskVersion) } + // GetChecksum returns the checksum of row data (not null columns). func (r *row) GetChecksum() (uint32, bool) { if !r.hasChecksum() { diff --git a/util/rowcodec/rowcodec_test.go b/util/rowcodec/rowcodec_test.go index be2324c0959aa..dacf576ded5ab 100644 --- a/util/rowcodec/rowcodec_test.go +++ b/util/rowcodec/rowcodec_test.go @@ -15,13 +15,16 @@ package rowcodec_test import ( + "encoding/binary" "hash/crc32" "math" + "sort" "strings" "testing" "time" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/tablecodec" @@ -866,161 +869,422 @@ func Test65535Bug(t *testing.T) { require.Equal(t, text65535, rs.GetString()) } -func TestChecksum(t *testing.T) { - sc := new(stmtctx.StatementContext) - enc := rowcodec.Encoder{WithChecksum: true} - - d0 := types.NewDatum(nil) - d1 := types.NewDatum("foo") - t0 := types.NewFieldType(mysql.TypeNull) - t1 := types.NewFieldType(mysql.TypeString) +func TestColumnEncode(t *testing.T) { + encodeUint64 := func(v uint64) []byte { + return binary.LittleEndian.AppendUint64(nil, v) + } + encodeBytes := func(v []byte) []byte { + return append(binary.LittleEndian.AppendUint32(nil, uint32(len(v))), v...) + } + var ( + buf = make([]byte, 0, 128) + intZero = 0 + intPos = 42 + intNeg = -2 + i8Min = math.MinInt8 + i16Min = math.MinInt16 + i32Min = math.MinInt32 + i64Min = math.MinInt64 + i24Min = -1 << 23 + ct = types.FromDate(2023, 1, 2, 3, 4, 5, 678) + dur = types.Duration{Duration: 123456*time.Microsecond + 7*time.Minute + 8*time.Hour, Fsp: 6} + decZero = types.NewDecFromStringForTest("0.000") + decPos = types.NewDecFromStringForTest("3.14") + decNeg = types.NewDecFromStringForTest("-1.2") + decMin = types.NewMaxOrMinDec(true, 12, 6) + decMax = types.NewMaxOrMinDec(false, 12, 6) + json1 = types.CreateBinaryJSON(nil) + json2 = types.CreateBinaryJSON(int64(42)) + json3 = types.CreateBinaryJSON(map[string]interface{}{"foo": "bar", "a": int64(42)}) + ) for _, tt := range []struct { - name string - ids []int64 - datums []types.Datum - types []*types.FieldType - checksum uint32 + name string + typ *types.FieldType + dat types.Datum + raw []byte + ok bool }{ + {"unspecified", types.NewFieldType(mysql.TypeUnspecified), types.NewDatum(1), nil, false}, + {"wrong", types.NewFieldType(42), types.NewDatum(1), nil, false}, + {"mismatch/timestamp", types.NewFieldType(mysql.TypeTimestamp), types.NewDatum(1), nil, false}, + {"mismatch/datetime", types.NewFieldType(mysql.TypeDatetime), types.NewDatum(1), nil, false}, + {"mismatch/date", types.NewFieldType(mysql.TypeDate), types.NewDatum(1), nil, false}, + {"mismatch/newdate", types.NewFieldType(mysql.TypeNewDate), types.NewDatum(1), nil, false}, + {"mismatch/decimal", types.NewFieldType(mysql.TypeNewDecimal), types.NewDatum(1), nil, false}, + + {"null", types.NewFieldType(mysql.TypeNull), types.NewDatum(1), nil, true}, + {"geometry", types.NewFieldType(mysql.TypeGeometry), types.NewDatum(1), nil, true}, + + {"tinyint/zero", types.NewFieldType(mysql.TypeTiny), types.NewDatum(intZero), encodeUint64(uint64(intZero)), true}, + {"tinyint/pos", types.NewFieldType(mysql.TypeTiny), types.NewDatum(intPos), encodeUint64(uint64(intPos)), true}, + {"tinyint/neg", types.NewFieldType(mysql.TypeTiny), types.NewDatum(intNeg), encodeUint64(uint64(intNeg)), true}, + {"tinyint/min/signed", types.NewFieldType(mysql.TypeTiny), types.NewDatum(i8Min), encodeUint64(uint64(i8Min)), true}, + {"tinyint/max/signed", types.NewFieldType(mysql.TypeTiny), types.NewDatum(math.MaxInt8), encodeUint64(math.MaxInt8), true}, + {"tinyint/max/unsigned", types.NewFieldType(mysql.TypeTiny), types.NewDatum(math.MaxUint8), encodeUint64(math.MaxUint8), true}, + + {"smallint/zero", types.NewFieldType(mysql.TypeShort), types.NewDatum(intZero), encodeUint64(uint64(intZero)), true}, + {"smallint/pos", types.NewFieldType(mysql.TypeShort), types.NewDatum(intPos), encodeUint64(uint64(intPos)), true}, + {"smallint/neg", types.NewFieldType(mysql.TypeShort), types.NewDatum(intNeg), encodeUint64(uint64(intNeg)), true}, + {"smallint/min/signed", types.NewFieldType(mysql.TypeShort), types.NewDatum(i16Min), encodeUint64(uint64(i16Min)), true}, + {"smallint/max/signed", types.NewFieldType(mysql.TypeShort), types.NewDatum(math.MaxInt16), encodeUint64(math.MaxInt16), true}, + {"smallint/max/unsigned", types.NewFieldType(mysql.TypeShort), types.NewDatum(math.MaxUint16), encodeUint64(math.MaxUint16), true}, + + {"int/zero", types.NewFieldType(mysql.TypeLong), types.NewDatum(intZero), encodeUint64(uint64(intZero)), true}, + {"int/pos", types.NewFieldType(mysql.TypeLong), types.NewDatum(intPos), encodeUint64(uint64(intPos)), true}, + {"int/neg", types.NewFieldType(mysql.TypeLong), types.NewDatum(intNeg), encodeUint64(uint64(intNeg)), true}, + {"int/min/signed", types.NewFieldType(mysql.TypeLong), types.NewDatum(i32Min), encodeUint64(uint64(i32Min)), true}, + {"int/max/signed", types.NewFieldType(mysql.TypeLong), types.NewDatum(math.MaxInt32), encodeUint64(math.MaxInt32), true}, + {"int/max/unsigned", types.NewFieldType(mysql.TypeLong), types.NewDatum(math.MaxUint32), encodeUint64(math.MaxUint32), true}, + + {"bigint/zero", types.NewFieldType(mysql.TypeLonglong), types.NewDatum(intZero), encodeUint64(uint64(intZero)), true}, + {"bigint/pos", types.NewFieldType(mysql.TypeLonglong), types.NewDatum(intPos), encodeUint64(uint64(intPos)), true}, + {"bigint/neg", types.NewFieldType(mysql.TypeLonglong), types.NewDatum(intNeg), encodeUint64(uint64(intNeg)), true}, + {"bigint/min/signed", types.NewFieldType(mysql.TypeLonglong), types.NewDatum(i64Min), encodeUint64(uint64(i64Min)), true}, + {"bigint/max/signed", types.NewFieldType(mysql.TypeLonglong), types.NewDatum(math.MaxInt64), encodeUint64(math.MaxInt64), true}, + {"bigint/max/unsigned", types.NewFieldType(mysql.TypeLonglong), types.NewDatum(uint64(math.MaxUint64)), encodeUint64(math.MaxUint64), true}, + + {"mediumint/zero", types.NewFieldType(mysql.TypeInt24), types.NewDatum(intZero), encodeUint64(uint64(intZero)), true}, + {"mediumint/pos", types.NewFieldType(mysql.TypeInt24), types.NewDatum(intPos), encodeUint64(uint64(intPos)), true}, + {"mediumint/neg", types.NewFieldType(mysql.TypeInt24), types.NewDatum(intNeg), encodeUint64(uint64(intNeg)), true}, + {"mediumint/min/signed", types.NewFieldType(mysql.TypeInt24), types.NewDatum(i24Min), encodeUint64(uint64(i24Min)), true}, + {"mediumint/max/signed", types.NewFieldType(mysql.TypeInt24), types.NewDatum(1<<23 - 1), encodeUint64(1<<23 - 1), true}, + {"mediumint/max/unsigned", types.NewFieldType(mysql.TypeInt24), types.NewDatum(1<<24 - 1), encodeUint64(1<<24 - 1), true}, + + {"year", types.NewFieldType(mysql.TypeYear), types.NewDatum(2023), encodeUint64(2023), true}, + + {"varchar", types.NewFieldType(mysql.TypeVarchar), types.NewDatum("foo"), encodeBytes([]byte("foo")), true}, + {"varchar/empty", types.NewFieldType(mysql.TypeVarchar), types.NewDatum(""), encodeBytes([]byte{}), true}, + {"varbinary", types.NewFieldType(mysql.TypeVarString), types.NewDatum([]byte("foo")), encodeBytes([]byte("foo")), true}, + {"varbinary/empty", types.NewFieldType(mysql.TypeVarString), types.NewDatum([]byte("")), encodeBytes([]byte{}), true}, + {"char", types.NewFieldType(mysql.TypeString), types.NewDatum("foo"), encodeBytes([]byte("foo")), true}, + {"char/empty", types.NewFieldType(mysql.TypeString), types.NewDatum(""), encodeBytes([]byte{}), true}, + {"binary", types.NewFieldType(mysql.TypeString), types.NewDatum([]byte("foo")), encodeBytes([]byte("foo")), true}, + {"binary/empty", types.NewFieldType(mysql.TypeString), types.NewDatum([]byte("")), encodeBytes([]byte{}), true}, + {"text", types.NewFieldType(mysql.TypeBlob), types.NewDatum("foo"), encodeBytes([]byte("foo")), true}, + {"text/empty", types.NewFieldType(mysql.TypeBlob), types.NewDatum(""), encodeBytes([]byte{}), true}, + {"blob", types.NewFieldType(mysql.TypeBlob), types.NewDatum([]byte("foo")), encodeBytes([]byte("foo")), true}, + {"blob/empty", types.NewFieldType(mysql.TypeBlob), types.NewDatum([]byte("")), encodeBytes([]byte{}), true}, + {"longtext", types.NewFieldType(mysql.TypeLongBlob), types.NewDatum("foo"), encodeBytes([]byte("foo")), true}, + {"longtext/empty", types.NewFieldType(mysql.TypeLongBlob), types.NewDatum(""), encodeBytes([]byte{}), true}, + {"longblob", types.NewFieldType(mysql.TypeLongBlob), types.NewDatum([]byte("foo")), encodeBytes([]byte("foo")), true}, + {"longblob/empty", types.NewFieldType(mysql.TypeLongBlob), types.NewDatum([]byte("")), encodeBytes([]byte{}), true}, + {"mediumtext", types.NewFieldType(mysql.TypeMediumBlob), types.NewDatum("foo"), encodeBytes([]byte("foo")), true}, + {"mediumtext/empty", types.NewFieldType(mysql.TypeMediumBlob), types.NewDatum(""), encodeBytes([]byte{}), true}, + {"mediumblob", types.NewFieldType(mysql.TypeMediumBlob), types.NewDatum([]byte("foo")), encodeBytes([]byte("foo")), true}, + {"mediumblob/empty", types.NewFieldType(mysql.TypeMediumBlob), types.NewDatum([]byte("")), encodeBytes([]byte{}), true}, + {"tinytext", types.NewFieldType(mysql.TypeTinyBlob), types.NewDatum("foo"), encodeBytes([]byte("foo")), true}, + {"tinytext/empty", types.NewFieldType(mysql.TypeTinyBlob), types.NewDatum(""), encodeBytes([]byte{}), true}, + {"tinyblob", types.NewFieldType(mysql.TypeTinyBlob), types.NewDatum([]byte("foo")), encodeBytes([]byte("foo")), true}, + {"tinyblob/empty", types.NewFieldType(mysql.TypeTinyBlob), types.NewDatum([]byte("")), encodeBytes([]byte{}), true}, + + {"float", types.NewFieldType(mysql.TypeFloat), types.NewDatum(float32(3.14)), encodeUint64(math.Float64bits(float64(float32(3.14)))), true}, + {"float/nan", types.NewFieldType(mysql.TypeFloat), types.NewDatum(float32(math.NaN())), encodeUint64(math.Float64bits(0)), true}, + {"float/+inf", types.NewFieldType(mysql.TypeFloat), types.NewDatum(float32(math.Inf(1))), encodeUint64(math.Float64bits(0)), true}, + {"float/-inf", types.NewFieldType(mysql.TypeFloat), types.NewDatum(float32(math.Inf(-1))), encodeUint64(math.Float64bits(0)), true}, + {"double", types.NewFieldType(mysql.TypeDouble), types.NewDatum(float64(3.14)), encodeUint64(math.Float64bits(3.14)), true}, + {"double/nan", types.NewFieldType(mysql.TypeDouble), types.NewDatum(math.NaN()), encodeUint64(math.Float64bits(0)), true}, + {"double/+inf", types.NewFieldType(mysql.TypeDouble), types.NewDatum(math.Inf(1)), encodeUint64(math.Float64bits(0)), true}, + {"double/-inf", types.NewFieldType(mysql.TypeDouble), types.NewDatum(math.Inf(-1)), encodeUint64(math.Float64bits(0)), true}, + + {"enum", types.NewFieldType(mysql.TypeEnum), types.NewDatum(0b010), encodeUint64(0b010), true}, + {"set", types.NewFieldType(mysql.TypeSet), types.NewDatum(0b101), encodeUint64(0b101), true}, + {"bit", types.NewFieldType(mysql.TypeBit), types.NewBinaryLiteralDatum([]byte{0x12, 0x34}), encodeUint64(0x1234), true}, + {"bit/truncate", types.NewFieldType(mysql.TypeBit), types.NewBinaryLiteralDatum([]byte{0x12, 0x34, 0x12, 0x34, 0x12, 0x34, 0x12, 0x34, 0xff}), encodeUint64(math.MaxUint64), true}, + + { + "timestamp", types.NewFieldType(mysql.TypeTimestamp), + types.NewTimeDatum(types.NewTime(ct, mysql.TypeTimestamp, 3)), + encodeBytes([]byte(types.NewTime(ct, mysql.TypeTimestamp, 3).String())), + true, + }, + { + "timestamp/zero", types.NewFieldType(mysql.TypeTimestamp), + types.NewTimeDatum(types.ZeroTimestamp), + encodeBytes([]byte(types.ZeroTimestamp.String())), + true, + }, + { + "timestamp/min", types.NewFieldType(mysql.TypeTimestamp), + types.NewTimeDatum(types.MinTimestamp), + encodeBytes([]byte(types.MinTimestamp.String())), + true, + }, + { + "timestamp/max", types.NewFieldType(mysql.TypeTimestamp), + types.NewTimeDatum(types.MaxTimestamp), + encodeBytes([]byte(types.MaxTimestamp.String())), + true, + }, + { + "datetime", types.NewFieldType(mysql.TypeDatetime), + types.NewTimeDatum(types.NewTime(ct, mysql.TypeDatetime, 3)), + encodeBytes([]byte(types.NewTime(ct, mysql.TypeDatetime, 3).String())), + true, + }, + { + "datetime/zero", types.NewFieldType(mysql.TypeDatetime), + types.NewTimeDatum(types.ZeroDatetime), + encodeBytes([]byte(types.ZeroTimestamp.String())), + true, + }, + { + "datetime/min", types.NewFieldType(mysql.TypeDatetime), + types.NewTimeDatum(types.NewTime(types.MinDatetime, mysql.TypeDatetime, 6)), + encodeBytes([]byte(types.NewTime(types.MinDatetime, mysql.TypeDatetime, 6).String())), + true, + }, { - "Empty", - []int64{}, - []types.Datum{}, - []*types.FieldType{}, - 0, + "datetime/max", types.NewFieldType(mysql.TypeDatetime), + types.NewTimeDatum(types.NewTime(types.MaxDatetime, mysql.TypeDatetime, 6)), + encodeBytes([]byte(types.NewTime(types.MaxDatetime, mysql.TypeDatetime, 6).String())), + true, }, { - "NullOnly", - []int64{1}, - []types.Datum{d0}, - []*types.FieldType{t0}, - 0, + "date", types.NewFieldType(mysql.TypeDate), + types.NewTimeDatum(types.NewTime(ct, mysql.TypeDate, 3)), + encodeBytes([]byte(types.NewTime(ct, mysql.TypeDate, 3).String())), + true, }, { - "SingleDatum", - []int64{1}, - []types.Datum{d1}, - []*types.FieldType{t1}, - crc32.ChecksumIEEE(d1.GetBytes()), + "date/zero", types.NewFieldType(mysql.TypeDate), + types.NewTimeDatum(types.ZeroDate), + encodeBytes([]byte(types.ZeroDate.String())), + true, }, { - "MultiDatums", - []int64{1, 2}, - []types.Datum{d1, d1}, - []*types.FieldType{t1, t1}, - crc32.ChecksumIEEE(append(d1.GetBytes(), d1.GetBytes()...)), + "date/min", + types.NewFieldType(mysql.TypeDate), + types.NewTimeDatum(types.NewTime(types.MinDatetime, mysql.TypeDate, 6)), + encodeBytes([]byte(types.NewTime(types.MinDatetime, mysql.TypeDate, 6).String())), + true, }, { - "DataAndNull1", - []int64{1, 2}, - []types.Datum{d1, d0}, - []*types.FieldType{t1, t0}, - crc32.ChecksumIEEE(d1.GetBytes()), + "date/max", + types.NewFieldType(mysql.TypeDate), + types.NewTimeDatum(types.NewTime(types.MaxDatetime, mysql.TypeDate, 6)), + encodeBytes([]byte(types.NewTime(types.MaxDatetime, mysql.TypeDate, 6).String())), + true, }, { - "DataAndNull2", - []int64{1, 2}, - []types.Datum{d0, d1}, - []*types.FieldType{t0, t1}, - crc32.ChecksumIEEE(d1.GetBytes()), + "newdate", types.NewFieldType(mysql.TypeNewDate), + types.NewTimeDatum(types.NewTime(ct, mysql.TypeNewDate, 3)), + encodeBytes([]byte(types.NewTime(ct, mysql.TypeNewDate, 3).String())), + true, }, + { + "newdate/zero", types.NewFieldType(mysql.TypeNewDate), + types.NewTimeDatum(types.ZeroDate), + encodeBytes([]byte(types.ZeroDate.String())), + true, + }, + { + "newdate/min", + types.NewFieldType(mysql.TypeNewDate), + types.NewTimeDatum(types.NewTime(types.MinDatetime, mysql.TypeNewDate, 6)), + encodeBytes([]byte(types.NewTime(types.MinDatetime, mysql.TypeNewDate, 6).String())), + true, + }, + { + "newdate/max", + types.NewFieldType(mysql.TypeNewDate), + types.NewTimeDatum(types.NewTime(types.MaxDatetime, mysql.TypeNewDate, 6)), + encodeBytes([]byte(types.NewTime(types.MaxDatetime, mysql.TypeNewDate, 6).String())), + true, + }, + + {"time", types.NewFieldType(mysql.TypeDuration), types.NewDurationDatum(dur), encodeBytes([]byte(dur.String())), true}, + {"time/zero", types.NewFieldType(mysql.TypeDuration), types.NewDurationDatum(types.ZeroDuration), encodeBytes([]byte(types.ZeroDuration.String())), true}, + {"time/max", types.NewFieldType(mysql.TypeDuration), types.NewDurationDatum(types.MaxMySQLDuration(3)), encodeBytes([]byte(types.MaxMySQLDuration(3).String())), true}, + + {"decimal/zero", types.NewFieldType(mysql.TypeNewDecimal), types.NewDecimalDatum(decZero), encodeBytes([]byte(decZero.String())), true}, + {"decimal/pos", types.NewFieldType(mysql.TypeNewDecimal), types.NewDecimalDatum(decPos), encodeBytes([]byte(decPos.String())), true}, + {"decimal/neg", types.NewFieldType(mysql.TypeNewDecimal), types.NewDecimalDatum(decNeg), encodeBytes([]byte(decNeg.String())), true}, + {"decimal/min", types.NewFieldType(mysql.TypeNewDecimal), types.NewDecimalDatum(decMin), encodeBytes([]byte(decMin.String())), true}, + {"decimal/max", types.NewFieldType(mysql.TypeNewDecimal), types.NewDecimalDatum(decMax), encodeBytes([]byte(decMax.String())), true}, + + {"json/1", types.NewFieldType(mysql.TypeJSON), types.NewJSONDatum(json1), encodeBytes([]byte(json1.String())), true}, + {"json/2", types.NewFieldType(mysql.TypeJSON), types.NewJSONDatum(json2), encodeBytes([]byte(json2.String())), true}, + {"json/3", types.NewFieldType(mysql.TypeJSON), types.NewJSONDatum(json3), encodeBytes([]byte(json3.String())), true}, } { t.Run(tt.name, func(t *testing.T) { - for _, enable := range []bool{false, true} { - enc.WithChecksum = enable - checksum, err := enc.Checksum(sc, tt.ids, tt.datums) + col := rowcodec.ColData{&model.ColumnInfo{FieldType: *tt.typ}, &tt.dat} + raw, err := col.Encode(buf[:0]) + if tt.ok { require.NoError(t, err) - require.Equal(t, tt.checksum, checksum) + if len(tt.raw) == 0 { + require.Len(t, raw, 0) + } else { + require.Equal(t, tt.raw, raw) + } + } else { + require.Error(t, err) } + }) + } - // encode - raw, err := enc.Encode(sc, tt.ids, tt.datums, nil) + t.Run("nulldatum", func(t *testing.T) { + for _, typ := range []byte{ + mysql.TypeUnspecified, + mysql.TypeTiny, + mysql.TypeShort, + mysql.TypeLong, + mysql.TypeFloat, + mysql.TypeDouble, + mysql.TypeNull, + mysql.TypeTimestamp, + mysql.TypeLonglong, + mysql.TypeInt24, + mysql.TypeDate, + mysql.TypeDuration, + mysql.TypeDatetime, + mysql.TypeYear, + mysql.TypeNewDate, + mysql.TypeVarchar, + mysql.TypeBit, + mysql.TypeJSON, + mysql.TypeNewDecimal, + mysql.TypeEnum, + mysql.TypeSet, + mysql.TypeTinyBlob, + mysql.TypeMediumBlob, + mysql.TypeLongBlob, + mysql.TypeBlob, + mysql.TypeVarString, + mysql.TypeString, + mysql.TypeGeometry, + 42, // wrong type + } { + ft := types.NewFieldType(typ) + dat := types.NewDatum(nil) + col := rowcodec.ColData{&model.ColumnInfo{FieldType: *ft}, &dat} + raw, err := col.Encode(nil) require.NoError(t, err) - checksum, ok := enc.GetChecksum() - require.True(t, ok) - require.Equal(t, tt.checksum, checksum) - - // decode - cols := make([]rowcodec.ColInfo, len(tt.ids)) - for i, id := range tt.ids { - cols[i] = rowcodec.ColInfo{ID: id, Ft: tt.types[i]} + require.Len(t, raw, 0) + } + }) +} + +func TestRowChecksum(t *testing.T) { + typ1 := types.NewFieldType(mysql.TypeNull) + dat1 := types.NewDatum(nil) + col1 := rowcodec.ColData{&model.ColumnInfo{ID: 1, FieldType: *typ1}, &dat1} + typ2 := types.NewFieldType(mysql.TypeLong) + dat2 := types.NewDatum(42) + col2 := rowcodec.ColData{&model.ColumnInfo{ID: 2, FieldType: *typ2}, &dat2} + typ3 := types.NewFieldType(mysql.TypeVarchar) + dat3 := types.NewDatum("foobar") + col3 := rowcodec.ColData{&model.ColumnInfo{ID: 2, FieldType: *typ3}, &dat3} + buf := make([]byte, 0, 64) + for _, tt := range []struct { + name string + cols []rowcodec.ColData + }{ + {"nil", nil}, + {"empty", []rowcodec.ColData{}}, + {"nullonly", []rowcodec.ColData{col1}}, + {"ordered", []rowcodec.ColData{col1, col2, col3}}, + {"unordered", []rowcodec.ColData{col3, col1, col2}}, + } { + t.Run(tt.name, func(t *testing.T) { + row := rowcodec.RowData{tt.cols, buf} + if !sort.IsSorted(row) { + sort.Sort(row) } - dec := rowcodec.NewDatumMapDecoder(cols, sc.TimeZone) - m, err := dec.DecodeToDatumMap(raw, nil) + checksum, err := row.Checksum() require.NoError(t, err) - for i, id := range tt.ids { - require.Equal(t, m[id], tt.datums[i]) - } - checksum, ok = dec.GetChecksum() - require.True(t, ok) - require.Equal(t, tt.checksum, checksum) + raw, err := row.Encode() + require.NoError(t, err) + require.Equal(t, crc32.ChecksumIEEE(raw), checksum) }) } } -func TestExtraChecksum(t *testing.T) { +func TestEncodeDecodeRowWithChecksum(t *testing.T) { sc := new(stmtctx.StatementContext) enc := rowcodec.Encoder{} - d0 := types.NewDatum(nil) - d1 := types.NewDatum("foo") - t0 := types.NewFieldType(mysql.TypeNull) - t1 := types.NewFieldType(mysql.TypeString) - extraChecksum := uint32(42) + for _, tt := range []struct { + name string + checksums []uint32 + }{ + {"NoChecksum", nil}, + {"OneChecksum", []uint32{1}}, + {"TwoChecksum", []uint32{1, 2}}, + {"ThreeChecksum", []uint32{1, 2, 3}}, + } { + t.Run(tt.name, func(t *testing.T) { + raw, err := enc.Encode(sc, nil, nil, nil, tt.checksums...) + require.NoError(t, err) + dec := rowcodec.NewDatumMapDecoder([]rowcodec.ColInfo{}, sc.TimeZone) + _, err = dec.DecodeToDatumMap(raw, nil) + require.NoError(t, err) + v1, ok1 := enc.GetChecksum() + v2, ok2 := enc.GetExtraChecksum() + v3, ok3 := dec.GetChecksum() + v4, ok4 := dec.GetExtraChecksum() + if len(tt.checksums) == 0 { + require.False(t, ok1) + require.False(t, ok2) + require.False(t, ok3) + require.False(t, ok4) + } else if len(tt.checksums) == 1 { + require.True(t, ok1) + require.False(t, ok2) + require.True(t, ok3) + require.False(t, ok4) + require.Equal(t, tt.checksums[0], v1) + require.Equal(t, tt.checksums[0], v3) + require.Zero(t, v2) + require.Zero(t, v4) + } else { + require.True(t, ok1) + require.True(t, ok2) + require.True(t, ok3) + require.True(t, ok4) + require.Equal(t, tt.checksums[0], v1) + require.Equal(t, tt.checksums[1], v2) + require.Equal(t, tt.checksums[0], v3) + require.Equal(t, tt.checksums[1], v4) + } + }) + } - t.Run("ChecksumDisabled", func(t *testing.T) { - raw1, err := enc.EncodeWithExtraChecksum(sc, []int64{1, 2}, []types.Datum{d0, d1}, extraChecksum, nil) - require.NoError(t, err) - _, ok := enc.GetChecksum() - require.False(t, ok) - _, ok = enc.GetExtraChecksum() - require.False(t, ok) + t.Run("ReuseDecoder", func(t *testing.T) { + dec := rowcodec.NewDatumMapDecoder([]rowcodec.ColInfo{}, sc.TimeZone) - raw2, err := enc.Encode(sc, []int64{1, 2}, []types.Datum{d0, d1}, nil) + raw1, err := enc.Encode(sc, nil, nil, nil) require.NoError(t, err) - require.Equal(t, raw1, raw2) - }) - - enc.WithChecksum = true - - t.Run("EncodeDecode", func(t *testing.T) { - // encode - raw, err := enc.EncodeWithExtraChecksum(sc, []int64{1, 2}, []types.Datum{d0, d1}, extraChecksum, nil) + _, err = dec.DecodeToDatumMap(raw1, nil) + require.NoError(t, err) + v1, ok1 := dec.GetChecksum() + v2, ok2 := dec.GetExtraChecksum() + require.False(t, ok1) + require.False(t, ok2) + require.Zero(t, v1) + require.Zero(t, v2) + + raw2, err := enc.Encode(sc, nil, nil, nil, 1, 2) require.NoError(t, err) - h, ok := enc.GetChecksum() - require.True(t, ok) - require.Equal(t, crc32.ChecksumIEEE(d1.GetBytes()), h) - h, ok = enc.GetExtraChecksum() - require.True(t, ok) - require.Equal(t, extraChecksum, h) - - // decode - dec := rowcodec.NewDatumMapDecoder([]rowcodec.ColInfo{{ID: 1, Ft: t0}, {ID: 2, Ft: t1}}, sc.TimeZone) - m, err := dec.DecodeToDatumMap(raw, nil) + _, err = dec.DecodeToDatumMap(raw2, nil) require.NoError(t, err) - require.Equal(t, m[1], d0) - require.Equal(t, m[2], d1) - h, ok = dec.GetChecksum() - require.True(t, ok) - require.Equal(t, crc32.ChecksumIEEE(d1.GetBytes()), h) - h, ok = dec.GetExtraChecksum() - require.True(t, ok) - require.Equal(t, extraChecksum, h) - - // decode data without checksum by same decoder - enc.WithChecksum = false - raw, err = enc.Encode(sc, []int64{1, 2}, []types.Datum{d0, d1}, nil) + v1, ok1 = dec.GetChecksum() + v2, ok2 = dec.GetExtraChecksum() + require.True(t, ok1) + require.True(t, ok2) + require.Equal(t, uint32(1), v1) + require.Equal(t, uint32(2), v2) + + raw3, err := enc.Encode(sc, nil, nil, nil, 1) require.NoError(t, err) - m, err = dec.DecodeToDatumMap(raw, nil) + _, err = dec.DecodeToDatumMap(raw3, nil) require.NoError(t, err) - require.Equal(t, m[1], d0) - require.Equal(t, m[2], d1) - h, ok = dec.GetChecksum() - require.False(t, ok) - require.Equal(t, uint32(0), h) - h, ok = dec.GetExtraChecksum() - require.False(t, ok) - require.Equal(t, uint32(0), h) + v1, ok1 = dec.GetChecksum() + v2, ok2 = dec.GetExtraChecksum() + require.True(t, ok1) + require.False(t, ok2) + require.Equal(t, uint32(1), v1) + require.Zero(t, v2) }) }