Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

types: use flags in types package to handle clip zero case #47543

Merged
merged 6 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions br/pkg/lightning/backend/kv/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/topsql/stmtstats"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -313,6 +314,9 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session {
}
}
vars.StmtCtx.SetTimeZone(vars.Location())
vars.StmtCtx.SetTypeFlags(types.StrictFlags.
WithClipNegativeToZero(true),
)
if err := vars.SetSystemVar("timestamp", strconv.FormatInt(options.Timestamp, 10)); err != nil {
logger.Warn("new session: failed to set timestamp",
log.ShortError(err))
Expand Down
16 changes: 6 additions & 10 deletions pkg/ddl/backfilling_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"fmt"
"sync"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/ddl/copr"
Expand All @@ -33,6 +32,7 @@ import (
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/table"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/dbterror"
"github.com/pingcap/tidb/pkg/util/intest"
Expand Down Expand Up @@ -148,12 +148,6 @@ func initSessCtx(
sqlMode mysql.SQLMode,
tzLocation *model.TimeZoneLocation,
) error {
// Unify the TimeZone settings in newContext.
if sessCtx.GetSessionVars().StmtCtx.TimeZone() == nil {
tz := *time.UTC
sessCtx.GetSessionVars().StmtCtx.SetTimeZone(&tz)
}
sessCtx.GetSessionVars().StmtCtx.IsDDLJobInQueue = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change expected? Though it seems that it's only used in pkg/executor/ddl.go 🤔 .

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think IsDDLJobInQueue is used to make ShouldClipToZeror in the previous code. The verify CI passed after removing it.. @tangenta PTAL

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, pkg/executor/ddl.go is not related with pkg/ddl/backfilling_scheduler.go 🤔 , so it seems fine for this case.

// Set the row encode format version.
rowFormat := variable.GetDDLReorgRowFormat()
sessCtx.GetSessionVars().RowEncoder.Enable = rowFormat != variable.DefTiDBRowFormatV1
Expand All @@ -162,15 +156,17 @@ func initSessCtx(
if err := setSessCtxLocation(sessCtx, tzLocation); err != nil {
return errors.Trace(err)
}
sessCtx.GetSessionVars().StmtCtx.SetTimeZone(sessCtx.GetSessionVars().Location())
sessCtx.GetSessionVars().StmtCtx.BadNullAsWarning = !sqlMode.HasStrictMode()
sessCtx.GetSessionVars().StmtCtx.OverflowAsWarning = !sqlMode.HasStrictMode()
sessCtx.GetSessionVars().StmtCtx.AllowInvalidDate = sqlMode.HasAllowInvalidDatesMode()
sessCtx.GetSessionVars().StmtCtx.DividedByZeroAsWarning = !sqlMode.HasStrictMode()
sessCtx.GetSessionVars().StmtCtx.IgnoreZeroInDate = !sqlMode.HasStrictMode() || sqlMode.HasAllowInvalidDatesMode()
sessCtx.GetSessionVars().StmtCtx.NoZeroDate = sqlMode.HasStrictMode()

typeFlags := sessCtx.GetSessionVars().StmtCtx.TypeFlags().WithTruncateAsWarning(!sqlMode.HasStrictMode())
sessCtx.GetSessionVars().StmtCtx.SetTypeFlags(typeFlags)
sessCtx.GetSessionVars().StmtCtx.SetTypeFlags(types.StrictFlags.
WithTruncateAsWarning(!sqlMode.HasStrictMode()).
WithClipNegativeToZero(true),
)

// Prevent initializing the mock context in the workers concurrently.
// For details, see https://github.com/pingcap/tidb/issues/40879.
Expand Down
7 changes: 6 additions & 1 deletion pkg/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2148,7 +2148,12 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.SetTypeFlags(sc.TypeFlags().
WithSkipUTF8Check(vars.SkipUTF8Check).
WithSkipSACIICheck(vars.SkipASCIICheck).
WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load()))
WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load()).
// WithClipNegativeToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types.
// This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode.
// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html
WithClipNegativeToZero(sc.InInsertStmt || sc.InLoadDataStmt || sc.InUpdateStmt || sc.InCreateOrAlterStmt),
)

vars.PlanCacheParams.Reset()
if priority := mysql.PriorityEnum(atomic.LoadInt32(&variable.ForcePriority)); priority != mysql.NoPriority {
Expand Down
6 changes: 4 additions & 2 deletions pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,9 @@ var fakeSctx = newFakeSctx()

func newFakeSctx() *stmtctx.StatementContext {
sc := stmtctx.NewStmtCtx()
sc.InInsertStmt = true
sc.SetTypeFlags(types.StrictFlags.
WithClipNegativeToZero(true),
)
return sc
}

Expand Down Expand Up @@ -980,7 +982,7 @@ func (b *builtinCastRealAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool
} else {
var uintVal uint64
sc := b.ctx.GetSessionVars().StmtCtx
uintVal, err = types.ConvertFloatToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong)
uintVal, err = types.ConvertFloatToUint(sc.TypeFlags(), val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong)
res = int64(uintVal)
}
if types.ErrOverflow.Equal(err) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin_cast_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ func (b *builtinCastRealAsIntSig) vecEvalInt(input *chunk.Chunk, result *chunk.C
} else {
var uintVal uint64
sc := b.ctx.GetSessionVars().StmtCtx
uintVal, err = types.ConvertFloatToUint(sc, f64s[i], types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong)
uintVal, err = types.ConvertFloatToUint(sc.TypeFlags(), f64s[i], types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong)
i64s[i] = int64(uintVal)
}
if types.ErrOverflow.Equal(err) {
Expand Down
22 changes: 4 additions & 18 deletions pkg/sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,12 +468,6 @@ func (sc *StatementContext) SetTypeFlags(flags typectx.Flags) {
sc.TypeCtx = sc.TypeCtx.WithFlags(flags)
}

// UpdateTypeFlags updates the flags of the type context
func (sc *StatementContext) UpdateTypeFlags(fn func(typectx.Flags) typectx.Flags) {
flags := fn(sc.TypeCtx.Flags())
sc.TypeCtx = sc.TypeCtx.WithFlags(flags)
}

// StmtHints are SessionVars related sql hints.
type StmtHints struct {
// Hint Information
Expand Down Expand Up @@ -1123,13 +1117,6 @@ func (sc *StatementContext) GetExecDetails() execdetails.ExecDetails {
return details
}

// ShouldClipToZero indicates whether values less than 0 should be clipped to 0 for unsigned integer types.
// This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode.
// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html
func (sc *StatementContext) ShouldClipToZero() bool {
return sc.InInsertStmt || sc.InLoadDataStmt || sc.InUpdateStmt || sc.InCreateOrAlterStmt || sc.IsDDLJobInQueue
}

// ShouldIgnoreOverflowError indicates whether we should ignore the error when type conversion overflows,
// so we can leave it for further processing like clipping values less than 0 to 0 for unsigned integer types.
func (sc *StatementContext) ShouldIgnoreOverflowError() bool {
Expand Down Expand Up @@ -1226,12 +1213,11 @@ func (sc *StatementContext) InitFromPBFlagAndTz(flags uint64, tz *time.Location)
sc.IgnoreZeroInDate = (flags & model.FlagIgnoreZeroInDate) > 0
sc.DividedByZeroAsWarning = (flags & model.FlagDividedByZeroAsWarning) > 0
sc.SetTimeZone(tz)

typeFlags := sc.TypeCtx.Flags()
typeFlags = typeFlags.
sc.SetTypeFlags(typectx.StrictFlags.
lcwangchao marked this conversation as resolved.
Show resolved Hide resolved
WithIgnoreTruncateErr((flags & model.FlagIgnoreTruncate) > 0).
WithTruncateAsWarning((flags & model.FlagTruncateAsWarning) > 0)
sc.TypeCtx = typectx.NewContext(typeFlags, tz, sc.AppendWarning)
WithTruncateAsWarning((flags & model.FlagTruncateAsWarning) > 0).
WithClipNegativeToZero(sc.InInsertStmt),
)
}

// GetLockWaitStartTime returns the statement pessimistic lock wait start time
Expand Down
6 changes: 0 additions & 6 deletions pkg/sessionctx/stmtctx/stmtctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,6 @@ func TestSetStmtCtxTypeFlags(t *testing.T) {
sc.SetTypeFlags(typectx.FlagSkipASCIICheck | typectx.FlagSkipUTF8Check | typectx.FlagInvalidDateAsWarning)
require.Equal(t, typectx.FlagSkipASCIICheck|typectx.FlagSkipUTF8Check|typectx.FlagInvalidDateAsWarning, sc.TypeFlags())
require.Equal(t, sc.TypeFlags(), sc.TypeCtx.Flags())

sc.UpdateTypeFlags(func(flags typectx.Flags) typectx.Flags {
return (flags | typectx.FlagSkipUTF8Check | typectx.FlagClipNegativeToZero) &^ typectx.FlagSkipASCIICheck
})
require.Equal(t, typectx.FlagSkipUTF8Check|typectx.FlagClipNegativeToZero|typectx.FlagInvalidDateAsWarning, sc.TypeFlags())
require.Equal(t, sc.TypeFlags(), sc.TypeCtx.Flags())
}

func TestResetStmtCtx(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/store/mockstore/mockcopr/cop_handler_dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ func (e *evalContext) decodeRelatedColumnVals(relatedColOffsets []int, value [][

// flagsAndTzToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`.
func flagsAndTzToStatementContext(flags uint64, tz *time.Location) *stmtctx.StatementContext {
sc := new(stmtctx.StatementContext)
sc := stmtctx.NewStmtCtx()
sc.InitFromPBFlagAndTz(flags, tz)
return sc
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/store/mockstore/unistore/cophandler/cop_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ func newRowDecoder(columnInfos []*tipb.ColumnInfo, fieldTps []*types.FieldType,

// flagsAndTzToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`.
func flagsAndTzToStatementContext(flags uint64, tz *time.Location) *stmtctx.StatementContext {
sc := new(stmtctx.StatementContext)
sc := stmtctx.NewStmtCtx()
sc.InitFromPBFlagAndTz(flags, tz)
return sc
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/table/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ func FillVirtualColumnValue(virtualRetTypes []*types.FieldType, virtualColumnInd
}

// Clip to zero if get negative value after cast to unsigned.
if mysql.HasUnsignedFlag(colInfos[idx].FieldType.GetFlag()) && !castDatum.IsNull() && !sctx.GetSessionVars().StmtCtx.ShouldClipToZero() {
if mysql.HasUnsignedFlag(colInfos[idx].FieldType.GetFlag()) && !castDatum.IsNull() && !sctx.GetSessionVars().StmtCtx.TypeFlags().ClipNegativeToZero() {
switch datum.Kind() {
case types.KindInt64:
if datum.GetInt64() < 0 {
Expand Down
13 changes: 13 additions & 0 deletions pkg/types/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ const (
FlagSkipUTF8MB4Check
)

// ClipNegativeToZero indicates whether the flag `FlagClipNegativeToZero` is set
func (f Flags) ClipNegativeToZero() bool {
return f&FlagClipNegativeToZero != 0
}

// WithClipNegativeToZero returns a new flags with `FlagClipNegativeToZero` set/unset according to the clip parameter
func (f Flags) WithClipNegativeToZero(clip bool) Flags {
if clip {
return f | FlagClipNegativeToZero
}
return f &^ FlagClipNegativeToZero
}

// SkipASCIICheck indicates whether the flag `FlagSkipASCIICheck` is set
func (f Flags) SkipASCIICheck() bool {
return f&FlagSkipASCIICheck != 0
Expand Down
16 changes: 13 additions & 3 deletions pkg/types/context/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,23 @@ func TestWithNewFlags(t *testing.T) {
require.Equal(t, time.UTC, ctx2.Location())
}

func TestStringFlags(t *testing.T) {
func TestSimpleOnOffFlags(t *testing.T) {
cases := []struct {
name string
flag Flags
readFn func(f Flags) bool
writeFn func(f Flags, skip bool) Flags
readFn func(Flags) bool
writeFn func(Flags, bool) Flags
}{
{
name: "FlagClipNegativeToZero",
flag: FlagClipNegativeToZero,
readFn: func(f Flags) bool {
return f.ClipNegativeToZero()
},
writeFn: func(f Flags, clip bool) Flags {
return f.WithClipNegativeToZero(clip)
},
},
{
name: "FlagSkipASCIICheck",
flag: FlagSkipASCIICheck,
Expand Down
12 changes: 6 additions & 6 deletions pkg/types/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) {
}

// ConvertIntToUint converts an int value to an uint value.
func ConvertIntToUint(sc *stmtctx.StatementContext, val int64, upperBound uint64, tp byte) (uint64, error) {
if sc.ShouldClipToZero() && val < 0 {
func ConvertIntToUint(flags Flags, val int64, upperBound uint64, tp byte) (uint64, error) {
if val < 0 && flags.ClipNegativeToZero() {
return 0, overflow(val, tp)
}

Expand All @@ -167,10 +167,10 @@ func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) {
}

// ConvertFloatToUint converts a float value to an uint value.
func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) {
func ConvertFloatToUint(flags Flags, fval float64, upperBound uint64, tp byte) (uint64, error) {
val := RoundFloat(fval)
if val < 0 {
if sc.ShouldClipToZero() {
if flags.ClipNegativeToZero() {
return 0, overflow(val, tp)
}
return uint64(int64(val)), overflow(val, tp)
Expand Down Expand Up @@ -585,7 +585,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j BinaryJSON, unsigned bool,
i := j.GetInt64()
if unsigned {
uBound := IntergerUnsignedUpperBound(tp)
u, err := ConvertIntToUint(sc, i, uBound, tp)
u, err := ConvertIntToUint(sc.TypeFlags(), i, uBound, tp)
return int64(u), sc.HandleOverflow(err, err)
}

Expand Down Expand Up @@ -613,7 +613,7 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j BinaryJSON, unsigned bool,
return u, sc.HandleOverflow(e, e)
}
bound := IntergerUnsignedUpperBound(tp)
u, err := ConvertFloatToUint(sc, f, bound, tp)
u, err := ConvertFloatToUint(sc.TypeFlags(), f, bound, tp)
return int64(u), sc.HandleOverflow(err, err)
case JSONTypeCodeString:
str := string(hack.String(j.GetString()))
Expand Down
10 changes: 5 additions & 5 deletions pkg/types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -1190,11 +1190,11 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
)
switch d.k {
case KindInt64:
val, err = ConvertIntToUint(sc, d.GetInt64(), upperBound, tp)
val, err = ConvertIntToUint(sc.TypeFlags(), d.GetInt64(), upperBound, tp)
case KindUint64:
val, err = ConvertUintToUint(d.GetUint64(), upperBound, tp)
case KindFloat32, KindFloat64:
val, err = ConvertFloatToUint(sc, d.GetFloat64(), upperBound, tp)
val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetFloat64(), upperBound, tp)
case KindString, KindBytes:
uval, err1 := StrToUint(sc.TypeCtxOrDefault(), d.GetString(), false)
if err1 != nil && ErrOverflow.Equal(err1) && !sc.ShouldIgnoreOverflowError() {
Expand All @@ -1211,7 +1211,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
if err == nil {
err = err1
}
val, err1 = ConvertIntToUint(sc, ival, upperBound, tp)
val, err1 = ConvertIntToUint(sc.TypeFlags(), ival, upperBound, tp)
if err == nil {
err = err1
}
Expand All @@ -1226,9 +1226,9 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
case KindMysqlDecimal:
val, err = ConvertDecimalToUint(sc, d.GetMysqlDecimal(), upperBound, tp)
case KindMysqlEnum:
val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp)
val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetMysqlEnum().ToNumber(), upperBound, tp)
case KindMysqlSet:
val, err = ConvertFloatToUint(sc, d.GetMysqlSet().ToNumber(), upperBound, tp)
val, err = ConvertFloatToUint(sc.TypeFlags(), d.GetMysqlSet().ToNumber(), upperBound, tp)
case KindBinaryLiteral, KindMysqlBit:
val, err = d.GetBinaryLiteral().ToInt(sc.TypeCtxOrDefault())
if err == nil {
Expand Down
Loading