Skip to content

Commit

Permalink
sql/(plan,expression): implement String method to make tree printable
Browse files Browse the repository at this point in the history
Closes src-d#95
Closes src-d#108

- Implements String method on all sql.Nodes.
- Implements String method on all sql.Expressions.
- Remove Name method from sql.Expressions.
- Only leave Name in expressions that will be Nameable.
- Created TreePrinter to print trees with a Node and Children.

Signed-off-by: Miguel Molina <miguel@erizocosmi.co>
  • Loading branch information
erizocosmico committed Mar 15, 2018
1 parent 73d2f54 commit 13d90bc
Show file tree
Hide file tree
Showing 41 changed files with 540 additions and 139 deletions.
25 changes: 25 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"gopkg.in/src-d/go-mysql-server.v0"
"gopkg.in/src-d/go-mysql-server.v0/mem"
"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/parse"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -293,3 +294,27 @@ func newEngine(t *testing.T) *sqle.Engine {

return e
}

const expectedTree = `Offset(2)
└─ Limit(5)
└─ Project(t.foo, bar.baz)
└─ Filter(foo > qux)
└─ InnerJoin(foo = baz)
├─ TableAlias(t)
│ └─ UnresolvedTable(tbl)
└─ UnresolvedTable(bar)
`

func TestPrintTree(t *testing.T) {
require := require.New(t)
node, err := parse.Parse(nil, `
SELECT t.foo, bar.baz
FROM tbl t
INNER JOIN bar
ON foo = baz
WHERE foo > qux
LIMIT 5
OFFSET 2`)
require.NoError(err)
require.Equal(expectedTree, node.String())
}
4 changes: 4 additions & 0 deletions mem/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,7 @@ func (t *Table) Insert(row sql.Row) error {
t.data = append(t.data, row.Copy())
return nil
}

func (t Table) String() string {
return fmt.Sprintf("Table(%s)", t.name)
}
6 changes: 3 additions & 3 deletions sql/analyzer/validation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func validateGroupBy(n sql.Node) error {

var validAggs []string
for _, expr := range n.Grouping {
validAggs = append(validAggs, expr.Name())
validAggs = append(validAggs, expr.String())
}

// TODO: validate columns inside aggregations
Expand All @@ -78,7 +78,7 @@ func validateGroupBy(n sql.Node) error {
for _, expr := range n.Aggregate {
if _, ok := expr.(sql.Aggregation); !ok {
if !isValidAgg(validAggs, expr) {
return ErrValidationGroupBy.New(expr.Name())
return ErrValidationGroupBy.New(expr.String())
}
}
}
Expand All @@ -96,7 +96,7 @@ func isValidAgg(validAggs []string, expr sql.Expression) bool {
case *expression.Alias:
return isValidAgg(validAggs, expr.Child)
default:
return stringContains(validAggs, expr.Name())
return stringContains(validAggs, expr.String())
}
}

Expand Down
1 change: 1 addition & 0 deletions sql/analyzer/validation_rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ func TestValidateSchemaSource(t *testing.T) {

type dummyNode struct{ resolved bool }

func (n dummyNode) String() string { return "dummynode" }
func (n dummyNode) Resolved() bool { return n.resolved }
func (dummyNode) Schema() sql.Schema { return sql.Schema{} }
func (dummyNode) Children() []sql.Node { return nil }
Expand Down
6 changes: 4 additions & 2 deletions sql/core.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package sql

import (
"fmt"

"gopkg.in/src-d/go-errors.v1"
)

Expand Down Expand Up @@ -30,10 +32,9 @@ type Transformable interface {
// Expression is a combination of one or more SQL expressions.
type Expression interface {
Resolvable
fmt.Stringer
// Type returns the expression type.
Type() Type
// Name returns the expression name.
Name() string
// IsNullable returns whether the expression can be null.
IsNullable() bool
// Eval evaluates the given row and returns a result.
Expand Down Expand Up @@ -63,6 +64,7 @@ type Aggregation interface {
type Node interface {
Resolvable
Transformable
fmt.Stringer
// Schema of the node.
Schema() Schema
// Children nodes.
Expand Down
20 changes: 8 additions & 12 deletions sql/expression/aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ func (c *Count) Resolved() bool {
return c.Child.Resolved()
}

// Name returns the name of the node.
func (c *Count) Name() string {
return fmt.Sprintf("count(%s)", c.Child.Name())
func (c Count) String() string {
return fmt.Sprintf("COUNT(%s)", c.Child)
}

// TransformUp implements the Expression interface.
Expand Down Expand Up @@ -110,9 +109,8 @@ func (m *Min) Type() sql.Type {
return m.Child.Type()
}

// Name returns the name of the node.
func (m *Min) Name() string {
return fmt.Sprintf("min(%s)", m.Child.Name())
func (m Min) String() string {
return fmt.Sprintf("MIN(%s)", m.Child)
}

// IsNullable returns whether the return value can be null.
Expand Down Expand Up @@ -187,9 +185,8 @@ func (m *Max) Type() sql.Type {
return m.Child.Type()
}

// Name returns the name of the node.
func (m *Max) Name() string {
return fmt.Sprintf("max(%s)", m.Child.Name())
func (m Max) String() string {
return fmt.Sprintf("MAX(%s)", m.Child)
}

// IsNullable returns whether the return value can be null.
Expand Down Expand Up @@ -253,9 +250,8 @@ func NewAvg(e sql.Expression) *Avg {
return &Avg{UnaryExpression{e}}
}

// Name implements Nameable interface.
func (a *Avg) Name() string {
return fmt.Sprintf("avg(%s)", a.Child.Name())
func (a Avg) String() string {
return fmt.Sprintf("AVG(%s)", a.Child)
}

// Resolved implements AggregationExpression interface. (AggregationExpression[Expression[Resolvable]]])
Expand Down
17 changes: 8 additions & 9 deletions sql/expression/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (
"gopkg.in/src-d/go-mysql-server.v0/sql"
)

func TestCount_Name(t *testing.T) {
func TestCount_String(t *testing.T) {
require := require.New(t)

c := NewCount(NewLiteral("foo", sql.Text))
require.Equal("count(literal_TEXT)", c.Name())
require.Equal(`COUNT("foo")`, c.String())
}

func TestCount_Eval_1(t *testing.T) {
Expand Down Expand Up @@ -79,7 +79,7 @@ func TestMin_Name(t *testing.T) {
assert := require.New(t)

m := NewMin(NewGetField(0, sql.Int32, "field", true))
assert.Equal("min(field)", m.Name())
assert.Equal("MIN(field)", m.String())
}

func TestMin_Eval_Int32(t *testing.T) {
Expand Down Expand Up @@ -161,11 +161,10 @@ func TestMin_Eval_Empty(t *testing.T) {
assert.NoError(err)
assert.Equal(nil, v)
}
func TestMax_Name(t *testing.T) {
func TestMax_String(t *testing.T) {
assert := require.New(t)

m := NewMax(NewGetField(0, sql.Int32, "field", true))
assert.Equal("max(field)", m.Name())
assert.Equal("MAX(field)", m.String())
}

func TestMax_Eval_Int32(t *testing.T) {
Expand Down Expand Up @@ -247,11 +246,11 @@ func TestMax_Eval_Empty(t *testing.T) {
assert.Equal(nil, v)
}

func TestAvg_Name(t *testing.T) {
func TestAvg_String(t *testing.T) {
require := require.New(t)

avgNode := NewAvg(NewGetField(0, sql.Int32, "col1", true))
require.Equal("avg(col1)", avgNode.Name())
avg := NewAvg(NewGetField(0, sql.Int32, "col1", true))
require.Equal("AVG(col1)", avg.String())
}

func TestAvg_Eval_INT32(t *testing.T) {
Expand Down
14 changes: 10 additions & 4 deletions sql/expression/alias.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package expression

import "gopkg.in/src-d/go-mysql-server.v0/sql"
import (
"fmt"

"gopkg.in/src-d/go-mysql-server.v0/sql"
)

// Alias is a node that gives a name to an expression.
type Alias struct {
Expand All @@ -23,9 +27,8 @@ func (e *Alias) Eval(session sql.Session, row sql.Row) (interface{}, error) {
return e.Child.Eval(session, row)
}

// Name implements the Expression interface.
func (e *Alias) Name() string {
return e.name
func (e Alias) String() string {
return fmt.Sprintf("%s as %s", e.Child, e.name)
}

// TransformUp implements the Expression interface.
Expand All @@ -36,3 +39,6 @@ func (e *Alias) TransformUp(f func(sql.Expression) (sql.Expression, error)) (sql
}
return f(NewAlias(child, e.name))
}

// Name implements the Nameable interface.
func (e *Alias) Name() string { return e.name }
7 changes: 5 additions & 2 deletions sql/expression/between.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package expression

import (
"fmt"

"gopkg.in/src-d/go-mysql-server.v0/sql"
)

Expand All @@ -16,8 +18,9 @@ func NewBetween(val, lower, upper sql.Expression) *Between {
return &Between{val, lower, upper}
}

// Name implements the Expression interface.
func (Between) Name() string { return "between" }
func (b Between) String() string {
return fmt.Sprintf("BETWEEN(%s, %s, %s)", b.Val, b.Lower, b.Upper)
}

// Type implements the Expression interface.
func (Between) Type() sql.Type { return sql.Boolean }
Expand Down
7 changes: 4 additions & 3 deletions sql/expression/boolean.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package expression

import (
"fmt"

"gopkg.in/src-d/go-mysql-server.v0/sql"
)

Expand Down Expand Up @@ -33,9 +35,8 @@ func (e Not) Eval(session sql.Session, row sql.Row) (interface{}, error) {
return !v.(bool), nil
}

// Name implements the Expression interface.
func (e Not) Name() string {
return "Not(" + e.Child.Name() + ")"
func (e Not) String() string {
return fmt.Sprintf("NOT(%s)", e.Child)
}

// TransformUp implements the Expression interface.
Expand Down
32 changes: 21 additions & 11 deletions sql/expression/comparison.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package expression

import (
"fmt"
"regexp"

"gopkg.in/src-d/go-mysql-server.v0/sql"
Expand Down Expand Up @@ -28,11 +29,6 @@ func (*Comparison) Type() sql.Type {
return sql.Boolean
}

// Name implements the Expression interface.
func (*Comparison) Name() string {
return ""
}

// Equals is a comparison that checks an expression is equal to another.
type Equals struct {
Comparison
Expand Down Expand Up @@ -77,9 +73,8 @@ func (e *Equals) TransformUp(f func(sql.Expression) (sql.Expression, error)) (sq
return f(NewEquals(left, right))
}

// Name implements the Expression interface.
func (e Equals) Name() string {
return e.Left.Name() + "==" + e.Right.Name()
func (e Equals) String() string {
return fmt.Sprintf("%s = %s", e.Left, e.Right)
}

// Regexp is a comparison that checks an expression matches a regexp.
Expand Down Expand Up @@ -137,9 +132,8 @@ func (re *Regexp) TransformUp(f func(sql.Expression) (sql.Expression, error)) (s
return f(NewRegexp(left, right))
}

// Name implements the Expression interface.
func (re Regexp) Name() string {
return re.Left.Name() + " REGEXP " + re.Right.Name()
func (re Regexp) String() string {
return fmt.Sprintf("%s REGEXP %s", re.Left, re.Right)
}

// GreaterThan is a comparison that checks an expression is greater than another.
Expand Down Expand Up @@ -189,6 +183,10 @@ func (gt *GreaterThan) TransformUp(f func(sql.Expression) (sql.Expression, error
return f(NewGreaterThan(left, right))
}

func (gt GreaterThan) String() string {
return fmt.Sprintf("%s > %s", gt.Left, gt.Right)
}

// LessThan is a comparison that checks an expression is less than another.
type LessThan struct {
Comparison
Expand Down Expand Up @@ -233,6 +231,10 @@ func (lt *LessThan) TransformUp(f func(sql.Expression) (sql.Expression, error))
return f(NewLessThan(left, right))
}

func (lt LessThan) String() string {
return fmt.Sprintf("%s < %s", lt.Left, lt.Right)
}

// GreaterThanOrEqual is a comparison that checks an expression is greater or equal to
// another.
type GreaterThanOrEqual struct {
Expand Down Expand Up @@ -281,6 +283,10 @@ func (gte *GreaterThanOrEqual) TransformUp(f func(sql.Expression) (sql.Expressio
return f(NewGreaterThanOrEqual(left, right))
}

func (gte GreaterThanOrEqual) String() string {
return fmt.Sprintf("%s >= %s", gte.Left, gte.Right)
}

// LessThanOrEqual is a comparison that checks an expression is equal or lower than
// another.
type LessThanOrEqual struct {
Expand Down Expand Up @@ -328,3 +334,7 @@ func (lte *LessThanOrEqual) TransformUp(f func(sql.Expression) (sql.Expression,

return f(NewLessThanOrEqual(left, right))
}

func (lte LessThanOrEqual) String() string {
return fmt.Sprintf("%s <= %s", lte.Left, lte.Right)
}
Loading

0 comments on commit 13d90bc

Please sign in to comment.