diff --git a/cdc/sink/common/common.go b/cdc/sink/common/common.go index 844e2f0bedd..469a12dbb40 100644 --- a/cdc/sink/common/common.go +++ b/cdc/sink/common/common.go @@ -16,7 +16,6 @@ package common import ( "sort" "sync" - "sync/atomic" "github.com/pingcap/log" "github.com/pingcap/ticdc/cdc/model" @@ -55,7 +54,6 @@ func (t *txnsWithTheSameCommitTs) Append(row *model.RowChangedEvent) { type UnresolvedTxnCache struct { unresolvedTxnsMu sync.Mutex unresolvedTxns map[model.TableID][]*txnsWithTheSameCommitTs - checkpointTs uint64 } // NewUnresolvedTxnCache returns a new UnresolvedTxnCache @@ -103,32 +101,27 @@ func (c *UnresolvedTxnCache) Append(filter *filter.Filter, rows ...*model.RowCha // Resolved returns resolved txns according to resolvedTs // The returned map contains many txns grouped by tableID. for each table, the each commitTs of txn in txns slice is strictly increasing -func (c *UnresolvedTxnCache) Resolved(resolvedTs uint64) map[model.TableID][]*model.SingleTableTxn { - if resolvedTs <= atomic.LoadUint64(&c.checkpointTs) { - return nil - } - +func (c *UnresolvedTxnCache) Resolved(resolvedTsMap *sync.Map) (map[model.TableID]uint64, map[model.TableID][]*model.SingleTableTxn) { c.unresolvedTxnsMu.Lock() defer c.unresolvedTxnsMu.Unlock() if len(c.unresolvedTxns) == 0 { - return nil + return nil, nil } - _, resolvedTxnsMap := splitResolvedTxn(resolvedTs, c.unresolvedTxns) - return resolvedTxnsMap -} - -// UpdateCheckpoint updates the checkpoint ts -func (c *UnresolvedTxnCache) UpdateCheckpoint(checkpointTs uint64) { - atomic.StoreUint64(&c.checkpointTs, checkpointTs) + return splitResolvedTxn(resolvedTsMap, c.unresolvedTxns) } func splitResolvedTxn( - resolvedTs uint64, unresolvedTxns map[model.TableID][]*txnsWithTheSameCommitTs, -) (minTs uint64, resolvedRowsMap map[model.TableID][]*model.SingleTableTxn) { + resolvedTsMap *sync.Map, unresolvedTxns map[model.TableID][]*txnsWithTheSameCommitTs, +) (flushedResolvedTsMap map[model.TableID]uint64, resolvedRowsMap map[model.TableID][]*model.SingleTableTxn) { resolvedRowsMap = make(map[model.TableID][]*model.SingleTableTxn, len(unresolvedTxns)) - minTs = resolvedTs + flushedResolvedTsMap = make(map[model.TableID]uint64, len(unresolvedTxns)) for tableID, txns := range unresolvedTxns { + v, ok := resolvedTsMap.Load(tableID) + if !ok { + continue + } + resolvedTs := v.(uint64) i := sort.Search(len(txns), func(i int) bool { return txns[i].commitTs > resolvedTs }) @@ -154,9 +147,7 @@ func splitResolvedTxn( } } resolvedRowsMap[tableID] = resolvedTxns - if len(resolvedTxnsWithTheSameCommitTs) > 0 && resolvedTxnsWithTheSameCommitTs[0].commitTs < minTs { - minTs = resolvedTxnsWithTheSameCommitTs[0].commitTs - } + flushedResolvedTsMap[tableID] = resolvedTs } return } diff --git a/cdc/sink/common/common_test.go b/cdc/sink/common/common_test.go index 28a87086337..ebf1ced1b39 100644 --- a/cdc/sink/common/common_test.go +++ b/cdc/sink/common/common_test.go @@ -15,26 +15,22 @@ package common import ( "sort" + "sync" "testing" "github.com/google/go-cmp/cmp" - "github.com/pingcap/check" "github.com/pingcap/ticdc/cdc/model" "github.com/pingcap/ticdc/pkg/util/testleak" + "github.com/stretchr/testify/require" ) -type SinkCommonSuite struct{} +func TestSplitResolvedTxn(test *testing.T) { + defer testleak.AfterTestT(test)() -func Test(t *testing.T) { check.TestingT(t) } - -var _ = check.Suite(&SinkCommonSuite{}) - -func (s SinkCommonSuite) TestSplitResolvedTxn(c *check.C) { - defer testleak.AfterTest(c)() testCases := [][]struct { - input []*model.RowChangedEvent - resolvedTs model.Ts - expected map[model.TableID][]*model.SingleTableTxn + input []*model.RowChangedEvent + resolvedTsMap map[model.TableID]uint64 + expected map[model.TableID][]*model.SingleTableTxn }{{{ // Testing basic transaction collocation, no txns with the same committs input: []*model.RowChangedEvent{ {StartTs: 1, CommitTs: 5, Table: &model.TableName{TableID: 1}}, @@ -45,7 +41,10 @@ func (s SinkCommonSuite) TestSplitResolvedTxn(c *check.C) { {StartTs: 1, CommitTs: 11, Table: &model.TableName{TableID: 1}}, {StartTs: 1, CommitTs: 12, Table: &model.TableName{TableID: 2}}, }, - resolvedTs: 6, + resolvedTsMap: map[model.TableID]uint64{ + 1: uint64(6), + 2: uint64(6), + }, expected: map[model.TableID][]*model.SingleTableTxn{ 1: {{Table: &model.TableName{TableID: 1}, StartTs: 1, CommitTs: 5, Rows: []*model.RowChangedEvent{ {StartTs: 1, CommitTs: 5, Table: &model.TableName{TableID: 1}}, @@ -59,7 +58,11 @@ func (s SinkCommonSuite) TestSplitResolvedTxn(c *check.C) { input: []*model.RowChangedEvent{ {StartTs: 1, CommitTs: 8, Table: &model.TableName{TableID: 3}}, }, - resolvedTs: 13, + resolvedTsMap: map[model.TableID]uint64{ + 1: uint64(13), + 2: uint64(13), + 3: uint64(13), + }, expected: map[model.TableID][]*model.SingleTableTxn{ 1: {{Table: &model.TableName{TableID: 1}, StartTs: 1, CommitTs: 8, Rows: []*model.RowChangedEvent{ {StartTs: 1, CommitTs: 8, Table: &model.TableName{TableID: 1}}, @@ -76,17 +79,24 @@ func (s SinkCommonSuite) TestSplitResolvedTxn(c *check.C) { }}}, }, }}, {{ // Testing the short circuit path - input: []*model.RowChangedEvent{}, - resolvedTs: 6, - expected: nil, + input: []*model.RowChangedEvent{}, + resolvedTsMap: map[model.TableID]uint64{ + 1: uint64(13), + 2: uint64(13), + 3: uint64(13), + }, + expected: nil, }, { input: []*model.RowChangedEvent{ {StartTs: 1, CommitTs: 11, Table: &model.TableName{TableID: 1}}, {StartTs: 1, CommitTs: 12, Table: &model.TableName{TableID: 1}}, {StartTs: 1, CommitTs: 13, Table: &model.TableName{TableID: 2}}, }, - resolvedTs: 6, - expected: map[model.TableID][]*model.SingleTableTxn{}, + resolvedTsMap: map[model.TableID]uint64{ + 1: uint64(6), + 2: uint64(6), + }, + expected: map[model.TableID][]*model.SingleTableTxn{}, }}, {{ // Testing the txns with the same commitTs input: []*model.RowChangedEvent{ {StartTs: 1, CommitTs: 5, Table: &model.TableName{TableID: 1}}, @@ -99,7 +109,10 @@ func (s SinkCommonSuite) TestSplitResolvedTxn(c *check.C) { {StartTs: 1, CommitTs: 6, Table: &model.TableName{TableID: 2}}, {StartTs: 1, CommitTs: 7, Table: &model.TableName{TableID: 2}}, }, - resolvedTs: 6, + resolvedTsMap: map[model.TableID]uint64{ + 1: uint64(6), + 2: uint64(6), + }, expected: map[model.TableID][]*model.SingleTableTxn{ 1: {{Table: &model.TableName{TableID: 1}, StartTs: 1, CommitTs: 5, Rows: []*model.RowChangedEvent{ {StartTs: 1, CommitTs: 5, Table: &model.TableName{TableID: 1}}, @@ -119,7 +132,10 @@ func (s SinkCommonSuite) TestSplitResolvedTxn(c *check.C) { {StartTs: 2, CommitTs: 8, Table: &model.TableName{TableID: 1}}, {StartTs: 1, CommitTs: 9, Table: &model.TableName{TableID: 1}}, }, - resolvedTs: 13, + resolvedTsMap: map[model.TableID]uint64{ + 1: uint64(13), + 2: uint64(13), + }, expected: map[model.TableID][]*model.SingleTableTxn{ 1: {{Table: &model.TableName{TableID: 1}, StartTs: 1, CommitTs: 8, Rows: []*model.RowChangedEvent{ {StartTs: 1, CommitTs: 8, Table: &model.TableName{TableID: 1}}, @@ -144,7 +160,11 @@ func (s SinkCommonSuite) TestSplitResolvedTxn(c *check.C) { cache := NewUnresolvedTxnCache() for _, t := range tc { cache.Append(nil, t.input...) - resolved := cache.Resolved(t.resolvedTs) + resolvedTsMap := sync.Map{} + for tableID, ts := range t.resolvedTsMap { + resolvedTsMap.Store(tableID, ts) + } + _, resolved := cache.Resolved(&resolvedTsMap) for tableID, txns := range resolved { sort.Slice(txns, func(i, j int) bool { if txns[i].CommitTs != txns[j].CommitTs { @@ -154,8 +174,7 @@ func (s SinkCommonSuite) TestSplitResolvedTxn(c *check.C) { }) resolved[tableID] = txns } - c.Assert(resolved, check.DeepEquals, t.expected, - check.Commentf("%s", cmp.Diff(resolved, t.expected))) + require.Equal(test, t.expected, resolved, cmp.Diff(resolved, t.expected)) } } } diff --git a/cdc/sink/mysql.go b/cdc/sink/mysql.go index c038622b72d..d7feed73954 100644 --- a/cdc/sink/mysql.go +++ b/cdc/sink/mysql.go @@ -21,7 +21,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" dmysql "github.com/go-sql-driver/mysql" @@ -58,10 +57,10 @@ type mysqlSink struct { filter *tifilter.Filter cyclic *cyclic.Cyclic - txnCache *common.UnresolvedTxnCache - workers []*mysqlSinkWorker - resolvedTs uint64 - maxResolvedTs uint64 + txnCache *common.UnresolvedTxnCache + workers []*mysqlSinkWorker + tableCheckpointTs sync.Map + tableMaxResolvedTs sync.Map execWaitNotifier *notify.Notifier resolvedNotifier *notify.Notifier @@ -208,12 +207,10 @@ func (s *mysqlSink) EmitRowChangedEvents(ctx context.Context, rows ...*model.Row // FlushRowChangedEvents will flush all received events, we don't allow mysql // sink to receive events before resolving func (s *mysqlSink) FlushRowChangedEvents(ctx context.Context, tableID model.TableID, resolvedTs uint64) (uint64, error) { - if atomic.LoadUint64(&s.maxResolvedTs) < resolvedTs { - atomic.StoreUint64(&s.maxResolvedTs, resolvedTs) + v, ok := s.tableMaxResolvedTs.Load(tableID) + if !ok || v.(uint64) < resolvedTs { + s.tableMaxResolvedTs.Store(tableID, resolvedTs) } - // resolvedTs can be fallen back, such as a new table is added into this sink - // with a smaller start-ts - atomic.StoreUint64(&s.resolvedTs, resolvedTs) s.resolvedNotifier.Notify() // check and throw error @@ -223,13 +220,7 @@ func (s *mysqlSink) FlushRowChangedEvents(ctx context.Context, tableID model.Tab default: } - checkpointTs := resolvedTs - for _, worker := range s.workers { - workerCheckpointTs := atomic.LoadUint64(&worker.checkpointTs) - if workerCheckpointTs < checkpointTs { - checkpointTs = workerCheckpointTs - } - } + checkpointTs := s.getTableCheckpointTs(tableID) s.statistics.PrintStatus(ctx) return checkpointTs, nil } @@ -246,13 +237,12 @@ func (s *mysqlSink) flushRowChangedEvents(ctx context.Context, receiver *notify. return case <-receiver.C: } - resolvedTs := atomic.LoadUint64(&s.resolvedTs) - resolvedTxnsMap := s.txnCache.Resolved(resolvedTs) + flushedResolvedTsMap, resolvedTxnsMap := s.txnCache.Resolved(&s.tableMaxResolvedTs) if len(resolvedTxnsMap) == 0 { - for _, worker := range s.workers { - atomic.StoreUint64(&worker.checkpointTs, resolvedTs) - } - s.txnCache.UpdateCheckpoint(resolvedTs) + s.tableMaxResolvedTs.Range(func(key, value interface{}) bool { + s.tableCheckpointTs.Store(key, value) + return true + }) continue } @@ -264,10 +254,9 @@ func (s *mysqlSink) flushRowChangedEvents(ctx context.Context, receiver *notify. } s.dispatchAndExecTxns(ctx, resolvedTxnsMap) - for _, worker := range s.workers { - atomic.StoreUint64(&worker.checkpointTs, resolvedTs) + for tableID, resolvedTs := range flushedResolvedTsMap { + s.tableCheckpointTs.Store(tableID, resolvedTs) } - s.txnCache.UpdateCheckpoint(resolvedTs) } } @@ -482,12 +471,20 @@ func (s *mysqlSink) Barrier(ctx context.Context, tableID model.TableID) error { case <-ctx.Done(): return errors.Trace(ctx.Err()) case <-ticker.C: + maxResolvedTs, ok := s.tableMaxResolvedTs.Load(tableID) log.Warn("Barrier doesn't return in time, may be stuck", - zap.Uint64("resolved-ts", atomic.LoadUint64(&s.maxResolvedTs)), - zap.Uint64("checkpoint-ts", s.checkpointTs())) + zap.Int64("tableID", tableID), + zap.Bool("has resolvedTs", ok), + zap.Any("resolvedTs", maxResolvedTs), + zap.Uint64("checkpointTs", s.getTableCheckpointTs(tableID))) default: - maxResolvedTs := atomic.LoadUint64(&s.maxResolvedTs) - if s.checkpointTs() >= maxResolvedTs { + v, ok := s.tableMaxResolvedTs.Load(tableID) + if !ok { + log.Info("No table resolvedTs is found", zap.Int64("table-id", tableID)) + return nil + } + maxResolvedTs := v.(uint64) + if s.getTableCheckpointTs(tableID) >= maxResolvedTs { return nil } checkpointTs, err := s.FlushRowChangedEvents(ctx, tableID, maxResolvedTs) @@ -503,15 +500,12 @@ func (s *mysqlSink) Barrier(ctx context.Context, tableID model.TableID) error { } } -func (s *mysqlSink) checkpointTs() uint64 { - checkpointTs := atomic.LoadUint64(&s.resolvedTs) - for _, worker := range s.workers { - workerCheckpointTs := atomic.LoadUint64(&worker.checkpointTs) - if workerCheckpointTs < checkpointTs { - checkpointTs = workerCheckpointTs - } +func (s *mysqlSink) getTableCheckpointTs(tableID model.TableID) uint64 { + v, ok := s.tableCheckpointTs.Load(tableID) + if ok { + return v.(uint64) } - return checkpointTs + return uint64(0) } func logDMLTxnErr(err error) error { diff --git a/cdc/sink/mysql_params_test.go b/cdc/sink/mysql_params_test.go index 9ca34508be3..06c5fa33803 100644 --- a/cdc/sink/mysql_params_test.go +++ b/cdc/sink/mysql_params_test.go @@ -18,21 +18,22 @@ import ( "database/sql" "net/url" "strings" + "testing" "github.com/DATA-DOG/go-sqlmock" dmysql "github.com/go-sql-driver/mysql" - "github.com/pingcap/check" "github.com/pingcap/ticdc/pkg/util/testleak" + "github.com/stretchr/testify/require" ) -func (s MySQLSinkSuite) TestSinkParamsClone(c *check.C) { - defer testleak.AfterTest(c)() +func TestSinkParamsClone(t *testing.T) { + defer testleak.AfterTestT(t)() param1 := defaultParams.Clone() param2 := param1.Clone() param2.changefeedID = "123" param2.batchReplaceEnabled = false param2.maxTxnRow = 1 - c.Assert(param1, check.DeepEquals, &sinkParams{ + require.Equal(t, &sinkParams{ workerCount: DefaultWorkerCount, maxTxnRow: DefaultMaxTxnRow, tidbTxnMode: defaultTiDBTxnMode, @@ -42,8 +43,8 @@ func (s MySQLSinkSuite) TestSinkParamsClone(c *check.C) { writeTimeout: defaultWriteTimeout, dialTimeout: defaultDialTimeout, safeMode: defaultSafeMode, - }) - c.Assert(param2, check.DeepEquals, &sinkParams{ + }, param1) + require.Equal(t, &sinkParams{ changefeedID: "123", workerCount: DefaultWorkerCount, maxTxnRow: 1, @@ -54,22 +55,22 @@ func (s MySQLSinkSuite) TestSinkParamsClone(c *check.C) { writeTimeout: defaultWriteTimeout, dialTimeout: defaultDialTimeout, safeMode: defaultSafeMode, - }) + }, param2) } -func (s MySQLSinkSuite) TestGenerateDSNByParams(c *check.C) { - defer testleak.AfterTest(c)() +func TestGenerateDSNByParams(t *testing.T) { + defer testleak.AfterTestT(t)() testDefaultParams := func() { db, err := mockTestDB() - c.Assert(err, check.IsNil) + require.Nil(t, err) defer db.Close() dsn, err := dmysql.ParseDSN("root:123456@tcp(127.0.0.1:4000)/") - c.Assert(err, check.IsNil) + require.Nil(t, err) params := defaultParams.Clone() dsnStr, err := generateDSNByParams(context.TODO(), dsn, params, db) - c.Assert(err, check.IsNil) + require.Nil(t, err) expectedParams := []string{ "tidb_txn_mode=optimistic", "readTimeout=2m", @@ -77,45 +78,45 @@ func (s MySQLSinkSuite) TestGenerateDSNByParams(c *check.C) { "allow_auto_random_explicit_insert=1", } for _, param := range expectedParams { - c.Assert(strings.Contains(dsnStr, param), check.IsTrue) + require.True(t, strings.Contains(dsnStr, param)) } - c.Assert(strings.Contains(dsnStr, "time_zone"), check.IsFalse) + require.False(t, strings.Contains(dsnStr, "time_zone")) } testTimezoneParam := func() { db, err := mockTestDB() - c.Assert(err, check.IsNil) + require.Nil(t, err) defer db.Close() dsn, err := dmysql.ParseDSN("root:123456@tcp(127.0.0.1:4000)/") - c.Assert(err, check.IsNil) + require.Nil(t, err) params := defaultParams.Clone() params.timezone = `"UTC"` dsnStr, err := generateDSNByParams(context.TODO(), dsn, params, db) - c.Assert(err, check.IsNil) - c.Assert(strings.Contains(dsnStr, "time_zone=%22UTC%22"), check.IsTrue) + require.Nil(t, err) + require.True(t, strings.Contains(dsnStr, "time_zone=%22UTC%22")) } testTimeoutParams := func() { db, err := mockTestDB() - c.Assert(err, check.IsNil) + require.Nil(t, err) defer db.Close() dsn, err := dmysql.ParseDSN("root:123456@tcp(127.0.0.1:4000)/") - c.Assert(err, check.IsNil) + require.Nil(t, err) uri, err := url.Parse("mysql://127.0.0.1:3306/?read-timeout=4m&write-timeout=5m&timeout=3m") - c.Assert(err, check.IsNil) + require.Nil(t, err) params, err := parseSinkURIToParams(context.TODO(), uri, map[string]string{}) - c.Assert(err, check.IsNil) + require.Nil(t, err) dsnStr, err := generateDSNByParams(context.TODO(), dsn, params, db) - c.Assert(err, check.IsNil) + require.Nil(t, err) expectedParams := []string{ "readTimeout=4m", "writeTimeout=5m", "timeout=3m", } for _, param := range expectedParams { - c.Assert(strings.Contains(dsnStr, param), check.IsTrue) + require.True(t, strings.Contains(dsnStr, param)) } } @@ -124,8 +125,8 @@ func (s MySQLSinkSuite) TestGenerateDSNByParams(c *check.C) { testTimeoutParams() } -func (s MySQLSinkSuite) TestParseSinkURIToParams(c *check.C) { - defer testleak.AfterTest(c)() +func TestParseSinkURIToParams(t *testing.T) { + defer testleak.AfterTestT(t)() expected := defaultParams.Clone() expected.workerCount = 64 expected.maxTxnRow = 20 @@ -144,14 +145,14 @@ func (s MySQLSinkSuite) TestParseSinkURIToParams(c *check.C) { OptCaptureAddr: expected.captureAddr, } uri, err := url.Parse(uriStr) - c.Assert(err, check.IsNil) + require.Nil(t, err) params, err := parseSinkURIToParams(context.TODO(), uri, opts) - c.Assert(err, check.IsNil) - c.Assert(params, check.DeepEquals, expected) + require.Nil(t, err) + require.Equal(t, expected, params) } -func (s MySQLSinkSuite) TestParseSinkURITimezone(c *check.C) { - defer testleak.AfterTest(c)() +func TestParseSinkURITimezone(t *testing.T) { + defer testleak.AfterTestT(t)() uris := []string{ "mysql://127.0.0.1:3306/?time-zone=Asia/Shanghai&worker-count=32", "mysql://127.0.0.1:3306/?time-zone=&worker-count=32", @@ -166,15 +167,15 @@ func (s MySQLSinkSuite) TestParseSinkURITimezone(c *check.C) { opts := map[string]string{} for i, uriStr := range uris { uri, err := url.Parse(uriStr) - c.Assert(err, check.IsNil) + require.Nil(t, err) params, err := parseSinkURIToParams(ctx, uri, opts) - c.Assert(err, check.IsNil) - c.Assert(params.timezone, check.Equals, expected[i]) + require.Nil(t, err) + require.Equal(t, expected[i], params.timezone) } } -func (s MySQLSinkSuite) TestParseSinkURIBadQueryString(c *check.C) { - defer testleak.AfterTest(c)() +func TestParseSinkURIBadQueryString(t *testing.T) { + defer testleak.AfterTestT(t)() uris := []string{ "", "postgre://127.0.0.1:3306", @@ -192,19 +193,19 @@ func (s MySQLSinkSuite) TestParseSinkURIBadQueryString(c *check.C) { for _, uriStr := range uris { if uriStr != "" { uri, err = url.Parse(uriStr) - c.Assert(err, check.IsNil) + require.Nil(t, err) } else { uri = nil } _, err = parseSinkURIToParams(ctx, uri, opts) - c.Assert(err, check.NotNil) + require.NotNil(t, err) } } -func (s MySQLSinkSuite) TestCheckTiDBVariable(c *check.C) { - defer testleak.AfterTest(c)() +func TestCheckTiDBVariable(t *testing.T) { + defer testleak.AfterTestT(t)() db, mock, err := sqlmock.New() - c.Assert(err, check.IsNil) + require.Nil(t, err) defer db.Close() //nolint:errcheck columns := []string{"Variable_name", "Value"} @@ -212,15 +213,16 @@ func (s MySQLSinkSuite) TestCheckTiDBVariable(c *check.C) { sqlmock.NewRows(columns).AddRow("allow_auto_random_explicit_insert", "0"), ) val, err := checkTiDBVariable(context.TODO(), db, "allow_auto_random_explicit_insert", "1") - c.Assert(err, check.IsNil) - c.Assert(val, check.Equals, "1") + require.Nil(t, err) + require.Equal(t, "1", val) mock.ExpectQuery("show session variables like 'no_exist_variable';").WillReturnError(sql.ErrNoRows) val, err = checkTiDBVariable(context.TODO(), db, "no_exist_variable", "0") - c.Assert(err, check.IsNil) - c.Assert(val, check.Equals, "") + require.Nil(t, err) + require.Equal(t, "", val) mock.ExpectQuery("show session variables like 'version';").WillReturnError(sql.ErrConnDone) _, err = checkTiDBVariable(context.TODO(), db, "version", "5.7.25-TiDB-v4.0.0") - c.Assert(err, check.ErrorMatches, ".*"+sql.ErrConnDone.Error()) + require.NotNil(t, err) + require.Regexp(t, ".*"+sql.ErrConnDone.Error(), err.Error()) } diff --git a/cdc/sink/mysql_test.go b/cdc/sink/mysql_test.go index c7328070eef..6f0b6ffa18c 100644 --- a/cdc/sink/mysql_test.go +++ b/cdc/sink/mysql_test.go @@ -23,10 +23,10 @@ import ( "sort" "sync" "testing" + "time" "github.com/DATA-DOG/go-sqlmock" dmysql "github.com/go-sql-driver/mysql" - "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/ticdc/cdc/model" "github.com/pingcap/ticdc/cdc/sink/common" @@ -35,21 +35,15 @@ import ( cerror "github.com/pingcap/ticdc/pkg/errors" "github.com/pingcap/ticdc/pkg/filter" "github.com/pingcap/ticdc/pkg/retry" - "github.com/pingcap/ticdc/pkg/util/testleak" "github.com/pingcap/tidb/infoschema" timodel "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" + "github.com/stretchr/testify/require" ) -type MySQLSinkSuite struct{} - -func Test(t *testing.T) { check.TestingT(t) } - -var _ = check.Suite(&MySQLSinkSuite{}) - -func newMySQLSink4Test(ctx context.Context, c *check.C) *mysqlSink { +func newMySQLSink4Test(ctx context.Context, t *testing.T) *mysqlSink { f, err := filter.NewFilter(config.GetDefaultReplicaConfig()) - c.Assert(err, check.IsNil) + require.Nil(t, err) params := defaultParams.Clone() params.batchReplaceEnabled = false return &mysqlSink{ @@ -60,8 +54,7 @@ func newMySQLSink4Test(ctx context.Context, c *check.C) *mysqlSink { } } -func (s MySQLSinkSuite) TestPrepareDML(c *check.C) { - defer testleak.AfterTest(c)() +func TestPrepareDML(t *testing.T) { testCases := []struct { input []*model.RowChangedEvent expected *preparedDMLs @@ -96,15 +89,14 @@ func (s MySQLSinkSuite) TestPrepareDML(c *check.C) { }} ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ms := newMySQLSink4Test(ctx, c) + ms := newMySQLSink4Test(ctx, t) for i, tc := range testCases { dmls := ms.prepareDMLs(tc.input, 0, 0) - c.Assert(dmls, check.DeepEquals, tc.expected, check.Commentf("%d", i)) + require.Equal(t, tc.expected, dmls, tc.expected, fmt.Sprintf("%d", i)) } } -func (s MySQLSinkSuite) TestPrepareUpdate(c *check.C) { - defer testleak.AfterTest(c)() +func TestPrepareUpdate(t *testing.T) { testCases := []struct { quoteTable string preCols []*model.Column @@ -150,13 +142,12 @@ func (s MySQLSinkSuite) TestPrepareUpdate(c *check.C) { } for _, tc := range testCases { query, args := prepareUpdate(tc.quoteTable, tc.preCols, tc.cols, false) - c.Assert(query, check.Equals, tc.expectedSQL) - c.Assert(args, check.DeepEquals, tc.expectedArgs) + require.Equal(t, tc.expectedSQL, query) + require.Equal(t, tc.expectedArgs, args) } } -func (s MySQLSinkSuite) TestPrepareDelete(c *check.C) { - defer testleak.AfterTest(c)() +func TestPrepareDelete(t *testing.T) { testCases := []struct { quoteTable string preCols []*model.Column @@ -191,13 +182,12 @@ func (s MySQLSinkSuite) TestPrepareDelete(c *check.C) { } for _, tc := range testCases { query, args := prepareDelete(tc.quoteTable, tc.preCols, false) - c.Assert(query, check.Equals, tc.expectedSQL) - c.Assert(args, check.DeepEquals, tc.expectedArgs) + require.Equal(t, tc.expectedSQL, query) + require.Equal(t, tc.expectedArgs, args) } } -func (s MySQLSinkSuite) TestWhereSlice(c *check.C) { - defer testleak.AfterTest(c)() +func TestWhereSlice(t *testing.T) { testCases := []struct { cols []*model.Column forceReplicate bool @@ -276,13 +266,12 @@ func (s MySQLSinkSuite) TestWhereSlice(c *check.C) { } for _, tc := range testCases { colNames, args := whereSlice(tc.cols, tc.forceReplicate) - c.Assert(colNames, check.DeepEquals, tc.expectedColNames) - c.Assert(args, check.DeepEquals, tc.expectedArgs) + require.Equal(t, tc.expectedColNames, colNames) + require.Equal(t, tc.expectedArgs, args) } } -func (s MySQLSinkSuite) TestMapReplace(c *check.C) { - defer testleak.AfterTest(c)() +func TestMapReplace(t *testing.T) { testCases := []struct { quoteTable string cols []*model.Column @@ -316,8 +305,8 @@ func (s MySQLSinkSuite) TestMapReplace(c *check.C) { // multiple times to verify the stability of column sequence in query string for i := 0; i < 10; i++ { query, args := prepareReplace(tc.quoteTable, tc.cols, false, false) - c.Assert(query, check.Equals, tc.expectedQuery) - c.Assert(args, check.DeepEquals, tc.expectedArgs) + require.Equal(t, tc.expectedQuery, query) + require.Equal(t, tc.expectedArgs, args) } } } @@ -328,8 +317,7 @@ func (a sqlArgs) Len() int { return len(a) } func (a sqlArgs) Less(i, j int) bool { return fmt.Sprintf("%s", a[i]) < fmt.Sprintf("%s", a[j]) } func (a sqlArgs) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (s MySQLSinkSuite) TestReduceReplace(c *check.C) { - defer testleak.AfterTest(c)() +func TestReduceReplace(t *testing.T) { testCases := []struct { replaces map[string][][]interface{} batchSize int @@ -434,8 +422,8 @@ func (s MySQLSinkSuite) TestReduceReplace(c *check.C) { sort.Strings(sqls) sort.Sort(sqlArgs(args)) } - c.Assert(sqls, check.DeepEquals, tc.expectSQLs) - c.Assert(args, check.DeepEquals, tc.expectArgs) + require.Equal(t, tc.expectSQLs, sqls) + require.Equal(t, tc.expectArgs, args) } } @@ -456,9 +444,7 @@ func mockTestDB() (*sql.DB, error) { return db, nil } -func (s MySQLSinkSuite) TestAdjustSQLMode(c *check.C) { - defer testleak.AfterTest(c)() - +func TestAdjustSQLMode(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -470,12 +456,12 @@ func (s MySQLSinkSuite) TestAdjustSQLMode(c *check.C) { if dbIndex == 0 { // test db db, err := mockTestDB() - c.Assert(err, check.IsNil) + require.Nil(t, err) return db, nil } // normal db db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) - c.Assert(err, check.IsNil) + require.Nil(t, err) mock.ExpectQuery("SELECT @@SESSION.sql_mode;"). WillReturnRows(sqlmock.NewRows([]string{"@@SESSION.sql_mode"}). AddRow("ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE")) @@ -492,7 +478,8 @@ func (s MySQLSinkSuite) TestAdjustSQLMode(c *check.C) { changefeed := "test-changefeed" sinkURI, err := url.Parse("mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=4") - c.Assert(err, check.IsNil) + require.Nil(t, err) + require.Nil(t, err) rc := config.GetDefaultReplicaConfig() rc.Cyclic = &config.CyclicConfig{ Enable: true, @@ -500,17 +487,17 @@ func (s MySQLSinkSuite) TestAdjustSQLMode(c *check.C) { FilterReplicaID: []uint64{2}, } f, err := filter.NewFilter(rc) - c.Assert(err, check.IsNil) + require.Nil(t, err) cyclicConfig, err := rc.Cyclic.Marshal() - c.Assert(err, check.IsNil) + require.Nil(t, err) opts := map[string]string{ mark.OptCyclicConfig: cyclicConfig, } sink, err := newMySQLSink(ctx, changefeed, sinkURI, f, rc, opts) - c.Assert(err, check.IsNil) + require.Nil(t, err) err = sink.Close(ctx) - c.Assert(err, check.IsNil) + require.Nil(t, err) } type mockUnavailableMySQL struct { @@ -519,19 +506,19 @@ type mockUnavailableMySQL struct { wg sync.WaitGroup } -func newMockUnavailableMySQL(addr string, c *check.C) *mockUnavailableMySQL { +func newMockUnavailableMySQL(addr string, t *testing.T) *mockUnavailableMySQL { s := &mockUnavailableMySQL{ quit: make(chan interface{}), } l, err := net.Listen("tcp", addr) - c.Assert(err, check.IsNil) + require.Nil(t, err) s.listener = l s.wg.Add(1) - go s.serve(c) + go s.serve(t) return s } -func (s *mockUnavailableMySQL) serve(c *check.C) { +func (s *mockUnavailableMySQL) serve(t *testing.T) { defer s.wg.Done() for { @@ -541,7 +528,7 @@ func (s *mockUnavailableMySQL) serve(c *check.C) { case <-s.quit: return default: - c.Error(err) + require.Error(t, err) } } else { s.wg.Add(1) @@ -560,28 +547,24 @@ func (s *mockUnavailableMySQL) Stop() { s.wg.Wait() } -func (s MySQLSinkSuite) TestNewMySQLTimeout(c *check.C) { - defer testleak.AfterTest(c)() - +func TestNewMySQLTimeout(t *testing.T) { addr := "127.0.0.1:33333" - mockMySQL := newMockUnavailableMySQL(addr, c) + mockMySQL := newMockUnavailableMySQL(addr, t) defer mockMySQL.Stop() ctx, cancel := context.WithCancel(context.Background()) defer cancel() changefeed := "test-changefeed" sinkURI, err := url.Parse(fmt.Sprintf("mysql://%s/?read-timeout=2s&timeout=2s", addr)) - c.Assert(err, check.IsNil) + require.Nil(t, err) rc := config.GetDefaultReplicaConfig() f, err := filter.NewFilter(rc) - c.Assert(err, check.IsNil) + require.Nil(t, err) _, err = newMySQLSink(ctx, changefeed, sinkURI, f, rc, map[string]string{}) - c.Assert(errors.Cause(err), check.Equals, driver.ErrBadConn) + require.Equal(t, driver.ErrBadConn, errors.Cause(err)) } -func (s MySQLSinkSuite) TestNewMySQLSinkExecDML(c *check.C) { - defer testleak.AfterTest(c)() - +func TestNewMySQLSinkExecDML(t *testing.T) { dbIndex := 0 mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { defer func() { @@ -590,12 +573,12 @@ func (s MySQLSinkSuite) TestNewMySQLSinkExecDML(c *check.C) { if dbIndex == 0 { // test db db, err := mockTestDB() - c.Assert(err, check.IsNil) + require.Nil(t, err) return db, nil } // normal db db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) - c.Assert(err, check.IsNil) + require.Nil(t, err) mock.ExpectBegin() mock.ExpectExec("REPLACE INTO `s1`.`t1`(`a`,`b`) VALUES (?,?),(?,?)"). WithArgs(1, "test", 2, "test"). @@ -619,12 +602,12 @@ func (s MySQLSinkSuite) TestNewMySQLSinkExecDML(c *check.C) { defer cancel() changefeed := "test-changefeed" sinkURI, err := url.Parse("mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=4") - c.Assert(err, check.IsNil) + require.Nil(t, err) rc := config.GetDefaultReplicaConfig() f, err := filter.NewFilter(rc) - c.Assert(err, check.IsNil) + require.Nil(t, err) sink, err := newMySQLSink(ctx, changefeed, sinkURI, f, rc, map[string]string{}) - c.Assert(err, check.IsNil) + require.Nil(t, err) rows := []*model.RowChangedEvent{ { @@ -675,39 +658,38 @@ func (s MySQLSinkSuite) TestNewMySQLSinkExecDML(c *check.C) { } err = sink.EmitRowChangedEvents(ctx, rows...) - c.Assert(err, check.IsNil) + require.Nil(t, err) + // retry to make sure event is flushed err = retry.Do(context.Background(), func() error { - ts, err := sink.FlushRowChangedEvents(ctx, 2, uint64(2)) - c.Assert(err, check.IsNil) + ts, err := sink.FlushRowChangedEvents(ctx, 1, uint64(2)) + require.Nil(t, err) if ts < uint64(2) { return errors.Errorf("checkpoint ts %d less than resolved ts %d", ts, 2) } return nil }, retry.WithBackoffBaseDelay(20), retry.WithMaxTries(10), retry.WithIsRetryableErr(cerror.IsRetryableError)) - c.Assert(err, check.IsNil) + require.Nil(t, err) err = retry.Do(context.Background(), func() error { ts, err := sink.FlushRowChangedEvents(ctx, 2, uint64(4)) - c.Assert(err, check.IsNil) + require.Nil(t, err) if ts < uint64(4) { return errors.Errorf("checkpoint ts %d less than resolved ts %d", ts, 4) } return nil }, retry.WithBackoffBaseDelay(20), retry.WithMaxTries(10), retry.WithIsRetryableErr(cerror.IsRetryableError)) - c.Assert(err, check.IsNil) + require.Nil(t, err) err = sink.Barrier(ctx, 2) - c.Assert(err, check.IsNil) + require.Nil(t, err) err = sink.Close(ctx) - c.Assert(err, check.IsNil) + require.Nil(t, err) } -func (s MySQLSinkSuite) TestExecDMLRollbackErrDatabaseNotExists(c *check.C) { - defer testleak.AfterTest(c)() - +func TestExecDMLRollbackErrDatabaseNotExists(t *testing.T) { rows := []*model.RowChangedEvent{ { Table: &model.TableName{Schema: "s1", Table: "t1", TableID: 1}, @@ -735,12 +717,12 @@ func (s MySQLSinkSuite) TestExecDMLRollbackErrDatabaseNotExists(c *check.C) { if dbIndex == 0 { // test db db, err := mockTestDB() - c.Assert(err, check.IsNil) + require.Nil(t, err) return db, nil } // normal db db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) - c.Assert(err, check.IsNil) + require.Nil(t, err) mock.ExpectBegin() mock.ExpectExec("REPLACE INTO `s1`.`t1`(`a`) VALUES (?),(?)"). WithArgs(1, 2). @@ -759,23 +741,21 @@ func (s MySQLSinkSuite) TestExecDMLRollbackErrDatabaseNotExists(c *check.C) { defer cancel() changefeed := "test-changefeed" sinkURI, err := url.Parse("mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=1") - c.Assert(err, check.IsNil) + require.Nil(t, err) rc := config.GetDefaultReplicaConfig() f, err := filter.NewFilter(rc) - c.Assert(err, check.IsNil) + require.Nil(t, err) sink, err := newMySQLSink(ctx, changefeed, sinkURI, f, rc, map[string]string{}) - c.Assert(err, check.IsNil) + require.Nil(t, err) err = sink.(*mysqlSink).execDMLs(ctx, rows, 1 /* replicaID */, 1 /* bucket */) - c.Assert(errors.Cause(err), check.Equals, errDatabaseNotExists) + require.Equal(t, errDatabaseNotExists, errors.Cause(err)) err = sink.Close(ctx) - c.Assert(err, check.IsNil) + require.Nil(t, err) } -func (s MySQLSinkSuite) TestExecDMLRollbackErrTableNotExists(c *check.C) { - defer testleak.AfterTest(c)() - +func TestExecDMLRollbackErrTableNotExists(t *testing.T) { rows := []*model.RowChangedEvent{ { Table: &model.TableName{Schema: "s1", Table: "t1", TableID: 1}, @@ -803,12 +783,12 @@ func (s MySQLSinkSuite) TestExecDMLRollbackErrTableNotExists(c *check.C) { if dbIndex == 0 { // test db db, err := mockTestDB() - c.Assert(err, check.IsNil) + require.Nil(t, err) return db, nil } // normal db db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) - c.Assert(err, check.IsNil) + require.Nil(t, err) mock.ExpectBegin() mock.ExpectExec("REPLACE INTO `s1`.`t1`(`a`) VALUES (?),(?)"). WithArgs(1, 2). @@ -827,23 +807,21 @@ func (s MySQLSinkSuite) TestExecDMLRollbackErrTableNotExists(c *check.C) { defer cancel() changefeed := "test-changefeed" sinkURI, err := url.Parse("mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=1") - c.Assert(err, check.IsNil) + require.Nil(t, err) rc := config.GetDefaultReplicaConfig() f, err := filter.NewFilter(rc) - c.Assert(err, check.IsNil) + require.Nil(t, err) sink, err := newMySQLSink(ctx, changefeed, sinkURI, f, rc, map[string]string{}) - c.Assert(err, check.IsNil) + require.Nil(t, err) err = sink.(*mysqlSink).execDMLs(ctx, rows, 1 /* replicaID */, 1 /* bucket */) - c.Assert(errors.Cause(err), check.Equals, errTableNotExists) + require.Equal(t, errTableNotExists, errors.Cause(err)) err = sink.Close(ctx) - c.Assert(err, check.IsNil) + require.Nil(t, err) } -func (s MySQLSinkSuite) TestExecDMLRollbackErrRetryable(c *check.C) { - defer testleak.AfterTest(c)() - +func TestExecDMLRollbackErrRetryable(t *testing.T) { rows := []*model.RowChangedEvent{ { Table: &model.TableName{Schema: "s1", Table: "t1", TableID: 1}, @@ -871,12 +849,12 @@ func (s MySQLSinkSuite) TestExecDMLRollbackErrRetryable(c *check.C) { if dbIndex == 0 { // test db db, err := mockTestDB() - c.Assert(err, check.IsNil) + require.Nil(t, err) return db, nil } // normal db db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) - c.Assert(err, check.IsNil) + require.Nil(t, err) for i := 0; i < defaultDMLMaxRetryTime; i++ { mock.ExpectBegin() mock.ExpectExec("REPLACE INTO `s1`.`t1`(`a`) VALUES (?),(?)"). @@ -897,23 +875,21 @@ func (s MySQLSinkSuite) TestExecDMLRollbackErrRetryable(c *check.C) { defer cancel() changefeed := "test-changefeed" sinkURI, err := url.Parse("mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=1") - c.Assert(err, check.IsNil) + require.Nil(t, err) rc := config.GetDefaultReplicaConfig() f, err := filter.NewFilter(rc) - c.Assert(err, check.IsNil) + require.Nil(t, err) sink, err := newMySQLSink(ctx, changefeed, sinkURI, f, rc, map[string]string{}) - c.Assert(err, check.IsNil) + require.Nil(t, err) err = sink.(*mysqlSink).execDMLs(ctx, rows, 1 /* replicaID */, 1 /* bucket */) - c.Assert(errors.Cause(err), check.Equals, errLockDeadlock) + require.Equal(t, errLockDeadlock, errors.Cause(err)) err = sink.Close(ctx) - c.Assert(err, check.IsNil) + require.Nil(t, err) } -func (s MySQLSinkSuite) TestNewMySQLSinkExecDDL(c *check.C) { - defer testleak.AfterTest(c)() - +func TestNewMySQLSinkExecDDL(t *testing.T) { dbIndex := 0 mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { defer func() { @@ -922,12 +898,12 @@ func (s MySQLSinkSuite) TestNewMySQLSinkExecDDL(c *check.C) { if dbIndex == 0 { // test db db, err := mockTestDB() - c.Assert(err, check.IsNil) + require.Nil(t, err) return db, nil } // normal db db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) - c.Assert(err, check.IsNil) + require.Nil(t, err) mock.ExpectBegin() mock.ExpectExec("USE `test`;").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec("ALTER TABLE test.t1 ADD COLUMN a int").WillReturnResult(sqlmock.NewResult(1, 1)) @@ -952,15 +928,15 @@ func (s MySQLSinkSuite) TestNewMySQLSinkExecDDL(c *check.C) { defer cancel() changefeed := "test-changefeed" sinkURI, err := url.Parse("mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=4") - c.Assert(err, check.IsNil) + require.Nil(t, err) rc := config.GetDefaultReplicaConfig() rc.Filter = &config.FilterConfig{ Rules: []string{"test.t1"}, } f, err := filter.NewFilter(rc) - c.Assert(err, check.IsNil) + require.Nil(t, err) sink, err := newMySQLSink(ctx, changefeed, sinkURI, f, rc, map[string]string{}) - c.Assert(err, check.IsNil) + require.Nil(t, err) ddl1 := &model.DDLEvent{ StartTs: 1000, @@ -984,19 +960,18 @@ func (s MySQLSinkSuite) TestNewMySQLSinkExecDDL(c *check.C) { } err = sink.EmitDDLEvent(ctx, ddl1) - c.Assert(err, check.IsNil) + require.Nil(t, err) err = sink.EmitDDLEvent(ctx, ddl2) - c.Assert(cerror.ErrDDLEventIgnored.Equal(err), check.IsTrue) + require.True(t, cerror.ErrDDLEventIgnored.Equal(err)) // DDL execute failed, but error can be ignored err = sink.EmitDDLEvent(ctx, ddl1) - c.Assert(err, check.IsNil) + require.Nil(t, err) err = sink.Close(ctx) - c.Assert(err, check.IsNil) + require.Nil(t, err) } -func (s MySQLSinkSuite) TestNeedSwitchDB(c *check.C) { - defer testleak.AfterTest(c)() +func TestNeedSwitchDB(t *testing.T) { testCases := []struct { ddl *model.DDLEvent needSwitch bool @@ -1040,13 +1015,11 @@ func (s MySQLSinkSuite) TestNeedSwitchDB(c *check.C) { } for _, tc := range testCases { - c.Assert(needSwitchDB(tc.ddl), check.Equals, tc.needSwitch) + require.Equal(t, tc.needSwitch, needSwitchDB(tc.ddl)) } } -func (s MySQLSinkSuite) TestNewMySQLSink(c *check.C) { - defer testleak.AfterTest(c)() - +func TestNewMySQLSink(t *testing.T) { dbIndex := 0 mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { defer func() { @@ -1055,13 +1028,13 @@ func (s MySQLSinkSuite) TestNewMySQLSink(c *check.C) { if dbIndex == 0 { // test db db, err := mockTestDB() - c.Assert(err, check.IsNil) + require.Nil(t, err) return db, nil } // normal db db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) mock.ExpectClose() - c.Assert(err, check.IsNil) + require.Nil(t, err) return db, nil } backupGetDBConn := GetDBConnImpl @@ -1075,19 +1048,17 @@ func (s MySQLSinkSuite) TestNewMySQLSink(c *check.C) { changefeed := "test-changefeed" sinkURI, err := url.Parse("mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=4") - c.Assert(err, check.IsNil) + require.Nil(t, err) rc := config.GetDefaultReplicaConfig() f, err := filter.NewFilter(rc) - c.Assert(err, check.IsNil) + require.Nil(t, err) sink, err := newMySQLSink(ctx, changefeed, sinkURI, f, rc, map[string]string{}) - c.Assert(err, check.IsNil) + require.Nil(t, err) err = sink.Close(ctx) - c.Assert(err, check.IsNil) + require.Nil(t, err) } -func (s MySQLSinkSuite) TestMySQLSinkClose(c *check.C) { - defer testleak.AfterTest(c)() - +func TestMySQLSinkClose(t *testing.T) { dbIndex := 0 mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { defer func() { @@ -1096,13 +1067,13 @@ func (s MySQLSinkSuite) TestMySQLSinkClose(c *check.C) { if dbIndex == 0 { // test db db, err := mockTestDB() - c.Assert(err, check.IsNil) + require.Nil(t, err) return db, nil } // normal db db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) mock.ExpectClose() - c.Assert(err, check.IsNil) + require.Nil(t, err) return db, nil } backupGetDBConn := GetDBConnImpl @@ -1115,14 +1086,100 @@ func (s MySQLSinkSuite) TestMySQLSinkClose(c *check.C) { changefeed := "test-changefeed" sinkURI, err := url.Parse("mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=4") - c.Assert(err, check.IsNil) + require.Nil(t, err) rc := config.GetDefaultReplicaConfig() f, err := filter.NewFilter(rc) - c.Assert(err, check.IsNil) + require.Nil(t, err) // test sink.Close will work correctly even if the ctx pass in has not been cancel sink, err := newMySQLSink(ctx, changefeed, sinkURI, f, rc, map[string]string{}) - c.Assert(err, check.IsNil) + require.Nil(t, err) + err = sink.Close(ctx) + require.Nil(t, err) +} + +func TestMySQLSinkFlushResovledTs(t *testing.T) { + dbIndex := 0 + mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { + defer func() { + dbIndex++ + }() + if dbIndex == 0 { + // test db + db, err := mockTestDB() + require.Nil(t, err) + return db, nil + } + // normal db + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + mock.ExpectBegin() + mock.ExpectExec("REPLACE INTO `s1`.`t1`(`a`) VALUES (?)"). + WithArgs(1). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectBegin() + mock.ExpectExec("REPLACE INTO `s1`.`t2`(`a`) VALUES (?)"). + WithArgs(1). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectClose() + require.Nil(t, err) + return db, nil + } + backupGetDBConn := GetDBConnImpl + GetDBConnImpl = mockGetDBConn + defer func() { + GetDBConnImpl = backupGetDBConn + }() + + ctx := context.Background() + + changefeed := "test-changefeed" + sinkURI, err := url.Parse("mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=4") + require.Nil(t, err) + rc := config.GetDefaultReplicaConfig() + f, err := filter.NewFilter(rc) + require.Nil(t, err) + + // test sink.Close will work correctly even if the ctx pass in has not been cancel + si, err := newMySQLSink(ctx, changefeed, sinkURI, f, rc, map[string]string{}) + sink := si.(*mysqlSink) + require.Nil(t, err) + checkpoint, err := sink.FlushRowChangedEvents(ctx, model.TableID(1), 1) + require.Nil(t, err) + require.Equal(t, uint64(0), checkpoint) + rows := []*model.RowChangedEvent{ + { + Table: &model.TableName{Schema: "s1", Table: "t1", TableID: 1}, + CommitTs: 5, + Columns: []*model.Column{ + {Name: "a", Type: mysql.TypeLong, Flag: model.HandleKeyFlag | model.PrimaryKeyFlag, Value: 1}, + }, + }, + } + err = sink.EmitRowChangedEvents(ctx, rows...) + require.Nil(t, err) + checkpoint, err = sink.FlushRowChangedEvents(ctx, model.TableID(1), 6) + require.True(t, checkpoint <= 5) + time.Sleep(500 * time.Millisecond) + require.Nil(t, err) + require.Equal(t, uint64(6), sink.getTableCheckpointTs(model.TableID(1))) + rows = []*model.RowChangedEvent{ + { + Table: &model.TableName{Schema: "s1", Table: "t2", TableID: 2}, + CommitTs: 4, + Columns: []*model.Column{ + {Name: "a", Type: mysql.TypeLong, Flag: model.HandleKeyFlag | model.PrimaryKeyFlag, Value: 1}, + }, + }, + } + err = sink.EmitRowChangedEvents(ctx, rows...) + require.Nil(t, err) + checkpoint, err = sink.FlushRowChangedEvents(ctx, model.TableID(2), 5) + require.True(t, checkpoint <= 5) + time.Sleep(500 * time.Millisecond) + require.Nil(t, err) + require.Equal(t, uint64(5), sink.getTableCheckpointTs(model.TableID(2))) err = sink.Close(ctx) - c.Assert(err, check.IsNil) + require.Nil(t, err) } diff --git a/cdc/sink/mysql_worker.go b/cdc/sink/mysql_worker.go index 03630c51353..17e25adedfe 100644 --- a/cdc/sink/mysql_worker.go +++ b/cdc/sink/mysql_worker.go @@ -17,7 +17,6 @@ import ( "context" "runtime" "sync" - "sync/atomic" "github.com/pingcap/errors" "github.com/pingcap/log" @@ -35,7 +34,6 @@ type mysqlSinkWorker struct { execDMLs func(context.Context, []*model.RowChangedEvent, uint64, int) error metricBucketSize prometheus.Counter receiver *notify.Receiver - checkpointTs uint64 closedCh chan struct{} } @@ -78,10 +76,9 @@ func (w *mysqlSinkWorker) appendFinishTxn(wg *sync.WaitGroup) { func (w *mysqlSinkWorker) run(ctx context.Context) (err error) { var ( - toExecRows []*model.RowChangedEvent - replicaID uint64 - txnNum int - lastCommitTs uint64 + toExecRows []*model.RowChangedEvent + replicaID uint64 + txnNum int ) // mark FinishWg before worker exits, all data txns can be omitted. @@ -119,7 +116,6 @@ func (w *mysqlSinkWorker) run(ctx context.Context) (err error) { txnNum = 0 return err } - atomic.StoreUint64(&w.checkpointTs, lastCommitTs) toExecRows = toExecRows[:0] w.metricBucketSize.Add(float64(txnNum)) txnNum = 0 @@ -149,7 +145,6 @@ func (w *mysqlSinkWorker) run(ctx context.Context) (err error) { } replicaID = txn.ReplicaID toExecRows = append(toExecRows, txn.Rows...) - lastCommitTs = txn.CommitTs txnNum++ case <-w.receiver.C: if err := flushRows(); err != nil { diff --git a/cdc/sink/mysql_worker_test.go b/cdc/sink/mysql_worker_test.go index 17a3a766c00..509f3450297 100644 --- a/cdc/sink/mysql_worker_test.go +++ b/cdc/sink/mysql_worker_test.go @@ -15,20 +15,28 @@ package sink import ( "context" + "fmt" "sync" + "testing" "time" "github.com/davecgh/go-spew/spew" - "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/ticdc/cdc/model" "github.com/pingcap/ticdc/pkg/notify" "github.com/pingcap/ticdc/pkg/util/testleak" + "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" ) -func (s MySQLSinkSuite) TestMysqlSinkWorker(c *check.C) { - defer testleak.AfterTest(c)() +func TestMysqlSinkWorker(t *testing.T) { + defer testleak.AfterTestT(t)() + tbl := &model.TableName{ + Schema: "test", + Table: "user", + TableID: 1, + IsPartition: false, + } testCases := []struct { txns []*model.SingleTableTxn expectedOutputRows [][]*model.RowChangedEvent @@ -41,6 +49,7 @@ func (s MySQLSinkSuite) TestMysqlSinkWorker(c *check.C) { }, { txns: []*model.SingleTableTxn{ { + Table: tbl, CommitTs: 1, Rows: []*model.RowChangedEvent{{CommitTs: 1}}, ReplicaID: 1, @@ -52,6 +61,7 @@ func (s MySQLSinkSuite) TestMysqlSinkWorker(c *check.C) { }, { txns: []*model.SingleTableTxn{ { + Table: tbl, CommitTs: 1, Rows: []*model.RowChangedEvent{{CommitTs: 1}, {CommitTs: 1}, {CommitTs: 1}}, ReplicaID: 1, @@ -65,16 +75,19 @@ func (s MySQLSinkSuite) TestMysqlSinkWorker(c *check.C) { }, { txns: []*model.SingleTableTxn{ { + Table: tbl, CommitTs: 1, Rows: []*model.RowChangedEvent{{CommitTs: 1}, {CommitTs: 1}}, ReplicaID: 1, }, { + Table: tbl, CommitTs: 2, Rows: []*model.RowChangedEvent{{CommitTs: 2}}, ReplicaID: 1, }, { + Table: tbl, CommitTs: 3, Rows: []*model.RowChangedEvent{{CommitTs: 3}, {CommitTs: 3}}, ReplicaID: 1, @@ -89,16 +102,19 @@ func (s MySQLSinkSuite) TestMysqlSinkWorker(c *check.C) { }, { txns: []*model.SingleTableTxn{ { + Table: tbl, CommitTs: 1, Rows: []*model.RowChangedEvent{{CommitTs: 1}}, ReplicaID: 1, }, { + Table: tbl, CommitTs: 2, Rows: []*model.RowChangedEvent{{CommitTs: 2}}, ReplicaID: 2, }, { + Table: tbl, CommitTs: 3, Rows: []*model.RowChangedEvent{{CommitTs: 3}}, ReplicaID: 3, @@ -114,21 +130,25 @@ func (s MySQLSinkSuite) TestMysqlSinkWorker(c *check.C) { }, { txns: []*model.SingleTableTxn{ { + Table: tbl, CommitTs: 1, Rows: []*model.RowChangedEvent{{CommitTs: 1}}, ReplicaID: 1, }, { + Table: tbl, CommitTs: 2, Rows: []*model.RowChangedEvent{{CommitTs: 2}, {CommitTs: 2}, {CommitTs: 2}}, ReplicaID: 1, }, { + Table: tbl, CommitTs: 3, Rows: []*model.RowChangedEvent{{CommitTs: 3}}, ReplicaID: 1, }, { + Table: tbl, CommitTs: 4, Rows: []*model.RowChangedEvent{{CommitTs: 4}}, ReplicaID: 1, @@ -151,7 +171,7 @@ func (s MySQLSinkSuite) TestMysqlSinkWorker(c *check.C) { var outputRows [][]*model.RowChangedEvent var outputReplicaIDs []uint64 receiver, err := notifier.NewReceiver(-1) - c.Assert(err, check.IsNil) + require.Nil(t, err) w := newMySQLSinkWorker(tc.maxTxnRow, 1, bucketSizeCounter.WithLabelValues("capture", "changefeed", "1"), receiver, @@ -174,40 +194,52 @@ func (s MySQLSinkSuite) TestMysqlSinkWorker(c *check.C) { notifier.Notify() wg.Wait() cancel() - c.Assert(errors.Cause(errg.Wait()), check.Equals, context.Canceled) - c.Assert(outputRows, check.DeepEquals, tc.expectedOutputRows, - check.Commentf("case %v, %s, %s", i, spew.Sdump(outputRows), spew.Sdump(tc.expectedOutputRows))) - c.Assert(outputReplicaIDs, check.DeepEquals, tc.exportedOutputReplicaIDs, - check.Commentf("case %v, %s, %s", i, spew.Sdump(outputReplicaIDs), spew.Sdump(tc.exportedOutputReplicaIDs))) + require.Equal(t, context.Canceled, errors.Cause(errg.Wait())) + require.Equal(t, tc.expectedOutputRows, outputRows, + fmt.Sprintf("case %v, %s, %s", i, spew.Sdump(outputRows), spew.Sdump(tc.expectedOutputRows))) + require.Equal(t, tc.exportedOutputReplicaIDs, outputReplicaIDs, tc.exportedOutputReplicaIDs, + fmt.Sprintf("case %v, %s, %s", i, spew.Sdump(outputReplicaIDs), spew.Sdump(tc.exportedOutputReplicaIDs))) } } -func (s MySQLSinkSuite) TestMySQLSinkWorkerExitWithError(c *check.C) { - defer testleak.AfterTest(c)() +func TestMySQLSinkWorkerExitWithError(t *testing.T) { + defer testleak.AfterTestT(t)() + tbl := &model.TableName{ + Schema: "test", + Table: "user", + TableID: 1, + IsPartition: false, + } txns1 := []*model.SingleTableTxn{ { + Table: tbl, CommitTs: 1, Rows: []*model.RowChangedEvent{{CommitTs: 1}}, }, { + Table: tbl, CommitTs: 2, Rows: []*model.RowChangedEvent{{CommitTs: 2}}, }, { + Table: tbl, CommitTs: 3, Rows: []*model.RowChangedEvent{{CommitTs: 3}}, }, { + Table: tbl, CommitTs: 4, Rows: []*model.RowChangedEvent{{CommitTs: 4}}, }, } txns2 := []*model.SingleTableTxn{ { + Table: tbl, CommitTs: 5, Rows: []*model.RowChangedEvent{{CommitTs: 5}}, }, { + Table: tbl, CommitTs: 6, Rows: []*model.RowChangedEvent{{CommitTs: 6}}, }, @@ -219,7 +251,7 @@ func (s MySQLSinkSuite) TestMySQLSinkWorkerExitWithError(c *check.C) { notifier := new(notify.Notifier) cctx, cancel := context.WithCancel(ctx) receiver, err := notifier.NewReceiver(-1) - c.Assert(err, check.IsNil) + require.Nil(t, err) w := newMySQLSinkWorker(maxTxnRow, 1, /*bucket*/ bucketSizeCounter.WithLabelValues("capture", "changefeed", "1"), receiver, @@ -253,23 +285,32 @@ func (s MySQLSinkSuite) TestMySQLSinkWorkerExitWithError(c *check.C) { wg.Wait() cancel() - c.Assert(errg.Wait(), check.Equals, errExecFailed) + require.Equal(t, errExecFailed, errg.Wait()) } -func (s MySQLSinkSuite) TestMySQLSinkWorkerExitCleanup(c *check.C) { - defer testleak.AfterTest(c)() +func TestMySQLSinkWorkerExitCleanup(t *testing.T) { + defer testleak.AfterTestT(t)() + tbl := &model.TableName{ + Schema: "test", + Table: "user", + TableID: 1, + IsPartition: false, + } txns1 := []*model.SingleTableTxn{ { + Table: tbl, CommitTs: 1, Rows: []*model.RowChangedEvent{{CommitTs: 1}}, }, { + Table: tbl, CommitTs: 2, Rows: []*model.RowChangedEvent{{CommitTs: 2}}, }, } txns2 := []*model.SingleTableTxn{ { + Table: tbl, CommitTs: 5, Rows: []*model.RowChangedEvent{{CommitTs: 5}}, }, @@ -282,7 +323,7 @@ func (s MySQLSinkSuite) TestMySQLSinkWorkerExitCleanup(c *check.C) { notifier := new(notify.Notifier) cctx, cancel := context.WithCancel(ctx) receiver, err := notifier.NewReceiver(-1) - c.Assert(err, check.IsNil) + require.Nil(t, err) w := newMySQLSinkWorker(maxTxnRow, 1, /*bucket*/ bucketSizeCounter.WithLabelValues("capture", "changefeed", "1"), receiver, @@ -317,5 +358,5 @@ func (s MySQLSinkSuite) TestMySQLSinkWorkerExitCleanup(c *check.C) { wg.Wait() cancel() - c.Assert(errg.Wait(), check.Equals, errExecFailed) + require.Equal(t, errExecFailed, errg.Wait()) } diff --git a/cmd/kafka-consumer/main.go b/cmd/kafka-consumer/main.go index a027406589e..a7b4caa0a71 100644 --- a/cmd/kafka-consumer/main.go +++ b/cmd/kafka-consumer/main.go @@ -290,6 +290,13 @@ func main() { } } +type partitionSink struct { + sink.Sink + resolvedTs uint64 + partitionNo int + tablesMap sync.Map +} + // Consumer represents a Sarama consumer group consumer type Consumer struct { ready chan bool @@ -298,10 +305,7 @@ type Consumer struct { maxDDLReceivedTs uint64 ddlListMu sync.Mutex - sinks []*struct { - sink.Sink - resolvedTs uint64 - } + sinks []*partitionSink sinksMu sync.Mutex ddlSink sink.Sink @@ -326,10 +330,7 @@ func NewConsumer(ctx context.Context) (*Consumer, error) { c.fakeTableIDGenerator = &fakeTableIDGenerator{ tableIDs: make(map[string]int64), } - c.sinks = make([]*struct { - sink.Sink - resolvedTs uint64 - }, kafkaPartitionNum) + c.sinks = make([]*partitionSink, kafkaPartitionNum) ctx, cancel := context.WithCancel(ctx) errCh := make(chan error, 1) opts := map[string]string{} @@ -339,10 +340,7 @@ func NewConsumer(ctx context.Context) (*Consumer, error) { cancel() return nil, errors.Trace(err) } - c.sinks[i] = &struct { - sink.Sink - resolvedTs uint64 - }{Sink: s} + c.sinks[i] = &partitionSink{Sink: s, partitionNo: i} } sink, err := sink.New(ctx, "kafka-consumer", downstreamURIStr, filter, config.GetDefaultReplicaConfig(), opts, errCh) if err != nil { @@ -443,6 +441,10 @@ ClaimMessages: if err != nil { log.Fatal("emit row changed event failed", zap.Error(err)) } + lastCommitTs, ok := sink.tablesMap.Load(row.Table.TableID) + if !ok || lastCommitTs.(uint64) < row.CommitTs { + sink.tablesMap.Store(row.Table.TableID, row.CommitTs) + } case model.MqMessageTypeResolved: ts, err := batchDecoder.NextResolvedEvent() if err != nil { @@ -503,10 +505,7 @@ func (c *Consumer) popDDL() *model.DDLEvent { return nil } -func (c *Consumer) forEachSink(fn func(sink *struct { - sink.Sink - resolvedTs uint64 -}) error) error { +func (c *Consumer) forEachSink(fn func(sink *partitionSink) error) error { c.sinksMu.Lock() defer c.sinksMu.Unlock() for _, sink := range c.sinks { @@ -529,10 +528,7 @@ func (c *Consumer) Run(ctx context.Context) error { time.Sleep(100 * time.Millisecond) // handle ddl globalResolvedTs := uint64(math.MaxUint64) - err := c.forEachSink(func(sink *struct { - sink.Sink - resolvedTs uint64 - }) error { + err := c.forEachSink(func(sink *partitionSink) error { resolvedTs := atomic.LoadUint64(&sink.resolvedTs) if resolvedTs < globalResolvedTs { globalResolvedTs = resolvedTs @@ -545,10 +541,7 @@ func (c *Consumer) Run(ctx context.Context) error { todoDDL := c.getFrontDDL() if todoDDL != nil && globalResolvedTs >= todoDDL.CommitTs { // flush DMLs - err := c.forEachSink(func(sink *struct { - sink.Sink - resolvedTs uint64 - }) error { + err := c.forEachSink(func(sink *partitionSink) error { return syncFlushRowChangedEvents(ctx, sink, todoDDL.CommitTs) }) if err != nil { @@ -574,10 +567,7 @@ func (c *Consumer) Run(ctx context.Context) error { atomic.StoreUint64(&c.globalResolvedTs, globalResolvedTs) log.Info("update globalResolvedTs", zap.Uint64("ts", globalResolvedTs)) - err = c.forEachSink(func(sink *struct { - sink.Sink - resolvedTs uint64 - }) error { + err = c.forEachSink(func(sink *partitionSink) error { return syncFlushRowChangedEvents(ctx, sink, globalResolvedTs) }) if err != nil { @@ -586,19 +576,34 @@ func (c *Consumer) Run(ctx context.Context) error { } } -func syncFlushRowChangedEvents(ctx context.Context, sink sink.Sink, resolvedTs uint64) error { +func syncFlushRowChangedEvents(ctx context.Context, sink *partitionSink, resolvedTs uint64) error { for { select { case <-ctx.Done(): return ctx.Err() default: } - // todo: use real table id - checkpointTs, err := sink.FlushRowChangedEvents(ctx, 0, resolvedTs) + // tables are flushed + var ( + err error + checkpointTs uint64 + ) + flushedResolvedTs := true + sink.tablesMap.Range(func(key, value interface{}) bool { + tableID := key.(int64) + checkpointTs, err = sink.FlushRowChangedEvents(ctx, tableID, resolvedTs) + if err != nil { + return false + } + if checkpointTs < resolvedTs { + flushedResolvedTs = false + } + return true + }) if err != nil { return err } - if checkpointTs >= resolvedTs { + if flushedResolvedTs { return nil } } diff --git a/tests/integration_tests/sink_hang/run.sh b/tests/integration_tests/sink_hang/run.sh index ac386512890..8a3303956e3 100644 --- a/tests/integration_tests/sink_hang/run.sh +++ b/tests/integration_tests/sink_hang/run.sh @@ -42,7 +42,7 @@ function run() { *) SINK_URI="mysql://normal:123456@127.0.0.1:3306/?max-txn-row=1" ;; esac - export GO_FAILPOINTS='github.com/pingcap/ticdc/cdc/sink/MySQLSinkHangLongTime=1*return(true);github.com/pingcap/ticdc/cdc/sink/MySQLSinkExecDMLError=9*return(true)' + export GO_FAILPOINTS='github.com/pingcap/ticdc/cdc/sink/MySQLSinkExecDMLError=2*return(true)' run_cdc_server --workdir $WORK_DIR --binary $CDC_BINARY --addr "127.0.0.1:8300" --pd $pd_addr changefeed_id=$(cdc cli changefeed create --pd=$pd_addr --sink-uri="$SINK_URI" 2>&1 | tail -n2 | head -n1 | awk '{print $2}') if [ "$SINK_TYPE" == "kafka" ]; then @@ -54,8 +54,6 @@ function run() { run_sql "CREATE table sink_hang.t2(id int primary key auto_increment, val int);" ${UP_TIDB_HOST} ${UP_TIDB_PORT} run_sql "BEGIN; INSERT INTO sink_hang.t1 VALUES (),(),(); INSERT INTO sink_hang.t2 VALUES (),(),(); COMMIT" ${UP_TIDB_HOST} ${UP_TIDB_PORT} - ensure $MAX_RETRIES check_changefeed_state $pd_addr $changefeed_id "error" - cdc cli changefeed resume --changefeed-id=$changefeed_id --pd=$pd_addr ensure $MAX_RETRIES check_changefeed_state $pd_addr $changefeed_id "normal" check_table_exists "sink_hang.t1" ${DOWN_TIDB_HOST} ${DOWN_TIDB_PORT}