diff --git a/internal/dinosql/checks.go b/internal/dinosql/checks.go index 836f25ad94..610acfac7f 100644 --- a/internal/dinosql/checks.go +++ b/internal/dinosql/checks.go @@ -4,16 +4,18 @@ import ( "fmt" "strings" + nodes "github.com/lfittl/pg_query_go/nodes" + "github.com/kyleconroy/sqlc/internal/catalog" "github.com/kyleconroy/sqlc/internal/pg" - nodes "github.com/lfittl/pg_query_go/nodes" + "github.com/kyleconroy/sqlc/internal/postgresql/ast" ) func validateParamRef(n nodes.Node) error { var allrefs []nodes.ParamRef // Find all parameter references - Walk(VisitorFunc(func(node nodes.Node) { + ast.Walk(ast.VisitorFunc(func(node nodes.Node) { switch n := node.(type) { case nodes.ParamRef: allrefs = append(allrefs, n) @@ -41,7 +43,7 @@ type funcCallVisitor struct { err error } -func (v *funcCallVisitor) Visit(node nodes.Node) Visitor { +func (v *funcCallVisitor) Visit(node nodes.Node) ast.Visitor { if v.err != nil { return nil } @@ -91,7 +93,7 @@ func (v *funcCallVisitor) Visit(node nodes.Node) Visitor { func validateFuncCall(c *pg.Catalog, n nodes.Node) error { visitor := funcCallVisitor{catalog: c} - Walk(&visitor, n) + ast.Walk(&visitor, n) return visitor.err } @@ -120,3 +122,29 @@ func validateInsertStmt(stmt nodes.InsertStmt) error { } return nil } + +// A query can use one (and only one) of the following formats: +// - positional parameters $1 +// - named parameter operator @param +// - named parameter function calls sqlc.arg(param) +func validateParamStyle(n nodes.Node) error { + positional := search(n, func(node nodes.Node) bool { + _, ok := node.(nodes.ParamRef) + return ok + }) + namedFunc := search(n, isNamedParamFunc) + namedSign := search(n, isNamedParamSign) + for _, check := range []bool{ + len(positional.Items) > 0 && len(namedSign.Items)+len(namedFunc.Items) > 0, + len(namedFunc.Items) > 0 && len(namedSign.Items)+len(positional.Items) > 0, + len(namedSign.Items) > 0 && len(positional.Items)+len(namedFunc.Items) > 0, + } { + if check { + return pg.Error{ + Code: "", // TODO: Pick a new error code + Message: "query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg)", + } + } + } + return nil +} diff --git a/internal/dinosql/parser.go b/internal/dinosql/parser.go index dd3c9846f9..adf1c897e9 100644 --- a/internal/dinosql/parser.go +++ b/internal/dinosql/parser.go @@ -16,6 +16,7 @@ import ( "github.com/kyleconroy/sqlc/internal/config" core "github.com/kyleconroy/sqlc/internal/pg" "github.com/kyleconroy/sqlc/internal/postgres" + "github.com/kyleconroy/sqlc/internal/postgresql/ast" "github.com/davecgh/go-spew/spew" pg "github.com/lfittl/pg_query_go" @@ -315,13 +316,13 @@ func pluckQuery(source string, n nodes.RawStmt) (string, error) { func rangeVars(root nodes.Node) []nodes.RangeVar { var vars []nodes.RangeVar - find := VisitorFunc(func(node nodes.Node) { + find := ast.VisitorFunc(func(node nodes.Node) { switch n := node.(type) { case nodes.RangeVar: vars = append(vars, n) } }) - Walk(find, root) + ast.Walk(find, root) return vars } @@ -416,6 +417,9 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) if err := validateParamRef(stmt); err != nil { return nil, err } + if err := validateParamStyle(stmt); err != nil { + return nil, err + } raw, ok := stmt.(nodes.RawStmt) if !ok { return nil, errors.New("node is not a statement") @@ -449,9 +453,12 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) if err := validateCmd(raw.Stmt, name, cmd); err != nil { return nil, err } + + // Re-write query AST + raw, namedParams, edits := rewriteNamedParameters(raw) rvs := rangeVars(raw.Stmt) refs := findParameters(raw.Stmt) - params, err := resolveCatalogRefs(c, rvs, refs) + params, err := resolveCatalogRefs(c, rvs, refs, namedParams) if err != nil { return nil, err } @@ -464,10 +471,23 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) if err != nil { return nil, err } - expanded, err := expand(qc, raw, rawSQL) + + expandEdits, err := expand(qc, raw) if err != nil { return nil, err } + edits = append(edits, expandEdits...) + expanded, err := editQuery(rawSQL, edits) + if err != nil { + return nil, err + } + + // If the query string was edited, make sure the syntax is valid + if expanded != rawSQL { + if _, err := pg.Parse(expanded); err != nil { + return nil, fmt.Errorf("edited query syntax is invalid: %w", err) + } + } trimmed, comments, err := stripComments(strings.TrimSpace(expanded)) if err != nil { @@ -506,7 +526,7 @@ type edit struct { New string } -func expand(qc *QueryCatalog, raw nodes.RawStmt, sql string) (string, error) { +func expand(qc *QueryCatalog, raw nodes.RawStmt) ([]edit, error) { list := search(raw, func(node nodes.Node) bool { switch node.(type) { case nodes.DeleteStmt: @@ -519,17 +539,17 @@ func expand(qc *QueryCatalog, raw nodes.RawStmt, sql string) (string, error) { return true }) if len(list.Items) == 0 { - return sql, nil + return nil, nil } var edits []edit for _, item := range list.Items { edit, err := expandStmt(qc, raw, item) if err != nil { - return "", err + return nil, err } edits = append(edits, edit...) } - return editQuery(sql, edits) + return edits, nil } func expandStmt(qc *QueryCatalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) { @@ -983,6 +1003,7 @@ type paramRef struct { parent nodes.Node rv *nodes.RangeVar ref nodes.ParamRef + name string // Named parameter support } type paramSearch struct { @@ -1014,10 +1035,16 @@ type limitOffset struct { nodeImpl } -func (p paramSearch) Visit(node nodes.Node) Visitor { +func (p paramSearch) Visit(node nodes.Node) ast.Visitor { switch n := node.(type) { case nodes.A_Expr: + if join(n.Name, "-") == "@" && n.Lexpr == nil { + param := nodes.ParamRef{Number: 1} + // TODO: Remove hard-coded slug + p.refs[1] = paramRef{parent: p.parent, rv: p.rangeVar, name: "slug", ref: param} + return nil + } p.parent = node case nodes.FuncCall: @@ -1111,7 +1138,7 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { func findParameters(root nodes.Node) []paramRef { v := paramSearch{refs: map[int]paramRef{}} - Walk(v, root) + ast.Walk(v, root) refs := make([]paramRef, 0) for _, r := range v.refs { refs = append(refs, r) @@ -1125,7 +1152,7 @@ type nodeSearch struct { check func(nodes.Node) bool } -func (s *nodeSearch) Visit(node nodes.Node) Visitor { +func (s *nodeSearch) Visit(node nodes.Node) ast.Visitor { if s.check(node) { s.list.Items = append(s.list.Items, node) } @@ -1134,16 +1161,23 @@ func (s *nodeSearch) Visit(node nodes.Node) Visitor { func search(root nodes.Node, f func(nodes.Node) bool) nodes.List { ns := &nodeSearch{check: f} - Walk(ns, root) + ast.Walk(ns, root) return ns.list } -func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ([]Parameter, error) { +func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef, names map[int]string) ([]Parameter, error) { aliasMap := map[string]core.FQN{} // TODO: Deprecate defaultTable var defaultTable *core.FQN var tables []core.FQN + parameterName := func(n int, defaultName string) string { + if n, ok := names[n]; ok { + return n + } + return defaultName + } + for _, rv := range rvs { if rv.Relname == nil { continue @@ -1193,7 +1227,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( a = append(a, Parameter{ Number: ref.ref.Number, Column: core.Column{ - Name: "offset", + Name: parameterName(ref.ref.Number, "offset"), DataType: "integer", NotNull: true, }, @@ -1203,7 +1237,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( a = append(a, Parameter{ Number: ref.ref.Number, Column: core.Column{ - Name: "limit", + Name: parameterName(ref.ref.Number, "limit"), DataType: "integer", NotNull: true, }, @@ -1256,10 +1290,13 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( for _, table := range search { if c, ok := typeMap[table.Schema][table.Rel][key]; ok { found += 1 + if ref.name != "" { + key = ref.name + } a = append(a, Parameter{ Number: ref.ref.Number, Column: core.Column{ - Name: key, + Name: parameterName(ref.ref.Number, key), DataType: c.DataType, NotNull: c.NotNull, IsArray: c.IsArray, @@ -1312,7 +1349,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( a = append(a, Parameter{ Number: ref.ref.Number, Column: core.Column{ - Name: fun.Name, + Name: parameterName(ref.ref.Number, fun.Name), DataType: "any", }, }) @@ -1329,7 +1366,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( a = append(a, Parameter{ Number: ref.ref.Number, Column: core.Column{ - Name: name, + Name: parameterName(ref.ref.Number, name), DataType: arg.DataType, NotNull: true, }, @@ -1345,7 +1382,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( a = append(a, Parameter{ Number: ref.ref.Number, Column: core.Column{ - Name: key, + Name: parameterName(ref.ref.Number, key), DataType: c.DataType, NotNull: c.NotNull, IsArray: c.IsArray, @@ -1364,9 +1401,11 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( if n.TypeName == nil { return nil, fmt.Errorf("nodes.TypeCast has nil type name") } + col := catalog.ToColumn(n.TypeName) + col.Name = parameterName(ref.ref.Number, col.Name) a = append(a, Parameter{ Number: ref.ref.Number, - Column: catalog.ToColumn(n.TypeName), + Column: col, }) case nodes.ParamRef: diff --git a/internal/dinosql/rewrite.go b/internal/dinosql/rewrite.go new file mode 100644 index 0000000000..95b5c457a8 --- /dev/null +++ b/internal/dinosql/rewrite.go @@ -0,0 +1,147 @@ +package dinosql + +import ( + "fmt" + + nodes "github.com/lfittl/pg_query_go/nodes" + + "github.com/kyleconroy/sqlc/internal/postgresql/ast" +) + +// Given an AST node, return the string representation of names +func flatten(root nodes.Node) string { + sw := &stringWalker{} + ast.Walk(sw, root) + return sw.String +} + +type stringWalker struct { + String string +} + +func (s *stringWalker) Visit(node nodes.Node) ast.Visitor { + if n, ok := node.(nodes.String); ok { + s.String += n.Str + } + return s +} + +func isNamedParamFunc(node nodes.Node) bool { + fun, ok := node.(nodes.FuncCall) + return ok && ast.Join(fun.Funcname, ".") == "sqlc.arg" +} + +func isNamedParamSign(node nodes.Node) bool { + expr, ok := node.(nodes.A_Expr) + return ok && ast.Join(expr.Name, ".") == "@" +} + +func isNamedParamSignCast(node nodes.Node) bool { + expr, ok := node.(nodes.A_Expr) + if !ok { + return false + } + _, cast := expr.Rexpr.(nodes.TypeCast) + return ast.Join(expr.Name, ".") == "@" && cast +} + +func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, []edit) { + foundFunc := search(raw, isNamedParamFunc) + foundSign := search(raw, isNamedParamSign) + if len(foundFunc.Items)+len(foundSign.Items) == 0 { + return raw, map[int]string{}, nil + } + + args := map[string]int{} + argn := 0 + var edits []edit + node := ast.Apply(raw, func(cr *ast.Cursor) bool { + node := cr.Node() + switch { + + case isNamedParamFunc(node): + fun := node.(nodes.FuncCall) + param := flatten(fun.Args) + if num, ok := args[param]; ok { + cr.Replace(nodes.ParamRef{ + Number: num, + Location: fun.Location, + }) + } else { + argn += 1 + args[param] = argn + cr.Replace(nodes.ParamRef{ + Number: argn, + Location: fun.Location, + }) + } + // TODO: This code assumes that sqlc.arg(name) is on a single line + edits = append(edits, edit{ + Location: fun.Location - raw.StmtLocation, + Old: fmt.Sprintf("sqlc.arg(%s)", param), + New: fmt.Sprintf("$%d", args[param]), + }) + return false + + case isNamedParamSignCast(node): + expr := node.(nodes.A_Expr) + cast := expr.Rexpr.(nodes.TypeCast) + param := flatten(cast.Arg) + if num, ok := args[param]; ok { + cast.Arg = nodes.ParamRef{ + Number: num, + Location: expr.Location, + } + cr.Replace(cast) + } else { + argn += 1 + args[param] = argn + cast.Arg = nodes.ParamRef{ + Number: argn, + Location: expr.Location, + } + cr.Replace(cast) + } + // TODO: This code assumes that @foo::bool is on a single line + edits = append(edits, edit{ + Location: expr.Location - raw.StmtLocation, + Old: fmt.Sprintf("@%s", param), + New: fmt.Sprintf("$%d", args[param]), + }) + return false + + case isNamedParamSign(node): + expr := node.(nodes.A_Expr) + param := flatten(expr.Rexpr) + if num, ok := args[param]; ok { + cr.Replace(nodes.ParamRef{ + Number: num, + Location: expr.Location, + }) + } else { + argn += 1 + args[param] = argn + cr.Replace(nodes.ParamRef{ + Number: argn, + Location: expr.Location, + }) + } + // TODO: This code assumes that @foo is on a single line + edits = append(edits, edit{ + Location: expr.Location - raw.StmtLocation, + Old: fmt.Sprintf("@%s", param), + New: fmt.Sprintf("$%d", args[param]), + }) + return false + + default: + return true + } + }, nil) + + named := map[int]string{} + for k, v := range args { + named[v] = k + } + return node.(nodes.RawStmt), named, edits +} diff --git a/internal/endtoend/testdata/invalid_params/query.sql b/internal/endtoend/testdata/invalid_params/query.sql index 2d785e01ed..29b0c42023 100644 --- a/internal/endtoend/testdata/invalid_params/query.sql +++ b/internal/endtoend/testdata/invalid_params/query.sql @@ -9,8 +9,12 @@ SELECT foo FROM bar WHERE baz = $1 AND baz = $3; -- name: foo :one SELECT foo FROM bar; +-- name: Named :many +SELECT id FROM bar WHERE id = $1 AND sqlc.arg(named) = true; + -- stderr -- # package querytest -- query.sql:4:1: could not determine data type of parameter $1 -- query.sql:7:1: could not determine data type of parameter $2 -- query.sql:10:8: column "foo" does not exist +-- query.sql:13:1: query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg) diff --git a/internal/endtoend/testdata/named_param/go/db.go b/internal/endtoend/testdata/named_param/go/db.go new file mode 100644 index 0000000000..6a99519302 --- /dev/null +++ b/internal/endtoend/testdata/named_param/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/named_param/go/models.go b/internal/endtoend/testdata/named_param/go/models.go new file mode 100644 index 0000000000..36b40a056b --- /dev/null +++ b/internal/endtoend/testdata/named_param/go/models.go @@ -0,0 +1,10 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import () + +type Foo struct { + Name string + Bio string +} diff --git a/internal/endtoend/testdata/named_param/go/query.sql.go b/internal/endtoend/testdata/named_param/go/query.sql.go new file mode 100644 index 0000000000..3b2bcdf45a --- /dev/null +++ b/internal/endtoend/testdata/named_param/go/query.sql.go @@ -0,0 +1,72 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" +) + +const atParams = `-- name: AtParams :many +SELECT name FROM foo WHERE name = $1 AND $2::bool +` + +type AtParamsParams struct { + Slug string + Filter bool +} + +func (q *Queries) AtParams(ctx context.Context, arg AtParamsParams) ([]string, error) { + rows, err := q.db.QueryContext(ctx, atParams, arg.Slug, arg.Filter) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const funcParams = `-- name: FuncParams :many +SELECT name FROM foo WHERE name = $1 AND $2::bool +` + +type FuncParamsParams struct { + Slug string + Filter bool +} + +func (q *Queries) FuncParams(ctx context.Context, arg FuncParamsParams) ([]string, error) { + rows, err := q.db.QueryContext(ctx, funcParams, arg.Slug, arg.Filter) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + items = append(items, name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/named_param/query.sql b/internal/endtoend/testdata/named_param/query.sql new file mode 100644 index 0000000000..35dba4c476 --- /dev/null +++ b/internal/endtoend/testdata/named_param/query.sql @@ -0,0 +1,7 @@ +CREATE TABLE foo (name text not null, bio text not null); + +-- name: FuncParams :many +SELECT name FROM foo WHERE name = sqlc.arg(slug) AND sqlc.arg(filter)::bool; + +-- name: AtParams :many +SELECT name FROM foo WHERE name = @slug AND @filter::bool; diff --git a/internal/endtoend/testdata/named_param/sqlc.json b/internal/endtoend/testdata/named_param/sqlc.json new file mode 100644 index 0000000000..1161aac713 --- /dev/null +++ b/internal/endtoend/testdata/named_param/sqlc.json @@ -0,0 +1,9 @@ +{ + "version": "1", + "packages": [{ + "path": "go", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + }] +} diff --git a/internal/pg/catalog.go b/internal/pg/catalog.go index bcd2543f76..e0dd17552a 100644 --- a/internal/pg/catalog.go +++ b/internal/pg/catalog.go @@ -5,6 +5,7 @@ func NewCatalog() Catalog { Schemas: map[string]Schema{ "public": NewSchema(), "pg_catalog": pgCatalog(), + "sqlc": internalSchema(), // Likewise, the current session's temporary-table schema, pg_temp_nnn, is // always searched if it exists. It can be explicitly listed in the path by // using the alias pg_temp. If it is not listed in the path then it is diff --git a/internal/pg/sqlc.go b/internal/pg/sqlc.go new file mode 100644 index 0000000000..4c6fc57460 --- /dev/null +++ b/internal/pg/sqlc.go @@ -0,0 +1,24 @@ +package pg + +func internalSchema() Schema { + s := NewSchema() + s.Name = "sqlc" + fs := []Function{ + { + Name: "arg", + Desc: "Named argumented placeholder", + ReturnType: "void", + Arguments: []Argument{ + { + Name: "name", + DataType: "id", + }, + }, + }, + } + s.Funcs = make(map[string][]Function, len(fs)) + for _, f := range fs { + s.Funcs[f.Name] = append(s.Funcs[f.Name], f) + } + return s +} diff --git a/internal/postgresql/ast/astutil.go b/internal/postgresql/ast/astutil.go new file mode 100644 index 0000000000..97eade91e8 --- /dev/null +++ b/internal/postgresql/ast/astutil.go @@ -0,0 +1,1532 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ast + +import ( + "fmt" + "reflect" + + nodes "github.com/lfittl/pg_query_go/nodes" +) + +// An ApplyFunc is invoked by Apply for each node n, even if n is nil, +// before and/or after the node's children, using a Cursor describing +// the current node and providing operations on it. +// +// The return value of ApplyFunc controls the syntax tree traversal. +// See Apply for details. +type ApplyFunc func(*Cursor) bool + +// Apply traverses a syntax tree recursively, starting with root, +// and calling pre and post for each node as described below. +// Apply returns the syntax tree, possibly modified. +// +// If pre is not nil, it is called for each node before the node's +// children are traversed (pre-order). If pre returns false, no +// children are traversed, and post is not called for that node. +// +// If post is not nil, and a prior call of pre didn't return false, +// post is called for each node after its children are traversed +// (post-order). If post returns false, traversal is terminated and +// Apply returns immediately. +// +// Only fields that refer to AST nodes are considered children; +// i.e., token.Pos, Scopes, Objects, and fields of basic types +// (strings, etc.) are ignored. +// +// Children are traversed in the order in which they appear in the +// respective node's struct definition. A package's files are +// traversed in the filenames' alphabetical order. +// +func Apply(root nodes.Node, pre, post ApplyFunc) (result nodes.Node) { + parent := &struct{ nodes.Node }{root} + defer func() { + if r := recover(); r != nil && r != abort { + panic(r) + } + result = parent.Node + }() + a := &application{pre: pre, post: post} + a.apply(parent, "Node", nil, root) + return +} + +var abort = new(int) // singleton, to signal termination of Apply + +// A Cursor describes a node encountered during Apply. +// Information about the node and its parent is available +// from the Node, Parent, Name, and Index methods. +// +// If p is a variable of type and value of the current parent node +// c.Parent(), and f is the field identifier with name c.Name(), +// the following invariants hold: +// +// p.f == c.Node() if c.Index() < 0 +// p.f[c.Index()] == c.Node() if c.Index() >= 0 +// +// The methods Replace, Delete, InsertBefore, and InsertAfter +// can be used to change the AST without disrupting Apply. +type Cursor struct { + parent nodes.Node + name string + iter *iterator // valid if non-nil + node nodes.Node +} + +// Node returns the current Node. +func (c *Cursor) Node() nodes.Node { return c.node } + +// Parent returns the parent of the current Node. +func (c *Cursor) Parent() nodes.Node { return c.parent } + +// Name returns the name of the parent Node field that contains the current Node. +// If the parent is a *ast.Package and the current Node is a *ast.File, Name returns +// the filename for the current Node. +func (c *Cursor) Name() string { return c.name } + +// Index reports the index >= 0 of the current Node in the slice of Nodes that +// contains it, or a value < 0 if the current Node is not part of a slice. +// The index of the current node changes if InsertBefore is called while +// processing the current node. +func (c *Cursor) Index() int { + if c.iter != nil { + return c.iter.index + } + return -1 +} + +// field returns the current node's parent field value. +func (c *Cursor) field() reflect.Value { + return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name) +} + +// Replace replaces the current Node with n. +// The replacement node is not walked by Apply. +func (c *Cursor) Replace(n nodes.Node) { + v := c.field() + if i := c.Index(); i >= 0 { + v = v.Index(i) + } + v.Set(reflect.ValueOf(n)) +} + +// Replace replaces the current Node with n. +// The replacement node is not walked by Apply. +func (c *Cursor) set(val nodes.Node, ptr nodes.Node) { + v := c.field() + if i := c.Index(); i >= 0 { + v = v.Index(i) + } + if v.Type().Kind() == reflect.Ptr { + v.Set(reflect.ValueOf(ptr)) + } else { + v.Set(reflect.ValueOf(val)) + } +} + +// application carries all the shared data so we can pass it around cheaply. +type application struct { + pre, post ApplyFunc + cursor Cursor + iter iterator +} + +func (a *application) apply(parent nodes.Node, name string, iter *iterator, node nodes.Node) { + // convert typed nil into untyped nil + if v := reflect.ValueOf(node); v.Kind() == reflect.Ptr && v.IsNil() { + node = nil + } + + // TODO: If node is a pointer, dereference it. This prevents us from having + // to have nil checks in the case statement + + // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead + saved := a.cursor + a.cursor.parent = parent + a.cursor.name = name + a.cursor.iter = iter + a.cursor.node = node + + if a.pre != nil && !a.pre(&a.cursor) { + a.cursor = saved + return + } + + // walk children + // (the order of the cases matches the order of the corresponding node types in go/ast) + switch n := node.(type) { + case nil: + // nothing to do + + case nodes.A_ArrayExpr: + a.apply(&n, "Elements", nil, n.Elements) + a.cursor.set(n, &n) + + case nodes.A_Const: + a.apply(&n, "Val", nil, n.Val) + a.cursor.set(n, &n) + + case nodes.A_Expr: + a.apply(&n, "Name", nil, n.Name) + a.apply(&n, "Lexpr", nil, n.Lexpr) + a.apply(&n, "Rexpr", nil, n.Rexpr) + a.cursor.set(n, &n) + + case nodes.A_Indices: + a.apply(&n, "Lidx", nil, n.Lidx) + a.apply(&n, "Uidx", nil, n.Uidx) + a.cursor.set(n, &n) + + case nodes.A_Indirection: + a.apply(&n, "Arg", nil, n.Arg) + a.apply(&n, "Indirection", nil, n.Indirection) + a.cursor.set(n, &n) + + case nodes.A_Star: + // pass + + case nodes.AccessPriv: + a.apply(&n, "Cols", nil, n.Cols) + a.cursor.set(n, &n) + + case nodes.Aggref: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Aggargtypes", nil, n.Aggargtypes) + a.apply(&n, "Aggdirectargs", nil, n.Aggdirectargs) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Aggorder", nil, n.Aggorder) + a.apply(&n, "Aggdistinct", nil, n.Aggdistinct) + a.apply(&n, "Aggfilter", nil, n.Aggfilter) + a.cursor.set(n, &n) + + case nodes.Alias: + a.apply(&n, "Colnames", nil, n.Colnames) + a.cursor.set(n, &n) + + case nodes.AlterCollationStmt: + a.apply(&n, "Collname", nil, n.Collname) + a.cursor.set(n, &n) + + case nodes.AlterDatabaseSetStmt: + if n.Setstmt != nil { + a.apply(&n, "Setstmt", nil, *n.Setstmt) + } + a.cursor.set(n, &n) + + case nodes.AlterDatabaseStmt: + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.AlterDefaultPrivilegesStmt: + if n.Action != nil { + a.apply(&n, "Action", nil, *n.Action) + } + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.AlterDomainStmt: + a.apply(&n, "TypeName", nil, n.TypeName) + a.apply(&n, "Def", nil, n.Def) + a.cursor.set(n, &n) + + case nodes.AlterEnumStmt: + a.apply(&n, "TypeName", nil, n.TypeName) + a.cursor.set(n, &n) + + case nodes.AlterEventTrigStmt: + // pass + + case nodes.AlterExtensionContentsStmt: + a.apply(&n, "Object", nil, n.Object) + a.cursor.set(n, &n) + + case nodes.AlterExtensionStmt: + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.AlterFdwStmt: + a.apply(&n, "FuncOptions", nil, n.FuncOptions) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.AlterForeignServerStmt: + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.AlterFunctionStmt: + if n.Func != nil { + a.apply(&n, "Func", nil, n.Func) + } + a.apply(&n, "Actions", nil, n.Actions) + a.cursor.set(n, &n) + + case nodes.AlterObjectDependsStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "Object", nil, n.Object) + a.apply(&n, "Extname", nil, n.Extname) + a.cursor.set(n, &n) + + case nodes.AlterObjectSchemaStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "Object", nil, n.Object) + a.cursor.set(n, &n) + + case nodes.AlterOpFamilyStmt: + a.apply(&n, "Opfamilyname", nil, n.Opfamilyname) + a.apply(&n, "Items", nil, n.Items) + a.cursor.set(n, &n) + + case nodes.AlterOperatorStmt: + if n.Opername != nil { + a.apply(&n, "Opername", nil, *n.Opername) + } + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.AlterOwnerStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "Object", nil, n.Object) + if n.Newowner != nil { + a.apply(&n, "Newowner", nil, *n.Newowner) + } + a.cursor.set(n, &n) + + case nodes.AlterPolicyStmt: + if n.Table != nil { + a.apply(&n, "Table", nil, *n.Table) + } + a.apply(&n, "Roles", nil, n.Roles) + a.apply(&n, "Qual", nil, n.Qual) + a.apply(&n, "WithCheck", nil, n.WithCheck) + a.cursor.set(n, &n) + + case nodes.AlterPublicationStmt: + a.apply(&n, "Options", nil, n.Options) + a.apply(&n, "Table", nil, n.Tables) + a.cursor.set(n, &n) + + case nodes.AlterRoleSetStmt: + if n.Role != nil { + a.apply(&n, "Role", nil, *n.Role) + } + a.apply(&n, "Setstmt", nil, n.Setstmt) + a.cursor.set(n, &n) + + case nodes.AlterRoleStmt: + if n.Role != nil { + a.apply(&n, "Role", nil, *n.Role) + } + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.AlterSeqStmt: + if n.Sequence != nil { + a.apply(&n, "Sequence", nil, *n.Sequence) + } + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.AlterSubscriptionStmt: + a.apply(&n, "Publication", nil, n.Publication) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.AlterSystemStmt: + a.apply(&n, "Setstmt", nil, n.Setstmt) + a.cursor.set(n, &n) + + case nodes.AlterTSConfigurationStmt: + a.apply(&n, "Cfgname", nil, n.Cfgname) + a.apply(&n, "Tokentype", nil, n.Tokentype) + a.apply(&n, "Dicts", nil, n.Dicts) + a.cursor.set(n, &n) + + case nodes.AlterTSDictionaryStmt: + a.apply(&n, "Dictname", nil, n.Dictname) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.AlterTableCmd: + if n.Newowner != nil { + a.apply(&n, "Newowner", nil, *n.Newowner) + } + a.apply(&n, "Def", nil, n.Def) + a.cursor.set(n, &n) + + case nodes.AlterTableMoveAllStmt: + a.apply(&n, "Roles", nil, n.Roles) + a.cursor.set(n, &n) + + case nodes.AlterTableSpaceOptionsStmt: + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.AlterTableStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "Cmds", nil, n.Cmds) + a.cursor.set(n, &n) + + case nodes.AlterUserMappingStmt: + if n.User != nil { + a.apply(&n, "User", nil, *n.User) + } + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.AlternativeSubPlan: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Subplans", nil, n.Subplans) + a.cursor.set(n, &n) + + case nodes.ArrayCoerceExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.cursor.set(n, &n) + + case nodes.ArrayExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Elements", nil, n.Elements) + a.cursor.set(n, &n) + + case nodes.ArrayRef: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Refupperindexpr", nil, n.Refupperindexpr) + a.apply(&n, "Reflowerindexpr", nil, n.Reflowerindexpr) + a.apply(&n, "Refexpr", nil, n.Refexpr) + a.apply(&n, "Refassgnexpr", nil, n.Refassgnexpr) + a.cursor.set(n, &n) + + case nodes.BitString: + // pass + + case nodes.BlockIdData: + // pass + + case nodes.BoolExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) + + case nodes.BooleanTest: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.cursor.set(n, &n) + + case nodes.CaseExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Defresult", nil, n.Defresult) + a.cursor.set(n, &n) + + case nodes.CaseTestExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.set(n, &n) + + case nodes.CaseWhen: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Expr", nil, n.Expr) + a.apply(&n, "Result", nil, n.Result) + a.cursor.set(n, &n) + + case nodes.CheckPointStmt: + // pass + + case nodes.ClosePortalStmt: + // pass + + case nodes.ClusterStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.cursor.set(n, &n) + + case nodes.CoalesceExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) + + case nodes.CoerceToDomain: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.cursor.set(n, &n) + + case nodes.CoerceToDomainValue: + a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.set(n, &n) + + case nodes.CoerceViaIO: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.cursor.set(n, &n) + + case nodes.CollateClause: + a.apply(&n, "Arg", nil, n.Arg) + a.apply(&n, "Collname", nil, n.Collname) + a.cursor.set(n, &n) + case nodes.CollateExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.cursor.set(n, &n) + + case nodes.ColumnDef: + if n.TypeName != nil { + a.apply(&n, "TypeName", nil, *n.TypeName) + } + a.apply(&n, "RawDefault", nil, n.RawDefault) + a.apply(&n, "CookedDefault", nil, n.CookedDefault) + a.apply(&n, "Constraints", nil, n.Constraints) + a.apply(&n, "Fdwoptions", nil, n.Fdwoptions) + a.cursor.set(n, &n) + + case nodes.ColumnRef: + a.apply(&n, "Fields", nil, n.Fields) + a.cursor.set(n, &n) + + case nodes.CommentStmt: + a.apply(&n, "Object", nil, n.Object) + a.cursor.set(n, &n) + + case nodes.CommonTableExpr: + a.apply(&n, "Aliascolnames", nil, n.Aliascolnames) + a.apply(&n, "Ctequery", nil, n.Ctequery) + a.apply(&n, "Ctecolnames", nil, n.Ctecolnames) + a.apply(&n, "Ctecolcollations", nil, n.Ctecolcollations) + a.cursor.set(n, &n) + + case nodes.CompositeTypeStmt: + if n.Typevar != nil { + a.apply(&n, "Typevar", nil, *n.Typevar) + } + a.apply(&n, "Coldeflist", nil, n.Coldeflist) + a.cursor.set(n, &n) + + case nodes.Const: + a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.set(n, &n) + + case nodes.Constraint: + a.apply(&n, "RawExpr", nil, n.RawExpr) + a.apply(&n, "Keys", nil, n.Keys) + a.apply(&n, "Exclusions", nil, n.Exclusions) + a.apply(&n, "Options", nil, n.Options) + a.apply(&n, "WhereClause", nil, n.WhereClause) + if n.Pktable != nil { + a.apply(&n, "Pktable", nil, *n.Pktable) + } + a.apply(&n, "FkAttrs", nil, n.FkAttrs) + a.apply(&n, "PkAttrs", nil, n.PkAttrs) + a.apply(&n, "OldConpfeqop", nil, n.OldConpfeqop) + a.cursor.set(n, &n) + + case nodes.ConstraintsSetStmt: + a.apply(&n, "Constraints", nil, n.Constraints) + a.cursor.set(n, &n) + + case nodes.ConvertRowtypeExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.cursor.set(n, &n) + + case nodes.CopyStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "Query", nil, n.Query) + a.apply(&n, "Attlist", nil, n.Attlist) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.CreateAmStmt: + a.apply(&n, "HandlerName", nil, n.HandlerName) + a.cursor.set(n, &n) + + case nodes.CreateCastStmt: + if n.Sourcetype != nil { + a.apply(&n, "Sourcetype", nil, *n.Sourcetype) + } + if n.Targettype != nil { + a.apply(&n, "Targettype", nil, *n.Targettype) + } + a.apply(&n, "Func", nil, n.Func) + a.cursor.set(n, &n) + + case nodes.CreateConversionStmt: + a.apply(&n, "ConversionName", nil, n.ConversionName) + a.apply(&n, "Funcname", nil, n.FuncName) + a.cursor.set(n, &n) + + case nodes.CreateDomainStmt: + a.apply(&n, "Domainname", nil, n.Domainname) + if n.TypeName != nil { + a.apply(&n, "TypeName", nil, *n.TypeName) + } + if n.CollClause != nil { + a.apply(&n, "CollClause", nil, *n.CollClause) + } + a.apply(&n, "Constraints", nil, n.Constraints) + a.cursor.set(n, &n) + + case nodes.CreateEnumStmt: + a.apply(&n, "TypeName", nil, n.TypeName) + a.apply(&n, "Vals", nil, n.Vals) + a.cursor.set(n, &n) + + case nodes.CreateEventTrigStmt: + a.apply(&n, "Whenclause", nil, n.Whenclause) + a.apply(&n, "Funcname", nil, n.Funcname) + a.cursor.set(n, &n) + + case nodes.CreateExtensionStmt: + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.CreateFdwStmt: + a.apply(&n, "FuncOptions", nil, n.FuncOptions) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.CreateForeignServerStmt: + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.CreateForeignTableStmt: + a.apply(&n, "Base", nil, n.Base) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.CreateFunctionStmt: + a.apply(&n, "Funcname", nil, n.Funcname) + a.apply(&n, "Parameters", nil, n.Parameters) + if n.ReturnType != nil { + a.apply(&n, "ReturnType", nil, *n.ReturnType) + } + a.apply(&n, "Options", nil, n.Options) + a.apply(&n, "WithClause", nil, n.WithClause) + a.cursor.set(n, &n) + + case nodes.CreateOpClassItem: + a.apply(&n, "Name", nil, n.Name) + a.apply(&n, "OrderFamily", nil, n.OrderFamily) + a.apply(&n, "ClassArgs", nil, n.ClassArgs) + if n.Storedtype != nil { + a.apply(&n, "Storedtype", nil, *n.Storedtype) + } + a.cursor.set(n, &n) + + case nodes.CreateOpClassStmt: + a.apply(&n, "Opclassname", nil, n.Opclassname) + a.apply(&n, "Opfamilyname", nil, n.Opfamilyname) + if n.Datatype != nil { + a.apply(&n, "Datatype", nil, *n.Datatype) + } + a.apply(&n, "Items", nil, n.Items) + a.cursor.set(n, &n) + + case nodes.CreateOpFamilyStmt: + a.apply(&n, "Opfamilyname", nil, n.Opfamilyname) + a.cursor.set(n, &n) + + case nodes.CreatePLangStmt: + a.apply(&n, "Plhandler", nil, n.Plhandler) + a.apply(&n, "Plinline", nil, n.Plinline) + a.apply(&n, "Plvalidator", nil, n.Plvalidator) + a.cursor.set(n, &n) + + case nodes.CreatePolicyStmt: + if n.Table != nil { + a.apply(&n, "Table", nil, *n.Table) + } + a.apply(&n, "Roles", nil, n.Roles) + a.apply(&n, "Qual", nil, n.Qual) + a.apply(&n, "WithCheck", nil, n.WithCheck) + a.cursor.set(n, &n) + + case nodes.CreatePublicationStmt: + a.apply(&n, "Options", nil, n.Options) + a.apply(&n, "Tables", nil, n.Tables) + a.cursor.set(n, &n) + + case nodes.CreateRangeStmt: + a.apply(&n, "TypeName", nil, n.TypeName) + a.apply(&n, "Params", nil, n.Params) + a.cursor.set(n, &n) + + case nodes.CreateRoleStmt: + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.CreateSchemaStmt: + if n.Authrole != nil { + a.apply(&n, "Authrole", nil, *n.Authrole) + } + a.apply(&n, "SchemaElts", nil, n.SchemaElts) + a.cursor.set(n, &n) + + case nodes.CreateSeqStmt: + if n.Sequence != nil { + a.apply(&n, "Sequence", nil, *n.Sequence) + } + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.CreateStatsStmt: + a.apply(&n, "Defnames", nil, n.Defnames) + a.apply(&n, "StatTypes", nil, n.StatTypes) + a.apply(&n, "Exprs", nil, n.Exprs) + a.apply(&n, "Relations", nil, n.Relations) + a.cursor.set(n, &n) + + case nodes.CreateStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "TableElts", nil, n.TableElts) + a.apply(&n, "InhRelations", nil, n.InhRelations) + if n.Partbound != nil { + a.apply(&n, "Partbound", nil, *n.Partbound) + } + if n.Partspec != nil { + a.apply(&n, "Partspec", nil, *n.Partspec) + } + a.apply(&n, "Constraints", nil, n.Constraints) + a.apply(&n, "Options", nil, n.Options) + if n.OfTypename != nil { + a.apply(&n, "OfTypename", nil, *n.OfTypename) + } + a.cursor.set(n, &n) + + case nodes.CreateSubscriptionStmt: + a.apply(&n, "Publication", nil, n.Publication) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.CreateTableAsStmt: + a.apply(&n, "Query", nil, n.Query) + a.apply(&n, "Into", nil, n.Into) + a.cursor.set(n, &n) + + case nodes.CreateTableSpaceStmt: + if n.Owner != nil { + a.apply(&n, "Owner", nil, *n.Owner) + } + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.CreateTransformStmt: + if n.TypeName != nil { + a.apply(&n, "TypeName", nil, *n.TypeName) + } + if n.Fromsql != nil { + a.apply(&n, "Fromsql", nil, *n.Fromsql) + } + if n.Tosql != nil { + a.apply(&n, "Tosql", nil, *n.Tosql) + } + a.cursor.set(n, &n) + + case nodes.CreateTrigStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "Funcname", nil, n.Funcname) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Columns", nil, n.Columns) + a.apply(&n, "WhenClause", nil, n.WhenClause) + a.apply(&n, "TransitionRels", nil, n.TransitionRels) + if n.Constrrel != nil { + a.apply(&n, "Constrrel", nil, *n.Constrrel) + } + a.cursor.set(n, &n) + + case nodes.CreateUserMappingStmt: + if n.User != nil { + a.apply(&n, "User", nil, *n.User) + } + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.CreatedbStmt: + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.CurrentOfExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.set(n, &n) + + case nodes.DeallocateStmt: + // pass + + case nodes.DeclareCursorStmt: + a.apply(&n, "Query", nil, n.Query) + a.cursor.set(n, &n) + + case nodes.DefElem: + a.apply(&n, "Arg", nil, n.Arg) + a.cursor.set(n, &n) + + case nodes.DefineStmt: + a.apply(&n, "Defnames", nil, n.Defnames) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Definition", nil, n.Definition) + a.cursor.set(n, &n) + + case nodes.DeleteStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "UsingClause", nil, n.UsingClause) + a.apply(&n, "WhereClause", nil, n.WhereClause) + a.apply(&n, "ReturningList", nil, n.ReturningList) + if n.WithClause != nil { + a.apply(&n, "WithClause", nil, *n.WithClause) + } + a.cursor.set(n, &n) + + case nodes.DiscardStmt: + // pass + + case nodes.DoStmt: + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) + + case nodes.DropOwnedStmt: + a.apply(&n, "Roles", nil, n.Roles) + a.cursor.set(n, &n) + + case nodes.DropRoleStmt: + a.apply(&n, "Roles", nil, n.Roles) + a.cursor.set(n, &n) + + case nodes.DropStmt: + a.apply(&n, "Objects", nil, n.Objects) + a.cursor.set(n, &n) + + case nodes.DropSubscriptionStmt: + // pass + + case nodes.DropTableSpaceStmt: + // pass + + case nodes.DropUserMappingStmt: + if n.User != nil { + a.apply(&n, "User", nil, *n.User) + } + a.cursor.set(n, &n) + + case nodes.DropdbStmt: + // pass + + case nodes.ExecuteStmt: + a.apply(&n, "Params", nil, n.Params) + a.cursor.set(n, &n) + + case nodes.ExplainStmt: + a.apply(&n, "Query", nil, n.Query) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.Expr: + // pass + + case nodes.FetchStmt: + // pass + + case nodes.FieldSelect: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.cursor.set(n, &n) + + case nodes.FieldStore: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.apply(&n, "Newvals", nil, n.Newvals) + a.apply(&n, "Fieldnums", nil, n.Fieldnums) + a.cursor.set(n, &n) + + case nodes.Float: + // pass + + case nodes.FromExpr: + a.apply(&n, "Fromlist", nil, n.Fromlist) + a.apply(&n, "Quals", nil, n.Quals) + a.cursor.set(n, &n) + + case nodes.FuncCall: + a.apply(&n, "Funcname", nil, n.Funcname) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "AggOrder", nil, n.AggOrder) + a.apply(&n, "AggFilter", nil, n.AggFilter) + if n.Over != nil { + a.apply(&n, "Over", nil, *n.Over) + } + a.cursor.set(n, &n) + + case nodes.FuncExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) + + case nodes.FunctionParameter: + if n.ArgType != nil { + a.apply(&n, "ArgType", nil, *n.ArgType) + } + a.apply(&n, "Defexpr", nil, n.Defexpr) + a.cursor.set(n, &n) + + case nodes.GrantRoleStmt: + a.apply(&n, "GrantedRoles", nil, n.GrantedRoles) + a.apply(&n, "GranteeRoles", nil, n.GranteeRoles) + if n.Grantor != nil { + a.apply(&n, "Grantor", nil, *n.Grantor) + } + a.cursor.set(n, &n) + + case nodes.GrantStmt: + a.apply(&n, "Objects", nil, n.Objects) + a.apply(&n, "Privileges", nil, n.Privileges) + a.apply(&n, "Grantees", nil, n.Grantees) + a.cursor.set(n, &n) + + case nodes.GroupingFunc: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Refs", nil, n.Refs) + a.apply(&n, "Cols", nil, n.Cols) + a.cursor.set(n, &n) + + case nodes.GroupingSet: + a.apply(&n, "Content", nil, n.Content) + a.cursor.set(n, &n) + + case nodes.ImportForeignSchemaStmt: + a.apply(&n, "TableList", nil, n.TableList) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.IndexElem: + a.apply(&n, "Expr", nil, n.Expr) + a.apply(&n, "Collation", nil, n.Collation) + a.apply(&n, "Opclass", nil, n.Opclass) + a.cursor.set(n, &n) + + case nodes.IndexStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "IndexParams", nil, n.IndexParams) + a.apply(&n, "Options", nil, n.Options) + a.apply(&n, "WhereClause", nil, n.WhereClause) + a.apply(&n, "ExcludeOpNames", nil, n.ExcludeOpNames) + a.cursor.set(n, &n) + + case nodes.InferClause: + a.apply(&n, "IndexElems", nil, n.IndexElems) + a.apply(&n, "WhereClause", nil, n.WhereClause) + a.cursor.set(n, &n) + + case nodes.InferenceElem: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Expr", nil, n.Expr) + a.cursor.set(n, &n) + + case nodes.InlineCodeBlock: + // pass + + case nodes.InsertStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "Cols", nil, n.Cols) + a.apply(&n, "SelectStmt", nil, n.SelectStmt) + if n.OnConflictClause != nil { + a.apply(&n, "OnConflictClause", nil, *n.OnConflictClause) + } + a.apply(&n, "ReturningList", nil, n.ReturningList) + if n.WithClause != nil { + a.apply(&n, "WithClause", nil, *n.WithClause) + } + a.cursor.set(n, &n) + + case nodes.Integer: + // pass + + case nodes.IntoClause: + if n.Rel != nil { + a.apply(&n, "Rel", nil, *n.Rel) + } + a.apply(&n, "ColNames", nil, n.ColNames) + a.apply(&n, "Options", nil, n.Options) + a.apply(&n, "ViewQuery", nil, n.ViewQuery) + a.cursor.set(n, &n) + + case nodes.JoinExpr: + a.apply(&n, "Larg", nil, n.Larg) + a.apply(&n, "Rarg", nil, n.Rarg) + a.apply(&n, "UsingClause", nil, n.UsingClause) + a.apply(&n, "Quals", nil, n.Quals) + if n.Alias != nil { + a.apply(&n, "Alias", nil, *n.Alias) + } + a.cursor.set(n, &n) + + case nodes.List: + // Since item is a slice + a.applyList(&n, "Items") + + case nodes.ListenStmt: + // pass + + case nodes.LoadStmt: + // pass + + case nodes.LockStmt: + a.apply(&n, "Relations", nil, n.Relations) + a.cursor.set(n, &n) + + case nodes.LockingClause: + a.apply(&n, "LockedRels", nil, n.LockedRels) + a.cursor.set(n, &n) + + case nodes.MinMaxExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) + + case nodes.MultiAssignRef: + a.apply(&n, "Source", nil, n.Source) + a.cursor.set(n, &n) + + case nodes.NamedArgExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Arg) + a.cursor.set(n, &n) + + case nodes.NextValueExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.set(n, &n) + + case nodes.NotifyStmt: + // pass + + case nodes.Null: + // pass + + case nodes.NullTest: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.cursor.set(n, &n) + + case nodes.ObjectWithArgs: + a.apply(&n, "Objname", nil, n.Objname) + a.apply(&n, "Objargs", nil, n.Objargs) + a.cursor.set(n, &n) + + case nodes.OnConflictClause: + if n.Infer != nil { + a.apply(&n, "Infer", nil, *n.Infer) + } + a.apply(&n, "TargetList", nil, n.TargetList) + a.apply(&n, "WhereClause", nil, n.WhereClause) + a.cursor.set(n, &n) + + case nodes.OnConflictExpr: + a.apply(&n, "ArbiterElems", nil, n.ArbiterElems) + a.apply(&n, "ArbiterWhere", nil, n.ArbiterWhere) + a.apply(&n, "OnConflictSet", nil, n.OnConflictSet) + a.apply(&n, "OnConflictWhere", nil, n.OnConflictWhere) + a.apply(&n, "ExclRelTlist", nil, n.ExclRelTlist) + a.cursor.set(n, &n) + + case nodes.OpExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) + + case nodes.Param: + a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.set(n, &n) + + case nodes.ParamExecData: + // pass + + case nodes.ParamExternData: + // pass + + case nodes.ParamListInfoData: + // pass + + case nodes.ParamRef: + // pass + + case nodes.PartitionBoundSpec: + a.apply(&n, "Listdatums", nil, n.Listdatums) + a.apply(&n, "Lowerdatums", nil, n.Lowerdatums) + a.apply(&n, "Upperdatums", nil, n.Upperdatums) + a.cursor.set(n, &n) + + case nodes.PartitionCmd: + if n.Name != nil { + a.apply(&n, "Name", nil, *n.Name) + } + if n.Bound != nil { + a.apply(&n, "Bound", nil, *n.Bound) + } + a.cursor.set(n, &n) + + case nodes.PartitionElem: + a.apply(&n, "Expr", nil, n.Expr) + a.apply(&n, "Collation", nil, n.Collation) + a.apply(&n, "Opclass", nil, n.Opclass) + a.cursor.set(n, &n) + + case nodes.PartitionRangeDatum: + a.apply(&n, "Value", nil, n.Value) + a.cursor.set(n, &n) + + case nodes.PartitionSpec: + a.apply(&n, "PartParams", nil, n.PartParams) + a.cursor.set(n, &n) + + case nodes.PrepareStmt: + a.apply(&n, "Argtypes", nil, n.Argtypes) + a.apply(&n, "Query", nil, n.Query) + a.cursor.set(n, &n) + + case nodes.Query: + a.apply(&n, "UtilityStmt", nil, n.UtilityStmt) + a.apply(&n, "CteList", nil, n.CteList) + a.apply(&n, "Jointree", nil, n.Jointree) + a.apply(&n, "TargetList", nil, n.TargetList) + a.apply(&n, "OnConflict", nil, n.OnConflict) + a.apply(&n, "ReturningList", nil, n.ReturningList) + a.apply(&n, "GroupClause", nil, n.GroupClause) + a.apply(&n, "GroupingSets", nil, n.GroupingSets) + a.apply(&n, "HavingQual", nil, n.HavingQual) + a.apply(&n, "WindowClause", nil, n.WindowClause) + a.apply(&n, "DistinctClause", nil, n.DistinctClause) + a.apply(&n, "SortClause", nil, n.SortClause) + a.apply(&n, "LimitCount", nil, n.LimitCount) + a.apply(&n, "RowMarks", nil, n.RowMarks) + a.apply(&n, "SetOperations", nil, n.SetOperations) + a.apply(&n, "ConstraintDeps", nil, n.ConstraintDeps) + a.apply(&n, "WithCheckOptions", nil, n.WithCheckOptions) + a.cursor.set(n, &n) + + case nodes.RangeFunction: + a.apply(&n, "Functions", nil, n.Functions) + if n.Alias != nil { + a.apply(&n, "Alias", nil, *n.Alias) + } + a.apply(&n, "Coldeflist", nil, n.Coldeflist) + a.cursor.set(n, &n) + + case nodes.RangeSubselect: + a.apply(&n, "Subquery", nil, n.Subquery) + if n.Alias != nil { + a.apply(&n, "Alias", nil, *n.Alias) + } + a.cursor.set(n, &n) + + case nodes.RangeTableFunc: + a.apply(&n, "Docexpr", nil, n.Docexpr) + a.apply(&n, "Rowexpr", nil, n.Rowexpr) + a.apply(&n, "Namespaces", nil, n.Namespaces) + a.apply(&n, "Columns", nil, n.Columns) + if n.Alias != nil { + a.apply(&n, "Alias", nil, *n.Alias) + } + a.cursor.set(n, &n) + + case nodes.RangeTableFuncCol: + if n.TypeName != nil { + a.apply(&n, "TypeName", nil, *n.TypeName) + } + a.apply(&n, "Colexpr", nil, n.Colexpr) + a.apply(&n, "Coldefexpr", nil, n.Coldefexpr) + a.cursor.set(n, &n) + + case nodes.RangeTableSample: + a.apply(&n, "Relation", nil, n.Relation) + a.apply(&n, "Method", nil, n.Method) + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) + + case nodes.RangeTblEntry: + a.apply(&n, "Tablesample", nil, n.Tablesample) + a.apply(&n, "Subquery", nil, n.Subquery) + a.apply(&n, "Joinaliasvars", nil, n.Joinaliasvars) + a.apply(&n, "Functions", nil, n.Functions) + a.apply(&n, "Tablefund", nil, n.Tablefunc) + a.apply(&n, "ValuesLists", nil, n.ValuesLists) + a.apply(&n, "Coltypes", nil, n.Coltypes) + a.apply(&n, "Colcollations", nil, n.Colcollations) + if n.Alias != nil { + a.apply(&n, "Alias", nil, *n.Alias) + } + a.apply(&n, "Eref", nil, n.Eref) + a.apply(&n, "SecurityQuals", nil, n.SecurityQuals) + a.cursor.set(n, &n) + + case nodes.RangeTblFunction: + a.apply(&n, "Funcexpr", nil, n.Funcexpr) + a.apply(&n, "Funccolnames", nil, n.Funccolnames) + a.apply(&n, "Funccoltypes", nil, n.Funccoltypes) + a.apply(&n, "Funccoltypmods", nil, n.Funccoltypmods) + a.apply(&n, "Funccolcollations", nil, n.Funccolcollations) + a.cursor.set(n, &n) + + case nodes.RangeTblRef: + // pass + + case nodes.RangeVar: + if n.Alias != nil { + a.apply(&n, "Alias", nil, *n.Alias) + } + a.cursor.set(n, &n) + + case nodes.RawStmt: + a.apply(&n, "Stmt", nil, n.Stmt) + a.cursor.set(n, &n) + + case nodes.ReassignOwnedStmt: + a.apply(&n, "Roles", nil, n.Roles) + if n.Newrole != nil { + a.apply(&n, "Newrole", nil, *n.Newrole) + } + a.cursor.set(n, &n) + + case nodes.RefreshMatViewStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + a.cursor.set(n, &n) + } + + case nodes.ReindexStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + a.cursor.set(n, &n) + } + + case nodes.RelabelType: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.cursor.set(n, &n) + + case nodes.RenameStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "Object", nil, n.Object) + a.cursor.set(n, &n) + + case nodes.ReplicaIdentityStmt: + // pass + + case nodes.ResTarget: + a.apply(&n, "Indirection", nil, n.Indirection) + a.apply(&n, "Val", nil, n.Val) + a.cursor.set(n, &n) + + case nodes.RoleSpec: + // pass + + case nodes.RowCompareExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Opnos", nil, n.Opnos) + a.apply(&n, "Opfamilies", nil, n.Opfamilies) + a.apply(&n, "Inputcollids", nil, n.Inputcollids) + a.apply(&n, "Largs", nil, n.Largs) + a.apply(&n, "Rargs", nil, n.Rargs) + a.cursor.set(n, &n) + + case nodes.RowExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Colnames", nil, n.Colnames) + a.cursor.set(n, &n) + + case nodes.RowMarkClause: + // pass + + case nodes.RuleStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "WhereClause", nil, n.WhereClause) + a.apply(&n, "Actions", nil, n.Actions) + a.cursor.set(n, &n) + + case nodes.SQLValueFunction: + a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.set(n, &n) + + case nodes.ScalarArrayOpExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) + + case nodes.SecLabelStmt: + a.apply(&n, "Object", nil, n.Object) + a.cursor.set(n, &n) + + case nodes.SelectStmt: + a.apply(&n, "DistinctClause", nil, n.DistinctClause) + if n.IntoClause != nil { + a.apply(&n, "IntoClause", nil, *n.IntoClause) + } + a.apply(&n, "TargetList", nil, n.TargetList) + a.apply(&n, "FromClause", nil, n.FromClause) + a.apply(&n, "WhereClause", nil, n.WhereClause) + a.apply(&n, "GroupClause", nil, n.GroupClause) + a.apply(&n, "HavingClause", nil, n.HavingClause) + a.apply(&n, "WindowClause", nil, n.WindowClause) + // TODO: Not sure how to handle a slice of a slice + // + // for _, vs := range n.ValuesLists { + // for _, v := range vs { + // a.apply(&n, "", nil, v) + // } + // } + a.apply(&n, "SortClause", nil, n.SortClause) + a.apply(&n, "LimitOffset", nil, n.LimitOffset) + a.apply(&n, "LimitCount", nil, n.LimitCount) + a.apply(&n, "LockingClause", nil, n.LockingClause) + if n.WithClause != nil { + a.apply(&n, "WithClause", nil, *n.WithClause) + } + if n.Larg != nil { + a.apply(&n, "Larg", nil, *n.Larg) + } + if n.Rarg != nil { + a.apply(&n, "Rarg", nil, *n.Rarg) + } + a.cursor.set(n, &n) + + case nodes.SetOperationStmt: + a.apply(&n, "Larg", nil, n.Larg) + a.apply(&n, "Rarg", nil, n.Rarg) + a.apply(&n, "ColTypes", nil, n.ColTypes) + a.apply(&n, "ColTypmods", nil, n.ColTypmods) + a.apply(&n, "ColCollations", nil, n.ColCollations) + a.apply(&n, "GroupClauses", nil, n.GroupClauses) + a.cursor.set(n, &n) + + case nodes.SetToDefault: + a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.set(n, &n) + + case nodes.SortBy: + a.apply(&n, "Node", nil, n.Node) + a.apply(&n, "UseOp", nil, n.UseOp) + a.cursor.set(n, &n) + + case nodes.SortGroupClause: + // pass + + case nodes.String: + // pass + + case nodes.SubLink: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Testexpr", nil, n.Testexpr) + a.apply(&n, "Opername", nil, n.OperName) + a.apply(&n, "Subselect", nil, n.Subselect) + a.cursor.set(n, &n) + + case nodes.SubPlan: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Testexpr", nil, n.Testexpr) + a.apply(&n, "ParamIds", nil, n.ParamIds) + a.apply(&n, "SetParam", nil, n.SetParam) + a.apply(&n, "ParParam", nil, n.ParParam) + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) + + case nodes.TableFunc: + a.apply(&n, "NsUris", nil, n.NsUris) + a.apply(&n, "NsNames", nil, n.NsNames) + a.apply(&n, "Docexpr", nil, n.Docexpr) + a.apply(&n, "Rowexpr", nil, n.Rowexpr) + a.apply(&n, "Colnames", nil, n.Colnames) + a.apply(&n, "Coltypes", nil, n.Coltypes) + a.apply(&n, "ColTypmods", nil, n.Coltypmods) + a.apply(&n, "Colcollations", nil, n.Colcollations) + a.apply(&n, "Colexprs", nil, n.Colexprs) + a.apply(&n, "Coldefexprs", nil, n.Coldefexprs) + a.cursor.set(n, &n) + + case nodes.TableLikeClause: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + a.cursor.set(n, &n) + } + + case nodes.TableSampleClause: + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Repeatable", nil, n.Repeatable) + a.cursor.set(n, &n) + + case nodes.TargetEntry: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Expr", nil, n.Expr) + a.cursor.set(n, &n) + + case nodes.TransactionStmt: + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.TriggerTransition: + // pass + + case nodes.TruncateStmt: + a.apply(&n, "Relations", nil, n.Relations) + a.cursor.set(n, &n) + + case nodes.TypeCast: + a.apply(&n, "Arg", nil, n.Arg) + if n.TypeName != nil { + a.apply(&n, "TypeName", nil, *n.TypeName) + } + a.cursor.set(n, &n) + + case nodes.TypeName: + a.apply(&n, "Names", nil, n.Names) + a.apply(&n, "Typmods", nil, n.Typmods) + a.apply(&n, "ArrayBounds", nil, n.ArrayBounds) + a.cursor.set(n, &n) + + case nodes.UnlistenStmt: + // pass + + case nodes.UpdateStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "TargetList", nil, n.TargetList) + a.apply(&n, "WhereClause", nil, n.WhereClause) + a.apply(&n, "FromClause", nil, n.FromClause) + a.apply(&n, "ReturningList", nil, n.ReturningList) + if n.WithClause != nil { + a.apply(&n, "WithClause", nil, *n.WithClause) + } + a.cursor.set(n, &n) + + case nodes.VacuumStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "VaCols", nil, n.VaCols) + a.cursor.set(n, &n) + + case nodes.Var: + a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.set(n, &n) + + case nodes.VariableSetStmt: + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) + + case nodes.VariableShowStmt: + // pass + + case nodes.ViewStmt: + if n.View != nil { + a.apply(&n, "View", nil, *n.View) + } + a.apply(&n, "Aliases", nil, n.Aliases) + a.apply(&n, "Query", nil, n.Query) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) + + case nodes.WindowClause: + a.apply(&n, "PartitionClause", nil, n.PartitionClause) + a.apply(&n, "OrderClause", nil, n.OrderClause) + a.apply(&n, "StartOffset", nil, n.StartOffset) + a.apply(&n, "EndOffset", nil, n.EndOffset) + a.cursor.set(n, &n) + + case nodes.WindowDef: + a.apply(&n, "PartitionClause", nil, n.PartitionClause) + a.apply(&n, "OrderClause", nil, n.OrderClause) + a.apply(&n, "StartOffset", nil, n.StartOffset) + a.apply(&n, "EndOffset", nil, n.EndOffset) + a.cursor.set(n, &n) + + case nodes.WindowFunc: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Aggfilter", nil, n.Aggfilter) + a.cursor.set(n, &n) + + case nodes.WithCheckOption: + a.apply(&n, "Qual", nil, n.Qual) + a.cursor.set(n, &n) + + case nodes.WithClause: + a.apply(&n, "Ctes", nil, n.Ctes) + a.cursor.set(n, &n) + + case nodes.XmlExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "NamedArgs", nil, n.NamedArgs) + a.apply(&n, "ArgNames", nil, n.ArgNames) + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) + + case nodes.XmlSerialize: + a.apply(&n, "Expr", nil, n.Expr) + if n.TypeName != nil { + a.apply(&n, "TypeName", nil, *n.TypeName) + } + a.cursor.set(n, &n) + + default: + panic(fmt.Sprintf("Apply: unexpected node type %T", n)) + } + + if a.post != nil && !a.post(&a.cursor) { + panic(abort) + } + + a.cursor = saved +} + +// An iterator controls iteration over a slice of nodes. +type iterator struct { + index, step int +} + +func (a *application) applyList(parent nodes.Node, name string) { + // avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead + saved := a.iter + a.iter.index = 0 + for { + // must reload parent.name each time, since cursor modifications might change it + v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name) + if a.iter.index >= v.Len() { + break + } + + // element x may be nil in a bad AST - be cautious + var x nodes.Node + if e := v.Index(a.iter.index); e.IsValid() { + x = e.Interface().(nodes.Node) + } + + a.iter.step = 1 + a.apply(parent, name, &a.iter, x) + a.iter.index += a.iter.step + } + a.iter = saved +} diff --git a/internal/postgresql/ast/astutil_test.go b/internal/postgresql/ast/astutil_test.go new file mode 100644 index 0000000000..2f6f4273e3 --- /dev/null +++ b/internal/postgresql/ast/astutil_test.go @@ -0,0 +1,41 @@ +package ast + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + pg "github.com/lfittl/pg_query_go" + nodes "github.com/lfittl/pg_query_go/nodes" +) + +func TestApply(t *testing.T) { + input, err := pg.Parse("SELECT sqlc.arg(name)") + if err != nil { + t.Fatal(err) + } + output, err := pg.Parse("SELECT $1") + if err != nil { + t.Fatal(err) + } + + expect := output.Statements[0] + actual := Apply(input.Statements[0], func(cr *Cursor) bool { + fun, ok := cr.Node().(nodes.FuncCall) + if !ok { + return true + } + if Join(fun.Funcname, ".") == "sqlc.arg" { + cr.Replace(nodes.ParamRef{ + Number: 1, + Location: fun.Location, + }) + return false + } + + return true + }, nil) + + if diff := cmp.Diff(expect, actual); diff != "" { + t.Errorf("rewrite mismatch:\n%s", diff) + } +} diff --git a/internal/postgresql/ast/join.go b/internal/postgresql/ast/join.go new file mode 100644 index 0000000000..343b58129c --- /dev/null +++ b/internal/postgresql/ast/join.go @@ -0,0 +1,17 @@ +package ast + +import ( + "strings" + + nodes "github.com/lfittl/pg_query_go/nodes" +) + +func Join(list nodes.List, sep string) string { + items := []string{} + for _, item := range list.Items { + if n, ok := item.(nodes.String); ok { + items = append(items, n.Str) + } + } + return strings.Join(items, sep) +} diff --git a/internal/dinosql/soup.go b/internal/postgresql/ast/soup.go similarity index 99% rename from internal/dinosql/soup.go rename to internal/postgresql/ast/soup.go index 17786526eb..a662e6f9a1 100644 --- a/internal/dinosql/soup.go +++ b/internal/postgresql/ast/soup.go @@ -1,4 +1,4 @@ -package dinosql +package ast import ( "fmt"