diff --git a/pkg/datasource/sql/conn/oracle_test.go b/pkg/datasource/sql/conn/oracle_test.go index ef8e4eac0..8daeb5852 100644 --- a/pkg/datasource/sql/conn/oracle_test.go +++ b/pkg/datasource/sql/conn/oracle_test.go @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package conn import ( diff --git a/pkg/datasource/sql/exec/at/at_executor.go b/pkg/datasource/sql/exec/at/at_executor.go index e32c13242..6d0446854 100644 --- a/pkg/datasource/sql/exec/at/at_executor.go +++ b/pkg/datasource/sql/exec/at/at_executor.go @@ -64,8 +64,10 @@ func (e *ATExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Exec exec = NewDeleteExecutor(parser, execCtx, e.hooks) case types.SQLTypeSelectForUpdate: exec = NewSelectForUpdateExecutor(parser, execCtx, e.hooks) - //case types.SQLTypeMultiDelete: - //case types.SQLTypeMultiUpdate: + case types.SQLTypeInsertOnUpdate: + exec = NewInsertOnUpdateExecutor(parser, execCtx, e.hooks) + // case types.SQLTypeMultiDelete: + // case types.SQLTypeMultiUpdate: default: exec = NewPlainExecutor(parser, execCtx) } diff --git a/pkg/datasource/sql/exec/at/insert_executor.go b/pkg/datasource/sql/exec/at/insert_executor.go index e27ae8771..872832dc8 100644 --- a/pkg/datasource/sql/exec/at/insert_executor.go +++ b/pkg/datasource/sql/exec/at/insert_executor.go @@ -194,7 +194,7 @@ func (i *insertExecutor) getPkValues(ctx context.Context, execCtx *types.ExecCon pkColumnNameList := meta.GetPrimaryKeyOnlyName() pkValuesMap := make(map[string][]interface{}) var err error - //when there is only one pk in the table + // when there is only one pk in the table if len(pkColumnNameList) == 1 { if i.containsPK(meta, parseCtx) { // the insert sql contain pk value @@ -215,9 +215,9 @@ func (i *insertExecutor) getPkValues(ctx context.Context, execCtx *types.ExecCon } } } else { - //when there is multiple pk in the table - //1,all pk columns are filled value. - //2,the auto increment pk column value is null, and other pk value are not null. + // when there is multiple pk in the table + // 1,all pk columns are filled value. + // 2,the auto increment pk column value is null, and other pk value are not null. pkValuesMap, err = i.getPkValuesByColumn(ctx, execCtx) if err != nil { return nil, err @@ -336,7 +336,7 @@ func (i *insertExecutor) parsePkValuesFromStatement(insertStmt *ast.InsertStmt, pkValuesMap := make(map[string][]interface{}) if nameValues != nil && len(nameValues) > 0 { - //use prepared statements + // use prepared statements insertRows, err := getInsertRows(insertStmt, pkIndexArray) if err != nil { return nil, err diff --git a/pkg/datasource/sql/exec/at/insert_executor_test.go b/pkg/datasource/sql/exec/at/insert_executor_test.go index 804327cdb..b7556c5a5 100644 --- a/pkg/datasource/sql/exec/at/insert_executor_test.go +++ b/pkg/datasource/sql/exec/at/insert_executor_test.go @@ -297,7 +297,7 @@ func TestMySQLInsertUndoLogBuilder_containPK(t *testing.T) { executor.(*insertExecutor).businesSQLResult = tt.fields.InsertResult executor.(*insertExecutor).incrementStep = tt.fields.IncrementStep - assert.Equalf(t, tt.want, executor.(*insertExecutor).containPK(tt.args.columnName, tt.args.meta), "containPK(%v, %v)", tt.args.columnName, tt.args.meta) + assert.Equalf(t, tt.want, executor.(*insertExecutor).containPK(tt.args.columnName, tt.args.meta), "isPKColumn(%v, %v)", tt.args.columnName, tt.args.meta) }) } } @@ -407,22 +407,25 @@ func TestMySQLInsertUndoLogBuilder_getPkIndex(t *testing.T) { }, }, }, - meta: types.TableMeta{}}, want: map[string]int{}}, + meta: types.TableMeta{}, + }, want: map[string]int{}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { executor := NewInsertExecutor(nil, &types.ExecContext{}, []exec.SQLHook{}) executor.(*insertExecutor).businesSQLResult = tt.fields.InsertResult executor.(*insertExecutor).incrementStep = tt.fields.IncrementStep - assert.Equalf(t, tt.want, executor.(*insertExecutor).getPkIndex(tt.args.InsertStmt, tt.args.meta), "getPkIndex(%v, %v)", tt.args.InsertStmt, tt.args.meta) + assert.Equalf(t, tt.want, executor.(*insertExecutor).getPkIndex(tt.args.InsertStmt, tt.args.meta), "getPkIndexArray(%v, %v)", tt.args.InsertStmt, tt.args.meta) }) } } + func genIntDatum(id int64) test_driver.Datum { tmp := test_driver.Datum{} tmp.SetInt64(id) return tmp } + func genStrDatum(str string) test_driver.Datum { tmp := test_driver.Datum{} tmp.SetBytesAsString([]byte(str)) @@ -455,12 +458,13 @@ func TestMySQLInsertUndoLogBuilder_parsePkValuesFromStatement(t *testing.T) { Name: model.CIStr{O: "id", L: "id"}, }, }, - Lists: [][]ast.ExprNode{{ - &test_driver.ValueExpr{ - Datum: genIntDatum(1), + Lists: [][]ast.ExprNode{ + { + &test_driver.ValueExpr{ + Datum: genIntDatum(1), + }, }, }, - }, }, meta: types.TableMeta{ ColumnNames: []string{"id"}, @@ -503,12 +507,13 @@ func TestMySQLInsertUndoLogBuilder_parsePkValuesFromStatement(t *testing.T) { Name: model.CIStr{O: "id", L: "id"}, }, }, - Lists: [][]ast.ExprNode{{ - &test_driver.ValueExpr{ - Datum: genStrDatum("?"), + Lists: [][]ast.ExprNode{ + { + &test_driver.ValueExpr{ + Datum: genStrDatum("?"), + }, }, }, - }, }, meta: types.TableMeta{ ColumnNames: []string{"id"}, @@ -605,15 +610,17 @@ func TestMySQLInsertUndoLogBuilder_getPkValuesByColumn(t *testing.T) { Name: model.CIStr{O: "id", L: "id"}, }, }, - Lists: [][]ast.ExprNode{{ - &test_driver.ValueExpr{ - Datum: genIntDatum(1), + Lists: [][]ast.ExprNode{ + { + &test_driver.ValueExpr{ + Datum: genIntDatum(1), + }, }, }, - }, }, }, - }}, + }, + }, want: map[string][]interface{}{ "id": {int64(1)}, }, @@ -705,15 +712,17 @@ func TestMySQLInsertUndoLogBuilder_getPkValuesByAuto(t *testing.T) { Name: model.CIStr{O: "name", L: "name"}, }, }, - Lists: [][]ast.ExprNode{{ - &test_driver.ValueExpr{ - Datum: genStrDatum("Tom"), + Lists: [][]ast.ExprNode{ + { + &test_driver.ValueExpr{ + Datum: genStrDatum("Tom"), + }, }, }, - }, }, }, - }}, + }, + }, want: map[string][]interface{}{ "id": {int64(100)}, }, @@ -795,12 +804,13 @@ func TestMySQLInsertUndoLogBuilder_autoGeneratePks(t *testing.T) { Name: model.CIStr{O: "id", L: "id"}, }, }, - Lists: [][]ast.ExprNode{{ - &test_driver.ValueExpr{ - Datum: genIntDatum(1), + Lists: [][]ast.ExprNode{ + { + &test_driver.ValueExpr{ + Datum: genIntDatum(1), + }, }, }, - }, }, }, }, diff --git a/pkg/datasource/sql/exec/at/insert_on_update_executor.go b/pkg/datasource/sql/exec/at/insert_on_update_executor.go new file mode 100644 index 000000000..1ba0223cb --- /dev/null +++ b/pkg/datasource/sql/exec/at/insert_on_update_executor.go @@ -0,0 +1,413 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package at + +import ( + "context" + "database/sql/driver" + "fmt" + "strings" + + "github.com/arana-db/parser/ast" + "github.com/seata/seata-go/pkg/datasource/sql/datasource" + "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/datasource/sql/util" + "github.com/seata/seata-go/pkg/util/log" +) + +// insertOnUpdateExecutor execute insert on update SQL +type insertOnUpdateExecutor struct { + baseExecutor + parserCtx *types.ParseContext + execContext *types.ExecContext + beforeImageSqlPrimaryKeys map[string]bool + beforeSelectSql string + beforeSelectArgs []driver.NamedValue +} + +// NewInsertOnUpdateExecutor get insert on update executor +func NewInsertOnUpdateExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor { + return &insertOnUpdateExecutor{ + baseExecutor: baseExecutor{hooks: hooks}, + parserCtx: parserCtx, + execContext: execContent, + beforeImageSqlPrimaryKeys: make(map[string]bool), + } +} + +func (i *insertOnUpdateExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { + i.beforeHooks(ctx, i.execContext) + defer func() { + i.afterHooks(ctx, i.execContext) + }() + + beforeImage, err := i.beforeImage(ctx) + if err != nil { + return nil, err + } + + res, err := f(ctx, i.execContext.Query, i.execContext.NamedValues) + if err != nil { + return nil, err + } + + afterImage, err := i.afterImage(ctx, beforeImage) + if err != nil { + return nil, err + } + + i.execContext.TxCtx.RoundImages.AppendBeofreImage(beforeImage) + i.execContext.TxCtx.RoundImages.AppendAfterImage(afterImage) + return res, nil +} + +// beforeImage build before image +func (i *insertOnUpdateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, error) { + if !i.isAstStmtValid() { + log.Errorf("invalid insert statement! parser ctx:%+v", i.parserCtx) + return nil, fmt.Errorf("invalid insert statement! parser ctx:%+v", i.parserCtx) + } + tableName, err := i.parserCtx.GetTableName() + if err != nil { + return nil, err + } + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName) + if err != nil { + return nil, err + } + selectSQL, selectArgs, err := i.buildBeforeImageSQL(i.parserCtx.InsertStmt, *metaData, i.execContext.NamedValues) + if err != nil { + return nil, err + } + if len(selectArgs) == 0 { + log.Errorf("the SQL statement has no primary key or unique index value, it will not hit any row data."+ + "recommend to convert to a normal insert statement. db name:%s table name:%s sql:%s", i.execContext.DBName, tableName, i.execContext.Query) + return nil, fmt.Errorf("the SQL statement has no primary key or unique index value, it will not hit any row data."+ + "recommend to convert to a normal insert statement. db name:%s table name:%s sql:%s", i.execContext.DBName, tableName, i.execContext.Query) + } + i.beforeSelectSql = selectSQL + i.beforeSelectArgs = selectArgs + + var rowsi driver.Rows + queryerCtx, queryerCtxExists := i.execContext.Conn.(driver.QueryerContext) + var queryer driver.Queryer + var queryerExists bool + + if !queryerCtxExists { + queryer, queryerExists = i.execContext.Conn.(driver.Queryer) + } + if !queryerExists && !queryerCtxExists { + log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") + return nil, fmt.Errorf("invalid conn") + } + rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) + defer func() { + if rowsi != nil { + rowsi.Close() + } + }() + if err != nil { + log.Errorf("ctx driver query: %+v", err) + return nil, err + } + image, err := i.buildRecordImages(rowsi, metaData) + if err != nil { + return nil, err + } + return image, nil +} + +// buildBeforeImageSQL build the SQL to query before image data +func (i *insertOnUpdateExecutor) buildBeforeImageSQL(insertStmt *ast.InsertStmt, metaData types.TableMeta, args []driver.NamedValue) (string, []driver.NamedValue, error) { + if err := checkDuplicateKeyUpdate(insertStmt, metaData); err != nil { + return "", nil, err + } + + paramMap, insertNum, err := i.buildBeforeImageSQLParameters(insertStmt, args, metaData) + if err != nil { + return "", nil, err + } + sql := strings.Builder{} + sql.WriteString("SELECT * FROM " + metaData.TableName + " ") + isContainWhere := false + var selectArgs []driver.NamedValue + for j := 0; j < insertNum; j++ { + finalJ := j + var paramAppenderTempList []driver.NamedValue + for _, index := range metaData.Indexs { + // unique index + if index.NonUnique || isIndexValueNull(index, paramMap, finalJ) { + continue + } + columnIsNull := true + var uniqueList []string + for _, columnMeta := range index.Columns { + columnName := columnMeta.ColumnName + imageParameters, ok := paramMap[columnName] + if !ok && columnMeta.ColumnDef != nil { + if strings.EqualFold("PRIMARY", index.Name) { + i.beforeImageSqlPrimaryKeys[columnName] = true + } + uniqueList = append(uniqueList, columnName+" = DEFAULT("+columnName+") ") + columnIsNull = false + continue + } + if strings.EqualFold("PRIMARY", index.Name) { + i.beforeImageSqlPrimaryKeys[columnName] = true + } + columnIsNull = false + uniqueList = append(uniqueList, columnName+" = ? ") + paramAppenderTempList = append(paramAppenderTempList, imageParameters[finalJ]) + } + + if !columnIsNull { + if isContainWhere { + sql.WriteString(" OR (" + strings.Join(uniqueList, " and ") + ") ") + } else { + sql.WriteString(" WHERE (" + strings.Join(uniqueList, " and ") + ") ") + isContainWhere = true + } + } + } + selectArgs = append(selectArgs, paramAppenderTempList...) + } + log.Infof("build select sql by insert on update sourceQuery, sql {}", sql.String()) + return sql.String(), selectArgs, nil +} + +// buildBeforeImageSQLParameters build the SQL parameters to query before image data +func (i *insertOnUpdateExecutor) buildBeforeImageSQLParameters(insertStmt *ast.InsertStmt, args []driver.NamedValue, metaData types.TableMeta) (map[string][]driver.NamedValue, int, error) { + pkIndexArray := i.getPkIndexArray(insertStmt, metaData) + insertRows, err := getInsertRows(insertStmt, pkIndexArray) + if err != nil { + return nil, 0, err + } + + parameterMap := make(map[string][]driver.NamedValue) + insertColumns := getInsertColumns(insertStmt) + placeHolderIndex := 0 + for _, rowColumns := range insertRows { + if len(rowColumns) != len(insertColumns) { + log.Errorf("insert row's column size not equal to insert column size. row columns:%+v insert columns:%+v", rowColumns, insertColumns) + return nil, 0, fmt.Errorf("insert row's column size not equal to insert column size. row columns:%+v insert columns:%+v", rowColumns, insertColumns) + } + for i, col := range insertColumns { + columnName := DelEscape(col, types.DBTypeMySQL) + val := rowColumns[i] + rStr, ok := val.(string) + if ok && strings.EqualFold(rStr, sqlPlaceholder) { + objects := args[placeHolderIndex] + parameterMap[columnName] = append(parameterMap[col], objects) + placeHolderIndex++ + } else { + parameterMap[columnName] = append(parameterMap[col], driver.NamedValue{ + Ordinal: i + 1, + Name: columnName, + Value: val, + }) + } + } + } + return parameterMap, len(insertRows), nil +} + +// afterImage build after image +func (i *insertOnUpdateExecutor) afterImage(ctx context.Context, beforeImages *types.RecordImage) (*types.RecordImage, error) { + afterSelectSql, selectArgs := i.buildAfterImageSQL(beforeImages) + var rowsi driver.Rows + queryerCtx, queryerCtxExists := i.execContext.Conn.(driver.QueryerContext) + var queryer driver.Queryer + var queryerExists bool + if !queryerCtxExists { + queryer, queryerExists = i.execContext.Conn.(driver.Queryer) + } + if !queryerCtxExists && !queryerExists { + log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") + return nil, fmt.Errorf("invalid conn") + } + rowsi, err := util.CtxDriverQuery(ctx, queryerCtx, queryer, afterSelectSql, selectArgs) + defer func() { + if rowsi != nil { + rowsi.Close() + } + }() + if err != nil { + log.Errorf("ctx driver query: %+v", err) + return nil, err + } + tableName, err := i.parserCtx.GetTableName() + if err != nil { + return nil, err + } + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName) + if err != nil { + return nil, err + } + afterImage, err := i.buildRecordImages(rowsi, metaData) + if err != nil { + return nil, err + } + lockKey := i.buildLockKey(afterImage, *metaData) + i.execContext.TxCtx.LockKeys[lockKey] = struct{}{} + return afterImage, nil +} + +// buildAfterImageSQL build the SQL to query after image data +func (i *insertOnUpdateExecutor) buildAfterImageSQL(beforeImage *types.RecordImage) (string, []driver.NamedValue) { + selectSQL, selectArgs := i.beforeSelectSql, i.beforeSelectArgs + primaryValueMap := make(map[string][]interface{}) + + for _, row := range beforeImage.Rows { + for _, col := range row.Columns { + if col.KeyType == types.IndexTypePrimaryKey { + primaryValueMap[col.ColumnName] = append(primaryValueMap[col.ColumnName], col.Value) + } + } + } + + var afterImageSql strings.Builder + var primaryValues []driver.NamedValue + afterImageSql.WriteString(selectSQL) + for j := 0; j < len(beforeImage.Rows); j++ { + var wherePrimaryList []string + for name, value := range primaryValueMap { + if !i.beforeImageSqlPrimaryKeys[name] { + wherePrimaryList = append(wherePrimaryList, name+" = ? ") + primaryValues = append(primaryValues, driver.NamedValue{ + Name: name, + Value: value[j], + }) + } + } + if len(wherePrimaryList) != 0 { + afterImageSql.WriteString(" OR (" + strings.Join(wherePrimaryList, " and ") + ") ") + } + } + selectArgs = append(selectArgs, primaryValues...) + log.Infof("build after select sql by insert on duplicate sourceQuery, sql {}", afterImageSql.String()) + return afterImageSql.String(), selectArgs +} + +// isPKColumn check the column name to see if it is a primary key column +func (i *insertOnUpdateExecutor) isPKColumn(columnName string, meta types.TableMeta) bool { + newColumnName := DelEscape(columnName, types.DBTypeMySQL) + pkColumnNameList := meta.GetPrimaryKeyOnlyName() + if len(pkColumnNameList) == 0 { + return false + } + for _, name := range pkColumnNameList { + if strings.EqualFold(name, newColumnName) { + return true + } + } + return false +} + +// getPkIndexArray get index of primary key from insert statement +func (i *insertOnUpdateExecutor) getPkIndexArray(insertStmt *ast.InsertStmt, meta types.TableMeta) []int { + var pkIndexArray []int + if insertStmt == nil { + return pkIndexArray + } + insertColumnsSize := len(insertStmt.Columns) + if insertColumnsSize == 0 { + return pkIndexArray + } + if meta.ColumnNames == nil { + return pkIndexArray + } + if len(meta.Columns) > 0 { + for paramIdx := 0; paramIdx < insertColumnsSize; paramIdx++ { + sqlColumnName := insertStmt.Columns[paramIdx].Name.O + if i.isPKColumn(sqlColumnName, meta) { + pkIndexArray = append(pkIndexArray, paramIdx) + } + } + return pkIndexArray + } + + pkIndex := -1 + allColumns := meta.Columns + for _, columnMeta := range allColumns { + tmpColumnMeta := columnMeta + pkIndex++ + if i.isPKColumn(tmpColumnMeta.ColumnName, meta) { + pkIndexArray = append(pkIndexArray, pkIndex) + } + } + + return pkIndexArray +} + +func (i *insertOnUpdateExecutor) isAstStmtValid() bool { + return i.parserCtx != nil && i.parserCtx.InsertStmt != nil +} + +// isIndexValueNull check if the index value is null +func isIndexValueNull(indexMeta types.IndexMeta, imageParameterMap map[string][]driver.NamedValue, rowIndex int) bool { + for _, colMeta := range indexMeta.Columns { + columnName := colMeta.ColumnName + imageParameters := imageParameterMap[columnName] + if imageParameters == nil && colMeta.ColumnDef == nil { + return true + } else if imageParameters != nil && (rowIndex >= len(imageParameters) || imageParameters[rowIndex].Value == nil) { + return true + } + } + return false +} + +// getInsertColumns get insert columns from insert statement +func getInsertColumns(insertStmt *ast.InsertStmt) []string { + if insertStmt == nil { + return nil + } + colList := insertStmt.Columns + if len(colList) == 0 { + return nil + } + var list []string + for _, col := range colList { + list = append(list, col.Name.L) + } + return list +} + +// checkDuplicateKeyUpdate check whether insert on update sql wants to update the duplicate keys +func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) error { + duplicateColsMap := make(map[string]bool) + for _, v := range insert.OnDuplicate { + duplicateColsMap[strings.ToLower(v.Column.Name.L)] = true + } + if len(duplicateColsMap) == 0 { + return nil + } + for _, index := range metaData.Indexs { + if types.IndexTypePrimaryKey != index.IType { + continue + } + for _, col := range index.Columns { + if duplicateColsMap[strings.ToLower(col.ColumnName)] { + log.Errorf("update pk value is not supported! index name:%s update column name: %s", index.Name, col.ColumnName) + return fmt.Errorf("update pk value is not supported! index name:%s update column name: %s", index.Name, col.ColumnName) + } + } + } + return nil +} diff --git a/pkg/datasource/sql/exec/at/insert_on_update_executor_test.go b/pkg/datasource/sql/exec/at/insert_on_update_executor_test.go new file mode 100644 index 000000000..fab3bf6cf --- /dev/null +++ b/pkg/datasource/sql/exec/at/insert_on_update_executor_test.go @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package at + +import ( + "database/sql/driver" + "testing" + + "github.com/seata/seata-go/pkg/datasource/sql/parser" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/datasource/sql/util" + "github.com/stretchr/testify/assert" +) + +func TestInsertOnUpdateBeforeImageSQL(t *testing.T) { + var ( + ioe = insertOnUpdateExecutor{ + beforeImageSqlPrimaryKeys: make(map[string]bool), + } + tableMeta1 types.TableMeta + // one index table + tableMeta2 types.TableMeta + columns = make(map[string]types.ColumnMeta) + index = make(map[string]types.IndexMeta) + index2 = make(map[string]types.IndexMeta) + columnMeta1 []types.ColumnMeta + columnMeta2 []types.ColumnMeta + ColumnNames []string + ) + columnId := types.ColumnMeta{ + ColumnDef: nil, + ColumnName: "id", + } + columnName := types.ColumnMeta{ + ColumnDef: nil, + ColumnName: "name", + } + columnAge := types.ColumnMeta{ + ColumnDef: nil, + ColumnName: "age", + } + columns["id"] = columnId + columns["name"] = columnName + columns["age"] = columnAge + columnMeta1 = append(columnMeta1, columnId) + columnMeta2 = append(columnMeta2, columnName, columnAge) + index["id"] = types.IndexMeta{ + Name: "PRIMARY", + IType: types.IndexTypePrimaryKey, + Columns: columnMeta1, + } + index["id_name_age"] = types.IndexMeta{ + Name: "name_age_idx", + IType: types.IndexUnique, + Columns: columnMeta2, + } + + ColumnNames = []string{"id", "name", "age"} + tableMeta1 = types.TableMeta{ + TableName: "t_user", + Columns: columns, + Indexs: index, + ColumnNames: ColumnNames, + } + + index2["id_name_age"] = types.IndexMeta{ + Name: "name_age_idx", + IType: types.IndexUnique, + Columns: columnMeta2, + } + + tableMeta2 = types.TableMeta{ + TableName: "t_user", + Columns: columns, + Indexs: index2, + ColumnNames: ColumnNames, + } + + tests := []struct { + name string + execCtx *types.ExecContext + sourceQueryArgs []driver.Value + expectQuery1 string + expectQueryArgs1 []driver.Value + expectQuery2 string + expectQueryArgs2 []driver.Value + }{ + { + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name, age) values(?,?,?) on duplicate key update name = ?,age = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1}, + }, + sourceQueryArgs: []driver.Value{1, "Jack1", 81, "Link", 18}, + expectQuery1: "SELECT * FROM t_user WHERE (id = ? ) OR (name = ? and age = ? ) ", + expectQueryArgs1: []driver.Value{1, "Jack1", 81}, + expectQuery2: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (id = ? ) ", + expectQueryArgs2: []driver.Value{"Jack1", 81, 1}, + }, + { + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name, age) values(1,'Jack1',?) on duplicate key update name = 'Michael',age = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1}, + }, + sourceQueryArgs: []driver.Value{81, "Link", 18}, + expectQuery1: "SELECT * FROM t_user WHERE (id = ? ) OR (name = ? and age = ? ) ", + expectQueryArgs1: []driver.Value{int64(1), "Jack1", 81}, + expectQuery2: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (id = ? ) ", + expectQueryArgs2: []driver.Value{"Jack1", 81, int64(1)}, + }, + // multi insert one index + { + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name, age) values(?,?,?),(?,?,?) on duplicate key update name = ?,age = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta2}, + }, + sourceQueryArgs: []driver.Value{1, "Jack1", 81, 2, "Michal", 35, "Link", 18}, + expectQuery1: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (name = ? and age = ? ) ", + expectQueryArgs1: []driver.Value{"Jack1", 81, "Michal", 35}, + }, + { + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name, age) values(?,'Jack1',?),(?,?,35) on duplicate key update name = 'Faker',age = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta2}, + }, + sourceQueryArgs: []driver.Value{1, 81, 2, "Michal", 26}, + expectQuery1: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (name = ? and age = ? ) ", + expectQueryArgs1: []driver.Value{"Jack1", 81, "Michal", int64(35)}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := parser.DoParser(tt.execCtx.Query) + assert.Nil(t, err) + tt.execCtx.ParseContext = c + query, args, err := ioe.buildBeforeImageSQL(tt.execCtx.ParseContext.InsertStmt, tt.execCtx.MetaDataMap["t_user"], util.ValueToNamedValue(tt.sourceQueryArgs)) + assert.Nil(t, err) + if query == tt.expectQuery1 { + assert.Equal(t, tt.expectQuery1, query) + assert.Equal(t, tt.expectQueryArgs1, util.NamedValueToValue(args)) + } else { + assert.Equal(t, tt.expectQuery2, query) + assert.Equal(t, tt.expectQueryArgs2, util.NamedValueToValue(args)) + } + }) + } +} + +func TestInsertOnUpdateAfterImageSQL(t *testing.T) { + ioe := insertOnUpdateExecutor{} + tests := []struct { + name string + beforeSelectSql string + BeforeImageSqlPrimaryKeys map[string]bool + beforeSelectArgs []driver.Value + beforeImage *types.RecordImage + expectQuery string + expectQueryArgs []driver.Value + }{ + { + beforeSelectSql: "SELECT * FROM t_user WHERE (id = ? ) OR (name = ? and age = ? ) ", + BeforeImageSqlPrimaryKeys: map[string]bool{"id": true}, + beforeSelectArgs: []driver.Value{1, "Jack1", 81}, + beforeImage: &types.RecordImage{ + TableName: "t_user", + Rows: []types.RowImage{ + { + Columns: []types.ColumnImage{ + { + KeyType: types.IndexTypePrimaryKey, + ColumnName: "id", + Value: 2, + }, + { + KeyType: types.IndexUnique, + ColumnName: "name", + Value: "Jack", + }, + { + KeyType: types.IndexUnique, + ColumnName: "age", + Value: 18, + }, + }, + }, + }, + }, + expectQuery: "SELECT * FROM t_user WHERE (id = ? ) OR (name = ? and age = ? ) ", + expectQueryArgs: []driver.Value{1, "Jack1", 81}, + }, + { + beforeSelectSql: "SELECT * FROM t_user WHERE (id = ? ) OR (name = ? and age = ? ) OR (id = ? ) OR (name = ? and age = ? ) ", + BeforeImageSqlPrimaryKeys: map[string]bool{"id": true}, + beforeSelectArgs: []driver.Value{1, "Jack1", 30, 2, "Michael", 18}, + beforeImage: &types.RecordImage{ + TableName: "t_user", + Rows: []types.RowImage{ + { + Columns: []types.ColumnImage{ + { + KeyType: types.IndexTypePrimaryKey, + ColumnName: "id", + Value: 1, + }, + { + KeyType: types.IndexUnique, + ColumnName: "name", + Value: "Jack", + }, + { + KeyType: types.IndexUnique, + ColumnName: "age", + Value: 18, + }, + }, + }, + }, + }, + expectQuery: "SELECT * FROM t_user WHERE (id = ? ) OR (name = ? and age = ? ) OR (id = ? ) OR (name = ? and age = ? ) ", + expectQueryArgs: []driver.Value{1, "Jack1", 30, 2, "Michael", 18}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ioe.beforeSelectSql = tt.beforeSelectSql + ioe.beforeImageSqlPrimaryKeys = tt.BeforeImageSqlPrimaryKeys + ioe.beforeSelectArgs = util.ValueToNamedValue(tt.beforeSelectArgs) + query, args := ioe.buildAfterImageSQL(tt.beforeImage) + assert.Equal(t, tt.expectQuery, query) + assert.Equal(t, tt.expectQueryArgs, util.NamedValueToValue(args)) + }) + } +} diff --git a/pkg/datasource/sql/parser/parser_factory.go b/pkg/datasource/sql/parser/parser_factory.go index ab7d9e844..59e127ded 100644 --- a/pkg/datasource/sql/parser/parser_factory.go +++ b/pkg/datasource/sql/parser/parser_factory.go @@ -59,8 +59,8 @@ func parseParseContext(stmtNode ast.StmtNode) *types.ParseContext { if stmt.IsReplace { parserCtx.ExecutorType = types.ReplaceIntoExecutor } - if len(stmt.OnDuplicate) != 0 { + parserCtx.SQLType = types.SQLTypeInsertOnUpdate parserCtx.ExecutorType = types.InsertOnDuplicateExecutor } case *ast.UpdateStmt: @@ -82,6 +82,5 @@ func parseParseContext(stmtNode ast.StmtNode) *types.ParseContext { parserCtx.DeleteStmt = stmt parserCtx.ExecutorType = types.DeleteExecutor } - return parserCtx } diff --git a/pkg/datasource/sql/tx_at.go b/pkg/datasource/sql/tx_at.go index b8e2777a2..deca88b1b 100644 --- a/pkg/datasource/sql/tx_at.go +++ b/pkg/datasource/sql/tx_at.go @@ -18,9 +18,8 @@ package sql import ( - "github.com/seata/seata-go/pkg/datasource/sql/undo" - "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/datasource/sql/undo" ) // ATTx @@ -54,7 +53,6 @@ func (tx *ATTx) Rollback() error { // commitOnAT func (tx *ATTx) commitOnAT() error { originTx := tx.tx - if err := originTx.register(originTx.tranCtx); err != nil { return err } diff --git a/pkg/datasource/sql/types/sql.go b/pkg/datasource/sql/types/sql.go index 3189ce646..c276e29a8 100644 --- a/pkg/datasource/sql/types/sql.go +++ b/pkg/datasource/sql/types/sql.go @@ -28,6 +28,7 @@ const ( SQLTypeUpdate SQLTypeDelete SQLTypeSelectForUpdate + SQLTypeInsertOnUpdate SQLTypeReplace SQLTypeTruncate SQLTypeCreate @@ -67,6 +68,8 @@ func (s SQLType) MarshalText() (text []byte, err error) { return []byte("DELETE"), nil case SQLTypeSelectForUpdate: return []byte("SELECT_FOR_UPDATE"), nil + case SQLTypeInsertOnUpdate: + return []byte("INSERT_ON_UPDATE"), nil case SQLTypeReplace: return []byte("REPLACE"), nil case SQLTypeTruncate: @@ -133,6 +136,8 @@ func (s *SQLType) UnmarshalText(b []byte) error { *s = SQLTypeDelete case "SELECT_FOR_UPDATE": *s = SQLTypeSelectForUpdate + case "INSERT_ON_UPDATE": + *s = SQLTypeInsertOnUpdate case "REPLACE": *s = SQLTypeReplace case "TRUNCATE": diff --git a/pkg/datasource/sql/types/sqltype_string.go b/pkg/datasource/sql/types/sqltype_string.go deleted file mode 100644 index 40a4359ca..000000000 --- a/pkg/datasource/sql/types/sqltype_string.go +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Code generated by "stringer -type=SQLType"; DO NOT EDIT. - -package types - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[SQLTypeUnknown-1] - _ = x[SQLTypeSelect-2] - _ = x[SQLTypeInsert-3] - _ = x[SQLTypeUpdate-4] - _ = x[SQLTypeDelete-5] - _ = x[SQLTypeSelectForUpdate-6] - _ = x[SQLTypeReplace-7] - _ = x[SQLTypeTruncate-8] - _ = x[SQLTypeCreate-9] - _ = x[SQLTypeDrop-10] - _ = x[SQLTypeLoad-11] - _ = x[SQLTypeMerge-12] - _ = x[SQLTypeShow-13] - _ = x[SQLTypeAlter-14] - _ = x[SQLTypeRename-15] - _ = x[SQLTypeDump-16] - _ = x[SQLTypeDebug-17] - _ = x[SQLTypeExplain-18] - _ = x[SQLTypeDesc-19] - _ = x[SQLTypeSet-20] - _ = x[SQLTypeReload-21] - _ = x[SQLTypeSelectUnion-22] - _ = x[SQLTypeCreateTable-23] - _ = x[SQLTypeDropTable-24] - _ = x[SQLTypeAlterTable-25] - _ = x[SQLTypeSelectFromUpdate-26] - _ = x[SQLTypeMultiDelete-27] - _ = x[SQLTypeMultiUpdate-28] - _ = x[SQLTypeCreateIndex-29] - _ = x[SQLTypeDropIndex-30] -} - -const _SQLType_name = "SQLTypeUnknownSQLTypeSelectSQLTypeInsertSQLTypeUpdateSQLTypeDeleteSQLTypeSelectForUpdateSQLTypeReplaceSQLTypeTruncateSQLTypeCreateSQLTypeDropSQLTypeLoadSQLTypeMergeSQLTypeShowSQLTypeAlterSQLTypeRenameSQLTypeDumpSQLTypeDebugSQLTypeExplainSQLTypeDescSQLTypeSetSQLTypeReloadSQLTypeSelectUnionSQLTypeCreateTableSQLTypeDropTableSQLTypeAlterTableSQLTypeSelectFromUpdateSQLTypeMultiDeleteSQLTypeMultiUpdateSQLTypeCreateIndexSQLTypeDropIndex" - -var _SQLType_index = [...]uint16{0, 14, 27, 40, 53, 66, 88, 102, 117, 130, 141, 152, 164, 175, 187, 200, 211, 223, 237, 248, 258, 271, 289, 307, 323, 340, 363, 381, 399, 417, 433} - -func (i SQLType) String() string { - i -= 1 - if i < 0 || i >= SQLType(len(_SQLType_index)-1) { - return "SQLType(" + strconv.FormatInt(int64(i+1), 10) + ")" - } - return _SQLType_name[_SQLType_index[i]:_SQLType_index[i+1]] -} diff --git a/pkg/rm/tcc/fence/handler/tcc_fence_wrapper_handler.go b/pkg/rm/tcc/fence/handler/tcc_fence_wrapper_handler.go index 78abd104f..bf49174b7 100644 --- a/pkg/rm/tcc/fence/handler/tcc_fence_wrapper_handler.go +++ b/pkg/rm/tcc/fence/handler/tcc_fence_wrapper_handler.go @@ -23,10 +23,11 @@ import ( "database/sql" "errors" "fmt" - "github.com/go-sql-driver/mysql" "sync" "time" + "github.com/go-sql-driver/mysql" + "github.com/seata/seata-go/pkg/rm/tcc/fence/enum" "github.com/seata/seata-go/pkg/rm/tcc/fence/store/db/dao" "github.com/seata/seata-go/pkg/rm/tcc/fence/store/db/model" diff --git a/pkg/rm/tcc/tcc_service_test.go b/pkg/rm/tcc/tcc_service_test.go index 09be6d21b..792cf12d5 100644 --- a/pkg/rm/tcc/tcc_service_test.go +++ b/pkg/rm/tcc/tcc_service_test.go @@ -35,6 +35,7 @@ import ( "github.com/seata/seata-go/pkg/rm" "github.com/seata/seata-go/pkg/tm" "github.com/seata/seata-go/pkg/util/log" + //"github.com/seata/seata-go/sample/tcc/dubbo/client/service" testdata2 "github.com/seata/seata-go/testdata" )