Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ WHERE book_id = ?
data class UpdateBookISBNParams (
val title: String,
val tags: List<String>,
val bookId: Int,
val isbn: String
val isbn: String,
val bookId: Int
)

class QueriesImpl(private val conn: Connection) {
Expand Down Expand Up @@ -282,8 +282,8 @@ class QueriesImpl(private val conn: Connection) {
val stmt = conn.prepareStatement(updateBookISBN)
stmt.setString(1, arg.title)
stmt.setArray(2, conn.createArrayOf("pg_catalog.varchar", arg.tags.toTypedArray()))
stmt.setInt(3, arg.bookId)
stmt.setString(4, arg.isbn)
stmt.setString(3, arg.isbn)
stmt.setInt(4, arg.bookId)

stmt.execute()
stmt.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ WHERE slug = ?
"""

data class UpdateCityNameParams (
val slug: String,
val name: String
val name: String,
val slug: String
)

const val updateVenueName = """-- name: updateVenueName :one
Expand All @@ -108,8 +108,8 @@ RETURNING id
"""

data class UpdateVenueNameParams (
val slug: String,
val name: String
val name: String,
val slug: String
)

const val venueCountByCity = """-- name: venueCountByCity :many
Expand Down Expand Up @@ -180,6 +180,7 @@ class QueriesImpl(private val conn: Connection) : Queries {
override fun deleteVenue(slug: String) {
val stmt = conn.prepareStatement(deleteVenue)
stmt.setString(1, slug)
stmt.setString(2, slug)

stmt.execute()
stmt.close()
Expand Down Expand Up @@ -278,8 +279,8 @@ class QueriesImpl(private val conn: Connection) : Queries {
@Throws(SQLException::class)
override fun updateCityName(arg: UpdateCityNameParams) {
val stmt = conn.prepareStatement(updateCityName)
stmt.setString(1, arg.slug)
stmt.setString(2, arg.name)
stmt.setString(1, arg.name)
stmt.setString(2, arg.slug)

stmt.execute()
stmt.close()
Expand All @@ -288,8 +289,8 @@ class QueriesImpl(private val conn: Connection) : Queries {
@Throws(SQLException::class)
override fun updateVenueName(arg: UpdateVenueNameParams): Int {
val stmt = conn.prepareStatement(updateVenueName)
stmt.setString(1, arg.slug)
stmt.setString(2, arg.name)
stmt.setString(1, arg.name)
stmt.setString(2, arg.slug)

return stmt.executeQuery().use { results ->
if (!results.next()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ class QueriesImplTest {

// ISBN update fails because parameters are not in sequential order. After changing $N to ?, ordering is lost,
// and the parameters are filled into the wrong slots.
// db.updateBookISBN(
// UpdateBookISBNParams(
// bookId = b3.bookId,
// isbn = "NEW ISBN",
// title = "never ever gonna finish, a quatrain",
// tags = listOf("someother")
// )
// )
db.updateBookISBN(
UpdateBookISBNParams(
bookId = b3.bookId,
isbn = "NEW ISBN",
title = "never ever gonna finish, a quatrain",
tags = listOf("someother")
)
)

val books0 = db.booksByTitleYear(BooksByTitleYearParams("my book title", 2016))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ class QueriesImplTest {
assertEquals(listOf(city), q.listCities())
assertEquals(listOf(venue), q.listVenues(city.slug))

// These updates fail because parameters are not in sequential order. After changing $N to ?, ordering is lost,
// and the parameters are filled into the wrong slots.
// q.updateCityName(UpdateCityNameParams(slug = city.slug, name = "SF"))
// q.updateVenueName(UpdateVenueNameParams(slug = venue.slug, name = "Fillmore"))
q.updateCityName(UpdateCityNameParams(slug = city.slug, name = "SF"))
val id = q.updateVenueName(UpdateVenueNameParams(slug = venue.slug, name = "Fillmore"))
assertEquals(venue.id, id)

q.deleteVenue(venue.slug)
}
}
4 changes: 4 additions & 0 deletions internal/dinosql/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ type PackageSettings struct {
EmitJSONTags bool `json:"emit_json_tags"`
EmitPreparedQueries bool `json:"emit_prepared_queries"`
Overrides []Override `json:"overrides"`
// HACK: this is only set in tests, only here till Kotlin support can be merged.
rewriteParams bool
}

type Override struct {
Expand Down Expand Up @@ -206,6 +208,8 @@ func ParseConfig(rd io.Reader) (GenerateSettings, error) {
}
if config.Packages[j].Language == "" {
config.Packages[j].Language = LanguageGo
} else if config.Packages[j].Language == "kotlin" {
config.Packages[j].rewriteParams = true
}
}
return config, nil
Expand Down
43 changes: 34 additions & 9 deletions internal/dinosql/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,28 +719,43 @@ func (r Result) goInnerType(col core.Column, settings CombinedSettings) string {
}
}

type goColumn struct {
id int
core.Column
}

// It's possible that this method will generate duplicate JSON tag values
//
// Columns: count, count, count_2
// Fields: Count, Count_2, Count2
// JSON tags: count, count_2, count_2
//
// This is unlikely to happen, so don't fix it yet
func (r Result) columnsToStruct(name string, columns []core.Column, settings CombinedSettings) *GoStruct {
func (r Result) columnsToStruct(name string, columns []goColumn, settings CombinedSettings) *GoStruct {
gs := GoStruct{
Name: name,
}
seen := map[string]int{}
suffixes := map[int]int{}
for i, c := range columns {
tagName := c.Name
fieldName := StructName(columnName(c, i), settings)
if v := seen[c.Name]; v > 0 {
tagName = fmt.Sprintf("%s_%d", tagName, v+1)
fieldName = fmt.Sprintf("%s_%d", fieldName, v+1)
fieldName := StructName(columnName(c.Column, i), settings)
// Track suffixes by the ID of the column, so that columns referring to the same numbered parameter can be
// reused.
suffix := 0
if o, ok := suffixes[c.id]; ok {
suffix = o
} else if v := seen[c.Name]; v > 0 {
suffix = v+1
}
suffixes[c.id] = suffix
if suffix > 0 {
tagName = fmt.Sprintf("%s_%d", tagName, suffix)
fieldName = fmt.Sprintf("%s_%d", fieldName, suffix)
}
gs.Fields = append(gs.Fields, GoField{
Name: fieldName,
Type: r.goType(c, settings),
Type: r.goType(c.Column, settings),
Tags: map[string]string{"json:": tagName},
})
seen[c.Name]++
Expand Down Expand Up @@ -815,9 +830,12 @@ func (r Result) GoQueries(settings CombinedSettings) []GoQuery {
Typ: r.goType(p.Column, settings),
}
} else if len(query.Params) > 1 {
var cols []core.Column
var cols []goColumn
for _, p := range query.Params {
cols = append(cols, p.Column)
cols = append(cols, goColumn{
id: p.Number,
Column: p.Column,
})
}
gq.Arg = GoQueryValue{
Emit: true,
Expand Down Expand Up @@ -858,7 +876,14 @@ func (r Result) GoQueries(settings CombinedSettings) []GoQuery {
}

if gs == nil {
gs = r.columnsToStruct(gq.MethodName+"Row", query.Columns, settings)
var columns []goColumn
for i, c := range query.Columns {
columns = append(columns, goColumn{
id: i,
Column: c,
})
}
gs = r.columnsToStruct(gq.MethodName+"Row", columns, settings)
emit = true
}
gq.Ret = GoQueryValue{
Expand Down
89 changes: 61 additions & 28 deletions internal/dinosql/kotlin/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,19 @@ type KtField struct {
}

type KtStruct struct {
Table core.FQN
Name string
Fields []KtField
Comment string
Table core.FQN
Name string
Fields []KtField
JDBCParamBindings []KtField
Comment string
}

type KtQueryValue struct {
Emit bool
Name string
Struct *KtStruct
Typ ktType
Emit bool
Name string
Struct *KtStruct
Typ ktType
JDBCParamBindCount int
}

func (v KtQueryValue) EmitStruct() bool {
Expand Down Expand Up @@ -101,9 +103,15 @@ func (v KtQueryValue) Params() string {
}
var out []string
if v.Struct == nil {
out = append(out, jdbcSet(v.Typ, 1, v.Name))
repeat := 1
if v.JDBCParamBindCount > 0 {
repeat = v.JDBCParamBindCount
}
for i := 1; i <= repeat; i++ {
out = append(out, jdbcSet(v.Typ, i, v.Name))
}
} else {
for i, f := range v.Struct.Fields {
for i, f := range v.Struct.JDBCParamBindings {
out = append(out, jdbcSet(f.Type, i+1, v.Name+"."+f.Name))
}
}
Expand Down Expand Up @@ -595,21 +603,34 @@ func (r Result) ktInnerType(col core.Column, settings dinosql.CombinedSettings)
}
}

func (r Result) ktColumnsToStruct(name string, columns []core.Column, settings dinosql.CombinedSettings) *KtStruct {
type goColumn struct {
id int
core.Column
}

func (r Result) ktColumnsToStruct(name string, columns []goColumn, settings dinosql.CombinedSettings) *KtStruct {
gs := KtStruct{
Name: name,
}
seen := map[string]int{}
idSeen := map[int]KtField{}
nameSeen := map[string]int{}
for i, c := range columns {
fieldName := KtMemberName(ktColumnName(c, i), settings)
if v := seen[c.Name]; v > 0 {
if binding, ok := idSeen[c.id]; ok {
gs.JDBCParamBindings = append(gs.JDBCParamBindings, binding)
continue
}
fieldName := KtMemberName(ktColumnName(c.Column, i), settings)
if v := nameSeen[c.Name]; v > 0 {
fieldName = fmt.Sprintf("%s_%d", fieldName, v+1)
}
gs.Fields = append(gs.Fields, KtField{
field := KtField{
Name: fieldName,
Type: r.ktType(c, settings),
})
seen[c.Name]++
Type: r.ktType(c.Column, settings),
}
gs.Fields = append(gs.Fields, field)
gs.JDBCParamBindings = append(gs.JDBCParamBindings, field)
nameSeen[c.Name]++
idSeen[c.id] = field
}
return &gs
}
Expand Down Expand Up @@ -672,21 +693,26 @@ func (r Result) KtQueries(settings dinosql.CombinedSettings) []KtQuery {
Comments: query.Comments,
}

if len(query.Params) == 1 {
var cols []goColumn
for _, p := range query.Params {
cols = append(cols, goColumn{
id: p.Number,
Column: p.Column,
})
}
params := r.ktColumnsToStruct(gq.ClassName+"Params", cols, settings)
if len(params.Fields) == 1 {
p := query.Params[0]
gq.Arg = KtQueryValue{
Name: ktParamName(p),
Typ: r.ktType(p.Column, settings),
}
} else if len(query.Params) > 1 {
var cols []core.Column
for _, p := range query.Params {
cols = append(cols, p.Column)
Name: ktParamName(p),
Typ: r.ktType(p.Column, settings),
JDBCParamBindCount: len(params.JDBCParamBindings),
}
} else if len(params.Fields) > 1 {
gq.Arg = KtQueryValue{
Emit: true,
Name: "arg",
Struct: r.ktColumnsToStruct(gq.ClassName+"Params", cols, settings),
Struct: params,
}
}

Expand Down Expand Up @@ -722,7 +748,14 @@ func (r Result) KtQueries(settings dinosql.CombinedSettings) []KtQuery {
}

if gs == nil {
gs = r.ktColumnsToStruct(gq.ClassName+"Row", query.Columns, settings)
var columns []goColumn
for i, c := range query.Columns {
columns = append(columns, goColumn{
id: i,
Column: c,
})
}
gs = r.ktColumnsToStruct(gq.ClassName+"Row", columns, settings)
emit = true
}
gq.Ret = KtQueryValue{
Expand Down
Loading