Skip to content

Commit

Permalink
cherry pick pingcap#11389
Browse files Browse the repository at this point in the history
  • Loading branch information
bb7133 committed Jun 11, 2021
1 parent 29f559a commit bd86369
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
32 changes: 25 additions & 7 deletions executor/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
package executor

import (
"context"
"fmt"
"math"

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
Expand All @@ -28,7 +30,6 @@ import (
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/zap"
"golang.org/x/net/context"
)

// InsertValues is the data to insert.
Expand Down Expand Up @@ -470,12 +471,10 @@ func (e *InsertValues) adjustAutoIncrementDatum(d types.Datum, hasValue bool, c
d.SetNull()
}
if !d.IsNull() {
sc := e.ctx.GetSessionVars().StmtCtx
datum, err1 := d.ConvertTo(sc, &c.FieldType)
if e.filterErr(err1) != nil {
return types.Datum{}, err1
recordID, err = getAutoRecordID(d, &c.FieldType, true)
if err != nil {
return types.Datum{}, err
}
recordID = datum.GetInt64()
}
// Use the value if it's not null and not 0.
if recordID != 0 {
Expand All @@ -485,7 +484,6 @@ func (e *InsertValues) adjustAutoIncrementDatum(d types.Datum, hasValue bool, c
}
e.ctx.GetSessionVars().StmtCtx.InsertID = uint64(recordID)
retryInfo.AddAutoIncrementID(recordID)
d.SetAutoID(recordID, c.Flag)
return d, nil
}

Expand Down Expand Up @@ -513,6 +511,26 @@ func (e *InsertValues) adjustAutoIncrementDatum(d types.Datum, hasValue bool, c
return casted, nil
}

func getAutoRecordID(d types.Datum, target *types.FieldType, isInsert bool) (int64, error) {
var recordID int64

switch target.Tp {
case mysql.TypeFloat, mysql.TypeDouble:
f := d.GetFloat64()
if isInsert {
recordID = int64(math.Round(f))
} else {
recordID = int64(f)
}
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
recordID = d.GetInt64()
default:
return 0, errors.Errorf("unexpected field type [%v]", target.Tp)
}

return recordID, nil
}

func (e *InsertValues) handleWarning(err error, logInfo string) {
sc := e.ctx.GetSessionVars().StmtCtx
sc.AppendWarning(err)
Expand Down
8 changes: 6 additions & 2 deletions executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu
modified[i] = true
// Rebase auto increment id if the field is changed.
if mysql.HasAutoIncrementFlag(col.Flag) {
if err = t.RebaseAutoID(ctx, newData[i].GetInt64(), true); err != nil {
return false, false, 0, errors.Trace(err)
recordID, err := getAutoRecordID(newData[i], &col.FieldType, false)
if err != nil {
return false, false, 0, err
}
if err = t.RebaseAutoID(ctx, recordID, true); err != nil {
return false, false, 0, err
}
}
if col.IsPKHandleColumn(t.Meta()) {
Expand Down

0 comments on commit bd86369

Please sign in to comment.