Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support CTE clause #1207

Merged
merged 5 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 123 additions & 2 deletions ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,22 @@ func (s *SelectStmtKind) String() string {
return ""
}

// SelectStmt represents a select/table/values query node.
type CommonTableExpression struct {
node

Name model.CIStr
Query *SubqueryExpr
ColNameList []model.CIStr
}

type WithClause struct {
node

IsRecursive bool
CTEs []*CommonTableExpression
}

// SelectStmt represents the select query node.
// See https://dev.mysql.com/doc/refman/5.7/en/select.html
type SelectStmt struct {
dmlNode
Expand Down Expand Up @@ -1032,6 +1047,53 @@ type SelectStmt struct {
Kind SelectStmtKind
// Lists is filled only when Kind == SelectStmtKindValues
Lists []*RowExpr
With *WithClause
}

func (n *WithClause) Restore(ctx *format.RestoreCtx) error {
ctx.WriteKeyWord("WITH ")
if n.IsRecursive {
ctx.WriteKeyWord("RECURSIVE ")
}
for i, cte := range n.CTEs {
if i != 0 {
ctx.WritePlain(", ")
}
ctx.WriteName(cte.Name.String())
if len(cte.ColNameList) > 0 {
ctx.WritePlain(" (")
for j, name := range cte.ColNameList {
if j != 0 {
ctx.WritePlain(", ")
}
ctx.WriteName(name.String())
}
ctx.WritePlain(")")
}
ctx.WriteKeyWord(" AS ")
err := cte.Query.Restore(ctx)
if err != nil {
return err
}
}
ctx.WritePlain(" ")
return nil
}

func (n *WithClause) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNode)
}

for _, cte := range n.CTEs {
node, ok := cte.Query.Accept(v)
if !ok {
return n, false
}
cte.Query = node.(*SubqueryExpr)
}
return v.Leave(n)
}

// Restore implements Node interface.
Expand All @@ -1042,6 +1104,13 @@ func (n *SelectStmt) Restore(ctx *format.RestoreCtx) error {
ctx.WritePlain(")")
}()
}
if n.With != nil {
err := n.With.Restore(ctx)
if err != nil {
return err
}
}

ctx.WriteKeyWord(n.Kind.String())
ctx.WritePlain(" ")
switch n.Kind {
Expand Down Expand Up @@ -1204,6 +1273,15 @@ func (n *SelectStmt) Accept(v Visitor) (Node, bool) {
}

n = newNode.(*SelectStmt)

if n.With != nil {
node, ok := n.With.Accept(v)
if !ok {
return n, false
}
n.With = node.(*WithClause)
}

if n.TableHints != nil && len(n.TableHints) != 0 {
newHints := make([]*TableOptimizerHint, len(n.TableHints))
for i, hint := range n.TableHints {
Expand Down Expand Up @@ -1381,10 +1459,17 @@ type SetOprStmt struct {
SelectList *SetOprSelectList
OrderBy *OrderByClause
Limit *Limit
With *WithClause
}

// Restore implements Node interface.
func (n *SetOprStmt) Restore(ctx *format.RestoreCtx) error {
if n.With != nil {
if err := n.With.Restore(ctx); err != nil {
return errors.Annotate(err, "An error occurred while restore UnionStmt.With")
}
}

if err := n.SelectList.Restore(ctx); err != nil {
return errors.Annotate(err, "An error occurred while restore SetOprStmt.SelectList")
}
Expand All @@ -1411,7 +1496,13 @@ func (n *SetOprStmt) Accept(v Visitor) (Node, bool) {
if skipChildren {
return v.Leave(newNode)
}
n = newNode.(*SetOprStmt)
if n.With != nil {
node, ok := n.With.Accept(v)
if !ok {
return n, false
}
n.With = node.(*WithClause)
}
if n.SelectList != nil {
node, ok := n.SelectList.Accept(v)
if !ok {
Expand Down Expand Up @@ -1943,10 +2034,18 @@ type DeleteStmt struct {
BeforeFrom bool
// TableHints represents the table level Optimizer Hint for join type.
TableHints []*TableOptimizerHint
With *WithClause
}

// Restore implements Node interface.
func (n *DeleteStmt) Restore(ctx *format.RestoreCtx) error {
if n.With != nil {
err := n.With.Restore(ctx)
if err != nil {
return err
}
}

ctx.WriteKeyWord("DELETE ")

if n.TableHints != nil && len(n.TableHints) != 0 {
Expand Down Expand Up @@ -2036,6 +2135,13 @@ func (n *DeleteStmt) Accept(v Visitor) (Node, bool) {
}

n = newNode.(*DeleteStmt)
if n.With != nil {
node, ok := n.With.Accept(v)
if !ok {
return n, false
}
n.With = node.(*WithClause)
}
node, ok := n.TableRefs.Accept(v)
if !ok {
return n, false
Expand Down Expand Up @@ -2088,10 +2194,18 @@ type UpdateStmt struct {
IgnoreErr bool
MultipleTable bool
TableHints []*TableOptimizerHint
With *WithClause
}

// Restore implements Node interface.
func (n *UpdateStmt) Restore(ctx *format.RestoreCtx) error {
if n.With != nil {
err := n.With.Restore(ctx)
if err != nil {
return err
}
}

ctx.WriteKeyWord("UPDATE ")

if n.TableHints != nil && len(n.TableHints) != 0 {
Expand Down Expand Up @@ -2169,6 +2283,13 @@ func (n *UpdateStmt) Accept(v Visitor) (Node, bool) {
return v.Leave(newNode)
}
n = newNode.(*UpdateStmt)
if n.With != nil {
node, ok := n.With.Accept(v)
if !ok {
return n, false
}
n.With = node.(*WithClause)
}
node, ok := n.TableRefs.Accept(v)
if !ok {
return n, false
Expand Down
1 change: 1 addition & 0 deletions misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ var tokenMap = map[string]int{
"REBUILD": rebuild,
"RECENT": recent,
"RECOVER": recover,
"RECURSIVE": recursive,
"REDUNDANT": redundant,
"REFERENCES": references,
"REGEXP": regexpKwd,
Expand Down
Loading