diff --git a/executor/batch_checker.go b/executor/batch_checker.go index 3496a306664a6..12d2b7c63a834 100644 --- a/executor/batch_checker.go +++ b/executor/batch_checker.go @@ -16,12 +16,14 @@ package executor import ( "github.com/pingcap/errors" "github.com/pingcap/parser/model" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" ) type keyValue struct { @@ -284,7 +286,8 @@ func (b *batchChecker) deleteDupKeys(ctx sessionctx.Context, t table.Table, rows // getOldRow gets the table record row from storage for batch check. // t could be a normal table or a partition, but it must not be a PartitionedTable. -func (b *batchChecker) getOldRow(ctx sessionctx.Context, t table.Table, handle int64) ([]types.Datum, error) { +func (b *batchChecker) getOldRow(ctx sessionctx.Context, t table.Table, handle int64, + genExprs []expression.Expression) ([]types.Datum, error) { oldValue, ok := b.dupOldRowValues[string(t.RecordKey(handle))] if !ok { return nil, errors.NotFoundf("can not be duplicated row, due to old row not found. handle %d", handle) @@ -295,6 +298,7 @@ func (b *batchChecker) getOldRow(ctx sessionctx.Context, t table.Table, handle i return nil, err } // Fill write-only and write-reorg columns with originDefaultValue if not found in oldValue. + gIdx := 0 for _, col := range cols { if col.State != model.StatePublic && oldRow[col.Offset].IsNull() { _, found := oldRowMap[col.ID] @@ -305,6 +309,20 @@ func (b *batchChecker) getOldRow(ctx sessionctx.Context, t table.Table, handle i } } } + if col.IsGenerated() { + // only the virtual column needs fill back. + if !col.GeneratedStored { + val, err := genExprs[gIdx].Eval(chunk.MutRowFromDatums(oldRow).ToRow()) + if err != nil { + return nil, err + } + oldRow[col.Offset], err = table.CastValue(ctx, val, col.ToInfo()) + if err != nil { + return nil, err + } + } + gIdx++ + } } return oldRow, nil } diff --git a/executor/insert.go b/executor/insert.go index fc43828c25550..e65681157d2ec 100644 --- a/executor/insert.go +++ b/executor/insert.go @@ -169,7 +169,7 @@ func (e *InsertExec) Open(ctx context.Context) error { // updateDupRow updates a duplicate row to a new row. func (e *InsertExec) updateDupRow(row toBeCheckedRow, handle int64, onDuplicate []*expression.Assignment) error { - oldRow, err := e.getOldRow(e.ctx, e.Table, handle) + oldRow, err := e.getOldRow(e.ctx, e.Table, handle, e.GenExprs) if err != nil { logutil.Logger(context.Background()).Error("get old row failed when insert on dup", zap.Int64("handle", handle), zap.String("toBeInsertedRow", types.DatumsToStrNoErr(row.row))) return err diff --git a/executor/replace.go b/executor/replace.go index d900744d25dd1..26d3875eddb4e 100644 --- a/executor/replace.go +++ b/executor/replace.go @@ -55,7 +55,7 @@ func (e *ReplaceExec) Open(ctx context.Context) error { // but if the to-be-removed row equals to the to-be-added row, no remove or add things to do. func (e *ReplaceExec) removeRow(handle int64, r toBeCheckedRow) (bool, error) { newRow := r.row - oldRow, err := e.batchChecker.getOldRow(e.ctx, r.t, handle) + oldRow, err := e.batchChecker.getOldRow(e.ctx, r.t, handle, e.GenExprs) if err != nil { logutil.Logger(context.Background()).Error("get old row failed when replace", zap.Int64("handle", handle), zap.String("toBeInsertedRow", types.DatumsToStrNoErr(r.row))) return false, err diff --git a/executor/write_test.go b/executor/write_test.go index 1c1ba50866b2a..798fec9133615 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -900,6 +900,47 @@ func (s *testSuite4) TestReplace(c *C) { tk.CheckLastMessage("Records: 1 Duplicates: 1 Warnings: 0") } +func (s *testSuite2) TestGeneratedColumnForInsert(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + // test cases for default behavior + tk.MustExec(`drop table if exists t1;`) + tk.MustExec(`create table t1(id int, id_gen int as(id + 42), b int, unique key id_gen(id_gen));`) + tk.MustExec(`insert into t1 (id, b) values(1,1),(2,2),(3,3),(4,4),(5,5);`) + tk.MustExec(`replace into t1 (id, b) values(1,1);`) + tk.MustExec(`replace into t1 (id, b) values(1,1),(2,2);`) + tk.MustExec(`replace into t1 (id, b) values(6,16),(7,17),(8,18);`) + tk.MustQuery("select * from t1;").Check(testkit.Rows( + "1 43 1", "2 44 2", "3 45 3", "4 46 4", "5 47 5", "6 48 16", "7 49 17", "8 50 18")) + tk.MustExec(`insert into t1 (id, b) values (6,18) on duplicate key update id = -id;`) + tk.MustExec(`insert into t1 (id, b) values (7,28) on duplicate key update b = -values(b);`) + tk.MustQuery("select * from t1;").Check(testkit.Rows( + "1 43 1", "2 44 2", "3 45 3", "4 46 4", "5 47 5", "-6 36 16", "7 49 -28", "8 50 18")) + + // test cases for virtual and stored columns in the same table + tk.MustExec(`drop table if exists t`) + tk.MustExec(`create table t + (i int as(k+1) stored, j int as(k+2) virtual, k int, unique key idx_i(i), unique key idx_j(j))`) + tk.MustExec(`insert into t (k) values (1), (2)`) + tk.MustExec(`replace into t (k) values (1), (2)`) + tk.MustQuery(`select * from t`).Check(testkit.Rows("2 3 1", "3 4 2")) + + tk.MustExec(`drop table if exists t`) + tk.MustExec(`create table t + (i int as(k+1) stored, j int as(k+2) virtual, k int, unique key idx_j(j))`) + tk.MustExec(`insert into t (k) values (1), (2)`) + tk.MustExec(`replace into t (k) values (1), (2)`) + tk.MustQuery(`select * from t`).Check(testkit.Rows("2 3 1", "3 4 2")) + + tk.MustExec(`drop table if exists t`) + tk.MustExec(`create table t + (i int as(k+1) stored, j int as(k+2) virtual, k int, unique key idx_i(i))`) + tk.MustExec(`insert into t (k) values (1), (2)`) + tk.MustExec(`replace into t (k) values (1), (2)`) + tk.MustQuery(`select * from t`).Check(testkit.Rows("2 3 1", "3 4 2")) +} + func (s *testSuite4) TestPartitionedTableReplace(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test")