diff --git a/go/vt/sqlparser/expression_converter.go b/go/vt/sqlparser/expression_converter.go index d0bf78229bf..695ad24fe97 100644 --- a/go/vt/sqlparser/expression_converter.go +++ b/go/vt/sqlparser/expression_converter.go @@ -22,7 +22,8 @@ import ( "vitess.io/vitess/go/vt/vtgate/evalengine" ) -var ExprNotSupported = fmt.Errorf("Expr Not Supported") +// ErrExprNotSupported signals that the expression cannot be handled by expression evaluation engine. +var ErrExprNotSupported = fmt.Errorf("Expr Not Supported") //Convert converts between AST expressions and executable expressions func Convert(e Expr) (evalengine.Expr, error) { @@ -30,14 +31,19 @@ func Convert(e Expr) (evalengine.Expr, error) { case *SQLVal: switch node.Type { case IntVal: - return evalengine.NewLiteralInt(node.Val) + return evalengine.NewLiteralIntFromBytes(node.Val) case FloatVal: return evalengine.NewLiteralFloat(node.Val) case ValArg: - return &evalengine.BindVariable{Key: string(node.Val[1:])}, nil + return evalengine.NewBindVar(string(node.Val[1:])), nil case StrVal: - return evalengine.NewLiteralString(node.Val) + return evalengine.NewLiteralString(node.Val), nil } + case BoolVal: + if node { + return evalengine.NewLiteralIntFromBytes([]byte("1")) + } + return evalengine.NewLiteralIntFromBytes([]byte("0")) case *BinaryExpr: var op evalengine.BinaryExpr switch node.Operator { @@ -50,7 +56,7 @@ func Convert(e Expr) (evalengine.Expr, error) { case DivStr: op = &evalengine.Division{} default: - return nil, ExprNotSupported + return nil, ErrExprNotSupported } left, err := Convert(node.Left) if err != nil { @@ -67,5 +73,5 @@ func Convert(e Expr) (evalengine.Expr, error) { }, nil } - return nil, ExprNotSupported + return nil, ErrExprNotSupported } diff --git a/go/vt/sqlparser/set_normalizer.go b/go/vt/sqlparser/set_normalizer.go index 4926708582f..bfd300138e4 100644 --- a/go/vt/sqlparser/set_normalizer.go +++ b/go/vt/sqlparser/set_normalizer.go @@ -50,13 +50,13 @@ func (n *setNormalizer) normalizeSetExpr(in *SetExpr) (*SetExpr, error) { } switch { case strings.HasPrefix(in.Name.Lowered(), "session."): - in.Name = NewColIdent(in.Name.Lowered()[8:]) + in.Name = createColumn(in.Name.Lowered()[8:]) in.Scope = SessionStr case strings.HasPrefix(in.Name.Lowered(), "global."): - in.Name = NewColIdent(in.Name.Lowered()[7:]) + in.Name = createColumn(in.Name.Lowered()[7:]) in.Scope = GlobalStr case strings.HasPrefix(in.Name.Lowered(), "vitess_metadata."): - in.Name = NewColIdent(in.Name.Lowered()[16:]) + in.Name = createColumn(in.Name.Lowered()[16:]) in.Scope = VitessMetadataStr default: in.Name.at = NoAt @@ -79,3 +79,11 @@ func (n *setNormalizer) normalizeSetExpr(in *SetExpr) (*SetExpr, error) { } panic("this should never happen") } + +func createColumn(str string) ColIdent { + size := len(str) + if str[0] == '`' && str[size-1] == '`' { + str = str[1 : size-1] + } + return NewColIdent(str) +} diff --git a/go/vt/sqlparser/set_normalizer_test.go b/go/vt/sqlparser/set_normalizer_test.go index 394ff4f63ee..b115539b547 100644 --- a/go/vt/sqlparser/set_normalizer_test.go +++ b/go/vt/sqlparser/set_normalizer_test.go @@ -26,8 +26,8 @@ func TestNormalizeSetExpr(t *testing.T) { tests := []struct { in, expected, err string }{{ - in: "@@foo = 42", - expected: "session foo = 42", + in: "@@session.x.foo = 42", + expected: "session `x.foo` = 42", }, { in: "@@session.foo = 42", expected: "session foo = 42", @@ -57,8 +57,14 @@ func TestNormalizeSetExpr(t *testing.T) { in: "@@x.foo = 42", expected: "session `x.foo` = 42", }, { - in: "@@session.x.foo = 42", - expected: "session `x.foo` = 42", + in: "@@session.`foo` = 1", + expected: "session foo = 1", + }, { + in: "@@global.`foo` = 1", + expected: "global foo = 1", + }, { + in: "@@vitess_metadata.`foo` = 1", + expected: "vitess_metadata foo = 1", //}, { TODO: we should support local scope as well // in: "local foo = 42", // expected: "session foo = 42", diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index c7888ace99f..9f9464c4ac8 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -85,7 +85,31 @@ func (t noopVCursor) Session() SessionActions { return t } -func (t noopVCursor) SetTarget(target string) error { +func (t noopVCursor) SetAutocommit(bool) error { + panic("implement me") +} + +func (t noopVCursor) SetClientFoundRows(bool) { + panic("implement me") +} + +func (t noopVCursor) SetSkipQueryPlanCache(bool) { + panic("implement me") +} + +func (t noopVCursor) SetSQLSelectLimit(int64) { + panic("implement me") +} + +func (t noopVCursor) SetTransactionMode(vtgatepb.TransactionMode) { + panic("implement me") +} + +func (t noopVCursor) SetWorkload(querypb.ExecuteOptions_Workload) { + panic("implement me") +} + +func (t noopVCursor) SetTarget(string) error { panic("implement me") } @@ -192,7 +216,7 @@ func (f *loggingVCursor) ShardSession() []*srvtopo.ResolvedShard { return nil } -func (f *loggingVCursor) ExecuteVSchema(keyspace string, vschemaDDL *sqlparser.DDL) error { +func (f *loggingVCursor) ExecuteVSchema(string, *sqlparser.DDL) error { panic("implement me") } @@ -209,7 +233,7 @@ func (f *loggingVCursor) Context() context.Context { return context.Background() } -func (f *loggingVCursor) SetContextTimeout(timeout time.Duration) context.CancelFunc { +func (f *loggingVCursor) SetContextTimeout(time.Duration) context.CancelFunc { return func() {} } @@ -221,7 +245,7 @@ func (f *loggingVCursor) RecordWarning(warning *querypb.QueryWarning) { f.warnings = append(f.warnings, warning) } -func (f *loggingVCursor) Execute(method string, query string, bindvars map[string]*querypb.BindVariable, rollbackOnError bool, co vtgatepb.CommitOrder) (*sqltypes.Result, error) { +func (f *loggingVCursor) Execute(_ string, query string, bindvars map[string]*querypb.BindVariable, rollbackOnError bool, co vtgatepb.CommitOrder) (*sqltypes.Result, error) { name := "Unknown" switch co { case vtgatepb.CommitOrder_NORMAL: @@ -351,6 +375,30 @@ func (f *loggingVCursor) Rewind() { f.warnings = nil } +func (f *loggingVCursor) SetAutocommit(bool) error { + panic("implement me") +} + +func (f *loggingVCursor) SetClientFoundRows(bool) { + panic("implement me") +} + +func (f *loggingVCursor) SetSkipQueryPlanCache(bool) { + panic("implement me") +} + +func (f *loggingVCursor) SetSQLSelectLimit(int64) { + panic("implement me") +} + +func (f *loggingVCursor) SetTransactionMode(vtgatepb.TransactionMode) { + panic("implement me") +} + +func (f *loggingVCursor) SetWorkload(querypb.ExecuteOptions_Workload) { + panic("implement me") +} + func (f *loggingVCursor) nextResult() (*sqltypes.Result, error) { if f.results == nil || f.curResult >= len(f.results) { return &sqltypes.Result{}, f.resultErr diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index 78f5fa9d693..7c3bac5c932 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -107,6 +107,13 @@ type ( // ShardSession returns shard info about open connections ShardSession() []*srvtopo.ResolvedShard + + SetAutocommit(bool) error + SetClientFoundRows(bool) + SetSkipQueryPlanCache(bool) + SetSQLSelectLimit(int64) + SetTransactionMode(vtgatepb.TransactionMode) + SetWorkload(querypb.ExecuteOptions_Workload) } // Plan represents the execution strategy for a given query. diff --git a/go/vt/vtgate/engine/set.go b/go/vt/vtgate/engine/set.go index 042d34528f0..7c8148c75c1 100644 --- a/go/vt/vtgate/engine/set.go +++ b/go/vt/vtgate/engine/set.go @@ -20,6 +20,9 @@ import ( "bytes" "encoding/json" "fmt" + "strings" + + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" "vitess.io/vitess/go/vt/log" @@ -78,13 +81,11 @@ type ( Expr string } - // SysVarSetSpecial implements the SetOp interface and will write the changes variable into the session + // SysVarSetAware implements the SetOp interface and will write the changes variable into the session // The special part is that these settings change the sessions behaviour in different ways - SysVarSetSpecial struct { - Name string - Keyspace *vindexes.Keyspace - TargetDestination key.Destination `json:",omitempty"` - Expr string + SysVarSetAware struct { + Name string + Expr evalengine.Expr } ) @@ -338,3 +339,119 @@ func (svs *SysVarSet) checkAndUpdateSysVar(vcursor VCursor, res evalengine.Expre vcursor.Session().NeedsReservedConn() return true, nil } + +var _ SetOp = (*SysVarSetAware)(nil) + +// System variables that needs special handling +const ( + Autocommit = "autocommit" + ClientFoundRows = "client_found_rows" + SkipQueryPlanCache = "skip_query_plan_cache" + TxReadOnly = "tx_read_only" + TransactionReadOnly = "transaction_read_only" + SQLSelectLimit = "sql_select_limit" + TransactionMode = "transaction_mode" + Workload = "workload" + Charset = "charset" + Names = "names" +) + +//MarshalJSON provides the type to SetOp for plan json +func (svss *SysVarSetAware) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Expr string + }{ + Type: "SysVarAware", + Name: svss.Name, + Expr: svss.Expr.String(), + }) +} + +//Execute implements the SetOp interface method +func (svss *SysVarSetAware) Execute(vcursor VCursor, env evalengine.ExpressionEnv) error { + switch svss.Name { + // These are all the boolean values we need to handle + case Autocommit, ClientFoundRows, SkipQueryPlanCache, TxReadOnly, TransactionReadOnly: + value, err := svss.Expr.Evaluate(env) + if err != nil { + return err + } + boolValue, err := value.ToBooleanStrict() + if err != nil { + return vterrors.Wrapf(err, "System setting '%s' can't be set to this value", svss.Name) + } + switch svss.Name { + case Autocommit: + vcursor.Session().SetAutocommit(boolValue) + case ClientFoundRows: + vcursor.Session().SetClientFoundRows(boolValue) + case SkipQueryPlanCache: + vcursor.Session().SetSkipQueryPlanCache(boolValue) + case TxReadOnly, TransactionReadOnly: + // TODO (4127): This is a dangerous NOP. + } + + case SQLSelectLimit: + value, err := svss.Expr.Evaluate(env) + if err != nil { + return err + } + + v := value.Value() + if !v.IsIntegral() { + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for sql_select_limit: %T", value.Value().Type().String()) + } + intValue, err := v.ToInt64() + if err != nil { + return err + } + vcursor.Session().SetSQLSelectLimit(intValue) + + // String settings + case TransactionMode, Workload, Charset, Names: + value, err := svss.Expr.Evaluate(env) + if err != nil { + return err + } + v := value.Value() + if !v.IsText() && !v.IsBinary() { + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for %s: %s", svss.Name, value.Value().Type().String()) + } + + str := v.ToString() + switch svss.Name { + case TransactionMode: + out, ok := vtgatepb.TransactionMode_value[strings.ToUpper(str)] + if !ok { + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid transaction_mode: %s", str) + } + vcursor.Session().SetTransactionMode(vtgatepb.TransactionMode(out)) + case Workload: + out, ok := querypb.ExecuteOptions_Workload_value[strings.ToUpper(str)] + if !ok { + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid workload: %s", str) + } + vcursor.Session().SetWorkload(querypb.ExecuteOptions_Workload(out)) + case Charset, Names: + switch strings.ToLower(str) { + case "", "utf8", "utf8mb4", "latin1", "default": + // do nothing + break + default: + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value for charset/names: %v", str) + } + } + + default: + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unsupported construct") + } + + return nil +} + +//VariableName implements the SetOp interface method +func (svss *SysVarSetAware) VariableName() string { + return svss.Name +} diff --git a/go/vt/vtgate/engine/set_test.go b/go/vt/vtgate/engine/set_test.go index da4b7ae3ec7..fb23d7ddf0e 100644 --- a/go/vt/vtgate/engine/set_test.go +++ b/go/vt/vtgate/engine/set_test.go @@ -18,7 +18,6 @@ package engine import ( "fmt" - "strconv" "testing" "github.com/stretchr/testify/require" @@ -75,12 +74,6 @@ func TestSetTable(t *testing.T) { expectedError string } - intExpr := func(i int) evalengine.Expr { - s := strconv.FormatInt(int64(i), 10) - e, _ := evalengine.NewLiteralInt([]byte(s)) - return e - } - tests := []testCase{ { testName: "nil set ops", @@ -91,7 +84,7 @@ func TestSetTable(t *testing.T) { setOps: []SetOp{ &UserDefinedVariable{ Name: "x", - Expr: intExpr(42), + Expr: evalengine.NewLiteralInt(42), }, }, expectedQueryLog: []string{ @@ -173,7 +166,7 @@ func TestSetTable(t *testing.T) { setOps: []SetOp{ &UserDefinedVariable{ Name: "x", - Expr: intExpr(1), + Expr: evalengine.NewLiteralInt(1), }, &SysVarIgnore{ Name: "y", diff --git a/go/vt/vtgate/evalengine/arithmetic.go b/go/vt/vtgate/evalengine/arithmetic.go index 115ec7730f1..f40e9d5c456 100644 --- a/go/vt/vtgate/evalengine/arithmetic.go +++ b/go/vt/vtgate/evalengine/arithmetic.go @@ -363,63 +363,63 @@ func ToNative(v sqltypes.Value) (interface{}, error) { return out, err } -// newEvalResult parses a value and produces an evalResult containing the value -func newEvalResult(v sqltypes.Value) (evalResult, error) { +// newEvalResult parses a value and produces an EvalResult containing the value +func newEvalResult(v sqltypes.Value) (EvalResult, error) { raw := v.Raw() switch { case v.IsBinary() || v.IsText(): - return evalResult{bytes: raw, typ: sqltypes.VarBinary}, nil + return EvalResult{bytes: raw, typ: sqltypes.VarBinary}, nil case v.IsSigned(): ival, err := strconv.ParseInt(string(raw), 10, 64) if err != nil { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return evalResult{ival: ival, typ: sqltypes.Int64}, nil + return EvalResult{ival: ival, typ: sqltypes.Int64}, nil case v.IsUnsigned(): uval, err := strconv.ParseUint(string(raw), 10, 64) if err != nil { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return evalResult{uval: uval, typ: sqltypes.Uint64}, nil + return EvalResult{uval: uval, typ: sqltypes.Uint64}, nil case v.IsFloat() || v.Type() == sqltypes.Decimal: fval, err := strconv.ParseFloat(string(raw), 64) if err != nil { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return evalResult{fval: fval, typ: sqltypes.Float64}, nil + return EvalResult{fval: fval, typ: sqltypes.Float64}, nil } - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "this should not be reached. got %s", v.String()) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "this should not be reached. got %s", v.String()) } // newIntegralNumeric parses a value and produces an Int64 or Uint64. -func newIntegralNumeric(v sqltypes.Value) (evalResult, error) { +func newIntegralNumeric(v sqltypes.Value) (EvalResult, error) { str := v.ToString() switch { case v.IsSigned(): ival, err := strconv.ParseInt(str, 10, 64) if err != nil { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return evalResult{ival: ival, typ: sqltypes.Int64}, nil + return EvalResult{ival: ival, typ: sqltypes.Int64}, nil case v.IsUnsigned(): uval, err := strconv.ParseUint(str, 10, 64) if err != nil { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return evalResult{uval: uval, typ: sqltypes.Uint64}, nil + return EvalResult{uval: uval, typ: sqltypes.Uint64}, nil } // For other types, do best effort. if ival, err := strconv.ParseInt(str, 10, 64); err == nil { - return evalResult{ival: ival, typ: sqltypes.Int64}, nil + return EvalResult{ival: ival, typ: sqltypes.Int64}, nil } if uval, err := strconv.ParseUint(str, 10, 64); err == nil { - return evalResult{uval: uval, typ: sqltypes.Uint64}, nil + return EvalResult{uval: uval, typ: sqltypes.Uint64}, nil } - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", str) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", str) } -func addNumeric(v1, v2 evalResult) evalResult { +func addNumeric(v1, v2 EvalResult) EvalResult { v1, v2 = makeNumericAndprioritize(v1, v2) switch v1.typ { case sqltypes.Int64: @@ -437,7 +437,7 @@ func addNumeric(v1, v2 evalResult) evalResult { panic("unreachable") } -func addNumericWithError(v1, v2 evalResult) (evalResult, error) { +func addNumericWithError(v1, v2 EvalResult) (EvalResult, error) { v1, v2 = makeNumericAndprioritize(v1, v2) switch v1.typ { case sqltypes.Int64: @@ -452,11 +452,11 @@ func addNumericWithError(v1, v2 evalResult) (evalResult, error) { case sqltypes.Float64: return floatPlusAny(v1.fval, v2), nil } - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.Value().String(), v2.Value().String()) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.Value().String(), v2.Value().String()) } -func subtractNumericWithError(i1, i2 evalResult) (evalResult, error) { +func subtractNumericWithError(i1, i2 EvalResult) (EvalResult, error) { v1 := makeNumeric(i1) v2 := makeNumeric(i2) switch v1.typ { @@ -481,10 +481,10 @@ func subtractNumericWithError(i1, i2 evalResult) (evalResult, error) { case sqltypes.Float64: return floatMinusAny(v1.fval, v2), nil } - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.Value().String(), v2.Value().String()) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.Value().String(), v2.Value().String()) } -func multiplyNumericWithError(v1, v2 evalResult) (evalResult, error) { +func multiplyNumericWithError(v1, v2 EvalResult) (EvalResult, error) { v1, v2 = makeNumericAndprioritize(v1, v2) switch v1.typ { case sqltypes.Int64: @@ -499,11 +499,11 @@ func multiplyNumericWithError(v1, v2 evalResult) (evalResult, error) { case sqltypes.Float64: return floatTimesAny(v1.fval, v2), nil } - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.Value().String(), v2.Value().String()) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.Value().String(), v2.Value().String()) } -func divideNumericWithError(i1, i2 evalResult) (evalResult, error) { +func divideNumericWithError(i1, i2 EvalResult) (EvalResult, error) { v1 := makeNumeric(i1) v2 := makeNumeric(i2) switch v1.typ { @@ -516,12 +516,12 @@ func divideNumericWithError(i1, i2 evalResult) (evalResult, error) { case sqltypes.Float64: return floatDivideAnyWithError(v1.fval, v2) } - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.Value().String(), v2.Value().String()) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid arithmetic between: %s %s", v1.Value().String(), v2.Value().String()) } // makeNumericAndprioritize reorders the input parameters // to be Float64, Uint64, Int64. -func makeNumericAndprioritize(i1, i2 evalResult) (evalResult, evalResult) { +func makeNumericAndprioritize(i1, i2 EvalResult) (EvalResult, EvalResult) { v1 := makeNumeric(i1) v2 := makeNumeric(i2) switch v1.typ { @@ -537,20 +537,20 @@ func makeNumericAndprioritize(i1, i2 evalResult) (evalResult, evalResult) { return v1, v2 } -func makeNumeric(v evalResult) evalResult { +func makeNumeric(v EvalResult) EvalResult { if sqltypes.IsNumber(v.typ) { return v } if ival, err := strconv.ParseInt(string(v.bytes), 10, 64); err == nil { - return evalResult{ival: ival, typ: sqltypes.Int64} + return EvalResult{ival: ival, typ: sqltypes.Int64} } if fval, err := strconv.ParseFloat(string(v.bytes), 64); err == nil { - return evalResult{fval: fval, typ: sqltypes.Float64} + return EvalResult{fval: fval, typ: sqltypes.Float64} } - return evalResult{ival: 0, typ: sqltypes.Int64} + return EvalResult{ival: 0, typ: sqltypes.Int64} } -func intPlusInt(v1, v2 int64) evalResult { +func intPlusInt(v1, v2 int64) EvalResult { result := v1 + v2 if v1 > 0 && v2 > 0 && result < 0 { goto overflow @@ -558,61 +558,61 @@ func intPlusInt(v1, v2 int64) evalResult { if v1 < 0 && v2 < 0 && result > 0 { goto overflow } - return evalResult{typ: sqltypes.Int64, ival: result} + return EvalResult{typ: sqltypes.Int64, ival: result} overflow: - return evalResult{typ: sqltypes.Float64, fval: float64(v1) + float64(v2)} + return EvalResult{typ: sqltypes.Float64, fval: float64(v1) + float64(v2)} } -func intPlusIntWithError(v1, v2 int64) (evalResult, error) { +func intPlusIntWithError(v1, v2 int64) (EvalResult, error) { result := v1 + v2 if (result > v1) != (v2 > 0) { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2) } - return evalResult{typ: sqltypes.Int64, ival: result}, nil + return EvalResult{typ: sqltypes.Int64, ival: result}, nil } -func intMinusIntWithError(v1, v2 int64) (evalResult, error) { +func intMinusIntWithError(v1, v2 int64) (EvalResult, error) { result := v1 - v2 if (result < v1) != (v2 > 0) { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v - %v", v1, v2) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v - %v", v1, v2) } - return evalResult{typ: sqltypes.Int64, ival: result}, nil + return EvalResult{typ: sqltypes.Int64, ival: result}, nil } -func intTimesIntWithError(v1, v2 int64) (evalResult, error) { +func intTimesIntWithError(v1, v2 int64) (EvalResult, error) { result := v1 * v2 if v1 != 0 && result/v1 != v2 { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v * %v", v1, v2) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v * %v", v1, v2) } - return evalResult{typ: sqltypes.Int64, ival: result}, nil + return EvalResult{typ: sqltypes.Int64, ival: result}, nil } -func intMinusUintWithError(v1 int64, v2 uint64) (evalResult, error) { +func intMinusUintWithError(v1 int64, v2 uint64) (EvalResult, error) { if v1 < 0 || v1 < int64(v2) { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) } return uintMinusUintWithError(uint64(v1), v2) } -func uintPlusInt(v1 uint64, v2 int64) evalResult { +func uintPlusInt(v1 uint64, v2 int64) EvalResult { return uintPlusUint(v1, uint64(v2)) } -func uintPlusIntWithError(v1 uint64, v2 int64) (evalResult, error) { +func uintPlusIntWithError(v1 uint64, v2 int64) (EvalResult, error) { if v2 < 0 && v1 < uint64(v2) { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) } // convert to int -> uint is because for numeric operators (such as + or -) // where one of the operands is an unsigned integer, the result is unsigned by default. return uintPlusUintWithError(v1, uint64(v2)) } -func uintMinusIntWithError(v1 uint64, v2 int64) (evalResult, error) { +func uintMinusIntWithError(v1 uint64, v2 int64) (EvalResult, error) { if int64(v1) < v2 && v2 > 0 { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) } // uint - (- int) = uint + int if v2 < 0 { @@ -621,77 +621,77 @@ func uintMinusIntWithError(v1 uint64, v2 int64) (evalResult, error) { return uintMinusUintWithError(v1, uint64(v2)) } -func uintTimesIntWithError(v1 uint64, v2 int64) (evalResult, error) { +func uintTimesIntWithError(v1 uint64, v2 int64) (EvalResult, error) { if v2 < 0 || int64(v1) < 0 { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) } return uintTimesUintWithError(v1, uint64(v2)) } -func uintPlusUint(v1, v2 uint64) evalResult { +func uintPlusUint(v1, v2 uint64) EvalResult { result := v1 + v2 if result < v2 { - return evalResult{typ: sqltypes.Float64, fval: float64(v1) + float64(v2)} + return EvalResult{typ: sqltypes.Float64, fval: float64(v1) + float64(v2)} } - return evalResult{typ: sqltypes.Uint64, uval: result} + return EvalResult{typ: sqltypes.Uint64, uval: result} } -func uintPlusUintWithError(v1, v2 uint64) (evalResult, error) { +func uintPlusUintWithError(v1, v2 uint64) (EvalResult, error) { result := v1 + v2 if result < v2 { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) } - return evalResult{typ: sqltypes.Uint64, uval: result}, nil + return EvalResult{typ: sqltypes.Uint64, uval: result}, nil } -func uintMinusUintWithError(v1, v2 uint64) (evalResult, error) { +func uintMinusUintWithError(v1, v2 uint64) (EvalResult, error) { result := v1 - v2 if v2 > v1 { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) } - return evalResult{typ: sqltypes.Uint64, uval: result}, nil + return EvalResult{typ: sqltypes.Uint64, uval: result}, nil } -func uintTimesUintWithError(v1, v2 uint64) (evalResult, error) { +func uintTimesUintWithError(v1, v2 uint64) (EvalResult, error) { result := v1 * v2 if result < v2 || result < v1 { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) } - return evalResult{typ: sqltypes.Uint64, uval: result}, nil + return EvalResult{typ: sqltypes.Uint64, uval: result}, nil } -func floatPlusAny(v1 float64, v2 evalResult) evalResult { +func floatPlusAny(v1 float64, v2 EvalResult) EvalResult { switch v2.typ { case sqltypes.Int64: v2.fval = float64(v2.ival) case sqltypes.Uint64: v2.fval = float64(v2.uval) } - return evalResult{typ: sqltypes.Float64, fval: v1 + v2.fval} + return EvalResult{typ: sqltypes.Float64, fval: v1 + v2.fval} } -func floatMinusAny(v1 float64, v2 evalResult) evalResult { +func floatMinusAny(v1 float64, v2 EvalResult) EvalResult { switch v2.typ { case sqltypes.Int64: v2.fval = float64(v2.ival) case sqltypes.Uint64: v2.fval = float64(v2.uval) } - return evalResult{typ: sqltypes.Float64, fval: v1 - v2.fval} + return EvalResult{typ: sqltypes.Float64, fval: v1 - v2.fval} } -func floatTimesAny(v1 float64, v2 evalResult) evalResult { +func floatTimesAny(v1 float64, v2 EvalResult) EvalResult { switch v2.typ { case sqltypes.Int64: v2.fval = float64(v2.ival) case sqltypes.Uint64: v2.fval = float64(v2.uval) } - return evalResult{typ: sqltypes.Float64, fval: v1 * v2.fval} + return EvalResult{typ: sqltypes.Float64, fval: v1 * v2.fval} } -func floatDivideAnyWithError(v1 float64, v2 evalResult) (evalResult, error) { +func floatDivideAnyWithError(v1 float64, v2 EvalResult) (EvalResult, error) { switch v2.typ { case sqltypes.Int64: v2.fval = float64(v2.ival) @@ -703,23 +703,23 @@ func floatDivideAnyWithError(v1 float64, v2 evalResult) (evalResult, error) { resultMismatch := v2.fval*result != v1 if divisorLessThanOne && resultMismatch { - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in %v / %v", v1, v2.fval) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in %v / %v", v1, v2.fval) } - return evalResult{typ: sqltypes.Float64, fval: v1 / v2.fval}, nil + return EvalResult{typ: sqltypes.Float64, fval: v1 / v2.fval}, nil } -func anyMinusFloat(v1 evalResult, v2 float64) evalResult { +func anyMinusFloat(v1 EvalResult, v2 float64) EvalResult { switch v1.typ { case sqltypes.Int64: v1.fval = float64(v1.ival) case sqltypes.Uint64: v1.fval = float64(v1.uval) } - return evalResult{typ: sqltypes.Float64, fval: v1.fval - v2} + return EvalResult{typ: sqltypes.Float64, fval: v1.fval - v2} } -func castFromNumeric(v evalResult, resultType querypb.Type) sqltypes.Value { +func castFromNumeric(v EvalResult, resultType querypb.Type) sqltypes.Value { switch { case sqltypes.IsSigned(resultType): switch v.typ { @@ -758,7 +758,7 @@ func castFromNumeric(v evalResult, resultType querypb.Type) sqltypes.Value { return sqltypes.NULL } -func compareNumeric(v1, v2 evalResult) int { +func compareNumeric(v1, v2 EvalResult) int { // Equalize the types. switch v1.typ { case sqltypes.Int64: @@ -767,9 +767,9 @@ func compareNumeric(v1, v2 evalResult) int { if v1.ival < 0 { return -1 } - v1 = evalResult{typ: sqltypes.Uint64, uval: uint64(v1.ival)} + v1 = EvalResult{typ: sqltypes.Uint64, uval: uint64(v1.ival)} case sqltypes.Float64: - v1 = evalResult{typ: sqltypes.Float64, fval: float64(v1.ival)} + v1 = EvalResult{typ: sqltypes.Float64, fval: float64(v1.ival)} } case sqltypes.Uint64: switch v2.typ { @@ -777,16 +777,16 @@ func compareNumeric(v1, v2 evalResult) int { if v2.ival < 0 { return 1 } - v2 = evalResult{typ: sqltypes.Uint64, uval: uint64(v2.ival)} + v2 = EvalResult{typ: sqltypes.Uint64, uval: uint64(v2.ival)} case sqltypes.Float64: - v1 = evalResult{typ: sqltypes.Float64, fval: float64(v1.uval)} + v1 = EvalResult{typ: sqltypes.Float64, fval: float64(v1.uval)} } case sqltypes.Float64: switch v2.typ { case sqltypes.Int64: - v2 = evalResult{typ: sqltypes.Float64, fval: float64(v2.ival)} + v2 = EvalResult{typ: sqltypes.Float64, fval: float64(v2.ival)} case sqltypes.Uint64: - v2 = evalResult{typ: sqltypes.Float64, fval: float64(v2.uval)} + v2 = EvalResult{typ: sqltypes.Float64, fval: float64(v2.uval)} } } diff --git a/go/vt/vtgate/evalengine/arithmetic_test.go b/go/vt/vtgate/evalengine/arithmetic_test.go index 148588c0de1..cf5f5a099bb 100644 --- a/go/vt/vtgate/evalengine/arithmetic_test.go +++ b/go/vt/vtgate/evalengine/arithmetic_test.go @@ -882,7 +882,7 @@ func TestToNative(t *testing.T) { var mustMatch = utils.MustMatchFn( []interface{}{ // types with unexported fields - evalResult{}, + EvalResult{}, }, []string{}, // ignored fields ) @@ -890,25 +890,25 @@ var mustMatch = utils.MustMatchFn( func TestNewNumeric(t *testing.T) { tcases := []struct { v sqltypes.Value - out evalResult + out EvalResult err error }{{ v: sqltypes.NewInt64(1), - out: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: EvalResult{typ: querypb.Type_INT64, ival: 1}, }, { v: sqltypes.NewUint64(1), - out: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: EvalResult{typ: querypb.Type_UINT64, uval: 1}, }, { v: sqltypes.NewFloat64(1), - out: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, }, { // For non-number type, Int64 is the default. v: sqltypes.TestValue(querypb.Type_VARCHAR, "1"), - out: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: EvalResult{typ: querypb.Type_INT64, ival: 1}, }, { // If Int64 can't work, we use Float64. v: sqltypes.TestValue(querypb.Type_VARCHAR, "1.2"), - out: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2}, + out: EvalResult{typ: querypb.Type_FLOAT64, fval: 1.2}, }, { // Only valid Int64 allowed if type is Int64. v: sqltypes.TestValue(querypb.Type_INT64, "1.2"), @@ -923,7 +923,7 @@ func TestNewNumeric(t *testing.T) { err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseFloat: parsing \"abcd\": invalid syntax"), }, { v: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), - out: evalResult{typ: querypb.Type_FLOAT64, fval: 0}, + out: EvalResult{typ: querypb.Type_FLOAT64, fval: 0}, }} for _, tcase := range tcases { got, err := newEvalResult(tcase.v) @@ -941,25 +941,25 @@ func TestNewNumeric(t *testing.T) { func TestNewIntegralNumeric(t *testing.T) { tcases := []struct { v sqltypes.Value - out evalResult + out EvalResult err error }{{ v: sqltypes.NewInt64(1), - out: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: EvalResult{typ: querypb.Type_INT64, ival: 1}, }, { v: sqltypes.NewUint64(1), - out: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: EvalResult{typ: querypb.Type_UINT64, uval: 1}, }, { v: sqltypes.NewFloat64(1), - out: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: EvalResult{typ: querypb.Type_INT64, ival: 1}, }, { // For non-number type, Int64 is the default. v: sqltypes.TestValue(querypb.Type_VARCHAR, "1"), - out: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: EvalResult{typ: querypb.Type_INT64, ival: 1}, }, { // If Int64 can't work, we use Uint64. v: sqltypes.TestValue(querypb.Type_VARCHAR, "18446744073709551615"), - out: evalResult{typ: querypb.Type_UINT64, uval: 18446744073709551615}, + out: EvalResult{typ: querypb.Type_UINT64, uval: 18446744073709551615}, }, { // Only valid Int64 allowed if type is Int64. v: sqltypes.TestValue(querypb.Type_INT64, "1.2"), @@ -987,52 +987,52 @@ func TestNewIntegralNumeric(t *testing.T) { func TestAddNumeric(t *testing.T) { tcases := []struct { - v1, v2 evalResult - out evalResult + v1, v2 EvalResult + out EvalResult err error }{{ - v1: evalResult{typ: querypb.Type_INT64, ival: 1}, - v2: evalResult{typ: querypb.Type_INT64, ival: 2}, - out: evalResult{typ: querypb.Type_INT64, ival: 3}, + v1: EvalResult{typ: querypb.Type_INT64, ival: 1}, + v2: EvalResult{typ: querypb.Type_INT64, ival: 2}, + out: EvalResult{typ: querypb.Type_INT64, ival: 3}, }, { - v1: evalResult{typ: querypb.Type_INT64, ival: 1}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, - out: evalResult{typ: querypb.Type_UINT64, uval: 3}, + v1: EvalResult{typ: querypb.Type_INT64, ival: 1}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 2}, + out: EvalResult{typ: querypb.Type_UINT64, uval: 3}, }, { - v1: evalResult{typ: querypb.Type_INT64, ival: 1}, - v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, - out: evalResult{typ: querypb.Type_FLOAT64, fval: 3}, + v1: EvalResult{typ: querypb.Type_INT64, ival: 1}, + v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: EvalResult{typ: querypb.Type_FLOAT64, fval: 3}, }, { - v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, - out: evalResult{typ: querypb.Type_UINT64, uval: 3}, + v1: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 2}, + out: EvalResult{typ: querypb.Type_UINT64, uval: 3}, }, { - v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, - v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, - out: evalResult{typ: querypb.Type_FLOAT64, fval: 3}, + v1: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: EvalResult{typ: querypb.Type_FLOAT64, fval: 3}, }, { - v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, - v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, - out: evalResult{typ: querypb.Type_FLOAT64, fval: 3}, + v1: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: EvalResult{typ: querypb.Type_FLOAT64, fval: 3}, }, { // Int64 overflow. - v1: evalResult{typ: querypb.Type_INT64, ival: 9223372036854775807}, - v2: evalResult{typ: querypb.Type_INT64, ival: 2}, - out: evalResult{typ: querypb.Type_FLOAT64, fval: 9223372036854775809}, + v1: EvalResult{typ: querypb.Type_INT64, ival: 9223372036854775807}, + v2: EvalResult{typ: querypb.Type_INT64, ival: 2}, + out: EvalResult{typ: querypb.Type_FLOAT64, fval: 9223372036854775809}, }, { // Int64 underflow. - v1: evalResult{typ: querypb.Type_INT64, ival: -9223372036854775807}, - v2: evalResult{typ: querypb.Type_INT64, ival: -2}, - out: evalResult{typ: querypb.Type_FLOAT64, fval: -9223372036854775809}, + v1: EvalResult{typ: querypb.Type_INT64, ival: -9223372036854775807}, + v2: EvalResult{typ: querypb.Type_INT64, ival: -2}, + out: EvalResult{typ: querypb.Type_FLOAT64, fval: -9223372036854775809}, }, { - v1: evalResult{typ: querypb.Type_INT64, ival: -1}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, - out: evalResult{typ: querypb.Type_FLOAT64, fval: 18446744073709551617}, + v1: EvalResult{typ: querypb.Type_INT64, ival: -1}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 2}, + out: EvalResult{typ: querypb.Type_FLOAT64, fval: 18446744073709551617}, }, { // Uint64 overflow. - v1: evalResult{typ: querypb.Type_UINT64, uval: 18446744073709551615}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, - out: evalResult{typ: querypb.Type_FLOAT64, fval: 18446744073709551617}, + v1: EvalResult{typ: querypb.Type_UINT64, uval: 18446744073709551615}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 2}, + out: EvalResult{typ: querypb.Type_FLOAT64, fval: 18446744073709551617}, }} for _, tcase := range tcases { got := addNumeric(tcase.v1, tcase.v2) @@ -1042,15 +1042,15 @@ func TestAddNumeric(t *testing.T) { } func TestPrioritize(t *testing.T) { - ival := evalResult{typ: querypb.Type_INT64, ival: -1} - uval := evalResult{typ: querypb.Type_UINT64, uval: 1} - fval := evalResult{typ: querypb.Type_FLOAT64, fval: 1.2} - textIntval := evalResult{typ: querypb.Type_VARBINARY, bytes: []byte("-1")} - textFloatval := evalResult{typ: querypb.Type_VARBINARY, bytes: []byte("1.2")} + ival := EvalResult{typ: querypb.Type_INT64, ival: -1} + uval := EvalResult{typ: querypb.Type_UINT64, uval: 1} + fval := EvalResult{typ: querypb.Type_FLOAT64, fval: 1.2} + textIntval := EvalResult{typ: querypb.Type_VARBINARY, bytes: []byte("-1")} + textFloatval := EvalResult{typ: querypb.Type_VARBINARY, bytes: []byte("1.2")} tcases := []struct { - v1, v2 evalResult - out1, out2 evalResult + v1, v2 EvalResult + out1, out2 EvalResult }{{ v1: ival, v2: uval, @@ -1104,57 +1104,57 @@ func TestPrioritize(t *testing.T) { func TestCastFromNumeric(t *testing.T) { tcases := []struct { typ querypb.Type - v evalResult + v EvalResult out sqltypes.Value err error }{{ typ: querypb.Type_INT64, - v: evalResult{typ: querypb.Type_INT64, ival: 1}, + v: EvalResult{typ: querypb.Type_INT64, ival: 1}, out: sqltypes.NewInt64(1), }, { typ: querypb.Type_INT64, - v: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v: EvalResult{typ: querypb.Type_UINT64, uval: 1}, out: sqltypes.NewInt64(1), }, { typ: querypb.Type_INT64, - v: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + v: EvalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, out: sqltypes.NewInt64(0), }, { typ: querypb.Type_UINT64, - v: evalResult{typ: querypb.Type_INT64, ival: 1}, + v: EvalResult{typ: querypb.Type_INT64, ival: 1}, out: sqltypes.NewUint64(1), }, { typ: querypb.Type_UINT64, - v: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v: EvalResult{typ: querypb.Type_UINT64, uval: 1}, out: sqltypes.NewUint64(1), }, { typ: querypb.Type_UINT64, - v: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + v: EvalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, out: sqltypes.NewUint64(0), }, { typ: querypb.Type_FLOAT64, - v: evalResult{typ: querypb.Type_INT64, ival: 1}, + v: EvalResult{typ: querypb.Type_INT64, ival: 1}, out: sqltypes.TestValue(querypb.Type_FLOAT64, "1"), }, { typ: querypb.Type_FLOAT64, - v: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v: EvalResult{typ: querypb.Type_UINT64, uval: 1}, out: sqltypes.TestValue(querypb.Type_FLOAT64, "1"), }, { typ: querypb.Type_FLOAT64, - v: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + v: EvalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, out: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2e-16"), }, { typ: querypb.Type_DECIMAL, - v: evalResult{typ: querypb.Type_INT64, ival: 1}, + v: EvalResult{typ: querypb.Type_INT64, ival: 1}, out: sqltypes.TestValue(querypb.Type_DECIMAL, "1"), }, { typ: querypb.Type_DECIMAL, - v: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v: EvalResult{typ: querypb.Type_UINT64, uval: 1}, out: sqltypes.TestValue(querypb.Type_DECIMAL, "1"), }, { // For float, we should not use scientific notation. typ: querypb.Type_DECIMAL, - v: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + v: EvalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, out: sqltypes.TestValue(querypb.Type_DECIMAL, "0.00000000000000012"), }} for _, tcase := range tcases { @@ -1168,125 +1168,125 @@ func TestCastFromNumeric(t *testing.T) { func TestCompareNumeric(t *testing.T) { tcases := []struct { - v1, v2 evalResult + v1, v2 EvalResult out int }{{ - v1: evalResult{typ: querypb.Type_INT64, ival: 1}, - v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + v1: EvalResult{typ: querypb.Type_INT64, ival: 1}, + v2: EvalResult{typ: querypb.Type_INT64, ival: 1}, out: 0, }, { - v1: evalResult{typ: querypb.Type_INT64, ival: 1}, - v2: evalResult{typ: querypb.Type_INT64, ival: 2}, + v1: EvalResult{typ: querypb.Type_INT64, ival: 1}, + v2: EvalResult{typ: querypb.Type_INT64, ival: 2}, out: -1, }, { - v1: evalResult{typ: querypb.Type_INT64, ival: 2}, - v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + v1: EvalResult{typ: querypb.Type_INT64, ival: 2}, + v2: EvalResult{typ: querypb.Type_INT64, ival: 1}, out: 1, }, { // Special case. - v1: evalResult{typ: querypb.Type_INT64, ival: -1}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v1: EvalResult{typ: querypb.Type_INT64, ival: -1}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 1}, out: -1, }, { - v1: evalResult{typ: querypb.Type_INT64, ival: 1}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v1: EvalResult{typ: querypb.Type_INT64, ival: 1}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 1}, out: 0, }, { - v1: evalResult{typ: querypb.Type_INT64, ival: 1}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + v1: EvalResult{typ: querypb.Type_INT64, ival: 1}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 2}, out: -1, }, { - v1: evalResult{typ: querypb.Type_INT64, ival: 2}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v1: EvalResult{typ: querypb.Type_INT64, ival: 2}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 1}, out: 1, }, { - v1: evalResult{typ: querypb.Type_INT64, ival: 1}, - v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v1: EvalResult{typ: querypb.Type_INT64, ival: 1}, + v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, out: 0, }, { - v1: evalResult{typ: querypb.Type_INT64, ival: 1}, - v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + v1: EvalResult{typ: querypb.Type_INT64, ival: 1}, + v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 2}, out: -1, }, { - v1: evalResult{typ: querypb.Type_INT64, ival: 2}, - v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v1: EvalResult{typ: querypb.Type_INT64, ival: 2}, + v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, out: 1, }, { // Special case. - v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, - v2: evalResult{typ: querypb.Type_INT64, ival: -1}, + v1: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: EvalResult{typ: querypb.Type_INT64, ival: -1}, out: 1, }, { - v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, - v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + v1: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: EvalResult{typ: querypb.Type_INT64, ival: 1}, out: 0, }, { - v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, - v2: evalResult{typ: querypb.Type_INT64, ival: 2}, + v1: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: EvalResult{typ: querypb.Type_INT64, ival: 2}, out: -1, }, { - v1: evalResult{typ: querypb.Type_UINT64, uval: 2}, - v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + v1: EvalResult{typ: querypb.Type_UINT64, uval: 2}, + v2: EvalResult{typ: querypb.Type_INT64, ival: 1}, out: 1, }, { - v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v1: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 1}, out: 0, }, { - v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + v1: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 2}, out: -1, }, { - v1: evalResult{typ: querypb.Type_UINT64, uval: 2}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v1: EvalResult{typ: querypb.Type_UINT64, uval: 2}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 1}, out: 1, }, { - v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, - v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v1: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, out: 0, }, { - v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, - v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + v1: EvalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 2}, out: -1, }, { - v1: evalResult{typ: querypb.Type_UINT64, uval: 2}, - v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v1: EvalResult{typ: querypb.Type_UINT64, uval: 2}, + v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, out: 1, }, { - v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, - v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + v1: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: EvalResult{typ: querypb.Type_INT64, ival: 1}, out: 0, }, { - v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, - v2: evalResult{typ: querypb.Type_INT64, ival: 2}, + v1: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: EvalResult{typ: querypb.Type_INT64, ival: 2}, out: -1, }, { - v1: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, - v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + v1: EvalResult{typ: querypb.Type_FLOAT64, fval: 2}, + v2: EvalResult{typ: querypb.Type_INT64, ival: 1}, out: 1, }, { - v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v1: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 1}, out: 0, }, { - v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + v1: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 2}, out: -1, }, { - v1: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, - v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v1: EvalResult{typ: querypb.Type_FLOAT64, fval: 2}, + v2: EvalResult{typ: querypb.Type_UINT64, uval: 1}, out: 1, }, { - v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, - v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v1: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, out: 0, }, { - v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, - v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + v1: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 2}, out: -1, }, { - v1: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, - v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v1: EvalResult{typ: querypb.Type_FLOAT64, fval: 2}, + v2: EvalResult{typ: querypb.Type_FLOAT64, fval: 1}, out: 1, }} for _, tcase := range tcases { @@ -1460,8 +1460,8 @@ func BenchmarkAddGoInterface(b *testing.B) { } func BenchmarkAddGoNonInterface(b *testing.B) { - v1 := evalResult{typ: querypb.Type_INT64, ival: 1} - v2 := evalResult{typ: querypb.Type_INT64, ival: 12} + v1 := EvalResult{typ: querypb.Type_INT64, ival: 1} + v2 := EvalResult{typ: querypb.Type_INT64, ival: 12} for i := 0; i < b.N; i++ { if v1.typ != querypb.Type_INT64 { b.Error("type assertion failed") @@ -1469,7 +1469,7 @@ func BenchmarkAddGoNonInterface(b *testing.B) { if v2.typ != querypb.Type_INT64 { b.Error("type assertion failed") } - v1 = evalResult{typ: querypb.Type_INT64, ival: v1.ival + v2.ival} + v1 = EvalResult{typ: querypb.Type_INT64, ival: v1.ival + v2.ival} } } diff --git a/go/vt/vtgate/evalengine/casting.go b/go/vt/vtgate/evalengine/casting.go new file mode 100644 index 00000000000..0ea9bf80bc7 --- /dev/null +++ b/go/vt/vtgate/evalengine/casting.go @@ -0,0 +1,58 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "strings" + + "vitess.io/vitess/go/sqltypes" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) + +//ToBooleanStrict is used when the casting to a boolean has to be minimally forgiving, +//such as when assigning to a system variable that is expected to be a boolean +func (e *EvalResult) ToBooleanStrict() (bool, error) { + intToBool := func(i int) (bool, error) { + switch i { + case 0: + return false, nil + case 1: + return true, nil + default: + return false, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%d is not a boolean", i) + } + } + + switch e.typ { + case sqltypes.Int8, sqltypes.Int16, sqltypes.Int32, sqltypes.Int64: + return intToBool(int(e.ival)) + case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint32, sqltypes.Uint64: + return intToBool(int(e.uval)) + case sqltypes.VarBinary: + lower := strings.ToLower(string(e.bytes)) + switch lower { + case "on": + return true, nil + case "off": + return false, nil + default: + return false, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "'%s' is not a boolean", lower) + } + } + return false, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "is not a boolean") +} diff --git a/go/vt/vtgate/evalengine/casting_test.go b/go/vt/vtgate/evalengine/casting_test.go new file mode 100644 index 00000000000..d27b773b057 --- /dev/null +++ b/go/vt/vtgate/evalengine/casting_test.go @@ -0,0 +1,84 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/sqltypes" +) + +func TestEvalResultToBooleanStrict(t *testing.T) { + trueValues := []*EvalResult{{ + typ: sqltypes.Int64, + ival: 1, + }, { + typ: sqltypes.Uint64, + uval: 1, + }, { + typ: sqltypes.Int8, + ival: 1, + }} + + falseValues := []*EvalResult{{ + typ: sqltypes.Int64, + ival: 0, + }, { + typ: sqltypes.Uint64, + uval: 0, + }, { + typ: sqltypes.Int8, + uval: 0, + }} + + invalid := []*EvalResult{{ + typ: sqltypes.VarChar, + bytes: []byte("foobar"), + }, { + typ: sqltypes.Float32, + fval: 1, + }, { + typ: sqltypes.Int64, + ival: 12, + }} + + for _, res := range trueValues { + name := res.debugString() + t.Run(fmt.Sprintf("ToBooleanStrict() %s expected true (success)", name), func(t *testing.T) { + result, err := res.ToBooleanStrict() + require.NoError(t, err, name) + require.Equal(t, true, result, name) + }) + } + for _, res := range falseValues { + name := res.debugString() + t.Run(fmt.Sprintf("ToBooleanStrict() %s expected false (success)", name), func(t *testing.T) { + result, err := res.ToBooleanStrict() + require.NoError(t, err, name) + require.Equal(t, false, result, name) + }) + } + for _, res := range invalid { + name := res.debugString() + t.Run(fmt.Sprintf("ToBooleanStrict() %s expected fail", name), func(t *testing.T) { + _, err := res.ToBooleanStrict() + require.Error(t, err) + }) + } +} diff --git a/go/vt/vtgate/evalengine/expressions.go b/go/vt/vtgate/evalengine/expressions.go index 9679a240312..d561b17551f 100644 --- a/go/vt/vtgate/evalengine/expressions.go +++ b/go/vt/vtgate/evalengine/expressions.go @@ -28,7 +28,7 @@ import ( ) type ( - evalResult struct { + EvalResult struct { typ querypb.Type ival int64 uval uint64 @@ -42,9 +42,6 @@ type ( Row []sqltypes.Value } - // EvalResult is used so we don't have to expose all parts of the private struct - EvalResult = evalResult - // Expr is the interface that all evaluating expressions must implement Expr interface { Evaluate(env ExpressionEnv) (EvalResult, error) @@ -80,13 +77,18 @@ func (e EvalResult) Value() sqltypes.Value { return castFromNumeric(e, e.typ) } -//NewLiteralInt returns a literal expression -func NewLiteralInt(val []byte) (Expr, error) { +//NewLiteralIntFromBytes returns a literal expression +func NewLiteralIntFromBytes(val []byte) (Expr, error) { ival, err := strconv.ParseInt(string(val), 10, 64) if err != nil { return nil, err } - return &Literal{evalResult{typ: sqltypes.Int64, ival: ival}}, nil + return NewLiteralInt(ival), nil +} + +//NewLiteralInt returns a literal expression +func NewLiteralInt(i int64) Expr { + return &Literal{EvalResult{typ: sqltypes.Int64, ival: i}} } //NewLiteralFloat returns a literal expression @@ -95,12 +97,24 @@ func NewLiteralFloat(val []byte) (Expr, error) { if err != nil { return nil, err } - return &Literal{evalResult{typ: sqltypes.Float64, fval: fval}}, nil + return &Literal{EvalResult{typ: sqltypes.Float64, fval: fval}}, nil } //NewLiteralFloat returns a literal expression -func NewLiteralString(val []byte) (Expr, error) { - return &Literal{evalResult{typ: sqltypes.VarBinary, bytes: val}}, nil +func NewLiteralString(val []byte) Expr { + return &Literal{EvalResult{typ: sqltypes.VarBinary, bytes: val}} +} + +//NewBindVar returns a bind variable +func NewBindVar(key string) Expr { + return &BindVariable{Key: key} +} + +//NewColumn returns a bind variable +func NewColumn(offset int) Expr { + return &Column{ + Offset: offset, + } } var _ Expr = (*Literal)(nil) @@ -248,7 +262,7 @@ func (l *Literal) String() string { //String implements the Expr interface func (c *Column) String() string { - return fmt.Sprintf("[%d]", c.Offset) + return fmt.Sprintf("column %d from the input", c.Offset) } func mergeNumericalTypes(ltype, rtype querypb.Type) querypb.Type { @@ -272,23 +286,28 @@ func evaluateByType(val *querypb.BindVariable) (EvalResult, error) { if err != nil { ival = 0 } - return evalResult{typ: sqltypes.Int64, ival: ival}, nil + return EvalResult{typ: sqltypes.Int64, ival: ival}, nil case sqltypes.Uint64: uval, err := strconv.ParseUint(string(val.Value), 10, 64) if err != nil { uval = 0 } - return evalResult{typ: sqltypes.Uint64, uval: uval}, nil + return EvalResult{typ: sqltypes.Uint64, uval: uval}, nil case sqltypes.Float64: fval, err := strconv.ParseFloat(string(val.Value), 64) if err != nil { fval = 0 } - return evalResult{typ: sqltypes.Float64, fval: fval}, nil + return EvalResult{typ: sqltypes.Float64, fval: fval}, nil case sqltypes.VarChar, sqltypes.Text, sqltypes.VarBinary: - return evalResult{typ: sqltypes.VarBinary, bytes: val.Value}, nil + return EvalResult{typ: sqltypes.VarBinary, bytes: val.Value}, nil case sqltypes.Null: - return evalResult{typ: sqltypes.Null}, nil + return EvalResult{typ: sqltypes.Null}, nil } - return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Type is not supported: %s", val.Type.String()) + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Type is not supported: %s", val.Type.String()) +} + +// debugString is +func (e *EvalResult) debugString() string { + return fmt.Sprintf("(%s) %d %d %f %s", querypb.Type_name[int32(e.typ)], e.ival, e.uval, e.fval, string(e.bytes)) } diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 87e06f36641..22fb205e2dd 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -29,8 +29,6 @@ import ( "sync" "time" - vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" - "golang.org/x/net/context" "vitess.io/vitess/go/trace" @@ -222,7 +220,7 @@ func (e *Executor) legacyExecute(ctx context.Context, safeSession *SafeSession, sqlparser.StmtDelete, sqlparser.StmtDDL, sqlparser.StmtUse, sqlparser.StmtExplain, sqlparser.StmtOther: return 0, nil, vterrors.New(vtrpcpb.Code_INTERNAL, "BUG: not reachable as handled with plan execute") case sqlparser.StmtSet: - qr, err := e.handleSet(ctx, safeSession, sql, logStats) + qr, err := e.handleSet(ctx, sql, logStats) return sqlparser.StmtSet, qr, err case sqlparser.StmtShow: qr, err := e.handleShow(ctx, safeSession, sql, bindVars, dest, destKeyspace, destTabletType, logStats) @@ -272,7 +270,7 @@ func (e *Executor) destinationExec(ctx context.Context, safeSession *SafeSession return e.resolver.Execute(ctx, sql, bindVars, destKeyspace, destTabletType, dest, safeSession, safeSession.Options, logStats, false /* canAutocommit */, ignoreMaxMemoryRows) } -func (e *Executor) handleBegin(ctx context.Context, safeSession *SafeSession, destTabletType topodatapb.TabletType, logStats *LogStats) (*sqltypes.Result, error) { +func (e *Executor) handleBegin(ctx context.Context, safeSession *SafeSession, logStats *LogStats) (*sqltypes.Result, error) { execStart := time.Now() logStats.PlanTime = execStart.Sub(logStats.StartTime) err := e.txConn.Begin(ctx, safeSession) @@ -294,6 +292,11 @@ func (e *Executor) handleCommit(ctx context.Context, safeSession *SafeSession, l return &sqltypes.Result{}, err } +//Commit commits the existing transactions +func (e *Executor) Commit(ctx context.Context, safeSession *SafeSession) error { + return e.txConn.Commit(ctx, safeSession) +} + func (e *Executor) handleRollback(ctx context.Context, safeSession *SafeSession, logStats *LogStats) (*sqltypes.Result, error) { execStart := time.Now() logStats.PlanTime = execStart.Sub(logStats.StartTime) @@ -347,7 +350,7 @@ func (e *Executor) CloseSession(ctx context.Context, safeSession *SafeSession) e return e.txConn.ReleaseAll(ctx, safeSession) } -func (e *Executor) handleSet(ctx context.Context, safeSession *SafeSession, sql string, logStats *LogStats) (*sqltypes.Result, error) { +func (e *Executor) handleSet(ctx context.Context, sql string, logStats *LogStats) (*sqltypes.Result, error) { stmt, err := sqlparser.Parse(sql) if err != nil { return nil, err @@ -382,14 +385,6 @@ func (e *Executor) handleSet(ctx context.Context, safeSession *SafeSession, sql _, out := sqlparser.NewStringTokenizer(expr.Name.Lowered()).Scan() name := string(out) switch expr.Scope { - case sqlparser.GlobalStr: - return nil, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unsupported in set: global") - case sqlparser.SessionStr: - value, err = getValueFor(expr) - if err != nil { - return nil, err - } - err = handleSessionSetting(ctx, name, safeSession, value, e.txConn, sql) case sqlparser.VitessMetadataStr: value, err = getValueFor(expr) if err != nil { @@ -400,8 +395,8 @@ func (e *Executor) handleSet(ctx context.Context, safeSession *SafeSession, sql return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for charset: %v", value) } _, err = e.handleSetVitessMetadata(ctx, name, val) - case "": // we should only get here with UDVs - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "should have been handled by planning") + default: + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "should have been handled by planning: %s", sql) } if err != nil { @@ -448,129 +443,6 @@ func getValueFor(expr *sqlparser.SetExpr) (interface{}, error) { default: return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid syntax: %s", sqlparser.String(expr)) } - -} - -func handleSessionSetting(ctx context.Context, name string, session *SafeSession, value interface{}, conn *TxConn, sql string) error { - switch name { - case "autocommit": - val, err := validateSetOnOff(value, name) - if err != nil { - return err - } - - switch val { - case 0: - session.Autocommit = false - case 1: - if session.InTransaction() { - if err := conn.Commit(ctx, session); err != nil { - return err - } - } - session.Autocommit = true - default: - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value for autocommit: %d", val) - } - case "client_found_rows": - val, ok := value.(int64) - if !ok { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for client_found_rows: %T", value) - } - if session.Options == nil { - session.Options = &querypb.ExecuteOptions{} - } - switch val { - case 0: - session.Options.ClientFoundRows = false - case 1: - session.Options.ClientFoundRows = true - default: - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value for client_found_rows: %d", val) - } - case "skip_query_plan_cache": - val, ok := value.(int64) - if !ok { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for skip_query_plan_cache: %T", value) - } - if session.Options == nil { - session.Options = &querypb.ExecuteOptions{} - } - switch val { - case 0: - session.Options.SkipQueryPlanCache = false - case 1: - session.Options.SkipQueryPlanCache = true - default: - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value for skip_query_plan_cache: %d", val) - } - case "transaction_mode": - val, ok := value.(string) - if !ok { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for transaction_mode: %T", value) - } - out, ok := vtgatepb.TransactionMode_value[strings.ToUpper(val)] - if !ok { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid transaction_mode: %s", val) - } - session.TransactionMode = vtgatepb.TransactionMode(out) - case "tx_read_only", "transaction_read_only": // TODO move this to set tx - val, err := validateSetOnOff(value, name) - if err != nil { - return err - } - switch val { - case 0, 1: - // TODO (4127): This is a dangerous NOP. - default: - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value for %v: %d", name, val) - } - case "workload": - val, ok := value.(string) - if !ok { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for workload: %T", value) - } - out, ok := querypb.ExecuteOptions_Workload_value[strings.ToUpper(val)] - if !ok { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid workload: %s", val) - } - if session.Options == nil { - session.Options = &querypb.ExecuteOptions{} - } - session.Options.Workload = querypb.ExecuteOptions_Workload(out) - case "sql_select_limit": - var val int64 - - switch cast := value.(type) { - case int64: - val = cast - case string: - if !strings.EqualFold(cast, "default") { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected string value for sql_select_limit: %v", value) - } - default: - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for sql_select_limit: %T", value) - } - - if session.Options == nil { - session.Options = &querypb.ExecuteOptions{} - } - session.Options.SqlSelectLimit = val - case "charset", "names": - val, ok := value.(string) - if !ok { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for charset/names: %T", value) - } - switch val { - case "", "utf8", "utf8mb4", "latin1", "default": - break - default: - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value for charset/names: %v", val) - } - default: - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsupported construct: %s", sql) - } - return nil } func (e *Executor) handleSetVitessMetadata(ctx context.Context, name, value string) (*sqltypes.Result, error) { @@ -631,26 +503,6 @@ func (e *Executor) handleShowVitessMetadata(ctx context.Context, opt *sqlparser. }, nil } -func validateSetOnOff(v interface{}, typ string) (int64, error) { - var val int64 - switch v := v.(type) { - case int64: - val = v - case string: - lcaseV := strings.ToLower(v) - if lcaseV == "on" { - val = 1 - } else if lcaseV == "off" { - val = 0 - } else { - return -1, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value for %s: %s", typ, v) - } - default: - return -1, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for %s: %T", typ, v) - } - return val, nil -} - func (e *Executor) handleShow(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, dest key.Destination, destKeyspace string, destTabletType topodatapb.TabletType, logStats *LogStats) (*sqltypes.Result, error) { stmt, err := sqlparser.Parse(sql) if err != nil { diff --git a/go/vt/vtgate/executor_set_test.go b/go/vt/vtgate/executor_set_test.go index 86e6181795c..bde58c4af98 100644 --- a/go/vt/vtgate/executor_set_test.go +++ b/go/vt/vtgate/executor_set_test.go @@ -17,6 +17,7 @@ limitations under the License. package vtgate import ( + "fmt" "testing" querypb "vitess.io/vitess/go/vt/proto/query" @@ -45,23 +46,17 @@ func TestExecutorSet(t *testing.T) { out *vtgatepb.Session err string }{{ - in: "set autocommit = 1, client_found_rows = 1", - out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{ClientFoundRows: true}}, - }, { in: "set @@autocommit = true", out: &vtgatepb.Session{Autocommit: true}, + }, { + in: "set autocommit = 1, client_found_rows = 1", + out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{ClientFoundRows: true}}, }, { in: "set @@session.autocommit = true", out: &vtgatepb.Session{Autocommit: true}, }, { in: "set @@session.`autocommit` = true", out: &vtgatepb.Session{Autocommit: true}, - }, { - in: "set @@session.'autocommit' = true", - out: &vtgatepb.Session{Autocommit: true}, - }, { - in: "set @@session.\"autocommit\" = true", - out: &vtgatepb.Session{Autocommit: true}, }, { in: "set autocommit = true", out: &vtgatepb.Session{Autocommit: true}, @@ -94,10 +89,10 @@ func TestExecutorSet(t *testing.T) { out: &vtgatepb.Session{}, }, { in: "set AUTOCOMMIT = 'aa'", - err: "unexpected value for autocommit: aa", + err: "System setting 'autocommit' can't be set to this value: 'aa' is not a boolean", }, { in: "set autocommit = 2", - err: "unexpected value for autocommit: 2", + err: "System setting 'autocommit' can't be set to this value: 2 is not a boolean", }, { in: "set client_found_rows = 1", out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{ClientFoundRows: true}}, @@ -112,19 +107,19 @@ func TestExecutorSet(t *testing.T) { out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{}}, }, { in: "set @@global.client_found_rows = 1", - err: "unsupported in set: global", + err: "unsupported global scope in set: global client_found_rows = 1", }, { in: "set global client_found_rows = 1", - err: "unsupported in set: global", + err: "unsupported global scope in set: global client_found_rows = 1", }, { in: "set global @@session.client_found_rows = 1", err: "cannot use scope and @@", }, { in: "set client_found_rows = 'aa'", - err: "unexpected value type for client_found_rows: string", + err: "System setting 'client_found_rows' can't be set to this value: 'aa' is not a boolean", }, { in: "set client_found_rows = 2", - err: "unexpected value for client_found_rows: 2", + err: "System setting 'client_found_rows' can't be set to this value: 2 is not a boolean", }, { in: "set transaction_mode = 'unspecified'", out: &vtgatepb.Session{Autocommit: true, TransactionMode: vtgatepb.TransactionMode_UNSPECIFIED}, @@ -145,7 +140,7 @@ func TestExecutorSet(t *testing.T) { err: "invalid transaction_mode: aa", }, { in: "set transaction_mode = 1", - err: "unexpected value type for transaction_mode: int64", + err: "unexpected value type for transaction_mode: INT64", }, { in: "set workload = 'unspecified'", out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{Workload: querypb.ExecuteOptions_UNSPECIFIED}}, @@ -163,7 +158,7 @@ func TestExecutorSet(t *testing.T) { err: "invalid workload: aa", }, { in: "set workload = 1", - err: "unexpected value type for workload: int64", + err: "unexpected value type for workload: INT64", }, { in: "set transaction_mode = 'twopc', autocommit=1", out: &vtgatepb.Session{Autocommit: true, TransactionMode: vtgatepb.TransactionMode_TWOPC}, @@ -175,13 +170,19 @@ func TestExecutorSet(t *testing.T) { out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{SqlSelectLimit: 0}}, }, { in: "set sql_select_limit = 'asdfasfd'", - err: "unexpected string value for sql_select_limit: asdfasfd", + err: "unexpected value type for sql_select_limit: string", }, { in: "set autocommit = 1+1", - err: "invalid syntax: 1 + 1", + err: "System setting 'autocommit' can't be set to this value: 2 is not a boolean", + }, { + in: "set autocommit = 1+0", + out: &vtgatepb.Session{Autocommit: true}, + }, { + in: "set autocommit = default", + out: &vtgatepb.Session{Autocommit: true}, }, { in: "set foo = 1", - err: "unsupported construct: set foo = 1", + err: "unsupported construct in set: session foo = 1", }, { in: "set names utf8", out: &vtgatepb.Session{Autocommit: true}, @@ -205,10 +206,10 @@ func TestExecutorSet(t *testing.T) { out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{}}, }, { in: "set tx_read_only = 2", - err: "unexpected value for tx_read_only: 2", + err: "System setting 'tx_read_only' can't be set to this value: 2 is not a boolean", }, { in: "set transaction_read_only = 2", - err: "unexpected value for transaction_read_only: 2", + err: "System setting 'transaction_read_only' can't be set to this value: 2 is not a boolean", }, { in: "set session transaction isolation level repeatable read", out: &vtgatepb.Session{Autocommit: true}, @@ -234,8 +235,8 @@ func TestExecutorSet(t *testing.T) { in: "set session transaction read write", out: &vtgatepb.Session{Autocommit: true}, }} - for _, tcase := range testcases { - t.Run(tcase.in, func(t *testing.T) { + for i, tcase := range testcases { + t.Run(fmt.Sprintf("%d-%s", i, tcase.in), func(t *testing.T) { session := NewSafeSession(&vtgatepb.Session{Autocommit: true}) _, err := executorEnv.Execute(context.Background(), "TestExecute", session, tcase.in, nil) if tcase.err == "" { diff --git a/go/vt/vtgate/plan_execute.go b/go/vt/vtgate/plan_execute.go index f2180a68980..61af9a69286 100644 --- a/go/vt/vtgate/plan_execute.go +++ b/go/vt/vtgate/plan_execute.go @@ -77,7 +77,7 @@ func (e *Executor) newExecute(ctx context.Context, safeSession *SafeSession, sql // will fall through and be handled through planning switch plan.Type { case sqlparser.StmtBegin: - qr, err := e.handleBegin(ctx, safeSession, vcursor.tabletType, logStats) + qr, err := e.handleBegin(ctx, safeSession, logStats) return sqlparser.StmtBegin, qr, err case sqlparser.StmtCommit: qr, err := e.handleCommit(ctx, safeSession, logStats) diff --git a/go/vt/vtgate/planbuilder/expression_converter.go b/go/vt/vtgate/planbuilder/expression_converter.go new file mode 100644 index 00000000000..d882201443e --- /dev/null +++ b/go/vt/vtgate/planbuilder/expression_converter.go @@ -0,0 +1,120 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package planbuilder + +import ( + "fmt" + "strings" + + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/evalengine" +) + +type expressionConverter struct { + tabletExpressions []sqlparser.Expr +} + +func booleanValues(astExpr sqlparser.Expr) evalengine.Expr { + switch node := astExpr.(type) { + case *sqlparser.SQLVal: + //set autocommit = 'on' + if node.Type == sqlparser.StrVal { + switch strings.ToLower(string(node.Val)) { + case "on": + return ON + case "off": + return OFF + } + } + case *sqlparser.ColName: + //set autocommit = on + if node.Name.AtCount() == sqlparser.NoAt { + switch node.Name.Lowered() { + case "on": + return ON + case "off": + return OFF + } + } + } + return nil +} + +func identifierAsStringValue(astExpr sqlparser.Expr) evalengine.Expr { + colName, isColName := astExpr.(*sqlparser.ColName) + if isColName { + return evalengine.NewLiteralString([]byte(colName.Name.Lowered())) + } + return nil +} + +func (ec *expressionConverter) convert(astExpr sqlparser.Expr, boolean, identifierAsString bool) (evalengine.Expr, error) { + if boolean { + evalExpr := booleanValues(astExpr) + if evalExpr != nil { + return evalExpr, nil + } + } + if identifierAsString { + evalExpr := identifierAsStringValue(astExpr) + if evalExpr != nil { + return evalExpr, nil + } + } + evalExpr, err := sqlparser.Convert(astExpr) + if err != nil { + if err != sqlparser.ErrExprNotSupported { + return nil, err + } + // We have an expression that we can't handle at the vtgate level + if !expressionOkToDelegateToTablet(astExpr) { + // Uh-oh - this expression is not even safe to delegate to the tablet. Give up. + return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "expression not supported for SET: %s", sqlparser.String(astExpr)) + } + evalExpr = &evalengine.Column{Offset: len(ec.tabletExpressions)} + ec.tabletExpressions = append(ec.tabletExpressions, astExpr) + } + return evalExpr, nil +} + +func (ec *expressionConverter) source(vschema ContextVSchema) (engine.Primitive, error) { + if len(ec.tabletExpressions) == 0 { + return &engine.SingleRow{}, nil + } + ks, dest, err := resolveDestination(vschema) + if err != nil { + return nil, err + } + + var expr []string + for _, e := range ec.tabletExpressions { + expr = append(expr, sqlparser.String(e)) + } + query := fmt.Sprintf("select %s from dual", strings.Join(expr, ",")) + + primitive := &engine.Send{ + Keyspace: ks, + TargetDestination: dest, + Query: query, + IsDML: false, + SingleShardOnly: true, + } + return primitive, nil +} diff --git a/go/vt/vtgate/planbuilder/expression_converter_test.go b/go/vt/vtgate/planbuilder/expression_converter_test.go new file mode 100644 index 00000000000..2ff05d188fa --- /dev/null +++ b/go/vt/vtgate/planbuilder/expression_converter_test.go @@ -0,0 +1,70 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package planbuilder + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/evalengine" +) + +func e(in ...evalengine.Expr) []evalengine.Expr { + return in +} + +func TestConversion(t *testing.T) { + type queriesWithExpectations struct { + expressionsIn string + expressionsOut []evalengine.Expr + } + + queries := []queriesWithExpectations{{ + expressionsIn: "1", + expressionsOut: e(evalengine.NewLiteralInt(1)), + }, { + expressionsIn: "@@foo", + expressionsOut: e(evalengine.NewColumn(0)), + }} + + for _, tc := range queries { + t.Run(tc.expressionsIn, func(t *testing.T) { + statement, err := sqlparser.Parse("select " + tc.expressionsIn) + require.NoError(t, err) + slct := statement.(*sqlparser.Select) + exprs := extract(slct.SelectExprs) + ec := &expressionConverter{} + var result []evalengine.Expr + for _, expr := range exprs { + evalExpr, err := ec.convert(expr, false, false) + require.NoError(t, err) + result = append(result, evalExpr) + } + require.Equal(t, tc.expressionsOut, result) + }) + } +} + +func extract(in sqlparser.SelectExprs) []sqlparser.Expr { + var result []sqlparser.Expr + for _, expr := range in { + result = append(result, expr.(*sqlparser.AliasedExpr).Expr) + } + return result +} diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index bcec982af58..b5d200defec 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -173,7 +173,6 @@ func TestPlan(t *testing.T) { testFile(t, "memory_sort_cases.txt", testOutputTempDir, vschemaWrapper) testFile(t, "use_cases.txt", testOutputTempDir, vschemaWrapper) testFile(t, "set_cases.txt", testOutputTempDir, vschemaWrapper) - testFile(t, "set_sysvar_cases.txt", testOutputTempDir, vschemaWrapper) testFile(t, "union_cases.txt", testOutputTempDir, vschemaWrapper) testFile(t, "transaction_cases.txt", testOutputTempDir, vschemaWrapper) testFile(t, "lock_cases.txt", testOutputTempDir, vschemaWrapper) diff --git a/go/vt/vtgate/planbuilder/set.go b/go/vt/vtgate/planbuilder/set.go index 952abc77c80..fe66e0d4b0a 100644 --- a/go/vt/vtgate/planbuilder/set.go +++ b/go/vt/vtgate/planbuilder/set.go @@ -20,6 +20,8 @@ import ( "fmt" "strings" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vtgate/vindexes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -28,196 +30,24 @@ import ( "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" - "vitess.io/vitess/go/vt/vtgate/evalengine" ) type ( - planFunc = func(expr *sqlparser.SetExpr, vschema ContextVSchema) (engine.SetOp, error) - - expressionConverter struct { - tabletExpressions []*sqlparser.SetExpr - } + planFunc = func(expr *sqlparser.SetExpr, vschema ContextVSchema, ec *expressionConverter) (engine.SetOp, error) setting struct { - name string - boolean bool + name string + boolean bool + defaultValue evalengine.Expr + + // this allows identifiers (a.k.a. ColName) from the AST to be handled as if they are strings. + // SET transaction_mode = two_pc => SET transaction_mode = 'two_pc' + identifierAsString bool } ) var sysVarPlanningFunc = map[string]planFunc{} -var notSupported = []setting{ - {name: "audit_log_read_buffer_size"}, - {name: "auto_increment_increment"}, - {name: "auto_increment_offset"}, - {name: "binlog_direct_non_transactional_updates"}, - {name: "binlog_row_image"}, - {name: "binlog_rows_query_log_events"}, - {name: "innodb_ft_enable_stopword"}, - {name: "innodb_ft_user_stopword_table"}, - {name: "max_points_in_geometry"}, - {name: "max_sp_recursion_depth"}, - {name: "myisam_repair_threads"}, - {name: "myisam_sort_buffer_size"}, - {name: "myisam_stats_method"}, - {name: "ndb_allow_copying_alter_table"}, - {name: "ndb_autoincrement_prefetch_sz"}, - {name: "ndb_blob_read_batch_bytes"}, - {name: "ndb_blob_write_batch_bytes"}, - {name: "ndb_deferred_constraints"}, - {name: "ndb_force_send"}, - {name: "ndb_fully_replicated"}, - {name: "ndb_index_stat_enable"}, - {name: "ndb_index_stat_option"}, - {name: "ndb_join_pushdown"}, - {name: "ndb_log_bin"}, - {name: "ndb_log_exclusive_reads"}, - {name: "ndb_row_checksum"}, - {name: "ndb_use_exact_count"}, - {name: "ndb_use_transactions"}, - {name: "ndbinfo_max_bytes"}, - {name: "ndbinfo_max_rows"}, - {name: "ndbinfo_show_hidden"}, - {name: "ndbinfo_table_prefix"}, - {name: "old_alter_table"}, - {name: "preload_buffer_size"}, - {name: "rbr_exec_mode"}, - {name: "sql_log_off"}, - {name: "thread_pool_high_priority_connection"}, - {name: "thread_pool_prio_kickup_timer"}, - {name: "transaction_write_set_extraction"}, -} - -var ignoreThese = []setting{ - {name: "big_tables", boolean: true}, - {name: "bulk_insert_buffer_size"}, - {name: "debug"}, - {name: "default_storage_engine"}, - {name: "default_tmp_storage_engine"}, - {name: "innodb_strict_mode", boolean: true}, - {name: "innodb_support_xa", boolean: true}, - {name: "innodb_table_locks", boolean: true}, - {name: "innodb_tmpdir"}, - {name: "join_buffer_size"}, - {name: "keep_files_on_create", boolean: true}, - {name: "lc_messages"}, - {name: "long_query_time"}, - {name: "low_priority_updates", boolean: true}, - {name: "max_delayed_threads"}, - {name: "max_insert_delayed_threads"}, - {name: "multi_range_count"}, - {name: "net_buffer_length"}, - {name: "new", boolean: true}, - {name: "query_cache_type"}, - {name: "query_cache_wlock_invalidate", boolean: true}, - {name: "query_prealloc_size"}, - {name: "sql_buffer_result", boolean: true}, - {name: "transaction_alloc_block_size"}, - {name: "wait_timeout"}, -} - -var useReservedConn = []setting{ - {name: "default_week_format"}, - {name: "end_markers_in_json", boolean: true}, - {name: "eq_range_index_dive_limit"}, - {name: "explicit_defaults_for_timestamp"}, - {name: "foreign_key_checks", boolean: true}, - {name: "group_concat_max_len"}, - {name: "max_heap_table_size"}, - {name: "max_seeks_for_key"}, - {name: "max_tmp_tables"}, - {name: "min_examined_row_limit"}, - {name: "old_passwords"}, - {name: "optimizer_prune_level"}, - {name: "optimizer_search_depth"}, - {name: "optimizer_switch"}, - {name: "optimizer_trace"}, - {name: "optimizer_trace_features"}, - {name: "optimizer_trace_limit"}, - {name: "optimizer_trace_max_mem_size"}, - {name: "transaction_isolation"}, - {name: "tx_isolation"}, - {name: "optimizer_trace_offset"}, - {name: "parser_max_mem_size"}, - {name: "profiling", boolean: true}, - {name: "profiling_history_size"}, - {name: "query_alloc_block_size"}, - {name: "range_alloc_block_size"}, - {name: "range_optimizer_max_mem_size"}, - {name: "read_buffer_size"}, - {name: "read_rnd_buffer_size"}, - {name: "show_create_table_verbosity", boolean: true}, - {name: "show_old_temporals", boolean: true}, - {name: "sort_buffer_size"}, - {name: "sql_big_selects", boolean: true}, - {name: "sql_mode"}, - {name: "sql_notes", boolean: true}, - {name: "sql_quote_show_create", boolean: true}, - {name: "sql_safe_updates", boolean: true}, - {name: "sql_warnings", boolean: true}, - {name: "tmp_table_size"}, - {name: "transaction_prealloc_size"}, - {name: "unique_checks", boolean: true}, - {name: "updatable_views_with_limit", boolean: true}, -} - -// TODO: Most of these settings should be moved into SysSetOpAware, and change Vitess behaviour. -// Until then, SET statements against these settings are allowed -// as long as they have the same value as the underlying database -var checkAndIgnore = []setting{ - {name: "binlog_format"}, - {name: "block_encryption_mode"}, - {name: "character_set_client"}, - {name: "character_set_connection"}, - {name: "character_set_database"}, - {name: "character_set_filesystem"}, - {name: "character_set_results"}, - {name: "character_set_server"}, - {name: "collation_connection"}, - {name: "collation_database"}, - {name: "collation_server"}, - {name: "completion_type"}, - {name: "div_precision_increment"}, - {name: "innodb_lock_wait_timeout"}, - {name: "interactive_timeout"}, - {name: "lc_time_names"}, - {name: "lock_wait_timeout"}, - {name: "max_allowed_packet"}, - {name: "max_error_count"}, - {name: "max_execution_time"}, - {name: "max_join_size"}, - {name: "max_length_for_sort_data"}, - {name: "max_sort_length"}, - {name: "max_user_connections"}, - {name: "net_read_timeout"}, - {name: "net_retry_count"}, - {name: "net_write_timeout"}, - {name: "session_track_gtids"}, - {name: "session_track_schema", boolean: true}, - {name: "session_track_state_change", boolean: true}, - {name: "session_track_system_variables"}, - {name: "session_track_transaction_info"}, - {name: "sql_auto_is_null", boolean: true}, - {name: "time_zone"}, - {name: "version_tokens_session"}, -} - -func init() { - forSettings(ignoreThese, buildSetOpIgnore) - forSettings(useReservedConn, buildSetOpVarSet) - forSettings(checkAndIgnore, buildSetOpCheckAndIgnore) - forSettings(notSupported, buildNotSupported) -} - -func forSettings(settings []setting, f func(bool) planFunc) { - for _, setting := range settings { - if _, alreadyExists := sysVarPlanningFunc[setting.name]; alreadyExists { - panic("bug in set plan init - " + setting.name + " aleady configured") - } - sysVarPlanningFunc[setting.name] = f(setting.boolean) - } -} - func buildSetPlan(stmt *sqlparser.Set, vschema ContextVSchema) (engine.Primitive, error) { var setOps []engine.SetOp var setOp engine.SetOp @@ -228,13 +58,13 @@ func buildSetPlan(stmt *sqlparser.Set, vschema ContextVSchema) (engine.Primitive for _, expr := range stmt.Exprs { switch expr.Scope { case sqlparser.GlobalStr: - return nil, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unsupported in set: global") + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsupported global scope in set: %s", sqlparser.String(expr)) // AST struct has been prepared before getting here, so no scope here means that // we have a UDV. If the original query didn't explicitly specify the scope, it // would have been explictly set to sqlparser.SessionStr before reaching this // phase of planning case "": - evalExpr, err := ec.convert(expr) + evalExpr, err := ec.convert(expr.Expr /*boolean*/, false /*identifierAsString*/, false) if err != nil { return nil, err } @@ -247,9 +77,9 @@ func buildSetPlan(stmt *sqlparser.Set, vschema ContextVSchema) (engine.Primitive case sqlparser.SessionStr: planFunc, ok := sysVarPlanningFunc[expr.Name.Lowered()] if !ok { - return nil, ErrPlanNotSupported + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsupported construct in set: %s", sqlparser.String(expr)) } - setOp, err = planFunc(expr, vschema) + setOp, err = planFunc(expr, vschema, ec) if err != nil { return nil, err } @@ -270,67 +100,24 @@ func buildSetPlan(stmt *sqlparser.Set, vschema ContextVSchema) (engine.Primitive }, nil } -func (spb *expressionConverter) convert(setExpr *sqlparser.SetExpr) (evalengine.Expr, error) { - astExpr := setExpr.Expr - evalExpr, err := sqlparser.Convert(astExpr) - if err != nil { - if err != sqlparser.ExprNotSupported { - return nil, err - } - // We have an expression that we can't handle at the vtgate level - if !expressionOkToDelegateToTablet(astExpr) { - // Uh-oh - this expression is not even safe to delegate to the tablet. Give up. - return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "expression not supported for SET: %s", sqlparser.String(astExpr)) - } - evalExpr = &evalengine.Column{Offset: len(spb.tabletExpressions)} - spb.tabletExpressions = append(spb.tabletExpressions, setExpr) - } - return evalExpr, nil -} - -func (spb *expressionConverter) source(vschema ContextVSchema) (engine.Primitive, error) { - if len(spb.tabletExpressions) == 0 { - return &engine.SingleRow{}, nil - } - ks, dest, err := resolveDestination(vschema) - if err != nil { - return nil, err - } - - var expr []string - for _, e := range spb.tabletExpressions { - expr = append(expr, sqlparser.String(e.Expr)) - } - query := fmt.Sprintf("select %s from dual", strings.Join(expr, ",")) - - primitive := &engine.Send{ - Keyspace: ks, - TargetDestination: dest, - Query: query, - IsDML: false, - SingleShardOnly: true, - } - return primitive, nil -} - -func buildNotSupported(bool) func(*sqlparser.SetExpr, ContextVSchema) (engine.SetOp, error) { - return func(expr *sqlparser.SetExpr, schema ContextVSchema) (engine.SetOp, error) { +func buildNotSupported(setting) planFunc { + return func(expr *sqlparser.SetExpr, schema ContextVSchema, _ *expressionConverter) (engine.SetOp, error) { return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%s: system setting is not supported", expr.Name) } } -func buildSetOpIgnore(boolean bool) func(*sqlparser.SetExpr, ContextVSchema) (engine.SetOp, error) { - return func(expr *sqlparser.SetExpr, _ ContextVSchema) (engine.SetOp, error) { +func buildSetOpIgnore(s setting) planFunc { + return func(expr *sqlparser.SetExpr, vschema ContextVSchema, _ *expressionConverter) (engine.SetOp, error) { return &engine.SysVarIgnore{ Name: expr.Name.Lowered(), - Expr: extractValue(expr, boolean), + Expr: extractValue(expr, s.boolean), }, nil } } -func buildSetOpCheckAndIgnore(boolean bool) func(*sqlparser.SetExpr, ContextVSchema) (engine.SetOp, error) { - return func(expr *sqlparser.SetExpr, schema ContextVSchema) (engine.SetOp, error) { - return planSysVarCheckIgnore(expr, schema, boolean) +func buildSetOpCheckAndIgnore(s setting) planFunc { + return func(expr *sqlparser.SetExpr, schema ContextVSchema, _ *expressionConverter) (engine.SetOp, error) { + return planSysVarCheckIgnore(expr, schema, s.boolean) } } @@ -359,19 +146,20 @@ func expressionOkToDelegateToTablet(e sqlparser.Expr) bool { _, ok := validFuncs[n.Name.Lowered()] valid = ok return ok + case *sqlparser.ColName: + valid = n.Name.AtCount() == 2 + return false } return true }) return valid } -func buildSetOpVarSet(boolean bool) func(*sqlparser.SetExpr, ContextVSchema) (engine.SetOp, error) { - - return func(expr *sqlparser.SetExpr, vschema ContextVSchema) (engine.SetOp, error) { +func buildSetOpVarSet(s setting) planFunc { + return func(expr *sqlparser.SetExpr, vschema ContextVSchema, _ *expressionConverter) (engine.SetOp, error) { if !vschema.SysVarSetEnabled() { - return planSysVarCheckIgnore(expr, vschema, boolean) + return planSysVarCheckIgnore(expr, vschema, s.boolean) } - ks, err := vschema.AnyKeyspace() if err != nil { return nil, err @@ -381,7 +169,32 @@ func buildSetOpVarSet(boolean bool) func(*sqlparser.SetExpr, ContextVSchema) (en Name: expr.Name.Lowered(), Keyspace: ks, TargetDestination: vschema.Destination(), - Expr: extractValue(expr, boolean), + Expr: extractValue(expr, s.boolean), + }, nil + } +} + +func buildSetOpVitessAware(s setting) planFunc { + return func(astExpr *sqlparser.SetExpr, vschema ContextVSchema, ec *expressionConverter) (engine.SetOp, error) { + var err error + var runtimeExpr evalengine.Expr + + _, isDefault := astExpr.Expr.(*sqlparser.Default) + if isDefault { + if s.defaultValue == nil { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "don't know default value for %s", astExpr.Name) + } + runtimeExpr = s.defaultValue + } else { + runtimeExpr, err = ec.convert(astExpr.Expr, s.boolean, s.identifierAsString) + if err != nil { + return nil, err + } + } + + return &engine.SysVarSetAware{ + Name: astExpr.Name.Lowered(), + Expr: runtimeExpr, }, nil } } diff --git a/go/vt/vtgate/planbuilder/system_settings.go b/go/vt/vtgate/planbuilder/system_settings.go new file mode 100644 index 00000000000..9cc04698f22 --- /dev/null +++ b/go/vt/vtgate/planbuilder/system_settings.go @@ -0,0 +1,213 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package planbuilder + +import ( + "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/evalengine" +) + +func init() { + forSettings(ignoreThese, buildSetOpIgnore) + forSettings(useReservedConn, buildSetOpVarSet) + forSettings(checkAndIgnore, buildSetOpCheckAndIgnore) + forSettings(notSupported, buildNotSupported) + forSettings(vitessAware, buildSetOpVitessAware) +} + +func forSettings(settings []setting, f func(s setting) planFunc) { + for _, setting := range settings { + if _, alreadyExists := sysVarPlanningFunc[setting.name]; alreadyExists { + panic("bug in set plan init - " + setting.name + " aleady configured") + } + sysVarPlanningFunc[setting.name] = f(setting) + } +} + +var ( + ON = evalengine.NewLiteralInt(1) + OFF = evalengine.NewLiteralInt(0) + + vitessAware = []setting{ + {name: engine.Autocommit, boolean: true, defaultValue: ON}, + {name: engine.ClientFoundRows, boolean: true, defaultValue: OFF}, + {name: engine.SkipQueryPlanCache, boolean: true, defaultValue: OFF}, + {name: engine.TransactionReadOnly, boolean: true, defaultValue: OFF}, + {name: engine.TxReadOnly, boolean: true, defaultValue: OFF}, + {name: engine.SQLSelectLimit, defaultValue: OFF}, + {name: engine.TransactionMode, identifierAsString: true, defaultValue: evalengine.NewLiteralString([]byte("MULTI"))}, + {name: engine.Workload, identifierAsString: true, defaultValue: evalengine.NewLiteralString([]byte("UNSPECIFIED"))}, + {name: engine.Charset, identifierAsString: true, defaultValue: evalengine.NewLiteralString([]byte("utf8"))}, + {name: engine.Names, identifierAsString: true, defaultValue: evalengine.NewLiteralString([]byte("utf8"))}, + } + + notSupported = []setting{ + {name: "audit_log_read_buffer_size"}, + {name: "auto_increment_increment"}, + {name: "auto_increment_offset"}, + {name: "binlog_direct_non_transactional_updates"}, + {name: "binlog_row_image"}, + {name: "binlog_rows_query_log_events"}, + {name: "innodb_ft_enable_stopword"}, + {name: "innodb_ft_user_stopword_table"}, + {name: "max_points_in_geometry"}, + {name: "max_sp_recursion_depth"}, + {name: "myisam_repair_threads"}, + {name: "myisam_sort_buffer_size"}, + {name: "myisam_stats_method"}, + {name: "ndb_allow_copying_alter_table"}, + {name: "ndb_autoincrement_prefetch_sz"}, + {name: "ndb_blob_read_batch_bytes"}, + {name: "ndb_blob_write_batch_bytes"}, + {name: "ndb_deferred_constraints"}, + {name: "ndb_force_send"}, + {name: "ndb_fully_replicated"}, + {name: "ndb_index_stat_enable"}, + {name: "ndb_index_stat_option"}, + {name: "ndb_join_pushdown"}, + {name: "ndb_log_bin"}, + {name: "ndb_log_exclusive_reads"}, + {name: "ndb_row_checksum"}, + {name: "ndb_use_exact_count"}, + {name: "ndb_use_transactions"}, + {name: "ndbinfo_max_bytes"}, + {name: "ndbinfo_max_rows"}, + {name: "ndbinfo_show_hidden"}, + {name: "ndbinfo_table_prefix"}, + {name: "old_alter_table"}, + {name: "preload_buffer_size"}, + {name: "rbr_exec_mode"}, + {name: "sql_log_off"}, + {name: "thread_pool_high_priority_connection"}, + {name: "thread_pool_prio_kickup_timer"}, + {name: "transaction_write_set_extraction"}, + } + + ignoreThese = []setting{ + {name: "big_tables", boolean: true}, + {name: "bulk_insert_buffer_size"}, + {name: "debug"}, + {name: "default_storage_engine"}, + {name: "default_tmp_storage_engine"}, + {name: "innodb_strict_mode", boolean: true}, + {name: "innodb_support_xa", boolean: true}, + {name: "innodb_table_locks", boolean: true}, + {name: "innodb_tmpdir"}, + {name: "join_buffer_size"}, + {name: "keep_files_on_create", boolean: true}, + {name: "lc_messages"}, + {name: "long_query_time"}, + {name: "low_priority_updates", boolean: true}, + {name: "max_delayed_threads"}, + {name: "max_insert_delayed_threads"}, + {name: "multi_range_count"}, + {name: "net_buffer_length"}, + {name: "new", boolean: true}, + {name: "query_cache_type"}, + {name: "query_cache_wlock_invalidate", boolean: true}, + {name: "query_prealloc_size"}, + {name: "sql_buffer_result", boolean: true}, + {name: "transaction_alloc_block_size"}, + {name: "wait_timeout"}, + } + + useReservedConn = []setting{ + {name: "default_week_format"}, + {name: "end_markers_in_json", boolean: true}, + {name: "eq_range_index_dive_limit"}, + {name: "explicit_defaults_for_timestamp"}, + {name: "foreign_key_checks", boolean: true}, + {name: "group_concat_max_len"}, + {name: "max_heap_table_size"}, + {name: "max_seeks_for_key"}, + {name: "max_tmp_tables"}, + {name: "min_examined_row_limit"}, + {name: "old_passwords"}, + {name: "optimizer_prune_level"}, + {name: "optimizer_search_depth"}, + {name: "optimizer_switch"}, + {name: "optimizer_trace"}, + {name: "optimizer_trace_features"}, + {name: "optimizer_trace_limit"}, + {name: "optimizer_trace_max_mem_size"}, + {name: "transaction_isolation"}, + {name: "tx_isolation"}, + {name: "optimizer_trace_offset"}, + {name: "parser_max_mem_size"}, + {name: "profiling", boolean: true}, + {name: "profiling_history_size"}, + {name: "query_alloc_block_size"}, + {name: "range_alloc_block_size"}, + {name: "range_optimizer_max_mem_size"}, + {name: "read_buffer_size"}, + {name: "read_rnd_buffer_size"}, + {name: "show_create_table_verbosity", boolean: true}, + {name: "show_old_temporals", boolean: true}, + {name: "sort_buffer_size"}, + {name: "sql_big_selects", boolean: true}, + {name: "sql_mode"}, + {name: "sql_notes", boolean: true}, + {name: "sql_quote_show_create", boolean: true}, + {name: "sql_safe_updates", boolean: true}, + {name: "sql_warnings", boolean: true}, + {name: "tmp_table_size"}, + {name: "transaction_prealloc_size"}, + {name: "unique_checks", boolean: true}, + {name: "updatable_views_with_limit", boolean: true}, + } + + // TODO: Most of these settings should be moved into SysSetOpAware, and change Vitess behaviour. + // Until then, SET statements against these settings are allowed + // as long as they have the same value as the underlying database + checkAndIgnore = []setting{ + {name: "binlog_format"}, + {name: "block_encryption_mode"}, + {name: "character_set_client"}, + {name: "character_set_connection"}, + {name: "character_set_database"}, + {name: "character_set_filesystem"}, + {name: "character_set_results"}, + {name: "character_set_server"}, + {name: "collation_connection"}, + {name: "collation_database"}, + {name: "collation_server"}, + {name: "completion_type"}, + {name: "div_precision_increment"}, + {name: "innodb_lock_wait_timeout"}, + {name: "interactive_timeout"}, + {name: "lc_time_names"}, + {name: "lock_wait_timeout"}, + {name: "max_allowed_packet"}, + {name: "max_error_count"}, + {name: "max_execution_time"}, + {name: "max_join_size"}, + {name: "max_length_for_sort_data"}, + {name: "max_sort_length"}, + {name: "max_user_connections"}, + {name: "net_read_timeout"}, + {name: "net_retry_count"}, + {name: "net_write_timeout"}, + {name: "session_track_gtids"}, + {name: "session_track_schema", boolean: true}, + {name: "session_track_state_change", boolean: true}, + {name: "session_track_system_variables"}, + {name: "session_track_transaction_info"}, + {name: "sql_auto_is_null", boolean: true}, + {name: "time_zone"}, + {name: "version_tokens_session"}, + } +) diff --git a/go/vt/vtgate/planbuilder/testdata/set_cases.txt b/go/vt/vtgate/planbuilder/testdata/set_cases.txt index 4b2372b5bcf..6f99d615d57 100644 --- a/go/vt/vtgate/planbuilder/testdata/set_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/set_cases.txt @@ -85,7 +85,7 @@ { "Type": "UserDefinedVariable", "Name": "foo", - "Expr": "[0]" + "Expr": "column 0 from the input" } ], "Inputs": [ @@ -168,3 +168,293 @@ ] } } + +# autocommit case +"SET autocommit = 1, autocommit = on, autocommit = 'on', autocommit = @myudv, autocommit = `on`, autocommit = `off`" +{ + "QueryType": "SET", + "Original": "SET autocommit = 1, autocommit = on, autocommit = 'on', autocommit = @myudv, autocommit = `on`, autocommit = `off`", + "Instructions": { + "OperatorType": "Set", + "Ops": [ + { + "Type": "SysVarAware", + "Name": "autocommit", + "Expr": "INT64(1)" + }, + { + "Type": "SysVarAware", + "Name": "autocommit", + "Expr": "INT64(1)" + }, + { + "Type": "SysVarAware", + "Name": "autocommit", + "Expr": "INT64(1)" + }, + { + "Type": "SysVarAware", + "Name": "autocommit", + "Expr": ":__vtudvmyudv" + }, + { + "Type": "SysVarAware", + "Name": "autocommit", + "Expr": "INT64(1)" + }, + { + "Type": "SysVarAware", + "Name": "autocommit", + "Expr": "INT64(0)" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow" + } + ] + } +} + +# set ignore plan +"set @@default_storage_engine = 'DONOTCHANGEME'" +{ + "QueryType": "SET", + "Original": "set @@default_storage_engine = 'DONOTCHANGEME'", + "Instructions": { + "OperatorType": "Set", + "Ops": [ + { + "Type": "SysVarIgnore", + "Name": "default_storage_engine", + "Expr": "'DONOTCHANGEME'" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow" + } + ] + } +} + +# set check and ignore plan +"set @@sql_mode = concat(@@sql_mode, ',NO_AUTO_CREATE_USER')" +{ + "QueryType": "SET", + "Original": "set @@sql_mode = concat(@@sql_mode, ',NO_AUTO_CREATE_USER')", + "Instructions": { + "OperatorType": "Set", + "Ops": [ + { + "Type": "SysVarSet", + "Name": "sql_mode", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "Expr": "concat(@@sql_mode, ',NO_AUTO_CREATE_USER')" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow" + } + ] + } +} + +# set system settings +"set @@sql_safe_updates = 1" +{ + "QueryType": "SET", + "Original": "set @@sql_safe_updates = 1", + "Instructions": { + "OperatorType": "Set", + "Ops": [ + { + "Type": "SysVarSet", + "Name": "sql_safe_updates", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "Expr": "1" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow" + } + ] + } +} + +# set plan building with ON/OFF enum +"set @@innodb_strict_mode = OFF" +{ + "QueryType": "SET", + "Original": "set @@innodb_strict_mode = OFF", + "Instructions": { + "OperatorType": "Set", + "Ops": [ + { + "Type": "SysVarIgnore", + "Name": "innodb_strict_mode", + "Expr": "0" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow" + } + ] + } +} + +# set plan building with string literal +"set @@innodb_strict_mode = 'OFF'" +{ + "QueryType": "SET", + "Original": "set @@innodb_strict_mode = 'OFF'", + "Instructions": { + "OperatorType": "Set", + "Ops": [ + { + "Type": "SysVarIgnore", + "Name": "innodb_strict_mode", + "Expr": "0" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow" + } + ] + } +} + +# set plan building with string literal +"set @@innodb_tmpdir = 'OFF'" +{ + "QueryType": "SET", + "Original": "set @@innodb_tmpdir = 'OFF'", + "Instructions": { + "OperatorType": "Set", + "Ops": [ + { + "Type": "SysVarIgnore", + "Name": "innodb_tmpdir", + "Expr": "'OFF'" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow" + } + ] + } +} + +# set system settings +"set @@ndbinfo_max_bytes = 192" +"ndbinfo_max_bytes: system setting is not supported" + +# set autocommit +"set autocommit = 1" +{ + "QueryType": "SET", + "Original": "set autocommit = 1", + "Instructions": { + "OperatorType": "Set", + "Ops": [ + { + "Type": "SysVarAware", + "Name": "autocommit", + "Expr": "INT64(1)" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow" + } + ] + } +} + +# set autocommit false +"set autocommit = 0" +{ + "QueryType": "SET", + "Original": "set autocommit = 0", + "Instructions": { + "OperatorType": "Set", + "Ops": [ + { + "Type": "SysVarAware", + "Name": "autocommit", + "Expr": "INT64(0)" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow" + } + ] + } +} + +# set autocommit with backticks +"set @@session.`autocommit` = 0" +{ + "QueryType": "SET", + "Original": "set @@session.`autocommit` = 0", + "Instructions": { + "OperatorType": "Set", + "Ops": [ + { + "Type": "SysVarAware", + "Name": "autocommit", + "Expr": "INT64(0)" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow" + } + ] + } +} + +# more vitess aware settings +"set client_found_rows = off, skip_query_plan_cache = ON, sql_select_limit=20" +{ + "QueryType": "SET", + "Original": "set client_found_rows = off, skip_query_plan_cache = ON, sql_select_limit=20", + "Instructions": { + "OperatorType": "Set", + "Ops": [ + { + "Type": "SysVarAware", + "Name": "client_found_rows", + "Expr": "INT64(0)" + }, + { + "Type": "SysVarAware", + "Name": "skip_query_plan_cache", + "Expr": "INT64(1)" + }, + { + "Type": "SysVarAware", + "Name": "sql_select_limit", + "Expr": "INT64(20)" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow" + } + ] + } +} + \ No newline at end of file diff --git a/go/vt/vtgate/planbuilder/testdata/set_sysvar_cases.txt b/go/vt/vtgate/planbuilder/testdata/set_sysvar_cases.txt deleted file mode 100644 index c580cb86f5c..00000000000 --- a/go/vt/vtgate/planbuilder/testdata/set_sysvar_cases.txt +++ /dev/null @@ -1,143 +0,0 @@ -# set ignore plan -"set @@default_storage_engine = 'DONOTCHANGEME'" -{ - "QueryType": "SET", - "Original": "set @@default_storage_engine = 'DONOTCHANGEME'", - "Instructions": { - "OperatorType": "Set", - "Ops": [ - { - "Type": "SysVarIgnore", - "Name": "default_storage_engine", - "Expr": "'DONOTCHANGEME'" - } - ], - "Inputs": [ - { - "OperatorType": "SingleRow" - } - ] - } -} - -# set check and ignore plan -"set @@sql_mode = concat(@@sql_mode, ',NO_AUTO_CREATE_USER')" -{ - "QueryType": "SET", - "Original": "set @@sql_mode = concat(@@sql_mode, ',NO_AUTO_CREATE_USER')", - "Instructions": { - "OperatorType": "Set", - "Ops": [ - { - "Type": "SysVarSet", - "Name": "sql_mode", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "Expr": "concat(@@sql_mode, ',NO_AUTO_CREATE_USER')" - } - ], - "Inputs": [ - { - "OperatorType": "SingleRow" - } - ] - } -} - -# set system settings -"set @@sql_safe_updates = 1" -{ - "QueryType": "SET", - "Original": "set @@sql_safe_updates = 1", - "Instructions": { - "OperatorType": "Set", - "Ops": [ - { - "Type": "SysVarSet", - "Name": "sql_safe_updates", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "Expr": "1" - } - ], - "Inputs": [ - { - "OperatorType": "SingleRow" - } - ] - } -} - -# set plan building with ON/OFF enum -"set @@innodb_strict_mode = OFF" -{ - "QueryType": "SET", - "Original": "set @@innodb_strict_mode = OFF", - "Instructions": { - "OperatorType": "Set", - "Ops": [ - { - "Type": "SysVarIgnore", - "Name": "innodb_strict_mode", - "Expr": "0" - } - ], - "Inputs": [ - { - "OperatorType": "SingleRow" - } - ] - } -} - -# set plan building with string literal -"set @@innodb_strict_mode = 'OFF'" -{ - "QueryType": "SET", - "Original": "set @@innodb_strict_mode = 'OFF'", - "Instructions": { - "OperatorType": "Set", - "Ops": [ - { - "Type": "SysVarIgnore", - "Name": "innodb_strict_mode", - "Expr": "0" - } - ], - "Inputs": [ - { - "OperatorType": "SingleRow" - } - ] - } -} - -# set plan building with string literal -"set @@innodb_tmpdir = 'OFF'" -{ - "QueryType": "SET", - "Original": "set @@innodb_tmpdir = 'OFF'", - "Instructions": { - "OperatorType": "Set", - "Ops": [ - { - "Type": "SysVarIgnore", - "Name": "innodb_tmpdir", - "Expr": "'OFF'" - } - ], - "Inputs": [ - { - "OperatorType": "SingleRow" - } - ] - } -} - -# set system settings -"set @@ndbinfo_max_bytes = 192" -"ndbinfo_max_bytes: system setting is not supported" diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt index 3dc5f49dfe4..7dd4584e1f5 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt @@ -2,14 +2,6 @@ "select * from user union select * from user_extra" "unsupported: UNION cannot be executed as a single route" -# SET -"set a=1" -"plan building not supported" - -# set user defined and system variable -"set @foo = 42, @bar = @foo, @@xyz = 28800" -"plan building not supported" - # SHOW "show create database" "plan building not supported" diff --git a/go/vt/vtgate/safe_session.go b/go/vt/vtgate/safe_session.go index 3ff802463be..c33d0375b6b 100644 --- a/go/vt/vtgate/safe_session.go +++ b/go/vt/vtgate/safe_session.go @@ -444,3 +444,11 @@ func removeShard(tabletAlias *topodatapb.TabletAlias, sessions []*vtgatepb.Sessi } return append(sessions[:idx], sessions[idx+1:]...), nil } + +//GetOrCreateOptions will return the current options struct, or create one and return it if no-one exists +func (session *SafeSession) GetOrCreateOptions() *querypb.ExecuteOptions { + if session.Session.Options == nil { + session.Session.Options = &querypb.ExecuteOptions{} + } + return session.Session.Options +} diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index b96ef944ec2..76eb2ec9275 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -61,6 +61,7 @@ type iExecute interface { ExecuteMultiShard(ctx context.Context, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, session *SafeSession, autocommit bool, ignoreMaxMemoryRows bool) (qr *sqltypes.Result, errs []error) StreamExecuteMulti(ctx context.Context, s string, rss []*srvtopo.ResolvedShard, vars []map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(reply *sqltypes.Result) error) error ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedShard, query *querypb.BoundQuery, session *SafeSession) (*sqltypes.Result, error) + Commit(ctx context.Context, safeSession *SafeSession) error // TODO: remove when resolver is gone ParseDestinationTarget(targetString string) (string, topodatapb.TabletType, key.Destination, error) @@ -473,6 +474,42 @@ func (vc *vcursorImpl) TargetDestination(qualifier string) (key.Destination, *vi return vc.destination, keyspace.Keyspace, vc.tabletType, nil } +//SetAutocommit implementes the SessionActions interface +func (vc *vcursorImpl) SetAutocommit(autocommit bool) error { + if autocommit && vc.safeSession.InTransaction() { + if err := vc.executor.Commit(vc.ctx, vc.safeSession); err != nil { + return err + } + } + vc.safeSession.Autocommit = autocommit + return nil +} + +//SetClientFoundRows implementes the SessionActions interface +func (vc *vcursorImpl) SetClientFoundRows(clientFoundRows bool) { + vc.safeSession.GetOrCreateOptions().ClientFoundRows = clientFoundRows +} + +//SetSkipQueryPlanCache implementes the SessionActions interface +func (vc *vcursorImpl) SetSkipQueryPlanCache(skipQueryPlanCache bool) { + vc.safeSession.GetOrCreateOptions().SkipQueryPlanCache = skipQueryPlanCache +} + +//SetSkipQueryPlanCache implementes the SessionActions interface +func (vc *vcursorImpl) SetSQLSelectLimit(limit int64) { + vc.safeSession.GetOrCreateOptions().SqlSelectLimit = limit +} + +//SetSkipQueryPlanCache implementes the SessionActions interface +func (vc *vcursorImpl) SetTransactionMode(mode vtgatepb.TransactionMode) { + vc.safeSession.TransactionMode = mode +} + +//SetWorkload implementes the SessionActions interface +func (vc *vcursorImpl) SetWorkload(workload querypb.ExecuteOptions_Workload) { + vc.safeSession.GetOrCreateOptions().Workload = workload +} + func (vc *vcursorImpl) SysVarSetEnabled() bool { return *sysVarSetEnabled }