Skip to content

Commit

Permalink
Merge pull request #136 from hanchuanchuan/patch-update-multi-tables
Browse files Browse the repository at this point in the history
feature: 实现update set多表时的回滚支持
  • Loading branch information
hanchuanchuan authored Dec 16, 2019
2 parents 01ba2a2 + b174832 commit e575b00
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 239 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ before_install:
script:
- rm -f go.sum
# - docker ps
- travis_wait make dev upload-coverage
- travis_wait 30 make dev upload-coverage

after_failure:
- netstat -nltp
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ ifeq ("$(TRAVIS_COVERAGE)", "1")
else
@echo "Running in native mode."
@export log_level=error; \
$(GOTEST) -timeout 20m -ldflags '$(TEST_LDFLAGS)' -cover $(PACKAGES) || { $(GOFAIL_DISABLE); exit 1; }
$(GOTEST) -timeout 30m -ldflags '$(TEST_LDFLAGS)' -cover $(PACKAGES) || { $(GOFAIL_DISABLE); exit 1; }
endif
@$(GOFAIL_DISABLE)

race: parserlib
$(GO) get github.com/etcd-io/gofail@v0.0.0-20180808172546-51ce9a71510a
@$(GOFAIL_ENABLE)
@export log_level=debug; \
$(GOTEST) -timeout 20m -race $(PACKAGES) || { $(GOFAIL_DISABLE); exit 1; }
$(GOTEST) -timeout 30m -race $(PACKAGES) || { $(GOFAIL_DISABLE); exit 1; }
@$(GOFAIL_DISABLE)

leak: parserlib
Expand Down
6 changes: 6 additions & 0 deletions session/inception_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ type Record struct {
DBName string
TableName string
TableInfo *TableInfo

// ddl回滚
DDLRollback string
OPID string
Expand All @@ -88,6 +89,11 @@ type Record struct {

// 是否开启OSC
useOsc bool

// update多表时,记录多余的表
// update多表时,默认set第一列的表为主表,其余表才会记录到该处
// 仅在发现多表操作时,初始化该参数
MultiTables map[string]*TableInfo
}

type recordSet struct {
Expand Down
80 changes: 72 additions & 8 deletions session/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,13 @@ func (s *session) GetNextBackupRecord() *Record {
if r.StageStatus != StatusExecFail {
r.StageStatus = StatusBackupFail
}
// 清理已删除的列
clearDeleteColumns(r.TableInfo)
if r.MultiTables != nil {
for _, t := range r.MultiTables {
clearDeleteColumns(t)
}
}

return r
}
Expand Down Expand Up @@ -146,6 +152,7 @@ func configPrimaryKey(t *TableInfo) {
}
}

// clearDeleteColumns 清理已删除的列,方便解析binlog
func clearDeleteColumns(t *TableInfo) {
if t == nil || t.IsClear {
return
Expand Down Expand Up @@ -291,7 +298,9 @@ func (s *session) Parser(ctx context.Context) {
if s.checkFilter(event, record, currentThreadID) {
changeRows += len(event.Rows)
_, err = s.generateDeleteSql(record.TableInfo, event, e)
s.checkError(err)
if err != nil {
log.Error(err)
}
} else {
goto ENDCHECK
}
Expand All @@ -303,18 +312,26 @@ func (s *session) Parser(ctx context.Context) {
if s.checkFilter(event, record, currentThreadID) {
changeRows += len(event.Rows)
_, err = s.generateInsertSql(record.TableInfo, event, e)
s.checkError(err)
if err != nil {
log.Error(err)
}
} else {
goto ENDCHECK
}
}

case replication.UPDATE_ROWS_EVENTv1, replication.UPDATE_ROWS_EVENTv2:
if event, ok := e.Event.(*replication.RowsEvent); ok {
if s.checkFilter(event, record, currentThreadID) {
if ok, t := s.checkUpdateFilter(event, record, currentThreadID); ok {
changeRows += len(event.Rows) / 2
_, err = s.generateUpdateSql(record.TableInfo, event, e)
s.checkError(err)
if t != nil {
_, err = s.generateUpdateSql(t, event, e)
} else {
_, err = s.generateUpdateSql(record.TableInfo, event, e)
}
if err != nil {
log.Error(err)
}
} else {
goto ENDCHECK
}
Expand Down Expand Up @@ -404,6 +421,47 @@ func (s *session) checkFilter(event *replication.RowsEvent,
return true
}

// checkUpdateFilter 检查update的筛选条件
// update会涉及多表更新问题,所以需要把匹配到的表返回
func (s *session) checkUpdateFilter(event *replication.RowsEvent,
record *Record, currentThreadID uint32) (bool, *TableInfo) {
var multiTable *TableInfo
if record.MultiTables == nil {
if !strings.EqualFold(string(event.Table.Schema), record.TableInfo.Schema) ||
!strings.EqualFold(string(event.Table.Table), record.TableInfo.Name) {
return false, nil
}
} else {
found := false
if strings.EqualFold(string(event.Table.Schema), record.TableInfo.Schema) &&
strings.EqualFold(string(event.Table.Table), record.TableInfo.Name) {
found = true
} else {
for _, t := range record.MultiTables {
if strings.EqualFold(string(event.Table.Schema), t.Schema) &&
strings.EqualFold(string(event.Table.Table), t.Name) {
multiTable = t
found = true
break
}
}
}
if !found {
return false, nil
}
}

if currentThreadID == 0 && s.DBType == DBTypeMariaDB {
if record.ErrLevel != 1 {
record.AppendErrorNo(ErrNotFoundThreadId, s.DBVersion)
}
return true, multiTable
} else if record.ThreadId != currentThreadID {
return false, nil
}
return true, multiTable
}

// 解析的sql写入缓存,并定期入库
func (s *session) myWrite(b []byte, binEvent *replication.BinlogEvent,
opid string, table string, record *Record) {
Expand Down Expand Up @@ -517,7 +575,9 @@ func (s *session) generateInsertSql(t *TableInfo, e *replication.RowsEvent,
}

r, err := InterpolateParams(sql, vv, s.Inc.HexBlob)
s.checkError(err)
if err != nil {
log.Error(err)
}

s.write(r, binEvent)
}
Expand Down Expand Up @@ -580,7 +640,9 @@ func (s *session) generateDeleteSql(t *TableInfo, e *replication.RowsEvent,
newSql := strings.Join([]string{sql, strings.Join(columnNames, " AND")}, "")

r, err := InterpolateParams(newSql, vv, s.Inc.HexBlob)
s.checkError(err)
if err != nil {
log.Error(err)
}

s.write(r, binEvent)

Expand Down Expand Up @@ -760,7 +822,9 @@ func (s *session) generateUpdateSql(t *TableInfo, e *replication.RowsEvent,
newSql = strings.Join([]string{sql, strings.Join(columnNames, " AND")}, "")
newValues = append(newValues, oldValues...)
r, err := InterpolateParams(newSql, newValues, s.Inc.HexBlob)
s.checkError(err)
if err != nil {
log.Error(err)
}

s.write(r, binEvent)

Expand Down
69 changes: 60 additions & 9 deletions session/session_inception.go
Original file line number Diff line number Diff line change
Expand Up @@ -6686,14 +6686,20 @@ func (s *session) checkChangeDB(node *ast.UseStmt, sql string) {
s.DBName = node.DBName

// 新建库跳过use 切换
if s.checkDBExists(node.DBName, true) && !s.dbCacheList[strings.ToLower(node.DBName)].IsNew {
_, err := s.Exec(fmt.Sprintf("USE `%s`", node.DBName), true)
if err != nil {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
s.AppendErrorMessage(myErr.Message)
} else {
s.AppendErrorMessage(err.Error())
if s.checkDBExists(node.DBName, true) {
key := node.DBName
if s.IgnoreCase() {
key = strings.ToLower(key)
}
if v, ok := s.dbCacheList[key]; ok && !v.IsNew {
_, err := s.Exec(fmt.Sprintf("USE `%s`", node.DBName), true)
if err != nil {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
s.AppendErrorMessage(myErr.Message)
} else {
s.AppendErrorMessage(err.Error())
}
}
}
}
Expand Down Expand Up @@ -6987,6 +6993,7 @@ func (s *session) checkUpdate(node *ast.UpdateStmt, sql string) {
}
}
}

}
}

Expand Down Expand Up @@ -7068,7 +7075,28 @@ func (s *session) checkUpdate(node *ast.UpdateStmt, sql string) {
l.Column.Table = model.NewCIStr(s.myRecord.TableInfo.Name)
}

s.checkFieldItem(l.Column, tableInfoList)
if s.checkFieldItem(l.Column, tableInfoList) {

// update多表操作
// set不同的表
// 存储其他表到MultiTables对象
if len(tableInfoList) > 1 {
if t := getFieldWithTableInfo(l.Column, tableInfoList); t != nil {
if !strings.EqualFold(t.Schema, s.myRecord.TableInfo.Schema) ||
!strings.EqualFold(t.Name, s.myRecord.TableInfo.Name) {
key := fmt.Sprintf("%s.%s", t.Schema, t.Name)
key = strings.ToLower(key)

if s.myRecord.MultiTables == nil {
s.myRecord.MultiTables = make(map[string]*TableInfo, 0)
s.myRecord.MultiTables[key] = t
} else if _, ok := s.myRecord.MultiTables[key]; !ok {
s.myRecord.MultiTables[key] = t
}
}
}
}
}

// 多表update情况时,下面的判断会有问题
// found := false
Expand Down Expand Up @@ -7320,6 +7348,29 @@ func (s *session) checkFieldItem(name *ast.ColumnName, tables []*TableInfo) bool
}
}

// getFieldWithTableInfo 获取字段对应的表信息
func getFieldWithTableInfo(name *ast.ColumnName, tables []*TableInfo) *TableInfo {
db := name.Schema.L
for _, t := range tables {
var tName string
if t.AsName != "" {
tName = t.AsName
} else {
tName = t.Name
}
if name.Table.L != "" && (db == "" || strings.EqualFold(t.Schema, db)) &&
(strings.EqualFold(tName, name.Table.L)) ||
name.Table.L == "" {
for _, field := range t.Fields {
if strings.EqualFold(field.Field, name.Name.L) && !field.IsDeleted {
return t
}
}
}
}
return nil
}

// getFieldItem 获取字段信息
func getFieldInfo(name *ast.ColumnName, tables []*TableInfo) (*FieldInfo, string) {
db := name.Schema.L
Expand Down
84 changes: 80 additions & 4 deletions session/session_inception_backup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,19 +431,95 @@ func (s *testSessionIncBackupSuite) TestUpdate(c *C) {
"UPDATE `test_inc`.`t1` SET `c1`=123456789012.1234 WHERE `id`=1;", Commentf("%v", res.Rows()))

// -------------------- 多表update -------------------
config.GetGlobalConfig().Inc.EnableMinimalRollback = false
sql = `drop table if exists table1;drop table if exists table2;
create table table1(id1 int primary key,c1 int,c2 int);
create table table2(id2 int primary key,c1 int,c2 int,c22 int);
insert into table1 values(1,1,1),(2,1,1);
insert into table2 values(1,1,1,null),(2,1,null,null);
update table1 t1,table2 t2 set t1.c1=10,t2.c22=20 where t1.id1=t2.id2 and t2.c1=1;`
// res = s.mustRunBackup(c, sql)
// row = res.Rows()[int(s.tk.Se.AffectedRows())-1]
// backup = s.query("table1", row[7].(string))
// c.Assert(backup, Equals, "UPDATE `test_inc`.`t1` SET `c1`=123456789012.1234 WHERE `id`=1;", Commentf("%v", res.Rows()))

res = s.mustRunBackup(c, sql)
s.assertRows(c, res.Rows()[int(s.tk.Se.AffectedRows())-1:],
"UPDATE `test_inc`.`table2` SET `id2`=1, `c1`=1, `c2`=1, `c22`=NULL WHERE `id2`=1;",
"UPDATE `test_inc`.`table2` SET `id2`=2, `c1`=1, `c2`=NULL, `c22`=NULL WHERE `id2`=2;",
"UPDATE `test_inc`.`table1` SET `id1`=1, `c1`=1, `c2`=1 WHERE `id1`=1;",
"UPDATE `test_inc`.`table1` SET `id1`=2, `c1`=1, `c2`=1 WHERE `id1`=2;",
)

sql = `drop table if exists table1;drop table if exists table2;
create table table1(id1 int primary key,c1 int,c2 int);
create table table2(id2 int primary key,c1 int,c2 int,c22 int);
insert into table1 values(1,1,1),(2,1,1);
insert into table2 values(1,1,1,null),(2,1,null,null);
update table1 t1,table2 t2 set t2.c22=20,t1.c1=10 where t1.id1=t2.id2 and t2.c1=1;`

res = s.mustRunBackup(c, sql)
s.assertRows(c, res.Rows()[int(s.tk.Se.AffectedRows())-1:],
"UPDATE `test_inc`.`table2` SET `id2`=1, `c1`=1, `c2`=1, `c22`=NULL WHERE `id2`=1;",
"UPDATE `test_inc`.`table2` SET `id2`=2, `c1`=1, `c2`=NULL, `c22`=NULL WHERE `id2`=2;",
"UPDATE `test_inc`.`table1` SET `id1`=1, `c1`=1, `c2`=1 WHERE `id1`=1;",
"UPDATE `test_inc`.`table1` SET `id1`=2, `c1`=1, `c2`=1 WHERE `id1`=2;",
)

sql = `drop table if exists table1;drop table if exists table2;
create table table1(id1 int primary key,c1 int,c2 int);
create table table2(id2 int primary key,c1 int,c2 int,c22 int);
insert into table1 values(1,1,1),(2,1,1);
insert into table2 values(1,1,1,null),(2,1,null,null);
update table1 t1,table2 t2 set c22=20,t1.c1=10 where t1.id1=t2.id2 and t2.c1=1;`

res = s.mustRunBackup(c, sql)
s.assertRows(c, res.Rows()[int(s.tk.Se.AffectedRows())-1:],
"UPDATE `test_inc`.`table2` SET `id2`=1, `c1`=1, `c2`=1, `c22`=NULL WHERE `id2`=1;",
"UPDATE `test_inc`.`table2` SET `id2`=2, `c1`=1, `c2`=NULL, `c22`=NULL WHERE `id2`=2;",
"UPDATE `test_inc`.`table1` SET `id1`=1, `c1`=1, `c2`=1 WHERE `id1`=1;",
"UPDATE `test_inc`.`table1` SET `id1`=2, `c1`=1, `c2`=1 WHERE `id1`=2;",
)

config.GetGlobalConfig().Inc.EnableMinimalRollback = true

sql = `drop table if exists table1;drop table if exists table2;
create table table1(id1 int primary key,c1 int,c2 int);
create table table2(id2 int primary key,c1 int,c2 int,c22 int);
insert into table1 values(1,1,1),(2,1,1);
insert into table2 values(1,1,1,null),(2,1,null,null);
update table1 t1,table2 t2 set t1.c1=10,t2.c22=20 where t1.id1=t2.id2 and t2.c1=1;`

res = s.mustRunBackup(c, sql)
s.assertRows(c, res.Rows()[int(s.tk.Se.AffectedRows())-1:],
"UPDATE `test_inc`.`table2` SET `c22`=NULL WHERE `id2`=1;",
"UPDATE `test_inc`.`table2` SET `c22`=NULL WHERE `id2`=2;",
"UPDATE `test_inc`.`table1` SET `c1`=1 WHERE `id1`=1;",
"UPDATE `test_inc`.`table1` SET `c1`=1 WHERE `id1`=2;",
)

sql = `drop table if exists table1;drop table if exists table2;
create table table1(id1 int primary key,c1 int,c2 int);
create table table2(id2 int primary key,c1 int,c2 int,c22 int);
insert into table1 values(1,1,1),(2,1,1);
insert into table2 values(1,1,1,null),(2,1,null,null);
update table1 t1,table2 t2 set t2.c22=20,t1.c1=10 where t1.id1=t2.id2 and t2.c1=1;`

res = s.mustRunBackup(c, sql)
s.assertRows(c, res.Rows()[int(s.tk.Se.AffectedRows())-1:],
"UPDATE `test_inc`.`table2` SET `c22`=NULL WHERE `id2`=1;",
"UPDATE `test_inc`.`table2` SET `c22`=NULL WHERE `id2`=2;",
"UPDATE `test_inc`.`table1` SET `c1`=1 WHERE `id1`=1;",
"UPDATE `test_inc`.`table1` SET `c1`=1 WHERE `id1`=2;",
)

sql = `drop table if exists table1;drop table if exists table2;
create table table1(id1 int primary key,c1 int,c2 int);
create table table2(id2 int primary key,c1 int,c2 int,c22 int);
insert into table1 values(1,1,1),(2,1,1);
insert into table2 values(1,1,1,null),(2,1,null,null);
update table1 t1,table2 t2 set c22=20,t1.c1=10 where t1.id1=t2.id2 and t2.c1=1;`

res = s.mustRunBackup(c, sql)
s.assertRows(c, res.Rows()[int(s.tk.Se.AffectedRows())-1:],
"UPDATE `test_inc`.`table2` SET `c22`=NULL WHERE `id2`=1;",
"UPDATE `test_inc`.`table2` SET `c22`=NULL WHERE `id2`=2;",
"UPDATE `test_inc`.`table1` SET `c1`=1 WHERE `id1`=1;",
"UPDATE `test_inc`.`table1` SET `c1`=1 WHERE `id1`=2;",
)
Expand Down
Loading

0 comments on commit e575b00

Please sign in to comment.