diff --git a/cdc/sink/mysql.go b/cdc/sink/mysql.go index f6551edaec0..427c65caf7b 100644 --- a/cdc/sink/mysql.go +++ b/cdc/sink/mysql.go @@ -34,6 +34,7 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tiflow/cdc/model" "github.com/pingcap/tiflow/cdc/sink/common" + dmutils "github.com/pingcap/tiflow/dm/pkg/utils" "github.com/pingcap/tiflow/pkg/config" "github.com/pingcap/tiflow/pkg/cyclic" "github.com/pingcap/tiflow/pkg/cyclic/mark" @@ -108,6 +109,141 @@ type mysqlSink struct { metricBucketSizeCounters []prometheus.Counter forceReplicate bool +<<<<<<< HEAD +======= + cancel func() +} + +var _ Sink = &mysqlSink{} + +// newMySQLSink creates a new MySQL sink using schema storage +func newMySQLSink( + ctx context.Context, + changefeedID model.ChangeFeedID, + sinkURI *url.URL, + filter *tifilter.Filter, + replicaConfig *config.ReplicaConfig, + opts map[string]string, +) (Sink, error) { + opts[OptChangefeedID] = changefeedID + params, err := parseSinkURIToParams(ctx, sinkURI, opts) + if err != nil { + return nil, err + } + + params.enableOldValue = replicaConfig.EnableOldValue + + // dsn format of the driver: + // [username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] + username := sinkURI.User.Username() + password, _ := sinkURI.User.Password() + port := sinkURI.Port() + if username == "" { + username = "root" + } + if port == "" { + port = "4000" + } + + dsnStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", username, password, sinkURI.Hostname(), port, params.tls) + dsn, err := dmysql.ParseDSN(dsnStr) + if err != nil { + return nil, cerror.WrapError(cerror.ErrMySQLInvalidConfig, err) + } + + // create test db used for parameter detection + if dsn.Params == nil { + dsn.Params = make(map[string]string, 1) + } + if params.timezone != "" { + dsn.Params["time_zone"] = params.timezone + } + dsn.Params["readTimeout"] = params.readTimeout + dsn.Params["writeTimeout"] = params.writeTimeout + dsn.Params["timeout"] = params.dialTimeout + testDB, err := GetDBConnImpl(ctx, dsn.FormatDSN()) + if err != nil { + return nil, err + } + defer testDB.Close() + + // Adjust sql_mode for compatibility. + dsn.Params["sql_mode"], err = querySQLMode(ctx, testDB) + if err != nil { + return nil, errors.Trace(err) + } + dsn.Params["sql_mode"], err = dmutils.AdjustSQLModeCompatible(dsn.Params["sql_mode"]) + if err != nil { + return nil, errors.Trace(err) + } + + // Adjust sql_mode for cyclic replication. + var sinkCyclic *cyclic.Cyclic = nil + if val, ok := opts[mark.OptCyclicConfig]; ok { + cfg := new(config.CyclicConfig) + err := cfg.Unmarshal([]byte(val)) + if err != nil { + return nil, cerror.WrapError(cerror.ErrMySQLInvalidConfig, err) + } + sinkCyclic = cyclic.NewCyclic(cfg) + dsn.Params["sql_mode"] = cyclic.RelaxSQLMode(dsn.Params["sql_mode"]) + } + // NOTE: quote the string is necessary to avoid ambiguities. + dsn.Params["sql_mode"] = strconv.Quote(dsn.Params["sql_mode"]) + + dsnStr, err = generateDSNByParams(ctx, dsn, params, testDB) + if err != nil { + return nil, errors.Trace(err) + } + db, err := GetDBConnImpl(ctx, dsnStr) + if err != nil { + return nil, err + } + + log.Info("Start mysql sink") + + db.SetMaxIdleConns(params.workerCount) + db.SetMaxOpenConns(params.workerCount) + + metricConflictDetectDurationHis := conflictDetectDurationHis.WithLabelValues( + params.captureAddr, params.changefeedID) + metricBucketSizeCounters := make([]prometheus.Counter, params.workerCount) + for i := 0; i < params.workerCount; i++ { + metricBucketSizeCounters[i] = bucketSizeCounter.WithLabelValues( + params.captureAddr, params.changefeedID, strconv.Itoa(i)) + } + ctx, cancel := context.WithCancel(ctx) + + sink := &mysqlSink{ + db: db, + params: params, + filter: filter, + cyclic: sinkCyclic, + txnCache: common.NewUnresolvedTxnCache(), + statistics: NewStatistics(ctx, "mysql", opts), + metricConflictDetectDurationHis: metricConflictDetectDurationHis, + metricBucketSizeCounters: metricBucketSizeCounters, + errCh: make(chan error, 1), + forceReplicate: replicaConfig.ForceReplicate, + cancel: cancel, + } + + sink.execWaitNotifier = new(notify.Notifier) + sink.resolvedNotifier = new(notify.Notifier) + + err = sink.createSinkWorkers(ctx) + if err != nil { + return nil, err + } + + receiver, err := sink.resolvedNotifier.NewReceiver(50 * time.Millisecond) + if err != nil { + return nil, err + } + go sink.flushRowChangedEvents(ctx, receiver) + + return sink, nil +>>>>>>> 1df27c666 (sink(ticdc): adjust sql mode compatibility for mysql sink (#3938)) } func (s *mysqlSink) EmitRowChangedEvents(ctx context.Context, rows ...*model.RowChangedEvent) error { @@ -242,6 +378,7 @@ func (s *mysqlSink) execDDL(ctx context.Context, ddl *model.DDLEvent) error { return nil } +<<<<<<< HEAD // adjustSQLMode adjust sql mode according to sink config. func (s *mysqlSink) adjustSQLMode(ctx context.Context) error { // Must relax sql mode to support cyclic replication, as downstream may have @@ -258,10 +395,25 @@ func (s *mysqlSink) adjustSQLMode(ctx context.Context) error { newMode = cyclic.RelaxSQLMode(oldMode) _, err = s.db.ExecContext(ctx, fmt.Sprintf("SET sql_mode = '%s';", newMode)) +======= +func needSwitchDB(ddl *model.DDLEvent) bool { + if len(ddl.TableInfo.Schema) == 0 { + return false + } + if ddl.Type == timodel.ActionCreateSchema || ddl.Type == timodel.ActionDropSchema { + return false + } + return true +} + +func querySQLMode(ctx context.Context, db *sql.DB) (sqlMode string, err error) { + row := db.QueryRowContext(ctx, "SELECT @@SESSION.sql_mode;") + err = row.Scan(&sqlMode) +>>>>>>> 1df27c666 (sink(ticdc): adjust sql mode compatibility for mysql sink (#3938)) if err != nil { - return cerror.WrapError(cerror.ErrMySQLQueryError, err) + err = cerror.WrapError(cerror.ErrMySQLQueryError, err) } - return nil + return } var _ Sink = &mysqlSink{} diff --git a/cdc/sink/mysql_params_test.go b/cdc/sink/mysql_params_test.go new file mode 100644 index 00000000000..63f408f722a --- /dev/null +++ b/cdc/sink/mysql_params_test.go @@ -0,0 +1,228 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sink + +import ( + "context" + "database/sql" + "net/url" + "strings" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + dmysql "github.com/go-sql-driver/mysql" + "github.com/pingcap/tiflow/pkg/util/testleak" + "github.com/stretchr/testify/require" +) + +func TestSinkParamsClone(t *testing.T) { + defer testleak.AfterTestT(t)() + param1 := defaultParams.Clone() + param2 := param1.Clone() + param2.changefeedID = "123" + param2.batchReplaceEnabled = false + param2.maxTxnRow = 1 + require.Equal(t, &sinkParams{ + workerCount: DefaultWorkerCount, + maxTxnRow: DefaultMaxTxnRow, + tidbTxnMode: defaultTiDBTxnMode, + batchReplaceEnabled: defaultBatchReplaceEnabled, + batchReplaceSize: defaultBatchReplaceSize, + readTimeout: defaultReadTimeout, + writeTimeout: defaultWriteTimeout, + dialTimeout: defaultDialTimeout, + safeMode: defaultSafeMode, + }, param1) + require.Equal(t, &sinkParams{ + changefeedID: "123", + workerCount: DefaultWorkerCount, + maxTxnRow: 1, + tidbTxnMode: defaultTiDBTxnMode, + batchReplaceEnabled: false, + batchReplaceSize: defaultBatchReplaceSize, + readTimeout: defaultReadTimeout, + writeTimeout: defaultWriteTimeout, + dialTimeout: defaultDialTimeout, + safeMode: defaultSafeMode, + }, param2) +} + +func TestGenerateDSNByParams(t *testing.T) { + defer testleak.AfterTestT(t)() + + testDefaultParams := func() { + db, err := mockTestDB(false) + require.Nil(t, err) + defer db.Close() + + dsn, err := dmysql.ParseDSN("root:123456@tcp(127.0.0.1:4000)/") + require.Nil(t, err) + params := defaultParams.Clone() + dsnStr, err := generateDSNByParams(context.TODO(), dsn, params, db) + require.Nil(t, err) + expectedParams := []string{ + "tidb_txn_mode=optimistic", + "readTimeout=2m", + "writeTimeout=2m", + "allow_auto_random_explicit_insert=1", + } + for _, param := range expectedParams { + require.True(t, strings.Contains(dsnStr, param)) + } + require.False(t, strings.Contains(dsnStr, "time_zone")) + } + + testTimezoneParam := func() { + db, err := mockTestDB(false) + require.Nil(t, err) + defer db.Close() + + dsn, err := dmysql.ParseDSN("root:123456@tcp(127.0.0.1:4000)/") + require.Nil(t, err) + params := defaultParams.Clone() + params.timezone = `"UTC"` + dsnStr, err := generateDSNByParams(context.TODO(), dsn, params, db) + require.Nil(t, err) + require.True(t, strings.Contains(dsnStr, "time_zone=%22UTC%22")) + } + + testTimeoutParams := func() { + db, err := mockTestDB(false) + require.Nil(t, err) + defer db.Close() + + dsn, err := dmysql.ParseDSN("root:123456@tcp(127.0.0.1:4000)/") + require.Nil(t, err) + uri, err := url.Parse("mysql://127.0.0.1:3306/?read-timeout=4m&write-timeout=5m&timeout=3m") + require.Nil(t, err) + params, err := parseSinkURIToParams(context.TODO(), uri, map[string]string{}) + require.Nil(t, err) + dsnStr, err := generateDSNByParams(context.TODO(), dsn, params, db) + require.Nil(t, err) + expectedParams := []string{ + "readTimeout=4m", + "writeTimeout=5m", + "timeout=3m", + } + for _, param := range expectedParams { + require.True(t, strings.Contains(dsnStr, param)) + } + } + + testDefaultParams() + testTimezoneParam() + testTimeoutParams() +} + +func TestParseSinkURIToParams(t *testing.T) { + defer testleak.AfterTestT(t)() + expected := defaultParams.Clone() + expected.workerCount = 64 + expected.maxTxnRow = 20 + expected.batchReplaceEnabled = true + expected.batchReplaceSize = 50 + expected.safeMode = true + expected.timezone = `"UTC"` + expected.changefeedID = "cf-id" + expected.captureAddr = "127.0.0.1:8300" + expected.tidbTxnMode = "pessimistic" + uriStr := "mysql://127.0.0.1:3306/?worker-count=64&max-txn-row=20" + + "&batch-replace-enable=true&batch-replace-size=50&safe-mode=true" + + "&tidb-txn-mode=pessimistic" + opts := map[string]string{ + OptChangefeedID: expected.changefeedID, + OptCaptureAddr: expected.captureAddr, + } + uri, err := url.Parse(uriStr) + require.Nil(t, err) + params, err := parseSinkURIToParams(context.TODO(), uri, opts) + require.Nil(t, err) + require.Equal(t, expected, params) +} + +func TestParseSinkURITimezone(t *testing.T) { + defer testleak.AfterTestT(t)() + uris := []string{ + "mysql://127.0.0.1:3306/?time-zone=Asia/Shanghai&worker-count=32", + "mysql://127.0.0.1:3306/?time-zone=&worker-count=32", + "mysql://127.0.0.1:3306/?worker-count=32", + } + expected := []string{ + "\"Asia/Shanghai\"", + "", + "\"UTC\"", + } + ctx := context.TODO() + opts := map[string]string{} + for i, uriStr := range uris { + uri, err := url.Parse(uriStr) + require.Nil(t, err) + params, err := parseSinkURIToParams(ctx, uri, opts) + require.Nil(t, err) + require.Equal(t, expected[i], params.timezone) + } +} + +func TestParseSinkURIBadQueryString(t *testing.T) { + defer testleak.AfterTestT(t)() + uris := []string{ + "", + "postgre://127.0.0.1:3306", + "mysql://127.0.0.1:3306/?worker-count=not-number", + "mysql://127.0.0.1:3306/?max-txn-row=not-number", + "mysql://127.0.0.1:3306/?ssl-ca=only-ca-exists", + "mysql://127.0.0.1:3306/?batch-replace-enable=not-bool", + "mysql://127.0.0.1:3306/?batch-replace-enable=true&batch-replace-size=not-number", + "mysql://127.0.0.1:3306/?safe-mode=not-bool", + } + ctx := context.TODO() + opts := map[string]string{OptChangefeedID: "changefeed-01"} + var uri *url.URL + var err error + for _, uriStr := range uris { + if uriStr != "" { + uri, err = url.Parse(uriStr) + require.Nil(t, err) + } else { + uri = nil + } + _, err = parseSinkURIToParams(ctx, uri, opts) + require.NotNil(t, err) + } +} + +func TestCheckTiDBVariable(t *testing.T) { + defer testleak.AfterTestT(t)() + db, mock, err := sqlmock.New() + require.Nil(t, err) + defer db.Close() //nolint:errcheck + columns := []string{"Variable_name", "Value"} + + mock.ExpectQuery("show session variables like 'allow_auto_random_explicit_insert';").WillReturnRows( + sqlmock.NewRows(columns).AddRow("allow_auto_random_explicit_insert", "0"), + ) + val, err := checkTiDBVariable(context.TODO(), db, "allow_auto_random_explicit_insert", "1") + require.Nil(t, err) + require.Equal(t, "1", val) + + mock.ExpectQuery("show session variables like 'no_exist_variable';").WillReturnError(sql.ErrNoRows) + val, err = checkTiDBVariable(context.TODO(), db, "no_exist_variable", "0") + require.Nil(t, err) + require.Equal(t, "", val) + + mock.ExpectQuery("show session variables like 'version';").WillReturnError(sql.ErrConnDone) + _, err = checkTiDBVariable(context.TODO(), db, "version", "5.7.25-TiDB-v4.0.0") + require.NotNil(t, err) + require.Regexp(t, ".*"+sql.ErrConnDone.Error(), err.Error()) +} diff --git a/cdc/sink/mysql_test.go b/cdc/sink/mysql_test.go index 24463682b27..e7f725b748f 100644 --- a/cdc/sink/mysql_test.go +++ b/cdc/sink/mysql_test.go @@ -428,6 +428,7 @@ func TestReduceReplace(t *testing.T) { } } +<<<<<<< HEAD func TestSinkParamsClone(t *testing.T) { param1 := defaultParams.Clone() param2 := param1.Clone() @@ -599,11 +600,19 @@ func TestParseSinkURIBadQueryString(t *testing.T) { } func mockTestDB() (*sql.DB, error) { +======= +func mockTestDB(adjustSQLMode bool) (*sql.DB, error) { +>>>>>>> 1df27c666 (sink(ticdc): adjust sql mode compatibility for mysql sink (#3938)) // mock for test db, which is used querying TiDB session variable db, mock, err := sqlmock.New() if err != nil { return nil, err } + if adjustSQLMode { + mock.ExpectQuery("SELECT @@SESSION.sql_mode;"). + WillReturnRows(sqlmock.NewRows([]string{"@@SESSION.sql_mode"}). + AddRow("ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE")) + } columns := []string{"Variable_name", "Value"} mock.ExpectQuery("show session variables like 'allow_auto_random_explicit_insert';").WillReturnRows( sqlmock.NewRows(columns).AddRow("allow_auto_random_explicit_insert", "0"), @@ -626,18 +635,13 @@ func TestAdjustSQLMode(t *testing.T) { }() if dbIndex == 0 { // test db - db, err := mockTestDB() + db, err := mockTestDB(true) require.Nil(t, err) return db, nil } // normal db db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) require.Nil(t, err) - mock.ExpectQuery("SELECT @@SESSION.sql_mode;"). - WillReturnRows(sqlmock.NewRows([]string{"@@SESSION.sql_mode"}). - AddRow("ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE")) - mock.ExpectExec("SET sql_mode = 'ONLY_FULL_GROUP_BY,NO_ZERO_IN_DATE,NO_ZERO_DATE';"). - WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectClose() return db, nil } @@ -743,7 +747,7 @@ func TestNewMySQLSinkExecDML(t *testing.T) { }() if dbIndex == 0 { // test db - db, err := mockTestDB() + db, err := mockTestDB(true) require.Nil(t, err) return db, nil } @@ -887,7 +891,7 @@ func TestExecDMLRollbackErrDatabaseNotExists(t *testing.T) { }() if dbIndex == 0 { // test db - db, err := mockTestDB() + db, err := mockTestDB(true) require.Nil(t, err) return db, nil } @@ -953,7 +957,7 @@ func TestExecDMLRollbackErrTableNotExists(t *testing.T) { }() if dbIndex == 0 { // test db - db, err := mockTestDB() + db, err := mockTestDB(true) require.Nil(t, err) return db, nil } @@ -1019,7 +1023,7 @@ func TestExecDMLRollbackErrRetryable(t *testing.T) { }() if dbIndex == 0 { // test db - db, err := mockTestDB() + db, err := mockTestDB(true) require.Nil(t, err) return db, nil } @@ -1068,7 +1072,7 @@ func TestNewMySQLSinkExecDDL(t *testing.T) { }() if dbIndex == 0 { // test db - db, err := mockTestDB() + db, err := mockTestDB(true) require.Nil(t, err) return db, nil } @@ -1150,7 +1154,7 @@ func TestNewMySQLSink(t *testing.T) { }() if dbIndex == 0 { // test db - db, err := mockTestDB() + db, err := mockTestDB(true) require.Nil(t, err) return db, nil } @@ -1188,7 +1192,7 @@ func TestMySQLSinkClose(t *testing.T) { }() if dbIndex == 0 { // test db - db, err := mockTestDB() + db, err := mockTestDB(true) require.Nil(t, err) return db, nil } @@ -1228,7 +1232,7 @@ func TestMySQLSinkFlushResovledTs(t *testing.T) { }() if dbIndex == 0 { // test db - db, err := mockTestDB() + db, err := mockTestDB(true) require.Nil(t, err) return db, nil } diff --git a/integration/framework/task.go b/integration/framework/task.go index c122af35e04..c6ece257fa1 100644 --- a/integration/framework/task.go +++ b/integration/framework/task.go @@ -16,6 +16,8 @@ package framework import ( "context" "database/sql" + "fmt" + "strconv" "strings" _ "github.com/go-sql-driver/mysql" // imported for side effects @@ -82,20 +84,19 @@ func (c *TaskContext) SQLHelper() *SQLHelper { func (p *CDCProfile) String() string { builder := strings.Builder{} builder.WriteString("cli changefeed create ") + if p.PDUri == "" { p.PDUri = "http://127.0.0.1:2379" } - - builder.WriteString("--pd=" + p.PDUri + " ") + builder.WriteString(fmt.Sprintf("--pd=%s ", strconv.Quote(p.PDUri))) if p.SinkURI == "" { log.Fatal("SinkURI cannot be empty!") } - - builder.WriteString("--sink-uri=\"" + p.SinkURI + "\" ") + builder.WriteString(fmt.Sprintf("--sink-uri=%s ", strconv.Quote(p.SinkURI))) if p.ConfigFile != "" { - builder.WriteString("--config=" + p.ConfigFile + " ") + builder.WriteString(fmt.Sprintf("--config=%s ", strconv.Quote(p.ConfigFile))) } if p.Opts == nil || len(p.Opts) == 0 { diff --git a/pkg/applier/redo_test.go b/pkg/applier/redo_test.go new file mode 100644 index 00000000000..5d2466bd27e --- /dev/null +++ b/pkg/applier/redo_test.go @@ -0,0 +1,243 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package applier + +import ( + "context" + "database/sql" + "fmt" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/phayes/freeport" + "github.com/pingcap/tiflow/cdc/model" + "github.com/pingcap/tiflow/cdc/redo" + "github.com/pingcap/tiflow/cdc/redo/reader" + "github.com/pingcap/tiflow/cdc/sink" + "github.com/stretchr/testify/require" +) + +// MockReader is a mock redo log reader that implements LogReader interface +type MockReader struct { + checkpointTs uint64 + resolvedTs uint64 + redoLogCh chan *model.RedoRowChangedEvent + ddlEventCh chan *model.RedoDDLEvent +} + +// NewMockReader creates a new MockReader +func NewMockReader( + checkpointTs uint64, + resolvedTs uint64, + redoLogCh chan *model.RedoRowChangedEvent, + ddlEventCh chan *model.RedoDDLEvent, +) *MockReader { + return &MockReader{ + checkpointTs: checkpointTs, + resolvedTs: resolvedTs, + redoLogCh: redoLogCh, + ddlEventCh: ddlEventCh, + } +} + +// ResetReader implements LogReader.ReadLog +func (br *MockReader) ResetReader(ctx context.Context, startTs, endTs uint64) error { + return nil +} + +// ReadNextLog implements LogReader.ReadNextLog +func (br *MockReader) ReadNextLog(ctx context.Context, maxNumberOfMessages uint64) ([]*model.RedoRowChangedEvent, error) { + cached := make([]*model.RedoRowChangedEvent, 0) + for { + select { + case <-ctx.Done(): + return cached, nil + case redoLog, ok := <-br.redoLogCh: + if !ok { + return cached, nil + } + cached = append(cached, redoLog) + if len(cached) >= int(maxNumberOfMessages) { + return cached, nil + } + } + } +} + +// ReadNextDDL implements LogReader.ReadNextDDL +func (br *MockReader) ReadNextDDL(ctx context.Context, maxNumberOfDDLs uint64) ([]*model.RedoDDLEvent, error) { + cached := make([]*model.RedoDDLEvent, 0) + for { + select { + case <-ctx.Done(): + return cached, nil + case ddl, ok := <-br.ddlEventCh: + if !ok { + return cached, nil + } + cached = append(cached, ddl) + if len(cached) >= int(maxNumberOfDDLs) { + return cached, nil + } + } + } +} + +// ReadMeta implements LogReader.ReadMeta +func (br *MockReader) ReadMeta(ctx context.Context) (checkpointTs, resolvedTs uint64, err error) { + return br.checkpointTs, br.resolvedTs, nil +} + +// Close implements LogReader.Close. +func (br *MockReader) Close() error { + return nil +} + +func TestApplyDMLs(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + checkpointTs := uint64(1000) + resolvedTs := uint64(2000) + redoLogCh := make(chan *model.RedoRowChangedEvent, 1024) + ddlEventCh := make(chan *model.RedoDDLEvent, 1024) + createMockReader := func(ctx context.Context, cfg *RedoApplierConfig) (reader.RedoLogReader, error) { + return NewMockReader(checkpointTs, resolvedTs, redoLogCh, ddlEventCh), nil + } + + dbIndex := 0 + mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { + defer func() { + dbIndex++ + }() + if dbIndex == 0 { + // mock for test db, which is used querying TiDB session variable + db, mock, err := sqlmock.New() + if err != nil { + return nil, err + } + mock.ExpectQuery("SELECT @@SESSION.sql_mode;"). + WillReturnRows(sqlmock.NewRows([]string{"@@SESSION.sql_mode"}). + AddRow("ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE")) + columns := []string{"Variable_name", "Value"} + mock.ExpectQuery("show session variables like 'allow_auto_random_explicit_insert';").WillReturnRows( + sqlmock.NewRows(columns).AddRow("allow_auto_random_explicit_insert", "0"), + ) + mock.ExpectQuery("show session variables like 'tidb_txn_mode';").WillReturnRows( + sqlmock.NewRows(columns).AddRow("tidb_txn_mode", "pessimistic"), + ) + mock.ExpectClose() + return db, nil + } + // normal db + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + require.Nil(t, err) + mock.ExpectBegin() + mock.ExpectExec("REPLACE INTO `test`.`t1`(`a`,`b`) VALUES (?,?)"). + WithArgs(1, "2"). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + mock.ExpectBegin() + mock.ExpectExec("DELETE FROM `test`.`t1` WHERE `a` = ? LIMIT 1;"). + WithArgs(1). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("REPLACE INTO `test`.`t1`(`a`,`b`) VALUES (?,?)"). + WithArgs(2, "3"). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectClose() + return db, nil + } + + getDBConnBak := sink.GetDBConnImpl + sink.GetDBConnImpl = mockGetDBConn + createRedoReaderBak := createRedoReader + createRedoReader = createMockReader + defer func() { + createRedoReader = createRedoReaderBak + sink.GetDBConnImpl = getDBConnBak + }() + + dmls := []*model.RowChangedEvent{ + { + StartTs: 1100, + CommitTs: 1200, + Table: &model.TableName{Schema: "test", Table: "t1"}, + Columns: []*model.Column{ + { + Name: "a", + Value: 1, + Flag: model.HandleKeyFlag, + }, { + Name: "b", + Value: "2", + Flag: 0, + }, + }, + }, + { + StartTs: 1200, + CommitTs: 1300, + Table: &model.TableName{Schema: "test", Table: "t1"}, + PreColumns: []*model.Column{ + { + Name: "a", + Value: 1, + Flag: model.HandleKeyFlag, + }, { + Name: "b", + Value: "2", + Flag: 0, + }, + }, + Columns: []*model.Column{ + { + Name: "a", + Value: 2, + Flag: model.HandleKeyFlag, + }, { + Name: "b", + Value: "3", + Flag: 0, + }, + }, + }, + } + for _, dml := range dmls { + redoLogCh <- redo.RowToRedo(dml) + } + close(redoLogCh) + close(ddlEventCh) + + cfg := &RedoApplierConfig{SinkURI: "mysql://127.0.0.1:4000/?worker-count=1&max-txn-row=1"} + ap := NewRedoApplier(cfg) + err := ap.Apply(ctx) + require.Nil(t, err) +} + +func TestApplyMeetSinkError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + port, err := freeport.GetFreePort() + require.Nil(t, err) + cfg := &RedoApplierConfig{ + Storage: "blackhole://", + SinkURI: fmt.Sprintf("mysql://127.0.0.1:%d/?read-timeout=1s&timeout=1s", port), + } + ap := NewRedoApplier(cfg) + err = ap.Apply(ctx) + require.Regexp(t, "CDC:ErrMySQLConnectionError", err) +}