Skip to content

Commit 8785fce

Browse files
committed
add support of mysql driver in schema2struct utility
1 parent 6e13572 commit 8785fce

File tree

1 file changed

+40
-4
lines changed

1 file changed

+40
-4
lines changed

schema2struct/schema2struct.go

+40-4
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ import (
1111
"github.com/Masterminds/squirrel"
1212
"github.com/codegangsta/cli"
1313

14+
_ "github.com/go-sql-driver/mysql"
1415
_ "github.com/lib/pq"
1516
)
1617

1718
const version = "DEV"
1819

20+
// Usage : exported const Usage
1921
const Usage = `Read a schema and generate Structable structs.
2022
2123
This utility generates Structable structs be reading your database table and
@@ -31,6 +33,7 @@ import (
3133
3234
"github.com/Masterminds/squirrel"
3335
"github.com/Masterminds/structable"
36+
_ "github.com/go-sql-driver/mysql"
3437
_ "github.com/lib/pq"
3538
)
3639
@@ -252,7 +255,7 @@ func importTables(c *cli.Context) {
252255
}
253256

254257
for _, t := range tables {
255-
f, err := importTable(t, bldr)
258+
f, err := importTable(t, bldr, driver(c))
256259
if err != nil {
257260
fmt.Fprintf(os.Stderr, "Failed to import table %s: %s", t, err)
258261
}
@@ -288,7 +291,7 @@ func publicTables(b squirrel.StatementBuilderType) ([]string, error) {
288291
// importTable reads a table definition and writes a corresponding struct.
289292
// SELECT table_name, column_name, data_type, character_maximum_length
290293
// FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = 'goose_db_version'
291-
func importTable(tbl string, b squirrel.StatementBuilderType) (*structDesc, error) {
294+
func importTable(tbl string, b squirrel.StatementBuilderType, driver string) (*structDesc, error) {
292295

293296
pks, err := primaryKeyField(tbl, b)
294297
if err != nil {
@@ -313,7 +316,12 @@ func importTable(tbl string, b squirrel.StatementBuilderType) (*structDesc, erro
313316
return nil, err
314317
}
315318
c.Max = length.Int64
316-
ff = append(ff, structField(c, pks, tbl, b))
319+
switch driver {
320+
case "mysql":
321+
ff = append(ff, structFieldMySQL(c, pks, tbl, b))
322+
case "postgres":
323+
ff = append(ff, structField(c, pks, tbl, b))
324+
}
317325
}
318326
sd := &structDesc{
319327
StructName: goName(tbl),
@@ -345,8 +353,18 @@ func primaryKeyField(tbl string, b squirrel.StatementBuilderType) ([]string, err
345353
return res, nil
346354
}
347355

348-
func sequentialKey(tbl, pk string, b squirrel.StatementBuilderType) bool {
356+
func autoincrementKey(tbl, pk string, b squirrel.StatementBuilderType) bool {
357+
q := b.Select("COUNT(*)").
358+
From("INFORMATION_SCHEMA.COLUMNS").
359+
Where("TABLE_NAME = ? AND COLUMN_NAME = ? AND EXTRA = 'auto_increment'", tbl, pk)
360+
var num int
361+
if err := q.Scan(&num); err != nil {
362+
panic(err)
363+
}
364+
return num > 0
365+
}
349366

367+
func sequentialKey(tbl, pk string, b squirrel.StatementBuilderType) bool {
350368
tlen := 58
351369

352370
stbl := tbl
@@ -372,6 +390,24 @@ func sequentialKey(tbl, pk string, b squirrel.StatementBuilderType) bool {
372390
return num > 0
373391
}
374392

393+
func structFieldMySQL(c *column, pks []string, tbl string, b squirrel.StatementBuilderType) string {
394+
tpl := "%s %s `stbl:\"%s\"`"
395+
gn := destutter(goName(c.Name), goName(tbl))
396+
tt := goType(c.DataType)
397+
398+
tag := c.Name
399+
for _, p := range pks {
400+
if c.Name == p {
401+
tag += ",PRIMARY_KEY"
402+
if autoincrementKey(tbl, c.Name, b) {
403+
tag += ",AUTO_INCREMENT"
404+
}
405+
}
406+
}
407+
408+
return fmt.Sprintf(tpl, gn, tt, tag)
409+
}
410+
375411
func structField(c *column, pks []string, tbl string, b squirrel.StatementBuilderType) string {
376412
tpl := "%s %s `stbl:\"%s\"`"
377413
gn := destutter(goName(c.Name), goName(tbl))

0 commit comments

Comments
 (0)