diff --git a/cdc/sink/mysql_test.go b/cdc/sink/mysql_test.go index b7dbe1e49f0..a6669c5887a 100644 --- a/cdc/sink/mysql_test.go +++ b/cdc/sink/mysql_test.go @@ -444,13 +444,26 @@ func mockTestDB() (*sql.DB, error) { return db, nil } +func mockDBWithAdjustedSQLMode() (*sql.DB, sqlmock.Sqlmock, error) { + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + if err != nil { + return db, mock, err + } + // sql mode is adjust for compatibility. + mock.ExpectQuery("SELECT @@SESSION.sql_mode;"). + WillReturnRows(sqlmock.NewRows([]string{"@@SESSION.sql_mode"}). + AddRow("STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE")) + mock.ExpectExec("SET sql_mode = 'ALLOW_INVALID_DATES,IGNORE_SPACE,NO_AUTO_VALUE_ON_ZERO';"). + WillReturnResult(sqlmock.NewResult(0, 0)) + return db, mock, err +} + func TestAdjustSQLMode(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() dbIndex := 0 mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { - fmt.Printf("mockGetDBConn is called\n") defer func() { dbIndex++ }() @@ -461,16 +474,9 @@ func TestAdjustSQLMode(t *testing.T) { return db, nil } // normal db - db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + db, mock, err := mockDBWithAdjustedSQLMode() require.Nil(t, err) - // sql mode is adjust for compatibility. - mock.ExpectQuery("SELECT @@SESSION.sql_mode;"). - WillReturnRows(sqlmock.NewRows([]string{"@@SESSION.sql_mode"}). - AddRow("STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE")) - mock.ExpectExec("SET sql_mode = 'ALLOW_INVALID_DATES,IGNORE_SPACE,NO_AUTO_VALUE_ON_ZERO';"). - WillReturnResult(sqlmock.NewResult(0, 0)) - // sql mode is adjust for compatibility. mock.ExpectQuery("SELECT @@SESSION.sql_mode;"). WillReturnRows(sqlmock.NewRows([]string{"@@SESSION.sql_mode"}). @@ -587,7 +593,7 @@ func TestNewMySQLSinkExecDML(t *testing.T) { return db, nil } // normal db - db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + db, mock, err := mockDBWithAdjustedSQLMode() require.Nil(t, err) mock.ExpectBegin() mock.ExpectExec("REPLACE INTO `s1`.`t1`(`a`,`b`) VALUES (?,?),(?,?)"). @@ -731,7 +737,7 @@ func TestExecDMLRollbackErrDatabaseNotExists(t *testing.T) { return db, nil } // normal db - db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + db, mock, err := mockDBWithAdjustedSQLMode() require.Nil(t, err) mock.ExpectBegin() mock.ExpectExec("REPLACE INTO `s1`.`t1`(`a`) VALUES (?),(?)"). @@ -797,7 +803,7 @@ func TestExecDMLRollbackErrTableNotExists(t *testing.T) { return db, nil } // normal db - db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + db, mock, err := mockDBWithAdjustedSQLMode() require.Nil(t, err) mock.ExpectBegin() mock.ExpectExec("REPLACE INTO `s1`.`t1`(`a`) VALUES (?),(?)"). @@ -863,7 +869,7 @@ func TestExecDMLRollbackErrRetryable(t *testing.T) { return db, nil } // normal db - db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + db, mock, err := mockDBWithAdjustedSQLMode() require.Nil(t, err) for i := 0; i < defaultDMLMaxRetryTime; i++ { mock.ExpectBegin() @@ -912,7 +918,7 @@ func TestNewMySQLSinkExecDDL(t *testing.T) { return db, nil } // normal db - db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + db, mock, err := mockDBWithAdjustedSQLMode() require.Nil(t, err) mock.ExpectBegin() mock.ExpectExec("USE `test`;").WillReturnResult(sqlmock.NewResult(1, 1)) @@ -1042,7 +1048,8 @@ func TestNewMySQLSink(t *testing.T) { return db, nil } // normal db - db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + db, mock, err := mockDBWithAdjustedSQLMode() + require.Nil(t, err) mock.ExpectClose() require.Nil(t, err) return db, nil @@ -1081,7 +1088,8 @@ func TestMySQLSinkClose(t *testing.T) { return db, nil } // normal db - db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + db, mock, err := mockDBWithAdjustedSQLMode() + require.Nil(t, err) mock.ExpectClose() require.Nil(t, err) return db, nil @@ -1121,7 +1129,8 @@ func TestMySQLSinkFlushResovledTs(t *testing.T) { return db, nil } // normal db - db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + db, mock, err := mockDBWithAdjustedSQLMode() + require.Nil(t, err) mock.ExpectBegin() mock.ExpectExec("REPLACE INTO `s1`.`t1`(`a`) VALUES (?)"). WithArgs(1).