From 825bb7c6a326bcc143a7f4d986800467097c56b0 Mon Sep 17 00:00:00 2001 From: siddontang Date: Fri, 11 Sep 2015 11:48:51 +0800 Subject: [PATCH] expressions: update subquery --- expression/expressions/expression_test.go | 6 ++ expression/expressions/subquery.go | 70 +++++++++++++++++++---- expression/expressions/subquery_test.go | 44 ++++++++++++-- parser/parser.y | 2 +- 4 files changed, 104 insertions(+), 18 deletions(-) diff --git a/expression/expressions/expression_test.go b/expression/expressions/expression_test.go index 7350f75d8aa4d..d9d2d2f08ff1e 100644 --- a/expression/expressions/expression_test.go +++ b/expression/expressions/expression_test.go @@ -97,6 +97,11 @@ func newMockRecordset() *mockRecordset { } func (r *mockRecordset) Do(f func(data []interface{}) (more bool, err error)) error { + for i := range r.rows { + if more, err := f(r.rows[i]); !more || err != nil { + return err + } + } return nil } @@ -181,6 +186,7 @@ func newMockPlan(rset *mockRecordset) *mockPlan { func (p *mockPlan) Do(ctx context.Context, f plan.RowIterFunc) error { for _, data := range p.rset.rows { + if more, err := f(nil, data[:p.rset.offset]); !more || err != nil { return err } diff --git a/expression/expressions/subquery.go b/expression/expressions/subquery.go index c53f35e9c40f8..4145439109ca6 100644 --- a/expression/expressions/subquery.go +++ b/expression/expressions/subquery.go @@ -17,46 +17,72 @@ import ( "fmt" "strings" + "github.com/juju/errors" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/stmt" ) +// SubQueryStatement implements stmt.Statement and plan.Planner interface +type SubQueryStatement interface { + stmt.Statement + plan.Planner +} + var _ expression.Expression = (*SubQuery)(nil) // SubQuery expresion holds a select statement. // TODO: complete according to https://dev.mysql.com/doc/refman/5.7/en/subquery-restrictions.html type SubQuery struct { // Stmt is the sub select statement. - Stmt stmt.Statement + Stmt SubQueryStatement // Value holds the sub select result. Value interface{} + + p plan.Plan } // Clone implements the Expression Clone interface. func (sq *SubQuery) Clone() (expression.Expression, error) { - // TODO: Statement does not have Clone interface. So we need to check this - nsq := &SubQuery{Stmt: sq.Stmt, Value: sq.Value} + nsq := &SubQuery{Stmt: sq.Stmt, Value: sq.Value, p: sq.p} return nsq, nil } // Eval implements the Expression Eval interface. +// Eval doesn't support multi rows return, so we can only get a scalar or a row result. +// If you want to get multi rows, use Plan to get a execution plan and run Do() directly. func (sq *SubQuery) Eval(ctx context.Context, args map[interface{}]interface{}) (v interface{}, err error) { if sq.Value != nil { return sq.Value, nil } - rs, err := sq.Stmt.Exec(ctx) + + p, err := sq.Plan(ctx) if err != nil { - return nil, err + return nil, errors.Trace(err) } - // TODO: check row/column number - // Output all the data and let the outer caller check the row/column number - // This simple implementation is used to pass tpc-c - rows, err := rs.Rows(1, 0) - if err != nil || len(rows) == 0 || len(rows[0]) == 0 { - return nil, err + + count := 0 + err = p.Do(ctx, func(id interface{}, data []interface{}) (bool, error) { + if count > 0 { + return false, errors.Errorf("Subquery returns more than 1 row") + } + + if len(p.GetFields()) == 1 { + // a scalar value is a single value + sq.Value = data[0] + } else { + // a row value is []interface{} + sq.Value = data + } + count++ + return true, nil + }) + + if err != nil { + return nil, errors.Trace(err) } - sq.Value = rows[0][0] + return sq.Value, nil } @@ -73,3 +99,23 @@ func (sq *SubQuery) String() string { } return "" } + +// ColumnCount returns column count for the sub query. +func (sq *SubQuery) ColumnCount(ctx context.Context) (int, error) { + p, err := sq.Plan(ctx) + if err != nil { + return 0, errors.Trace(err) + } + return len(p.GetFields()), nil +} + +// Plan implements plan.Planner interface. +func (sq *SubQuery) Plan(ctx context.Context) (plan.Plan, error) { + if sq.p != nil { + return sq.p, nil + } + + var err error + sq.p, err = sq.Stmt.Plan(ctx) + return sq.p, errors.Trace(err) +} diff --git a/expression/expressions/subquery_test.go b/expression/expressions/subquery_test.go index d79871af55861..33d46a0f17ee0 100644 --- a/expression/expressions/subquery_test.go +++ b/expression/expressions/subquery_test.go @@ -13,9 +13,7 @@ package expressions -import ( - . "github.com/pingcap/check" -) +import . "github.com/pingcap/check" var _ = Suite(&testSubQuerySuite{}) @@ -42,13 +40,49 @@ func (s *testSubQuerySuite) TestSubQuery(c *C) { e2, ok := ec.(*SubQuery) c.Assert(ok, IsTrue) - e2.Value = nil - e2.Stmt = newMockStatement() + e2 = newMockSubQuery([][]interface{}{{1}}, []string{"id"}) vv, err := e2.Eval(nil, nil) c.Assert(err, IsNil) c.Assert(vv, Equals, 1) + e2.Value = nil + vv, err = e2.Eval(nil, nil) + c.Assert(err, IsNil) + c.Assert(vv, Equals, 1) + + e2 = newMockSubQuery([][]interface{}{{1, 2}}, []string{"id", "name"}) + + vv, err = e2.Eval(nil, nil) + c.Assert(err, IsNil) + c.Assert(vv, DeepEquals, []interface{}{1, 2}) + + e2 = newMockSubQuery([][]interface{}{{1}, {2}}, []string{"id"}) + + _, err = e2.Eval(nil, nil) + c.Assert(err, NotNil) + str = e2.String() c.Assert(len(str), Greater, 0) + + e2 = newMockSubQuery([][]interface{}{{1, 2}}, []string{"id", "name"}) + + count, err := e2.ColumnCount(nil) + c.Assert(err, IsNil) + c.Assert(count, Equals, 2) + + count, err = e2.ColumnCount(nil) + c.Assert(err, IsNil) + c.Assert(count, Equals, 2) +} + +func newMockSubQuery(rows [][]interface{}, fields []string) *SubQuery { + r := &mockRecordset{ + rows: rows, + fields: fields, + offset: len(fields), + } + ms := &mockStatement{rset: r} + ms.plan = newMockPlan(ms.rset) + return &SubQuery{Stmt: ms} } diff --git a/parser/parser.y b/parser/parser.y index 6d36efb353b1d..58b61abf48e91 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -2461,7 +2461,7 @@ SelectStmtOrder: SubSelect: '(' SelectStmt ')' { - s := $2.(stmt.Statement) + s := $2.(*stmts.SelectStmt) s.SetText(yylex.(*lexer).src[yyS[yypt - 1].col-1:yyS[yypt].col-1]) $$ = &expressions.SubQuery{Stmt: s} }