From 8ec4cf000c0805867b1cd0474d77a39a083479fd Mon Sep 17 00:00:00 2001 From: Jack Yu Date: Tue, 17 Apr 2018 11:36:49 +0800 Subject: [PATCH] cherry pick https://github.com/pingcap/tidb/pull/6249 (#6280) --- ddl/ddl_db_change_test.go | 8 ++++---- ddl/ddl_db_test.go | 32 ++++++++++++++++++++++++++++++++ executor/write.go | 25 ++++++++++++++++++++++++- table/tables/tables.go | 25 ++++++++++++++++++------- 4 files changed, 78 insertions(+), 12 deletions(-) diff --git a/ddl/ddl_db_change_test.go b/ddl/ddl_db_change_test.go index 3635979aeb6e0..a6cb2a618120b 100644 --- a/ddl/ddl_db_change_test.go +++ b/ddl/ddl_db_change_test.go @@ -375,13 +375,13 @@ func (s *testStateChangeSuite) TestParallelDDL(c *C) { return } var qLen int64 - var err error + var err1 error for { kv.RunInNewTxn(s.store, false, func(txn kv.Transaction) error { m := meta.NewMeta(txn) - qLen, err = m.DDLJobQueueLen() - if err != nil { - return err + qLen, err1 = m.DDLJobQueueLen() + if err1 != nil { + return err1 } return nil }) diff --git a/ddl/ddl_db_test.go b/ddl/ddl_db_test.go index fe36f5642ac82..959085afb1b15 100644 --- a/ddl/ddl_db_test.go +++ b/ddl/ddl_db_test.go @@ -20,6 +20,7 @@ import ( "math/rand" "strconv" "strings" + "sync" "time" "github.com/juju/errors" @@ -1726,3 +1727,34 @@ func (s *testDBSuite) TestRebaseAutoID(c *C) { s.tk.MustExec("create table tidb.test2 (a int);") s.testErrorCode(c, "alter table tidb.test2 add column b int auto_increment key, auto_increment=10;", tmysql.ErrUnknown) } + +func (s *testDBSuite) TestAddNotNullColumnWhileInsertOnDupUpdate(c *C) { + tk1 := testkit.NewTestKit(c, s.store) + tk1.MustExec("use " + s.schemaName) + tk2 := testkit.NewTestKit(c, s.store) + tk2.MustExec("use " + s.schemaName) + closeCh := make(chan bool) + wg := new(sync.WaitGroup) + wg.Add(1) + tk1.MustExec("create table nn (a int primary key, b int)") + tk1.MustExec("insert nn values (1, 1)") + var tk2Err error + go func() { + defer wg.Done() + for { + select { + case <-closeCh: + return + default: + } + _, tk2Err = tk2.Exec("insert nn (a, b) values (1, 1) on duplicate key update a = 1, b = b + 1") + if tk2Err != nil { + return + } + } + }() + tk1.MustExec("alter table nn add column c int not null default 0") + close(closeCh) + wg.Wait() + c.Assert(tk2Err, IsNil) +} diff --git a/executor/write.go b/executor/write.go index af5d5cd589314..30015e8000f53 100644 --- a/executor/write.go +++ b/executor/write.go @@ -24,8 +24,10 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/util/types" ) @@ -1302,10 +1304,31 @@ func (e *InsertValues) adjustAutoIncrementDatum(row []types.Datum, i int, c *tab return nil } +// rowWithCols is **ONLY** used for onDuplicateUpdate to avoid null value inserted to not null column. +func (e *InsertExec) rowWithCols(h int64, cols []*table.Column) ([]types.Datum, error) { + oldRow, oldRowMap, err := tables.RowWithColsForRead(e.ctx, e.Table, h, cols) + if err != nil { + return nil, errors.Trace(err) + } + // Fill write-only and write-reorg columns with originDefaultValue if not found in oldValue. + for _, col := range cols { + if col.State != model.StatePublic && oldRow[col.Offset].IsNull() { + _, found := oldRowMap[col.ID] + if !found { + oldRow[col.Offset], err = table.GetColOriginDefaultValue(e.ctx, col.ToInfo()) + if err != nil { + return nil, errors.Trace(err) + } + } + } + } + return oldRow, nil +} + // onDuplicateUpdate updates the duplicate row. // TODO: Report rows affected and last insert id. func (e *InsertExec) onDuplicateUpdate(row []types.Datum, h int64, cols []*expression.Assignment) error { - data, err := e.Table.RowWithCols(e.ctx, h, e.Table.WritableCols()) + data, err := e.rowWithCols(h, e.Table.WritableCols()) if err != nil { return errors.Trace(err) } diff --git a/table/tables/tables.go b/table/tables/tables.go index 5a524626e469a..dda8ec8b72330 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -470,13 +470,15 @@ func (t *Table) addIndices(ctx context.Context, recordID int64, r []types.Datum, return 0, nil } -// RowWithCols implements table.Table RowWithCols interface. -func (t *Table) RowWithCols(ctx context.Context, h int64, cols []*table.Column) ([]types.Datum, error) { +// RowWithColsForRead is used to return a row get from the storage, it does not set the default value for non public columns, +// because non public columns should not be returned as a result. +// It also return a row value map for duplicate update which needs to set a default value for non public columns to do the update. +func RowWithColsForRead(ctx context.Context, t table.Table, h int64, cols []*table.Column) ([]types.Datum, map[int64]types.Datum, error) { // Get raw row data from kv. key := t.RecordKey(h) value, err := ctx.Txn().Get(key) if err != nil { - return nil, errors.Trace(err) + return nil, nil, errors.Trace(err) } // Decode raw row data. v := make([]types.Datum, len(cols)) @@ -485,7 +487,7 @@ func (t *Table) RowWithCols(ctx context.Context, h int64, cols []*table.Column) if col == nil { continue } - if col.IsPKHandleColumn(t.meta) { + if col.IsPKHandleColumn(t.Meta()) { if mysql.HasUnsignedFlag(col.Flag) { v[i].SetUint64(uint64(h)) } else { @@ -497,14 +499,14 @@ func (t *Table) RowWithCols(ctx context.Context, h int64, cols []*table.Column) } rowMap, err := tablecodec.DecodeRow(value, colTps, ctx.GetSessionVars().GetTimeZone()) if err != nil { - return nil, errors.Trace(err) + return nil, nil, errors.Trace(err) } defaultVals := make([]types.Datum, len(cols)) for i, col := range cols { if col == nil { continue } - if col.IsPKHandleColumn(t.meta) { + if col.IsPKHandleColumn(t.Meta()) { continue } ri, ok := rowMap[col.ID] @@ -514,9 +516,18 @@ func (t *Table) RowWithCols(ctx context.Context, h int64, cols []*table.Column) } v[i], err = GetColDefaultValue(ctx, col, defaultVals) if err != nil { - return nil, errors.Trace(err) + return nil, nil, errors.Trace(err) } } + return v, rowMap, nil +} + +// RowWithCols implements table.Table RowWithCols interface. +func (t *Table) RowWithCols(ctx context.Context, h int64, cols []*table.Column) ([]types.Datum, error) { + v, _, err := RowWithColsForRead(ctx, t, h, cols) + if err != nil { + return nil, errors.Trace(err) + } return v, nil }