diff --git a/examples/kotlin/src/main/kotlin/com/example/booktest/postgresql/QueriesImpl.kt b/examples/kotlin/src/main/kotlin/com/example/booktest/postgresql/QueriesImpl.kt index e81fc242e1..281ba0dd66 100644 --- a/examples/kotlin/src/main/kotlin/com/example/booktest/postgresql/QueriesImpl.kt +++ b/examples/kotlin/src/main/kotlin/com/example/booktest/postgresql/QueriesImpl.kt @@ -109,8 +109,8 @@ WHERE book_id = ? data class UpdateBookISBNParams ( val title: String, val tags: List, - val bookId: Int, - val isbn: String + val isbn: String, + val bookId: Int ) class QueriesImpl(private val conn: Connection) { @@ -282,8 +282,8 @@ class QueriesImpl(private val conn: Connection) { val stmt = conn.prepareStatement(updateBookISBN) stmt.setString(1, arg.title) stmt.setArray(2, conn.createArrayOf("pg_catalog.varchar", arg.tags.toTypedArray())) - stmt.setInt(3, arg.bookId) - stmt.setString(4, arg.isbn) + stmt.setString(3, arg.isbn) + stmt.setInt(4, arg.bookId) stmt.execute() stmt.close() diff --git a/examples/kotlin/src/main/kotlin/com/example/ondeck/QueriesImpl.kt b/examples/kotlin/src/main/kotlin/com/example/ondeck/QueriesImpl.kt index 60777d19c9..7b7911af70 100644 --- a/examples/kotlin/src/main/kotlin/com/example/ondeck/QueriesImpl.kt +++ b/examples/kotlin/src/main/kotlin/com/example/ondeck/QueriesImpl.kt @@ -96,8 +96,8 @@ WHERE slug = ? """ data class UpdateCityNameParams ( - val slug: String, - val name: String + val name: String, + val slug: String ) const val updateVenueName = """-- name: updateVenueName :one @@ -108,8 +108,8 @@ RETURNING id """ data class UpdateVenueNameParams ( - val slug: String, - val name: String + val name: String, + val slug: String ) const val venueCountByCity = """-- name: venueCountByCity :many @@ -180,6 +180,7 @@ class QueriesImpl(private val conn: Connection) : Queries { override fun deleteVenue(slug: String) { val stmt = conn.prepareStatement(deleteVenue) stmt.setString(1, slug) + stmt.setString(2, slug) stmt.execute() stmt.close() @@ -278,8 +279,8 @@ class QueriesImpl(private val conn: Connection) : Queries { @Throws(SQLException::class) override fun updateCityName(arg: UpdateCityNameParams) { val stmt = conn.prepareStatement(updateCityName) - stmt.setString(1, arg.slug) - stmt.setString(2, arg.name) + stmt.setString(1, arg.name) + stmt.setString(2, arg.slug) stmt.execute() stmt.close() @@ -288,8 +289,8 @@ class QueriesImpl(private val conn: Connection) : Queries { @Throws(SQLException::class) override fun updateVenueName(arg: UpdateVenueNameParams): Int { val stmt = conn.prepareStatement(updateVenueName) - stmt.setString(1, arg.slug) - stmt.setString(2, arg.name) + stmt.setString(1, arg.name) + stmt.setString(2, arg.slug) return stmt.executeQuery().use { results -> if (!results.next()) { diff --git a/examples/kotlin/src/test/kotlin/com/example/booktest/postgresql/QueriesImplTest.kt b/examples/kotlin/src/test/kotlin/com/example/booktest/postgresql/QueriesImplTest.kt index fb0e48e3d4..f7c54daf7b 100644 --- a/examples/kotlin/src/test/kotlin/com/example/booktest/postgresql/QueriesImplTest.kt +++ b/examples/kotlin/src/test/kotlin/com/example/booktest/postgresql/QueriesImplTest.kt @@ -81,14 +81,14 @@ class QueriesImplTest { // ISBN update fails because parameters are not in sequential order. After changing $N to ?, ordering is lost, // and the parameters are filled into the wrong slots. -// db.updateBookISBN( -// UpdateBookISBNParams( -// bookId = b3.bookId, -// isbn = "NEW ISBN", -// title = "never ever gonna finish, a quatrain", -// tags = listOf("someother") -// ) -// ) + db.updateBookISBN( + UpdateBookISBNParams( + bookId = b3.bookId, + isbn = "NEW ISBN", + title = "never ever gonna finish, a quatrain", + tags = listOf("someother") + ) + ) val books0 = db.booksByTitleYear(BooksByTitleYearParams("my book title", 2016)) diff --git a/examples/kotlin/src/test/kotlin/com/example/ondeck/QueriesImplTest.kt b/examples/kotlin/src/test/kotlin/com/example/ondeck/QueriesImplTest.kt index 8a2ff68dce..fdd1933862 100644 --- a/examples/kotlin/src/test/kotlin/com/example/ondeck/QueriesImplTest.kt +++ b/examples/kotlin/src/test/kotlin/com/example/ondeck/QueriesImplTest.kt @@ -43,9 +43,10 @@ class QueriesImplTest { assertEquals(listOf(city), q.listCities()) assertEquals(listOf(venue), q.listVenues(city.slug)) - // These updates fail because parameters are not in sequential order. After changing $N to ?, ordering is lost, - // and the parameters are filled into the wrong slots. -// q.updateCityName(UpdateCityNameParams(slug = city.slug, name = "SF")) -// q.updateVenueName(UpdateVenueNameParams(slug = venue.slug, name = "Fillmore")) + q.updateCityName(UpdateCityNameParams(slug = city.slug, name = "SF")) + val id = q.updateVenueName(UpdateVenueNameParams(slug = venue.slug, name = "Fillmore")) + assertEquals(venue.id, id) + + q.deleteVenue(venue.slug) } } \ No newline at end of file diff --git a/internal/dinosql/config.go b/internal/dinosql/config.go index b51cd0a6b1..8ae8c6b76f 100644 --- a/internal/dinosql/config.go +++ b/internal/dinosql/config.go @@ -59,6 +59,8 @@ type PackageSettings struct { EmitJSONTags bool `json:"emit_json_tags"` EmitPreparedQueries bool `json:"emit_prepared_queries"` Overrides []Override `json:"overrides"` + // HACK: this is only set in tests, only here till Kotlin support can be merged. + rewriteParams bool } type Override struct { @@ -206,6 +208,8 @@ func ParseConfig(rd io.Reader) (GenerateSettings, error) { } if config.Packages[j].Language == "" { config.Packages[j].Language = LanguageGo + } else if config.Packages[j].Language == "kotlin" { + config.Packages[j].rewriteParams = true } } return config, nil diff --git a/internal/dinosql/gen.go b/internal/dinosql/gen.go index 325eed5dbe..4a20999022 100644 --- a/internal/dinosql/gen.go +++ b/internal/dinosql/gen.go @@ -719,6 +719,11 @@ func (r Result) goInnerType(col core.Column, settings CombinedSettings) string { } } +type goColumn struct { + id int + core.Column +} + // It's possible that this method will generate duplicate JSON tag values // // Columns: count, count, count_2 @@ -726,21 +731,31 @@ func (r Result) goInnerType(col core.Column, settings CombinedSettings) string { // JSON tags: count, count_2, count_2 // // This is unlikely to happen, so don't fix it yet -func (r Result) columnsToStruct(name string, columns []core.Column, settings CombinedSettings) *GoStruct { +func (r Result) columnsToStruct(name string, columns []goColumn, settings CombinedSettings) *GoStruct { gs := GoStruct{ Name: name, } seen := map[string]int{} + suffixes := map[int]int{} for i, c := range columns { tagName := c.Name - fieldName := StructName(columnName(c, i), settings) - if v := seen[c.Name]; v > 0 { - tagName = fmt.Sprintf("%s_%d", tagName, v+1) - fieldName = fmt.Sprintf("%s_%d", fieldName, v+1) + fieldName := StructName(columnName(c.Column, i), settings) + // Track suffixes by the ID of the column, so that columns referring to the same numbered parameter can be + // reused. + suffix := 0 + if o, ok := suffixes[c.id]; ok { + suffix = o + } else if v := seen[c.Name]; v > 0 { + suffix = v+1 + } + suffixes[c.id] = suffix + if suffix > 0 { + tagName = fmt.Sprintf("%s_%d", tagName, suffix) + fieldName = fmt.Sprintf("%s_%d", fieldName, suffix) } gs.Fields = append(gs.Fields, GoField{ Name: fieldName, - Type: r.goType(c, settings), + Type: r.goType(c.Column, settings), Tags: map[string]string{"json:": tagName}, }) seen[c.Name]++ @@ -815,9 +830,12 @@ func (r Result) GoQueries(settings CombinedSettings) []GoQuery { Typ: r.goType(p.Column, settings), } } else if len(query.Params) > 1 { - var cols []core.Column + var cols []goColumn for _, p := range query.Params { - cols = append(cols, p.Column) + cols = append(cols, goColumn{ + id: p.Number, + Column: p.Column, + }) } gq.Arg = GoQueryValue{ Emit: true, @@ -858,7 +876,14 @@ func (r Result) GoQueries(settings CombinedSettings) []GoQuery { } if gs == nil { - gs = r.columnsToStruct(gq.MethodName+"Row", query.Columns, settings) + var columns []goColumn + for i, c := range query.Columns { + columns = append(columns, goColumn{ + id: i, + Column: c, + }) + } + gs = r.columnsToStruct(gq.MethodName+"Row", columns, settings) emit = true } gq.Ret = GoQueryValue{ diff --git a/internal/dinosql/kotlin/gen.go b/internal/dinosql/kotlin/gen.go index a6029ecd44..6107622983 100644 --- a/internal/dinosql/kotlin/gen.go +++ b/internal/dinosql/kotlin/gen.go @@ -37,17 +37,19 @@ type KtField struct { } type KtStruct struct { - Table core.FQN - Name string - Fields []KtField - Comment string + Table core.FQN + Name string + Fields []KtField + JDBCParamBindings []KtField + Comment string } type KtQueryValue struct { - Emit bool - Name string - Struct *KtStruct - Typ ktType + Emit bool + Name string + Struct *KtStruct + Typ ktType + JDBCParamBindCount int } func (v KtQueryValue) EmitStruct() bool { @@ -101,9 +103,15 @@ func (v KtQueryValue) Params() string { } var out []string if v.Struct == nil { - out = append(out, jdbcSet(v.Typ, 1, v.Name)) + repeat := 1 + if v.JDBCParamBindCount > 0 { + repeat = v.JDBCParamBindCount + } + for i := 1; i <= repeat; i++ { + out = append(out, jdbcSet(v.Typ, i, v.Name)) + } } else { - for i, f := range v.Struct.Fields { + for i, f := range v.Struct.JDBCParamBindings { out = append(out, jdbcSet(f.Type, i+1, v.Name+"."+f.Name)) } } @@ -595,21 +603,34 @@ func (r Result) ktInnerType(col core.Column, settings dinosql.CombinedSettings) } } -func (r Result) ktColumnsToStruct(name string, columns []core.Column, settings dinosql.CombinedSettings) *KtStruct { +type goColumn struct { + id int + core.Column +} + +func (r Result) ktColumnsToStruct(name string, columns []goColumn, settings dinosql.CombinedSettings) *KtStruct { gs := KtStruct{ Name: name, } - seen := map[string]int{} + idSeen := map[int]KtField{} + nameSeen := map[string]int{} for i, c := range columns { - fieldName := KtMemberName(ktColumnName(c, i), settings) - if v := seen[c.Name]; v > 0 { + if binding, ok := idSeen[c.id]; ok { + gs.JDBCParamBindings = append(gs.JDBCParamBindings, binding) + continue + } + fieldName := KtMemberName(ktColumnName(c.Column, i), settings) + if v := nameSeen[c.Name]; v > 0 { fieldName = fmt.Sprintf("%s_%d", fieldName, v+1) } - gs.Fields = append(gs.Fields, KtField{ + field := KtField{ Name: fieldName, - Type: r.ktType(c, settings), - }) - seen[c.Name]++ + Type: r.ktType(c.Column, settings), + } + gs.Fields = append(gs.Fields, field) + gs.JDBCParamBindings = append(gs.JDBCParamBindings, field) + nameSeen[c.Name]++ + idSeen[c.id] = field } return &gs } @@ -672,21 +693,26 @@ func (r Result) KtQueries(settings dinosql.CombinedSettings) []KtQuery { Comments: query.Comments, } - if len(query.Params) == 1 { + var cols []goColumn + for _, p := range query.Params { + cols = append(cols, goColumn{ + id: p.Number, + Column: p.Column, + }) + } + params := r.ktColumnsToStruct(gq.ClassName+"Params", cols, settings) + if len(params.Fields) == 1 { p := query.Params[0] gq.Arg = KtQueryValue{ - Name: ktParamName(p), - Typ: r.ktType(p.Column, settings), - } - } else if len(query.Params) > 1 { - var cols []core.Column - for _, p := range query.Params { - cols = append(cols, p.Column) + Name: ktParamName(p), + Typ: r.ktType(p.Column, settings), + JDBCParamBindCount: len(params.JDBCParamBindings), } + } else if len(params.Fields) > 1 { gq.Arg = KtQueryValue{ Emit: true, Name: "arg", - Struct: r.ktColumnsToStruct(gq.ClassName+"Params", cols, settings), + Struct: params, } } @@ -722,7 +748,14 @@ func (r Result) KtQueries(settings dinosql.CombinedSettings) []KtQuery { } if gs == nil { - gs = r.ktColumnsToStruct(gq.ClassName+"Row", query.Columns, settings) + var columns []goColumn + for i, c := range query.Columns { + columns = append(columns, goColumn{ + id: i, + Column: c, + }) + } + gs = r.ktColumnsToStruct(gq.ClassName+"Row", columns, settings) emit = true } gq.Ret = KtQueryValue{ diff --git a/internal/dinosql/parser.go b/internal/dinosql/parser.go index 8154f1a5c6..1106fb4e38 100644 --- a/internal/dinosql/parser.go +++ b/internal/dinosql/parser.go @@ -229,7 +229,8 @@ func ParseQueries(c core.Catalog, pkg PackageSettings) (*Result, error) { continue } for _, stmt := range tree.Statements { - query, err := parseQuery(c, stmt, source) + rewriteParameters := pkg.rewriteParams + query, err := parseQuery(c, stmt, source, rewriteParameters) if err == errUnsupportedStatementType { continue } @@ -407,7 +408,7 @@ func validateCmd(n nodes.Node, name, cmd string) error { var errUnsupportedStatementType = errors.New("parseQuery: unsupported statement type") -func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) { +func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameters bool) (*Query, error) { if err := validateParamRef(stmt); err != nil { return nil, err } @@ -443,6 +444,16 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) } rvs := rangeVars(raw.Stmt) refs := findParameters(raw.Stmt) + var edits []edit + if rewriteParameters { + edits, err = rewriteNumberedParameters(refs, raw, rawSQL) + if err != nil { + return nil, err + } + } else { + refs = uniqueParamRefs(refs) + sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) + } params, err := resolveCatalogRefs(c, rvs, refs) if err != nil { return nil, err @@ -452,7 +463,13 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) if err != nil { return nil, err } - expanded, err := expand(c, raw, rawSQL) + expandEdits, err := expand(c, raw, rawSQL) + if err != nil { + return nil, err + } + edits = append(edits, expandEdits...) + + expanded, err := editQuery(rawSQL, edits) if err != nil { return nil, err } @@ -472,6 +489,18 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) }, nil } +func rewriteNumberedParameters(refs []paramRef, raw nodes.RawStmt, sql string) ([]edit, error) { + edits := make([]edit, len(refs)) + for i, ref := range refs { + edits[i] = edit{ + Location: ref.ref.Location - raw.StmtLocation, + Old: fmt.Sprintf("$%d", ref.ref.Number), + New: "?", + } + } + return edits, nil +} + func stripComments(sql string) (string, []string, error) { s := bufio.NewScanner(strings.NewReader(sql)) var lines, comments []string @@ -494,7 +523,7 @@ type edit struct { New string } -func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) { +func expand(c core.Catalog, raw nodes.RawStmt, sql string) ([]edit, error) { list := search(raw, func(node nodes.Node) bool { switch node.(type) { case nodes.DeleteStmt: @@ -507,17 +536,17 @@ func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) { return true }) if len(list.Items) == 0 { - return sql, nil + return nil, nil } var edits []edit for _, item := range list.Items { edit, err := expandStmt(c, raw, item) if err != nil { - return "", err + return nil, err } edits = append(edits, edit...) } - return editQuery(sql, edits) + return edits, nil } func expandStmt(c core.Catalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) { @@ -958,7 +987,8 @@ type paramRef struct { type paramSearch struct { parent nodes.Node rangeVar *nodes.RangeVar - refs map[int]paramRef + refs *[]paramRef + seen map[int]struct{} // XXX: Gross state hack for limit limitCount nodes.Node @@ -1005,7 +1035,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { continue } // TODO: Out-of-bounds panic - p.refs[ref.Number] = paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar} + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar}) + p.seen[ref.Location] = struct{}{} } for _, vl := range s.ValuesLists { for i, v := range vl { @@ -1014,7 +1045,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { continue } // TODO: Out-of-bounds panic - p.refs[ref.Number] = paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar} + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar}) + p.seen[ref.Location] = struct{}{} } } } @@ -1050,7 +1082,7 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { parent = limitOffset{} } } - if _, found := p.refs[n.Number]; found { + if _, found := p.seen[n.Location]; found { break } @@ -1072,7 +1104,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { } if set { - p.refs[n.Number] = paramRef{parent: parent, ref: n, rv: p.rangeVar} + *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar}) + p.seen[n.Location] = struct{}{} } return nil } @@ -1080,13 +1113,9 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { } func findParameters(root nodes.Node) []paramRef { - v := paramSearch{refs: map[int]paramRef{}} - Walk(v, root) refs := make([]paramRef, 0) - for _, r := range v.refs { - refs = append(refs, r) - } - sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) + v := paramSearch{seen: make(map[int]struct{}), refs: &refs} + Walk(v, root) return refs } @@ -1348,3 +1377,15 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( } return a, nil } + +func uniqueParamRefs(in []paramRef) []paramRef { + m := make(map[int]struct{}, len(in)) + o := make([]paramRef, 0, len(in)) + for _, v := range in { + if _, ok := m[v.ref.Number]; !ok { + m[v.ref.Number] = struct{}{} + o = append(o, v) + } + } + return o +} diff --git a/internal/dinosql/parser_test.go b/internal/dinosql/parser_test.go index e53c59955c..d1741cc6be 100644 --- a/internal/dinosql/parser_test.go +++ b/internal/dinosql/parser_test.go @@ -3,6 +3,7 @@ package dinosql import ( "testing" + "github.com/google/go-cmp/cmp" pg "github.com/lfittl/pg_query_go" nodes "github.com/lfittl/pg_query_go/nodes" ) @@ -87,13 +88,14 @@ func TestLineColumn(t *testing.T) { func TestExtractArgs(t *testing.T) { queries := []struct { - query string - count int + query string + bindNumbers []int }{ - {"SELECT * FROM venue WHERE slug = $1 AND city = $2", 2}, - {"SELECT * FROM venue WHERE slug = $1", 1}, - {"SELECT * FROM venue LIMIT $1", 1}, - {"SELECT * FROM venue OFFSET $1", 1}, + {"SELECT * FROM venue WHERE slug = $1 AND city = $2", []int{1, 2}}, + {"SELECT * FROM venue WHERE slug = $1 AND region = $2 AND city = $3 AND country = $2", []int{1, 2, 3, 2}}, + {"SELECT * FROM venue WHERE slug = $1", []int{1}}, + {"SELECT * FROM venue LIMIT $1", []int{1}}, + {"SELECT * FROM venue OFFSET $1", []int{1}}, } for _, q := range queries { tree, err := pg.Parse(q.query) @@ -105,8 +107,46 @@ func TestExtractArgs(t *testing.T) { if err != nil { t.Error(err) } - if len(refs) != q.count { - t.Errorf("expected %d refs, got %d", q.count, len(refs)) + nums := make([]int, len(refs)) + for i, n := range refs { + nums[i] = n.ref.Number + } + if diff := cmp.Diff(q.bindNumbers, nums); diff != "" { + t.Errorf("expected bindings %v, got %v", q.bindNumbers, nums) + } + } + } +} + +func TestRewriteParameters(t *testing.T) { + queries := []struct { + orig string + new string + }{ + {"SELECT * FROM venue WHERE slug = $1 AND city = $3 AND bar = $2", "SELECT * FROM venue WHERE slug = ? AND city = ? AND bar = ?"}, + {"DELETE FROM venue WHERE slug = $1 AND slug = $1", "DELETE FROM venue WHERE slug = ? AND slug = ?"}, + {"SELECT * FROM venue LIMIT $1", "SELECT * FROM venue LIMIT ?"}, + } + for _, q := range queries { + tree, err := pg.Parse(q.orig) + if err != nil { + t.Fatal(err) + } + for _, stmt := range tree.Statements { + refs := findParameters(stmt) + if err != nil { + t.Error(err) + } + edits, err := rewriteNumberedParameters(refs, stmt.(nodes.RawStmt), q.orig) + if err != nil { + t.Error(err) + } + rewritten, err := editQuery(q.orig, edits) + if err != nil { + t.Error(err) + } + if rewritten != q.new { + t.Errorf("expected %q, got %q", q.new, rewritten) } } } diff --git a/internal/dinosql/query_test.go b/internal/dinosql/query_test.go index c80b1d85e6..8ee13ac451 100644 --- a/internal/dinosql/query_test.go +++ b/internal/dinosql/query_test.go @@ -21,7 +21,7 @@ func parseSQL(in string) (Query, error) { return Query{}, err } - q, err := parseQuery(c, tree.Statements[len(tree.Statements)-1], in) + q, err := parseQuery(c, tree.Statements[len(tree.Statements)-1], in, false) if q == nil { return Query{}, err }