diff --git a/README.md b/README.md index 6acd579b8..50a316a72 100644 --- a/README.md +++ b/README.md @@ -70,10 +70,11 @@ Table of Contents * [Upsert](#upsert) * [Reload](#reload) * [Exists](#exists) + * [Enums](#enums) * [FAQ](#faq) * [Won't compiling models for a huge database be very slow?](#wont-compiling-models-for-a-huge-database-be-very-slow) * [Missing imports for generated package](#missing-imports-for-generated-package) - * [Benchmarks](#benchmarks) + * [Benchmarks](#benchmarks) ## About SQL Boiler @@ -97,6 +98,7 @@ Table of Contents - Debug logging - Schemas support - 1d arrays, json, hstore & more +- Enum types ### Supported Databases @@ -121,8 +123,8 @@ if err != nil { return err } -// If you don't want to pass in db to all generated methods -// you can use boil.SetDB to set it globally, and then use +// If you don't want to pass in db to all generated methods +// you can use boil.SetDB to set it globally, and then use // the G variant methods like so: boil.SetDB(db) users, err := models.UsersG().All() @@ -178,7 +180,7 @@ fmt.Println(len(users.R.FavoriteMovies)) * Go 1.6 minimum, and Go 1.7 for compatibility tests. * Table names and column names should use `snake_case` format. - * We require `snake_case` table names and column names. This is a recommended default in Postgres, + * We require `snake_case` table names and column names. This is a recommended default in Postgres, and we agree that it's good form, so we're enforcing this format for all drivers for the time being. * Join tables should use a *composite primary key*. * For join tables to be used transparently for relationships your join table must have @@ -1048,6 +1050,43 @@ exists, err := jet.Pilot(db).Exists() exists, err := models.Pilots(db, Where("id=?", 5)).Exists() ``` +### Enums + +If your MySQL or Postgres tables use enums we will generate constants that hold their values +that you can use in your queries. For example: + +``` +CREATE TYPE workday AS ENUM('monday', 'tuesday', 'wednesday', 'thursday', 'friday'); + +CREATE TABLE event_one ( + id serial PRIMARY KEY NOT NULL, + name VARCHAR(255), + day workday NOT NULL +); +``` + +An enum type defined like the above, being used by a table, will generate the following enums: + +```go +const ( + WorkdayMonday = "monday" + WorkdayTuesday = "tuesday" + WorkdayWednesday = "wednesday" + WorkdayThursday = "thursday" + WorkdayFriday = "friday" +) +``` + +For Postgres we use `enum type name + title cased` value to generate the const variable name. +For MySQL we use `table name + column name + title cased value` to generate the const variable name. + +Note: If your enum holds a value we cannot parse correctly due, to non-alphabet characters for example, +it may not be generated. In this event, you will receive errors in your generated tests because +the value randomizer in the test suite does not know how to generate valid enum values. You will +still be able to use your generated library, and it will still work as expected, but the only way +to get the tests to pass in this event is to either use a parsable enum value or use a regular column +instead of an enum. + ## FAQ #### Won't compiling models for a huge database be very slow? diff --git a/bdb/column.go b/bdb/column.go index a760933bb..b5c408e1e 100644 --- a/bdb/column.go +++ b/bdb/column.go @@ -1,6 +1,10 @@ package bdb -import "github.com/vattle/sqlboiler/strmangle" +import ( + "strings" + + "github.com/vattle/sqlboiler/strmangle" +) // Column holds information about a database column. // Types are Go types, converted by TranslateColumnType. @@ -54,3 +58,16 @@ func FilterColumnsByDefault(defaults bool, columns []Column) []Column { return cols } + +// FilterColumnsByEnum generates the list of columns that are enum values. +func FilterColumnsByEnum(columns []Column) []Column { + var cols []Column + + for _, c := range columns { + if strings.HasPrefix(c.DBType, "enum") { + cols = append(cols, c) + } + } + + return cols +} diff --git a/bdb/column_test.go b/bdb/column_test.go index 1d144a673..0b3bb3ba0 100644 --- a/bdb/column_test.go +++ b/bdb/column_test.go @@ -66,3 +66,23 @@ func TestFilterColumnsByDefault(t *testing.T) { t.Errorf("Invalid result: %#v", res) } } + +func TestFilterColumnsByEnum(t *testing.T) { + t.Parallel() + + cols := []Column{ + {Name: "col1", DBType: "enum('hello')"}, + {Name: "col2", DBType: "enum('hello','there')"}, + {Name: "col3", DBType: "enum"}, + {Name: "col4", DBType: ""}, + {Name: "col5", DBType: "int"}, + } + + res := FilterColumnsByEnum(cols) + if res[0].Name != `col1` { + t.Errorf("Invalid result: %#v", res) + } + if res[1].Name != `col2` { + t.Errorf("Invalid result: %#v", res) + } +} diff --git a/bdb/drivers/mysql.go b/bdb/drivers/mysql.go index 38f5fb546..7f1cd4b3e 100644 --- a/bdb/drivers/mysql.go +++ b/bdb/drivers/mysql.go @@ -121,7 +121,11 @@ func (m *MySQLDriver) Columns(schema, tableName string) ([]bdb.Column, error) { var columns []bdb.Column rows, err := m.dbConn.Query(` - select column_name, data_type, if(extra = 'auto_increment','auto_increment', column_default), is_nullable, + select + c.column_name, + if(c.data_type = 'enum', c.column_type, c.data_type), + if(extra = 'auto_increment','auto_increment', c.column_default), + c.is_nullable = 'YES', exists ( select c.column_name from information_schema.table_constraints tc @@ -140,24 +144,23 @@ func (m *MySQLDriver) Columns(schema, tableName string) ([]bdb.Column, error) { defer rows.Close() for rows.Next() { - var colName, colType, colDefault, nullable string - var unique bool - var defaultPtr *string - if err := rows.Scan(&colName, &colType, &defaultPtr, &nullable, &unique); err != nil { + var colName, colType string + var nullable, unique bool + var defaultValue *string + if err := rows.Scan(&colName, &colType, &defaultValue, &nullable, &unique); err != nil { return nil, errors.Wrapf(err, "unable to scan for table %s", tableName) } - if defaultPtr != nil && *defaultPtr != "NULL" { - colDefault = *defaultPtr - } - column := bdb.Column{ Name: colName, DBType: colType, - Default: colDefault, - Nullable: nullable == "YES", + Nullable: nullable, Unique: unique, } + if defaultValue != nil && *defaultValue != "NULL" { + column.Default = *defaultValue + } + columns = append(columns, column) } diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index aaa31df1e..d14fb1b5b 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -6,6 +6,7 @@ import ( "strings" // Side-effect import sql driver + _ "github.com/lib/pq" "github.com/pkg/errors" "github.com/vattle/sqlboiler/bdb" @@ -123,25 +124,55 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error) var columns []bdb.Column rows, err := p.dbConn.Query(` - select column_name, c.data_type, e.data_type, column_default, c.udt_name, is_nullable, - (select exists( - select 1 + select + c.column_name, + ( + case when c.data_type = 'USER-DEFINED' and c.udt_name <> 'hstore' + then + ( + select 'enum.' || c.udt_name || '(''' || string_agg(labels.label, ''',''') || ''')' + from ( + select pg_enum.enumlabel as label + from pg_enum + where pg_enum.enumtypid = + ( + select typelem + from pg_type + where pg_type.typtype = 'b' and pg_type.typname = ('_' || c.udt_name) + limit 1 + ) + order by pg_enum.enumsortorder + ) as labels + ) + else c.data_type + end + ) as column_type, + + c.udt_name, + e.data_type as array_type, + c.column_default, + + c.is_nullable = 'YES' as is_nullable, + (select exists( + select 1 from information_schema.constraint_column_usage as ccu - inner join information_schema.table_constraints tc on ccu.constraint_name = tc.constraint_name - where ccu.table_name = c.table_name and ccu.column_name = c.column_name and tc.constraint_type = 'UNIQUE' + inner join information_schema.table_constraints tc on ccu.constraint_name = tc.constraint_name + where ccu.table_name = c.table_name and ccu.column_name = c.column_name and tc.constraint_type = 'UNIQUE' )) OR (select exists( - select 1 - from - pg_indexes pgix - inner join pg_class pgc on pgix.indexname = pgc.relname and pgc.relkind = 'i' - inner join pg_index pgi on pgi.indexrelid = pgc.oid - inner join pg_attribute pga on pga.attrelid = pgi.indrelid and pga.attnum = ANY(pgi.indkey) - where - pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true + select 1 + from + pg_indexes pgix + inner join pg_class pgc on pgix.indexname = pgc.relname and pgc.relkind = 'i' + inner join pg_index pgi on pgi.indexrelid = pgc.oid + inner join pg_attribute pga on pga.attrelid = pgi.indrelid and pga.attnum = ANY(pgi.indkey) + where + pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true )) as is_unique - from information_schema.columns as c LEFT JOIN information_schema.element_types e - ON ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier) - = (e.object_catalog, e.object_schema, e.object_name, e.object_type, e.collection_type_identifier)) + + from information_schema.columns as c + left join information_schema.element_types e + on ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier) + = (e.object_catalog, e.object_schema, e.object_name, e.object_type, e.collection_type_identifier)) where c.table_name=$2 and c.table_schema = $1; `, schema, tableName) @@ -151,29 +182,25 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error) defer rows.Close() for rows.Next() { - var colName, udtName, colType, colDefault, nullable string - var elementType *string - var unique bool - var defaultPtr *string - if err := rows.Scan(&colName, &colType, &elementType, &defaultPtr, &udtName, &nullable, &unique); err != nil { + var colName, colType, udtName string + var defaultValue, arrayType *string + var nullable, unique bool + if err := rows.Scan(&colName, &colType, &udtName, &arrayType, &defaultValue, &nullable, &unique); err != nil { return nil, errors.Wrapf(err, "unable to scan for table %s", tableName) } - if defaultPtr == nil { - colDefault = "" - } else { - colDefault = *defaultPtr - } - column := bdb.Column{ Name: colName, DBType: colType, - ArrType: elementType, + ArrType: arrayType, UDTName: udtName, - Default: colDefault, - Nullable: nullable == "YES", + Nullable: nullable, Unique: unique, } + if defaultValue != nil { + column.Default = *defaultValue + } + columns = append(columns, column) } @@ -290,6 +317,8 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.Type = "null.Float32" case "bit", "interval", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": c.Type = "null.String" + case `"char"`: + c.Type = "null.Byte" case "bytea": c.Type = "null.Bytes" case "json", "jsonb": @@ -330,6 +359,8 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.Type = "float32" case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": c.Type = "string" + case `"char"`: + c.Type = "types.Byte" case "json", "jsonb": c.Type = "types.JSON" case "bytea": diff --git a/circle.yml b/circle.yml index 5d6727a83..ed75a12bd 100644 --- a/circle.yml +++ b/circle.yml @@ -1,8 +1,7 @@ test: pre: - mkdir -p /home/ubuntu/.go_workspace/src/github.com/jstemmer - - git clone git@github.com:nullbio/go-junit-report.git /home/ubuntu/.go_workspace/src/github.com/jstemmer/go-junit-report - - go install github.com/jstemmer/go-junit-report + - go get -u github.com/jstemmer/go-junit-report - echo -e "[postgres]\nhost=\"localhost\"\nport=5432\nuser=\"ubuntu\"\ndbname=\"sqlboiler\"\n[mysql]\nhost=\"localhost\"\nport=3306\nuser=\"ubuntu\"\ndbname=\"sqlboiler\"\nsslmode=\"false\"" > sqlboiler.toml - createdb -U ubuntu sqlboiler - psql -U ubuntu sqlboiler < ./testdata/postgres_test_schema.sql diff --git a/imports.go b/imports.go index 224e76338..880ba7760 100644 --- a/imports.go +++ b/imports.go @@ -274,55 +274,55 @@ var defaultTestMainImports = map[string]imports{ // TranslateColumnType to see the type assignments. var importsBasedOnType = map[string]imports{ "null.Float32": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Float64": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Int": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Int8": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Int16": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Int32": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Int64": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Uint": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Uint8": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Uint16": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Uint32": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Uint64": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.String": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Bool": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Time": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.JSON": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "null.Bytes": { - thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }, "time.Time": { standard: importList{`"time"`}, diff --git a/imports_test.go b/imports_test.go index b07fca79b..628d9558a 100644 --- a/imports_test.go +++ b/imports_test.go @@ -75,7 +75,7 @@ func TestCombineTypeImports(t *testing.T) { }, thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, - `"gopkg.in/nullbio/null.v5"`, + `"gopkg.in/nullbio/null.v6"`, }, } @@ -108,7 +108,7 @@ func TestCombineTypeImports(t *testing.T) { }, thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, - `"gopkg.in/nullbio/null.v5"`, + `"gopkg.in/nullbio/null.v6"`, }, } @@ -124,7 +124,7 @@ func TestCombineImports(t *testing.T) { a := imports{ standard: importList{"fmt"}, - thirdParty: importList{"github.com/vattle/sqlboiler", "gopkg.in/nullbio/null.v5"}, + thirdParty: importList{"github.com/vattle/sqlboiler", "gopkg.in/nullbio/null.v6"}, } b := imports{ standard: importList{"os"}, @@ -136,8 +136,8 @@ func TestCombineImports(t *testing.T) { if c.standard[0] != "fmt" && c.standard[1] != "os" { t.Errorf("Wanted: fmt, os got: %#v", c.standard) } - if c.thirdParty[0] != "github.com/vattle/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v5" { - t.Errorf("Wanted: github.com/vattle/sqlboiler, gopkg.in/nullbio/null.v5 got: %#v", c.thirdParty) + if c.thirdParty[0] != "github.com/vattle/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v6" { + t.Errorf("Wanted: github.com/vattle/sqlboiler, gopkg.in/nullbio/null.v6 got: %#v", c.thirdParty) } } diff --git a/main.go b/main.go index 77e4b5768..edee69ce9 100644 --- a/main.go +++ b/main.go @@ -125,6 +125,7 @@ func preRun(cmd *cobra.Command, args []string) error { OutFolder: viper.GetString("output"), Schema: viper.GetString("schema"), PkgName: viper.GetString("pkgname"), + BaseDir: viper.GetString("basedir"), Debug: viper.GetBool("debug"), NoTests: viper.GetBool("no-tests"), NoHooks: viper.GetBool("no-hooks"), diff --git a/queries/eager_load.go b/queries/eager_load.go index 95f337de1..6377409cc 100644 --- a/queries/eager_load.go +++ b/queries/eager_load.go @@ -71,7 +71,7 @@ func eagerLoad(exec boil.Executor, toLoad []string, obj interface{}, bkind bindK // - t is not considered here, and is always passed nil. The function exists on a loaded // struct to avoid a circular dependency with boil, and the receiver is ignored. // - exec is used to perform additional queries that might be required for loading the relationships. -// - singular is passed in to identify whether or not this was a single object +// - bkind is passed in to identify whether or not this was a single object // or a slice that must be loaded into. // - obj is the object or slice of objects, always of the type *obj or *[]*obj as per bind. // diff --git a/queries/helpers_test.go b/queries/helpers_test.go index c87bc7602..d37fcd9e5 100644 --- a/queries/helpers_test.go +++ b/queries/helpers_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "gopkg.in/nullbio/null.v5" + null "gopkg.in/nullbio/null.v6" ) type testObj struct { diff --git a/randomize/randomize.go b/randomize/randomize.go index a3cb5b228..df5d75e2a 100644 --- a/randomize/randomize.go +++ b/randomize/randomize.go @@ -4,6 +4,7 @@ package randomize import ( "database/sql" "fmt" + "math/rand" "reflect" "regexp" "sort" @@ -12,7 +13,7 @@ import ( "sync/atomic" "time" - "gopkg.in/nullbio/null.v5" + null "gopkg.in/nullbio/null.v6" "github.com/pkg/errors" "github.com/satori/go.uuid" @@ -34,6 +35,7 @@ var ( typeNullUint32 = reflect.TypeOf(null.Uint32{}) typeNullUint64 = reflect.TypeOf(null.Uint64{}) typeNullString = reflect.TypeOf(null.String{}) + typeNullByte = reflect.TypeOf(null.Byte{}) typeNullBool = reflect.TypeOf(null.Bool{}) typeNullTime = reflect.TypeOf(null.Time{}) typeNullBytes = reflect.TypeOf(null.Bytes{}) @@ -156,9 +158,26 @@ func randDate(s *Seed) time.Time { // If canBeNull is true: // The value has the possibility of being null or non-zero at random. func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bool) error { + kind := field.Kind() typ := field.Type() + if strings.HasPrefix(fieldType, "enum") { + enum, err := randEnumValue(fieldType) + if err != nil { + return err + } + + if kind == reflect.Struct { + val := null.NewString(enum, rand.Intn(1) == 0) + field.Set(reflect.ValueOf(val)) + } else { + field.Set(reflect.ValueOf(enum)) + } + + return nil + } + var value interface{} var isNull bool @@ -341,7 +360,7 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo // only get zero values for non byte slices // to stop mysql from being a jerk if isNull && kind != reflect.Slice { - value = getVariableZeroValue(s, kind) + value = getVariableZeroValue(s, kind, typ) } else { value = getVariableRandValue(s, kind, typ) } @@ -457,6 +476,8 @@ func getStructNullValue(s *Seed, typ reflect.Type) interface{} { return null.NewUint64(0, false) case typeNullBytes: return null.NewBytes(nil, false) + case typeNullByte: + return null.NewByte(byte(0), false) } return nil @@ -501,13 +522,21 @@ func getStructRandValue(s *Seed, typ reflect.Type) interface{} { return null.NewUint64(uint64(s.nextInt()), true) case typeNullBytes: return null.NewBytes(randByteSlice(s, 1), true) + case typeNullByte: + return null.NewByte(byte(rand.Intn(125-65)+65), true) } return nil } // getVariableZeroValue for the matching type. -func getVariableZeroValue(s *Seed, kind reflect.Kind) interface{} { +func getVariableZeroValue(s *Seed, kind reflect.Kind, typ reflect.Type) interface{} { + switch typ.String() { + case "types.Byte": + // Decimal 65 is 'A'. 0 is not a valid UTF8, so cannot use a zero value here. + return types.Byte(65) + } + switch kind { case reflect.Float32: return float32(0) @@ -548,6 +577,11 @@ func getVariableZeroValue(s *Seed, kind reflect.Kind) interface{} { // The randomness is really an incrementation of the global seed, // this is done to avoid duplicate key violations. func getVariableRandValue(s *Seed, kind reflect.Kind, typ reflect.Type) interface{} { + switch typ.String() { + case "types.Byte": + return types.Byte(rand.Intn(125-65) + 65) + } + switch kind { case reflect.Float32: return float32(float32(s.nextInt()%10)/10.0 + float32(s.nextInt()%10)) @@ -587,3 +621,12 @@ func getVariableRandValue(s *Seed, kind reflect.Kind, typ reflect.Type) interfac return nil } + +func randEnumValue(enum string) (string, error) { + vals := strmangle.ParseEnumVals(enum) + if vals == nil || len(vals) == 0 { + return "", fmt.Errorf("unable to parse enum string: %s", enum) + } + + return vals[rand.Intn(len(vals)-1)], nil +} diff --git a/randomize/randomize_test.go b/randomize/randomize_test.go index a3ba08304..6f117b767 100644 --- a/randomize/randomize_test.go +++ b/randomize/randomize_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "gopkg.in/nullbio/null.v5" + null "gopkg.in/nullbio/null.v6" ) func TestRandomizeStruct(t *testing.T) { @@ -144,3 +144,28 @@ func TestRandomizeField(t *testing.T) { } } } + +func TestRandEnumValue(t *testing.T) { + t.Parallel() + + enum1 := "enum.workday('monday','tuesday')" + enum2 := "enum('monday','tuesday')" + + r1, err := randEnumValue(enum1) + if err != nil { + t.Error(err) + } + + if r1 != "monday" && r1 != "tuesday" { + t.Errorf("Expected monday or tueday, got: %q", r1) + } + + r2, err := randEnumValue(enum2) + if err != nil { + t.Error(err) + } + + if r2 != "monday" && r2 != "tuesday" { + t.Errorf("Expected monday or tueday, got: %q", r2) + } +} diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index ef032a6b0..ab93d265d 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -16,15 +16,31 @@ import ( var ( idAlphabet = []byte("abcdefghijklmnopqrstuvwxyz") smartQuoteRgx = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(\."?[_a-z][_a-z0-9]*"?)*(\.\*)?$`) + + rgxEnum = regexp.MustCompile(`^enum(\.[a-z_]+)?\((,?'[^']+')+\)$`) + rgxEnumIsOK = regexp.MustCompile(`^(?i)[a-z][a-z0-9_]*$`) + rgxEnumShouldTitle = regexp.MustCompile(`^[a-z][a-z0-9_]*$`) ) var uppercaseWords = map[string]struct{}{ - "guid": {}, - "id": {}, - "ip": {}, - "uid": {}, - "uuid": {}, - "json": {}, + "acl": {}, + "api": {}, + "ascii": {}, + "cpu": {}, + "eof": {}, + "guid": {}, + "id": {}, + "ip": {}, + "json": {}, + "ram": {}, + "sla": {}, + "udp": {}, + "ui": {}, + "uid": {}, + "uuid": {}, + "uri": {}, + "url": {}, + "utf8": {}, } func init() { @@ -364,7 +380,7 @@ func MakeStringMap(types map[string]string) string { c := 0 for _, k := range keys { v := types[k] - buf.WriteString(fmt.Sprintf(`"%s": "%s"`, k, v)) + buf.WriteString(fmt.Sprintf("`%s`: `%s`", k, v)) if c < len(types)-1 { buf.WriteString(", ") } @@ -562,3 +578,55 @@ func GenerateIgnoreTags(tags []string) string { return buf.String() } + +// ParseEnumVals returns the values from an enum string +// +// Postgres and MySQL drivers return different values +// psql: enum.enum_name('values'...) +// mysql: enum('values'...) +func ParseEnumVals(s string) []string { + if !rgxEnum.MatchString(s) { + return nil + } + + startIndex := strings.IndexByte(s, '(') + s = s[startIndex+2 : len(s)-2] + return strings.Split(s, "','") +} + +// ParseEnumName returns the name portion of an enum if it exists +// +// Postgres and MySQL drivers return different values +// psql: enum.enum_name('values'...) +// mysql: enum('values'...) +// In the case of mysql, the name will never return anything +func ParseEnumName(s string) string { + if !rgxEnum.MatchString(s) { + return "" + } + + endIndex := strings.IndexByte(s, '(') + s = s[:endIndex] + startIndex := strings.IndexByte(s, '.') + if startIndex < 0 { + return "" + } + + return s[startIndex+1:] +} + +// IsEnumNormal checks a set of eval values to see if they're "normal" +func IsEnumNormal(values []string) bool { + for _, v := range values { + if !rgxEnumIsOK.MatchString(v) { + return false + } + } + + return true +} + +// ShouldTitleCaseEnum checks a value to see if it's title-case-able +func ShouldTitleCaseEnum(value string) bool { + return rgxEnumShouldTitle.MatchString(value) +} diff --git a/strmangle/strmangle_test.go b/strmangle/strmangle_test.go index 6d802d4c0..2a14af5ac 100644 --- a/strmangle/strmangle_test.go +++ b/strmangle/strmangle_test.go @@ -291,8 +291,8 @@ func TestMakeStringMap(t *testing.T) { r = MakeStringMap(m) - e1 := `"TestOne": "interval", "TestTwo": "integer"` - e2 := `"TestTwo": "integer", "TestOne": "interval"` + e1 := "`TestOne`: `interval`, `TestTwo`: `integer`" + e2 := "`TestTwo`: `integer`, `TestOne`: `interval`" if r != e1 && r != e2 { t.Errorf("Got %s", r) @@ -513,3 +513,70 @@ func TestGenerateIgnoreTags(t *testing.T) { t.Errorf("expected %s, got %s", exp, tags) } } + +func TestParseEnum(t *testing.T) { + t.Parallel() + + tests := []struct { + Enum string + Name string + Vals []string + }{ + {"enum('one')", "", []string{"one"}}, + {"enum('one','two')", "", []string{"one", "two"}}, + {"enum.working('one')", "working", []string{"one"}}, + {"enum.wor_king('one','two')", "wor_king", []string{"one", "two"}}, + } + + for i, test := range tests { + name := ParseEnumName(test.Enum) + vals := ParseEnumVals(test.Enum) + if name != test.Name { + t.Errorf("%d) name was wrong, want: %s got: %s (%s)", i, test.Name, name, test.Enum) + } + for j, v := range test.Vals { + if v != vals[j] { + t.Errorf("%d.%d) value was wrong, want: %s got: %s (%s)", i, j, v, vals[j], test.Enum) + } + } + } +} + +func TestIsEnumNormal(t *testing.T) { + t.Parallel() + + tests := []struct { + Vals []string + Ok bool + }{ + {[]string{"o1ne", "two2"}, true}, + {[]string{"one", "t#wo2"}, false}, + {[]string{"1one", "two2"}, false}, + } + + for i, test := range tests { + if got := IsEnumNormal(test.Vals); got != test.Ok { + t.Errorf("%d) want: %t got: %t, %#v", i, test.Ok, got, test.Vals) + } + } +} + +func TestShouldTitleCaseEnum(t *testing.T) { + t.Parallel() + + tests := []struct { + Val string + Ok bool + }{ + {"hello_there0", true}, + {"hEllo", false}, + {"_hello", false}, + {"0hello", false}, + } + + for i, test := range tests { + if got := ShouldTitleCaseEnum(test.Val); got != test.Ok { + t.Errorf("%d) want: %t got: %t, %v", i, test.Ok, got, test.Val) + } + } +} diff --git a/templates.go b/templates.go index a56c81e4d..9baae196a 100644 --- a/templates.go +++ b/templates.go @@ -121,6 +121,28 @@ func loadTemplate(dir string, filename string) (*template.Template, error) { return tpl.Lookup(filename), err } +// set is to stop duplication from named enums, allowing a template loop +// to keep some state +type once map[string]struct{} + +func newOnce() once { + return make(once) +} + +func (o once) Has(s string) bool { + _, ok := o[s] + return ok +} + +func (o once) Put(s string) bool { + if _, ok := o[s]; ok { + return false + } + + o[s] = struct{}{} + return true +} + // templateStringMappers are placed into the data to make it easy to use the // stringMap function. var templateStringMappers = map[string]func(string) string{ @@ -157,6 +179,15 @@ var templateFunctions = template.FuncMap{ "generateTags": strmangle.GenerateTags, "generateIgnoreTags": strmangle.GenerateIgnoreTags, + // Enum ops + "parseEnumName": strmangle.ParseEnumName, + "parseEnumVals": strmangle.ParseEnumVals, + "isEnumNormal": strmangle.IsEnumNormal, + "shouldTitleCaseEnum": strmangle.ShouldTitleCaseEnum, + "onceNew": newOnce, + "oncePut": once.Put, + "onceHas": once.Has, + // String Map ops "makeStringMap": strmangle.MakeStringMap, @@ -173,6 +204,7 @@ var templateFunctions = template.FuncMap{ // dbdrivers ops "filterColumnsByDefault": bdb.FilterColumnsByDefault, + "filterColumnsByEnum": bdb.FilterColumnsByEnum, "sqlColDefinitions": bdb.SQLColDefinitions, "columnNames": bdb.ColumnNames, "columnDBTypes": bdb.ColumnDBTypes, diff --git a/templates/07_relationship_to_one_eager.tpl b/templates/07_relationship_to_one_eager.tpl index 79790adcc..43392f0e6 100644 --- a/templates/07_relationship_to_one_eager.tpl +++ b/templates/07_relationship_to_one_eager.tpl @@ -22,11 +22,15 @@ func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singula args := make([]interface{}, count) if singular { - object.R = &{{$varNameSingular}}R{} + if object.R == nil { + object.R = &{{$varNameSingular}}R{} + } args[0] = object.{{$txt.LocalTable.ColumnNameGo}} } else { for i, obj := range slice { - obj.R = &{{$varNameSingular}}R{} + if obj.R == nil { + obj.R = &{{$varNameSingular}}R{} + } args[i] = obj.{{$txt.LocalTable.ColumnNameGo}} } } diff --git a/templates/08_relationship_one_to_one_eager.tpl b/templates/08_relationship_one_to_one_eager.tpl index 1dd9119c9..6603d55b6 100644 --- a/templates/08_relationship_one_to_one_eager.tpl +++ b/templates/08_relationship_one_to_one_eager.tpl @@ -22,11 +22,15 @@ func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singula args := make([]interface{}, count) if singular { - object.R = &{{$varNameSingular}}R{} + if object.R == nil { + object.R = &{{$varNameSingular}}R{} + } args[0] = object.{{$txt.LocalTable.ColumnNameGo}} } else { for i, obj := range slice { - obj.R = &{{$varNameSingular}}R{} + if obj.R == nil { + obj.R = &{{$varNameSingular}}R{} + } args[i] = obj.{{$txt.LocalTable.ColumnNameGo}} } } diff --git a/templates/09_relationship_to_many_eager.tpl b/templates/09_relationship_to_many_eager.tpl index 9611c022e..f1a7f5d26 100644 --- a/templates/09_relationship_to_many_eager.tpl +++ b/templates/09_relationship_to_many_eager.tpl @@ -23,11 +23,15 @@ func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singula args := make([]interface{}, count) if singular { - object.R = &{{$varNameSingular}}R{} + if object.R == nil { + object.R = &{{$varNameSingular}}R{} + } args[0] = object.{{.Column | titleCase}} } else { for i, obj := range slice { - obj.R = &{{$varNameSingular}}R{} + if obj.R == nil { + obj.R = &{{$varNameSingular}}R{} + } args[i] = obj.{{.Column | titleCase}} } } diff --git a/templates/singleton/boil_types.tpl b/templates/singleton/boil_types.tpl index ebb88918f..9bf13e8b3 100644 --- a/templates/singleton/boil_types.tpl +++ b/templates/singleton/boil_types.tpl @@ -35,3 +35,52 @@ func makeCacheKey(wl, nzDefaults []string) string { strmangle.PutBuffer(buf) return str } + +{{/* +The following is a little bit of black magic and deserves some explanation + +Because postgres and mysql define enums completely differently (one at the +database level as a custom datatype, and one at the table column level as +a unique thing per table)... There's a chance the enum is named (postgres) +and not (mysql). So we can't do this per table so this code is here. + +We loop through each table and column looking for enums. If it's named, we +then use some disgusting magic to write state during the template compile to +the "once" map. This lets named enums only be defined once if they're referenced +multiple times in many (or even the same) tables. + +Then we check if all it's values are normal, if they are we create the enum +output, if not we output a friendly error message as a comment to aid in +debugging. + +Postgres output looks like: EnumNameEnumValue = "enumvalue" +MySQL output looks like: TableNameColNameEnumValue = "enumvalue" + +It only titlecases the EnumValue portion if it's snake-cased. +*/}} +{{$dot := . -}} +{{$once := onceNew}} +{{- range $table := .Tables -}} + {{- range $col := $table.Columns | filterColumnsByEnum -}} + {{- $name := parseEnumName $col.DBType -}} + {{- $vals := parseEnumVals $col.DBType -}} + {{- $isNamed := ne (len $name) 0}} + {{- if and $isNamed (onceHas $once $name) -}} + {{- else -}} + {{- if $isNamed -}} + {{$_ := oncePut $once $name}} + {{- end -}} +{{- if and (gt (len $vals) 0) (isEnumNormal $vals)}} +// Enum values for {{if $isNamed}}{{$name}}{{else}}{{$table.Name}}.{{$col.Name}}{{end}} +const ( + {{- range $val := $vals -}} + {{- if $isNamed}}{{titleCase $name}}{{else}}{{titleCase $table.Name}}{{titleCase $col.Name}}{{end -}} + {{if shouldTitleCaseEnum $val}}{{titleCase $val}}{{else}}{{$val}}{{end}} = "{{$val}}" + {{end -}} +) +{{- else}} +// Enum values for {{if $isNamed}}{{$name}}{{else}}{{$table.Name}}.{{$col.Name}}{{end}} are not proper Go identifiers, cannot emit constants +{{- end -}} + {{- end -}} + {{- end -}} +{{- end -}} diff --git a/testdata/mysql_test_schema.sql b/testdata/mysql_test_schema.sql index e6c83ad89..423aceb34 100644 --- a/testdata/mysql_test_schema.sql +++ b/testdata/mysql_test_schema.sql @@ -1,3 +1,23 @@ +CREATE TABLE event_one ( + id serial PRIMARY KEY NOT NULL, + name VARCHAR(255), + day enum('monday','tuesday','wednesday') +); + +CREATE TABLE event_two ( + id serial PRIMARY KEY NOT NULL, + name VARCHAR(255), + face enum('happy','sad','bitter') +); + +CREATE TABLE event_three ( + id serial PRIMARY KEY NOT NULL, + name VARCHAR(255), + face enum('happy','sad','bitter'), + mood enum('happy','sad','bitter'), + day enum('monday','tuesday','wednesday') +); + CREATE TABLE magic ( id int PRIMARY KEY NOT NULL AUTO_INCREMENT, id_two int NOT NULL, diff --git a/testdata/postgres_test_schema.sql b/testdata/postgres_test_schema.sql index 2d49edfbb..7a3893e5c 100644 --- a/testdata/postgres_test_schema.sql +++ b/testdata/postgres_test_schema.sql @@ -1,3 +1,33 @@ +CREATE TYPE workday AS ENUM('monday', 'tuesday', 'wednesday', 'thursday', 'friday'); +CREATE TYPE faceyface AS ENUM('angry', 'hungry', 'bitter'); + +CREATE TABLE event_one ( + id serial PRIMARY KEY NOT NULL, + name VARCHAR(255), + day workday NOT NULL +); + +CREATE TABLE event_two ( + id serial PRIMARY KEY NOT NULL, + name VARCHAR(255), + day workday NOT NULL +); + +CREATE TABLE event_three ( + id serial PRIMARY KEY NOT NULL, + name VARCHAR(255), + day workday NOT NULL, + face faceyface NOT NULL, + thing workday NULL, + stuff faceyface NULL +); + +CREATE TABLE facey ( + id serial PRIMARY KEY NOT NULL, + name VARCHAR(255), + face faceyface NOT NULL +); + CREATE TABLE magic ( id serial PRIMARY KEY NOT NULL, id_two serial NOT NULL, @@ -24,6 +54,23 @@ CREATE TABLE magic ( string_ten VARCHAR(1000) NULL DEFAULT '', string_eleven VARCHAR(1000) NOT NULL DEFAULT '', + nonbyte_zero CHAR(1), + nonbyte_one CHAR(1) NULL, + nonbyte_two CHAR(1) NOT NULL, + nonbyte_three CHAR(1) NULL DEFAULT 'a', + nonbyte_four CHAR(1) NOT NULL DEFAULT 'b', + nonbyte_five CHAR(1000), + nonbyte_six CHAR(1000) NULL, + nonbyte_seven CHAR(1000) NOT NULL, + nonbyte_eight CHAR(1000) NULL DEFAULT 'a', + nonbyte_nine CHAR(1000) NOT NULL DEFAULT 'b', + + byte_zero "char", + byte_one "char" NULL, + byte_two "char" NULL DEFAULT 'a', + byte_three "char" NOT NULL, + byte_four "char" NOT NULL DEFAULT 'b', + big_int_zero bigint, big_int_one bigint NULL, big_int_two bigint NOT NULL, diff --git a/types/byte.go b/types/byte.go new file mode 100644 index 000000000..3b4cce7ad --- /dev/null +++ b/types/byte.go @@ -0,0 +1,61 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" + "errors" +) + +// Byte is an alias for byte. +// Byte implements Marshal and Unmarshal. +type Byte byte + +// String output your byte. +func (b Byte) String() string { + return string(b) +} + +// UnmarshalJSON sets *b to a copy of data. +func (b *Byte) UnmarshalJSON(data []byte) error { + if b == nil { + return errors.New("json: unmarshal json on nil pointer to byte") + } + + var x string + if err := json.Unmarshal(data, &x); err != nil { + return err + } + + if len(x) > 1 { + return errors.New("json: cannot convert to byte, text len is greater than one") + } + + *b = Byte(x[0]) + return nil +} + +// MarshalJSON returns the JSON encoding of b. +func (b Byte) MarshalJSON() ([]byte, error) { + return []byte{'"', byte(b), '"'}, nil +} + +// Value returns b as a driver.Value. +func (b Byte) Value() (driver.Value, error) { + return []byte{byte(b)}, nil +} + +// Scan stores the src in *b. +func (b *Byte) Scan(src interface{}) error { + switch src.(type) { + case uint8: + *b = Byte(src.(uint8)) + case string: + *b = Byte(src.(string)[0]) + case []byte: + *b = Byte(src.([]byte)[0]) + default: + return errors.New("incompatible type for byte") + } + + return nil +} diff --git a/types/byte_test.go b/types/byte_test.go new file mode 100644 index 000000000..f3b81c9f2 --- /dev/null +++ b/types/byte_test.go @@ -0,0 +1,74 @@ +package types + +import ( + "bytes" + "encoding/json" + "testing" +) + +func TestByteString(t *testing.T) { + t.Parallel() + + b := Byte('b') + if b.String() != "b" { + t.Errorf("Expected %q, got %s", "b", b.String()) + } +} + +func TestByteUnmarshal(t *testing.T) { + t.Parallel() + + var b Byte + err := json.Unmarshal([]byte(`"b"`), &b) + if err != nil { + t.Error(err) + } + + if b != 'b' { + t.Errorf("Expected %q, got %s", "b", b) + } +} + +func TestByteMarshal(t *testing.T) { + t.Parallel() + + b := Byte('b') + res, err := json.Marshal(&b) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(res, []byte(`"b"`)) { + t.Errorf("expected %s, got %s", `"b"`, b.String()) + } +} + +func TestByteValue(t *testing.T) { + t.Parallel() + + b := Byte('b') + v, err := b.Value() + if err != nil { + t.Error(err) + } + + if !bytes.Equal([]byte{byte(b)}, v.([]byte)) { + t.Errorf("byte mismatch, %v %v", b, v) + } +} + +func TestByteScan(t *testing.T) { + t.Parallel() + + var b Byte + + s := "b" + err := b.Scan(s) + if err != nil { + t.Error(err) + } + + if !bytes.Equal([]byte{byte(b)}, []byte{'b'}) { + t.Errorf("bad []byte: %#v ≠ %#v\n", b, "b") + } +} diff --git a/types/json.go b/types/json.go index b42e694b3..b9ac61600 100644 --- a/types/json.go +++ b/types/json.go @@ -35,7 +35,7 @@ func (j *JSON) Marshal(obj interface{}) error { // UnmarshalJSON sets *j to a copy of data. func (j *JSON) UnmarshalJSON(data []byte) error { if j == nil { - return errors.New("JSON: UnmarshalJSON on nil pointer") + return errors.New("json: unmarshal json on nil pointer to json") } *j = append((*j)[0:0], data...) @@ -68,7 +68,7 @@ func (j *JSON) Scan(src interface{}) error { case []byte: source = src.([]byte) default: - return errors.New("Incompatible type for JSON") + return errors.New("incompatible type for json") } *j = JSON(append((*j)[0:0], source...))