diff --git a/pkg/dbutil/common.go b/pkg/dbutil/common.go index 2195f3471..3d4ae2382 100644 --- a/pkg/dbutil/common.go +++ b/pkg/dbutil/common.go @@ -23,6 +23,8 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/model" + "github.com/pingcap/tidb-tools/pkg/utils" + "github.com/pingcap/tidb/types" log "github.com/sirupsen/logrus" ) @@ -37,6 +39,9 @@ const ( var ( // ErrVersionNotFound means can't get the database's version ErrVersionNotFound = errors.New("can't get the database's version") + + // ErrNoData means no data in table + ErrNoData = errors.New("no data found") ) // DBConfig is database configuration. @@ -164,48 +169,101 @@ func GetRowCount(ctx context.Context, db *sql.DB, schemaName string, tableName s return cnt.Int64, nil } -// GetRandomValues returns some random value of a column. -func GetRandomValues(ctx context.Context, db *sql.DB, schemaName, table, column string, num int64, min, max interface{}, limitRange string, collation string) ([]interface{}, error) { +// GetRandomValues returns some random value and these value's count of a column, just like sampling. Tips: limitArgs is the value in limitRange. +func GetRandomValues(ctx context.Context, db *sql.DB, schemaName, table, column string, num int, limitRange string, limitArgs []interface{}, collation string) ([]string, []int, error) { /* example: - mysql> SELECT `id` FROM (SELECT `id` FROM `test`.`test` WHERE `id` COLLATE "latin1_bin" > 0 AND `id` COLLATE "latin1_bin" < 100 AND true ORDER BY RAND() LIMIT 3)rand_tmp ORDER BY `id` COLLATE "latin1_bin"; - +----------+ - | rand_tmp | - +----------+ - | 15 | - | 58 | - | 67 | - +----------+ + mysql> SELECT `id`, COUNT(*) count FROM (SELECT `id` FROM `test`.`test` WHERE `id` COLLATE "latin1_bin" > 0 AND `id` COLLATE "latin1_bin" < 100 ORDER BY RAND() LIMIT 5) rand_tmp GROUP BY `id` ORDER BY `id` COLLATE "latin1_bin"; + +------+-------+ + | id | count | + +------+-------+ + | 1 | 2 | + | 2 | 2 | + | 3 | 1 | + +------+-------+ + + FIXME: TiDB now don't return rand value when use `ORDER BY RAND()` */ - if limitRange != "" { - limitRange = "true" + if limitRange == "" { + limitRange = "TRUE" } if collation != "" { collation = fmt.Sprintf(" COLLATE \"%s\"", collation) } - randomValue := make([]interface{}, 0, num) - query := fmt.Sprintf("SELECT `%s` FROM (SELECT `%s` FROM `%s`.`%s` WHERE `%s`%s > ? AND `%s`%s < ? AND %s ORDER BY RAND() LIMIT %d)rand_tmp ORDER BY `%s`%s", - column, column, schemaName, table, column, collation, column, collation, limitRange, num, column, collation) - log.Debugf("get random values sql: %s, min: %v, max: %v", query, min, max) - rows, err := db.QueryContext(ctx, query, min, max) + randomValue := make([]string, 0, num) + valueCount := make([]int, 0, num) + + query := fmt.Sprintf("SELECT %[1]s, COUNT(*) count FROM (SELECT %[1]s FROM %[2]s WHERE %[3]s ORDER BY RAND() LIMIT %[4]d)rand_tmp GROUP BY %[1]s ORDER BY %[1]s%[5]s", + escapeName(column), TableName(schemaName, table), limitRange, num, collation) + log.Debugf("get random values sql: %s, args: %v", query, limitArgs) + + rows, err := db.QueryContext(ctx, query, limitArgs...) if err != nil { - return nil, errors.Trace(err) + return nil, nil, errors.Trace(err) } defer rows.Close() for rows.Next() { - var value interface{} - err = rows.Scan(&value) + var value string + var count int + err = rows.Scan(&value, &count) if err != nil { - return nil, errors.Trace(err) + return nil, nil, errors.Trace(err) } randomValue = append(randomValue, value) + valueCount = append(valueCount, count) } - return randomValue, nil + return randomValue, valueCount, errors.Trace(rows.Err()) +} + +// GetMinMaxValue return min and max value of given column by specified limitRange condition. +func GetMinMaxValue(ctx context.Context, db *sql.DB, schema, table, column string, limitRange string, limitArgs []interface{}, collation string) (string, string, error) { + /* + example: + mysql> SELECT MIN(`id`) as MIN, MAX(`id`) as MAX FROM `test`.`testa` WHERE id > 0 AND id < 10; + +------+------+ + | MIN | MAX | + +------+------+ + | 1 | 2 | + +------+------+ + */ + + if limitRange == "" { + limitRange = "TRUE" + } + + if collation != "" { + collation = fmt.Sprintf(" COLLATE \"%s\"", collation) + } + + query := fmt.Sprintf("SELECT /*!40001 SQL_NO_CACHE */ MIN(`%s`%s) as MIN, MAX(`%s`%s) as MAX FROM `%s`.`%s` WHERE %s", + column, collation, column, collation, schema, table, limitRange) + log.Debugf("GetMinMaxValue query: %v, args: %v", query, limitArgs) + + var min, max sql.NullString + rows, err := db.QueryContext(ctx, query, limitArgs...) + if err != nil { + return "", "", errors.Trace(err) + } + defer rows.Close() + + for rows.Next() { + err = rows.Scan(&min, &max) + if err != nil { + return "", "", errors.Trace(err) + } + } + + if !min.Valid || !max.Valid { + // don't have any data + return "", "", ErrNoData + } + + return min.String, max.String, errors.Trace(rows.Err()) } // GetTables returns name of all tables in the specified schema @@ -302,8 +360,8 @@ func GetCRC32Checksum(ctx context.Context, db *sql.DB, schemaName, tableName str columnIsNull = append(columnIsNull, fmt.Sprintf("ISNULL(`%s`)", col.Name.O)) } - query := fmt.Sprintf("SELECT BIT_XOR(CAST(CRC32(CONCAT_WS(',', %s, CONCAT(%s)))AS UNSIGNED)) AS checksum FROM `%s`.`%s` WHERE %s;", - strings.Join(columnNames, ", "), strings.Join(columnIsNull, ", "), schemaName, tableName, limitRange) + query := fmt.Sprintf("SELECT BIT_XOR(CAST(CRC32(CONCAT_WS(',', %s, CONCAT(%s)))AS UNSIGNED)) AS checksum FROM %s WHERE %s;", + strings.Join(columnNames, ", "), strings.Join(columnIsNull, ", "), TableName(schemaName, tableName), limitRange) log.Debugf("checksum sql: %s, args: %v", query, args) var checksum sql.NullInt64 @@ -320,6 +378,130 @@ func GetCRC32Checksum(ctx context.Context, db *sql.DB, schemaName, tableName str return checksum.Int64, nil } +// Bucket saves the bucket information from TiDB. +type Bucket struct { + Count int64 + LowerBound string + UpperBound string +} + +// GetBucketsInfo SHOW STATS_BUCKETS in TiDB. +func GetBucketsInfo(ctx context.Context, db *sql.DB, schema, table string, tableInfo *model.TableInfo) (map[string][]Bucket, error) { + /* + example in tidb: + mysql> SHOW STATS_BUCKETS WHERE db_name= "test" AND table_name="testa"; + +---------+------------+----------------+-------------+----------+-----------+-------+---------+---------------------+---------------------+ + | Db_name | Table_name | Partition_name | Column_name | Is_index | Bucket_id | Count | Repeats | Lower_Bound | Upper_Bound | + +---------+------------+----------------+-------------+----------+-----------+-------+---------+---------------------+---------------------+ + | test | testa | | PRIMARY | 1 | 0 | 64 | 1 | 1846693550524203008 | 1846838686059069440 | + | test | testa | | PRIMARY | 1 | 1 | 128 | 1 | 1846840885082324992 | 1847056389361369088 | + +---------+------------+----------------+-------------+----------+-----------+-------+---------+---------------------+---------------------+ + */ + buckets := make(map[string][]Bucket) + query := "SHOW STATS_BUCKETS WHERE db_name= ? AND table_name= ?;" + log.Debugf("GetBucketsInfo query: %s", query) + + rows, err := db.QueryContext(ctx, query, schema, table) + if err != nil { + return nil, errors.Trace(err) + } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + return nil, errors.Trace(err) + } + + for rows.Next() { + var dbName, tableName, partitionName, columnName, lowerBound, upperBound sql.NullString + var isIndex, bucketID, count, repeats sql.NullInt64 + + // add partiton_name in new version + switch len(cols) { + case 9: + err = rows.Scan(&dbName, &tableName, &columnName, &isIndex, &bucketID, &count, &repeats, &lowerBound, &upperBound) + case 10: + err = rows.Scan(&dbName, &tableName, &partitionName, &columnName, &isIndex, &bucketID, &count, &repeats, &lowerBound, &upperBound) + default: + return nil, errors.New("Unknown struct for buckets info") + } + if err != nil { + return nil, errors.Trace(err) + } + + if _, ok := buckets[columnName.String]; !ok { + buckets[columnName.String] = make([]Bucket, 0, 100) + } + buckets[columnName.String] = append(buckets[columnName.String], Bucket{ + Count: count.Int64, + LowerBound: lowerBound.String, + UpperBound: upperBound.String, + }) + } + + // when primary key is int type, the columnName will be column's name, not `PRIMARY`, check and transform here. + indices := FindAllIndex(tableInfo) + for _, index := range indices { + if index.Name.O != "PRIMARY" { + continue + } + _, ok := buckets[index.Name.O] + if !ok && len(index.Columns) == 1 { + if _, ok := buckets[index.Columns[0].Name.O]; !ok { + return nil, errors.NotFoundf("primary key on %s in buckets info", index.Columns[0].Name.O) + } + buckets[index.Name.O] = buckets[index.Columns[0].Name.O] + delete(buckets, index.Columns[0].Name.O) + } + } + + return buckets, errors.Trace(rows.Err()) +} + +// AnalyzeValuesFromBuckets analyze upperBound or lowerBound to string for each column. +// upperBound and lowerBound are looks like '(123, abc)' for multiple fields, or '123' for one field. +func AnalyzeValuesFromBuckets(valueString string, cols []*model.ColumnInfo) ([]string, error) { + // FIXME: maybe some values contains '(', ')' or ', ' + vStr := strings.Trim(valueString, "()") + values := strings.Split(vStr, ", ") + if len(values) != len(cols) { + return nil, errors.Errorf("analyze value %s failed", valueString) + } + + for i, col := range cols { + if IsTimeTypeAndNeedDecode(col.Tp) { + value, err := DecodeTimeInBucket(values[i]) + if err != nil { + return nil, errors.Trace(err) + } + + values[i] = value + } + } + + return values, nil +} + +// DecodeTimeInBucket decodes Time from a packed uint64 value. +func DecodeTimeInBucket(packedStr string) (string, error) { + packed, err := strconv.ParseUint(packedStr, 10, 64) + if err != nil { + return "", err + } + + if packed == 0 { + return "", nil + } + + t := new(types.Time) + err = t.FromPackedUint(packed) + if err != nil { + return "", err + } + + return t.String(), nil +} + // GetTidbLatestTSO returns tidb's current TSO. func GetTidbLatestTSO(ctx context.Context, db *sql.DB) (int64, error) { /* @@ -421,3 +603,15 @@ func TableName(schema, table string) string { func escapeName(name string) string { return strings.Replace(name, "`", "``", -1) } + +// ReplacePlaceholder will use args to replace '?', used for log. +// tips: make sure the num of "?" is same with len(args) +func ReplacePlaceholder(str string, args []string) string { + /* + for example: + str is "a > ? AND a < ?", args is {'1', '2'}, + this function will return "a > '1' AND a < '2'" + */ + newStr := strings.Replace(str, "?", "'%s'", -1) + return fmt.Sprintf(newStr, utils.StringsToInterfaces(args)...) +} diff --git a/pkg/dbutil/common_test.go b/pkg/dbutil/common_test.go new file mode 100644 index 000000000..0552020b8 --- /dev/null +++ b/pkg/dbutil/common_test.go @@ -0,0 +1,71 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package dbutil + +import ( + . "github.com/pingcap/check" +) + +func (*testDBSuite) TestReplacePlaceholder(c *C) { + testCases := []struct { + originStr string + args []string + expectStr string + }{ + { + "a > ? AND a < ?", + []string{"1", "2"}, + "a > '1' AND a < '2'", + }, { + "a = ? AND b = ?", + []string{"1", "2"}, + "a = '1' AND b = '2'", + }, + } + + for _, testCase := range testCases { + str := ReplacePlaceholder(testCase.originStr, testCase.args) + c.Assert(str, Equals, testCase.expectStr) + } + +} + +func (*testDBSuite) TestTableName(c *C) { + testCases := []struct { + schema string + table string + expectTableName string + }{ + { + "test", + "testa", + "`test`.`testa`", + }, + { + "test-1", + "test-a", + "`test-1`.`test-a`", + }, + { + "test", + "t`esta", + "`test`.`t``esta`", + }, + } + + for _, testCase := range testCases { + tableName := TableName(testCase.schema, testCase.table) + c.Assert(tableName, Equals, testCase.expectTableName) + } +} diff --git a/pkg/dbutil/index.go b/pkg/dbutil/index.go index 13a884e1c..db5ce02a5 100644 --- a/pkg/dbutil/index.go +++ b/pkg/dbutil/index.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "sort" "strconv" "github.com/pingcap/errors" @@ -67,12 +68,12 @@ func ShowIndex(ctx context.Context, db *sql.DB, schemaName string, table string) return indices, nil } -// FindSuitableIndex returns first column of a suitable index. +// FindSuitableColumnWithIndex returns first column of a suitable index. // The priority is // * primary key // * unique key // * normal index which has max cardinality -func FindSuitableIndex(ctx context.Context, db *sql.DB, schemaName string, tableInfo *model.TableInfo) (*model.ColumnInfo, error) { +func FindSuitableColumnWithIndex(ctx context.Context, db *sql.DB, schemaName string, tableInfo *model.TableInfo) (*model.ColumnInfo, error) { // find primary key for _, index := range tableInfo.Indices { if index.Primary { @@ -113,6 +114,49 @@ func FindSuitableIndex(ctx context.Context, db *sql.DB, schemaName string, table return c, nil } +// FindAllIndex returns all index, order is pk, uk, and normal index. +func FindAllIndex(tableInfo *model.TableInfo) []*model.IndexInfo { + indices := make([]*model.IndexInfo, len(tableInfo.Indices)) + copy(indices, tableInfo.Indices) + sort.SliceStable(indices, func(i, j int) bool { + a := indices[i] + b := indices[j] + switch { + case b.Primary: + return false + case a.Primary: + return true + case b.Unique: + return false + case a.Unique: + return true + default: + return false + } + }) + return indices +} + +// FindAllColumnWithIndex returns columns with index, order is pk, uk and normal index. +func FindAllColumnWithIndex(tableInfo *model.TableInfo) []*model.ColumnInfo { + colsMap := make(map[string]interface{}) + cols := make([]*model.ColumnInfo, 0, 2) + + for _, index := range FindAllIndex(tableInfo) { + // index will be guaranteed to be visited in order PK -> UK -> IK + for _, indexCol := range index.Columns { + col := FindColumnByName(tableInfo.Columns, indexCol.Name.O) + if _, ok := colsMap[col.Name.O]; ok { + continue + } + colsMap[col.Name.O] = struct{}{} + cols = append(cols, col) + } + } + + return cols +} + // SelectUniqueOrderKey returns some columns for order by condition. func SelectUniqueOrderKey(tbInfo *model.TableInfo) ([]string, []*model.ColumnInfo) { keys := make([]string, 0, 2) diff --git a/pkg/dbutil/index_test.go b/pkg/dbutil/index_test.go new file mode 100644 index 000000000..4f140371c --- /dev/null +++ b/pkg/dbutil/index_test.go @@ -0,0 +1,97 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package dbutil + +import ( + . "github.com/pingcap/check" +) + +func (*testDBSuite) TestIndex(c *C) { + testCases := []struct { + sql string + indices []string + cols []string + }{ + { + ` + CREATE TABLE itest (a int(11) NOT NULL, + b double NOT NULL DEFAULT '2', + c varchar(10) NOT NULL, + d time DEFAULT NULL, + PRIMARY KEY (a, b), + UNIQUE KEY d(d)) + `, + []string{"PRIMARY", "d"}, + []string{"a", "b", "d"}, + }, { + ` + CREATE TABLE jtest ( + a int(11) NOT NULL, + b varchar(10) DEFAULT NULL, + c varchar(255) DEFAULT NULL, + KEY c(c), + UNIQUE KEY b(b, c), + PRIMARY KEY (a) + ) ENGINE=InnoDB DEFAULT CHARSET=latin1 COLLATE=latin1_bin + `, + []string{"PRIMARY", "b", "c"}, + []string{"a", "b", "c"}, + }, { + ` + CREATE TABLE mtest ( + a int(24), + KEY test (a)) + `, + []string{"test"}, + []string{"a"}, + }, + { + ` + CREATE TABLE mtest ( + a int(24), + b int(24), + KEY test1 (a), + KEY test2 (b)) + `, + []string{"test1", "test2"}, + []string{"a", "b"}, + }, + { + ` + CREATE TABLE mtest ( + a int(24), + b int(24), + UNIQUE KEY test1 (a), + UNIQUE KEY test2 (b)) + `, + []string{"test1", "test2"}, + []string{"a", "b"}, + }, + } + + for _, testCase := range testCases { + tableInfo, err := GetTableInfoBySQL(testCase.sql) + c.Assert(err, IsNil) + + indices := FindAllIndex(tableInfo) + for i, index := range indices { + c.Assert(index.Name.O, Equals, testCase.indices[i]) + } + + cols := FindAllColumnWithIndex(tableInfo) + for j, col := range cols { + c.Assert(col.Name.O, Equals, testCase.cols[j]) + } + } +} diff --git a/pkg/dbutil/table_test.go b/pkg/dbutil/table_test.go index da80fbf22..cfdbdd222 100644 --- a/pkg/dbutil/table_test.go +++ b/pkg/dbutil/table_test.go @@ -24,9 +24,9 @@ func TestClient(t *testing.T) { TestingT(t) } -var _ = Suite(&testTableSuite{}) +var _ = Suite(&testDBSuite{}) -type testTableSuite struct{} +type testDBSuite struct{} type testCase struct { sql string @@ -36,7 +36,7 @@ type testCase struct { fineCol bool } -func (*testTableSuite) TestTable(c *C) { +func (*testDBSuite) TestTable(c *C) { testCases := []*testCase{ { ` @@ -93,7 +93,7 @@ func (*testTableSuite) TestTable(c *C) { } } -func (*testTableSuite) TestTableStructEqual(c *C) { +func (*testDBSuite) TestTableStructEqual(c *C) { createTableSQL1 := "CREATE TABLE `test`.`atest` (`id` int(24), `name` varchar(24), `birthday` datetime, `update_time` time, `money` decimal(20,2), primary key(`id`))" tableInfo1, err := GetTableInfoBySQL(createTableSQL1) c.Assert(err, IsNil) diff --git a/pkg/dbutil/types.go b/pkg/dbutil/types.go index bf388cfaf..25bc91872 100644 --- a/pkg/dbutil/types.go +++ b/pkg/dbutil/types.go @@ -23,3 +23,11 @@ func IsFloatType(tp byte) bool { return false } + +// IsTimeTypeAndNeedDecode returns true if tp is time type and encoded in tidb buckets. +func IsTimeTypeAndNeedDecode(tp byte) bool { + if tp == mysql.TypeDatetime || tp == mysql.TypeTimestamp || tp == mysql.TypeDate { + return true + } + return false +} diff --git a/pkg/diff/chunk.go b/pkg/diff/chunk.go index 748a07afe..7c6e55bf6 100644 --- a/pkg/diff/chunk.go +++ b/pkg/diff/chunk.go @@ -17,301 +17,545 @@ import ( "context" "database/sql" "fmt" - "reflect" + "strings" "github.com/pingcap/errors" "github.com/pingcap/parser/model" "github.com/pingcap/tidb-tools/pkg/dbutil" + "github.com/pingcap/tidb-tools/pkg/utils" log "github.com/sirupsen/logrus" ) +var ( + equal = "=" + lt = "<" + lte = "<=" + gt = ">" + gte = ">=" + + bucketMode = "bucket" + normalMode = "normalMode" +) + +type bound struct { + column string + lower string + lowerSymbol string + upper string + upperSymbol string +} + // chunkRange represents chunk range type chunkRange struct { - begin interface{} - end interface{} - // for example: - // containBegin and containEnd is true, means [begin, end] - // containBegin is true, containEnd is false, means [begin, end) - containBegin bool - containEnd bool - - // if noBegin is true, means there is no lower limit - // if noEnd is true, means there is no upper limit - noBegin bool - noEnd bool + bounds []*bound } -// CheckJob is the struct of job for check -type CheckJob struct { - Schema string - Table string - Column *model.ColumnInfo - Where string - Args []interface{} - Chunk chunkRange +// newChunkRange return a chunkRange. +func newChunkRange() *chunkRange { + return &chunkRange{ + bounds: make([]*bound, 0, 2), + } } -// newChunkRange return a range struct -func newChunkRange(begin, end interface{}, containBegin, containEnd, noBegin, noEnd bool) chunkRange { - return chunkRange{ - begin: begin, - end: end, - containEnd: containEnd, - containBegin: containBegin, - noBegin: noBegin, - noEnd: noEnd, +func (c *chunkRange) toString(mode string, collation string) (string, []string) { + if collation != "" { + collation = fmt.Sprintf(" COLLATE '%s'", collation) + } + + if mode != bucketMode { + conditions := make([]string, 0, 2) + args := make([]string, 0, 2) + + for _, bound := range c.bounds { + if len(bound.lower) != 0 { + conditions = append(conditions, fmt.Sprintf("`%s`%s %s ?", bound.column, collation, bound.lowerSymbol)) + args = append(args, bound.lower) + } + if len(bound.upper) != 0 { + conditions = append(conditions, fmt.Sprintf("`%s`%s %s ?", bound.column, collation, bound.upperSymbol)) + args = append(args, bound.upper) + } + } + + if len(conditions) == 0 { + return "TRUE", nil + } + + return strings.Join(conditions, " AND "), args + } + + /* for example: + there is a bucket in TiDB, and the lowerbound and upperbound are (v1, v3), (v2, v4), and the columns are `a` and `b`, + this bucket's data range is (a > v1 or (a == v1 and b >= v2)) and (a < v3 or (a == v3 and a <= v4)), + not (a >= v1 and a <= v3 and b >= v2 and b <= v4) + */ + + lowerCondition := make([]string, 0, 1) + upperCondition := make([]string, 0, 1) + lowerArgs := make([]string, 0, 1) + upperArgs := make([]string, 0, 1) + + preConditionForLower := make([]string, 0, 1) + preConditionForUpper := make([]string, 0, 1) + preConditionArgsForLower := make([]string, 0, 1) + preConditionArgsForUpper := make([]string, 0, 1) + + for _, bound := range c.bounds { + if len(bound.lower) != 0 { + if len(preConditionForLower) > 0 { + lowerCondition = append(lowerCondition, fmt.Sprintf("(%s AND `%s`%s %s ?)", strings.Join(preConditionForLower, " AND "), bound.column, collation, bound.lowerSymbol)) + lowerArgs = append(append(lowerArgs, preConditionArgsForLower...), bound.lower) + } else { + lowerCondition = append(lowerCondition, fmt.Sprintf("(`%s`%s %s ?)", bound.column, collation, bound.lowerSymbol)) + lowerArgs = append(lowerArgs, bound.lower) + } + preConditionForLower = append(preConditionForLower, fmt.Sprintf("`%s` = ?", bound.column)) + preConditionArgsForLower = append(preConditionArgsForLower, bound.lower) + } + + if len(bound.upper) != 0 { + if len(preConditionForUpper) > 0 { + upperCondition = append(upperCondition, fmt.Sprintf("(%s AND `%s`%s %s ?)", strings.Join(preConditionForUpper, " AND "), bound.column, collation, bound.upperSymbol)) + upperArgs = append(append(upperArgs, preConditionArgsForUpper...), bound.upper) + } else { + upperCondition = append(upperCondition, fmt.Sprintf("(`%s`%s %s ?)", bound.column, collation, bound.upperSymbol)) + upperArgs = append(upperArgs, bound.upper) + } + preConditionForUpper = append(preConditionForUpper, fmt.Sprintf("`%s` = ?", bound.column)) + preConditionArgsForUpper = append(preConditionArgsForUpper, bound.upper) + } + } + + if len(upperCondition) == 0 && len(lowerCondition) == 0 { + return "TRUE", nil + } + + if len(upperCondition) == 0 { + return strings.Join(lowerCondition, " OR "), lowerArgs + } + + if len(lowerCondition) == 0 { + return strings.Join(upperCondition, " OR "), upperArgs } + + return fmt.Sprintf("(%s) AND (%s)", strings.Join(lowerCondition, " OR "), strings.Join(upperCondition, " OR ")), append(lowerArgs, upperArgs...) + } -func getChunksForTable(table *TableInstance, column *model.ColumnInfo, chunkSize, sample int, limits string, collation string) ([]chunkRange, error) { - if column == nil { - log.Warnf("no suitable index found for %s.%s", table.Schema, table.Table) - return nil, nil + +func (c *chunkRange) update(column, lower, lowerSymbol, upper, upperSymbol string) { + newBound := &bound{ + column: column, + lower: lower, + lowerSymbol: lowerSymbol, + upper: upper, + upperSymbol: upperSymbol, } - // get the chunk count + for i, b := range c.bounds { + if b.column == column { + // update the bound + c.bounds[i] = newBound + return + } + } + + // add a new bound + c.bounds = append(c.bounds, newBound) +} + +func (c *chunkRange) copy() *chunkRange { + newChunk := &chunkRange{ + bounds: make([]*bound, len(c.bounds)), + } + copy(newChunk.bounds, c.bounds) + + return newChunk +} + +func (c *chunkRange) copyAndUpdate(column, lower, lowerSymbol, upper, upperSymbol string) *chunkRange { + newChunk := c.copy() + newChunk.update(column, lower, lowerSymbol, upper, upperSymbol) + return newChunk +} + +type spliter interface { + // split splits a table's data to several chunks. + split(table *TableInstance, columns []*model.ColumnInfo, chunkSize int, limits string, collation string) ([]*chunkRange, error) +} + +type randomSpliter struct { + table *TableInstance + chunkSize int + limits string + collation string +} + +func (s *randomSpliter) split(table *TableInstance, columns []*model.ColumnInfo, chunkSize int, limits string, collation string) ([]*chunkRange, error) { + s.table = table + s.chunkSize = chunkSize + s.limits = limits + s.collation = collation + + // get the chunk count by data count and chunk size cnt, err := dbutil.GetRowCount(context.Background(), table.Conn, table.Schema, table.Table, limits) if err != nil { return nil, errors.Trace(err) } - if cnt == 0 { log.Infof("no data found in %s.%s", table.Schema, table.Table) return nil, nil } - chunkCnt := (cnt + int64(chunkSize) - 1) / int64(chunkSize) - if sample != 100 { - // use sampling check, can check more fragmented by split to more chunk - chunkCnt *= 10 + chunkCnt := (int(cnt) + chunkSize - 1) / chunkSize + chunks, err := s.splitRange(table.Conn, newChunkRange(), chunkCnt, table.Schema, table.Table, columns) + if err != nil { + return nil, errors.Trace(err) } - field := column.Name.O + return chunks, nil +} - collationStr := "" - if collation != "" { - collationStr = fmt.Sprintf(" COLLATE \"%s\"", collation) - } +// splitRange splits a chunk to multiple chunks. +func (s *randomSpliter) splitRange(db *sql.DB, chunk *chunkRange, count int, schema string, table string, columns []*model.ColumnInfo) ([]*chunkRange, error) { + var chunks []*chunkRange - // fetch min, max - query := fmt.Sprintf("SELECT /*!40001 SQL_NO_CACHE */ MIN(`%s`%s) as MIN, MAX(`%s`%s) as MAX FROM `%s`.`%s` WHERE %s", - field, collationStr, field, collationStr, table.Schema, table.Table, limits) + if count <= 1 { + chunks = append(chunks, chunk) + return chunks, nil + } - var chunk chunkRange - if dbutil.IsNumberType(column.Tp) { - var min, max sql.NullInt64 - err := table.Conn.QueryRow(query).Scan(&min, &max) - if err != nil { - return nil, errors.Trace(err) - } - if !min.Valid { - // min is NULL, means that no table data. - return nil, nil + var ( + splitCol, min, max, symbolMin, symbolMax string + err error + useNewColumn bool + ) + + chunkLimits, args := chunk.toString(normalMode, s.collation) + limitRange := fmt.Sprintf("%s AND %s", chunkLimits, s.limits) + + // if the last column's condition is not '=', continue use this column split data. + colNum := len(chunk.bounds) + if colNum != 0 && chunk.bounds[colNum-1].lowerSymbol != equal { + splitCol = chunk.bounds[colNum-1].column + min = chunk.bounds[colNum-1].lower + max = chunk.bounds[colNum-1].upper + symbolMin = chunk.bounds[colNum-1].lowerSymbol + symbolMax = chunk.bounds[colNum-1].upperSymbol + } else { + if len(columns) <= colNum { + log.Warnf("chunk %v can't be splited", chunk) + return append(chunks, chunk), nil } - chunk = newChunkRange(min.Int64, max.Int64, true, true, false, false) - } else if dbutil.IsFloatType(column.Tp) { - var min, max sql.NullFloat64 - err := table.Conn.QueryRow(query).Scan(&min, &max) + + + // choose the next column to split data + useNewColumn = true + splitCol = columns[colNum].Name.O + + min, max, err = dbutil.GetMinMaxValue(context.Background(), db, schema, table, splitCol, limitRange, utils.StringsToInterfaces(args), s.collation) if err != nil { + if errors.Cause(err) == dbutil.ErrNoData { + log.Infof("no data found in %s.%s range %s, args %v", schema, table, limitRange, args) + return append(chunks, chunk), nil + } return nil, errors.Trace(err) } - if !min.Valid { - // min is NULL, means that no table data. - return nil, nil - } - chunk = newChunkRange(min.Float64, max.Float64, true, true, false, false) + + symbolMin = gte + symbolMax = lte + + } + + splitValues := make([]string, 0, count) + valueCounts := make([]int, 0, count) + + // get random value as split value + randomValues, randomValueCount, err := dbutil.GetRandomValues(context.Background(), db, schema, table, splitCol, count-1, limitRange, utils.StringsToInterfaces(args), s.collation) + if err != nil { + return nil, errors.Trace(err) + } + log.Infof("split chunk %v, get split values from GetRandomValues: %v", chunk, randomValues) + + /* + for examples: + the result of GetRandomValues is: + mysql> SELECT `id`, count(*) count FROM (SELECT `id` FROM `test`.`test` ORDER BY RAND() LIMIT 100) rand_tmp GROUP BY `id` ORDER BY `id`; + +------+-------+ + | id | count | + +------+-------+ + | 1 | 1 | + | 2 | 1 | + | 3 | 96 | + | 4 | 1 | + | 5 | 1 | + +------+-------+ + + We can assume that the 96% of this table's data is in range [id = 3], so we should use another column to split range `id = 3`, + just like [id = 3 AND cid > 10], [id = 3 AND cid >= 5 AND cid <= 10], [id = 3 AND cid < 5]... + */ + + if len(randomValues) > 0 && randomValues[0] == min { + splitValues = append(splitValues, randomValues...) + valueCounts = append(valueCounts, randomValueCount...) + valueCounts[0]++ } else { - var min, max sql.NullString - err := table.Conn.QueryRow(query).Scan(&min, &max) - if err != nil { - return nil, errors.Trace(err) - } - if !min.Valid || !max.Valid { - return nil, nil + splitValues = append(append(splitValues, min), randomValues...) + valueCounts = append(append(valueCounts, 1), randomValueCount...) + } + + if len(randomValues) > 0 && randomValues[len(randomValues)-1] == max { + valueCounts[len(valueCounts)-1]++ + } else { + splitValues = append(splitValues, max) + valueCounts = append(valueCounts, 1) + } + + /* + for example: + the splitCol is `a`; + the splitValues is [1, 2, 3, 4, 5]; + the splitCounts is [1, 3, 1, 1, 1]; + + this means you get 3 times value 2 by random, we can assume that there amolst be a lot of rows with value 2, + so we need use another column `b` to split the chunk [`a` = 2] to 3 chunks. + + and then the splitCol is `b`; + the splitValues is ['w', 'x', 'y', 'z']; + the splitValues is [1, 1, 1, 1]; + the chunk [`a` = 2] will split to [`a` = 2 AND `b` < 'x'], [`a` = 2 AND `b` >= 'x' AND `b` < 'y'] and [`a` = 2 AND `b` >= 'y'] + */ + var lower, upper, lowerSymbol, upperSymbol string + for i := 0; i < len(splitValues); i++ { + if valueCounts[i] > 1 { + // means should split it + newChunk := chunk.copyAndUpdate(splitCol, splitValues[i], equal, "", "") + splitChunks, err := s.splitRange(db, newChunk, valueCounts[i], schema, table, columns) + if err != nil { + return nil, errors.Trace(err) + } + chunks = append(chunks, splitChunks...) + + // already have the chunk [column = value], so next chunk should start with column > value + lowerSymbol = gt + } else { + if i == 0 { + if useNewColumn { + lower = "" + lowerSymbol = "" + } else { + lower = splitValues[i] + lowerSymbol = symbolMin + } + } else { + lower = splitValues[i] + } + + if i == len(splitValues)-2 { + if useNewColumn && valueCounts[len(valueCounts)-1] == 1 { + upper = "" + upperSymbol = "" + } else { + upper = splitValues[i+1] + upperSymbol = symbolMax + } + } else { + if i == len(splitValues)-1 { + continue + } + + upper = splitValues[i+1] + upperSymbol = lt + + } } - chunk = newChunkRange(min.String, max.String, true, true, false, false) + + newChunk := chunk.copyAndUpdate(splitCol, lower, lowerSymbol, upper, upperSymbol) + chunks = append(chunks, newChunk) + + lowerSymbol = gte } - return splitRange(table.Conn, &chunk, chunkCnt, table.Schema, table.Table, column, limits, collation) + log.Debugf("getChunksForTable cut table: cnt=%d min=%s max=%s chunk=%d", count, min, max, len(chunks)) + return chunks, nil } -func splitRange(db *sql.DB, chunk *chunkRange, count int64, Schema string, table string, column *model.ColumnInfo, limitRange string, collation string) ([]chunkRange, error) { - var chunks []chunkRange +type bucketSpliter struct { + table *TableInstance + chunkSize int + limits string + collation string + buckets map[string][]dbutil.Bucket +} - // for example, the min and max value in target table is 2-9, but 1-10 in source table. so we need generate chunk for data < 2 and data > 9 - addOutRangeChunk := func() { - chunks = append(chunks, newChunkRange(struct{}{}, chunk.begin, false, false, true, false)) - chunks = append(chunks, newChunkRange(chunk.end, struct{}{}, false, false, false, true)) - } +func (s *bucketSpliter) split(table *TableInstance, columns []*model.ColumnInfo, chunkSize int, limits string, collation string) ([]*chunkRange, error) { + s.table = table + s.chunkSize = chunkSize + s.limits = limits + s.collation = collation - if count <= 1 { - chunks = append(chunks, *chunk) - addOutRangeChunk() - return chunks, nil + buckets, err := dbutil.GetBucketsInfo(context.Background(), s.table.Conn, s.table.Schema, s.table.Table, s.table.info) + if err != nil { + return nil, errors.Trace(err) } + s.buckets = buckets - if reflect.TypeOf(chunk.begin).Kind() == reflect.Int64 { - min, ok1 := chunk.begin.(int64) - max, ok2 := chunk.end.(int64) - if !ok1 || !ok2 { - return nil, errors.Errorf("can't parse chunk's begin: %v, end: %v", chunk.begin, chunk.end) - } - step := (max - min + count - 1) / count - cutoff := min - for cutoff <= max { - r := newChunkRange(cutoff, cutoff+step, true, false, false, false) - chunks = append(chunks, r) - cutoff += step - } + return s.getChunksByBuckets() +} + +func (s *bucketSpliter) getChunksByBuckets() ([]*chunkRange, error) { + chunks := make([]*chunkRange, 0, 1000) - log.Debugf("getChunksForTable cut table: cnt=%d min=%v max=%v step=%v chunk=%d", count, min, max, step, len(chunks)) - } else if reflect.TypeOf(chunk.begin).Kind() == reflect.Float64 { - min, ok1 := chunk.begin.(float64) - max, ok2 := chunk.end.(float64) - if !ok1 || !ok2 { - return nil, errors.Errorf("can't parse chunk's begin: %v, end: %v", chunk.begin, chunk.end) + indices := dbutil.FindAllIndex(s.table.info) + for _, index := range indices { + if index == nil { + continue } - step := (max - min + float64(count-1)) / float64(count) - cutoff := min - for cutoff <= max { - r := newChunkRange(cutoff, cutoff+step, true, false, false, false) - chunks = append(chunks, r) - cutoff += step + buckets, ok := s.buckets[index.Name.O] + if !ok { + return nil, errors.NotFoundf("index %s in buckets info", index.Name.O) } - log.Debugf("getChunksForTable cut table: cnt=%d min=%v max=%v step=%v chunk=%d", - count, min, max, step, len(chunks)) - } else { - max, ok1 := chunk.end.(string) - min, ok2 := chunk.begin.(string) - if !ok1 || !ok2 { - return nil, errors.Errorf("can't parse chunk's begin: %v, end: %v", chunk.begin, chunk.end) - } + var ( + lowerValues []string + upperValues []string + latestCount int64 + err error + ) - // get random value as split value - splitValues, err := dbutil.GetRandomValues(context.Background(), db, Schema, table, column.Name.O, count-1, min, max, limitRange, collation) - if err != nil { - return nil, errors.Trace(err) - } + indexColumns := getColumnsFromIndex(index, s.table.info) - var minTmp, maxTmp string - var i int64 - for i = 0; i < int64(len(splitValues)+1); i++ { - if i == 0 { - minTmp = min - } else { - minTmp = fmt.Sprintf("%s", splitValues[i-1]) + for i, bucket := range buckets { + upperValues, err = dbutil.AnalyzeValuesFromBuckets(bucket.UpperBound, indexColumns) + if err != nil { + return nil, errors.Trace(err) } - if i == int64(len(splitValues)) { - maxTmp = max - } else { - maxTmp = fmt.Sprintf("%s", splitValues[i]) + + if bucket.Count-latestCount > int64(s.chunkSize) || i == len(buckets)-1 { + // create a new chunk + chunk := newChunkRange() + var lower, upper, lowerSymbol, upperSymbol string + for j, col := range index.Columns { + if len(lowerValues) != 0 { + lower = lowerValues[j] + lowerSymbol = gt + } + if i != len(buckets)-1 { + upper = upperValues[j] + upperSymbol = lte + } + + chunk.update(col.Name.O, lower, lowerSymbol, upper, upperSymbol) + } + + chunks = append(chunks, chunk) + lowerValues = upperValues + latestCount = bucket.Count } - r := newChunkRange(minTmp, maxTmp, true, false, false, false) - chunks = append(chunks, r) } - log.Debugf("getChunksForTable cut table: cnt=%d min=%s max=%s chunk=%d", count, min, max, len(chunks)) + if len(chunks) != 0 { + break + } } - chunks[len(chunks)-1].end = chunk.end - chunks[0].containBegin = chunk.containBegin - chunks[len(chunks)-1].containEnd = chunk.containEnd + return chunks, nil +} - addOutRangeChunk() +func getChunksForTable(table *TableInstance, columns []*model.ColumnInfo, chunkSize int, limits string, collation string, useTiDBStatsInfo bool) ([]*chunkRange, string, error) { + if useTiDBStatsInfo { + s := bucketSpliter{} + chunks, err := s.split(table, columns, chunkSize, limits, collation) + if err == nil && len(chunks) > 0 { + return chunks, bucketMode, nil + } - return chunks, nil + log.Warnf("use tidb bucket information to get chunks error: %v, chunks num: %d, will split chunk by random again", errors.Trace(err), len(chunks)) + } + + // get chunks from tidb bucket information failed, use random. + s := randomSpliter{} + chunks, err := s.split(table, columns, chunkSize, limits, collation) + return chunks, normalMode, err } -func findSuitableField(db *sql.DB, Schema string, table *model.TableInfo) (*model.ColumnInfo, error) { - // first select the index, and number type index first - column, err := dbutil.FindSuitableIndex(context.Background(), db, Schema, table) - if err != nil { - return nil, errors.Trace(err) +// getSplitFields returns fields to split chunks, order by pk, uk, index, columns. +func getSplitFields(table *model.TableInfo, splitFields []string) ([]*model.ColumnInfo, error) { + cols := make([]*model.ColumnInfo, 0, len(table.Columns)) + colsMap := make(map[string]interface{}) + + splitCols := make([]*model.ColumnInfo, 0, 2) + for _, splitField := range splitFields { + col := dbutil.FindColumnByName(table.Columns, splitField) + if col == nil { + return nil, errors.NotFoundf("column %s in table %s", splitField, table.Name) + + } + splitCols = append(splitCols, col) } - if column != nil { - return column, nil + + indexColumns := dbutil.FindAllColumnWithIndex(table) + + // user's config had higher priorities + for _, col := range append(append(splitCols, indexColumns...), table.Columns...) { + if _, ok := colsMap[col.Name.O]; ok { + continue + } + + colsMap[col.Name.O] = struct{}{} + cols = append(cols, col) } - // use the first column - log.Infof("%s.%s don't have index, will use the first column as split field", Schema, table.Name.O) - return table.Columns[0], nil + return cols, nil +} + +// CheckJob is the struct of job for check +type CheckJob struct { + Schema string + Table string + Where string + Args []string } // GenerateCheckJob generates some CheckJobs. -func GenerateCheckJob(table *TableInstance, splitField, limits string, chunkSize, sample int, collation string) ([]*CheckJob, error) { +func GenerateCheckJob(table *TableInstance, splitFields, limits string, chunkSize int, collation string, useTiDBStatsInfo bool) ([]*CheckJob, error) { jobBucket := make([]*CheckJob, 0, 10) var jobCnt int - var column *model.ColumnInfo var err error - if splitField == "" { - column, err = findSuitableField(table.Conn, table.Schema, table.info) - if err != nil { - return nil, errors.Trace(err) - } - } else { - column = dbutil.FindColumnByName(table.info.Columns, splitField) - if column == nil { - return nil, errors.NotFoundf("column %s in table %s", splitField, table.Table) - } + var splitFieldArr []string + if len(splitFields) != 0 { + splitFieldArr = strings.Split(splitFields, ",") } - chunks, err := getChunksForTable(table, column, chunkSize, sample, limits, collation) + for i := range splitFieldArr { + splitFieldArr[i] = strings.TrimSpace(splitFieldArr[i]) + } + + fields, err := getSplitFields(table.info, splitFieldArr) + if err != nil { + return nil, errors.Trace(err) + } + + chunks, mode, err := getChunksForTable(table, fields, chunkSize, limits, collation, useTiDBStatsInfo) if err != nil { return nil, errors.Trace(err) } if chunks == nil { return nil, nil } - log.Debugf("chunks: %+v", chunks) jobCnt += len(chunks) - if collation != "" { - collation = fmt.Sprintf(" COLLATE \"%s\"", collation) - } - - for { - length := len(chunks) - if length == 0 { - break - } - - chunk := chunks[0] - chunks = chunks[1:] - - args := make([]interface{}, 0, 2) - var condition1, condition2 string - if !chunk.noBegin { - if chunk.containBegin { - condition1 = fmt.Sprintf("`%s`%s >= ?", column.Name, collation) - } else { - condition1 = fmt.Sprintf("`%s`%s > ?", column.Name, collation) - } - args = append(args, chunk.begin) - } else { - condition1 = "TRUE" - } - if !chunk.noEnd { - if chunk.containEnd { - condition2 = fmt.Sprintf("`%s`%s <= ?", column.Name, collation) - } else { - condition2 = fmt.Sprintf("`%s`%s < ?", column.Name, collation) - } - args = append(args, chunk.end) - } else { - condition2 = "TRUE" - } - where := fmt.Sprintf("(%s AND %s AND %s)", condition1, condition2, limits) + for _, chunk := range chunks { + conditions, args := chunk.toString(mode, collation) + where := fmt.Sprintf("(%s AND %s)", conditions, limits) - log.Debugf("%s.%s create dump job, where: %s, begin: %v, end: %v", table.Schema, table.Table, where, chunk.begin, chunk.end) + log.Debugf("%s.%s create check job, where: %s, args: %v", table.Schema, table.Table, where, args) jobBucket = append(jobBucket, &CheckJob{ Schema: table.Schema, Table: table.Table, - Column: column, Where: where, Args: args, - Chunk: chunk, }) } diff --git a/pkg/diff/chunk_test.go b/pkg/diff/chunk_test.go index ebbae3e36..a43f41a5d 100644 --- a/pkg/diff/chunk_test.go +++ b/pkg/diff/chunk_test.go @@ -14,7 +14,11 @@ package diff import ( + "context" + . "github.com/pingcap/check" + "github.com/pingcap/tidb-tools/pkg/dbutil" + "github.com/pingcap/tidb-tools/pkg/importer" ) var _ = Suite(&testChunkSuite{}) @@ -28,121 +32,173 @@ type chunkTestCase struct { } func (*testChunkSuite) TestSplitRange(c *C) { - testCases := []*chunkTestCase{ - { - &chunkRange{ - begin: int64(1), - end: int64(1000), - containBegin: true, - containEnd: true, - }, - 1, - []*chunkRange{ - { - begin: int64(1), - end: int64(1000), - containBegin: true, - containEnd: true, - }, - { - begin: struct{}{}, - end: int64(1), - containBegin: false, - containEnd: false, - noBegin: true, - }, - { - begin: int64(1000), - end: struct{}{}, - containBegin: false, - containEnd: false, - noEnd: true, - }, + dbConn, err := createConn() + c.Assert(err, IsNil) + + _, err = dbConn.Query("CREATE DATABASE IF NOT EXISTS `test`") + c.Assert(err, IsNil) + + _, err = dbConn.Query("DROP TABLE IF EXISTS `test`.`testa`") + c.Assert(err, IsNil) + + createTableSQL := `CREATE TABLE test.testa ( + a date NOT NULL, + b datetime DEFAULT NULL, + c time DEFAULT NULL, + d varchar(10) COLLATE latin1_bin DEFAULT NULL, + e int(10) DEFAULT NULL, + h year(4) DEFAULT NULL, + PRIMARY KEY (a))` + + dataCount := 10000 + cfg := &importer.Config{ + TableSQL: createTableSQL, + WorkerCount: 5, + JobCount: dataCount, + Batch: 100, + DBCfg: dbutil.GetDBConfigFromEnv("test"), + } + + // generate data for test.testa + importer.DoProcess(cfg) + defer dbConn.Query("DROP TABLE IF EXISTS `test`.`testa`") + + // only work on tidb, so don't assert err here + _, _ = dbConn.Query("ANALYZE TABLE `test`.`testa`") + + tableInfo, err := dbutil.GetTableInfoWithRowID(context.Background(), dbConn, "test", "testa", false) + c.Assert(err, IsNil) + + tableInstance := &TableInstance{ + Conn: dbConn, + Schema: "test", + Table: "testa", + info: tableInfo, + } + + // split chunks + fields, err := getSplitFields(tableInstance.info, nil) + c.Assert(err, IsNil) + chunks, mode, err := getChunksForTable(tableInstance, fields, 100, "TRUE", "", false) + c.Assert(err, IsNil) + + // get data count from every chunk, and the sum of them should equal to the table's count. + chunkDataCount := 0 + for _, chunk := range chunks { + conditions, args := chunk.toString(mode, "") + count, err := dbutil.GetRowCount(context.Background(), tableInstance.Conn, tableInstance.Schema, tableInstance.Table, dbutil.ReplacePlaceholder(conditions, args)) + c.Assert(err, IsNil) + chunkDataCount += int(count) + } + c.Assert(chunkDataCount, Equals, dataCount) +} + +func (*testChunkSuite) TestChunkUpdate(c *C) { + chunk := &chunkRange{ + bounds: []*bound{ + { + column: "a", + lower: "1", + lowerSymbol: ">", + upper: "2", + upperSymbol: "<=", + }, { + column: "b", + lower: "3", + lowerSymbol: ">=", + upper: "4", + upperSymbol: "<", }, + }, + } + + testCases := []struct { + boundArgs []string + expectStr string + expectArgs []string + }{ + { + []string{"a", "5", ">=", "6", "<="}, + "`a` >= ? AND `a` <= ? AND `b` >= ? AND `b` < ?", + []string{"5", "6", "3", "4"}, }, { - &chunkRange{ - begin: int64(1), - end: int64(1000), - containBegin: true, - containEnd: false, - }, - 2, - []*chunkRange{ - { - begin: int64(1), - end: int64(501), - containBegin: true, - containEnd: false, - }, - { - begin: int64(501), - end: int64(1000), - containBegin: true, - containEnd: false, - }, - { - begin: struct{}{}, - end: int64(1), - containBegin: false, - containEnd: false, - noBegin: true, - }, - { - begin: int64(1000), - end: struct{}{}, - containBegin: false, - containEnd: false, - noEnd: true, - }, - }, + []string{"a", "5", ">=", "6", "<"}, + "`a` >= ? AND `a` < ? AND `b` >= ? AND `b` < ?", + []string{"5", "6", "3", "4"}, }, { - &chunkRange{ - begin: float64(1.1), - end: float64(1000.1), - containBegin: false, - containEnd: false, - }, - 2, - []*chunkRange{ - { - begin: float64(1.1), - end: float64(501.1), - containBegin: false, - containEnd: false, - }, - { - begin: float64(501.1), - end: float64(1000.1), - containBegin: true, - containEnd: false, - }, - { - begin: struct{}{}, - end: float64(1.1), - containBegin: false, - containEnd: false, - noBegin: true, - }, - { - begin: float64(1000.1), - end: struct{}{}, - containBegin: false, - containEnd: false, - noEnd: true, - }, + []string{"c", "7", ">", "8", "<"}, + "`a` > ? AND `a` <= ? AND `b` >= ? AND `b` < ? AND `c` > ? AND `c` < ?", + []string{"1", "2", "3", "4", "7", "8"}, + }, + } + + for _, cs := range testCases { + newChunk := chunk.copyAndUpdate(cs.boundArgs[0], cs.boundArgs[1], cs.boundArgs[2], cs.boundArgs[3], cs.boundArgs[4]) + conditions, args := newChunk.toString(normalMode, "") + c.Assert(conditions, Equals, cs.expectStr) + c.Assert(args, DeepEquals, cs.expectArgs) + } + + // the origin chunk is not changed + conditions, args := chunk.toString(normalMode, "") + c.Assert(conditions, Equals, "`a` > ? AND `a` <= ? AND `b` >= ? AND `b` < ?") + expectArgs := []string{"1", "2", "3", "4"} + for i, arg := range args { + c.Assert(arg, Equals, expectArgs[i]) + } +} + +func (*testChunkSuite) TestChunkToString(c *C) { + chunk := &chunkRange{ + bounds: []*bound{ + { + column: "a", + lower: "1", + lowerSymbol: ">", + upper: "2", + upperSymbol: "<", + }, { + column: "b", + lower: "3", + lowerSymbol: ">", + upper: "4", + upperSymbol: "<", + }, { + column: "c", + lower: "5", + lowerSymbol: ">", + upper: "6", + upperSymbol: "<", }, }, } - for _, testCase := range testCases { - chunks, err := splitRange(nil, testCase.chunk, testCase.chunkCnt, "", "", nil, "", "") - c.Assert(err, IsNil) + conditions, args := chunk.toString(normalMode, "") + c.Assert(conditions, Equals, "`a` > ? AND `a` < ? AND `b` > ? AND `b` < ? AND `c` > ? AND `c` < ?") + expectArgs := []string{"1", "2", "3", "4", "5", "6"} + for i, arg := range args { + c.Assert(arg, Equals, expectArgs[i]) + } - for i, chunk := range chunks { - c.Assert(chunk.begin, Equals, testCase.expectChunks[i].begin) - c.Assert(chunk.end, Equals, testCase.expectChunks[i].end) - c.Assert(chunk.containBegin, Equals, testCase.expectChunks[i].containBegin) - c.Assert(chunk.containEnd, Equals, testCase.expectChunks[i].containEnd) - } + conditions, args = chunk.toString(normalMode, "latin1") + c.Assert(conditions, Equals, "`a` COLLATE 'latin1' > ? AND `a` COLLATE 'latin1' < ? AND `b` COLLATE 'latin1' > ? AND `b` COLLATE 'latin1' < ? AND `c` COLLATE 'latin1' > ? AND `c` COLLATE 'latin1' < ?") + expectArgs = []string{"1", "2", "3", "4", "5", "6"} + for i, arg := range args { + c.Assert(arg, Equals, expectArgs[i]) } + + conditions, args = chunk.toString(bucketMode, "") + c.Assert(conditions, Equals, "((`a` > ?) OR (`a` = ? AND `b` > ?) OR (`a` = ? AND `b` = ? AND `c` > ?)) AND ((`a` < ?) OR (`a` = ? AND `b` < ?) OR (`a` = ? AND `b` = ? AND `c` < ?))") + expectArgs = []string{"1", "1", "3", "1", "3", "5", "2", "2", "4", "2", "4", "6"} + for i, arg := range args { + c.Assert(arg, Equals, expectArgs[i]) + } + + conditions, args = chunk.toString(bucketMode, "latin1") + c.Assert(conditions, Equals, "((`a` COLLATE 'latin1' > ?) OR (`a` = ? AND `b` COLLATE 'latin1' > ?) OR (`a` = ? AND `b` = ? AND `c` COLLATE 'latin1' > ?)) AND ((`a` COLLATE 'latin1' < ?) OR (`a` = ? AND `b` COLLATE 'latin1' < ?) OR (`a` = ? AND `b` = ? AND `c` COLLATE 'latin1' < ?))") + expectArgs = []string{"1", "1", "3", "1", "3", "5", "2", "2", "4", "2", "4", "6"} + for i, arg := range args { + c.Assert(arg, Equals, expectArgs[i]) + } + } diff --git a/pkg/diff/diff.go b/pkg/diff/diff.go index 764dd22cd..d135af8c5 100644 --- a/pkg/diff/diff.go +++ b/pkg/diff/diff.go @@ -22,10 +22,11 @@ import ( "strings" "sync" - "github.com/ngaut/log" "github.com/pingcap/errors" "github.com/pingcap/parser/model" "github.com/pingcap/tidb-tools/pkg/dbutil" + "github.com/pingcap/tidb-tools/pkg/utils" + log "github.com/sirupsen/logrus" ) // TableInstance record a table instance @@ -50,7 +51,7 @@ type TableDiff struct { RemoveColumns []string // field should be the primary key, unique key or field with index - Field string + Fields string // select range, for example: "age > 10 AND age < 20" Range string @@ -82,6 +83,9 @@ type TableDiff struct { // ignore check table's data IgnoreDataCheck bool + // get tidb statistics information from which table instance. if is nil, will split chunk by random. + TiDBStatsSource *TableInstance + sqlCh chan string wg sync.WaitGroup @@ -150,7 +154,7 @@ func (t *TableDiff) adjustConfig() { } if len(t.Range) == 0 { - t.Range = "true" + t.Range = "TRUE" } if t.Sample <= 0 { t.Sample = 100 @@ -167,8 +171,19 @@ func (t *TableDiff) CheckTableData(ctx context.Context) (bool, error) { } // EqualTableData checks data is equal or not. -func (t *TableDiff) EqualTableData(ctx context.Context) (bool, error) { - allJobs, err := GenerateCheckJob(t.TargetTable, t.Field, t.Range, t.ChunkSize, t.Sample, t.Collation) +func (t *TableDiff) EqualTableData(ctx context.Context) (equal bool, err error) { + var allJobs []*CheckJob + + table := t.TargetTable + useTiDB := false + + if t.TiDBStatsSource != nil { + table = t.TiDBStatsSource + useTiDB = true + } + + allJobs, err = GenerateCheckJob(table, t.Fields, t.Range, t.ChunkSize, t.Collation, useTiDB) + if err != nil { return false, errors.Trace(err) } @@ -198,7 +213,7 @@ func (t *TableDiff) EqualTableData(ctx context.Context) (bool, error) { } num := 0 - equal := true + equal = true CheckResult: for { @@ -222,7 +237,7 @@ func (t *TableDiff) getSourceTableChecksum(ctx context.Context, job *CheckJob) ( var checksum int64 for _, sourceTable := range t.SourceTables { - checksumTmp, err := dbutil.GetCRC32Checksum(ctx, sourceTable.Conn, sourceTable.Schema, sourceTable.Table, t.TargetTable.info, job.Where, job.Args, SliceToMap(t.IgnoreColumns)) + checksumTmp, err := dbutil.GetCRC32Checksum(ctx, sourceTable.Conn, sourceTable.Schema, sourceTable.Table, t.TargetTable.info, job.Where, utils.StringsToInterfaces(job.Args), utils.SliceToMap(t.IgnoreColumns)) if err != nil { return -1, errors.Trace(err) } @@ -246,7 +261,7 @@ func (t *TableDiff) checkChunkDataEqual(ctx context.Context, checkJobs []*CheckJ return false, errors.Trace(err) } - targetChecksum, err := dbutil.GetCRC32Checksum(ctx, t.TargetTable.Conn, t.TargetTable.Schema, t.TargetTable.Table, t.TargetTable.info, job.Where, job.Args, SliceToMap(t.IgnoreColumns)) + targetChecksum, err := dbutil.GetCRC32Checksum(ctx, t.TargetTable.Conn, t.TargetTable.Schema, t.TargetTable.Table, t.TargetTable.info, job.Where, utils.StringsToInterfaces(job.Args), utils.SliceToMap(t.IgnoreColumns)) if err != nil { return false, errors.Trace(err) } @@ -261,14 +276,14 @@ func (t *TableDiff) checkChunkDataEqual(ctx context.Context, checkJobs []*CheckJ // if checksum is not equal or don't need compare checksum, compare the data sourceRows := make(map[string]*sql.Rows) for i, sourceTable := range t.SourceTables { - rows, _, err := getChunkRows(ctx, sourceTable.Conn, sourceTable.Schema, sourceTable.Table, sourceTable.info, job.Where, job.Args, SliceToMap(t.IgnoreColumns), t.Collation) + rows, _, err := getChunkRows(ctx, sourceTable.Conn, sourceTable.Schema, sourceTable.Table, sourceTable.info, job.Where, utils.StringsToInterfaces(job.Args), utils.SliceToMap(t.IgnoreColumns), t.Collation) if err != nil { return false, errors.Trace(err) } sourceRows[fmt.Sprintf("source-%d", i)] = rows } - targetRows, orderKeyCols, err := getChunkRows(ctx, t.TargetTable.Conn, t.TargetTable.Schema, t.TargetTable.Table, t.TargetTable.info, job.Where, job.Args, SliceToMap(t.IgnoreColumns), t.Collation) + targetRows, orderKeyCols, err := getChunkRows(ctx, t.TargetTable.Conn, t.TargetTable.Schema, t.TargetTable.Table, t.TargetTable.info, job.Where, utils.StringsToInterfaces(job.Args), utils.SliceToMap(t.IgnoreColumns), t.Collation) if err != nil { return false, errors.Trace(err) } @@ -506,30 +521,36 @@ func compareData(map1, map2 map[string][]byte, null1, null2 map[string]bool, ord return false, 0, errors.Errorf("don't have key %s", col.Name.O) } if needQuotes(col.FieldType) { - if string(data1) > string(data2) { - cmp = 1 - break - } else if string(data1) < string(data2) { - cmp = -1 - break - } else { + + strData1 := string(data1) + strData2 := string(data2) + + if len(strData1) == len(strData2) && strData1 == strData2 { continue } + + cmp = -1 + if strData1 > strData2 { + cmp = 1 + } + break + } else { num1, err1 := strconv.ParseFloat(string(data1), 64) num2, err2 := strconv.ParseFloat(string(data2), 64) if err1 != nil || err2 != nil { return false, 0, errors.Errorf("convert %s, %s to float failed, err1: %v, err2: %v", string(data1), string(data2), err1, err2) } + + if num1 == num2 { + continue + } + + cmp = -1 if num1 > num2 { cmp = 1 - break - } else if num1 < num2 { - cmp = -1 - break - } else { - continue } + break } } diff --git a/pkg/diff/diff_test.go b/pkg/diff/diff_test.go index 24693a2e2..13dd137ba 100644 --- a/pkg/diff/diff_test.go +++ b/pkg/diff/diff_test.go @@ -17,7 +17,6 @@ import ( "context" "database/sql" "fmt" - "math" "testing" _ "github.com/go-sql-driver/mysql" @@ -89,7 +88,7 @@ func (t *testDiffSuite) TestDiff(c *C) { dbConn, err := createConn() c.Assert(err, IsNil) - _, err = dbConn.Query("create database if not exists test") + _, err = dbConn.Query("CREATE DATABASE IF NOT EXISTS `test`") c.Assert(err, IsNil) testStructEqual(dbConn, c) @@ -162,8 +161,8 @@ func testDataEqual(dbConn *sql.DB, c *C) { targetTable := "testb" defer func() { - _, _ = dbConn.Query(fmt.Sprintf("drop table test.%s", sourceTable)) - _, _ = dbConn.Query(fmt.Sprintf("drop table test.%s", targetTable)) + _, _ = dbConn.Query(fmt.Sprintf("DROP TABLE `test`.`%s`", sourceTable)) + _, _ = dbConn.Query(fmt.Sprintf("DROP TABLE `test`.`%s`", targetTable)) }() err := generateData(dbConn, dbutil.GetDBConfigFromEnv("test"), sourceTable, targetTable) @@ -222,7 +221,7 @@ func createTableDiff(db *sql.DB) *TableDiff { } func createConn() (*sql.DB, error) { - return dbutil.OpenDB(dbutil.GetDBConfigFromEnv("test")) + return dbutil.OpenDB(dbutil.GetDBConfigFromEnv("")) } func generateData(dbConn *sql.DB, dbCfg dbutil.DBConfig, sourceTable, targetTable string) error { @@ -237,7 +236,7 @@ func generateData(dbConn *sql.DB, dbCfg dbutil.DBConfig, sourceTable, targetTabl cfg := &importer.Config{ TableSQL: createTableSQL, - WorkerCount: 1, + WorkerCount: 5, JobCount: 10000, Batch: 100, DBCfg: dbCfg, @@ -247,12 +246,12 @@ func generateData(dbConn *sql.DB, dbCfg dbutil.DBConfig, sourceTable, targetTabl importer.DoProcess(cfg) // generate data for target table - _, err := dbConn.Query(fmt.Sprintf("create table test.%s like test.%s", targetTable, sourceTable)) + _, err := dbConn.Query(fmt.Sprintf("CREATE TABLE `test`.`%s` LIKE `test`.`%s`", targetTable, sourceTable)) if err != nil { return err } - _, err = dbConn.Query(fmt.Sprintf("insert into test.%s (a, b, c, d, e, h) select a, b, c, d, e, h from test.%s", targetTable, sourceTable)) + _, err = dbConn.Query(fmt.Sprintf("INSERT INTO `test`.`%s` (`a`, `b`, `c`, `d`, `e`, `h`) SELECT `a`, `b`, `c`, `d`, `e`, `h` FROM `test`.`%s`", targetTable, sourceTable)) if err != nil { return err } @@ -261,22 +260,22 @@ func generateData(dbConn *sql.DB, dbCfg dbutil.DBConfig, sourceTable, targetTabl } func updateData(dbConn *sql.DB, table string) error { - values, err := dbutil.GetRandomValues(context.Background(), dbConn, "test", table, "e", 3, math.MinInt64, math.MaxInt64, "true", "") + values, _, err := dbutil.GetRandomValues(context.Background(), dbConn, "test", table, "e", 3, "TRUE", nil, "") if err != nil { return err } - _, err = dbConn.Exec(fmt.Sprintf("update test.%s set e = e+1 where e = %v", table, values[0])) + _, err = dbConn.Exec(fmt.Sprintf("UPDATE `test`.`%s` SET `e` = `e`+1 WHERE `e` = %v", table, values[0])) if err != nil { return err } - _, err = dbConn.Exec(fmt.Sprintf("delete from test.%s where e = %v", table, values[1])) + _, err = dbConn.Exec(fmt.Sprintf("DELETE FROM `test`.`%s` where `e` = %v", table, values[1])) if err != nil { return err } - _, err = dbConn.Exec(fmt.Sprintf("replace into test.%s values('1992-09-27','2018-09-03 16:26:27','14:45:33','i',2048790075,2008)", table)) + _, err = dbConn.Exec(fmt.Sprintf("REPLACE INTO `test`.`%s` VALUES('1992-09-27','2018-09-03 16:26:27','14:45:33','i',2048790075,2008)", table)) if err != nil { return err } diff --git a/pkg/diff/merge.go b/pkg/diff/merge.go index 7c2a540f1..8fa65ecda 100644 --- a/pkg/diff/merge.go +++ b/pkg/diff/merge.go @@ -16,8 +16,8 @@ package diff import ( "strconv" - "github.com/ngaut/log" "github.com/pingcap/parser/model" + log "github.com/sirupsen/logrus" ) // RowData is the struct of rows selected from mysql/tidb @@ -41,11 +41,10 @@ func (r RowDatas) Less(i, j int) bool { data1 = r.Rows[i].Data[col.Name.O] data2 = r.Rows[j].Data[col.Name.O] if needQuotes(col.FieldType) { - if string(data1) > string(data2) { - return false - } else if string(data1) < string(data2) { - return true - } else { + strData1 := string(data1) + strData2 := string(data2) + + if strData1 == strData2 { // `NULL` is less than "" if r.Rows[i].Null[col.Name.O] { return true @@ -55,20 +54,25 @@ func (r RowDatas) Less(i, j int) bool { } continue } - } else { - num1, err1 := strconv.ParseFloat(string(data1), 64) - num2, err2 := strconv.ParseFloat(string(data2), 64) - if err1 != nil || err2 != nil { - log.Fatalf("convert %s, %s to float failed, err1: %v, err2: %v", string(data1), string(data2), err1, err2) - } - if num1 > num2 { + if strData1 > strData2 { return false - } else if num1 < num2 { - return true - } else { - continue } + return true + } + num1, err1 := strconv.ParseFloat(string(data1), 64) + num2, err2 := strconv.ParseFloat(string(data2), 64) + if err1 != nil || err2 != nil { + log.Fatalf("convert %s, %s to float failed, err1: %v, err2: %v", string(data1), string(data2), err1, err2) } + + if num1 == num2 { + continue + } + if num1 > num2 { + return false + } + return true + } return true diff --git a/pkg/diff/util.go b/pkg/diff/util.go index fb92b7f26..3e98762e4 100644 --- a/pkg/diff/util.go +++ b/pkg/diff/util.go @@ -16,10 +16,11 @@ package diff import ( "math/rand" - "github.com/ngaut/log" "github.com/pingcap/parser/model" "github.com/pingcap/tidb-tools/pkg/dbutil" + "github.com/pingcap/tidb-tools/pkg/utils" "github.com/pingcap/tidb/types" + log "github.com/sirupsen/logrus" ) func equalStrings(str1, str2 []string) bool { @@ -39,7 +40,7 @@ func removeColumns(tableInfo *model.TableInfo, columns []string) *model.TableInf return tableInfo } - removeColMap := SliceToMap(columns) + removeColMap := utils.SliceToMap(columns) for i := 0; i < len(tableInfo.Indices); i++ { index := tableInfo.Indices[i] for j := 0; j < len(index.Columns); j++ { @@ -66,6 +67,19 @@ func removeColumns(tableInfo *model.TableInfo, columns []string) *model.TableInf return tableInfo } +func getColumnsFromIndex(index *model.IndexInfo, tableInfo *model.TableInfo) []*model.ColumnInfo { + indexColumns := make([]*model.ColumnInfo, 0, len(index.Columns)) + for _, indexColumn := range index.Columns { + for _, column := range tableInfo.Columns { + if column.Name.O == indexColumn.Name.O { + indexColumns = append(indexColumns, column) + } + } + } + + return indexColumns +} + func getRandomN(total, num int) []int { if num > total { log.Warnf("the num %d is greater than total %d", num, total) @@ -88,12 +102,3 @@ func getRandomN(total, num int) []int { func needQuotes(ft types.FieldType) bool { return !(dbutil.IsNumberType(ft.Tp) || dbutil.IsFloatType(ft.Tp)) } - -// SliceToMap converts slice to map -func SliceToMap(slice []string) map[string]interface{} { - sMap := make(map[string]interface{}) - for _, str := range slice { - sMap[str] = struct{}{} - } - return sMap -} diff --git a/pkg/utils/util.go b/pkg/utils/util.go new file mode 100644 index 000000000..caa33dc15 --- /dev/null +++ b/pkg/utils/util.go @@ -0,0 +1,20 @@ +package utils + +// SliceToMap converts slice to map +func SliceToMap(slice []string) map[string]interface{} { + sMap := make(map[string]interface{}) + for _, str := range slice { + sMap[str] = struct{}{} + } + return sMap +} + +// StringsToInterfaces converts string slice to interface slice +func StringsToInterfaces(strs []string) []interface{} { + is := make([]interface{}, 0, len(strs)) + for _, str := range strs { + is = append(is, str) + } + + return is +} diff --git a/sync_diff_inspector/config.go b/sync_diff_inspector/config.go index d2d9dc484..301e5561d 100644 --- a/sync_diff_inspector/config.go +++ b/sync_diff_inspector/config.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/model" "github.com/pingcap/tidb-tools/pkg/dbutil" + router "github.com/pingcap/tidb-tools/pkg/table-router" log "github.com/sirupsen/logrus" ) @@ -72,7 +73,7 @@ type TableConfig struct { // columns be removed, will remove these columns from table info, and will not check these columns' data. RemoveColumns []string `toml:"remove-columns"` // field should be the primary key, unique key or field with index - Field string `toml:"index-field"` + Fields string `toml:"index-fields"` // select range, for example: "age > 10 AND age < 20" Range string `toml:"range"` // set true if comparing sharding tables with target table, should have more than one source tables. @@ -185,6 +186,9 @@ type Config struct { // the tables to be checked Tables []*CheckTables `toml:"check-tables" json:"check-tables"` + // TableRules defines table name and database name's conversion relationship between source database and target database + TableRules []*router.TableRule `toml:"table-rules" json:"table-rules"` + // the config of table TableCfgs []*TableConfig `toml:"table-config" json:"table-config"` @@ -194,6 +198,9 @@ type Config struct { // ignore check table's data IgnoreDataCheck bool `toml:"ignore-data-check" json:"ignore-data-check"` + // use this tidb's statistics information to split chunk + TiDBInstanceID string `toml:"tidb-instance-id" json:"tidb-instance-id"` + // config file ConfigFile string @@ -275,10 +282,6 @@ func (c *Config) checkConfig() bool { return false } - if c.TargetDBCfg.InstanceID == "" { - c.TargetDBCfg.InstanceID = "target" - } - if len(c.SourceDBCfg) == 0 { log.Error("must have at least one source database") return false @@ -290,6 +293,14 @@ func (c *Config) checkConfig() bool { } } + if c.TargetDBCfg.InstanceID == "" { + c.TargetDBCfg.InstanceID = "target" + } + if _, ok := sourceInstanceMap[c.TargetDBCfg.InstanceID]; ok { + log.Errorf("target has same instance id %s in source", c.TargetDBCfg.InstanceID) + return false + } + if len(c.Tables) == 0 { log.Error("must specify check tables") return false diff --git a/sync_diff_inspector/config.toml b/sync_diff_inspector/config.toml index 1d411371d..4cfb8a1d4 100644 --- a/sync_diff_inspector/config.toml +++ b/sync_diff_inspector/config.toml @@ -29,6 +29,17 @@ ignore-struct-check = false # the name of the file which saves sqls used to fix different data fix-sql-file = "fix.sql" +# use this tidb's statistics information to split chunk +# tidb-instance-id = "" + +# uncomment this if comparing data with different database name or table name +#[[table-rules]] +#schema-pattern = "test_*" +#table-pattern = "t_*" +#target-schema = "test" +#target-table = "t" + + # tables need to check. [[check-tables]] # schema name in target database. @@ -39,7 +50,7 @@ tables = ["test1", "test2", "test3"] # support regular expression, must start with '~'. # for example, this config will check tables with prefix 'test'. -# tables = ["~test*"] +# tables = ["~^test.*"] # schema and table in table-config must be contained in check-tables. @@ -53,7 +64,10 @@ table = "test3" # field should be the primary key, unique key or field with index. # if comment this, diff will find a suitable field. -index-field = "id" +index-fields = "id" + +# can set multiple fields split by ',' +# index-fields = "id,age" # check data's range. range = "age > 10 AND age < 20" @@ -104,5 +118,6 @@ host = "127.0.0.1" port = 4000 user = "root" password = "" +instance-id = "target-1" # remove comment if use tidb's snapshot data # snapshot = "2016-10-08 16:45:26" \ No newline at end of file diff --git a/sync_diff_inspector/diff.go b/sync_diff_inspector/diff.go index 7002bc6c4..6d75c7538 100644 --- a/sync_diff_inspector/diff.go +++ b/sync_diff_inspector/diff.go @@ -22,6 +22,8 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb-tools/pkg/dbutil" "github.com/pingcap/tidb-tools/pkg/diff" + router "github.com/pingcap/tidb-tools/pkg/table-router" + "github.com/pingcap/tidb-tools/pkg/utils" log "github.com/sirupsen/logrus" ) @@ -39,6 +41,8 @@ type Diff struct { tables map[string]map[string]*TableConfig fixSQLFile *os.File report *Report + tidbInstanceID string + tableRouter *router.Table ctx context.Context } @@ -54,6 +58,7 @@ func NewDiff(ctx context.Context, cfg *Config) (diff *Diff, err error) { useChecksum: cfg.UseChecksum, ignoreDataCheck: cfg.IgnoreDataCheck, ignoreStructCheck: cfg.IgnoreStructCheck, + tidbInstanceID: cfg.TiDBInstanceID, tables: make(map[string]map[string]*TableConfig), report: NewReport(), ctx: ctx, @@ -125,18 +130,58 @@ func (df *Diff) CreateDBConn(cfg *Config) (err error) { } // AdjustTableConfig adjusts the table's config by check-tables and table-config. -func (df *Diff) AdjustTableConfig(cfg *Config) error { +func (df *Diff) AdjustTableConfig(cfg *Config) (err error) { + df.tableRouter, err = router.NewTableRouter(false, cfg.TableRules) + if err != nil { + return errors.Trace(err) + } + allTablesMap, err := df.GetAllTables(cfg) if err != nil { return errors.Trace(err) } + // get all source table's matched target table + // target database name => target table name => all matched source table instance + sourceTablesMap := make(map[string]map[string][]TableInstance) + for instanceID, allSchemas := range allTablesMap { + if instanceID == df.targetDB.InstanceID { + continue + } + + for schema, allTables := range allSchemas { + for table := range allTables { + targetSchema, targetTable, err := df.tableRouter.Route(schema, table) + if err != nil { + return errors.Errorf("get route result for %s.%s.%s failed, error %v", instanceID, schema, table, err) + } + + if _, ok := sourceTablesMap[targetSchema]; !ok { + sourceTablesMap[targetSchema] = make(map[string][]TableInstance) + } + + if _, ok := sourceTablesMap[targetSchema][targetTable]; !ok { + sourceTablesMap[targetSchema][targetTable] = make([]TableInstance, 0, 1) + } + + sourceTablesMap[targetSchema][targetTable] = append(sourceTablesMap[targetSchema][targetTable], TableInstance{ + InstanceID: instanceID, + Schema: schema, + Table: table, + }) + } + } + } + // fill the table information. // will add default source information, don't worry, we will use table config's info replace this later. for _, schemaTables := range cfg.Tables { df.tables[schemaTables.Schema] = make(map[string]*TableConfig) tables := make([]string, 0, len(schemaTables.Tables)) - allTables := allTablesMap[schemaName(df.targetDB.InstanceID, schemaTables.Schema)] + allTables, ok := allTablesMap[df.targetDB.InstanceID][schemaTables.Schema] + if !ok { + return errors.NotFoundf("schema %s.%s", df.targetDB.InstanceID, schemaTables.Schema) + } for _, table := range schemaTables.Tables { matchedTables, err := df.GetMatchTable(df.targetDB, schemaTables.Schema, table, allTables) @@ -149,7 +194,7 @@ func (df *Diff) AdjustTableConfig(cfg *Config) error { for _, tableName := range tables { tableInfo, err := dbutil.GetTableInfoWithRowID(df.ctx, df.targetDB.Conn, schemaTables.Schema, tableName, cfg.UseRowID) if err != nil { - return errors.Errorf("get table %s.%s's inforamtion error %v", schemaTables.Schema, tableName, errors.ErrorStack(err)) + return errors.Errorf("get table %s.%s's inforamtion error %s", schemaTables.Schema, tableName, errors.ErrorStack(err)) } if _, ok := df.tables[schemaTables.Schema][tableName]; ok { @@ -157,6 +202,18 @@ func (df *Diff) AdjustTableConfig(cfg *Config) error { continue } + sourceTables := make([]TableInstance, 0, 1) + if _, ok := sourceTablesMap[schemaTables.Schema][tableName]; ok { + sourceTables = sourceTablesMap[schemaTables.Schema][tableName] + } else { + // use same database name and table name + sourceTables = append(sourceTables, TableInstance{ + InstanceID: cfg.SourceDBCfg[0].InstanceID, + Schema: schemaTables.Schema, + Table: tableName, + }) + } + df.tables[schemaTables.Schema][tableName] = &TableConfig{ TableInstance: TableInstance{ Schema: schemaTables.Schema, @@ -165,21 +222,17 @@ func (df *Diff) AdjustTableConfig(cfg *Config) error { IgnoreColumns: make([]string, 0, 1), TargetTableInfo: tableInfo, Range: "TRUE", - SourceTables: []TableInstance{{ - InstanceID: cfg.SourceDBCfg[0].InstanceID, - Schema: schemaTables.Schema, - Table: tableName, - }}, + SourceTables: sourceTables, } } } for _, table := range cfg.TableCfgs { if _, ok := df.tables[table.Schema]; !ok { - return errors.Errorf("schema %s not found in check tables", table.Schema) + return errors.NotFoundf("schema %s in check tables", table.Schema) } if _, ok := df.tables[table.Schema][table.Table]; !ok { - return errors.Errorf("table %s.%s not found in check tables", table.Schema, table.Table) + return errors.NotFoundf("table %s.%s in check tables", table.Schema, table.Table) } sourceTables := make([]TableInstance, 0, len(table.SourceTables)) @@ -188,7 +241,7 @@ func (df *Diff) AdjustTableConfig(cfg *Config) error { return errors.Errorf("unkonwn database instance id %s", sourceTable.InstanceID) } - allTables, ok := allTablesMap[schemaName(df.sourceDBs[sourceTable.InstanceID].InstanceID, sourceTable.Schema)] + allTables, ok := allTablesMap[df.sourceDBs[sourceTable.InstanceID].InstanceID][sourceTable.Schema] if !ok { return errors.Errorf("unknown schema %s in database %+v", sourceTable.Schema, df.sourceDBs[sourceTable.InstanceID]) } @@ -215,7 +268,7 @@ func (df *Diff) AdjustTableConfig(cfg *Config) error { } df.tables[table.Schema][table.Table].IgnoreColumns = table.IgnoreColumns df.tables[table.Schema][table.Table].RemoveColumns = table.RemoveColumns - df.tables[table.Schema][table.Table].Field = table.Field + df.tables[table.Schema][table.Table].Fields = table.Fields df.tables[table.Schema][table.Table].Collation = table.Collation } @@ -223,37 +276,36 @@ func (df *Diff) AdjustTableConfig(cfg *Config) error { } // GetAllTables get all tables in all databases. -func (df *Diff) GetAllTables(cfg *Config) (map[string]map[string]interface{}, error) { - allTablesMap := make(map[string]map[string]interface{}) +func (df *Diff) GetAllTables(cfg *Config) (map[string]map[string]map[string]interface{}, error) { + // instanceID => schema => table + allTablesMap := make(map[string]map[string]map[string]interface{}) - for _, schemaTables := range cfg.Tables { - if _, ok := allTablesMap[schemaName(cfg.TargetDBCfg.InstanceID, schemaTables.Schema)]; ok { - continue - } - - allTables, err := dbutil.GetTables(df.ctx, cfg.TargetDBCfg.Conn, schemaTables.Schema) + allTablesMap[df.targetDB.InstanceID] = make(map[string]map[string]interface{}) + targetSchemas, err := dbutil.GetSchemas(df.ctx, df.targetDB.Conn) + if err != nil { + return nil, errors.Annotatef(err, "get schemas from %s", df.targetDB.InstanceID) + } + for _, schema := range targetSchemas { + allTables, err := dbutil.GetTables(df.ctx, df.targetDB.Conn, schema) if err != nil { - return nil, errors.Errorf("get tables from %s.%s error %v", cfg.TargetDBCfg.InstanceID, schemaTables.Schema, errors.Trace(err)) + return nil, errors.Annotatef(err, "get tables from %s.%s", df.targetDB.InstanceID, schema) } - allTablesMap[schemaName(cfg.TargetDBCfg.InstanceID, schemaTables.Schema)] = diff.SliceToMap(allTables) + allTablesMap[df.targetDB.InstanceID][schema] = utils.SliceToMap(allTables) } - for _, table := range cfg.TableCfgs { - for _, sourceTable := range table.SourceTables { - if _, ok := allTablesMap[schemaName(sourceTable.InstanceID, sourceTable.Schema)]; ok { - continue - } - - db, ok := df.sourceDBs[sourceTable.InstanceID] - if !ok { - return nil, errors.Errorf("unknown instance id %s", sourceTable.InstanceID) - } + for _, source := range df.sourceDBs { + allTablesMap[source.InstanceID] = make(map[string]map[string]interface{}) + sourceSchemas, err := dbutil.GetSchemas(df.ctx, source.Conn) + if err != nil { + return nil, errors.Annotatef(err, "get schemas from %s", source.InstanceID) + } - allTables, err := dbutil.GetTables(df.ctx, db.Conn, sourceTable.Schema) + for _, schema := range sourceSchemas { + allTables, err := dbutil.GetTables(df.ctx, source.Conn, schema) if err != nil { - return nil, errors.Errorf("get tables from %s.%s error %v", db.InstanceID, sourceTable.Schema, errors.Trace(err)) + return nil, errors.Annotatef(err, "get tables from %s.%s", source.InstanceID, schema) } - allTablesMap[schemaName(db.InstanceID, sourceTable.Schema)] = diff.SliceToMap(allTables) + allTablesMap[source.InstanceID][schema] = utils.SliceToMap(allTables) } } @@ -306,27 +358,44 @@ func (df *Diff) Equal() (err error) { for _, schema := range df.tables { for _, table := range schema { + var tidbStatsSource *diff.TableInstance + sourceTables := make([]*diff.TableInstance, 0, len(table.SourceTables)) for _, sourceTable := range table.SourceTables { - sourceTables = append(sourceTables, &diff.TableInstance{ + sourceTableInstance := &diff.TableInstance{ Conn: df.sourceDBs[sourceTable.InstanceID].Conn, Schema: sourceTable.Schema, Table: sourceTable.Table, - }) + } + sourceTables = append(sourceTables, sourceTableInstance) + + if sourceTable.InstanceID == df.tidbInstanceID { + tidbStatsSource = sourceTableInstance + } + } + + targetTableInstance := &diff.TableInstance{ + Conn: df.targetDB.Conn, + Schema: table.Schema, + Table: table.Table, + } + + if df.targetDB.InstanceID == df.tidbInstanceID { + tidbStatsSource = targetTableInstance + } + + if len(df.tidbInstanceID) != 0 && tidbStatsSource == nil { + return errors.NotFoundf("tidb instance id %s", df.tidbInstanceID) } td := &diff.TableDiff{ SourceTables: sourceTables, - TargetTable: &diff.TableInstance{ - Conn: df.targetDB.Conn, - Schema: table.Schema, - Table: table.Table, - }, + TargetTable: targetTableInstance, IgnoreColumns: table.IgnoreColumns, RemoveColumns: table.RemoveColumns, - Field: table.Field, + Fields: table.Fields, Range: table.Range, Collation: table.Collation, ChunkSize: df.chunkSize, @@ -336,6 +405,7 @@ func (df *Diff) Equal() (err error) { UseChecksum: df.useChecksum, IgnoreStructCheck: df.ignoreStructCheck, IgnoreDataCheck: df.ignoreDataCheck, + TiDBStatsSource: tidbStatsSource, } structEqual, dataEqual, err := td.Equal(df.ctx, func(dml string) error {