Skip to content

Commit

Permalink
*: fix writing null value into not null column in write-only state. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
coocood authored Apr 10, 2018
1 parent 5200cf1 commit 2b2522b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
32 changes: 32 additions & 0 deletions ddl/ddl_db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"math/rand"
"strconv"
"strings"
"sync"
"time"

"github.com/juju/errors"
Expand Down Expand Up @@ -1811,3 +1812,34 @@ func (s *testDBSuite) TestCharacterSetInColumns(c *C) {
s.tk.MustQuery("select count(*) from information_schema.columns where table_schema = 'varchar_test' and character_set_name != 'utf8'").Check(testkit.Rows("0"))
s.tk.MustQuery("select count(*) from information_schema.columns where table_schema = 'varchar_test' and character_set_name = 'utf8'").Check(testkit.Rows("2"))
}

func (s *testDBSuite) TestAddNotNullColumnWhileInsertOnDupUpdate(c *C) {
tk1 := testkit.NewTestKit(c, s.store)
tk1.MustExec("use " + s.schemaName)
tk2 := testkit.NewTestKit(c, s.store)
tk2.MustExec("use " + s.schemaName)
closeCh := make(chan bool)
wg := new(sync.WaitGroup)
wg.Add(1)
tk1.MustExec("create table nn (a int primary key, b int)")
tk1.MustExec("insert nn values (1, 1)")
var tk2Err error
go func() {
defer wg.Done()
for {
select {
case <-closeCh:
return
default:
}
_, tk2Err = tk2.Exec("insert nn (a, b) values (1, 1) on duplicate key update a = 1, b = b + 1")
if tk2Err != nil {
return
}
}
}()
tk1.MustExec("alter table nn add column c int not null default 0")
close(closeCh)
wg.Wait()
c.Assert(tk2Err, IsNil)
}
17 changes: 15 additions & 2 deletions executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ func batchGetOldValues(ctx sessionctx.Context, t table.Table, handles []int64) (
func encodeNewRow(ctx sessionctx.Context, t table.Table, row []types.Datum) ([]byte, error) {
colIDs := make([]int64, 0, len(row))
skimmedRow := make([]types.Datum, 0, len(row))
for _, col := range t.WritableCols() {
for _, col := range t.Cols() {
if !tables.CanSkip(t.Meta(), col, row[col.Offset]) {
colIDs = append(colIDs, col.ID)
skimmedRow = append(skimmedRow, row[col.Offset])
Expand Down Expand Up @@ -1107,10 +1107,23 @@ func (e *InsertExec) updateDupRow(keys []keyWithDupError, k keyWithDupError, val
if !ok {
return errors.NotFoundf("can not be duplicated row, due to old row not found. handle %d", oldHandle)
}
oldRow, err := tables.DecodeRawRowData(e.ctx, e.Table.Meta(), oldHandle, e.Table.WritableCols(), oldValue)
cols := e.Table.WritableCols()
oldRow, oldRowMap, err := tables.DecodeRawRowData(e.ctx, e.Table.Meta(), oldHandle, cols, oldValue)
if err != nil {
return errors.Trace(err)
}
// Fill write-only and write-reorg columns with originDefaultValue if not found in oldValue.
for _, col := range cols {
if col.State != model.StatePublic && oldRow[col.Offset].IsNull() {
_, found := oldRowMap[col.ID]
if !found {
oldRow[col.Offset], err = table.GetColOriginDefaultValue(e.ctx, col.ToInfo())
if err != nil {
return errors.Trace(err)
}
}
}
}

// Do update row.
updatedRow, handleChanged, newHandle, err := e.doDupRowUpdate(oldHandle, oldRow, newRow, onDuplicate)
Expand Down
13 changes: 6 additions & 7 deletions table/tables/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,17 +506,16 @@ func (t *Table) RowWithCols(ctx sessionctx.Context, h int64, cols []*table.Colum
if err != nil {
return nil, errors.Trace(err)
}
v, err := DecodeRawRowData(ctx, t.Meta(), h, cols, value)
v, _, err := DecodeRawRowData(ctx, t.Meta(), h, cols, value)
if err != nil {
return nil, errors.Trace(err)
}
return v, nil
}

// DecodeRawRowData decodes raw row data to a datum row.
// DecodeRawRowData decodes raw row data into a datum slice and a (columnID:columnValue) map.
func DecodeRawRowData(ctx sessionctx.Context, meta *model.TableInfo, h int64, cols []*table.Column,
value []byte) ([]types.Datum,
error) {
value []byte) ([]types.Datum, map[int64]types.Datum, error) {
v := make([]types.Datum, len(cols))
colTps := make(map[int64]*types.FieldType, len(cols))
for i, col := range cols {
Expand All @@ -535,7 +534,7 @@ func DecodeRawRowData(ctx sessionctx.Context, meta *model.TableInfo, h int64, co
}
rowMap, err := tablecodec.DecodeRow(value, colTps, ctx.GetSessionVars().GetTimeZone())
if err != nil {
return nil, errors.Trace(err)
return nil, rowMap, errors.Trace(err)
}
defaultVals := make([]types.Datum, len(cols))
for i, col := range cols {
Expand All @@ -552,10 +551,10 @@ func DecodeRawRowData(ctx sessionctx.Context, meta *model.TableInfo, h int64, co
}
v[i], err = GetColDefaultValue(ctx, col, defaultVals)
if err != nil {
return nil, errors.Trace(err)
return nil, rowMap, errors.Trace(err)
}
}
return v, nil
return v, rowMap, nil
}

// Row implements table.Table Row interface.
Expand Down

0 comments on commit 2b2522b

Please sign in to comment.