Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All not needed an alias when there is a subquery with the same column name #2639

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions internal/compiler/find_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
func findParameters(root ast.Node) ([]paramRef, error) {
refs := make([]paramRef, 0)
errors := make([]error, 0)
v := paramSearch{seen: make(map[int]struct{}), refs: &refs, errs: &errors}
v := paramSearch{seen: make(map[int]struct{}), refs: &refs, errs: &errors, rvs: &[]*ast.RangeVar{}}
astutils.Walk(v, root)
if len(*v.errs) > 0 {
problems := *v.errs
Expand All @@ -22,6 +22,7 @@ func findParameters(root ast.Node) ([]paramRef, error) {

type paramRef struct {
parent ast.Node
rvs []*ast.RangeVar
rv *ast.RangeVar
ref *ast.ParamRef
name string // Named parameter support
Expand All @@ -31,6 +32,7 @@ type paramSearch struct {
parent ast.Node
rangeVar *ast.RangeVar
refs *[]paramRef
rvs *[]*ast.RangeVar
seen map[int]struct{}
errs *[]error

Expand Down Expand Up @@ -58,6 +60,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
return p
}

var reset bool
switch n := node.(type) {

case *ast.A_Expr:
Expand All @@ -70,6 +73,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
p.parent = n.FuncCall

case *ast.DeleteStmt:
reset = true
if n.LimitCount != nil {
p.limitCount = n.LimitCount
}
Expand All @@ -78,7 +82,13 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
p.parent = node

case *ast.InsertStmt:
reset = true
rvs := *p.rvs
if n.Relation != nil {
rvs = append(rvs, n.Relation)
}
if s, ok := n.SelectStmt.(*ast.SelectStmt); ok {
rvs = append(rvs, toTables(s.FromClause)...)
for i, item := range s.TargetList.Items {
target, ok := item.(*ast.ResTarget)
if !ok {
Expand All @@ -92,7 +102,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
*p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns"))
return p
}
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation})
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: rvs})
p.seen[ref.Location] = struct{}{}
}
for _, item := range s.ValuesLists.Items {
Expand All @@ -109,13 +119,16 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
*p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns"))
return p
}
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation})
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: rvs})
p.seen[ref.Location] = struct{}{}
}
}
}

case *ast.UpdateStmt:
reset = true
rvs := append(*p.rvs, toTables(n.FromClause)...)
rvs = append(rvs, toTables(n.Relations)...)
for _, item := range n.TargetList.Items {
target, ok := item.(*ast.ResTarget)
if !ok {
Expand All @@ -130,7 +143,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
if !ok {
continue
}
*p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv})
*p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv, rvs: rvs})
}
p.seen[ref.Location] = struct{}{}
}
Expand All @@ -139,12 +152,16 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
}

case *ast.RangeVar:
if n != nil {
*p.rvs = append(*p.rvs, n)
}
p.rangeVar = n

case *ast.ResTarget:
p.parent = node

case *ast.SelectStmt:
reset = true
if n.LimitCount != nil {
p.limitCount = n.LimitCount
}
Expand Down Expand Up @@ -191,7 +208,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
}

if set {
*p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar})
*p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar, rvs: *p.rvs})
p.seen[n.Location] = struct{}{}
}
return nil
Expand All @@ -215,5 +232,20 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
p.Visit(n.Expr)
}
}
if reset {
rvs := *p.rvs
return paramSearch{seen: p.seen, refs: p.refs, errs: p.errs, rvs: &rvs, parent: p.parent, rangeVar: p.rangeVar, limitCount: p.limitCount, limitOffset: p.limitOffset}
}
return p
}

func toTables(tbl *ast.List) []*ast.RangeVar {
tables := make([]*ast.RangeVar, len(tbl.Items))
for _, t := range tbl.Items {
item, ok := t.(*ast.RangeVar)
if ok && item != nil {
tables = append(tables, item)
}
}
return tables
}
6 changes: 5 additions & 1 deletion internal/compiler/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
return nil, err
}

params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds)
err = c.resolveCatalogEmbeds(qc, rvs, embeds)
if err != nil {
return nil, err
}
params, err := c.resolveCatalogRefs(qc, refs, namedParams)
if err != nil {
return nil, err
}
Expand Down
70 changes: 65 additions & 5 deletions internal/compiler/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func dataType(n *ast.TypeName) string {
}
}

func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) {
func (comp *Compiler) resolveCatalogEmbeds(qc *QueryCatalog, rvs []*ast.RangeVar, embeds rewrite.EmbedSet) error {
c := comp.catalog

aliasMap := map[string]*ast.TableName{}
Expand Down Expand Up @@ -55,7 +55,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
}
fqn, err := ParseTableName(rv)
if err != nil {
return nil, err
return err
}
if _, found := aliasMap[fqn.Name]; found {
continue
Expand All @@ -64,13 +64,13 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
if err != nil {
// If the table name doesn't exist, fisrt check if it's a CTE
if _, qcerr := qc.GetTable(fqn); qcerr != nil {
return nil, err
return err
}
continue
}
err = indexTable(table)
if err != nil {
return nil, err
return err
}
if rv.Alias != nil {
aliasMap[*rv.Alias.Aliasname] = fqn
Expand All @@ -90,11 +90,71 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
continue
}

return nil, fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err)
return fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err)
}
return nil
}

func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, params *named.ParamSet) ([]Parameter, error) {
c := comp.catalog

// resolve a table for an embed
var a []Parameter
for _, ref := range args {
aliasMap := map[string]*ast.TableName{}
// TODO: Deprecate defaultTable
var defaultTable *ast.TableName
var tables []*ast.TableName

typeMap := map[string]map[string]map[string]*catalog.Column{}
indexTable := func(table catalog.Table) error {
tables = append(tables, table.Rel)
if defaultTable == nil {
defaultTable = table.Rel
}
schema := table.Rel.Schema
if schema == "" {
schema = c.DefaultSchema
}
if _, exists := typeMap[schema]; !exists {
typeMap[schema] = map[string]map[string]*catalog.Column{}
}
typeMap[schema][table.Rel.Name] = map[string]*catalog.Column{}
for _, c := range table.Columns {
cc := c
typeMap[schema][table.Rel.Name][c.Name] = cc
}
return nil
}

for _, rv := range ref.rvs {
if rv == nil || rv.Relname == nil {
continue
}
fqn, err := ParseTableName(rv)
if err != nil {
return nil, err
}
if _, found := aliasMap[fqn.Name]; found {
continue
}
table, err := c.GetTable(fqn)
if err != nil {
// If the table name doesn't exist, fisrt check if it's a CTE
if _, qcerr := qc.GetTable(fqn); qcerr != nil {
return nil, err
}
continue
}
err = indexTable(table)
if err != nil {
return nil, err
}
if rv.Alias != nil {
aliasMap[*rv.Alias.Aliasname] = fqn
}
}

switch n := ref.parent.(type) {

case *limitOffset:
Expand Down
31 changes: 31 additions & 0 deletions internal/endtoend/testdata/subquery_with_where/go/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions internal/endtoend/testdata/subquery_with_where/go/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 53 additions & 0 deletions internal/endtoend/testdata/subquery_with_where/go/query.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions internal/endtoend/testdata/subquery_with_where/query.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CREATE TABLE foo (a int not null, name text);
CREATE TABLE bar (a int not null, alias text);

-- name: Subquery :many
SELECT
a,
name,
(SELECT alias FROM bar WHERE bar.a=foo.a AND alias = $1 ORDER BY bar.a DESC limit 1) as alias
FROM FOO WHERE a = $2;
12 changes: 12 additions & 0 deletions internal/endtoend/testdata/subquery_with_where/sqlc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"version": "1",
"packages": [
{
"path": "go",
"engine": "postgresql",
"name": "querytest",
"schema": "query.sql",
"queries": "query.sql"
}
]
}
Loading