Skip to content

Commit

Permalink
Merge pull request #6964 from planetscale/set-udv-allow-more
Browse files Browse the repository at this point in the history
Set udv allow more expressions
  • Loading branch information
harshit-gangal authored Nov 2, 2020
2 parents b9ee82f + 589f9b2 commit b2433cc
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 230 deletions.
6 changes: 6 additions & 0 deletions go/sqltypes/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,12 @@ func (v Value) IsBinary() bool {
return IsBinary(v.typ)
}

// IsDateTime returns true if Value is datetime.
func (v Value) IsDateTime() bool {
dt := int(querypb.Type_DATETIME)
return int(v.typ)&dt == dt
}

// MarshalJSON should only be used for testing.
// It's not a complete implementation.
func (v Value) MarshalJSON() ([]byte, error) {
Expand Down
6 changes: 6 additions & 0 deletions go/test/endtoend/vtgate/reservedconn/udv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ func TestSetUDV(t *testing.T) {
}, {
query: "select id, val1 from test where val1 = @tablet",
expectedRows: `[[INT64(42) VARCHAR("foobar")]]`, rowsAffected: 1,
}, {
query: "set @foo = now(), @bar = now(), @dd = date('2020-10-20'), @tt = time('10:15')",
expectedRows: `[]`, rowsAffected: 0,
}, {
query: "select @foo = @bar, @dd, @tt",
expectedRows: `[[INT64(1) VARCHAR("2020-10-20") VARCHAR("10:15:00")]]`, rowsAffected: 1,
}}

conn, err := mysql.Connect(ctx, &vtParams)
Expand Down
36 changes: 35 additions & 1 deletion go/vt/vtgate/engine/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func TestSetTable(t *testing.T) {
expectedQueryLog []string
expectedWarning []*querypb.QueryWarning
expectedError string
input Primitive
}

tests := []testCase{
Expand All @@ -91,6 +92,36 @@ func TestSetTable(t *testing.T) {
`UDV set with (x,INT64(42))`,
},
},
{
testName: "udv with input",
setOps: []SetOp{
&UserDefinedVariable{
Name: "x",
Expr: evalengine.NewColumn(0),
},
},
qr: []*sqltypes.Result{sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col0",
"datetime",
),
"2020-10-28",
)},
expectedQueryLog: []string{
`ResolveDestinations ks [] Destinations:DestinationAnyShard()`,
`ExecuteMultiShard ks.-20: select now() from dual {} false false`,
`UDV set with (x,DATETIME("2020-10-28"))`,
},
input: &Send{
Keyspace: &vindexes.Keyspace{
Name: "ks",
Sharded: true,
},
TargetDestination: key.DestinationAnyShard{},
Query: "select now() from dual",
SingleShardOnly: true,
},
},
{
testName: "sysvar ignore",
setOps: []SetOp{
Expand Down Expand Up @@ -259,9 +290,12 @@ func TestSetTable(t *testing.T) {

for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) {
if tc.input == nil {
tc.input = &SingleRow{}
}
set := &Set{
Ops: tc.setOps,
Input: &SingleRow{},
Input: tc.input,
}
vc := &loggingVCursor{
shards: []string{"-20", "20-"},
Expand Down
17 changes: 9 additions & 8 deletions go/vt/vtgate/evalengine/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func Add(v1, v2 sqltypes.Value) (sqltypes.Value, error) {
return sqltypes.NULL, err
}

return castFromNumeric(lresult, lresult.typ), nil
return lresult.toSQLValue(lresult.typ), nil
}

// Subtract takes two values and subtracts them
Expand All @@ -79,7 +79,7 @@ func Subtract(v1, v2 sqltypes.Value) (sqltypes.Value, error) {
return sqltypes.NULL, err
}

return castFromNumeric(lresult, lresult.typ), nil
return lresult.toSQLValue(lresult.typ), nil
}

// Multiply takes two values and multiplies it together
Expand All @@ -101,7 +101,7 @@ func Multiply(v1, v2 sqltypes.Value) (sqltypes.Value, error) {
return sqltypes.NULL, err
}

return castFromNumeric(lresult, lresult.typ), nil
return lresult.toSQLValue(lresult.typ), nil
}

// Divide (Float) for MySQL. Replicates behavior of "/" operator
Expand Down Expand Up @@ -132,7 +132,7 @@ func Divide(v1, v2 sqltypes.Value) (sqltypes.Value, error) {
return sqltypes.NULL, err
}

return castFromNumeric(lresult, lresult.typ), nil
return lresult.toSQLValue(lresult.typ), nil
}

// NullsafeAdd adds two Values in a null-safe manner. A null value
Expand Down Expand Up @@ -164,7 +164,7 @@ func NullsafeAdd(v1, v2 sqltypes.Value, resultType querypb.Type) sqltypes.Value
}
lresult := addNumeric(lv1, lv2)

return castFromNumeric(lresult, resultType)
return lresult.toSQLValue(resultType)
}

// NullsafeCompare returns 0 if v1==v2, -1 if v1<v2, and 1 if v1>v2.
Expand Down Expand Up @@ -387,8 +387,9 @@ func newEvalResult(v sqltypes.Value) (EvalResult, error) {
return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err)
}
return EvalResult{fval: fval, typ: sqltypes.Float64}, nil
default:
return EvalResult{typ: v.Type(), bytes: raw}, nil
}
return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "this should not be reached. got %s", v.String())
}

// newIntegralNumeric parses a value and produces an Int64 or Uint64.
Expand Down Expand Up @@ -719,7 +720,7 @@ func anyMinusFloat(v1 EvalResult, v2 float64) EvalResult {
return EvalResult{typ: sqltypes.Float64, fval: v1.fval - v2}
}

func castFromNumeric(v EvalResult, resultType querypb.Type) sqltypes.Value {
func (v EvalResult) toSQLValue(resultType querypb.Type) sqltypes.Value {
switch {
case sqltypes.IsSigned(resultType):
switch v.typ {
Expand Down Expand Up @@ -752,7 +753,7 @@ func castFromNumeric(v EvalResult, resultType querypb.Type) sqltypes.Value {
}
return sqltypes.MakeTrusted(resultType, strconv.AppendFloat(nil, v.fval, format, -1, 64))
}
case resultType == sqltypes.VarChar || resultType == sqltypes.VarBinary || resultType == sqltypes.Binary || resultType == sqltypes.Text:
default:
return sqltypes.MakeTrusted(resultType, v.bytes)
}
return sqltypes.NULL
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/evalengine/arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ func TestPrioritize(t *testing.T) {
}
}

func TestCastFromNumeric(t *testing.T) {
func TestToSqlValue(t *testing.T) {
tcases := []struct {
typ querypb.Type
v EvalResult
Expand Down Expand Up @@ -1158,10 +1158,10 @@ func TestCastFromNumeric(t *testing.T) {
out: sqltypes.TestValue(querypb.Type_DECIMAL, "0.00000000000000012"),
}}
for _, tcase := range tcases {
got := castFromNumeric(tcase.v, tcase.typ)
got := tcase.v.toSQLValue(tcase.typ)

if !reflect.DeepEqual(got, tcase.out) {
t.Errorf("castFromNumeric(%v, %v): %v, want %v", tcase.v, tcase.typ, printValue(got), printValue(tcase.out))
t.Errorf("toSQLValue(%v, %v): %v, want %v", tcase.v, tcase.typ, printValue(got), printValue(tcase.out))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ type (

//Value allows for retrieval of the value we expose for public consumption
func (e EvalResult) Value() sqltypes.Value {
return castFromNumeric(e, e.typ)
return e.toSQLValue(e.typ)
}

//NewLiteralIntFromBytes returns a literal expression
Expand Down
7 changes: 0 additions & 7 deletions go/vt/vtgate/planbuilder/expression_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ import (
"fmt"
"strings"

vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/engine"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)
Expand Down Expand Up @@ -87,11 +85,6 @@ func (ec *expressionConverter) convert(astExpr sqlparser.Expr, boolean, identifi
if err != sqlparser.ErrExprNotSupported {
return nil, err
}
// We have an expression that we can't handle at the vtgate level
if !expressionOkToDelegateToTablet(astExpr) {
// Uh-oh - this expression is not even safe to delegate to the tablet. Give up.
return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "expression not supported for SET: %s", sqlparser.String(astExpr))
}
evalExpr = &evalengine.Column{Offset: len(ec.tabletExpressions)}
ec.tabletExpressions = append(ec.tabletExpressions, astExpr)
}
Expand Down
154 changes: 0 additions & 154 deletions go/vt/vtgate/planbuilder/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,26 +145,6 @@ func planSysVarCheckIgnore(expr *sqlparser.SetExpr, schema ContextVSchema, boole
}, nil
}

func expressionOkToDelegateToTablet(e sqlparser.Expr) bool {
valid := true
sqlparser.Rewrite(e, nil, func(cursor *sqlparser.Cursor) bool {
switch n := cursor.Node().(type) {
case *sqlparser.Subquery, *sqlparser.TimestampFuncExpr, *sqlparser.CurTimeFuncExpr:
valid = false
return false
case *sqlparser.FuncExpr:
_, ok := validFuncs[n.Name.Lowered()]
valid = ok
return ok
case *sqlparser.ColName:
valid = n.Name.AtCount() == 2
return false
}
return true
})
return valid
}

func buildSetOpReservedConn(s setting) planFunc {
return func(expr *sqlparser.SetExpr, vschema ContextVSchema, _ *expressionConverter) (engine.SetOp, error) {
if !vschema.SysVarSetEnabled() {
Expand Down Expand Up @@ -257,137 +237,3 @@ func extractValue(expr *sqlparser.SetExpr, boolean bool) (string, error) {

return sqlparser.String(expr.Expr), nil
}

// whitelist of functions knows to be safe to pass through to mysql for evaluation
// this list tries to not include functions that might return different results on different tablets
var validFuncs = map[string]interface{}{
"if": nil,
"ifnull": nil,
"nullif": nil,
"abs": nil,
"acos": nil,
"asin": nil,
"atan2": nil,
"atan": nil,
"ceil": nil,
"ceiling": nil,
"conv": nil,
"cos": nil,
"cot": nil,
"crc32": nil,
"degrees": nil,
"div": nil,
"exp": nil,
"floor": nil,
"ln": nil,
"log": nil,
"log10": nil,
"log2": nil,
"mod": nil,
"pi": nil,
"pow": nil,
"power": nil,
"radians": nil,
"rand": nil,
"round": nil,
"sign": nil,
"sin": nil,
"sqrt": nil,
"tan": nil,
"truncate": nil,
"adddate": nil,
"addtime": nil,
"convert_tz": nil,
"date": nil,
"date_add": nil,
"date_format": nil,
"date_sub": nil,
"datediff": nil,
"day": nil,
"dayname": nil,
"dayofmonth": nil,
"dayofweek": nil,
"dayofyear": nil,
"extract": nil,
"from_days": nil,
"from_unixtime": nil,
"get_format": nil,
"hour": nil,
"last_day": nil,
"makedate": nil,
"maketime": nil,
"microsecond": nil,
"minute": nil,
"month": nil,
"monthname": nil,
"period_add": nil,
"period_diff": nil,
"quarter": nil,
"sec_to_time": nil,
"second": nil,
"str_to_date": nil,
"subdate": nil,
"subtime": nil,
"time_format": nil,
"time_to_sec": nil,
"timediff": nil,
"timestampadd": nil,
"timestampdiff": nil,
"to_days": nil,
"to_seconds": nil,
"week": nil,
"weekday": nil,
"weekofyear": nil,
"year": nil,
"yearweek": nil,
"ascii": nil,
"bin": nil,
"bit_length": nil,
"char": nil,
"char_length": nil,
"character_length": nil,
"concat": nil,
"concat_ws": nil,
"elt": nil,
"export_set": nil,
"field": nil,
"find_in_set": nil,
"format": nil,
"from_base64": nil,
"hex": nil,
"insert": nil,
"instr": nil,
"lcase": nil,
"left": nil,
"length": nil,
"load_file": nil,
"locate": nil,
"lower": nil,
"lpad": nil,
"ltrim": nil,
"make_set": nil,
"mid": nil,
"oct": nil,
"octet_length": nil,
"ord": nil,
"position": nil,
"quote": nil,
"repeat": nil,
"replace": nil,
"reverse": nil,
"right": nil,
"rpad": nil,
"rtrim": nil,
"soundex": nil,
"space": nil,
"strcmp": nil,
"substr": nil,
"substring": nil,
"substring_index": nil,
"to_base64": nil,
"trim": nil,
"ucase": nil,
"unhex": nil,
"upper": nil,
"weight_string": nil,
}
Loading

0 comments on commit b2433cc

Please sign in to comment.