Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

sql/(parse,expression): implement unary minus #456

Merged
merged 2 commits into from
Oct 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,10 @@ var queries = []struct {
"CREATE DATABASE `mydb` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8_bin */",
}},
},
{
`SELECT -1`,
[]sql.Row{{int64(-1)}},
},
}

func TestQueries(t *testing.T) {
Expand Down
79 changes: 79 additions & 0 deletions sql/expression/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package expression

import (
"fmt"
"reflect"

errors "gopkg.in/src-d/go-errors.v1"
"gopkg.in/src-d/go-vitess.v1/vt/sqlparser"
Expand Down Expand Up @@ -407,3 +408,81 @@ func mod(lval, rval interface{}) (interface{}, error) {

return nil, errUnableToCast.New(lval, rval)
}

// UnaryMinus is an unary minus operator.
type UnaryMinus struct {
UnaryExpression
}

// NewUnaryMinus creates a new UnaryMinus expression node.
func NewUnaryMinus(child sql.Expression) *UnaryMinus {
return &UnaryMinus{UnaryExpression{Child: child}}
}

// Eval implements the sql.Expression interface.
func (e *UnaryMinus) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
child, err := e.Child.Eval(ctx, row)
if err != nil {
return nil, err
}

if child == nil {
return nil, nil
}

if !sql.IsNumber(e.Child.Type()) {
child, err = sql.Float64.Convert(child)
if err != nil {
child = 0.0
}
}

switch n := child.(type) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need a case for regular int?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for:

var child interface{} = 10

it will go to the default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there can't be ints in gitbase, either int32 or int64

case float64:
return -n, nil
case float32:
return -n, nil
case int64:
return -n, nil
case uint64:
return -int64(n), nil
case int32:
return -n, nil
case uint32:
return -int32(n), nil
default:
return nil, sql.ErrInvalidType.New(reflect.TypeOf(n))
}
}

// Type implements the sql.Expression interface.
func (e *UnaryMinus) Type() sql.Type {
typ := e.Child.Type()
if !sql.IsNumber(typ) {
return sql.Float64
}

if typ == sql.Uint32 {
return sql.Int32
}

if typ == sql.Uint64 {
return sql.Int64
}

return e.Child.Type()
}

func (e *UnaryMinus) String() string {
return fmt.Sprintf("-%s", e.Child)
}

// TransformUp implements the sql.Expression interface.
func (e *UnaryMinus) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
c, err := e.Child.TransformUp(f)
if err != nil {
return nil, err
}

return f(NewUnaryMinus(c))
}
28 changes: 28 additions & 0 deletions sql/expression/arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,31 @@ func TestAllInt64(t *testing.T) {
})
}
}

func TestUnaryMinus(t *testing.T) {
testCases := []struct {
name string
input interface{}
typ sql.Type
expected interface{}
}{
{"int32", int32(1), sql.Int32, int32(-1)},
{"uint32", uint32(1), sql.Uint32, int32(-1)},
{"int64", int64(1), sql.Int64, int64(-1)},
{"uint64", uint64(1), sql.Uint64, int64(-1)},
{"float32", float32(1), sql.Float32, float32(-1)},
{"float64", float64(1), sql.Float64, float64(-1)},
{"int text", "1", sql.Text, float64(-1)},
{"float text", "1.2", sql.Text, float64(-1.2)},
{"nil", nil, sql.Text, nil},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
f := NewUnaryMinus(NewLiteral(tt.input, tt.typ))
result, err := f.Eval(sql.NewEmptyContext(), nil)
require.NoError(t, err)
require.Equal(t, tt.expected, result)
})
}
}
17 changes: 17 additions & 0 deletions sql/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,8 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) {

case *sqlparser.BinaryExpr:
return binaryExprToExpression(v)
case *sqlparser.UnaryExpr:
return unaryExprToExpression(v)
}
}

Expand Down Expand Up @@ -893,6 +895,21 @@ func selectExprToExpression(se sqlparser.SelectExpr) (sql.Expression, error) {
}
}

func unaryExprToExpression(e *sqlparser.UnaryExpr) (sql.Expression, error) {
switch e.Operator {
case sqlparser.MinusStr:
expr, err := exprToExpression(e.Expr)
if err != nil {
return nil, err
}

return expression.NewUnaryMinus(expr), nil

default:
return nil, ErrUnsupportedFeature.New("unary operator: " + e.Operator)
}
}

func binaryExprToExpression(be *sqlparser.BinaryExpr) (sql.Expression, error) {
switch be.Operator {
case
Expand Down
8 changes: 8 additions & 0 deletions sql/parse/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,14 @@ var fixtures = map[string]sql.Node{
`SHOW CREATE SCHEMA foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), false),
`SHOW CREATE DATABASE IF NOT EXISTS foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), true),
`SHOW CREATE SCHEMA IF NOT EXISTS foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), true),
`SELECT -i FROM mytable`: plan.NewProject(
[]sql.Expression{
expression.NewUnaryMinus(
expression.NewUnresolvedColumn("i"),
),
},
plan.NewUnresolvedTable("mytable", ""),
),
}

func TestParse(t *testing.T) {
Expand Down