@@ -11,11 +11,13 @@ import (
11
11
"github.com/Masterminds/squirrel"
12
12
"github.com/codegangsta/cli"
13
13
14
+ _ "github.com/go-sql-driver/mysql"
14
15
_ "github.com/lib/pq"
15
16
)
16
17
17
18
const version = "DEV"
18
19
20
+ // Usage : exported const Usage
19
21
const Usage = `Read a schema and generate Structable structs.
20
22
21
23
This utility generates Structable structs be reading your database table and
@@ -31,6 +33,7 @@ import (
31
33
32
34
"github.com/Masterminds/squirrel"
33
35
"github.com/Masterminds/structable"
36
+ _ "github.com/go-sql-driver/mysql"
34
37
_ "github.com/lib/pq"
35
38
)
36
39
@@ -252,7 +255,7 @@ func importTables(c *cli.Context) {
252
255
}
253
256
254
257
for _ , t := range tables {
255
- f , err := importTable (t , bldr )
258
+ f , err := importTable (t , bldr , driver ( c ) )
256
259
if err != nil {
257
260
fmt .Fprintf (os .Stderr , "Failed to import table %s: %s" , t , err )
258
261
}
@@ -288,7 +291,7 @@ func publicTables(b squirrel.StatementBuilderType) ([]string, error) {
288
291
// importTable reads a table definition and writes a corresponding struct.
289
292
// SELECT table_name, column_name, data_type, character_maximum_length
290
293
// FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = 'goose_db_version'
291
- func importTable (tbl string , b squirrel.StatementBuilderType ) (* structDesc , error ) {
294
+ func importTable (tbl string , b squirrel.StatementBuilderType , driver string ) (* structDesc , error ) {
292
295
293
296
pks , err := primaryKeyField (tbl , b )
294
297
if err != nil {
@@ -313,7 +316,12 @@ func importTable(tbl string, b squirrel.StatementBuilderType) (*structDesc, erro
313
316
return nil , err
314
317
}
315
318
c .Max = length .Int64
316
- ff = append (ff , structField (c , pks , tbl , b ))
319
+ switch driver {
320
+ case "mysql" :
321
+ ff = append (ff , structFieldMySQL (c , pks , tbl , b ))
322
+ case "postgres" :
323
+ ff = append (ff , structField (c , pks , tbl , b ))
324
+ }
317
325
}
318
326
sd := & structDesc {
319
327
StructName : goName (tbl ),
@@ -345,8 +353,18 @@ func primaryKeyField(tbl string, b squirrel.StatementBuilderType) ([]string, err
345
353
return res , nil
346
354
}
347
355
348
- func sequentialKey (tbl , pk string , b squirrel.StatementBuilderType ) bool {
356
+ func autoincrementKey (tbl , pk string , b squirrel.StatementBuilderType ) bool {
357
+ q := b .Select ("COUNT(*)" ).
358
+ From ("INFORMATION_SCHEMA.COLUMNS" ).
359
+ Where ("TABLE_NAME = ? AND COLUMN_NAME = ? AND EXTRA = 'auto_increment'" , tbl , pk )
360
+ var num int
361
+ if err := q .Scan (& num ); err != nil {
362
+ panic (err )
363
+ }
364
+ return num > 0
365
+ }
349
366
367
+ func sequentialKey (tbl , pk string , b squirrel.StatementBuilderType ) bool {
350
368
tlen := 58
351
369
352
370
stbl := tbl
@@ -372,6 +390,24 @@ func sequentialKey(tbl, pk string, b squirrel.StatementBuilderType) bool {
372
390
return num > 0
373
391
}
374
392
393
+ func structFieldMySQL (c * column , pks []string , tbl string , b squirrel.StatementBuilderType ) string {
394
+ tpl := "%s %s `stbl:\" %s\" `"
395
+ gn := destutter (goName (c .Name ), goName (tbl ))
396
+ tt := goType (c .DataType )
397
+
398
+ tag := c .Name
399
+ for _ , p := range pks {
400
+ if c .Name == p {
401
+ tag += ",PRIMARY_KEY"
402
+ if autoincrementKey (tbl , c .Name , b ) {
403
+ tag += ",AUTO_INCREMENT"
404
+ }
405
+ }
406
+ }
407
+
408
+ return fmt .Sprintf (tpl , gn , tt , tag )
409
+ }
410
+
375
411
func structField (c * column , pks []string , tbl string , b squirrel.StatementBuilderType ) string {
376
412
tpl := "%s %s `stbl:\" %s\" `"
377
413
gn := destutter (goName (c .Name ), goName (tbl ))
0 commit comments