Skip to content

Commit

Permalink
types: use flags in types package to handle clip zero case
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao committed Oct 11, 2023
1 parent 1f795d2 commit ca41ccd
Show file tree
Hide file tree
Showing 13 changed files with 69 additions and 46 deletions.
4 changes: 4 additions & 0 deletions br/pkg/lightning/backend/kv/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/mathutil"
"github.com/pingcap/tidb/util/topsql/stmtstats"
"go.uber.org/zap"
Expand Down Expand Up @@ -312,6 +313,9 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session {
}
}
vars.StmtCtx.TimeZone = vars.Location()
vars.StmtCtx.ResetTypeContext(types.StrictFlags.
WithClipNegativeToZero(true),
)
if err := vars.SetSystemVar("timestamp", strconv.FormatInt(options.Timestamp, 10)); err != nil {
logger.Warn("new session: failed to set timestamp",
log.ShortError(err))
Expand Down
1 change: 0 additions & 1 deletion ddl/backfilling_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ func initSessCtx(
tz := *time.UTC
sessCtx.GetSessionVars().StmtCtx.TimeZone = &tz
}
sessCtx.GetSessionVars().StmtCtx.IsDDLJobInQueue = true
// Set the row encode format version.
rowFormat := variable.GetDDLReorgRowFormat()
sessCtx.GetSessionVars().RowEncoder.Enable = rowFormat != variable.DefTiDBRowFormatV1
Expand Down
11 changes: 7 additions & 4 deletions executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1942,7 +1942,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
} else {
sc = vars.InitStatementContext()
}
var typeConvFlags types.Flags
var typeFlags types.Flags
sc.TimeZone = vars.Location()
sc.TaskID = stmtctx.AllocateTaskID()
sc.CTEStorageMap = map[int]*CTEStorages{}
Expand Down Expand Up @@ -2139,10 +2139,14 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
}

typeConvFlags = typeConvFlags.
typeFlags = typeFlags.
WithSkipUTF8Check(vars.SkipUTF8Check).
WithSkipSACIICheck(vars.SkipASCIICheck).
WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load())
WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load()).
// WithClipNegativeToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types.
// This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode.
// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html
WithClipNegativeToZero(sc.InInsertStmt || sc.InLoadDataStmt || sc.InUpdateStmt || sc.InCreateOrAlterStmt)

vars.PlanCacheParams.Reset()
if priority := mysql.PriorityEnum(atomic.LoadInt32(&variable.ForcePriority)); priority != mysql.NoPriority {
Expand All @@ -2169,7 +2173,6 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.RuntimeStatsColl = execdetails.NewRuntimeStatsColl(reuseObj)
}

sc.TypeConvContext = types.NewContext(typeConvFlags, sc.TimeZone, sc.AppendWarning)
sc.TblInfo2UnionScan = make(map[*model.TableInfo]bool)
errCount, warnCount := vars.StmtCtx.NumErrorWarnings()
vars.SysErrorCount = errCount
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ func (b *builtinCastRealAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool
} else {
var uintVal uint64
sc := b.ctx.GetSessionVars().StmtCtx
uintVal, err = types.ConvertFloatToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong)
uintVal, err = types.ConvertFloatToUint(sc.TypeFlags(), val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong)
res = int64(uintVal)
}
if types.ErrOverflow.Equal(err) {
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_cast_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ func (b *builtinCastRealAsIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.C
} else {
var uintVal uint64
sc := b.ctx.GetSessionVars().StmtCtx
uintVal, err = types.ConvertFloatToUint(sc, f64s[i], types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong)
uintVal, err = types.ConvertFloatToUint(sc.TypeFlags(), f64s[i], types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong)
i64s[i] = int64(uintVal)
}
if types.ErrOverflow.Equal(err) {
Expand Down
20 changes: 13 additions & 7 deletions sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -1102,13 +1102,6 @@ func (sc *StatementContext) GetExecDetails() execdetails.ExecDetails {
return details
}

// ShouldClipToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types.
// This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode.
// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html
func (sc *StatementContext) ShouldClipToZero() bool {
return sc.InInsertStmt || sc.InLoadDataStmt || sc.InUpdateStmt || sc.InCreateOrAlterStmt || sc.IsDDLJobInQueue
}

// ShouldIgnoreOverflowError indicates whether we should ignore the error when type conversion overflows,
// so we can leave it for further processing like clipping values less than 0 to 0 for unsigned integer types.
func (sc *StatementContext) ShouldIgnoreOverflowError() bool {
Expand Down Expand Up @@ -1201,9 +1194,22 @@ func (sc *StatementContext) SetFlagsFromPBFlag(flags uint64) {
sc.TruncateAsWarning = (flags & model.FlagTruncateAsWarning) > 0
sc.InInsertStmt = (flags & model.FlagInInsertStmt) > 0
sc.InSelectStmt = (flags & model.FlagInSelectStmt) > 0
sc.InDeleteStmt = (flags & model.FlagInUpdateOrDeleteStmt) > 0
sc.OverflowAsWarning = (flags & model.FlagOverflowAsWarning) > 0
sc.IgnoreZeroInDate = (flags & model.FlagIgnoreZeroInDate) > 0
sc.DividedByZeroAsWarning = (flags & model.FlagDividedByZeroAsWarning) > 0
typeFlags := typectx.StrictFlags.WithClipNegativeToZero(sc.InInsertStmt)
sc.ResetTypeContext(typeFlags)
}

// TypeFlags returns the flags used by types package.
func (sc *StatementContext) TypeFlags() typectx.Flags {
return sc.TypeConvContext.Flags()
}

// ResetTypeContext resets the inner type context.
func (sc *StatementContext) ResetTypeContext(flags typectx.Flags) {
sc.TypeConvContext = typectx.NewContext(flags, sc.TimeZone, sc.AppendWarning)
}

// GetLockWaitStartTime returns the statement pessimistic lock wait start time
Expand Down
10 changes: 1 addition & 9 deletions store/mockstore/mockcopr/cop_handler_dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,15 +468,7 @@ func (e *evalContext) decodeRelatedColumnVals(relatedColOffsets []int, value [][
// flagsToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`.
func flagsToStatementContext(flags uint64) *stmtctx.StatementContext {
sc := new(stmtctx.StatementContext)
sc.IgnoreTruncate.Store((flags & model.FlagIgnoreTruncate) > 0)
sc.TruncateAsWarning = (flags & model.FlagTruncateAsWarning) > 0
sc.InInsertStmt = (flags & model.FlagInInsertStmt) > 0
sc.InSelectStmt = (flags & model.FlagInSelectStmt) > 0
sc.InDeleteStmt = (flags & model.FlagInUpdateOrDeleteStmt) > 0
sc.OverflowAsWarning = (flags & model.FlagOverflowAsWarning) > 0
sc.IgnoreZeroInDate = (flags & model.FlagIgnoreZeroInDate) > 0
sc.DividedByZeroAsWarning = (flags & model.FlagDividedByZeroAsWarning) > 0
// TODO set FlagInSetOprStmt,
sc.SetFlagsFromPBFlag(flags)
return sc
}

Expand Down
9 changes: 1 addition & 8 deletions store/mockstore/unistore/cophandler/cop_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,14 +423,7 @@ func newRowDecoder(columnInfos []*tipb.ColumnInfo, fieldTps []*types.FieldType,
// flagsToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`.
func flagsToStatementContext(flags uint64) *stmtctx.StatementContext {
sc := new(stmtctx.StatementContext)
sc.IgnoreTruncate.Store((flags & model.FlagIgnoreTruncate) > 0)
sc.TruncateAsWarning = (flags & model.FlagTruncateAsWarning) > 0
sc.InInsertStmt = (flags & model.FlagInInsertStmt) > 0
sc.InSelectStmt = (flags & model.FlagInSelectStmt) > 0
sc.InDeleteStmt = (flags & model.FlagInUpdateOrDeleteStmt) > 0
sc.OverflowAsWarning = (flags & model.FlagOverflowAsWarning) > 0
sc.IgnoreZeroInDate = (flags & model.FlagIgnoreZeroInDate) > 0
sc.DividedByZeroAsWarning = (flags & model.FlagDividedByZeroAsWarning) > 0
sc.SetFlagsFromPBFlag(flags)
return sc
}

Expand Down
2 changes: 1 addition & 1 deletion table/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ func FillVirtualColumnValue(virtualRetTypes []*types.FieldType, virtualColumnInd
}

// Clip to zero if get negative value after cast to unsigned.
if mysql.HasUnsignedFlag(colInfos[idx].FieldType.GetFlag()) && !castDatum.IsNull() && !sctx.GetSessionVars().StmtCtx.ShouldClipToZero() {
if mysql.HasUnsignedFlag(colInfos[idx].FieldType.GetFlag()) && !castDatum.IsNull() && !sctx.GetSessionVars().StmtCtx.TypeFlags().ClipNegativeToZero() {
switch datum.Kind() {
case types.KindInt64:
if datum.GetInt64() < 0 {
Expand Down
16 changes: 16 additions & 0 deletions types/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ const (
FlagSkipUTF8MB4Check
)

// ClipNegativeToZero indicates whether the flag `FlagClipNegativeToZero` is set
func (f Flags) ClipNegativeToZero() bool {
return f&FlagClipNegativeToZero != 0
}

// WithClipNegativeToZero returns a new flags with `FlagClipNegativeToZero` set/unset according to the clip parameter
func (f Flags) WithClipNegativeToZero(clip bool) Flags {
if clip {
return f | FlagClipNegativeToZero
}
return f &^ FlagClipNegativeToZero
}

// SkipASCIICheck indicates whether the flag `FlagSkipASCIICheck` is set
func (f Flags) SkipASCIICheck() bool {
return f&FlagSkipASCIICheck != 0
Expand Down Expand Up @@ -130,6 +143,9 @@ func (c *Context) WithFlags(f Flags) Context {

// Location returns the location of the context
func (c *Context) Location() *time.Location {
if c.loc == nil {
return time.UTC
}
return c.loc
}

Expand Down
16 changes: 13 additions & 3 deletions types/context/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,23 @@ func TestWithNewFlags(t *testing.T) {
require.Equal(t, time.UTC, ctx2.Location())
}

func TestStringFlags(t *testing.T) {
func TestSimpleOnOffFlags(t *testing.T) {
cases := []struct {
name string
flag Flags
readFn func(f Flags) bool
writeFn func(f Flags, skip bool) Flags
readFn func(Flags) bool
writeFn func(Flags, bool) Flags
}{
{
name: "FlagClipNegativeToZero",
flag: FlagClipNegativeToZero,
readFn: func(f Flags) bool {
return f.ClipNegativeToZero()
},
writeFn: func(f Flags, clip bool) Flags {
return f.WithClipNegativeToZero(clip)
},
},
{
name: "FlagSkipASCIICheck",
flag: FlagSkipASCIICheck,
Expand Down
12 changes: 6 additions & 6 deletions types/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) {
}

// ConvertIntToUint converts an int value to an uint value.
func ConvertIntToUint(sc *stmtctx.StatementContext, val int64, upperBound uint64, tp byte) (uint64, error) {
if sc.ShouldClipToZero() && val < 0 {
func ConvertIntToUint(flags Flags, val int64, upperBound uint64, tp byte) (uint64, error) {
if flags.ClipNegativeToZero() && val < 0 {
return 0, overflow(val, tp)
}

Expand All @@ -167,10 +167,10 @@ func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) {
}

// ConvertFloatToUint converts a float value to an uint value.
func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) {
func ConvertFloatToUint(flags Flags, fval float64, upperBound uint64, tp byte) (uint64, error) {
val := RoundFloat(fval)
if val < 0 {
if sc.ShouldClipToZero() {
if flags.ClipNegativeToZero() {
return 0, overflow(val, tp)
}
return uint64(int64(val)), overflow(val, tp)
Expand Down Expand Up @@ -585,7 +585,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j BinaryJSON, unsigned bool,
i := j.GetInt64()
if unsigned {
uBound := IntergerUnsignedUpperBound(tp)
u, err := ConvertIntToUint(sc, i, uBound, tp)
u, err := ConvertIntToUint(sc.TypeFlags(), i, uBound, tp)
return int64(u), sc.HandleOverflow(err, err)
}

Expand Down Expand Up @@ -613,7 +613,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j BinaryJSON, unsigned bool,
return u, sc.HandleOverflow(e, e)
}
bound := IntergerUnsignedUpperBound(tp)
u, err := ConvertFloatToUint(sc, f, bound, tp)
u, err := ConvertFloatToUint(sc.TypeFlags(), f, bound, tp)
return int64(u), sc.HandleOverflow(err, err)
case JSONTypeCodeString:
str := string(hack.String(j.GetString()))
Expand Down
10 changes: 5 additions & 5 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -1186,11 +1186,11 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
)
switch d.k {
case KindInt64:
val, err = ConvertIntToUint(sc, d.GetInt64(), upperBound, tp)
val, err = ConvertIntToUint(sc.TypeFlags(), d.GetInt64(), upperBound, tp)
case KindUint64:
val, err = ConvertUintToUint(d.GetUint64(), upperBound, tp)
case KindFloat32, KindFloat64:
val, err = ConvertFloatToUint(sc, d.GetFloat64(), upperBound, tp)
val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetFloat64(), upperBound, tp)
case KindString, KindBytes:
uval, err1 := StrToUint(sc, d.GetString(), false)
if err1 != nil && ErrOverflow.Equal(err1) && !sc.ShouldIgnoreOverflowError() {
Expand All @@ -1207,7 +1207,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
if err == nil {
err = err1
}
val, err1 = ConvertIntToUint(sc, ival, upperBound, tp)
val, err1 = ConvertIntToUint(sc.TypeFlags(), ival, upperBound, tp)
if err == nil {
err = err1
}
Expand All @@ -1222,9 +1222,9 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
case KindMysqlDecimal:
val, err = ConvertDecimalToUint(sc, d.GetMysqlDecimal(), upperBound, tp)
case KindMysqlEnum:
val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp)
val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetMysqlEnum().ToNumber(), upperBound, tp)
case KindMysqlSet:
val, err = ConvertFloatToUint(sc, d.GetMysqlSet().ToNumber(), upperBound, tp)
val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetMysqlSet().ToNumber(), upperBound, tp)
case KindBinaryLiteral, KindMysqlBit:
val, err = d.GetBinaryLiteral().ToInt(sc)
if err == nil {
Expand Down

0 comments on commit ca41ccd

Please sign in to comment.