88 "fmt"
99 "regexp"
1010 "sort"
11+ "strconv"
1112 "strings"
1213 "text/template"
1314
@@ -32,25 +33,24 @@ type Enum struct {
3233}
3334
3435type Field struct {
36+ ID int
3537 Name string
3638 Type ktType
3739 Comment string
3840}
3941
4042type Struct struct {
41- Table plugin.Identifier
42- Name string
43- Fields []Field
44- JDBCParamBindings []Field
45- Comment string
43+ Table plugin.Identifier
44+ Name string
45+ Fields []Field
46+ Comment string
4647}
4748
4849type QueryValue struct {
49- Emit bool
50- Name string
51- Struct * Struct
52- Typ ktType
53- JDBCParamBindCount int
50+ Emit bool
51+ Name string
52+ Struct * Struct
53+ Typ ktType
5454}
5555
5656func (v QueryValue ) EmitStruct () bool {
@@ -102,7 +102,8 @@ func jdbcSet(t ktType, idx int, name string) string {
102102}
103103
104104type Params struct {
105- Struct * Struct
105+ Struct * Struct
106+ binding []int
106107}
107108
108109func (v Params ) isEmpty () bool {
@@ -114,9 +115,19 @@ func (v Params) Args() string {
114115 return ""
115116 }
116117 var out []string
117- for _ , f := range v .Struct .Fields {
118+ fields := v .Struct .Fields
119+ for _ , f := range fields {
118120 out = append (out , f .Name + ": " + f .Type .String ())
119121 }
122+ if len (v .binding ) > 0 {
123+ lookup := map [int ]int {}
124+ for i , v := range v .binding {
125+ lookup [v ] = i
126+ }
127+ sort .Slice (out , func (i , j int ) bool {
128+ return lookup [fields [i ].ID ] < lookup [fields [j ].ID ]
129+ })
130+ }
120131 if len (out ) < 3 {
121132 return strings .Join (out , ", " )
122133 }
@@ -128,8 +139,15 @@ func (v Params) Bindings() string {
128139 return ""
129140 }
130141 var out []string
131- for i , f := range v .Struct .JDBCParamBindings {
132- out = append (out , jdbcSet (f .Type , i + 1 , f .Name ))
142+ if len (v .binding ) > 0 {
143+ for i , idx := range v .binding {
144+ f := v .Struct .Fields [idx - 1 ]
145+ out = append (out , jdbcSet (f .Type , i + 1 , f .Name ))
146+ }
147+ } else {
148+ for i , f := range v .Struct .Fields {
149+ out = append (out , jdbcSet (f .Type , i + 1 , f .Name ))
150+ }
133151 }
134152 return indent (strings .Join (out , "\n " ), 10 , 0 )
135153}
@@ -387,20 +405,19 @@ func ktColumnsToStruct(req *plugin.CodeGenRequest, name string, columns []goColu
387405 idSeen := map [int ]Field {}
388406 nameSeen := map [string ]int {}
389407 for _ , c := range columns {
390- if binding , ok := idSeen [c .id ]; ok {
391- gs .JDBCParamBindings = append (gs .JDBCParamBindings , binding )
408+ if _ , ok := idSeen [c .id ]; ok {
392409 continue
393410 }
394411 fieldName := memberName (namer (c .Column , c .id ), req .Settings )
395412 if v := nameSeen [c .Name ]; v > 0 {
396413 fieldName = fmt .Sprintf ("%s_%d" , fieldName , v + 1 )
397414 }
398415 field := Field {
416+ ID : c .id ,
399417 Name : fieldName ,
400418 Type : makeType (req , c .Column ),
401419 }
402420 gs .Fields = append (gs .Fields , field )
403- gs .JDBCParamBindings = append (gs .JDBCParamBindings , field )
404421 nameSeen [c .Name ]++
405422 idSeen [c .id ] = field
406423 }
@@ -438,11 +455,31 @@ var postgresPlaceholderRegexp = regexp.MustCompile(`\B\$\d+\b`)
438455// HACK: jdbc doesn't support numbered parameters, so we need to transform them to question marks...
439456// But there's no access to the SQL parser here, so we just do a dumb regexp replace instead. This won't work if
440457// the literal strings contain matching values, but good enough for a prototype.
441- func jdbcSQL (s , engine string ) string {
442- if engine = = "postgresql" {
443- return postgresPlaceholderRegexp . ReplaceAllString ( s , "?" )
458+ func jdbcSQL (s , engine string ) ( string , [] string ) {
459+ if engine ! = "postgresql" {
460+ return s , nil
444461 }
445- return s
462+ var args []string
463+ q := postgresPlaceholderRegexp .ReplaceAllStringFunc (s , func (placeholder string ) string {
464+ args = append (args , placeholder )
465+ return "?"
466+ })
467+ return q , args
468+ }
469+
470+ func parseInts (s []string ) ([]int , error ) {
471+ if len (s ) == 0 {
472+ return nil , nil
473+ }
474+ var refs []int
475+ for _ , v := range s {
476+ i , err := strconv .Atoi (strings .TrimPrefix (v , "$" ))
477+ if err != nil {
478+ return nil , err
479+ }
480+ refs = append (refs , i )
481+ }
482+ return refs , nil
446483}
447484
448485func buildQueries (req * plugin.CodeGenRequest , structs []Struct ) ([]Query , error ) {
@@ -458,14 +495,19 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
458495 return nil , errors .New ("Support for CopyFrom in Kotlin is not implemented" )
459496 }
460497
498+ ql , args := jdbcSQL (query .Text , req .Settings .Engine )
499+ refs , err := parseInts (args )
500+ if err != nil {
501+ return nil , fmt .Errorf ("Invalid parameter reference: %w" , err )
502+ }
461503 gq := Query {
462504 Cmd : query .Cmd ,
463505 ClassName : strings .Title (query .Name ),
464506 ConstantName : sdk .LowerTitle (query .Name ),
465507 FieldName : sdk .LowerTitle (query .Name ) + "Stmt" ,
466508 MethodName : sdk .LowerTitle (query .Name ),
467509 SourceName : query .Filename ,
468- SQL : jdbcSQL ( query . Text , req . Settings . Engine ) ,
510+ SQL : ql ,
469511 Comments : query .Comments ,
470512 }
471513
@@ -478,7 +520,8 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
478520 }
479521 params := ktColumnsToStruct (req , gq .ClassName + "Bindings" , cols , ktParamName )
480522 gq .Arg = Params {
481- Struct : params ,
523+ Struct : params ,
524+ binding : refs ,
482525 }
483526
484527 if len (query .Columns ) == 1 {
0 commit comments