From 7144d272bdbf9b0d259f99cbf23dfa8903f6db39 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Tue, 6 Sep 2016 00:41:12 +1000 Subject: [PATCH 01/64] Add whitelist feature --- README.md | 2 ++ bdb/drivers/mock.go | 5 ++++- bdb/drivers/postgres.go | 16 ++++++---------- bdb/interface.go | 6 +++--- bdb/interface_test.go | 7 +++++-- config.go | 1 + main.go | 9 +++++++++ sqlboiler.go | 6 +++--- text_helpers_test.go | 6 +++--- 9 files changed, 36 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 531ae4e2a..dbadf9ec2 100644 --- a/README.md +++ b/README.md @@ -225,6 +225,7 @@ not to pass them through the command line or environment variables: | basedir | none | | pkgname | "models" | | output | "models" | +| whitelist | [ ] | | exclude | [ ] | | tag | [ ] | | debug | false | @@ -261,6 +262,7 @@ sqlboiler postgres Flags: -b, --basedir string The base directory has the templates and templates_test folders -d, --debug Debug mode prints stack traces on error + -w, --whitelist stringSlice Only include these tables in your generated package -x, --exclude stringSlice Tables to be excluded from the generated package --no-auto-timestamps Disable automatic timestamps for created_at/updated_at --no-hooks Disable hooks feature for your models diff --git a/bdb/drivers/mock.go b/bdb/drivers/mock.go index 692062797..93d18cc45 100644 --- a/bdb/drivers/mock.go +++ b/bdb/drivers/mock.go @@ -9,7 +9,10 @@ import ( type MockDriver struct{} // TableNames returns a list of mock table names -func (m *MockDriver) TableNames(exclude []string) ([]string, error) { +func (m *MockDriver) TableNames(whitelist, exclude []string) ([]string, error) { + if len(whitelist) > 0 { + return whitelist, nil + } tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"} return strmangle.SetComplement(tables, exclude), nil } diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index c4b9984ac..507eec16f 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -9,7 +9,6 @@ import ( _ "github.com/lib/pq" "github.com/pkg/errors" "github.com/vattle/sqlboiler/bdb" - "github.com/vattle/sqlboiler/strmangle" ) // PostgresDriver holds the database connection string and a handle @@ -82,18 +81,15 @@ func (p *PostgresDriver) UseLastInsertID() bool { // TableNames connects to the postgres database and // retrieves all table names from the information_schema where the -// table schema is public. It excludes common migration tool tables -// such as gorp_migrations -func (p *PostgresDriver) TableNames(exclude []string) ([]string, error) { +// table schema is public. It uses a whitelist and exclude list. +func (p *PostgresDriver) TableNames(whitelist, exclude []string) ([]string, error) { var names []string query := `select table_name from information_schema.tables where table_schema = 'public'` - if len(exclude) > 0 { - quoteStr := func(x string) string { - return `'` + x + `'` - } - exclude = strmangle.StringMap(quoteStr, exclude) - query = query + fmt.Sprintf("and table_name not in (%s);", strings.Join(exclude, ",")) + if len(whitelist) > 0 { + query = query + fmt.Sprintf("and table_name in ('%s');", strings.Join(whitelist, "','")) + } else if len(exclude) > 0 { + query = query + fmt.Sprintf("and table_name not in ('%s');", strings.Join(exclude, "','")) } rows, err := p.dbConn.Query(query) diff --git a/bdb/interface.go b/bdb/interface.go index 9e4bfc70f..b59dbe1e4 100644 --- a/bdb/interface.go +++ b/bdb/interface.go @@ -6,7 +6,7 @@ import "github.com/pkg/errors" // Interface for a database driver. Functionality required to support a specific // database type (eg, MySQL, Postgres etc.) type Interface interface { - TableNames(exclude []string) ([]string, error) + TableNames(whitelist, exclude []string) ([]string, error) Columns(tableName string) ([]Column, error) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) ForeignKeyInfo(tableName string) ([]ForeignKey, error) @@ -26,10 +26,10 @@ type Interface interface { // Tables returns the metadata for all tables, minus the tables // specified in the exclude slice. -func Tables(db Interface, exclude ...string) ([]Table, error) { +func Tables(db Interface, whitelist, exclude []string) ([]Table, error) { var err error - names, err := db.TableNames(exclude) + names, err := db.TableNames(whitelist, exclude) if err != nil { return nil, errors.Wrap(err, "unable to get table names") } diff --git a/bdb/interface_test.go b/bdb/interface_test.go index fd3e59f60..048f3df5c 100644 --- a/bdb/interface_test.go +++ b/bdb/interface_test.go @@ -13,7 +13,10 @@ func (m mockDriver) UseLastInsertID() bool { return false } func (m mockDriver) Open() error { return nil } func (m mockDriver) Close() {} -func (m mockDriver) TableNames(exclude []string) ([]string, error) { +func (m mockDriver) TableNames(whitelist, exclude []string) ([]string, error) { + if len(whitelist) > 0 { + return whitelist, nil + } tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"} return strmangle.SetComplement(tables, exclude), nil } @@ -96,7 +99,7 @@ func (m mockDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) { func TestTables(t *testing.T) { t.Parallel() - tables, err := Tables(mockDriver{}) + tables, err := Tables(mockDriver{}, nil, nil) if err != nil { t.Error(err) } diff --git a/config.go b/config.go index 4af2518c6..662747dfc 100644 --- a/config.go +++ b/config.go @@ -6,6 +6,7 @@ type Config struct { PkgName string OutFolder string BaseDir string + WhitelistTables []string ExcludeTables []string Tags []string Debug bool diff --git a/main.go b/main.go index db3043b6e..058e617e0 100644 --- a/main.go +++ b/main.go @@ -64,6 +64,7 @@ func main() { rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package") rootCmd.PersistentFlags().StringP("basedir", "b", "", "The base directory has the templates and templates_test folders") rootCmd.PersistentFlags().StringSliceP("exclude", "x", nil, "Tables to be excluded from the generated package") + rootCmd.PersistentFlags().StringSliceP("whitelist", "w", nil, "Only include these tables in your generated package") rootCmd.PersistentFlags().StringSliceP("tag", "t", nil, "Struct tags to be included on your models in addition to json, yaml, toml") rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error") rootCmd.PersistentFlags().BoolP("no-tests", "", false, "Disable generated go test files") @@ -126,6 +127,14 @@ func preRun(cmd *cobra.Command, args []string) error { } } + cmdConfig.WhitelistTables = viper.GetStringSlice("whitelist") + if len(cmdConfig.WhitelistTables) == 1 && strings.HasPrefix(cmdConfig.WhitelistTables[0], "[") { + cmdConfig.WhitelistTables, err = cmd.PersistentFlags().GetStringSlice("whitelist") + if err != nil { + return err + } + } + cmdConfig.Tags = viper.GetStringSlice("tag") if len(cmdConfig.Tags) == 1 && strings.HasPrefix(cmdConfig.Tags[0], "[") { cmdConfig.Tags, err = cmd.PersistentFlags().GetStringSlice("tag") diff --git a/sqlboiler.go b/sqlboiler.go index aa4d92402..fcecb0d77 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -59,7 +59,7 @@ func New(config *Config) (*State, error) { return nil, errors.Wrap(err, "unable to connect to the database") } - err = s.initTables(config.ExcludeTables) + err = s.initTables(config.WhitelistTables, config.ExcludeTables) if err != nil { return nil, errors.Wrap(err, "unable to initialize tables") } @@ -239,9 +239,9 @@ func (s *State) initDriver(driverName string) error { } // initTables retrieves all "public" schema table names from the database. -func (s *State) initTables(exclude []string) error { +func (s *State) initTables(whitelist, exclude []string) error { var err error - s.Tables, err = bdb.Tables(s.Driver, exclude...) + s.Tables, err = bdb.Tables(s.Driver, whitelist, exclude) if err != nil { return errors.Wrap(err, "unable to fetch table data") } diff --git a/text_helpers_test.go b/text_helpers_test.go index 65ab7efc9..a85a4e5ae 100644 --- a/text_helpers_test.go +++ b/text_helpers_test.go @@ -12,7 +12,7 @@ import ( func TestTextsFromForeignKey(t *testing.T) { t.Parallel() - tables, err := bdb.Tables(&drivers.MockDriver{}) + tables, err := bdb.Tables(&drivers.MockDriver{}, nil, nil) if err != nil { t.Fatal(err) } @@ -81,7 +81,7 @@ func TestTextsFromForeignKey(t *testing.T) { func TestTextsFromOneToOneRelationship(t *testing.T) { t.Parallel() - tables, err := bdb.Tables(&drivers.MockDriver{}) + tables, err := bdb.Tables(&drivers.MockDriver{}, nil, nil) if err != nil { t.Fatal(err) } @@ -130,7 +130,7 @@ func TestTextsFromOneToOneRelationship(t *testing.T) { func TestTextsFromRelationship(t *testing.T) { t.Parallel() - tables, err := bdb.Tables(&drivers.MockDriver{}) + tables, err := bdb.Tables(&drivers.MockDriver{}, nil, nil) if err != nil { t.Fatal(err) } From 41c36cadf354a6f5d7a3f5d9f7b310f7157d55ad Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Tue, 6 Sep 2016 01:24:19 +1000 Subject: [PATCH 02/64] ValuesFromMapping now gets values --- boil/reflect.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/boil/reflect.go b/boil/reflect.go index 27376edc4..d4c4dee99 100644 --- a/boil/reflect.go +++ b/boil/reflect.go @@ -322,8 +322,10 @@ func ptrFromMapping(val reflect.Value, mapping uint64, addressOf bool) reflect.V v := (mapping >> uint(i*8)) & sentinel if v == sentinel { - if val.Kind() != reflect.Ptr { + if addressOf && val.Kind() != reflect.Ptr { return val.Addr() + } else if !addressOf && val.Kind() == reflect.Ptr { + return reflect.Indirect(val) } return val } From ce8573eccd9e5ab2d019489f0476f1096b43d526 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Wed, 7 Sep 2016 23:50:54 +1000 Subject: [PATCH 03/64] Updated to null.v5, update postgres driver types --- bdb/drivers/postgres.go | 8 ++++--- boil/helpers_test.go | 2 +- boil/randomize/randomize.go | 2 +- boil/randomize/randomize_test.go | 2 +- boil/reflect_test.go | 2 +- imports.go | 39 ++++++++++++++++++++------------ imports_test.go | 10 ++++---- 7 files changed, 38 insertions(+), 27 deletions(-) diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index 507eec16f..af2cd0b94 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -279,9 +279,9 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.Type = "null.Float64" case "real": c.Type = "null.Float32" - case "bit", "interval", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml": + case "bit", "interval", "bit varying", "character", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": c.Type = "null.String" - case "bytea": + case "bytea", "json", "jsonb": c.Type = "[]byte" case "boolean": c.Type = "null.Bool" @@ -302,8 +302,10 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.Type = "float64" case "real": c.Type = "float32" - case "bit", "interval", "uuint", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml": + case "bit", "interval", "uuint", "bit varying", "character", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": c.Type = "string" + case "json", "jsonb": + c.Type = "json.RawMessage" case "bytea": c.Type = "[]byte" case "boolean": diff --git a/boil/helpers_test.go b/boil/helpers_test.go index 756834547..73d284e68 100644 --- a/boil/helpers_test.go +++ b/boil/helpers_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "gopkg.in/nullbio/null.v4" + "gopkg.in/nullbio/null.v5" ) type testObj struct { diff --git a/boil/randomize/randomize.go b/boil/randomize/randomize.go index 89e516acb..7fe36108c 100644 --- a/boil/randomize/randomize.go +++ b/boil/randomize/randomize.go @@ -9,7 +9,7 @@ import ( "sync/atomic" "time" - "gopkg.in/nullbio/null.v4" + "gopkg.in/nullbio/null.v5" "github.com/pkg/errors" "github.com/satori/go.uuid" diff --git a/boil/randomize/randomize_test.go b/boil/randomize/randomize_test.go index ee028b9ab..a3ba08304 100644 --- a/boil/randomize/randomize_test.go +++ b/boil/randomize/randomize_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "gopkg.in/nullbio/null.v4" + "gopkg.in/nullbio/null.v5" ) func TestRandomizeStruct(t *testing.T) { diff --git a/boil/reflect_test.go b/boil/reflect_test.go index 0a05fd680..dab762a92 100644 --- a/boil/reflect_test.go +++ b/boil/reflect_test.go @@ -9,7 +9,7 @@ import ( "time" "gopkg.in/DATA-DOG/go-sqlmock.v1" - "gopkg.in/nullbio/null.v4" + "gopkg.in/nullbio/null.v5" ) func bin64(i uint64) string { diff --git a/imports.go b/imports.go index b753edaad..599bf560f 100644 --- a/imports.go +++ b/imports.go @@ -246,51 +246,60 @@ var defaultTestMainImports = map[string]imports{ // TranslateColumnType to see the type assignments. var importsBasedOnType = map[string]imports{ "null.Float32": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Float64": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Int": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Int8": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Int16": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Int32": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Int64": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Uint": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Uint8": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Uint16": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Uint32": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Uint64": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.String": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Bool": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "null.Time": { - thirdParty: importList{`"gopkg.in/nullbio/null.v4"`}, + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + }, + "null.JSON": { + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, + }, + "null.Bytes": { + thirdParty: importList{`"gopkg.in/nullbio/null.v5"`}, }, "time.Time": { standard: importList{`"time"`}, }, + "json.RawBytes": { + standard: importList{`"encoding/json"`}, + }, } diff --git a/imports_test.go b/imports_test.go index 79863fa2c..b07fca79b 100644 --- a/imports_test.go +++ b/imports_test.go @@ -75,7 +75,7 @@ func TestCombineTypeImports(t *testing.T) { }, thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, - `"gopkg.in/nullbio/null.v4"`, + `"gopkg.in/nullbio/null.v5"`, }, } @@ -108,7 +108,7 @@ func TestCombineTypeImports(t *testing.T) { }, thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, - `"gopkg.in/nullbio/null.v4"`, + `"gopkg.in/nullbio/null.v5"`, }, } @@ -124,7 +124,7 @@ func TestCombineImports(t *testing.T) { a := imports{ standard: importList{"fmt"}, - thirdParty: importList{"github.com/vattle/sqlboiler", "gopkg.in/nullbio/null.v4"}, + thirdParty: importList{"github.com/vattle/sqlboiler", "gopkg.in/nullbio/null.v5"}, } b := imports{ standard: importList{"os"}, @@ -136,8 +136,8 @@ func TestCombineImports(t *testing.T) { if c.standard[0] != "fmt" && c.standard[1] != "os" { t.Errorf("Wanted: fmt, os got: %#v", c.standard) } - if c.thirdParty[0] != "github.com/vattle/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v4" { - t.Errorf("Wanted: github.com/vattle/sqlboiler, gopkg.in/nullbio/null.v4 got: %#v", c.thirdParty) + if c.thirdParty[0] != "github.com/vattle/sqlboiler" && c.thirdParty[1] != "gopkg.in/nullbio/null.v5" { + t.Errorf("Wanted: github.com/vattle/sqlboiler, gopkg.in/nullbio/null.v5 got: %#v", c.thirdParty) } } From 757cbde016889bc6e100293ae3794cde283c1ea3 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 8 Sep 2016 19:07:33 +1000 Subject: [PATCH 04/64] Add NullJSON and JSON types, fix randomize struct --- bdb/drivers/postgres.go | 8 ++- boil/randomize/randomize.go | 24 +++++++- boil/types/json.go | 77 +++++++++++++++++++++++ boil/types/json_test.go | 119 ++++++++++++++++++++++++++++++++++++ imports.go | 4 +- 5 files changed, 224 insertions(+), 8 deletions(-) create mode 100644 boil/types/json.go create mode 100644 boil/types/json_test.go diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index af2cd0b94..ac2bd476d 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -281,8 +281,10 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.Type = "null.Float32" case "bit", "interval", "bit varying", "character", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": c.Type = "null.String" - case "bytea", "json", "jsonb": - c.Type = "[]byte" + case "bytea": + c.Type = "null.Bytes" + case "json", "jsonb": + c.Type = "null.JSON" case "boolean": c.Type = "null.Bool" case "date", "time", "timestamp without time zone", "timestamp with time zone": @@ -305,7 +307,7 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { case "bit", "interval", "uuint", "bit varying", "character", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": c.Type = "string" case "json", "jsonb": - c.Type = "json.RawMessage" + c.Type = "types.JSON" case "bytea": c.Type = "[]byte" case "boolean": diff --git a/boil/randomize/randomize.go b/boil/randomize/randomize.go index 7fe36108c..4bfaa2a1d 100644 --- a/boil/randomize/randomize.go +++ b/boil/randomize/randomize.go @@ -2,6 +2,7 @@ package randomize import ( + "fmt" "reflect" "regexp" "sort" @@ -13,6 +14,7 @@ import ( "github.com/pkg/errors" "github.com/satori/go.uuid" + "github.com/vattle/sqlboiler/boil/types" "github.com/vattle/sqlboiler/strmangle" ) @@ -32,11 +34,13 @@ var ( typeNullString = reflect.TypeOf(null.String{}) typeNullBool = reflect.TypeOf(null.Bool{}) typeNullTime = reflect.TypeOf(null.Time{}) + typeNullBytes = reflect.TypeOf(null.Bytes{}) + typeNullJSON = reflect.TypeOf(null.JSON{}) typeTime = reflect.TypeOf(time.Time{}) + typeJSON = reflect.TypeOf(types.JSON{}) + rgxValidTime = regexp.MustCompile(`[2-9]+`) - rgxValidTime = regexp.MustCompile(`[2-9]+`) - - validatedTypes = []string{"uuid", "interval"} + validatedTypes = []string{"uuid", "interval", "json", "jsonb"} ) // Seed is an atomic counter for pseudo-randomization structs. Using full @@ -163,6 +167,10 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo field.Set(reflect.ValueOf(value)) return nil } + case typeNullJSON: + value = null.NewJSON([]byte(fmt.Sprintf(`"%s"`, randStr(s, 1))), true) + field.Set(reflect.ValueOf(value)) + return nil } } else { switch kind { @@ -178,6 +186,12 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo return nil } } + switch typ { + case typeJSON: + value = []byte(fmt.Sprintf(`"%s"`, randStr(s, 1))) + field.Set(reflect.ValueOf(value)) + return nil + } } } @@ -250,6 +264,8 @@ func getStructNullValue(typ reflect.Type) interface{} { return null.NewUint32(0, false) case typeNullUint64: return null.NewUint64(0, false) + case typeNullBytes: + return null.NewBytes(nil, false) } return nil @@ -292,6 +308,8 @@ func getStructRandValue(s *Seed, typ reflect.Type) interface{} { return null.NewUint32(uint32(s.nextInt()), true) case typeNullUint64: return null.NewUint64(uint64(s.nextInt()), true) + case typeNullBytes: + return null.NewBytes(randByteSlice(s, 16), true) } return nil diff --git a/boil/types/json.go b/boil/types/json.go new file mode 100644 index 000000000..b42e694b3 --- /dev/null +++ b/boil/types/json.go @@ -0,0 +1,77 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" + "errors" +) + +// JSON is an alias for json.RawMessage, which is +// a []byte underneath. +// JSON implements Marshal and Unmarshal. +type JSON json.RawMessage + +// String output your JSON. +func (j JSON) String() string { + return string(j) +} + +// Unmarshal your JSON variable into dest. +func (j JSON) Unmarshal(dest interface{}) error { + return json.Unmarshal(j, dest) +} + +// Marshal obj into your JSON variable. +func (j *JSON) Marshal(obj interface{}) error { + res, err := json.Marshal(obj) + if err != nil { + return err + } + + *j = res + return nil +} + +// UnmarshalJSON sets *j to a copy of data. +func (j *JSON) UnmarshalJSON(data []byte) error { + if j == nil { + return errors.New("JSON: UnmarshalJSON on nil pointer") + } + + *j = append((*j)[0:0], data...) + return nil +} + +// MarshalJSON returns j as the JSON encoding of j. +func (j JSON) MarshalJSON() ([]byte, error) { + return j, nil +} + +// Value returns j as a value. +// Unmarshal into RawMessage for validation. +func (j JSON) Value() (driver.Value, error) { + var r json.RawMessage + if err := j.Unmarshal(&r); err != nil { + return nil, err + } + + return []byte(r), nil +} + +// Scan stores the src in *j. +func (j *JSON) Scan(src interface{}) error { + var source []byte + + switch src.(type) { + case string: + source = []byte(src.(string)) + case []byte: + source = src.([]byte) + default: + return errors.New("Incompatible type for JSON") + } + + *j = JSON(append((*j)[0:0], source...)) + + return nil +} diff --git a/boil/types/json_test.go b/boil/types/json_test.go new file mode 100644 index 000000000..9ba232711 --- /dev/null +++ b/boil/types/json_test.go @@ -0,0 +1,119 @@ +package types + +import ( + "bytes" + "testing" +) + +func TestJSONString(t *testing.T) { + t.Parallel() + + j := JSON("hello") + if j.String() != "hello" { + t.Errorf("Expected %q, got %s", "hello", j.String()) + } +} + +func TestJSONUnmarshal(t *testing.T) { + t.Parallel() + + type JSONTest struct { + Name string + Age int + } + var jt JSONTest + + j := JSON(`{"Name":"hi","Age":15}`) + err := j.Unmarshal(&jt) + if err != nil { + t.Error(err) + } + + if jt.Name != "hi" { + t.Errorf("Expected %q, got %s", "hi", jt.Name) + } + if jt.Age != 15 { + t.Errorf("Expected %v, got %v", 15, jt.Age) + } +} + +func TestJSONMarshal(t *testing.T) { + t.Parallel() + + type JSONTest struct { + Name string + Age int + } + jt := JSONTest{ + Name: "hi", + Age: 15, + } + + var j JSON + err := j.Marshal(jt) + if err != nil { + t.Error(err) + } + + if j.String() != `{"Name":"hi","Age":15}` { + t.Errorf("expected %s, got %s", `{"Name":"hi","Age":15}`, j.String()) + } +} + +func TestJSONUnmarshalJSON(t *testing.T) { + t.Parallel() + + j := JSON(nil) + + err := j.UnmarshalJSON(JSON(`"hi"`)) + if err != nil { + t.Error(err) + } + + if j.String() != `"hi"` { + t.Errorf("Expected %q, got %s", "hi", j.String()) + } +} + +func TestJSONMarshalJSON(t *testing.T) { + t.Parallel() + + j := JSON(`"hi"`) + res, err := j.MarshalJSON() + if err != nil { + t.Error(err) + } + + if !bytes.Equal(res, []byte(`"hi"`)) { + t.Errorf("Expected %q, got %v", `"hi"`, res) + } +} + +func TestJSONValue(t *testing.T) { + t.Parallel() + + j := JSON(`{"Name":"hi","Age":15}`) + v, err := j.Value() + if err != nil { + t.Error(err) + } + + if !bytes.Equal(j, v.([]byte)) { + t.Errorf("byte mismatch, %v %v", j, v) + } +} + +func TestJSONScan(t *testing.T) { + t.Parallel() + + j := JSON{} + + err := j.Scan(`"hello"`) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(j, []byte(`"hello"`)) { + t.Errorf("bad []byte: %#v ≠ %#v\n", j, string([]byte(`"hello"`))) + } +} diff --git a/imports.go b/imports.go index 599bf560f..c6510c23c 100644 --- a/imports.go +++ b/imports.go @@ -299,7 +299,7 @@ var importsBasedOnType = map[string]imports{ "time.Time": { standard: importList{`"time"`}, }, - "json.RawBytes": { - standard: importList{`"encoding/json"`}, + "types.JSON": { + thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, }, } From 14c8f651c46bdc408ac3998d186108d6ca834dc2 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 8 Sep 2016 19:52:52 +1000 Subject: [PATCH 05/64] Add all postgres types to test_schema --- testdata/test_schema.sql | 47 ++++++++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/testdata/test_schema.sql b/testdata/test_schema.sql index 65b786d73..9dd4953c7 100644 --- a/testdata/test_schema.sql +++ b/testdata/test_schema.sql @@ -93,7 +93,46 @@ CREATE TABLE magic ( strange_three timestamp without time zone default (now() at time zone 'utc'), strange_four timestamp with time zone default (now() at time zone 'utc'), strange_five interval NOT NULL DEFAULT '21 days', - strange_six interval NULL DEFAULT '23 hours' + strange_six interval NULL DEFAULT '23 hours', + + aa json NULL, + bb json NOT NULL, + cc jsonb NULL, + dd jsonb NOT NULL, + ee box NULL, + ff box NOT NULL, + gg cidr NULL, + hh cidr NOT NULL, + ii circle NULL, + jj circle NOT NULL, + kk double precision NULL, + ll double precision NOT NULL, + mm inet NULL, + nn inet NOT NULL, + oo line NULL, + pp line NOT NULL, + qq lseg NULL, + rr lseg NOT NULL, + ss macaddr NULL, + tt macaddr NOT NULL, + uu money NULL, + vv money NOT NULL, + ww path NULL, + xx path NOT NULL, + yy pg_lsn NULL, + zz pg_lsn NOT NULL, + aaa point NULL, + bbb point NOT NULL, + ccc polygon NULL, + ddd polygon NOT NULL, + eee tsquery NULL, + fff tsquery NOT NULL, + ggg tsvector NULL, + hhh tsvector NOT NULL, + iii txid_snapshot NULL, + jjj txid_snapshot NOT NULL, + kkk xml NULL, + lll xml NOT NULL ); create table owner ( @@ -136,12 +175,6 @@ create table spider_toys ( primary key (spider_id) ); -/* - Test: - * Variations of capitalization - * Single value columns - * Primary key as only value -*/ create table pals ( pal character varying, primary key (pal) From 5300a0f6a49fb2dd1e8976030a49235db1807e86 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Fri, 9 Sep 2016 00:35:43 +1000 Subject: [PATCH 06/64] Fix all postgres types, fix all randomize types --- bdb/drivers/postgres.go | 8 +- boil/randomize/randomize.go | 153 +++++++++++++++++++++++++++++++++++- 2 files changed, 156 insertions(+), 5 deletions(-) diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index ac2bd476d..be69816a8 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -275,11 +275,11 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.Type = "null.Int" case "smallint", "smallserial": c.Type = "null.Int16" - case "decimal", "numeric", "double precision", "money": + case "decimal", "numeric", "double precision": c.Type = "null.Float64" case "real": c.Type = "null.Float32" - case "bit", "interval", "bit varying", "character", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": + case "bit", "interval", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": c.Type = "null.String" case "bytea": c.Type = "null.Bytes" @@ -300,11 +300,11 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.Type = "int" case "smallint", "smallserial": c.Type = "int16" - case "decimal", "numeric", "double precision", "money": + case "decimal", "numeric", "double precision": c.Type = "float64" case "real": c.Type = "float32" - case "bit", "interval", "uuint", "bit varying", "character", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": + case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": c.Type = "string" case "json", "jsonb": c.Type = "types.JSON" diff --git a/boil/randomize/randomize.go b/boil/randomize/randomize.go index 4bfaa2a1d..398d5e82c 100644 --- a/boil/randomize/randomize.go +++ b/boil/randomize/randomize.go @@ -3,6 +3,7 @@ package randomize import ( "fmt" + "math/rand" "reflect" "regexp" "sort" @@ -40,7 +41,12 @@ var ( typeJSON = reflect.TypeOf(types.JSON{}) rgxValidTime = regexp.MustCompile(`[2-9]+`) - validatedTypes = []string{"uuid", "interval", "json", "jsonb"} + validatedTypes = []string{ + "inet", "line", "uuid", "interval", + "json", "jsonb", "box", "cidr", "circle", + "lseg", "macaddr", "path", "pg_lsn", "point", + "polygon", "txid_snapshot", "money", + } ) // Seed is an atomic counter for pseudo-randomization structs. Using full @@ -167,6 +173,47 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo field.Set(reflect.ValueOf(value)) return nil } + if fieldType == "box" || fieldType == "line" || fieldType == "lseg" || + fieldType == "path" || fieldType == "polygon" { + value = null.NewString(randBox(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "cidr" || fieldType == "inet" { + value = null.NewString(randNetAddr(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "macaddr" { + value = null.NewString(randMacAddr(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "circle" { + value = null.NewString(randCircle(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "pg_lsn" { + value = null.NewString(randLsn(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "point" { + value = null.NewString(randPoint(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "txid_snapshot" { + value = null.NewString(randTxID(), true) + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "money" { + value = null.NewString(randMoney(s), true) + field.Set(reflect.ValueOf(value)) + return nil + } case typeNullJSON: value = null.NewJSON([]byte(fmt.Sprintf(`"%s"`, randStr(s, 1))), true) field.Set(reflect.ValueOf(value)) @@ -185,6 +232,47 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo field.Set(reflect.ValueOf(value)) return nil } + if fieldType == "box" || fieldType == "line" || fieldType == "lseg" || + fieldType == "path" || fieldType == "polygon" { + value = randBox() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "cidr" || fieldType == "inet" { + value = randNetAddr() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "macaddr" { + value = randMacAddr() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "circle" { + value = randCircle() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "pg_lsn" { + value = randLsn() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "point" { + value = randPoint() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "txid_snapshot" { + value = randTxID() + field.Set(reflect.ValueOf(value)) + return nil + } + if fieldType == "money" { + value = randMoney(s) + field.Set(reflect.ValueOf(value)) + return nil + } } switch typ { case typeJSON: @@ -416,3 +504,66 @@ func randByteSlice(s *Seed, ln int) []byte { return str } + +func randPoint() string { + a := rand.Intn(100) + b := a + 1 + return fmt.Sprintf("(%d,%d)", a, b) +} + +func randBox() string { + a := rand.Intn(100) + b := a + 1 + c := a + 2 + d := a + 3 + return fmt.Sprintf("(%d,%d),(%d,%d)", a, b, c, d) +} + +func randCircle() string { + a, b, c := rand.Intn(100), rand.Intn(100), rand.Intn(100) + return fmt.Sprintf("((%d,%d),%d)", a, b, c) +} + +func randNetAddr() string { + return fmt.Sprintf( + "%d.%d.%d.%d", + rand.Intn(254)+1, + rand.Intn(254)+1, + rand.Intn(254)+1, + rand.Intn(254)+1, + ) +} + +func randMacAddr() string { + buf := make([]byte, 6) + _, err := rand.Read(buf) + if err != nil { + panic(err) + } + + // Set the local bit + buf[0] |= 2 + return fmt.Sprintf( + "%02x:%02x:%02x:%02x:%02x:%02x", + buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], + ) +} + +func randLsn() string { + a := rand.Int63n(9000000) + b := rand.Int63n(9000000) + return fmt.Sprintf("%d/%d", a, b) +} + +func randTxID() string { + // Order of integers is relevant + a := rand.Intn(200) + 100 + b := a + 100 + c := a + d := a + 50 + return fmt.Sprintf("%d:%d:%d,%d", a, b, c, d) +} + +func randMoney(s *Seed) string { + return fmt.Sprintf("%d.00", s.nextInt()) +} From ac42fbc2c7d968dc3a3aaba61c23ca0437ca7d45 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Fri, 9 Sep 2016 01:22:22 +1000 Subject: [PATCH 07/64] Add json to uppercase words --- strmangle/strmangle.go | 1 + 1 file changed, 1 insertion(+) diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index 822891963..86cab339e 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -22,6 +22,7 @@ var uppercaseWords = map[string]struct{}{ "id": {}, "uid": {}, "uuid": {}, + "json": {}, } func init() { From 8d486ef51be3f6a2a9af48bde1391258c6dcf153 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Fri, 9 Sep 2016 03:39:27 +1000 Subject: [PATCH 08/64] Add schema FAQ to readme --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index dbadf9ec2..2bc88e20b 100644 --- a/README.md +++ b/README.md @@ -1053,6 +1053,12 @@ The generated models might import a couple of packages that are not on your syst `cd` into your generated models directory and type `go get -u -t` to fetch them. You will only need to run this command once, not per generation. +#### How should I handle multiple schemas? + +If your database uses multiple schemas you should generate a new package for each of your schemas. +Note that this only applies to databases that use real, SQL standard schemas (like PostgreSQL), not +fake schemas (like MySQL). + ## Benchmarks If you'd like to run the benchmarks yourself check out our [boilbench](https://github.com/vattle/boilbench) repo. From 3929729a2c7ebe5af92c2b79255aa66b46bceb33 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Fri, 9 Sep 2016 03:43:36 +1000 Subject: [PATCH 09/64] Add schema flags to readme --- README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 2bc88e20b..485fc256a 100644 --- a/README.md +++ b/README.md @@ -223,6 +223,7 @@ not to pass them through the command line or environment variables: | Name | Default | | --- | --- | | basedir | none | +| schema | "public" | | pkgname | "models" | | output | "models" | | whitelist | [ ] | @@ -261,15 +262,16 @@ sqlboiler postgres Flags: -b, --basedir string The base directory has the templates and templates_test folders - -d, --debug Debug mode prints stack traces on error -w, --whitelist stringSlice Only include these tables in your generated package -x, --exclude stringSlice Tables to be excluded from the generated package + -s, --schema string The name of your database schema, for databases that support real schemas (default "public") + -p, --pkgname string The name you wish to assign to your generated package (default "models") + -o, --output string The name of the folder to output to (default "models") + -t, --tag stringSlice Struct tags to be included on your models in addition to json, yaml, toml + -d, --debug Debug mode prints stack traces on error --no-auto-timestamps Disable automatic timestamps for created_at/updated_at --no-hooks Disable hooks feature for your models --no-tests Disable generated go test files - -o, --output string The name of the folder to output to (default "models") - -p, --pkgname string The name you wish to assign to your generated package (default "models") - -t, --tag stringSlice Struct tags to be included on your models in addition to json, yaml, toml ``` Follow the steps below to do some basic model generation. Once we've generated From 1c8a9d2e398dea22c6e6083f08fcf235834f6710 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Fri, 9 Sep 2016 07:23:10 +1000 Subject: [PATCH 10/64] Add schema feature to everything (except rels) * Add strmangle SchemaTable helper --- README.md | 2 +- bdb/drivers/mock.go | 8 +++--- bdb/drivers/postgres.go | 30 ++++++++++---------- bdb/interface.go | 22 +++++++------- bdb/table.go | 7 +++-- config.go | 1 + main.go | 2 ++ sqlboiler.go | 8 ++++-- strmangle/strmangle.go | 12 ++++++++ templates.go | 2 ++ templates/04_relationship_to_one.tpl | 5 ++-- templates/05_relationship_to_many.tpl | 6 ++-- templates/06_relationship_to_one_eager.tpl | 7 +++-- templates/07_relationship_to_many_eager.tpl | 4 +-- templates/09_relationship_to_many_setops.tpl | 8 +++--- templates/10_all.tpl | 2 +- templates/11_find.tpl | 2 +- templates/12_insert.tpl | 4 +-- templates/13_update.tpl | 4 +-- templates/15_delete.tpl | 4 +-- templates/16_reload.tpl | 2 +- templates/17_exists.tpl | 2 +- templates_test/relationship_to_many.tpl | 4 +-- 23 files changed, 88 insertions(+), 60 deletions(-) diff --git a/README.md b/README.md index 485fc256a..eb6576a8d 100644 --- a/README.md +++ b/README.md @@ -1065,7 +1065,7 @@ fake schemas (like MySQL). If you'd like to run the benchmarks yourself check out our [boilbench](https://github.com/vattle/boilbench) repo. -Here are the results **(lower is better)**: +Here are the results (lower is better): `go test -bench . -benchmem` ``` diff --git a/bdb/drivers/mock.go b/bdb/drivers/mock.go index 93d18cc45..4371e6f63 100644 --- a/bdb/drivers/mock.go +++ b/bdb/drivers/mock.go @@ -9,7 +9,7 @@ import ( type MockDriver struct{} // TableNames returns a list of mock table names -func (m *MockDriver) TableNames(whitelist, exclude []string) ([]string, error) { +func (m *MockDriver) TableNames(schema string, whitelist, exclude []string) ([]string, error) { if len(whitelist) > 0 { return whitelist, nil } @@ -18,7 +18,7 @@ func (m *MockDriver) TableNames(whitelist, exclude []string) ([]string, error) { } // Columns returns a list of mock columns -func (m *MockDriver) Columns(tableName string) ([]bdb.Column, error) { +func (m *MockDriver) Columns(schema, tableName string) ([]bdb.Column, error) { return map[string][]bdb.Column{ "pilots": { {Name: "id", Type: "int", DBType: "integer"}, @@ -59,7 +59,7 @@ func (m *MockDriver) Columns(tableName string) ([]bdb.Column, error) { } // ForeignKeyInfo returns a list of mock foreignkeys -func (m *MockDriver) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, error) { +func (m *MockDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) { return map[string][]bdb.ForeignKey{ "jets": { {Table: "jets", Name: "jets_pilot_id_fk", Column: "pilot_id", ForeignTable: "pilots", ForeignColumn: "id", ForeignColumnUnique: true}, @@ -82,7 +82,7 @@ func (m *MockDriver) TranslateColumnType(c bdb.Column) bdb.Column { } // PrimaryKeyInfo returns mock primary key info for the passed in table name -func (m *MockDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, error) { +func (m *MockDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, error) { return map[string]*bdb.PrimaryKey{ "pilots": { Name: "pilot_id_pkey", diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index be69816a8..dace79eba 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -81,11 +81,11 @@ func (p *PostgresDriver) UseLastInsertID() bool { // TableNames connects to the postgres database and // retrieves all table names from the information_schema where the -// table schema is public. It uses a whitelist and exclude list. -func (p *PostgresDriver) TableNames(whitelist, exclude []string) ([]string, error) { +// table schema is schema. It uses a whitelist and exclude list. +func (p *PostgresDriver) TableNames(schema string, whitelist, exclude []string) ([]string, error) { var names []string - query := `select table_name from information_schema.tables where table_schema = 'public'` + query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = '%s'`, schema) if len(whitelist) > 0 { query = query + fmt.Sprintf("and table_name in ('%s');", strings.Join(whitelist, "','")) } else if len(exclude) > 0 { @@ -114,7 +114,7 @@ func (p *PostgresDriver) TableNames(whitelist, exclude []string) ([]string, erro // from the database information_schema.columns. It retrieves the column names // and column types and returns those as a []Column after TranslateColumnType() // converts the SQL types to Go types, for example: "varchar" to "string" -func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) { +func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error) { var columns []bdb.Column rows, err := p.dbConn.Query(` @@ -132,11 +132,11 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) { inner join pg_index pgi on pgi.indexrelid = pgc.oid inner join pg_attribute pga on pga.attrelid = pgi.indrelid and pga.attnum = ANY(pgi.indkey) where - pgix.schemaname = 'public' and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true + pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true )) as is_unique from information_schema.columns as c - where table_name=$1 and table_schema = 'public'; - `, tableName) + where table_name=$2 and table_schema = $3; + `, schema, tableName, schema) if err != nil { return nil, err @@ -172,16 +172,16 @@ func (p *PostgresDriver) Columns(tableName string) ([]bdb.Column, error) { } // PrimaryKeyInfo looks up the primary key for a table. -func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, error) { +func (p *PostgresDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, error) { pkey := &bdb.PrimaryKey{} var err error query := ` select tc.constraint_name from information_schema.table_constraints as tc - where tc.table_name = $1 and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = 'public';` + where tc.table_name = $1 and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = $2;` - row := p.dbConn.QueryRow(query, tableName) + row := p.dbConn.QueryRow(query, tableName, schema) if err = row.Scan(&pkey.Name); err != nil { if err == sql.ErrNoRows { return nil, nil @@ -192,10 +192,10 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, erro queryColumns := ` select kcu.column_name from information_schema.key_column_usage as kcu - where constraint_name = $1 and table_schema = 'public';` + where constraint_name = $1 and table_schema = $2;` var rows *sql.Rows - if rows, err = p.dbConn.Query(queryColumns, pkey.Name); err != nil { + if rows, err = p.dbConn.Query(queryColumns, pkey.Name, schema); err != nil { return nil, err } defer rows.Close() @@ -222,7 +222,7 @@ func (p *PostgresDriver) PrimaryKeyInfo(tableName string) (*bdb.PrimaryKey, erro } // ForeignKeyInfo retrieves the foreign keys for a given table name. -func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, error) { +func (p *PostgresDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) { var fkeys []bdb.ForeignKey query := ` @@ -235,11 +235,11 @@ func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]bdb.ForeignKey, err from information_schema.table_constraints as tc inner join information_schema.key_column_usage as kcu ON tc.constraint_name = kcu.constraint_name inner join information_schema.constraint_column_usage as ccu ON tc.constraint_name = ccu.constraint_name - where tc.table_name = $1 and tc.constraint_type = 'FOREIGN KEY' and tc.table_schema = 'public';` + where tc.table_name = $1 and tc.constraint_type = 'FOREIGN KEY' and tc.table_schema = $2;` var rows *sql.Rows var err error - if rows, err = p.dbConn.Query(query, tableName); err != nil { + if rows, err = p.dbConn.Query(query, tableName, schema); err != nil { return nil, err } diff --git a/bdb/interface.go b/bdb/interface.go index b59dbe1e4..a7f14e09a 100644 --- a/bdb/interface.go +++ b/bdb/interface.go @@ -6,10 +6,10 @@ import "github.com/pkg/errors" // Interface for a database driver. Functionality required to support a specific // database type (eg, MySQL, Postgres etc.) type Interface interface { - TableNames(whitelist, exclude []string) ([]string, error) - Columns(tableName string) ([]Column, error) - PrimaryKeyInfo(tableName string) (*PrimaryKey, error) - ForeignKeyInfo(tableName string) ([]ForeignKey, error) + TableNames(schema string, whitelist, exclude []string) ([]string, error) + Columns(schema, tableName string) ([]Column, error) + PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, error) + ForeignKeyInfo(schema, tableName string) ([]ForeignKey, error) // TranslateColumnType takes a Database column type and returns a go column type. TranslateColumnType(Column) Column @@ -26,19 +26,21 @@ type Interface interface { // Tables returns the metadata for all tables, minus the tables // specified in the exclude slice. -func Tables(db Interface, whitelist, exclude []string) ([]Table, error) { +func Tables(db Interface, schema string, whitelist, exclude []string) ([]Table, error) { var err error - names, err := db.TableNames(whitelist, exclude) + names, err := db.TableNames(schema, whitelist, exclude) if err != nil { return nil, errors.Wrap(err, "unable to get table names") } var tables []Table for _, name := range names { - t := Table{Name: name} + t := Table{ + Name: name, + } - if t.Columns, err = db.Columns(name); err != nil { + if t.Columns, err = db.Columns(schema, name); err != nil { return nil, errors.Wrapf(err, "unable to fetch table column info (%s)", name) } @@ -46,11 +48,11 @@ func Tables(db Interface, whitelist, exclude []string) ([]Table, error) { t.Columns[i] = db.TranslateColumnType(c) } - if t.PKey, err = db.PrimaryKeyInfo(name); err != nil { + if t.PKey, err = db.PrimaryKeyInfo(schema, name); err != nil { return nil, errors.Wrapf(err, "unable to fetch table pkey info (%s)", name) } - if t.FKeys, err = db.ForeignKeyInfo(name); err != nil { + if t.FKeys, err = db.ForeignKeyInfo(schema, name); err != nil { return nil, errors.Wrapf(err, "unable to fetch table fkey info (%s)", name) } diff --git a/bdb/table.go b/bdb/table.go index 82e872d7e..28fa6248c 100644 --- a/bdb/table.go +++ b/bdb/table.go @@ -4,8 +4,11 @@ import "fmt" // Table metadata from the database schema. type Table struct { - Name string - Columns []Column + Name string + // For dbs with real schemas, like Postgres. + // Example value: "schema_name"."table_name" + SchemaName string + Columns []Column PKey *PrimaryKey FKeys []ForeignKey diff --git a/config.go b/config.go index 662747dfc..2c5a2b3e7 100644 --- a/config.go +++ b/config.go @@ -3,6 +3,7 @@ package main // Config for the running of the commands type Config struct { DriverName string + Schema string PkgName string OutFolder string BaseDir string diff --git a/main.go b/main.go index 058e617e0..2b7992332 100644 --- a/main.go +++ b/main.go @@ -61,6 +61,7 @@ func main() { // Set up the cobra root command flags rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to") + rootCmd.PersistentFlags().StringP("schema", "s", "public", "The name of your database schema, for databases that support real schemas") rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package") rootCmd.PersistentFlags().StringP("basedir", "b", "", "The base directory has the templates and templates_test folders") rootCmd.PersistentFlags().StringSliceP("exclude", "x", nil, "Tables to be excluded from the generated package") @@ -108,6 +109,7 @@ func preRun(cmd *cobra.Command, args []string) error { cmdConfig = &Config{ DriverName: driverName, OutFolder: viper.GetString("output"), + Schema: viper.GetString("schema"), PkgName: viper.GetString("pkgname"), Debug: viper.GetBool("debug"), NoTests: viper.GetBool("no-tests"), diff --git a/sqlboiler.go b/sqlboiler.go index fcecb0d77..0757c3c5c 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -59,7 +59,7 @@ func New(config *Config) (*State, error) { return nil, errors.Wrap(err, "unable to connect to the database") } - err = s.initTables(config.WhitelistTables, config.ExcludeTables) + err = s.initTables(config.Schema, config.WhitelistTables, config.ExcludeTables) if err != nil { return nil, errors.Wrap(err, "unable to initialize tables") } @@ -96,6 +96,7 @@ func New(config *Config) (*State, error) { func (s *State) Run(includeTests bool) error { singletonData := &templateData{ Tables: s.Tables, + Schema: s.Config.Schema, DriverName: s.Config.DriverName, UseLastInsertID: s.Driver.UseLastInsertID(), PkgName: s.Config.PkgName, @@ -127,6 +128,7 @@ func (s *State) Run(includeTests bool) error { data := &templateData{ Tables: s.Tables, Table: table, + Schema: s.Config.Schema, DriverName: s.Config.DriverName, UseLastInsertID: s.Driver.UseLastInsertID(), PkgName: s.Config.PkgName, @@ -239,9 +241,9 @@ func (s *State) initDriver(driverName string) error { } // initTables retrieves all "public" schema table names from the database. -func (s *State) initTables(whitelist, exclude []string) error { +func (s *State) initTables(schema string, whitelist, exclude []string) error { var err error - s.Tables, err = bdb.Tables(s.Driver, whitelist, exclude) + s.Tables, err = bdb.Tables(s.Driver, schema, whitelist, exclude) if err != nil { return errors.Wrap(err, "unable to fetch table data") } diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index 86cab339e..5bda771e9 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -34,6 +34,18 @@ func init() { boilRuleset = newBoilRuleset() } +// SchemaTable returns a table name with a schema prefixed if +// using a database that supports real schemas, for example, +// for Postgres: "schema_name"."table_name", versus +// simply "table_name" for MySQL (because it does not support real schemas) +func SchemaTable(driver string, schema string, table string) string { + if driver == "postgres" { + return fmt.Sprintf(`"%s"."%s"`, schema, table) + } + + return fmt.Sprintf(`"%s"`, table) +} + // IdentQuote attempts to quote simple identifiers in SQL tatements func IdentQuote(s string) string { if strings.ToLower(s) == "null" { diff --git a/templates.go b/templates.go index e760711d7..8d171f2cc 100644 --- a/templates.go +++ b/templates.go @@ -15,6 +15,7 @@ import ( type templateData struct { Tables []bdb.Table Table bdb.Table + Schema string DriverName string UseLastInsertID bool PkgName string @@ -141,6 +142,7 @@ var templateFunctions = template.FuncMap{ // Database related mangling "whereClause": strmangle.WhereClause, + "schemaTable": strmangle.SchemaTable, // Text helpers "textsFromForeignKey": textsFromForeignKey, diff --git a/templates/04_relationship_to_one.tpl b/templates/04_relationship_to_one.tpl index 5081f397d..9b75a260c 100644 --- a/templates/04_relationship_to_one.tpl +++ b/templates/04_relationship_to_one.tpl @@ -14,12 +14,13 @@ func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}(exec bo queryMods = append(queryMods, mods...) query := {{.ForeignTable.NamePluralGo}}(exec, queryMods...) - boil.SetFrom(query.Query, "{{.ForeignTable.Name}}") + boil.SetFrom(query.Query, `{{schemaTable .DriverName .Schema .ForeignTable.Name}}`) return query } +{{end -}}{{/* end define */}} -{{end -}} +{{/* Begin execution of template for one-to-one relationship. */}} {{- if .Table.IsJoinTable -}} {{- else -}} {{- $dot := . -}} diff --git a/templates/05_relationship_to_many.tpl b/templates/05_relationship_to_many.tpl index e886258ef..dd08d32c6 100644 --- a/templates/05_relationship_to_many.tpl +++ b/templates/05_relationship_to_many.tpl @@ -5,8 +5,10 @@ {{- range .Table.ToManyRelationships -}} {{- $varNameSingular := .ForeignTable | singular | camelCase -}} {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} +{{- /* Begin execution of template for many-to-one relationship. */ -}} {{- template "relationship_to_one_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}} {{- else -}} +{{- /* Begin execution of template for many-to-many relationship. */ -}} {{- $rel := textsFromRelationship $dot.Tables $table . -}} // {{$rel.Function.Name}}G retrieves all the {{$rel.LocalTable.NameSingular}}'s {{$rel.ForeignTable.NameHumanReadable}} {{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}. @@ -27,7 +29,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Na {{if .ToJoinTable -}} queryMods = append(queryMods, - qm.InnerJoin(`"{{.JoinTable}}" as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}"`), + qm.InnerJoin(`{{schemaTable $dot.DriverName $dot.Schema .JoinTable}} as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}"`), qm.Where(`"{{id 1}}"."{{.JoinLocalColumn}}"=$1`, {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), ) {{else -}} @@ -37,7 +39,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Na {{end}} query := {{$rel.ForeignTable.NamePluralGo}}(exec, queryMods...) - boil.SetFrom(query.Query, `"{{.ForeignTable}}" as "{{id 0}}"`) + boil.SetFrom(query.Query, `{{schemaTable $dot.DriverName $dot.Schema .ForeignTable}} as "{{id 0}}"`) return query } diff --git a/templates/06_relationship_to_one_eager.tpl b/templates/06_relationship_to_one_eager.tpl index 8d6a514ee..e7bcef7b6 100644 --- a/templates/06_relationship_to_one_eager.tpl +++ b/templates/06_relationship_to_one_eager.tpl @@ -28,7 +28,7 @@ func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bo } query := fmt.Sprintf( - `select * from "{{.ForeignKey.ForeignTable}}" where "{{.ForeignKey.ForeignColumn}}" in (%s)`, + `select * from {{schemaTable .DriverName .Schema .ForeignKey.ForeignTable}} where "{{.ForeignKey.ForeignColumn}}" in (%s)`, strmangle.Placeholders(count, 1, 1), ) @@ -79,8 +79,9 @@ func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bo return nil } - {{- end -}} -{{end -}} + {{- end -}}{{- /* end with */ -}} +{{end -}}{{- /* end define */ -}} + {{- if .Table.IsJoinTable -}} {{- else -}} {{- $dot := . -}} diff --git a/templates/07_relationship_to_many_eager.tpl b/templates/07_relationship_to_many_eager.tpl index b8276c61c..3d9dea746 100644 --- a/templates/07_relationship_to_many_eager.tpl +++ b/templates/07_relationship_to_many_eager.tpl @@ -35,12 +35,12 @@ func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singula {{if .ToJoinTable -}} query := fmt.Sprintf( - `select "{{id 0}}".*, "{{id 1}}"."{{.JoinLocalColumn}}" from "{{.ForeignTable}}" as "{{id 0}}" inner join "{{.JoinTable}}" as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}" where "{{id 1}}"."{{.JoinLocalColumn}}" in (%s)`, + `select "{{id 0}}".*, "{{id 1}}"."{{.JoinLocalColumn}}" from {{schemaTable .DriverName .Schema .ForeignTable}} as "{{id 0}}" inner join {{schemaTable .DriverName .Schema .JoinTable}} as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}" where "{{id 1}}"."{{.JoinLocalColumn}}" in (%s)`, strmangle.Placeholders(count, 1, 1), ) {{else -}} query := fmt.Sprintf( - `select * from "{{.ForeignTable}}" where "{{.ForeignColumn}}" in (%s)`, + `select * from {{schemaTable $dot.DriverName $dot.Schema .ForeignTable}} where "{{.ForeignColumn}}" in (%s)`, strmangle.Placeholders(count, 1, 1), ) {{end -}} diff --git a/templates/09_relationship_to_many_setops.tpl b/templates/09_relationship_to_many_setops.tpl index d2d292202..860387acc 100644 --- a/templates/09_relationship_to_many_setops.tpl +++ b/templates/09_relationship_to_many_setops.tpl @@ -37,7 +37,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function {{if .ToJoinTable -}} for _, rel := range related { - query := `insert into "{{.JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)` + query := `insert into {{schemaTable .DriverName .Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)` values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}} if boil.DebugMode { @@ -94,10 +94,10 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function // Sets related.R.{{$rel.Function.ForeignName}}'s {{$rel.Function.Name}} accordingly. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Set{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error { {{if .ToJoinTable -}} - query := `delete from "{{.JoinTable}}" where "{{.JoinLocalColumn}}" = $1` + query := `delete from {{schemaTable $dot.DriverName $dot.Schema .JoinTable}} where "{{.JoinLocalColumn}}" = $1` values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} {{else -}} - query := `update "{{.ForeignTable}}" set "{{.ForeignColumn}}" = null where "{{.ForeignColumn}}" = $1` + query := `update {{schemaTable $dot.DriverName $dot.Schema .ForeignTable}} set "{{.ForeignColumn}}" = null where "{{.ForeignColumn}}" = $1` values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} {{end -}} if boil.DebugMode { @@ -138,7 +138,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Remove{{$rel.Funct var err error {{if .ToJoinTable -}} query := fmt.Sprintf( - `delete from "{{.JoinTable}}" where "{{.JoinLocalColumn}}" = $1 and "{{.JoinForeignColumn}}" in (%s)`, + `delete from {{schemaTable .DriverName .Schema .JoinTable}} where "{{.JoinLocalColumn}}" = $1 and "{{.JoinForeignColumn}}" in (%s)`, strmangle.Placeholders(len(related), 1, 1), ) values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} diff --git a/templates/10_all.tpl b/templates/10_all.tpl index a9470b18a..01489254f 100644 --- a/templates/10_all.tpl +++ b/templates/10_all.tpl @@ -8,6 +8,6 @@ func {{$tableNamePlural}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { // {{$tableNamePlural}} retrieves all the records using an executor. func {{$tableNamePlural}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query { - mods = append(mods, qm.From("{{.Table.Name}}")) + mods = append(mods, qm.From(`{{schemaTable .DriverName .Schema .Table.Name}}`)) return {{$varNameSingular}}Query{NewQuery(exec, mods...)} } diff --git a/templates/11_find.tpl b/templates/11_find.tpl index 987afcf99..6cbe97465 100644 --- a/templates/11_find.tpl +++ b/templates/11_find.tpl @@ -29,7 +29,7 @@ func Find{{$tableNameSingular}}(exec boil.Executor, {{$pkArgs}}, selectCols ...s sel = strings.Join(strmangle.IdentQuoteSlice(selectCols), ",") } query := fmt.Sprintf( - `select %s from "{{.Table.Name}}" where {{whereClause 1 .Table.PKey.Columns}}`, sel, + `select %s from {{schemaTable .DriverName .Schema .Table.Name}} where {{whereClause 1 .Table.PKey.Columns}}`, sel, ) q := boil.SQL(exec, query, {{$pkNames | join ", "}}) diff --git a/templates/12_insert.tpl b/templates/12_insert.tpl index aa7958dca..4300124a7 100644 --- a/templates/12_insert.tpl +++ b/templates/12_insert.tpl @@ -64,11 +64,11 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string if err != nil { return err } - cache.query = fmt.Sprintf(`INSERT INTO {{.Table.Name}} ("%s") VALUES (%s)`, strings.Join(wl, `","`), strmangle.Placeholders(len(wl), 1, 1)) + cache.query = fmt.Sprintf(`INSERT INTO {{schemaTable .DriverName .Schema .Table.Name}} ("%s") VALUES (%s)`, strings.Join(wl, `","`), strmangle.Placeholders(len(wl), 1, 1)) if len(cache.retMapping) != 0 { {{if .UseLastInsertID -}} - cache.retQuery = fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, strings.Join(returnColumns, `","`), strmangle.WhereClause(1, {{$varNameSingular}}PrimaryKeyColumns)) + cache.retQuery = fmt.Sprintf(`SELECT %s FROM {{schemaTable .DriverName .Schema .Table.Name}} WHERE %s`, strings.Join(returnColumns, `","`), strmangle.WhereClause(1, {{$varNameSingular}}PrimaryKeyColumns)) {{else -}} cache.query += fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ",")) {{end -}} diff --git a/templates/13_update.tpl b/templates/13_update.tpl index 8483d113e..ef623ace2 100644 --- a/templates/13_update.tpl +++ b/templates/13_update.tpl @@ -52,7 +52,7 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string if !cached { wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist) - cache.query = fmt.Sprintf(`UPDATE "{{.Table.Name}}" SET %s WHERE %s`, strmangle.SetParamNames(wl), strmangle.WhereClause(len(wl)+1, {{$varNameSingular}}PrimaryKeyColumns)) + cache.query = fmt.Sprintf(`UPDATE {{schemaTable .DriverName .Schema .Table.Name}} SET %s WHERE %s`, strmangle.SetParamNames(wl), strmangle.WhereClause(len(wl)+1, {{$varNameSingular}}PrimaryKeyColumns)) cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...)) if err != nil { return err @@ -155,7 +155,7 @@ func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error args = append(args, o.inPrimaryKeyArgs()...) sql := fmt.Sprintf( - `UPDATE {{.Table.Name}} SET (%s) = (%s) WHERE (%s) IN (%s)`, + `UPDATE {{schemaTable .DriverName .Schema .Table.Name}} SET (%s) = (%s) WHERE (%s) IN (%s)`, strings.Join(colNames, ", "), strmangle.Placeholders(len(colNames), 1, 1), strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","), diff --git a/templates/15_delete.tpl b/templates/15_delete.tpl index a3c8b5145..ebcc7448f 100644 --- a/templates/15_delete.tpl +++ b/templates/15_delete.tpl @@ -43,7 +43,7 @@ func (o *{{$tableNameSingular}}) Delete(exec boil.Executor) error { args := o.inPrimaryKeyArgs() - sql := `DELETE FROM {{.Table.Name}} WHERE {{whereClause 1 .Table.PKey.Columns}}` + sql := `DELETE FROM {{schemaTable .DriverName .Schema .Table.Name}} WHERE {{whereClause 1 .Table.PKey.Columns}}` if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, sql) @@ -132,7 +132,7 @@ func (o {{$tableNameSingular}}Slice) DeleteAll(exec boil.Executor) error { args := o.inPrimaryKeyArgs() sql := fmt.Sprintf( - `DELETE FROM {{.Table.Name}} WHERE (%s) IN (%s)`, + `DELETE FROM {{schemaTable .DriverName .Schema .Table.Name}} WHERE (%s) IN (%s)`, strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","), strmangle.Placeholders(len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), ) diff --git a/templates/16_reload.tpl b/templates/16_reload.tpl index 38e83067c..f2d18a9db 100644 --- a/templates/16_reload.tpl +++ b/templates/16_reload.tpl @@ -67,7 +67,7 @@ func (o *{{$tableNameSingular}}Slice) ReloadAll(exec boil.Executor) error { args := o.inPrimaryKeyArgs() sql := fmt.Sprintf( - `SELECT {{.Table.Name}}.* FROM {{.Table.Name}} WHERE (%s) IN (%s)`, + `SELECT {{schemaTable .DriverName .Schema .Table.Name}}.* FROM {{schemaTable .DriverName .Schema .Table.Name}} WHERE (%s) IN (%s)`, strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","), strmangle.Placeholders(len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), ) diff --git a/templates/17_exists.tpl b/templates/17_exists.tpl index 899bec096..68b972906 100644 --- a/templates/17_exists.tpl +++ b/templates/17_exists.tpl @@ -6,7 +6,7 @@ func {{$tableNameSingular}}Exists(exec boil.Executor, {{$pkArgs}}) (bool, error) { var exists bool - sql := `select exists(select 1 from "{{.Table.Name}}" where {{whereClause 1 .Table.PKey.Columns}} limit 1)` + sql := `select exists(select 1 from {{schemaTable .DriverName .Schema .Table.Name}} where {{whereClause 1 .Table.PKey.Columns}} limit 1)` if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, sql) diff --git a/templates_test/relationship_to_many.tpl b/templates_test/relationship_to_many.tpl index cb9ee364a..bf74cef9c 100644 --- a/templates_test/relationship_to_many.tpl +++ b/templates_test/relationship_to_many.tpl @@ -41,11 +41,11 @@ func test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) { } {{if .ToJoinTable -}} - _, err = tx.Exec(`insert into "{{.JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) + _, err = tx.Exec(`insert into "{{schemaTable .DriverName .Schema .JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) if err != nil { t.Fatal(err) } - _, err = tx.Exec(`insert into "{{.JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) + _, err = tx.Exec(`insert into "{{schemaTable .DriverName .Schema .JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) if err != nil { t.Fatal(err) } From 6224b1c4632579a05a7a33f13ca903ba6d52e4e2 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Thu, 8 Sep 2016 20:41:50 -0700 Subject: [PATCH 11/64] Move globals away so we can make mysql driver --- bdb/drivers/postgres.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index dace79eba..9fc643081 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -19,7 +19,7 @@ type PostgresDriver struct { } // validatedTypes are types that cannot be zero values in the database. -var validatedTypes = []string{"uuid"} +var psqlValidatedTypes = []string{"uuid"} // NewPostgresDriver takes the database connection details as parameters and // returns a pointer to a PostgresDriver object. Note that it is required to @@ -27,14 +27,14 @@ var validatedTypes = []string{"uuid"} // the database connection once an object has been obtained. func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver { driver := PostgresDriver{ - connStr: BuildQueryString(user, pass, dbname, host, port, sslmode), + connStr: PostgresBuildQueryString(user, pass, dbname, host, port, sslmode), } return &driver } -// BuildQueryString for Postgres -func BuildQueryString(user, pass, dbname, host string, port int, sslmode string) string { +// PostgresBuildQueryString builds a query string. +func PostgresBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string { parts := []string{} if len(user) != 0 { parts = append(parts, fmt.Sprintf("user=%s", user)) @@ -163,7 +163,7 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error) Default: colDefault, Nullable: nullable == "YES", Unique: unique, - Validated: isValidated(colType), + Validated: psqlIsValidated(colType), } columns = append(columns, column) } @@ -323,8 +323,8 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { } // isValidated checks if the database type is in the validatedTypes list. -func isValidated(typ string) bool { - for _, v := range validatedTypes { +func psqlIsValidated(typ string) bool { + for _, v := range psqlValidatedTypes { if v == typ { return true } From 0eac708c56ded298bdc940f77505fe140e55ee7c Mon Sep 17 00:00:00 2001 From: Aaron L Date: Thu, 8 Sep 2016 20:42:02 -0700 Subject: [PATCH 12/64] Fix bad SQL parameterization --- bdb/drivers/postgres.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index 9fc643081..f88be8a50 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -9,6 +9,7 @@ import ( _ "github.com/lib/pq" "github.com/pkg/errors" "github.com/vattle/sqlboiler/bdb" + "github.com/vattle/sqlboiler/strmangle" ) // PostgresDriver holds the database connection string and a handle @@ -85,14 +86,21 @@ func (p *PostgresDriver) UseLastInsertID() bool { func (p *PostgresDriver) TableNames(schema string, whitelist, exclude []string) ([]string, error) { var names []string - query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = '%s'`, schema) + query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = ?`) + args := []interface{}{schema} if len(whitelist) > 0 { - query = query + fmt.Sprintf("and table_name in ('%s');", strings.Join(whitelist, "','")) + query += fmt.Sprintf("and table_name in (%s);", strmangle.Placeholders(len(whitelist), 1, 1)) + for _, w := range whitelist { + args = append(args, w) + } } else if len(exclude) > 0 { - query = query + fmt.Sprintf("and table_name not in ('%s');", strings.Join(exclude, "','")) + query += fmt.Sprintf("and table_name not in (%s);", strmangle.Placeholders(len(exclude), 1, 1)) + for _, e := range exclude { + args = append(args, e) + } } - rows, err := p.dbConn.Query(query) + rows, err := p.dbConn.Query(query, args...) if err != nil { return nil, err @@ -135,8 +143,8 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error) pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true )) as is_unique from information_schema.columns as c - where table_name=$2 and table_schema = $3; - `, schema, tableName, schema) + where table_name=$2 and table_schema = $1; + `, schema, tableName) if err != nil { return nil, err From 97d6636da44b2c51360eb3f031bae51f9c8ea392 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Thu, 8 Sep 2016 22:22:28 -0700 Subject: [PATCH 13/64] Add a MySQL driver --- bdb/drivers/mysql.go | 321 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 321 insertions(+) create mode 100644 bdb/drivers/mysql.go diff --git a/bdb/drivers/mysql.go b/bdb/drivers/mysql.go new file mode 100644 index 000000000..5b5043f46 --- /dev/null +++ b/bdb/drivers/mysql.go @@ -0,0 +1,321 @@ +package drivers + +import ( + "database/sql" + "fmt" + "strconv" + "strings" + + "github.com/go-sql-driver/mysql" + "github.com/pkg/errors" + "github.com/vattle/sqlboiler/bdb" +) + +// MySQLDriver holds the database connection string and a handle +// to the database connection. +type MySQLDriver struct { + connStr string + dbConn *sql.DB +} + +// NewMySQLDriver takes the database connection details as parameters and +// returns a pointer to a MySQLDriver object. Note that it is required to +// call MySQLDriver.Open() and MySQLDriver.Close() to open and close +// the database connection once an object has been obtained. +func NewMySQLDriver(user, pass, dbname, host string, port int, sslmode string) *MySQLDriver { + driver := MySQLDriver{ + connStr: MySQLBuildQueryString(user, pass, dbname, host, port, sslmode), + } + + return &driver +} + +// MySQLBuildQueryString builds a query string for MySQL. +func MySQLBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string { + var config mysql.Config + + config.User = user + if len(pass) != 0 { + config.Passwd = pass + } + config.DBName = dbname + config.Net = "tcp" + config.Addr = host + if port == 0 { + port = 3306 + } + config.Addr += ":" + strconv.Itoa(port) + config.TLSConfig = sslmode + + return config.FormatDSN() +} + +// Open opens the database connection using the connection string +func (m *MySQLDriver) Open() error { + var err error + m.dbConn, err = sql.Open("mysql", m.connStr) + if err != nil { + return err + } + + return nil +} + +// Close closes the database connection +func (m *MySQLDriver) Close() { + m.dbConn.Close() +} + +// UseLastInsertID returns false for postgres +func (m *MySQLDriver) UseLastInsertID() bool { + return true +} + +// TableNames connects to the postgres database and +// retrieves all table names from the information_schema where the +// table schema is public. It excludes common migration tool tables +// such as gorp_migrations +func (m *MySQLDriver) TableNames(schema string, whitelist, exclude []string) ([]string, error) { + var names []string + + query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = ?`) + args := []interface{}{schema} + if len(whitelist) > 0 { + query += fmt.Sprintf("and table_name in (%s);", strings.Repeat(",?", len(whitelist))[1:]) + for _, w := range whitelist { + args = append(args, w) + } + } else if len(exclude) > 0 { + query += fmt.Sprintf("and table_name not in (%s);", strings.Repeat(",?", len(exclude))[1:]) + for _, e := range exclude { + args = append(args, e) + } + } + + rows, err := m.dbConn.Query(query, args...) + + if err != nil { + return nil, err + } + + defer rows.Close() + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + names = append(names, name) + } + + return names, nil +} + +// Columns takes a table name and attempts to retrieve the table information +// from the database information_schema.columns. It retrieves the column names +// and column types and returns those as a []Column after TranslateColumnType() +// converts the SQL types to Go types, for example: "varchar" to "string" +func (m *MySQLDriver) Columns(schema, tableName string) ([]bdb.Column, error) { + var columns []bdb.Column + + rows, err := m.dbConn.Query(` + select column_name, data_type, column_default, is_nullable, + exists ( + select c.column_name + from information_schema.table_constraints tc + inner join information_schema.key_column_usage kcu + on tc.constraint_name = kcu.constraint_name and tc.table_name = kcu.table_name and tc.table_schema = kcu.table_schema + where c.column_name = kcu.column_name and tc.table_name = c.table_name and + (tc.constraint_type = 'PRIMARY KEY' or tc.constraint_type = 'UNIQUE') + ) as is_unique + from information_schema.columns as c + where table_name = ? and table_schema = ?; + `, tableName, schema) + + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var colName, colType, colDefault, nullable string + var unique bool + var defaultPtr *string + if err := rows.Scan(&colName, &colType, &defaultPtr, &nullable, &unique); err != nil { + return nil, errors.Wrapf(err, "unable to scan for table %s", tableName) + } + + if defaultPtr != nil && *defaultPtr != "NULL" { + colDefault = *defaultPtr + } + + column := bdb.Column{ + Name: colName, + DBType: colType, + Default: colDefault, + Nullable: nullable == "YES", + Unique: unique, + Validated: psqlIsValidated(colType), + } + columns = append(columns, column) + } + + return columns, nil +} + +// PrimaryKeyInfo looks up the primary key for a table. +func (m *MySQLDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, error) { + pkey := &bdb.PrimaryKey{} + var err error + + query := ` + select tc.constraint_name + from information_schema.table_constraints as tc + where tc.table_name = ? and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = ?;` + + row := m.dbConn.QueryRow(query, tableName, schema) + if err = row.Scan(&pkey.Name); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + + queryColumns := ` + select kcu.column_name + from information_schema.key_column_usage as kcu + where table_name = ? and constraint_name = ? and table_schema = ?;` + + var rows *sql.Rows + if rows, err = m.dbConn.Query(queryColumns, tableName, pkey.Name, schema); err != nil { + return nil, err + } + defer rows.Close() + + var columns []string + for rows.Next() { + var column string + + err = rows.Scan(&column) + if err != nil { + return nil, err + } + + columns = append(columns, column) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + pkey.Columns = columns + + return pkey, nil +} + +// ForeignKeyInfo retrieves the foreign keys for a given table name. +func (m *MySQLDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) { + var fkeys []bdb.ForeignKey + + query := ` + select constraint_name, table_name, column_name, referenced_table_name, referenced_column_name + from information_schema.key_column_usage + where table_schema = ? and referenced_table_schema = ? and table_name = ? + ` + + var rows *sql.Rows + var err error + if rows, err = m.dbConn.Query(query, schema, schema, tableName); err != nil { + return nil, err + } + + for rows.Next() { + var fkey bdb.ForeignKey + var sourceTable string + + fkey.Table = tableName + err = rows.Scan(&fkey.Name, &sourceTable, &fkey.Column, &fkey.ForeignTable, &fkey.ForeignColumn) + if err != nil { + return nil, err + } + + fkeys = append(fkeys, fkey) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return fkeys, nil +} + +// TranslateColumnType converts postgres database types to Go types, for example +// "varchar" to "string" and "bigint" to "int64". It returns this parsed data +// as a Column object. +func (m *MySQLDriver) TranslateColumnType(c bdb.Column) bdb.Column { + if c.Nullable { + switch c.DBType { + case "tinyint": + c.Type = "null.Int8" + case "smallint": + c.Type = "null.Int16" + case "mediumint", "int", "integer": + c.Type = "null.Int" + case "bigint": + c.Type = "null.Int64" + case "float": + c.Type = "null.Float32" + case "double", "double precision", "real": + c.Type = "null.Float64" + case "boolean", "bool": + c.Type = "null.Bool" + case "date", "datetime", "timestamp", "time": + c.Type = "null.Time" + case "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": + c.Type = "null.Bytes" + case "json": + c.Type = "types.JSON" + default: + c.Type = "null.String" + } + } else { + switch c.DBType { + case "tinyint": + c.Type = "int8" + case "smallint": + c.Type = "int16" + case "mediumint", "int", "integer": + c.Type = "int" + case "bigint": + c.Type = "null.Int64" + case "float": + c.Type = "float32" + case "double", "double precision", "real": + c.Type = "float64" + case "boolean", "bool": + c.Type = "bool" + case "date", "datetime", "timestamp", "time": + c.Type = "time.Time" + case "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": + c.Type = "[]byte" + case "json": + c.Type = "types.JSON" + default: + c.Type = "string" + } + } + + return c +} + +var mySQLValidatedTypes = []string{} + +// isValidated checks if the database type is in the validatedTypes list. +func mySQLIsValidated(typ string) bool { + for _, v := range mySQLValidatedTypes { + if v == typ { + return true + } + } + + return false +} From c65c1f6a2cd3e034ec8a2199afe0354d7411ca14 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Thu, 8 Sep 2016 22:37:54 -0700 Subject: [PATCH 14/64] Fix a mistaken panic on bad parameters --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index 2b7992332..db7b29dba 100644 --- a/main.go +++ b/main.go @@ -81,7 +81,7 @@ func main() { if e, ok := err.(commandFailure); ok { fmt.Printf("Error: %v\n\n", string(e)) rootCmd.Help() - } else if !cmdConfig.Debug { + } else if !viper.GetBool("debug") { fmt.Printf("Error: %v\n", err) } else { fmt.Printf("Error: %+v\n", err) From 16b6a2b176d67e87f058c75dbe199927e8b3060e Mon Sep 17 00:00:00 2001 From: Aaron L Date: Thu, 8 Sep 2016 22:41:57 -0700 Subject: [PATCH 15/64] Rename Exclude -> Blacklist --- README.md | 12 ++++++------ bdb/drivers/mock.go | 4 ++-- bdb/drivers/mysql.go | 13 ++++++------- bdb/drivers/postgres.go | 12 ++++++------ bdb/interface.go | 8 ++++---- bdb/interface_test.go | 4 ++-- config.go | 2 +- main.go | 12 ++++++------ sqlboiler.go | 6 +++--- sqlboiler_test.go | 8 ++++---- 10 files changed, 40 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index eb6576a8d..09918daca 100644 --- a/README.md +++ b/README.md @@ -223,12 +223,12 @@ not to pass them through the command line or environment variables: | Name | Default | | --- | --- | | basedir | none | -| schema | "public" | +| schema | "public" | | pkgname | "models" | | output | "models" | -| whitelist | [ ] | -| exclude | [ ] | -| tag | [ ] | +| whitelist | [] | +| blacklist | [] | +| tag | [] | | debug | false | | no-hooks | false | | no-tests | false | @@ -261,14 +261,14 @@ Examples: sqlboiler postgres Flags: - -b, --basedir string The base directory has the templates and templates_test folders + -b, --blacklist stringSlice Do not include these tables in your generated package -w, --whitelist stringSlice Only include these tables in your generated package - -x, --exclude stringSlice Tables to be excluded from the generated package -s, --schema string The name of your database schema, for databases that support real schemas (default "public") -p, --pkgname string The name you wish to assign to your generated package (default "models") -o, --output string The name of the folder to output to (default "models") -t, --tag stringSlice Struct tags to be included on your models in addition to json, yaml, toml -d, --debug Debug mode prints stack traces on error + --basedir string The base directory has the templates and templates_test folders --no-auto-timestamps Disable automatic timestamps for created_at/updated_at --no-hooks Disable hooks feature for your models --no-tests Disable generated go test files diff --git a/bdb/drivers/mock.go b/bdb/drivers/mock.go index 4371e6f63..90d4e5401 100644 --- a/bdb/drivers/mock.go +++ b/bdb/drivers/mock.go @@ -9,12 +9,12 @@ import ( type MockDriver struct{} // TableNames returns a list of mock table names -func (m *MockDriver) TableNames(schema string, whitelist, exclude []string) ([]string, error) { +func (m *MockDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) { if len(whitelist) > 0 { return whitelist, nil } tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"} - return strmangle.SetComplement(tables, exclude), nil + return strmangle.SetComplement(tables, blacklist), nil } // Columns returns a list of mock columns diff --git a/bdb/drivers/mysql.go b/bdb/drivers/mysql.go index 5b5043f46..963591515 100644 --- a/bdb/drivers/mysql.go +++ b/bdb/drivers/mysql.go @@ -73,9 +73,8 @@ func (m *MySQLDriver) UseLastInsertID() bool { // TableNames connects to the postgres database and // retrieves all table names from the information_schema where the -// table schema is public. It excludes common migration tool tables -// such as gorp_migrations -func (m *MySQLDriver) TableNames(schema string, whitelist, exclude []string) ([]string, error) { +// table schema is public. +func (m *MySQLDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) { var names []string query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = ?`) @@ -85,10 +84,10 @@ func (m *MySQLDriver) TableNames(schema string, whitelist, exclude []string) ([] for _, w := range whitelist { args = append(args, w) } - } else if len(exclude) > 0 { - query += fmt.Sprintf("and table_name not in (%s);", strings.Repeat(",?", len(exclude))[1:]) - for _, e := range exclude { - args = append(args, e) + } else if len(blacklist) > 0 { + query += fmt.Sprintf("and table_name not in (%s);", strings.Repeat(",?", len(blacklist))[1:]) + for _, b := range blacklist { + args = append(args, b) } } diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index f88be8a50..e47829e9f 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -82,8 +82,8 @@ func (p *PostgresDriver) UseLastInsertID() bool { // TableNames connects to the postgres database and // retrieves all table names from the information_schema where the -// table schema is schema. It uses a whitelist and exclude list. -func (p *PostgresDriver) TableNames(schema string, whitelist, exclude []string) ([]string, error) { +// table schema is schema. It uses a whitelist and blacklist. +func (p *PostgresDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) { var names []string query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = ?`) @@ -93,10 +93,10 @@ func (p *PostgresDriver) TableNames(schema string, whitelist, exclude []string) for _, w := range whitelist { args = append(args, w) } - } else if len(exclude) > 0 { - query += fmt.Sprintf("and table_name not in (%s);", strmangle.Placeholders(len(exclude), 1, 1)) - for _, e := range exclude { - args = append(args, e) + } else if len(blacklist) > 0 { + query += fmt.Sprintf("and table_name not in (%s);", strmangle.Placeholders(len(blacklist), 1, 1)) + for _, b := range blacklist { + args = append(args, b) } } diff --git a/bdb/interface.go b/bdb/interface.go index a7f14e09a..b4d6b7a5d 100644 --- a/bdb/interface.go +++ b/bdb/interface.go @@ -6,7 +6,7 @@ import "github.com/pkg/errors" // Interface for a database driver. Functionality required to support a specific // database type (eg, MySQL, Postgres etc.) type Interface interface { - TableNames(schema string, whitelist, exclude []string) ([]string, error) + TableNames(schema string, whitelist, blacklist []string) ([]string, error) Columns(schema, tableName string) ([]Column, error) PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, error) ForeignKeyInfo(schema, tableName string) ([]ForeignKey, error) @@ -25,11 +25,11 @@ type Interface interface { } // Tables returns the metadata for all tables, minus the tables -// specified in the exclude slice. -func Tables(db Interface, schema string, whitelist, exclude []string) ([]Table, error) { +// specified in the blacklist. +func Tables(db Interface, schema string, whitelist, blacklist []string) ([]Table, error) { var err error - names, err := db.TableNames(schema, whitelist, exclude) + names, err := db.TableNames(schema, whitelist, blacklist) if err != nil { return nil, errors.Wrap(err, "unable to get table names") } diff --git a/bdb/interface_test.go b/bdb/interface_test.go index 048f3df5c..73cd4f54e 100644 --- a/bdb/interface_test.go +++ b/bdb/interface_test.go @@ -13,12 +13,12 @@ func (m mockDriver) UseLastInsertID() bool { return false } func (m mockDriver) Open() error { return nil } func (m mockDriver) Close() {} -func (m mockDriver) TableNames(whitelist, exclude []string) ([]string, error) { +func (m mockDriver) TableNames(whitelist, blacklist []string) ([]string, error) { if len(whitelist) > 0 { return whitelist, nil } tables := []string{"pilots", "jets", "airports", "licenses", "hangars", "languages", "pilot_languages"} - return strmangle.SetComplement(tables, exclude), nil + return strmangle.SetComplement(tables, blacklist), nil } // Columns returns a list of mock columns diff --git a/config.go b/config.go index 2c5a2b3e7..a637a9917 100644 --- a/config.go +++ b/config.go @@ -8,7 +8,7 @@ type Config struct { OutFolder string BaseDir string WhitelistTables []string - ExcludeTables []string + BlacklistTables []string Tags []string Debug bool NoTests bool diff --git a/main.go b/main.go index db7b29dba..a4d2a4df9 100644 --- a/main.go +++ b/main.go @@ -63,8 +63,8 @@ func main() { rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to") rootCmd.PersistentFlags().StringP("schema", "s", "public", "The name of your database schema, for databases that support real schemas") rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package") - rootCmd.PersistentFlags().StringP("basedir", "b", "", "The base directory has the templates and templates_test folders") - rootCmd.PersistentFlags().StringSliceP("exclude", "x", nil, "Tables to be excluded from the generated package") + rootCmd.PersistentFlags().StringP("basedir", "", "", "The base directory has the templates and templates_test folders") + rootCmd.PersistentFlags().StringSliceP("blacklist", "b", nil, "Do not include these tables in your generated package") rootCmd.PersistentFlags().StringSliceP("whitelist", "w", nil, "Only include these tables in your generated package") rootCmd.PersistentFlags().StringSliceP("tag", "t", nil, "Struct tags to be included on your models in addition to json, yaml, toml") rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error") @@ -118,12 +118,12 @@ func preRun(cmd *cobra.Command, args []string) error { } // BUG: https://github.com/spf13/viper/issues/200 - // Look up the value of ExcludeTables & Tags directly from PFlags in Cobra if we + // Look up the value of blacklist, whitelist & tags directly from PFlags in Cobra if we // detect a malformed value coming out of viper. // Once the bug is fixed we'll be able to move this into the init above - cmdConfig.ExcludeTables = viper.GetStringSlice("exclude") - if len(cmdConfig.ExcludeTables) == 1 && strings.HasPrefix(cmdConfig.ExcludeTables[0], "[") { - cmdConfig.ExcludeTables, err = cmd.PersistentFlags().GetStringSlice("exclude") + cmdConfig.BlacklistTables = viper.GetStringSlice("blacklist") + if len(cmdConfig.BlacklistTables) == 1 && strings.HasPrefix(cmdConfig.BlacklistTables[0], "[") { + cmdConfig.BlacklistTables, err = cmd.PersistentFlags().GetStringSlice("blacklist") if err != nil { return err } diff --git a/sqlboiler.go b/sqlboiler.go index 0757c3c5c..1c3bf309d 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -59,7 +59,7 @@ func New(config *Config) (*State, error) { return nil, errors.Wrap(err, "unable to connect to the database") } - err = s.initTables(config.Schema, config.WhitelistTables, config.ExcludeTables) + err = s.initTables(config.Schema, config.WhitelistTables, config.BlacklistTables) if err != nil { return nil, errors.Wrap(err, "unable to initialize tables") } @@ -241,9 +241,9 @@ func (s *State) initDriver(driverName string) error { } // initTables retrieves all "public" schema table names from the database. -func (s *State) initTables(schema string, whitelist, exclude []string) error { +func (s *State) initTables(schema string, whitelist, blacklist []string) error { var err error - s.Tables, err = bdb.Tables(s.Driver, schema, whitelist, exclude) + s.Tables, err = bdb.Tables(s.Driver, schema, whitelist, blacklist) if err != nil { return errors.Wrap(err, "unable to fetch table data") } diff --git a/sqlboiler_test.go b/sqlboiler_test.go index a170361fc..367e429b5 100644 --- a/sqlboiler_test.go +++ b/sqlboiler_test.go @@ -37,10 +37,10 @@ func TestNew(t *testing.T) { }() config := &Config{ - DriverName: "mock", - PkgName: "models", - OutFolder: out, - ExcludeTables: []string{"hangars"}, + DriverName: "mock", + PkgName: "models", + OutFolder: out, + BlacklistTables: []string{"hangars"}, } state, err = New(config) From 1e0b90a99c6d333be7229b353539033b28138a97 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Thu, 8 Sep 2016 22:42:49 -0700 Subject: [PATCH 16/64] Fix problem with table lookup in postgres --- bdb/drivers/postgres.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index e47829e9f..a4e7b82fb 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -86,15 +86,15 @@ func (p *PostgresDriver) UseLastInsertID() bool { func (p *PostgresDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) { var names []string - query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = ?`) + query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = $1`) args := []interface{}{schema} if len(whitelist) > 0 { - query += fmt.Sprintf("and table_name in (%s);", strmangle.Placeholders(len(whitelist), 1, 1)) + query += fmt.Sprintf(" and table_name in (%s);", strmangle.Placeholders(len(whitelist), 2, 1)) for _, w := range whitelist { args = append(args, w) } } else if len(blacklist) > 0 { - query += fmt.Sprintf("and table_name not in (%s);", strmangle.Placeholders(len(blacklist), 1, 1)) + query += fmt.Sprintf(" and table_name not in (%s);", strmangle.Placeholders(len(blacklist), 2+len(whitelist), 1)) for _, b := range blacklist { args = append(args, b) } From aadcf63e52b0e5afd60922976f85a4f2b6a62330 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Thu, 8 Sep 2016 23:04:33 -0700 Subject: [PATCH 17/64] Fix problem with mysql table query. --- bdb/drivers/mysql.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bdb/drivers/mysql.go b/bdb/drivers/mysql.go index 963591515..ee6dfe643 100644 --- a/bdb/drivers/mysql.go +++ b/bdb/drivers/mysql.go @@ -80,12 +80,12 @@ func (m *MySQLDriver) TableNames(schema string, whitelist, blacklist []string) ( query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = ?`) args := []interface{}{schema} if len(whitelist) > 0 { - query += fmt.Sprintf("and table_name in (%s);", strings.Repeat(",?", len(whitelist))[1:]) + query += fmt.Sprintf(" and table_name in (%s);", strings.Repeat(",?", len(whitelist))[1:]) for _, w := range whitelist { args = append(args, w) } } else if len(blacklist) > 0 { - query += fmt.Sprintf("and table_name not in (%s);", strings.Repeat(",?", len(blacklist))[1:]) + query += fmt.Sprintf(" and table_name not in (%s);", strings.Repeat(",?", len(blacklist))[1:]) for _, b := range blacklist { args = append(args, b) } From 81148d4beb732f4518e0f6abdd0b416a53ac5b0c Mon Sep 17 00:00:00 2001 From: Aaron L Date: Thu, 8 Sep 2016 23:04:58 -0700 Subject: [PATCH 18/64] Add MySQL configuration. --- config.go | 11 ++++++++++ main.go | 57 +++++++++++++++++++++++++++++++++++++++++++++++++--- sqlboiler.go | 9 +++++++++ 3 files changed, 74 insertions(+), 3 deletions(-) diff --git a/config.go b/config.go index a637a9917..da2813448 100644 --- a/config.go +++ b/config.go @@ -16,6 +16,7 @@ type Config struct { NoAutoTimestamps bool Postgres PostgresConfig + MySQL MySQLConfig } // PostgresConfig configures a postgres database @@ -27,3 +28,13 @@ type PostgresConfig struct { DBName string SSLMode string } + +// MySQLConfig configures a mysql database +type MySQLConfig struct { + User string + Pass string + Host string + Port int + DBName string + SSLMode string +} diff --git a/main.go b/main.go index a4d2a4df9..9008892b7 100644 --- a/main.go +++ b/main.go @@ -74,6 +74,9 @@ func main() { viper.SetDefault("postgres.sslmode", "require") viper.SetDefault("postgres.port", "5432") + viper.SetDefault("mysql.sslmode", "true") + viper.SetDefault("mysql.port", "3306") + viper.BindPFlags(rootCmd.PersistentFlags()) viper.AutomaticEnv() @@ -155,10 +158,17 @@ func preRun(cmd *cobra.Command, args []string) error { SSLMode: viper.GetString("postgres.sslmode"), } - // Set the default SSLMode value + // BUG: https://github.com/spf13/viper/issues/71 + // Despite setting defaults, nested values don't get defaults + // Set them manually if cmdConfig.Postgres.SSLMode == "" { - viper.Set("postgres.sslmode", "require") - cmdConfig.Postgres.SSLMode = viper.GetString("postgres.sslmode") + cmdConfig.Postgres.SSLMode = "require" + viper.Set("postgres.sslmode", cmdConfig.Postgres.SSLMode) + } + + if cmdConfig.Postgres.Port == 0 { + cmdConfig.Postgres.Port = 5432 + viper.Set("postgres.port", cmdConfig.Postgres.Port) } err = vala.BeginValidation().Validate( @@ -176,6 +186,47 @@ func preRun(cmd *cobra.Command, args []string) error { return errors.New("postgres driver requires a postgres section in your config file") } + if viper.IsSet("mysql.dbname") { + cmdConfig.MySQL = MySQLConfig{ + User: viper.GetString("mysql.user"), + Pass: viper.GetString("mysql.pass"), + Host: viper.GetString("mysql.host"), + Port: viper.GetInt("mysql.port"), + DBName: viper.GetString("mysql.dbname"), + SSLMode: viper.GetString("mysql.sslmode"), + } + + // MySQL doesn't have schemas, just databases + cmdConfig.Schema = cmdConfig.MySQL.DBName + + // BUG: https://github.com/spf13/viper/issues/71 + // Despite setting defaults, nested values don't get defaults + // Set them manually + if cmdConfig.MySQL.SSLMode == "" { + cmdConfig.MySQL.SSLMode = "true" + viper.Set("mysql.sslmode", cmdConfig.MySQL.SSLMode) + } + + if cmdConfig.MySQL.Port == 0 { + cmdConfig.MySQL.Port = 3306 + viper.Set("mysql.port", cmdConfig.MySQL.Port) + } + + err = vala.BeginValidation().Validate( + vala.StringNotEmpty(cmdConfig.MySQL.User, "mysql.user"), + vala.StringNotEmpty(cmdConfig.MySQL.Host, "mysql.host"), + vala.Not(vala.Equals(cmdConfig.MySQL.Port, 0, "mysql.port")), + vala.StringNotEmpty(cmdConfig.MySQL.DBName, "mysql.dbname"), + vala.StringNotEmpty(cmdConfig.MySQL.SSLMode, "mysql.sslmode"), + ).Check() + + if err != nil { + return commandFailure(err.Error()) + } + } else if driverName == "mysql" { + return errors.New("mysql driver requires a mysql section in your config file") + } + cmdState, err = New(cmdConfig) return err } diff --git a/sqlboiler.go b/sqlboiler.go index 1c3bf309d..0234cd860 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -229,6 +229,15 @@ func (s *State) initDriver(driverName string) error { s.Config.Postgres.Port, s.Config.Postgres.SSLMode, ) + case "mysql": + s.Driver = drivers.NewMySQLDriver( + s.Config.MySQL.User, + s.Config.MySQL.Pass, + s.Config.MySQL.DBName, + s.Config.MySQL.Host, + s.Config.MySQL.Port, + s.Config.MySQL.SSLMode, + ) case "mock": s.Driver = &drivers.MockDriver{} } From b1efbd21c7ccb216b4598e2a2f9c4db9be6f5f3e Mon Sep 17 00:00:00 2001 From: Aaron L Date: Thu, 8 Sep 2016 23:05:09 -0700 Subject: [PATCH 19/64] Add a MySQL main test that does nothing --- templates_test/main_test/mysql_main.tpl | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 templates_test/main_test/mysql_main.tpl diff --git a/templates_test/main_test/mysql_main.tpl b/templates_test/main_test/mysql_main.tpl new file mode 100644 index 000000000..96dcbfc05 --- /dev/null +++ b/templates_test/main_test/mysql_main.tpl @@ -0,0 +1,2 @@ +func TestMain(m *testing.M) { +} From 51b4f9b309b27d8b61f1261f2084c2558aa93043 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Fri, 9 Sep 2016 17:57:01 +1000 Subject: [PATCH 20/64] Fix superfluous arg to placeholders --- bdb/drivers/postgres.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index a4e7b82fb..7c27d4741 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -94,7 +94,7 @@ func (p *PostgresDriver) TableNames(schema string, whitelist, blacklist []string args = append(args, w) } } else if len(blacklist) > 0 { - query += fmt.Sprintf(" and table_name not in (%s);", strmangle.Placeholders(len(blacklist), 2+len(whitelist), 1)) + query += fmt.Sprintf(" and table_name not in (%s);", strmangle.Placeholders(len(blacklist), 2, 1)) for _, b := range blacklist { args = append(args, b) } From ac02f7d2e0353b058b7124a7086e93cece13e8e7 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Fri, 9 Sep 2016 22:31:51 +1000 Subject: [PATCH 21/64] Fix broken relationship templates by adding schema --- templates/04_relationship_to_one.tpl | 13 +++++---- templates/05_relationship_to_many.tpl | 12 ++++---- templates/06_relationship_to_one_eager.tpl | 15 +++++----- templates/07_relationship_to_many_eager.tpl | 27 ++++++++++-------- templates/08_relationship_to_one_setops.tpl | 22 ++++++++------ templates/09_relationship_to_many_setops.tpl | 30 +++++++++++--------- templates_test/main_test/postgres_main.tpl | 2 +- templates_test/relationship_to_many.tpl | 4 +-- 8 files changed, 70 insertions(+), 55 deletions(-) diff --git a/templates/04_relationship_to_one.tpl b/templates/04_relationship_to_one.tpl index 9b75a260c..ca872a566 100644 --- a/templates/04_relationship_to_one.tpl +++ b/templates/04_relationship_to_one.tpl @@ -1,5 +1,7 @@ {{- define "relationship_to_one_helper" -}} -{{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} + {{- $tmplData := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} + {{- with .Rel -}}{{/* Rel holds the text helper data, passed in through preserveDot */}} + {{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} // {{.Function.Name}}G pointed to by the foreign key. func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { return {{.Function.Receiver}}.{{.Function.Name}}(boil.GetDB(), mods...) @@ -14,18 +16,19 @@ func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}(exec bo queryMods = append(queryMods, mods...) query := {{.ForeignTable.NamePluralGo}}(exec, queryMods...) - boil.SetFrom(query.Query, `{{schemaTable .DriverName .Schema .ForeignTable.Name}}`) + boil.SetFrom(query.Query, `{{schemaTable $tmplData.DriverName $tmplData.Schema .ForeignTable.Name}}`) return query } + {{- end -}}{{/* end with */}} {{end -}}{{/* end define */}} -{{/* Begin execution of template for one-to-one relationship. */}} +{{- /* Begin execution of template for one-to-one relationship */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} {{- $dot := . -}} {{- range .Table.FKeys -}} - {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} -{{- template "relationship_to_one_helper" $rel -}} + {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} +{{- template "relationship_to_one_helper" (preserveDot $dot $txt) -}} {{- end -}} {{- end -}} diff --git a/templates/05_relationship_to_many.tpl b/templates/05_relationship_to_many.tpl index dd08d32c6..db7e81525 100644 --- a/templates/05_relationship_to_many.tpl +++ b/templates/05_relationship_to_many.tpl @@ -1,3 +1,4 @@ +{{- /* Begin execution of template for many-to-one or many-to-many relationship helper */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} {{- $dot := . -}} @@ -5,11 +6,12 @@ {{- range .Table.ToManyRelationships -}} {{- $varNameSingular := .ForeignTable | singular | camelCase -}} {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} -{{- /* Begin execution of template for many-to-one relationship. */ -}} -{{- template "relationship_to_one_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}} + {{- /* Begin execution of template for many-to-one relationship. */ -}} + {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table . -}} + {{- template "relationship_to_one_helper" (preserveDot $dot $txt) -}} {{- else -}} -{{- /* Begin execution of template for many-to-many relationship. */ -}} - {{- $rel := textsFromRelationship $dot.Tables $table . -}} + {{- /* Begin execution of template for many-to-many relationship. */ -}} + {{- $rel := textsFromRelationship $dot.Tables $table . -}} // {{$rel.Function.Name}}G retrieves all the {{$rel.LocalTable.NameSingular}}'s {{$rel.ForeignTable.NameHumanReadable}} {{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Name}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { @@ -45,4 +47,4 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Na {{end -}}{{- /* if unique foreign key */ -}} {{- end -}}{{- /* range relationships */ -}} -{{- end -}}{{- /* outer if join table */ -}} +{{- end -}}{{- /* if isJoinTable */ -}} diff --git a/templates/06_relationship_to_one_eager.tpl b/templates/06_relationship_to_one_eager.tpl index e7bcef7b6..a14cc438a 100644 --- a/templates/06_relationship_to_one_eager.tpl +++ b/templates/06_relationship_to_one_eager.tpl @@ -1,6 +1,6 @@ {{- define "relationship_to_one_eager_helper" -}} - {{- $varNameSingular := .Dot.Table.Name | singular | camelCase -}} - {{- $noHooks := .Dot.NoHooks -}} + {{- $tmplData := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} + {{- $varNameSingular := $tmplData.Table.Name | singular | camelCase -}} {{- with .Rel -}} {{- $arg := printf "maybe%s" .LocalTable.NameGo -}} {{- $slice := printf "%sSlice" .LocalTable.NameGo -}} @@ -28,7 +28,7 @@ func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bo } query := fmt.Sprintf( - `select * from {{schemaTable .DriverName .Schema .ForeignKey.ForeignTable}} where "{{.ForeignKey.ForeignColumn}}" in (%s)`, + `select * from {{schemaTable $tmplData.DriverName $tmplData.Schema .ForeignKey.ForeignTable}} where "{{.ForeignKey.ForeignColumn}}" in (%s)`, strmangle.Placeholders(count, 1, 1), ) @@ -47,7 +47,7 @@ func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bo return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable.NameGo}}") } - {{if not $noHooks -}} + {{if not $tmplData.NoHooks -}} if len({{.ForeignTable.Name | singular | camelCase}}AfterSelectHooks) != 0 { for _, obj := range resultSlice { if err := obj.doAfterSelectHooks(e); err != nil { @@ -82,11 +82,12 @@ func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bo {{- end -}}{{- /* end with */ -}} {{end -}}{{- /* end define */ -}} +{{- /* Begin execution of template for one-to-one eager load */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} {{- $dot := . -}} {{- range .Table.FKeys -}} - {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} -{{- template "relationship_to_one_eager_helper" (preserveDot $dot $rel) -}} -{{- end -}} + {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} + {{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}} + {{- end -}} {{end}} diff --git a/templates/07_relationship_to_many_eager.tpl b/templates/07_relationship_to_many_eager.tpl index 3d9dea746..d99953e77 100644 --- a/templates/07_relationship_to_many_eager.tpl +++ b/templates/07_relationship_to_many_eager.tpl @@ -1,15 +1,18 @@ +{{- /* Begin execution of template for many-to-one or many-to-many eager load */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} -{{- $dot := . -}} -{{- range .Table.ToManyRelationships -}} -{{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} - {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table . -}} - {{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}} -{{- else -}} - {{- $varNameSingular := $dot.Table.Name | singular | camelCase -}} - {{- $txt := textsFromRelationship $dot.Tables $dot.Table . -}} - {{- $arg := printf "maybe%s" $txt.LocalTable.NameGo -}} - {{- $slice := printf "%sSlice" $txt.LocalTable.NameGo -}} + {{- $dot := . -}} + {{- range .Table.ToManyRelationships -}} + {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- /* Begin execution of template for many-to-one eager load */ -}} + {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table . -}} + {{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}} + {{- else -}} + {{- /* Begin execution of template for many-to-many eager load */ -}} + {{- $varNameSingular := $dot.Table.Name | singular | camelCase -}} + {{- $txt := textsFromRelationship $dot.Tables $dot.Table . -}} + {{- $arg := printf "maybe%s" $txt.LocalTable.NameGo -}} + {{- $slice := printf "%sSlice" $txt.LocalTable.NameGo -}} // Load{{$txt.Function.Name}} allows an eager lookup of values, cached into the // loaded structs of the objects. func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error { @@ -35,7 +38,7 @@ func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singula {{if .ToJoinTable -}} query := fmt.Sprintf( - `select "{{id 0}}".*, "{{id 1}}"."{{.JoinLocalColumn}}" from {{schemaTable .DriverName .Schema .ForeignTable}} as "{{id 0}}" inner join {{schemaTable .DriverName .Schema .JoinTable}} as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}" where "{{id 1}}"."{{.JoinLocalColumn}}" in (%s)`, + `select "{{id 0}}".*, "{{id 1}}"."{{.JoinLocalColumn}}" from {{schemaTable $dot.DriverName $dot.Schema .ForeignTable}} as "{{id 0}}" inner join {{schemaTable $dot.DriverName $dot.Schema .JoinTable}} as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}" where "{{id 1}}"."{{.JoinLocalColumn}}" in (%s)`, strmangle.Placeholders(count, 1, 1), ) {{else -}} @@ -133,4 +136,4 @@ func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singula {{end -}}{{/* if ForeignColumnUnique */}} {{- end -}}{{/* range tomany */}} -{{- end -}}{{/* if isjointable */}} +{{- end -}}{{/* if IsJoinTable */}} diff --git a/templates/08_relationship_to_one_setops.tpl b/templates/08_relationship_to_one_setops.tpl index 8f8457fe1..66a6d8959 100644 --- a/templates/08_relationship_to_one_setops.tpl +++ b/templates/08_relationship_to_one_setops.tpl @@ -1,7 +1,8 @@ {{- define "relationship_to_one_setops_helper" -}} -{{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} -{{- $localNameSingular := .ForeignKey.Table | singular | camelCase}} - + {{- $tmplData := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} + {{- with .Rel -}} + {{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} + {{- $localNameSingular := .ForeignKey.Table | singular | camelCase}} // Set{{.Function.Name}} of the {{.ForeignKey.Table | singular}} to the related item. // Sets {{.Function.Receiver}}.R.{{.Function.Name}} to related. // Adds {{.Function.Receiver}} to related.R.{{.Function.ForeignName}}. @@ -51,8 +52,8 @@ func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) Set{{.Function.Name}}(exec {{end -}} return nil } -{{- if .ForeignKey.Nullable}} + {{- if .ForeignKey.Nullable}} // Remove{{.Function.Name}} relationship. // Sets {{.Function.Receiver}}.R.{{.Function.Name}} to nil. // Removes {{.Function.Receiver}} from all passed in related items' relationships struct (Optional). @@ -89,13 +90,16 @@ func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) Remove{{.Function.Name}}(e return nil } -{{end -}} -{{- end -}} + {{- end -}}{{/* if foreignkey nullable */}} + {{end -}}{{/* end with */}} +{{- end -}}{{/* end define */}} + +{{- /* Begin execution of template for one-to-one setops */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} {{- $dot := . -}} {{- range .Table.FKeys -}} - {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} -{{- template "relationship_to_one_setops_helper" $rel -}} -{{- end -}} + {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} + {{- template "relationship_to_one_setops_helper" (preserveDot $dot $txt) -}} + {{- end -}} {{- end -}} diff --git a/templates/09_relationship_to_many_setops.tpl b/templates/09_relationship_to_many_setops.tpl index 860387acc..1cc4c11ba 100644 --- a/templates/09_relationship_to_many_setops.tpl +++ b/templates/09_relationship_to_many_setops.tpl @@ -1,3 +1,4 @@ +{{- /* Begin execution of template for many-to-one or many-to-many setops */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} {{- $dot := . -}} @@ -5,12 +6,13 @@ {{- range .Table.ToManyRelationships -}} {{- $varNameSingular := .ForeignTable | singular | camelCase -}} {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} -{{- template "relationship_to_one_setops_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}} + {{- /* Begin execution of template for many-to-one setops */ -}} + {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table . -}} + {{- template "relationship_to_one_setops_helper" (preserveDot $dot $txt) -}} {{- else -}} - {{- $rel := textsFromRelationship $dot.Tables $table . -}} - {{- $localNameSingular := .Table | singular | camelCase -}} - {{- $foreignNameSingular := .ForeignTable | singular | camelCase}} - + {{- $rel := textsFromRelationship $dot.Tables $table . -}} + {{- $localNameSingular := .Table | singular | camelCase -}} + {{- $foreignNameSingular := .ForeignTable | singular | camelCase}} // Add{{$rel.Function.Name}} adds the given related objects to the existing relationships // of the {{$table.Name | singular}}, optionally inserting them as new records. // Appends related to {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}. @@ -37,7 +39,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function {{if .ToJoinTable -}} for _, rel := range related { - query := `insert into {{schemaTable .DriverName .Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)` + query := `insert into {{schemaTable $dot.DriverName $dot.Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)` values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}} if boil.DebugMode { @@ -84,8 +86,8 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function return nil } -{{- if (or .ForeignColumnNullable .ToJoinTable)}} + {{- if (or .ForeignColumnNullable .ToJoinTable)}} // Set{{$rel.Function.Name}} removes all previously related items of the // {{$table.Name | singular}} replacing them completely with the passed // in related items, optionally inserting them as new records. @@ -138,7 +140,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Remove{{$rel.Funct var err error {{if .ToJoinTable -}} query := fmt.Sprintf( - `delete from {{schemaTable .DriverName .Schema .JoinTable}} where "{{.JoinLocalColumn}}" = $1 and "{{.JoinForeignColumn}}" in (%s)`, + `delete from {{schemaTable $dot.DriverName $dot.Schema .JoinTable}} where "{{.JoinLocalColumn}}" = $1 and "{{.JoinForeignColumn}}" in (%s)`, strmangle.Placeholders(len(related), 1, 1), ) values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} @@ -191,7 +193,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Remove{{$rel.Funct return nil } -{{if .ToJoinTable -}} + {{if .ToJoinTable -}} func remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}, related []*{{$rel.ForeignTable.NameGo}}) { for _, rel := range related { if rel.R == nil { @@ -211,8 +213,8 @@ func remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$re } } } -{{end -}}{{- /* if join table */ -}} -{{- end -}}{{- /* if nullable foreign key */ -}} -{{- end -}}{{- /* if unique foreign key */ -}} -{{- end -}}{{- /* range relationships */ -}} -{{- end -}}{{- /* outer if join table */ -}} + {{end -}}{{- /* if ToJoinTable */ -}} + {{- end -}}{{- /* if nullable foreign key */ -}} + {{- end -}}{{- /* if unique foreign key */ -}} + {{- end -}}{{- /* range relationships */ -}} +{{- end -}}{{- /* if IsJoinTable */ -}} diff --git a/templates_test/main_test/postgres_main.tpl b/templates_test/main_test/postgres_main.tpl index ee3ce6985..2750f60cc 100644 --- a/templates_test/main_test/postgres_main.tpl +++ b/templates_test/main_test/postgres_main.tpl @@ -98,7 +98,7 @@ func dropTestDB() error { // DBConnect connects to a database and returns the handle. func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.DB, error) { - connStr := drivers.BuildQueryString(user, pass, dbname, host, port, sslmode) + connStr := drivers.PostgresBuildQueryString(user, pass, dbname, host, port, sslmode) return sql.Open("postgres", connStr) } diff --git a/templates_test/relationship_to_many.tpl b/templates_test/relationship_to_many.tpl index bf74cef9c..85c9da83f 100644 --- a/templates_test/relationship_to_many.tpl +++ b/templates_test/relationship_to_many.tpl @@ -41,11 +41,11 @@ func test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) { } {{if .ToJoinTable -}} - _, err = tx.Exec(`insert into "{{schemaTable .DriverName .Schema .JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) + _, err = tx.Exec(`insert into {{schemaTable $dot.DriverName $dot.Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) if err != nil { t.Fatal(err) } - _, err = tx.Exec(`insert into "{{schemaTable .DriverName .Schema .JoinTable}}" ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) + _, err = tx.Exec(`insert into {{schemaTable $dot.DriverName $dot.Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) if err != nil { t.Fatal(err) } From 817189fbfda225721f3d1e02360e6c54a9f99ab2 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Sat, 10 Sep 2016 01:06:07 +1000 Subject: [PATCH 22/64] Fixed outstanding failed tests --- bdb/interface_test.go | 20 ++++++++++---------- strmangle/strmangle.go | 2 +- text_helpers_test.go | 6 +++--- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/bdb/interface_test.go b/bdb/interface_test.go index 73cd4f54e..353a46b36 100644 --- a/bdb/interface_test.go +++ b/bdb/interface_test.go @@ -6,14 +6,14 @@ import ( "github.com/vattle/sqlboiler/strmangle" ) -type mockDriver struct{} +type testMockDriver struct{} -func (m mockDriver) TranslateColumnType(c Column) Column { return c } -func (m mockDriver) UseLastInsertID() bool { return false } -func (m mockDriver) Open() error { return nil } -func (m mockDriver) Close() {} +func (m testMockDriver) TranslateColumnType(c Column) Column { return c } +func (m testMockDriver) UseLastInsertID() bool { return false } +func (m testMockDriver) Open() error { return nil } +func (m testMockDriver) Close() {} -func (m mockDriver) TableNames(whitelist, blacklist []string) ([]string, error) { +func (m testMockDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) { if len(whitelist) > 0 { return whitelist, nil } @@ -22,7 +22,7 @@ func (m mockDriver) TableNames(whitelist, blacklist []string) ([]string, error) } // Columns returns a list of mock columns -func (m mockDriver) Columns(tableName string) ([]Column, error) { +func (m testMockDriver) Columns(schema, tableName string) ([]Column, error) { return map[string][]Column{ "pilots": { {Name: "id", Type: "int", DBType: "integer"}, @@ -64,7 +64,7 @@ func (m mockDriver) Columns(tableName string) ([]Column, error) { } // ForeignKeyInfo returns a list of mock foreignkeys -func (m mockDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error) { +func (m testMockDriver) ForeignKeyInfo(schema, tableName string) ([]ForeignKey, error) { return map[string][]ForeignKey{ "jets": { {Table: "jets", Name: "jets_pilot_id_fk", Column: "pilot_id", ForeignTable: "pilots", ForeignColumn: "id", ForeignColumnUnique: true}, @@ -84,7 +84,7 @@ func (m mockDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error) { } // PrimaryKeyInfo returns mock primary key info for the passed in table name -func (m mockDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) { +func (m testMockDriver) PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, error) { return map[string]*PrimaryKey{ "pilots": {Name: "pilot_id_pkey", Columns: []string{"id"}}, "airports": {Name: "airport_id_pkey", Columns: []string{"id"}}, @@ -99,7 +99,7 @@ func (m mockDriver) PrimaryKeyInfo(tableName string) (*PrimaryKey, error) { func TestTables(t *testing.T) { t.Parallel() - tables, err := Tables(mockDriver{}, nil, nil) + tables, err := Tables(testMockDriver{}, "public", nil, nil) if err != nil { t.Error(err) } diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index 5bda771e9..3c1fa87ec 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -39,7 +39,7 @@ func init() { // for Postgres: "schema_name"."table_name", versus // simply "table_name" for MySQL (because it does not support real schemas) func SchemaTable(driver string, schema string, table string) string { - if driver == "postgres" { + if driver == "postgres" && schema != "public" { return fmt.Sprintf(`"%s"."%s"`, schema, table) } diff --git a/text_helpers_test.go b/text_helpers_test.go index a85a4e5ae..5a09e8a5f 100644 --- a/text_helpers_test.go +++ b/text_helpers_test.go @@ -12,7 +12,7 @@ import ( func TestTextsFromForeignKey(t *testing.T) { t.Parallel() - tables, err := bdb.Tables(&drivers.MockDriver{}, nil, nil) + tables, err := bdb.Tables(&drivers.MockDriver{}, "public", nil, nil) if err != nil { t.Fatal(err) } @@ -81,7 +81,7 @@ func TestTextsFromForeignKey(t *testing.T) { func TestTextsFromOneToOneRelationship(t *testing.T) { t.Parallel() - tables, err := bdb.Tables(&drivers.MockDriver{}, nil, nil) + tables, err := bdb.Tables(&drivers.MockDriver{}, "public", nil, nil) if err != nil { t.Fatal(err) } @@ -130,7 +130,7 @@ func TestTextsFromOneToOneRelationship(t *testing.T) { func TestTextsFromRelationship(t *testing.T) { t.Parallel() - tables, err := bdb.Tables(&drivers.MockDriver{}, nil, nil) + tables, err := bdb.Tables(&drivers.MockDriver{}, "public", nil, nil) if err != nil { t.Fatal(err) } From 9e6a3d5ee3bd0456db4bea9e303db0171e4a811a Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Sat, 10 Sep 2016 03:14:18 +1000 Subject: [PATCH 23/64] Add quote dialects --- bdb/drivers/mock.go | 15 +++++++++++++++ bdb/drivers/mysql.go | 15 +++++++++++++++ bdb/drivers/postgres.go | 15 +++++++++++++++ bdb/interface.go | 7 +++++++ boil/query.go | 20 ++++++++++++++++++++ sqlboiler.go | 11 +++++++++-- strmangle/strmangle.go | 12 ++++++++++++ templates.go | 6 ++++-- templates/singleton/boil_queries.tpl | 17 ++++++++++++----- 9 files changed, 109 insertions(+), 9 deletions(-) diff --git a/bdb/drivers/mock.go b/bdb/drivers/mock.go index 90d4e5401..6e0b8e1a6 100644 --- a/bdb/drivers/mock.go +++ b/bdb/drivers/mock.go @@ -123,3 +123,18 @@ func (m *MockDriver) Open() error { return nil } // Close mimics a database close call func (m *MockDriver) Close() {} + +// RightQuote is the quoting character for the right side of the identifier +func (m *MockDriver) RightQuote() string { + return "`" +} + +// LeftQuote is the quoting character for the left side of the identifier +func (m *MockDriver) LeftQuote() string { + return `"` +} + +// IndexPlaceholders returns true to indicate fake support of indexed placeholders +func (m *MockDriver) IndexPlaceholders() bool { + return false +} diff --git a/bdb/drivers/mysql.go b/bdb/drivers/mysql.go index ee6dfe643..d26956019 100644 --- a/bdb/drivers/mysql.go +++ b/bdb/drivers/mysql.go @@ -318,3 +318,18 @@ func mySQLIsValidated(typ string) bool { return false } + +// RightQuote is the quoting character for the right side of the identifier +func (m *MySQLDriver) RightQuote() string { + return "`" +} + +// LeftQuote is the quoting character for the left side of the identifier +func (m *MySQLDriver) LeftQuote() string { + return "`" +} + +// IndexPlaceholders returns false to indicate MySQL doesnt support indexed placeholders +func (m *MySQLDriver) IndexPlaceholders() bool { + return false +} diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index 7c27d4741..c39e3f51e 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -340,3 +340,18 @@ func psqlIsValidated(typ string) bool { return false } + +// RightQuote is the quoting character for the right side of the identifier +func (p *PostgresDriver) RightQuote() string { + return `"` +} + +// LeftQuote is the quoting character for the left side of the identifier +func (p *PostgresDriver) LeftQuote() string { + return `"` +} + +// IndexPlaceholders returns true to indicate PSQL supports indexed placeholders +func (p *PostgresDriver) IndexPlaceholders() bool { + return true +} diff --git a/bdb/interface.go b/bdb/interface.go index b4d6b7a5d..d26ef545e 100644 --- a/bdb/interface.go +++ b/bdb/interface.go @@ -22,6 +22,13 @@ type Interface interface { Open() error // Close the database connection Close() + + // Dialect helpers, these provide the values that will go into + // a boil.Dialect, so the query builder knows how to support + // your database driver properly. + LeftQuote() string + RightQuote() string + IndexPlaceholders() bool } // Tables returns the metadata for all tables, minus the tables diff --git a/boil/query.go b/boil/query.go index 2ea37e938..195e0b2f4 100644 --- a/boil/query.go +++ b/boil/query.go @@ -19,6 +19,7 @@ const ( // Query holds the state for the built up query type Query struct { executor Executor + dialect *Dialect plainSQL plainSQL load []string delete bool @@ -37,6 +38,20 @@ type Query struct { forlock string } +// Dialect holds values that direct the query builder +// how to build compatible queries for each database. +// Each database driver needs to implement functions +// that provide these values. +type Dialect struct { + // The left quote character for SQL identifiers + LQ string + // The right quote character for SQL identifiers + RQ string + // Bool flag indicating whether indexed + // placeholders ($1) are used, or ? placeholders. + IndexPlaceholders bool +} + type where struct { clause string orSeparator bool @@ -121,6 +136,11 @@ func GetExecutor(q *Query) Executor { return q.executor } +// SetDialect on the query. +func SetDialect(q *Query, dialect *Dialect) { + q.dialect = dialect +} + // SetSQL on the query. func SetSQL(q *Query, sql string, args ...interface{}) { q.plainSQL = plainSQL{sql: sql, args: args} diff --git a/sqlboiler.go b/sqlboiler.go index 0234cd860..d6e1bb262 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -32,8 +32,9 @@ const ( type State struct { Config *Config - Driver bdb.Interface - Tables []bdb.Table + Driver bdb.Interface + Tables []bdb.Table + Dialect boil.Dialect Templates *templateList TestTemplates *templateList @@ -102,6 +103,7 @@ func (s *State) Run(includeTests bool) error { PkgName: s.Config.PkgName, NoHooks: s.Config.NoHooks, NoAutoTimestamps: s.Config.NoAutoTimestamps, + Dialect: s.Dialect, StringFuncs: templateStringMappers, } @@ -135,6 +137,7 @@ func (s *State) Run(includeTests bool) error { NoHooks: s.Config.NoHooks, NoAutoTimestamps: s.Config.NoAutoTimestamps, Tags: s.Config.Tags, + Dialect: s.Dialect, StringFuncs: templateStringMappers, } @@ -246,6 +249,10 @@ func (s *State) initDriver(driverName string) error { return errors.New("An invalid driver name was provided") } + s.Dialect.LQ = s.Driver.LeftQuote() + s.Dialect.RQ = s.Driver.RightQuote() + s.Dialect.IndexPlaceholders = s.Driver.IndexPlaceholders() + return nil } diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index 3c1fa87ec..e62f83a57 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -46,6 +46,18 @@ func SchemaTable(driver string, schema string, table string) string { return fmt.Sprintf(`"%s"`, table) } +// WrapQuote wraps a quote character in quotes. +func WrapQuote(s string) string { + if s == `"` { + return "`\"`" + } + if s == "`" { + return "\"`\"" + } + + return fmt.Sprintf("`%s`", s) +} + // IdentQuote attempts to quote simple identifiers in SQL tatements func IdentQuote(s string) string { if strings.ToLower(s) == "null" { diff --git a/templates.go b/templates.go index 8d171f2cc..ce9425168 100644 --- a/templates.go +++ b/templates.go @@ -8,6 +8,7 @@ import ( "text/template" "github.com/vattle/sqlboiler/bdb" + "github.com/vattle/sqlboiler/boil" "github.com/vattle/sqlboiler/strmangle" ) @@ -22,8 +23,8 @@ type templateData struct { NoHooks bool NoAutoTimestamps bool Tags []string - - StringFuncs map[string]func(string) string + StringFuncs map[string]func(string) string + Dialect boil.Dialect } type templateList struct { @@ -116,6 +117,7 @@ var templateFunctions = template.FuncMap{ // String ops "quoteWrap": func(a string) string { return fmt.Sprintf(`"%s"`, a) }, "id": strmangle.Identifier, + "wrapQuote": strmangle.WrapQuote, // Pluralization "singular": strmangle.Singular, diff --git a/templates/singleton/boil_queries.tpl b/templates/singleton/boil_queries.tpl index 6cb607df2..162c764ae 100644 --- a/templates/singleton/boil_queries.tpl +++ b/templates/singleton/boil_queries.tpl @@ -1,13 +1,20 @@ +var dialect boil.Dialect = boil.Dialect{ + LQ: {{wrapQuote .Dialect.LQ}}, + RQ: {{wrapQuote .Dialect.RQ}}, + IndexPlaceholders: {{.Dialect.IndexPlaceholders}}, +} + // NewQueryG initializes a new Query using the passed in QueryMods func NewQueryG(mods ...qm.QueryMod) *boil.Query { - return NewQuery(boil.GetDB(), mods...) + return NewQuery(boil.GetDB(), mods...) } // NewQuery initializes a new Query using the passed in QueryMods func NewQuery(exec boil.Executor, mods ...qm.QueryMod) *boil.Query { - q := &boil.Query{} - boil.SetExecutor(q, exec) - qm.Apply(q, mods...) + q := &boil.Query{} + boil.SetExecutor(q, exec) + boil.SetDialect(q, &dialect) + qm.Apply(q, mods...) - return q + return q } From 419f2760c736835505c9dc9719c5738755739541 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Sat, 10 Sep 2016 03:30:46 +1000 Subject: [PATCH 24/64] Change quotes to bytes --- bdb/drivers/mock.go | 8 ++++---- bdb/drivers/mysql.go | 8 ++++---- bdb/drivers/postgres.go | 8 ++++---- bdb/interface.go | 4 ++-- boil/query.go | 4 ++-- strmangle/strmangle.go | 12 ------------ templates.go | 1 - templates/singleton/boil_queries.tpl | 4 ++-- 8 files changed, 18 insertions(+), 31 deletions(-) diff --git a/bdb/drivers/mock.go b/bdb/drivers/mock.go index 6e0b8e1a6..abe85a482 100644 --- a/bdb/drivers/mock.go +++ b/bdb/drivers/mock.go @@ -125,13 +125,13 @@ func (m *MockDriver) Open() error { return nil } func (m *MockDriver) Close() {} // RightQuote is the quoting character for the right side of the identifier -func (m *MockDriver) RightQuote() string { - return "`" +func (m *MockDriver) RightQuote() byte { + return '"' } // LeftQuote is the quoting character for the left side of the identifier -func (m *MockDriver) LeftQuote() string { - return `"` +func (m *MockDriver) LeftQuote() byte { + return '"' } // IndexPlaceholders returns true to indicate fake support of indexed placeholders diff --git a/bdb/drivers/mysql.go b/bdb/drivers/mysql.go index d26956019..0cb836149 100644 --- a/bdb/drivers/mysql.go +++ b/bdb/drivers/mysql.go @@ -320,13 +320,13 @@ func mySQLIsValidated(typ string) bool { } // RightQuote is the quoting character for the right side of the identifier -func (m *MySQLDriver) RightQuote() string { - return "`" +func (m *MySQLDriver) RightQuote() byte { + return '`' } // LeftQuote is the quoting character for the left side of the identifier -func (m *MySQLDriver) LeftQuote() string { - return "`" +func (m *MySQLDriver) LeftQuote() byte { + return '`' } // IndexPlaceholders returns false to indicate MySQL doesnt support indexed placeholders diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index c39e3f51e..08c08e9ce 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -342,13 +342,13 @@ func psqlIsValidated(typ string) bool { } // RightQuote is the quoting character for the right side of the identifier -func (p *PostgresDriver) RightQuote() string { - return `"` +func (p *PostgresDriver) RightQuote() byte { + return '"' } // LeftQuote is the quoting character for the left side of the identifier -func (p *PostgresDriver) LeftQuote() string { - return `"` +func (p *PostgresDriver) LeftQuote() byte { + return '"' } // IndexPlaceholders returns true to indicate PSQL supports indexed placeholders diff --git a/bdb/interface.go b/bdb/interface.go index d26ef545e..ab32c8c1a 100644 --- a/bdb/interface.go +++ b/bdb/interface.go @@ -26,8 +26,8 @@ type Interface interface { // Dialect helpers, these provide the values that will go into // a boil.Dialect, so the query builder knows how to support // your database driver properly. - LeftQuote() string - RightQuote() string + LeftQuote() byte + RightQuote() byte IndexPlaceholders() bool } diff --git a/boil/query.go b/boil/query.go index 195e0b2f4..ae6a08cca 100644 --- a/boil/query.go +++ b/boil/query.go @@ -44,9 +44,9 @@ type Query struct { // that provide these values. type Dialect struct { // The left quote character for SQL identifiers - LQ string + LQ byte // The right quote character for SQL identifiers - RQ string + RQ byte // Bool flag indicating whether indexed // placeholders ($1) are used, or ? placeholders. IndexPlaceholders bool diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index e62f83a57..3c1fa87ec 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -46,18 +46,6 @@ func SchemaTable(driver string, schema string, table string) string { return fmt.Sprintf(`"%s"`, table) } -// WrapQuote wraps a quote character in quotes. -func WrapQuote(s string) string { - if s == `"` { - return "`\"`" - } - if s == "`" { - return "\"`\"" - } - - return fmt.Sprintf("`%s`", s) -} - // IdentQuote attempts to quote simple identifiers in SQL tatements func IdentQuote(s string) string { if strings.ToLower(s) == "null" { diff --git a/templates.go b/templates.go index ce9425168..2ff4dc845 100644 --- a/templates.go +++ b/templates.go @@ -117,7 +117,6 @@ var templateFunctions = template.FuncMap{ // String ops "quoteWrap": func(a string) string { return fmt.Sprintf(`"%s"`, a) }, "id": strmangle.Identifier, - "wrapQuote": strmangle.WrapQuote, // Pluralization "singular": strmangle.Singular, diff --git a/templates/singleton/boil_queries.tpl b/templates/singleton/boil_queries.tpl index 162c764ae..eec7351c0 100644 --- a/templates/singleton/boil_queries.tpl +++ b/templates/singleton/boil_queries.tpl @@ -1,6 +1,6 @@ var dialect boil.Dialect = boil.Dialect{ - LQ: {{wrapQuote .Dialect.LQ}}, - RQ: {{wrapQuote .Dialect.RQ}}, + LQ: 0x{{printf "%x" .Dialect.LQ}}, + RQ: 0x{{printf "%x" .Dialect.RQ}}, IndexPlaceholders: {{.Dialect.IndexPlaceholders}}, } From f14301de7b54c0310e832ecefb032de2967085fa Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Sat, 10 Sep 2016 03:35:32 +1000 Subject: [PATCH 25/64] Add things to shut tests up --- bdb/interface_test.go | 15 +++++++++++++++ strmangle/strmangle.go | 4 ++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/bdb/interface_test.go b/bdb/interface_test.go index 353a46b36..41e4c3b29 100644 --- a/bdb/interface_test.go +++ b/bdb/interface_test.go @@ -96,6 +96,21 @@ func (m testMockDriver) PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, e }[tableName], nil } +// RightQuote is the quoting character for the right side of the identifier +func (m *testMockDriver) RightQuote() byte { + return '"' +} + +// LeftQuote is the quoting character for the left side of the identifier +func (m *testMockDriver) LeftQuote() byte { + return '"' +} + +// IndexPlaceholders returns true to indicate fake support of indexed placeholders +func (m *testMockDriver) IndexPlaceholders() bool { + return false +} + func TestTables(t *testing.T) { t.Parallel() diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index 3c1fa87ec..7ce994e67 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -47,7 +47,7 @@ func SchemaTable(driver string, schema string, table string) string { } // IdentQuote attempts to quote simple identifiers in SQL tatements -func IdentQuote(s string) string { +func IdentQuote(lq byte, rq byte, s string) string { if strings.ToLower(s) == "null" { return s } @@ -79,7 +79,7 @@ func IdentQuote(s string) string { } // IdentQuoteSlice applies IdentQuote to a slice. -func IdentQuoteSlice(s []string) []string { +func IdentQuoteSlice(lq byte, rq byte, s []string) []string { if len(s) == 0 { return s } From 793522650cb6c680177ed878aede7b5a2eeb669f Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Sat, 10 Sep 2016 05:15:50 +1000 Subject: [PATCH 26/64] Add lq, rq, and indexplaceholders args everywhere --- README.md | 4 + bdb/drivers/postgres.go | 4 +- bdb/interface_test.go | 6 +- boil/query_builders.go | 79 +++++++++++++------- boil/query_builders_test.go | 15 +++- boil/reflect_test.go | 17 +++-- strmangle/strmangle.go | 25 ++++--- strmangle/strmangle_test.go | 28 +++++-- templates/04_relationship_to_one.tpl | 2 +- templates/05_relationship_to_many.tpl | 4 +- templates/06_relationship_to_one_eager.tpl | 4 +- templates/07_relationship_to_many_eager.tpl | 8 +- templates/09_relationship_to_many_setops.tpl | 10 +-- templates/10_all.tpl | 2 +- templates/11_find.tpl | 4 +- templates/12_insert.tpl | 4 +- templates/13_update.tpl | 12 +-- templates/14_upsert.tpl | 2 +- templates/15_delete.tpl | 8 +- templates/16_reload.tpl | 6 +- templates/17_exists.tpl | 2 +- templates/singleton/boil_queries.tpl | 2 +- templates_test/relationship_to_many.tpl | 4 +- 23 files changed, 160 insertions(+), 92 deletions(-) diff --git a/README.md b/README.md index 09918daca..7beba2229 100644 --- a/README.md +++ b/README.md @@ -1061,6 +1061,10 @@ If your database uses multiple schemas you should generate a new package for eac Note that this only applies to databases that use real, SQL standard schemas (like PostgreSQL), not fake schemas (like MySQL). +#### Where is the homepage? + +The homepage for the [SQLBoiler](https://github.com/vattle/sqlboiler) [Golang ORM](https://github.com/vattle/sqlboiler) generator is located at: https://github.com/vattle/sqlboiler + ## Benchmarks If you'd like to run the benchmarks yourself check out our [boilbench](https://github.com/vattle/boilbench) repo. diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index 08c08e9ce..8615013a5 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -89,12 +89,12 @@ func (p *PostgresDriver) TableNames(schema string, whitelist, blacklist []string query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = $1`) args := []interface{}{schema} if len(whitelist) > 0 { - query += fmt.Sprintf(" and table_name in (%s);", strmangle.Placeholders(len(whitelist), 2, 1)) + query += fmt.Sprintf(" and table_name in (%s);", strmangle.Placeholders(true, len(whitelist), 2, 1)) for _, w := range whitelist { args = append(args, w) } } else if len(blacklist) > 0 { - query += fmt.Sprintf(" and table_name not in (%s);", strmangle.Placeholders(len(blacklist), 2, 1)) + query += fmt.Sprintf(" and table_name not in (%s);", strmangle.Placeholders(true, len(blacklist), 2, 1)) for _, b := range blacklist { args = append(args, b) } diff --git a/bdb/interface_test.go b/bdb/interface_test.go index 41e4c3b29..48be0886b 100644 --- a/bdb/interface_test.go +++ b/bdb/interface_test.go @@ -97,17 +97,17 @@ func (m testMockDriver) PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, e } // RightQuote is the quoting character for the right side of the identifier -func (m *testMockDriver) RightQuote() byte { +func (m testMockDriver) RightQuote() byte { return '"' } // LeftQuote is the quoting character for the left side of the identifier -func (m *testMockDriver) LeftQuote() byte { +func (m testMockDriver) LeftQuote() byte { return '"' } // IndexPlaceholders returns true to indicate fake support of indexed placeholders -func (m *testMockDriver) IndexPlaceholders() bool { +func (m testMockDriver) IndexPlaceholders() bool { return false } diff --git a/boil/query_builders.go b/boil/query_builders.go index 7ea79952f..3b1f8c655 100644 --- a/boil/query_builders.go +++ b/boil/query_builders.go @@ -57,7 +57,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { // Don't identQuoteSlice - writeAsStatements does this buf.WriteString(strings.Join(selectColsWithAs, ", ")) } else if hasSelectCols { - buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.selectCols), ", ")) + buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.selectCols), ", ")) } else if hasJoins { selectColsWithStars := writeStars(q) buf.WriteString(strings.Join(selectColsWithStars, ", ")) @@ -70,7 +70,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { buf.WriteByte(')') } - fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.from), ", ")) + fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", ")) if len(q.joins) > 0 { argsLen := len(args) @@ -82,7 +82,12 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { fmt.Fprintf(joinBuf, " INNER JOIN %s", j.clause) args = append(args, j.args...) } - resp, _ := convertQuestionMarks(joinBuf.String(), argsLen+1) + var resp string + if q.dialect.IndexPlaceholders { + resp, _ = convertQuestionMarks(joinBuf.String(), argsLen+1) + } else { + resp = joinBuf.String() + } fmt.Fprintf(buf, resp) strmangle.PutBuffer(joinBuf) } @@ -110,7 +115,7 @@ func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { buf := strmangle.GetBuffer() buf.WriteString("DELETE FROM ") - buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", ")) + buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", ")) where, whereArgs := whereClause(q, 1) if len(whereArgs) != 0 { @@ -135,7 +140,7 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { buf := strmangle.GetBuffer() buf.WriteString("UPDATE ") - buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", ")) + buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", ")) cols := make(sort.StringSlice, len(q.update)) var args []interface{} @@ -150,13 +155,13 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { for i := 0; i < len(cols); i++ { args = append(args, q.update[cols[i]]) - cols[i] = strmangle.IdentQuote(cols[i]) + cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, cols[i]) } buf.WriteString(fmt.Sprintf( " SET (%s) = (%s)", strings.Join(cols, ", "), - strmangle.Placeholders(len(cols), 1, 1)), + strmangle.Placeholders(q.dialect.IndexPlaceholders, len(cols), 1, 1)), ) where, whereArgs := whereClause(q, len(args)+1) @@ -179,10 +184,10 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { } // BuildUpsertQuery builds a SQL statement string using the upsertData provided. -func BuildUpsertQuery(tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string { - conflict = strmangle.IdentQuoteSlice(conflict) - whitelist = strmangle.IdentQuoteSlice(whitelist) - ret = strmangle.IdentQuoteSlice(ret) +func BuildUpsertQuery(dia Dialect, tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string { + conflict = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, conflict) + whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist) + ret = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, ret) buf := strmangle.GetBuffer() defer strmangle.PutBuffer(buf) @@ -192,7 +197,7 @@ func BuildUpsertQuery(tableName string, updateOnConflict bool, ret, update, conf "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT ", tableName, strings.Join(whitelist, ", "), - strmangle.Placeholders(len(whitelist), 1, 1), + strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1), ) if !updateOnConflict || len(update) == 0 { @@ -206,7 +211,7 @@ func BuildUpsertQuery(tableName string, updateOnConflict bool, ret, update, conf if i != 0 { buf.WriteByte(',') } - quoted := strmangle.IdentQuote(v) + quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v) buf.WriteString(quoted) buf.WriteString(" = EXCLUDED.") buf.WriteString(quoted) @@ -237,7 +242,12 @@ func writeModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}) { fmt.Fprintf(havingBuf, j.clause) *args = append(*args, j.args...) } - resp, _ := convertQuestionMarks(havingBuf.String(), argsLen+1) + var resp string + if q.dialect.IndexPlaceholders { + resp, _ = convertQuestionMarks(havingBuf.String(), argsLen+1) + } else { + resp = havingBuf.String() + } fmt.Fprintf(buf, resp) strmangle.PutBuffer(havingBuf) } @@ -264,7 +274,7 @@ func writeStars(q *Query) []string { for i, f := range q.from { toks := strings.Split(f, " ") if len(toks) == 1 { - cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(toks[0])) + cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, toks[0])) continue } @@ -276,7 +286,7 @@ func writeStars(q *Query) []string { if len(alias) != 0 { name = alias } - cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(name)) + cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, name)) } return cols @@ -292,7 +302,7 @@ func writeAsStatements(q *Query) []string { toks := strings.Split(col, ".") if len(toks) == 1 { - cols[i] = strmangle.IdentQuote(col) + cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col) continue } @@ -301,7 +311,7 @@ func writeAsStatements(q *Query) []string { asParts[j] = strings.Trim(tok, `"`) } - cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(col), strings.Join(asParts, ".")) + cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col), strings.Join(asParts, ".")) } return cols @@ -335,7 +345,13 @@ func whereClause(q *Query, startAt int) (string, []interface{}) { args = append(args, where.args...) } - resp, _ := convertQuestionMarks(buf.String(), startAt) + var resp string + if q.dialect.IndexPlaceholders { + resp, _ = convertQuestionMarks(buf.String(), startAt) + } else { + resp = buf.String() + } + return resp, args } @@ -374,7 +390,7 @@ func inClause(q *Query, startAt int) (string, []interface{}) { // column name side, however if this case is being hit then the regexp // probably needs adjustment, or the user is passing in invalid clauses. if matches == nil { - clause, count := convertInQuestionMarks(in.clause, startAt, 1, ln) + clause, count := convertInQuestionMarks(q.dialect.IndexPlaceholders, in.clause, startAt, 1, ln) buf.WriteString(clause) startAt = startAt + count } else { @@ -384,11 +400,24 @@ func inClause(q *Query, startAt int) (string, []interface{}) { // of the clause to determine how many columns they are using. // This number determines the groupAt for the convert function. cols := strings.Split(leftSide, ",") - cols = strmangle.IdentQuoteSlice(cols) + cols = strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, cols) groupAt := len(cols) - leftClause, leftCount := convertQuestionMarks(strings.Join(cols, ","), startAt) - rightClause, rightCount := convertInQuestionMarks(rightSide, startAt+leftCount, groupAt, ln-leftCount) + var leftClause string + var leftCount int + if q.dialect.IndexPlaceholders { + leftClause, leftCount = convertQuestionMarks(strings.Join(cols, ","), startAt) + } else { + // Count the number of cols that are question marks, so we know + // how much to offset convertInQuestionMarks by + for _, v := range cols { + if v == "?" { + leftCount++ + } + } + leftClause = strings.Join(cols, ",") + } + rightClause, rightCount := convertInQuestionMarks(q.dialect.IndexPlaceholders, rightSide, startAt+leftCount, groupAt, ln-leftCount) buf.WriteString(leftClause) buf.WriteString(" IN ") buf.WriteString(rightClause) @@ -406,7 +435,7 @@ func inClause(q *Query, startAt int) (string, []interface{}) { // It uses groupAt to determine how many placeholders should be in each group, // for example, groupAt 2 would result in: (($1,$2),($3,$4)) // and groupAt 1 would result in ($1,$2,$3,$4) -func convertInQuestionMarks(clause string, startAt, groupAt, total int) (string, int) { +func convertInQuestionMarks(indexPlaceholders bool, clause string, startAt, groupAt, total int) (string, int) { if startAt == 0 || len(clause) == 0 { panic("Not a valid start number.") } @@ -428,7 +457,7 @@ func convertInQuestionMarks(clause string, startAt, groupAt, total int) (string, paramBuf.WriteString(clause[:foundAt]) paramBuf.WriteByte('(') - paramBuf.WriteString(strmangle.Placeholders(total, startAt, groupAt)) + paramBuf.WriteString(strmangle.Placeholders(indexPlaceholders, total, startAt, groupAt)) paramBuf.WriteByte(')') paramBuf.WriteString(clause[foundAt+1:]) diff --git a/boil/query_builders_test.go b/boil/query_builders_test.go index b5035d87c..db642b092 100644 --- a/boil/query_builders_test.go +++ b/boil/query_builders_test.go @@ -97,6 +97,7 @@ func TestBuildQuery(t *testing.T) { for i, test := range tests { filename := filepath.Join("_fixtures", fmt.Sprintf("%02d.sql", i)) + test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true} out, args := buildQuery(test.q) if *writeGoldenFiles { @@ -149,6 +150,7 @@ func TestWriteStars(t *testing.T) { } for i, test := range tests { + test.In.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true} selects := writeStars(&test.In) if !reflect.DeepEqual(selects, test.Out) { t.Errorf("writeStar test fail %d\nwant: %v\ngot: %v", i, test.Out, selects) @@ -275,6 +277,7 @@ func TestWhereClause(t *testing.T) { } for i, test := range tests { + test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true} result, _ := whereClause(&test.q, 1) if result != test.expect { t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result) @@ -407,6 +410,7 @@ func TestInClause(t *testing.T) { } for i, test := range tests { + test.q.dialect = &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true} result, args := inClause(&test.q, 1) if result != test.expect { t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result) @@ -489,7 +493,7 @@ func TestConvertInQuestionMarks(t *testing.T) { } for i, test := range tests { - res, count := convertInQuestionMarks(test.clause, test.start, test.group, test.total) + res, count := convertInQuestionMarks(true, test.clause, test.start, test.group, test.total) if res != test.expect { t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, res) } @@ -497,6 +501,14 @@ func TestConvertInQuestionMarks(t *testing.T) { t.Errorf("%d) Expected %d, got %d", i, test.total, count) } } + + res, count := convertInQuestionMarks(false, "?", 1, 3, 9) + if res != "((?,?,?),(?,?,?),(?,?,?))" { + t.Errorf("Mismatch between expected and result: %s", res) + } + if count != 9 { + t.Errorf("Expected 9 results, got %d", count) + } } func TestWriteAsStatements(t *testing.T) { @@ -512,6 +524,7 @@ func TestWriteAsStatements(t *testing.T) { `a.clown.run`, `COUNT(a)`, }, + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, } expect := []string{ diff --git a/boil/reflect_test.go b/boil/reflect_test.go index dab762a92..279e98b9f 100644 --- a/boil/reflect_test.go +++ b/boil/reflect_test.go @@ -44,7 +44,8 @@ func TestBindStruct(t *testing.T) { }{} query := &Query{ - from: []string{"fun"}, + from: []string{"fun"}, + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, } db, mock, err := sqlmock.New() @@ -83,7 +84,8 @@ func TestBindSlice(t *testing.T) { }{} query := &Query{ - from: []string{"fun"}, + from: []string{"fun"}, + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, } db, mock, err := sqlmock.New() @@ -133,7 +135,8 @@ func TestBindPtrSlice(t *testing.T) { }{} query := &Query{ - from: []string{"fun"}, + from: []string{"fun"}, + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, } db, mock, err := sqlmock.New() @@ -369,7 +372,8 @@ func TestBindSingular(t *testing.T) { }{} query := &Query{ - from: []string{"fun"}, + from: []string{"fun"}, + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, } db, mock, err := sqlmock.New() @@ -412,8 +416,9 @@ func TestBind_InnerJoin(t *testing.T) { }{} query := &Query{ - from: []string{"fun"}, - joins: []join{{kind: JoinInner, clause: "happy as h on fun.id = h.fun_id"}}, + from: []string{"fun"}, + joins: []join{{kind: JoinInner, clause: "happy as h on fun.id = h.fun_id"}}, + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, } db, mock, err := sqlmock.New() diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index 7ce994e67..82ff09eea 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -38,17 +38,17 @@ func init() { // using a database that supports real schemas, for example, // for Postgres: "schema_name"."table_name", versus // simply "table_name" for MySQL (because it does not support real schemas) -func SchemaTable(driver string, schema string, table string) string { +func SchemaTable(lq byte, rq byte, driver string, schema string, table string) string { if driver == "postgres" && schema != "public" { - return fmt.Sprintf(`"%s"."%s"`, schema, table) + return fmt.Sprintf(`%c%s%c.%c%s%c`, lq, schema, rq, lq, table, rq) } - return fmt.Sprintf(`"%s"`, table) + return fmt.Sprintf(`%c%s%c`, lq, table, rq) } // IdentQuote attempts to quote simple identifiers in SQL tatements func IdentQuote(lq byte, rq byte, s string) string { - if strings.ToLower(s) == "null" { + if strings.ToLower(s) == "null" || s == "?" { return s } @@ -65,14 +65,14 @@ func IdentQuote(lq byte, rq byte, s string) string { buf.WriteByte('.') } - if strings.HasPrefix(split, `"`) || strings.HasSuffix(split, `"`) || split == "*" { + if split[0] == lq || split[len(split)-1] == rq || split == "*" { buf.WriteString(split) continue } - buf.WriteByte('"') + buf.WriteByte(lq) buf.WriteString(split) - buf.WriteByte('"') + buf.WriteByte(rq) } return buf.String() @@ -86,7 +86,7 @@ func IdentQuoteSlice(lq byte, rq byte, s []string) []string { strs := make([]string, len(s)) for i, str := range s { - strs[i] = IdentQuote(str) + strs[i] = IdentQuote(lq, rq, str) } return strs @@ -381,7 +381,8 @@ func PrefixStringSlice(str string, strs []string) []string { // Placeholders generates the SQL statement placeholders for in queries. // For example, ($1,$2,$3),($4,$5,$6) etc. // It will start counting placeholders at "start". -func Placeholders(count int, start int, group int) string { +// If indexPlaceholders is false, it will convert to ? instead of $1 etc. +func Placeholders(indexPlaceholders bool, count int, start int, group int) string { buf := GetBuffer() defer PutBuffer(buf) @@ -400,7 +401,11 @@ func Placeholders(count int, start int, group int) string { buf.WriteByte(',') } } - buf.WriteString(fmt.Sprintf("$%d", start+i)) + if indexPlaceholders { + buf.WriteString(fmt.Sprintf("$%d", start+i)) + } else { + buf.WriteByte('?') + } } if group > 1 { buf.WriteByte(')') diff --git a/strmangle/strmangle_test.go b/strmangle/strmangle_test.go index c3626e2ad..5311ac906 100644 --- a/strmangle/strmangle_test.go +++ b/strmangle/strmangle_test.go @@ -29,7 +29,7 @@ func TestIdentQuote(t *testing.T) { } for _, test := range tests { - if got := IdentQuote(test.In); got != test.Out { + if got := IdentQuote('"', '"', test.In); got != test.Out { t.Errorf("want: %s, got: %s", test.Out, got) } } @@ -38,7 +38,7 @@ func TestIdentQuote(t *testing.T) { func TestIdentQuoteSlice(t *testing.T) { t.Parallel() - ret := IdentQuoteSlice([]string{`thing`, `null`}) + ret := IdentQuoteSlice('"', '"', []string{`thing`, `null`}) if ret[0] != `"thing"` { t.Error(ret[0]) } @@ -72,31 +72,43 @@ func TestIdentifier(t *testing.T) { func TestPlaceholders(t *testing.T) { t.Parallel() - x := Placeholders(1, 2, 1) + x := Placeholders(true, 1, 2, 1) want := "$2" if want != x { t.Errorf("want %s, got %s", want, x) } - x = Placeholders(5, 1, 1) + x = Placeholders(true, 5, 1, 1) want = "$1,$2,$3,$4,$5" if want != x { t.Errorf("want %s, got %s", want, x) } - x = Placeholders(6, 1, 2) + x = Placeholders(false, 5, 1, 1) + want = "?,?,?,?,?" + if want != x { + t.Errorf("want %s, got %s", want, x) + } + + x = Placeholders(true, 6, 1, 2) + want = "($1,$2),($3,$4),($5,$6)" + if want != x { + t.Errorf("want %s, got %s", want, x) + } + + x = Placeholders(true, 6, 1, 2) want = "($1,$2),($3,$4),($5,$6)" if want != x { t.Errorf("want %s, got %s", want, x) } - x = Placeholders(9, 1, 3) - want = "($1,$2,$3),($4,$5,$6),($7,$8,$9)" + x = Placeholders(false, 9, 1, 3) + want = "(?,?,?),(?,?,?),(?,?,?)" if want != x { t.Errorf("want %s, got %s", want, x) } - x = Placeholders(7, 1, 3) + x = Placeholders(true, 7, 1, 3) want = "($1,$2,$3),($4,$5,$6),($7)" if want != x { t.Errorf("want %s, got %s", want, x) diff --git a/templates/04_relationship_to_one.tpl b/templates/04_relationship_to_one.tpl index ca872a566..8d9f10c74 100644 --- a/templates/04_relationship_to_one.tpl +++ b/templates/04_relationship_to_one.tpl @@ -16,7 +16,7 @@ func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}(exec bo queryMods = append(queryMods, mods...) query := {{.ForeignTable.NamePluralGo}}(exec, queryMods...) - boil.SetFrom(query.Query, `{{schemaTable $tmplData.DriverName $tmplData.Schema .ForeignTable.Name}}`) + boil.SetFrom(query.Query, `{{schemaTable $tmplData.Dialect.LQ $tmplData.Dialect.RQ $tmplData.DriverName $tmplData.Schema .ForeignTable.Name}}`) return query } diff --git a/templates/05_relationship_to_many.tpl b/templates/05_relationship_to_many.tpl index db7e81525..e52dc1088 100644 --- a/templates/05_relationship_to_many.tpl +++ b/templates/05_relationship_to_many.tpl @@ -31,7 +31,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Na {{if .ToJoinTable -}} queryMods = append(queryMods, - qm.InnerJoin(`{{schemaTable $dot.DriverName $dot.Schema .JoinTable}} as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}"`), + qm.InnerJoin(`{{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}"`), qm.Where(`"{{id 1}}"."{{.JoinLocalColumn}}"=$1`, {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), ) {{else -}} @@ -41,7 +41,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Na {{end}} query := {{$rel.ForeignTable.NamePluralGo}}(exec, queryMods...) - boil.SetFrom(query.Query, `{{schemaTable $dot.DriverName $dot.Schema .ForeignTable}} as "{{id 0}}"`) + boil.SetFrom(query.Query, `{{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .ForeignTable}} as "{{id 0}}"`) return query } diff --git a/templates/06_relationship_to_one_eager.tpl b/templates/06_relationship_to_one_eager.tpl index a14cc438a..f8c55eed8 100644 --- a/templates/06_relationship_to_one_eager.tpl +++ b/templates/06_relationship_to_one_eager.tpl @@ -28,8 +28,8 @@ func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bo } query := fmt.Sprintf( - `select * from {{schemaTable $tmplData.DriverName $tmplData.Schema .ForeignKey.ForeignTable}} where "{{.ForeignKey.ForeignColumn}}" in (%s)`, - strmangle.Placeholders(count, 1, 1), + `select * from {{schemaTable $tmplData.Dialect.LQ $tmplData.Dialect.RQ $tmplData.DriverName $tmplData.Schema .ForeignKey.ForeignTable}} where "{{.ForeignKey.ForeignColumn}}" in (%s)`, + strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), ) if boil.DebugMode { diff --git a/templates/07_relationship_to_many_eager.tpl b/templates/07_relationship_to_many_eager.tpl index d99953e77..a79d75f63 100644 --- a/templates/07_relationship_to_many_eager.tpl +++ b/templates/07_relationship_to_many_eager.tpl @@ -38,13 +38,13 @@ func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singula {{if .ToJoinTable -}} query := fmt.Sprintf( - `select "{{id 0}}".*, "{{id 1}}"."{{.JoinLocalColumn}}" from {{schemaTable $dot.DriverName $dot.Schema .ForeignTable}} as "{{id 0}}" inner join {{schemaTable $dot.DriverName $dot.Schema .JoinTable}} as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}" where "{{id 1}}"."{{.JoinLocalColumn}}" in (%s)`, - strmangle.Placeholders(count, 1, 1), + `select "{{id 0}}".*, "{{id 1}}"."{{.JoinLocalColumn}}" from {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .ForeignTable}} as "{{id 0}}" inner join {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}" where "{{id 1}}"."{{.JoinLocalColumn}}" in (%s)`, + strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), ) {{else -}} query := fmt.Sprintf( - `select * from {{schemaTable $dot.DriverName $dot.Schema .ForeignTable}} where "{{.ForeignColumn}}" in (%s)`, - strmangle.Placeholders(count, 1, 1), + `select * from {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .ForeignTable}} where "{{.ForeignColumn}}" in (%s)`, + strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), ) {{end -}} diff --git a/templates/09_relationship_to_many_setops.tpl b/templates/09_relationship_to_many_setops.tpl index 1cc4c11ba..33e5ce468 100644 --- a/templates/09_relationship_to_many_setops.tpl +++ b/templates/09_relationship_to_many_setops.tpl @@ -39,7 +39,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function {{if .ToJoinTable -}} for _, rel := range related { - query := `insert into {{schemaTable $dot.DriverName $dot.Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)` + query := `insert into {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)` values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}} if boil.DebugMode { @@ -96,10 +96,10 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function // Sets related.R.{{$rel.Function.ForeignName}}'s {{$rel.Function.Name}} accordingly. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Set{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error { {{if .ToJoinTable -}} - query := `delete from {{schemaTable $dot.DriverName $dot.Schema .JoinTable}} where "{{.JoinLocalColumn}}" = $1` + query := `delete from {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} where "{{.JoinLocalColumn}}" = $1` values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} {{else -}} - query := `update {{schemaTable $dot.DriverName $dot.Schema .ForeignTable}} set "{{.ForeignColumn}}" = null where "{{.ForeignColumn}}" = $1` + query := `update {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .ForeignTable}} set "{{.ForeignColumn}}" = null where "{{.ForeignColumn}}" = $1` values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} {{end -}} if boil.DebugMode { @@ -140,8 +140,8 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Remove{{$rel.Funct var err error {{if .ToJoinTable -}} query := fmt.Sprintf( - `delete from {{schemaTable $dot.DriverName $dot.Schema .JoinTable}} where "{{.JoinLocalColumn}}" = $1 and "{{.JoinForeignColumn}}" in (%s)`, - strmangle.Placeholders(len(related), 1, 1), + `delete from {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} where "{{.JoinLocalColumn}}" = $1 and "{{.JoinForeignColumn}}" in (%s)`, + strmangle.Placeholders(dialect.IndexPlaceholders, len(related), 1, 1), ) values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} diff --git a/templates/10_all.tpl b/templates/10_all.tpl index 01489254f..62ddbeb7b 100644 --- a/templates/10_all.tpl +++ b/templates/10_all.tpl @@ -8,6 +8,6 @@ func {{$tableNamePlural}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { // {{$tableNamePlural}} retrieves all the records using an executor. func {{$tableNamePlural}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query { - mods = append(mods, qm.From(`{{schemaTable .DriverName .Schema .Table.Name}}`)) + mods = append(mods, qm.From(`{{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}}`)) return {{$varNameSingular}}Query{NewQuery(exec, mods...)} } diff --git a/templates/11_find.tpl b/templates/11_find.tpl index 6cbe97465..dfdc5b2a7 100644 --- a/templates/11_find.tpl +++ b/templates/11_find.tpl @@ -26,10 +26,10 @@ func Find{{$tableNameSingular}}(exec boil.Executor, {{$pkArgs}}, selectCols ...s sel := "*" if len(selectCols) > 0 { - sel = strings.Join(strmangle.IdentQuoteSlice(selectCols), ",") + sel = strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, selectCols), ",") } query := fmt.Sprintf( - `select %s from {{schemaTable .DriverName .Schema .Table.Name}} where {{whereClause 1 .Table.PKey.Columns}}`, sel, + `select %s from {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} where {{whereClause 1 .Table.PKey.Columns}}`, sel, ) q := boil.SQL(exec, query, {{$pkNames | join ", "}}) diff --git a/templates/12_insert.tpl b/templates/12_insert.tpl index 4300124a7..dba01c4d3 100644 --- a/templates/12_insert.tpl +++ b/templates/12_insert.tpl @@ -64,11 +64,11 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string if err != nil { return err } - cache.query = fmt.Sprintf(`INSERT INTO {{schemaTable .DriverName .Schema .Table.Name}} ("%s") VALUES (%s)`, strings.Join(wl, `","`), strmangle.Placeholders(len(wl), 1, 1)) + cache.query = fmt.Sprintf(`INSERT INTO {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} ("%s") VALUES (%s)`, strings.Join(wl, `","`), strmangle.Placeholders(dialect.IndexPlaceholders, len(wl), 1, 1)) if len(cache.retMapping) != 0 { {{if .UseLastInsertID -}} - cache.retQuery = fmt.Sprintf(`SELECT %s FROM {{schemaTable .DriverName .Schema .Table.Name}} WHERE %s`, strings.Join(returnColumns, `","`), strmangle.WhereClause(1, {{$varNameSingular}}PrimaryKeyColumns)) + cache.retQuery = fmt.Sprintf(`SELECT %s FROM {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} WHERE %s`, strings.Join(returnColumns, `","`), strmangle.WhereClause(1, {{$varNameSingular}}PrimaryKeyColumns)) {{else -}} cache.query += fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ",")) {{end -}} diff --git a/templates/13_update.tpl b/templates/13_update.tpl index ef623ace2..1aea39ad9 100644 --- a/templates/13_update.tpl +++ b/templates/13_update.tpl @@ -52,7 +52,7 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string if !cached { wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist) - cache.query = fmt.Sprintf(`UPDATE {{schemaTable .DriverName .Schema .Table.Name}} SET %s WHERE %s`, strmangle.SetParamNames(wl), strmangle.WhereClause(len(wl)+1, {{$varNameSingular}}PrimaryKeyColumns)) + cache.query = fmt.Sprintf(`UPDATE {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} SET %s WHERE %s`, strmangle.SetParamNames(wl), strmangle.WhereClause(len(wl)+1, {{$varNameSingular}}PrimaryKeyColumns)) cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...)) if err != nil { return err @@ -146,7 +146,7 @@ func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error i := 0 for name, value := range cols { - colNames[i] = strmangle.IdentQuote(name) + colNames[i] = strmangle.IdentQuote(dialect.LQ, dialect.RQ, name) args[i] = value i++ } @@ -155,11 +155,11 @@ func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error args = append(args, o.inPrimaryKeyArgs()...) sql := fmt.Sprintf( - `UPDATE {{schemaTable .DriverName .Schema .Table.Name}} SET (%s) = (%s) WHERE (%s) IN (%s)`, + `UPDATE {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} SET (%s) = (%s) WHERE (%s) IN (%s)`, strings.Join(colNames, ", "), - strmangle.Placeholders(len(colNames), 1, 1), - strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","), - strmangle.Placeholders(len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(colNames)+1, len({{$varNameSingular}}PrimaryKeyColumns)), + strmangle.Placeholders(dialect.IndexPlaceholders, len(colNames), 1, 1), + strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), + strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(colNames)+1, len({{$varNameSingular}}PrimaryKeyColumns)), ) if boil.DebugMode { diff --git a/templates/14_upsert.tpl b/templates/14_upsert.tpl index 55a50c126..2f41298a8 100644 --- a/templates/14_upsert.tpl +++ b/templates/14_upsert.tpl @@ -54,7 +54,7 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, updateOnConflict boo copy(conflict, {{$varNameSingular}}PrimaryKeyColumns) } - query := boil.BuildUpsertQuery("{{.Table.Name}}", updateOnConflict, ret, update, conflict, whitelist) + query := boil.BuildUpsertQuery(dialect, "{{.Table.Name}}", updateOnConflict, ret, update, conflict, whitelist) if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, query) diff --git a/templates/15_delete.tpl b/templates/15_delete.tpl index ebcc7448f..367e74ac2 100644 --- a/templates/15_delete.tpl +++ b/templates/15_delete.tpl @@ -43,7 +43,7 @@ func (o *{{$tableNameSingular}}) Delete(exec boil.Executor) error { args := o.inPrimaryKeyArgs() - sql := `DELETE FROM {{schemaTable .DriverName .Schema .Table.Name}} WHERE {{whereClause 1 .Table.PKey.Columns}}` + sql := `DELETE FROM {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} WHERE {{whereClause 1 .Table.PKey.Columns}}` if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, sql) @@ -132,9 +132,9 @@ func (o {{$tableNameSingular}}Slice) DeleteAll(exec boil.Executor) error { args := o.inPrimaryKeyArgs() sql := fmt.Sprintf( - `DELETE FROM {{schemaTable .DriverName .Schema .Table.Name}} WHERE (%s) IN (%s)`, - strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","), - strmangle.Placeholders(len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), + `DELETE FROM {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} WHERE (%s) IN (%s)`, + strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), + strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), ) if boil.DebugMode { diff --git a/templates/16_reload.tpl b/templates/16_reload.tpl index f2d18a9db..ac00577cc 100644 --- a/templates/16_reload.tpl +++ b/templates/16_reload.tpl @@ -67,9 +67,9 @@ func (o *{{$tableNameSingular}}Slice) ReloadAll(exec boil.Executor) error { args := o.inPrimaryKeyArgs() sql := fmt.Sprintf( - `SELECT {{schemaTable .DriverName .Schema .Table.Name}}.* FROM {{schemaTable .DriverName .Schema .Table.Name}} WHERE (%s) IN (%s)`, - strings.Join(strmangle.IdentQuoteSlice({{$varNameSingular}}PrimaryKeyColumns), ","), - strmangle.Placeholders(len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), + `SELECT {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}}.* FROM {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} WHERE (%s) IN (%s)`, + strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), + strmangle.Placeholders(dialect.IndexPlaceholders, len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), ) q := boil.SQL(exec, sql, args...) diff --git a/templates/17_exists.tpl b/templates/17_exists.tpl index 68b972906..713a8bbaf 100644 --- a/templates/17_exists.tpl +++ b/templates/17_exists.tpl @@ -6,7 +6,7 @@ func {{$tableNameSingular}}Exists(exec boil.Executor, {{$pkArgs}}) (bool, error) { var exists bool - sql := `select exists(select 1 from {{schemaTable .DriverName .Schema .Table.Name}} where {{whereClause 1 .Table.PKey.Columns}} limit 1)` + sql := `select exists(select 1 from {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} where {{whereClause 1 .Table.PKey.Columns}} limit 1)` if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, sql) diff --git a/templates/singleton/boil_queries.tpl b/templates/singleton/boil_queries.tpl index eec7351c0..4db7187fa 100644 --- a/templates/singleton/boil_queries.tpl +++ b/templates/singleton/boil_queries.tpl @@ -1,4 +1,4 @@ -var dialect boil.Dialect = boil.Dialect{ +var dialect = boil.Dialect{ LQ: 0x{{printf "%x" .Dialect.LQ}}, RQ: 0x{{printf "%x" .Dialect.RQ}}, IndexPlaceholders: {{.Dialect.IndexPlaceholders}}, diff --git a/templates_test/relationship_to_many.tpl b/templates_test/relationship_to_many.tpl index 85c9da83f..e1c4ee0cb 100644 --- a/templates_test/relationship_to_many.tpl +++ b/templates_test/relationship_to_many.tpl @@ -41,11 +41,11 @@ func test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) { } {{if .ToJoinTable -}} - _, err = tx.Exec(`insert into {{schemaTable $dot.DriverName $dot.Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) + _, err = tx.Exec(`insert into {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) if err != nil { t.Fatal(err) } - _, err = tx.Exec(`insert into {{schemaTable $dot.DriverName $dot.Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) + _, err = tx.Exec(`insert into {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) if err != nil { t.Fatal(err) } From e62dfe369ff8cfd00fc2947db2767804f1f3cd2a Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Mon, 12 Sep 2016 03:40:59 +1000 Subject: [PATCH 27/64] Add array types and hstore types --- README.md | 10 + bdb/column.go | 5 + bdb/drivers/mysql.go | 24 +- bdb/drivers/postgres.go | 79 ++- boil/randomize/randomize.go | 126 +++- boil/types/array.go | 863 +++++++++++++++++++++++++++ boil/types/array_test.go | 1125 +++++++++++++++++++++++++++++++++++ boil/types/hstore.go | 135 +++++ imports.go | 18 + 9 files changed, 2321 insertions(+), 64 deletions(-) create mode 100644 boil/types/array.go create mode 100644 boil/types/array_test.go create mode 100644 boil/types/hstore.go diff --git a/README.md b/README.md index 7beba2229..3479428a4 100644 --- a/README.md +++ b/README.md @@ -1061,6 +1061,16 @@ If your database uses multiple schemas you should generate a new package for eac Note that this only applies to databases that use real, SQL standard schemas (like PostgreSQL), not fake schemas (like MySQL). +#### How do I use types.BytesArray for Postgres bytea arrays? + +Only "escaped format" is supported for types.BytesArray. This means that your byte slice needs to have +a format of "\\x00" (4 bytes per byte) opposed to "\x00" (1 byte per byte). This is to maintain compatibility +with all Postgres drivers. Example: + +`x := types.BytesArray{0: []byte("\\x68\\x69")}` + +Please note that multi-dimensional Postgres ARRAY types are not supported at this time. + #### Where is the homepage? The homepage for the [SQLBoiler](https://github.com/vattle/sqlboiler) [Golang ORM](https://github.com/vattle/sqlboiler) generator is located at: https://github.com/vattle/sqlboiler diff --git a/bdb/column.go b/bdb/column.go index 688bd95bc..b6cd3359d 100644 --- a/bdb/column.go +++ b/bdb/column.go @@ -5,6 +5,11 @@ import "github.com/vattle/sqlboiler/strmangle" // Column holds information about a database column. // Types are Go types, converted by TranslateColumnType. type Column struct { + // ArrType is the underlying data type of the Postgres + // ARRAY type. See here: + // https://www.postgresql.org/docs/9.1/static/infoschema-element-types.html + ArrType *string + UDTName string Name string Type string DBType string diff --git a/bdb/drivers/mysql.go b/bdb/drivers/mysql.go index 0cb836149..bee699d12 100644 --- a/bdb/drivers/mysql.go +++ b/bdb/drivers/mysql.go @@ -148,12 +148,11 @@ func (m *MySQLDriver) Columns(schema, tableName string) ([]bdb.Column, error) { } column := bdb.Column{ - Name: colName, - DBType: colType, - Default: colDefault, - Nullable: nullable == "YES", - Unique: unique, - Validated: psqlIsValidated(colType), + Name: colName, + DBType: colType, + Default: colDefault, + Nullable: nullable == "YES", + Unique: unique, } columns = append(columns, column) } @@ -306,19 +305,6 @@ func (m *MySQLDriver) TranslateColumnType(c bdb.Column) bdb.Column { return c } -var mySQLValidatedTypes = []string{} - -// isValidated checks if the database type is in the validatedTypes list. -func mySQLIsValidated(typ string) bool { - for _, v := range mySQLValidatedTypes { - if v == typ { - return true - } - } - - return false -} - // RightQuote is the quoting character for the right side of the identifier func (m *MySQLDriver) RightQuote() byte { return '`' diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index 8615013a5..77dbc85fd 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -19,9 +19,6 @@ type PostgresDriver struct { dbConn *sql.DB } -// validatedTypes are types that cannot be zero values in the database. -var psqlValidatedTypes = []string{"uuid"} - // NewPostgresDriver takes the database connection details as parameters and // returns a pointer to a PostgresDriver object. Note that it is required to // call PostgresDriver.Open() and PostgresDriver.Close() to open and close @@ -126,7 +123,7 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error) var columns []bdb.Column rows, err := p.dbConn.Query(` - select column_name, data_type, column_default, is_nullable, + select column_name, c.data_type, e.data_type, column_default, c.udt_name, is_nullable, (select exists( select 1 from information_schema.constraint_column_usage as ccu @@ -142,8 +139,10 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error) where pgix.schemaname = $1 and pgix.tablename = c.table_name and pga.attname = c.column_name and pgi.indisunique = true )) as is_unique - from information_schema.columns as c - where table_name=$2 and table_schema = $1; + from information_schema.columns as c LEFT JOIN information_schema.element_types e + ON ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier) + = (e.object_catalog, e.object_schema, e.object_name, e.object_type, e.collection_type_identifier)) + where c.table_name=$2 and c.table_schema = $1; `, schema, tableName) if err != nil { @@ -152,10 +151,11 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error) defer rows.Close() for rows.Next() { - var colName, colType, colDefault, nullable string + var colName, udtName, colType, colDefault, nullable string + var elementType *string var unique bool var defaultPtr *string - if err := rows.Scan(&colName, &colType, &defaultPtr, &nullable, &unique); err != nil { + if err := rows.Scan(&colName, &colType, &elementType, &defaultPtr, &udtName, &nullable, &unique); err != nil { return nil, errors.Wrapf(err, "unable to scan for table %s", tableName) } @@ -166,12 +166,13 @@ func (p *PostgresDriver) Columns(schema, tableName string) ([]bdb.Column, error) } column := bdb.Column{ - Name: colName, - DBType: colType, - Default: colDefault, - Nullable: nullable == "YES", - Unique: unique, - Validated: psqlIsValidated(colType), + Name: colName, + DBType: colType, + ArrType: elementType, + UDTName: udtName, + Default: colDefault, + Nullable: nullable == "YES", + Unique: unique, } columns = append(columns, column) } @@ -297,6 +298,21 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.Type = "null.Bool" case "date", "time", "timestamp without time zone", "timestamp with time zone": c.Type = "null.Time" + case "ARRAY": + if c.ArrType == nil { + panic("unable to get postgres ARRAY underlying type") + } + c.Type = getArrayType(c) + // Make DBType something like ARRAYinteger for parsing with randomize.Struct + c.DBType = c.DBType + *c.ArrType + case "USER-DEFINED": + if c.UDTName == "hstore" { + c.Type = "types.Hstore" + c.DBType = "hstore" + } else { + c.Type = "string" + fmt.Printf("Warning: Incompatible data type detected: %s", c.UDTName) + } default: c.Type = "null.String" } @@ -322,6 +338,18 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.Type = "bool" case "date", "time", "timestamp without time zone", "timestamp with time zone": c.Type = "time.Time" + case "ARRAY": + c.Type = getArrayType(c) + // Make DBType something like ARRAYinteger for parsing with randomize.Struct + c.DBType = c.DBType + *c.ArrType + case "USER-DEFINED": + if c.UDTName == "hstore" { + c.Type = "types.Hstore" + c.DBType = "hstore" + } else { + c.Type = "string" + fmt.Printf("Warning: Incompatible data type detected: %s", c.UDTName) + } default: c.Type = "string" } @@ -330,15 +358,22 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { return c } -// isValidated checks if the database type is in the validatedTypes list. -func psqlIsValidated(typ string) bool { - for _, v := range psqlValidatedTypes { - if v == typ { - return true - } +// getArrayType returns the correct boil.Array type for each database type +func getArrayType(c bdb.Column) string { + switch *c.ArrType { + case "bigint", "bigserial", "integer", "serial", "smallint", "smallserial": + return "types.Int64Array" + case "bytea": + return "types.BytesArray" + case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": + return "types.StringArray" + case "bool": + return "types.BoolArray" + case "decimal", "numeric", "double precision", "real": + return "types.Float64Array" + default: + return "types.GenericArray" } - - return false } // RightQuote is the quoting character for the right side of the identifier diff --git a/boil/randomize/randomize.go b/boil/randomize/randomize.go index 398d5e82c..255043855 100644 --- a/boil/randomize/randomize.go +++ b/boil/randomize/randomize.go @@ -2,17 +2,20 @@ package randomize import ( + "database/sql" "fmt" "math/rand" "reflect" "regexp" "sort" "strconv" + "strings" "sync/atomic" "time" "gopkg.in/nullbio/null.v5" + "github.com/lib/pq/hstore" "github.com/pkg/errors" "github.com/satori/go.uuid" "github.com/vattle/sqlboiler/boil/types" @@ -20,32 +23,39 @@ import ( ) var ( - typeNullFloat32 = reflect.TypeOf(null.Float32{}) - typeNullFloat64 = reflect.TypeOf(null.Float64{}) - typeNullInt = reflect.TypeOf(null.Int{}) - typeNullInt8 = reflect.TypeOf(null.Int8{}) - typeNullInt16 = reflect.TypeOf(null.Int16{}) - typeNullInt32 = reflect.TypeOf(null.Int32{}) - typeNullInt64 = reflect.TypeOf(null.Int64{}) - typeNullUint = reflect.TypeOf(null.Uint{}) - typeNullUint8 = reflect.TypeOf(null.Uint8{}) - typeNullUint16 = reflect.TypeOf(null.Uint16{}) - typeNullUint32 = reflect.TypeOf(null.Uint32{}) - typeNullUint64 = reflect.TypeOf(null.Uint64{}) - typeNullString = reflect.TypeOf(null.String{}) - typeNullBool = reflect.TypeOf(null.Bool{}) - typeNullTime = reflect.TypeOf(null.Time{}) - typeNullBytes = reflect.TypeOf(null.Bytes{}) - typeNullJSON = reflect.TypeOf(null.JSON{}) - typeTime = reflect.TypeOf(time.Time{}) - typeJSON = reflect.TypeOf(types.JSON{}) - rgxValidTime = regexp.MustCompile(`[2-9]+`) + typeNullFloat32 = reflect.TypeOf(null.Float32{}) + typeNullFloat64 = reflect.TypeOf(null.Float64{}) + typeNullInt = reflect.TypeOf(null.Int{}) + typeNullInt8 = reflect.TypeOf(null.Int8{}) + typeNullInt16 = reflect.TypeOf(null.Int16{}) + typeNullInt32 = reflect.TypeOf(null.Int32{}) + typeNullInt64 = reflect.TypeOf(null.Int64{}) + typeNullUint = reflect.TypeOf(null.Uint{}) + typeNullUint8 = reflect.TypeOf(null.Uint8{}) + typeNullUint16 = reflect.TypeOf(null.Uint16{}) + typeNullUint32 = reflect.TypeOf(null.Uint32{}) + typeNullUint64 = reflect.TypeOf(null.Uint64{}) + typeNullString = reflect.TypeOf(null.String{}) + typeNullBool = reflect.TypeOf(null.Bool{}) + typeNullTime = reflect.TypeOf(null.Time{}) + typeNullBytes = reflect.TypeOf(null.Bytes{}) + typeNullJSON = reflect.TypeOf(null.JSON{}) + typeTime = reflect.TypeOf(time.Time{}) + typeJSON = reflect.TypeOf(types.JSON{}) + typeInt64Array = reflect.TypeOf(types.Int64Array{}) + typeBytesArray = reflect.TypeOf(types.BytesArray{}) + typeBoolArray = reflect.TypeOf(types.BoolArray{}) + typeFloat64Array = reflect.TypeOf(types.Float64Array{}) + typeStringArray = reflect.TypeOf(types.StringArray{}) + typeGenericArray = reflect.TypeOf(types.GenericArray{}) + typeHstore = reflect.TypeOf(types.Hstore{}) + rgxValidTime = regexp.MustCompile(`[2-9]+`) validatedTypes = []string{ "inet", "line", "uuid", "interval", "json", "jsonb", "box", "cidr", "circle", "lseg", "macaddr", "path", "pg_lsn", "point", - "polygon", "txid_snapshot", "money", + "polygon", "txid_snapshot", "money", "hstore", } ) @@ -218,7 +228,14 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo value = null.NewJSON([]byte(fmt.Sprintf(`"%s"`, randStr(s, 1))), true) field.Set(reflect.ValueOf(value)) return nil + case typeHstore: + value := hstore.Hstore{Map: map[string]sql.NullString{}} + value.Map[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} + value.Map[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} + field.Set(reflect.ValueOf(value)) + return nil } + } else { switch kind { case reflect.String: @@ -279,6 +296,12 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo value = []byte(fmt.Sprintf(`"%s"`, randStr(s, 1))) field.Set(reflect.ValueOf(value)) return nil + case typeHstore: + value := types.Hstore{} + value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} + value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} + field.Set(reflect.ValueOf(value)) + return nil } } } @@ -293,8 +316,15 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo isNull = false } - // Retrieve the value to be returned - if kind == reflect.Struct { + // If it's a Postgres array, treat it like one + if strings.HasPrefix(fieldType, "ARRAY") { + if isNull { + value = getArrayNullValue(typ) + } else { + value = getArrayRandValue(s, typ) + } + // Retrieve the value to be returned + } else if kind == reflect.Struct { if isNull { value = getStructNullValue(typ) } else { @@ -317,6 +347,45 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo return nil } +func getArrayNullValue(typ reflect.Type) interface{} { + fmt.Println(typ) + switch typ { + case typeInt64Array: + return types.Int64Array{} + case typeFloat64Array: + return types.Float64Array{} + case typeBoolArray: + return types.BoolArray{} + case typeStringArray: + return types.StringArray{} + case typeBytesArray: + return types.BytesArray{} + case typeGenericArray: + return types.GenericArray{} + } + + return nil +} + +func getArrayRandValue(s *Seed, typ reflect.Type) interface{} { + switch typ { + case typeInt64Array: + return types.Int64Array{int64(s.nextInt()), int64(s.nextInt())} + case typeFloat64Array: + return types.Float64Array{float64(s.nextInt()), float64(s.nextInt())} + case typeBoolArray: + return types.BoolArray{s.nextInt()%2 == 0, s.nextInt()%2 == 0, s.nextInt()%2 == 0} + case typeStringArray: + return types.StringArray{randStr(s, 4), randStr(s, 4), randStr(s, 4)} + case typeBytesArray: + return types.BytesArray{randByteSlice(s, 4), randByteSlice(s, 4), randByteSlice(s, 4)} + case typeGenericArray: + return types.GenericArray{A: []types.JSON{randJSON(s, 4), randJSON(s, 4), randJSON(s, 4)}} + } + + return nil +} + // getStructNullValue for the matching type. func getStructNullValue(typ reflect.Type) interface{} { switch typ { @@ -505,6 +574,17 @@ func randByteSlice(s *Seed, ln int) []byte { return str } +func randJSON(s *Seed, ln int) types.JSON { + str := make(types.JSON, ln) + str[0] = '"' + for i := 1; i < ln-1; i++ { + str[i] = byte(s.nextInt() % 256) + } + str[ln-1] = '"' + + return str +} + func randPoint() string { a := rand.Intn(100) b := a + 1 diff --git a/boil/types/array.go b/boil/types/array.go new file mode 100644 index 000000000..e8ddfabc1 --- /dev/null +++ b/boil/types/array.go @@ -0,0 +1,863 @@ +// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included +// in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package types + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/hex" + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +var typeByteSlice = reflect.TypeOf([]byte{}) +var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() +var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +func encode(x interface{}) []byte { + switch v := x.(type) { + case int64: + return strconv.AppendInt(nil, v, 10) + case float64: + return strconv.AppendFloat(nil, v, 'f', -1, 64) + case []byte: + return encodeBytes(v) + case string: + return []byte(v) + case bool: + return strconv.AppendBool(nil, v) + case time.Time: + return formatTimestamp(v) + + default: + panic(fmt.Errorf("encode: unknown type for %T", v)) + } +} + +// FormatTimestamp formats t into Postgres' text format for timestamps. +func formatTimestamp(t time.Time) []byte { + // Need to send dates before 0001 A.D. with " BC" suffix, instead of the + // minus sign preferred by Go. + // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on + bc := false + if t.Year() <= 0 { + // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" + t = t.AddDate((-t.Year())*2+1, 0, 0) + bc = true + } + b := []byte(t.Format(time.RFC3339Nano)) + + _, offset := t.Zone() + offset = offset % 60 + if offset != 0 { + // RFC3339Nano already printed the minus sign + if offset < 0 { + offset = -offset + } + + b = append(b, ':') + if offset < 10 { + b = append(b, '0') + } + b = strconv.AppendInt(b, int64(offset), 10) + } + + if bc { + b = append(b, " BC"...) + } + return b +} + +func encodeBytes(v []byte) (result []byte) { + for _, b := range v { + if b == '\\' { + result = append(result, '\\', '\\') + } else if b < 0x20 || b > 0x7e { + result = append(result, []byte(fmt.Sprintf("\\%03o", b))...) + } else { + result = append(result, b) + } + } + + return result +} + +// Parse a bytea value received from the server. Both "hex" and the legacy +// "escape" format are supported. +func parseBytes(s []byte) (result []byte, err error) { + if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { + // bytea_output = hex + s = s[2:] // trim off leading "\\x" + result = make([]byte, hex.DecodedLen(len(s))) + _, err := hex.Decode(result, s) + if err != nil { + return nil, err + } + } else { + for len(s) > 0 { + if s[0] == '\\' { + // escaped '\\' + if len(s) >= 2 && s[1] == '\\' { + result = append(result, '\\') + s = s[2:] + continue + } + + // '\\' followed by an octal number + if len(s) < 4 { + return nil, fmt.Errorf("invalid bytea sequence %v", s) + } + r, err := strconv.ParseInt(string(s[1:4]), 8, 9) + if err != nil { + return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) + } + result = append(result, byte(r)) + s = s[4:] + } else { + // We hit an unescaped, raw byte. Try to read in as many as + // possible in one go. + i := bytes.IndexByte(s, '\\') + if i == -1 { + result = append(result, s...) + break + } + result = append(result, s[:i]...) + s = s[i:] + } + } + } + + return result, nil +} + +// Array returns the optimal driver.Valuer and sql.Scanner for an array or +// slice of any dimension. +// +// For example: +// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) +// +// var x []sql.NullInt64 +// db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x)) +// +// Scanning multi-dimensional arrays is not supported. Arrays where the lower +// bound is not one (such as `[0:0]={1}') are not supported. +func Array(a interface{}) interface { + driver.Valuer + sql.Scanner +} { + switch a := a.(type) { + case []bool: + return (*BoolArray)(&a) + case []float64: + return (*Float64Array)(&a) + case []int64: + return (*Int64Array)(&a) + case []string: + return (*StringArray)(&a) + + case *[]bool: + return (*BoolArray)(a) + case *[]float64: + return (*Float64Array)(a) + case *[]int64: + return (*Int64Array)(a) + case *[]string: + return (*StringArray)(a) + } + + return GenericArray{a} +} + +// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner +// to override the array delimiter used by GenericArray. +type ArrayDelimiter interface { + // ArrayDelimiter returns the delimiter character(s) for this element's type. + ArrayDelimiter() string +} + +// BoolArray represents a one-dimensional array of the PostgreSQL boolean type. +type BoolArray []bool + +// Scan implements the sql.Scanner interface. +func (a *BoolArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + } + + return fmt.Errorf("pq: cannot convert %T to BoolArray", src) +} + +func (a *BoolArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "BoolArray") + if err != nil { + return err + } + if len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(BoolArray, len(elems)) + for i, v := range elems { + if len(v) != 1 { + return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) + } + switch v[0] { + case 't': + b[i] = true + case 'f': + b[i] = false + default: + return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a BoolArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be exactly two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1+2*n) + + for i := 0; i < n; i++ { + b[2*i] = ',' + if a[i] { + b[1+2*i] = 't' + } else { + b[1+2*i] = 'f' + } + } + + b[0] = '{' + b[2*n] = '}' + + return string(b), nil + } + + return "{}", nil +} + +// BytesArray represents a one-dimensional array of the PostgreSQL bytea type. +type BytesArray [][]byte + +// Scan implements the sql.Scanner interface. +func (a *BytesArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + } + + return fmt.Errorf("pq: cannot convert %T to BytesArray", src) +} + +func (a *BytesArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "BytesArray") + if err != nil { + return err + } + if len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(BytesArray, len(elems)) + for i, v := range elems { + b[i], err = parseBytes(v) + if err != nil { + return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. It uses the "hex" format which +// is only supported on PostgreSQL 9.0 or newer. +func (a BytesArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // 3*N bytes of hex formatting, and N-1 bytes of delimiters. + size := 1 + 6*n + for _, x := range a { + size += hex.EncodedLen(len(x)) + } + + b := make([]byte, size) + + for i, s := 0, b; i < n; i++ { + o := copy(s, `,"\\x`) + o += hex.Encode(s[o:], a[i]) + s[o] = '"' + s = s[o+1:] + } + + b[0] = '{' + b[size-1] = '}' + + return string(b), nil + } + + return "{}", nil +} + +// Float64Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float64Array []float64 + +// Scan implements the sql.Scanner interface. +func (a *Float64Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + } + + return fmt.Errorf("pq: cannot convert %T to Float64Array", src) +} + +func (a *Float64Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float64Array") + if err != nil { + return err + } + if len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float64Array, len(elems)) + for i, v := range elems { + if b[i], err = strconv.ParseFloat(string(v), 64); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, a[0], 'f', -1, 64) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, a[i], 'f', -1, 64) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// GenericArray implements the driver.Valuer and sql.Scanner interfaces for +// an array or slice of any dimension. +type GenericArray struct{ A interface{} } + +func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { + var assign func([]byte, reflect.Value) error + var del = "," + + // TODO calculate the assign function for other types + // TODO repeat this section on the element type of arrays or slices (multidimensional) + { + if reflect.PtrTo(rt).Implements(typeSQLScanner) { + // dest is always addressable because it is an element of a slice. + assign = func(src []byte, dest reflect.Value) (err error) { + ss := dest.Addr().Interface().(sql.Scanner) + if src == nil { + err = ss.Scan(nil) + } else { + err = ss.Scan(src) + } + return + } + goto FoundType + } + + assign = func([]byte, reflect.Value) error { + return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt) + } + } + +FoundType: + + if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok { + del = ad.ArrayDelimiter() + } + + return rt, assign, del +} + +// Scan implements the sql.Scanner interface. +func (a GenericArray) Scan(src interface{}) error { + dpv := reflect.ValueOf(a.A) + switch { + case dpv.Kind() != reflect.Ptr: + return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) + case dpv.IsNil(): + return fmt.Errorf("pq: destination %T is nil", a.A) + } + + dv := dpv.Elem() + switch dv.Kind() { + case reflect.Slice: + case reflect.Array: + default: + return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) + } + + switch src := src.(type) { + case []byte: + return a.scanBytes(src, dv) + case string: + return a.scanBytes([]byte(src), dv) + } + + return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) +} + +func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { + dtype, assign, del := a.evaluateDestination(dv.Type().Elem()) + dims, elems, err := parseArray(src, []byte(del)) + if err != nil { + return err + } + + // TODO allow multidimensional + + if len(dims) > 1 { + return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented", + strings.Replace(fmt.Sprint(dims), " ", "][", -1)) + } + + // Treat a zero-dimensional array like an array with a single dimension of zero. + if len(dims) == 0 { + dims = append(dims, 0) + } + + for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() { + switch rt.Kind() { + case reflect.Slice: + case reflect.Array: + if rt.Len() != dims[i] { + return fmt.Errorf("pq: cannot convert ARRAY%s to %s", + strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type()) + } + default: + // TODO handle multidimensional + } + } + + values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) + for i, e := range elems { + if err := assign(e, values.Index(i)); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + } + + // TODO handle multidimensional + + switch dv.Kind() { + case reflect.Slice: + dv.Set(values.Slice(0, dims[0])) + case reflect.Array: + for i := 0; i < dims[0]; i++ { + dv.Index(i).Set(values.Index(i)) + } + } + + return nil +} + +// Value implements the driver.Valuer interface. +func (a GenericArray) Value() (driver.Value, error) { + if a.A == nil { + return nil, nil + } + + rv := reflect.ValueOf(a.A) + + if k := rv.Kind(); k != reflect.Array && k != reflect.Slice { + return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A) + } + + if n := rv.Len(); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 0, 1+2*n) + + b, _, err := appendArray(b, rv, n) + return string(b), err + } + + return "{}", nil +} + +// Int64Array represents a one-dimensional array of the PostgreSQL integer types. +type Int64Array []int64 + +// Scan implements the sql.Scanner interface. +func (a *Int64Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + } + + return fmt.Errorf("pq: cannot convert %T to Int64Array", src) +} + +func (a *Int64Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Int64Array") + if err != nil { + return err + } + if len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Int64Array, len(elems)) + for i, v := range elems { + if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, a[0], 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, a[i], 10) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// StringArray represents a one-dimensional array of the PostgreSQL character types. +type StringArray []string + +// Scan implements the sql.Scanner interface. +func (a *StringArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + } + + return fmt.Errorf("pq: cannot convert %T to StringArray", src) +} + +func (a *StringArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "StringArray") + if err != nil { + return err + } + if len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(StringArray, len(elems)) + for i, v := range elems { + if b[i] = string(v); v == nil { + return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a StringArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+3*n) + b[0] = '{' + + b = appendArrayQuotedBytes(b, []byte(a[0])) + for i := 1; i < n; i++ { + b = append(b, ',') + b = appendArrayQuotedBytes(b, []byte(a[i])) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// appendArray appends rv to the buffer, returning the extended buffer and +// the delimiter used between elements. +// +// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice. +func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { + var del string + var err error + + b = append(b, '{') + + if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil { + return b, del, err + } + + for i := 1; i < n; i++ { + b = append(b, del...) + if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil { + return b, del, err + } + } + + return append(b, '}'), del, nil +} + +// appendArrayElement appends rv to the buffer, returning the extended buffer +// and the delimiter to use before the next element. +// +// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted +// using driver.DefaultParameterConverter and the resulting []byte or string +// is double-quoted. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { + if k := rv.Kind(); k == reflect.Array || k == reflect.Slice { + if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) { + if n := rv.Len(); n > 0 { + return appendArray(b, rv, n) + } + + return b, "", nil + } + } + + var del = "," + var err error + var iv interface{} = rv.Interface() + + if ad, ok := iv.(ArrayDelimiter); ok { + del = ad.ArrayDelimiter() + } + + if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil { + return b, del, err + } + + switch v := iv.(type) { + case nil: + return append(b, "NULL"...), del, nil + case []byte: + return appendArrayQuotedBytes(b, v), del, nil + case string: + return appendArrayQuotedBytes(b, []byte(v)), del, nil + } + + b, err = appendValue(b, iv) + return b, del, err +} + +func appendArrayQuotedBytes(b, v []byte) []byte { + b = append(b, '"') + for { + i := bytes.IndexAny(v, `"\`) + if i < 0 { + b = append(b, v...) + break + } + if i > 0 { + b = append(b, v[:i]...) + } + b = append(b, '\\', v[i]) + v = v[i+1:] + } + return append(b, '"') +} + +func appendValue(b []byte, v driver.Value) ([]byte, error) { + return append(b, encode(v)...), nil +} + +// parseArray extracts the dimensions and elements of an array represented in +// text format. Only representations emitted by the backend are supported. +// Notably, whitespace around brackets and delimiters is significant, and NULL +// is case-sensitive. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { + var depth, i int + + if len(src) < 1 || src[0] != '{' { + return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0) + } + +Open: + for i < len(src) { + switch src[i] { + case '{': + depth++ + i++ + case '}': + elems = make([][]byte, 0) + goto Close + default: + break Open + } + } + dims = make([]int, i) + +Element: + for i < len(src) { + switch src[i] { + case '{': + depth++ + dims[depth-1] = 0 + i++ + case '"': + var elem = []byte{} + var escape bool + for i++; i < len(src); i++ { + if escape { + elem = append(elem, src[i]) + escape = false + } else { + switch src[i] { + default: + elem = append(elem, src[i]) + case '\\': + escape = true + case '"': + elems = append(elems, elem) + i++ + break Element + } + } + } + default: + for start := i; i < len(src); i++ { + if bytes.HasPrefix(src[i:], del) || src[i] == '}' { + elem := src[start:i] + if len(elem) == 0 { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + if bytes.Equal(elem, []byte("NULL")) { + elem = nil + } + elems = append(elems, elem) + break Element + } + } + } + } + + for i < len(src) { + if bytes.HasPrefix(src[i:], del) { + dims[depth-1]++ + i += len(del) + goto Element + } else if src[i] == '}' { + dims[depth-1]++ + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + +Close: + for i < len(src) { + if src[i] == '}' && depth > 0 { + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + if depth > 0 { + err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i) + } + if err == nil { + for _, d := range dims { + if (len(elems) % d) != 0 { + err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions") + } + } + } + return +} + +func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { + dims, elems, err := parseArray(src, del) + if err != nil { + return nil, err + } + if len(dims) > 1 { + return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) + } + return elems, err +} diff --git a/boil/types/array_test.go b/boil/types/array_test.go new file mode 100644 index 000000000..6e24e447f --- /dev/null +++ b/boil/types/array_test.go @@ -0,0 +1,1125 @@ +// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included +// in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package types + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "math/rand" + "reflect" + "strings" + "testing" +) + +func TestParseArray(t *testing.T) { + for _, tt := range []struct { + input string + delim string + dims []int + elems [][]byte + }{ + {`{}`, `,`, nil, [][]byte{}}, + {`{NULL}`, `,`, []int{1}, [][]byte{nil}}, + {`{a}`, `,`, []int{1}, [][]byte{{'a'}}}, + {`{a,b}`, `,`, []int{2}, [][]byte{{'a'}, {'b'}}}, + {`{{a,b}}`, `,`, []int{1, 2}, [][]byte{{'a'}, {'b'}}}, + {`{{a},{b}}`, `,`, []int{2, 1}, [][]byte{{'a'}, {'b'}}}, + {`{{{a,b},{c,d},{e,f}}}`, `,`, []int{1, 3, 2}, [][]byte{ + {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'}, + }}, + {`{""}`, `,`, []int{1}, [][]byte{{}}}, + {`{","}`, `,`, []int{1}, [][]byte{{','}}}, + {`{",",","}`, `,`, []int{2}, [][]byte{{','}, {','}}}, + {`{{",",","}}`, `,`, []int{1, 2}, [][]byte{{','}, {','}}}, + {`{{","},{","}}`, `,`, []int{2, 1}, [][]byte{{','}, {','}}}, + {`{{{",",","},{",",","},{",",","}}}`, `,`, []int{1, 3, 2}, [][]byte{ + {','}, {','}, {','}, {','}, {','}, {','}, + }}, + {`{"\"}"}`, `,`, []int{1}, [][]byte{{'"', '}'}}}, + {`{"\"","\""}`, `,`, []int{2}, [][]byte{{'"'}, {'"'}}}, + {`{{"\"","\""}}`, `,`, []int{1, 2}, [][]byte{{'"'}, {'"'}}}, + {`{{"\""},{"\""}}`, `,`, []int{2, 1}, [][]byte{{'"'}, {'"'}}}, + {`{{{"\"","\""},{"\"","\""},{"\"","\""}}}`, `,`, []int{1, 3, 2}, [][]byte{ + {'"'}, {'"'}, {'"'}, {'"'}, {'"'}, {'"'}, + }}, + {`{axyzb}`, `xyz`, []int{2}, [][]byte{{'a'}, {'b'}}}, + } { + dims, elems, err := parseArray([]byte(tt.input), []byte(tt.delim)) + + if err != nil { + t.Fatalf("Expected no error for %q, got %q", tt.input, err) + } + if !reflect.DeepEqual(dims, tt.dims) { + t.Errorf("Expected %v dimensions for %q, got %v", tt.dims, tt.input, dims) + } + if !reflect.DeepEqual(elems, tt.elems) { + t.Errorf("Expected %v elements for %q, got %v", tt.elems, tt.input, elems) + } + } +} + +func TestParseArrayError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "expected '{' at offset 0"}, + {`x`, "expected '{' at offset 0"}, + {`}`, "expected '{' at offset 0"}, + {`{`, "expected '}' at offset 1"}, + {`{{}`, "expected '}' at offset 3"}, + {`{}}`, "unexpected '}' at offset 2"}, + {`{,}`, "unexpected ',' at offset 1"}, + {`{,x}`, "unexpected ',' at offset 1"}, + {`{x,}`, "unexpected '}' at offset 3"}, + {`{""x}`, "unexpected 'x' at offset 3"}, + {`{{a},{b,c}}`, "multidimensional arrays must have elements with matching dimensions"}, + } { + _, _, err := parseArray([]byte(tt.input), []byte{','}) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + } +} + +func TestArrayScanner(t *testing.T) { + var s sql.Scanner + + s = Array(&[]bool{}) + if _, ok := s.(*BoolArray); !ok { + t.Errorf("Expected *BoolArray, got %T", s) + } + + s = Array(&[]float64{}) + if _, ok := s.(*Float64Array); !ok { + t.Errorf("Expected *Float64Array, got %T", s) + } + + s = Array(&[]int64{}) + if _, ok := s.(*Int64Array); !ok { + t.Errorf("Expected *Int64Array, got %T", s) + } + + s = Array(&[]string{}) + if _, ok := s.(*StringArray); !ok { + t.Errorf("Expected *StringArray, got %T", s) + } + + for _, tt := range []interface{}{ + &[]sql.Scanner{}, + &[][]bool{}, + &[][]float64{}, + &[][]int64{}, + &[][]string{}, + } { + s = Array(tt) + if _, ok := s.(GenericArray); !ok { + t.Errorf("Expected GenericArray for %T, got %T", tt, s) + } + } +} + +func TestArrayValuer(t *testing.T) { + var v driver.Valuer + + v = Array([]bool{}) + if _, ok := v.(*BoolArray); !ok { + t.Errorf("Expected *BoolArray, got %T", v) + } + + v = Array([]float64{}) + if _, ok := v.(*Float64Array); !ok { + t.Errorf("Expected *Float64Array, got %T", v) + } + + v = Array([]int64{}) + if _, ok := v.(*Int64Array); !ok { + t.Errorf("Expected *Int64Array, got %T", v) + } + + v = Array([]string{}) + if _, ok := v.(*StringArray); !ok { + t.Errorf("Expected *StringArray, got %T", v) + } + + for _, tt := range []interface{}{ + nil, + []driver.Value{}, + [][]bool{}, + [][]float64{}, + [][]int64{}, + [][]string{}, + } { + v = Array(tt) + if _, ok := v.(GenericArray); !ok { + t.Errorf("Expected GenericArray for %T, got %T", tt, v) + } + } +} + +func TestBoolArrayScanUnsupported(t *testing.T) { + var arr BoolArray + err := arr.Scan(1) + + if err == nil { + t.Fatal("Expected error when scanning from int") + } + if !strings.Contains(err.Error(), "int to BoolArray") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +var BoolArrayStringTests = []struct { + str string + arr BoolArray +}{ + {`{}`, BoolArray{}}, + {`{t}`, BoolArray{true}}, + {`{f,t}`, BoolArray{false, true}}, +} + +func TestBoolArrayScanBytes(t *testing.T) { + for _, tt := range BoolArrayStringTests { + bytes := []byte(tt.str) + arr := BoolArray{true, true, true} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkBoolArrayScanBytes(b *testing.B) { + var a BoolArray + var x interface{} = []byte(`{t,f,t,f,t,f,t,f,t,f}`) + + for i := 0; i < b.N; i++ { + a = BoolArray{} + a.Scan(x) + } +} + +func TestBoolArrayScanString(t *testing.T) { + for _, tt := range BoolArrayStringTests { + arr := BoolArray{true, true, true} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestBoolArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{t},{f}}`, "cannot convert ARRAY[2][1] to BoolArray"}, + {`{NULL}`, `could not parse boolean array index 0: invalid boolean ""`}, + {`{a}`, `could not parse boolean array index 0: invalid boolean "a"`}, + {`{t,b}`, `could not parse boolean array index 1: invalid boolean "b"`}, + {`{t,f,cd}`, `could not parse boolean array index 2: invalid boolean "cd"`}, + } { + arr := BoolArray{true, true, true} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, BoolArray{true, true, true}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestBoolArrayValue(t *testing.T) { + result, err := BoolArray(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = BoolArray([]bool{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = BoolArray([]bool{false, true, false}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{f,t,f}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkBoolArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]bool, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Intn(2) == 0 + } + a := BoolArray(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestBytesArrayScanUnsupported(t *testing.T) { + var arr BytesArray + err := arr.Scan(1) + + if err == nil { + t.Fatal("Expected error when scanning from int") + } + if !strings.Contains(err.Error(), "int to BytesArray") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +var BytesArrayStringTests = []struct { + str string + arr BytesArray +}{ + {`{}`, BytesArray{}}, + {`{NULL}`, BytesArray{nil}}, + {`{"\\xfeff"}`, BytesArray{{'\xFE', '\xFF'}}}, + {`{"\\xdead","\\xbeef"}`, BytesArray{{'\xDE', '\xAD'}, {'\xBE', '\xEF'}}}, +} + +func TestBytesArrayScanBytes(t *testing.T) { + for _, tt := range BytesArrayStringTests { + bytes := []byte(tt.str) + arr := BytesArray{{2}, {6}, {0, 0}} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkBytesArrayScanBytes(b *testing.B) { + var a BytesArray + var x interface{} = []byte(`{"\\xfe","\\xff","\\xdead","\\xbeef","\\xfe","\\xff","\\xdead","\\xbeef","\\xfe","\\xff"}`) + + for i := 0; i < b.N; i++ { + a = BytesArray{} + a.Scan(x) + } +} + +func TestBytesArrayScanString(t *testing.T) { + for _, tt := range BytesArrayStringTests { + arr := BytesArray{{2}, {6}, {0, 0}} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestBytesArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{"\\xfeff"},{"\\xbeef"}}`, "cannot convert ARRAY[2][1] to BytesArray"}, + {`{"\\abc"}`, "could not parse bytea array index 0: could not parse bytea value"}, + } { + arr := BytesArray{{2}, {6}, {0, 0}} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, BytesArray{{2}, {6}, {0, 0}}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestBytesArrayValue(t *testing.T) { + result, err := BytesArray(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = BytesArray([][]byte{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = BytesArray([][]byte{{'\xDE', '\xAD', '\xBE', '\xEF'}, {'\xFE', '\xFF'}, {}}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{"\\xdeadbeef","\\xfeff","\\x"}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkBytesArrayValue(b *testing.B) { + rand.Seed(1) + x := make([][]byte, 10) + for i := 0; i < len(x); i++ { + x[i] = make([]byte, len(x)) + for j := 0; j < len(x); j++ { + x[i][j] = byte(rand.Int()) + } + } + a := BytesArray(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestFloat64ArrayScanUnsupported(t *testing.T) { + var arr Float64Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Float64Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +var Float64ArrayStringTests = []struct { + str string + arr Float64Array +}{ + {`{}`, Float64Array{}}, + {`{1.2}`, Float64Array{1.2}}, + {`{3.456,7.89}`, Float64Array{3.456, 7.89}}, + {`{3,1,2}`, Float64Array{3, 1, 2}}, +} + +func TestFloat64ArrayScanBytes(t *testing.T) { + for _, tt := range Float64ArrayStringTests { + bytes := []byte(tt.str) + arr := Float64Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkFloat64ArrayScanBytes(b *testing.B) { + var a Float64Array + var x interface{} = []byte(`{1.2,3.4,5.6,7.8,9.01,2.34,5.67,8.90,1.234,5.678}`) + + for i := 0; i < b.N; i++ { + a = Float64Array{} + a.Scan(x) + } +} + +func TestFloat64ArrayScanString(t *testing.T) { + for _, tt := range Float64ArrayStringTests { + arr := Float64Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestFloat64ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5.6},{7.8}}`, "cannot convert ARRAY[2][1] to Float64Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5.6,a}`, "parsing array element index 1:"}, + {`{5.6,7.8,a}`, "parsing array element index 2:"}, + } { + arr := Float64Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Float64Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestFloat64ArrayValue(t *testing.T) { + result, err := Float64Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Float64Array([]float64{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Float64Array([]float64{1.2, 3.4, 5.6}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1.2,3.4,5.6}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkFloat64ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]float64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.NormFloat64() + } + a := Float64Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestInt64ArrayScanUnsupported(t *testing.T) { + var arr Int64Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Int64Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +var Int64ArrayStringTests = []struct { + str string + arr Int64Array +}{ + {`{}`, Int64Array{}}, + {`{12}`, Int64Array{12}}, + {`{345,678}`, Int64Array{345, 678}}, +} + +func TestInt64ArrayScanBytes(t *testing.T) { + for _, tt := range Int64ArrayStringTests { + bytes := []byte(tt.str) + arr := Int64Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkInt64ArrayScanBytes(b *testing.B) { + var a Int64Array + var x interface{} = []byte(`{1,2,3,4,5,6,7,8,9,0}`) + + for i := 0; i < b.N; i++ { + a = Int64Array{} + a.Scan(x) + } +} + +func TestInt64ArrayScanString(t *testing.T) { + for _, tt := range Int64ArrayStringTests { + arr := Int64Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestInt64ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5},{6}}`, "cannot convert ARRAY[2][1] to Int64Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5,a}`, "parsing array element index 1:"}, + {`{5,6,a}`, "parsing array element index 2:"}, + } { + arr := Int64Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Int64Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestInt64ArrayValue(t *testing.T) { + result, err := Int64Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Int64Array([]int64{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Int64Array([]int64{1, 2, 3}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1,2,3}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkInt64ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]int64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Int63() + } + a := Int64Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestStringArrayScanUnsupported(t *testing.T) { + var arr StringArray + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to StringArray") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +var StringArrayStringTests = []struct { + str string + arr StringArray +}{ + {`{}`, StringArray{}}, + {`{t}`, StringArray{"t"}}, + {`{f,1}`, StringArray{"f", "1"}}, + {`{"a\\b","c d",","}`, StringArray{"a\\b", "c d", ","}}, +} + +func TestStringArrayScanBytes(t *testing.T) { + for _, tt := range StringArrayStringTests { + bytes := []byte(tt.str) + arr := StringArray{"x", "x", "x"} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkStringArrayScanBytes(b *testing.B) { + var a StringArray + var x interface{} = []byte(`{a,b,c,d,e,f,g,h,i,j}`) + var y interface{} = []byte(`{"\a","\b","\c","\d","\e","\f","\g","\h","\i","\j"}`) + + for i := 0; i < b.N; i++ { + a = StringArray{} + a.Scan(x) + a = StringArray{} + a.Scan(y) + } +} + +func TestStringArrayScanString(t *testing.T) { + for _, tt := range StringArrayStringTests { + arr := StringArray{"x", "x", "x"} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestStringArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{a},{b}}`, "cannot convert ARRAY[2][1] to StringArray"}, + {`{NULL}`, "parsing array element index 0: cannot convert nil to string"}, + {`{a,NULL}`, "parsing array element index 1: cannot convert nil to string"}, + {`{a,b,NULL}`, "parsing array element index 2: cannot convert nil to string"}, + } { + arr := StringArray{"x", "x", "x"} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, StringArray{"x", "x", "x"}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestStringArrayValue(t *testing.T) { + result, err := StringArray(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = StringArray([]string{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = StringArray([]string{`a`, `\b`, `c"`, `d,e`}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{"a","\\b","c\"","d,e"}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkStringArrayValue(b *testing.B) { + x := make([]string, 10) + for i := 0; i < len(x); i++ { + x[i] = strings.Repeat(`abc"def\ghi`, 5) + } + a := StringArray(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestGenericArrayScanUnsupported(t *testing.T) { + var s string + var ss []string + + for _, tt := range []struct { + src, dest interface{} + err string + }{ + {nil, nil, "destination is not a pointer to array or slice"}, + {nil, true, "destination bool is not a pointer to array or slice"}, + {nil, &s, "destination *string is not a pointer to array or slice"}, + {nil, ss, "destination []string is not a pointer to array or slice"}, + {true, &ss, "bool to []string"}, + {`{{x}}`, &ss, "multidimensional ARRAY[1][1] is not implemented"}, + {`{{x},{x}}`, &ss, "multidimensional ARRAY[2][1] is not implemented"}, + {`{x}`, &ss, "scanning to string is not implemented"}, + } { + err := GenericArray{tt.dest}.Scan(tt.src) + + if err == nil { + t.Fatalf("Expected error for [%#v %#v]", tt.src, tt.dest) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for [%#v %#v], got %q", tt.err, tt.src, tt.dest, err) + } + } +} + +func TestGenericArrayScanScannerArrayBytes(t *testing.T) { + src, expected, nsa := []byte(`{NULL,abc,"\""}`), + [3]sql.NullString{{}, {String: `abc`, Valid: true}, {String: `"`, Valid: true}}, + [3]sql.NullString{{String: ``, Valid: true}, {}, {}} + + if err := (GenericArray{&nsa}).Scan(src); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(nsa, expected) { + t.Errorf("Expected %v, got %v", expected, nsa) + } +} + +func TestGenericArrayScanScannerArrayString(t *testing.T) { + src, expected, nsa := `{NULL,"\"",xyz}`, + [3]sql.NullString{{}, {String: `"`, Valid: true}, {String: `xyz`, Valid: true}}, + [3]sql.NullString{{String: ``, Valid: true}, {}, {}} + + if err := (GenericArray{&nsa}).Scan(src); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(nsa, expected) { + t.Errorf("Expected %v, got %v", expected, nsa) + } +} + +func TestGenericArrayScanScannerSliceBytes(t *testing.T) { + src, expected, nss := []byte(`{NULL,abc,"\""}`), + []sql.NullString{{}, {String: `abc`, Valid: true}, {String: `"`, Valid: true}}, + []sql.NullString{{String: ``, Valid: true}, {}, {}, {}, {}} + + if err := (GenericArray{&nss}).Scan(src); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(nss, expected) { + t.Errorf("Expected %v, got %v", expected, nss) + } +} + +func BenchmarkGenericArrayScanScannerSliceBytes(b *testing.B) { + var a GenericArray + var x interface{} = []byte(`{a,b,c,d,e,f,g,h,i,j}`) + var y interface{} = []byte(`{"\a","\b","\c","\d","\e","\f","\g","\h","\i","\j"}`) + + for i := 0; i < b.N; i++ { + a = GenericArray{new([]sql.NullString)} + a.Scan(x) + a = GenericArray{new([]sql.NullString)} + a.Scan(y) + } +} + +func TestGenericArrayScanScannerSliceString(t *testing.T) { + src, expected, nss := `{NULL,"\"",xyz}`, + []sql.NullString{{}, {String: `"`, Valid: true}, {String: `xyz`, Valid: true}}, + []sql.NullString{{String: ``, Valid: true}, {}, {}} + + if err := (GenericArray{&nss}).Scan(src); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(nss, expected) { + t.Errorf("Expected %v, got %v", expected, nss) + } +} + +type TildeNullInt64 struct{ sql.NullInt64 } + +func (TildeNullInt64) ArrayDelimiter() string { return "~" } + +func TestGenericArrayScanDelimiter(t *testing.T) { + src, expected, tnis := `{12~NULL~76}`, + []TildeNullInt64{{sql.NullInt64{Int64: 12, Valid: true}}, {}, {sql.NullInt64{Int64: 76, Valid: true}}}, + []TildeNullInt64{{sql.NullInt64{Int64: 0, Valid: true}}, {}} + + if err := (GenericArray{&tnis}).Scan(src); err != nil { + t.Fatalf("Expected no error for %#v, got %v", src, err) + } + if !reflect.DeepEqual(tnis, expected) { + t.Errorf("Expected %v for %#v, got %v", expected, src, tnis) + } +} + +func TestGenericArrayScanErrors(t *testing.T) { + var sa [1]string + var nis []sql.NullInt64 + var pss *[]string + + for _, tt := range []struct { + src, dest interface{} + err string + }{ + {nil, pss, "destination *[]string is nil"}, + {`{`, &sa, "unable to parse"}, + {`{}`, &sa, "cannot convert ARRAY[0] to [1]string"}, + {`{x,x}`, &sa, "cannot convert ARRAY[2] to [1]string"}, + {`{x}`, &nis, `parsing array element index 0: converting`}, + } { + err := GenericArray{tt.dest}.Scan(tt.src) + + if err == nil { + t.Fatalf("Expected error for [%#v %#v]", tt.src, tt.dest) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for [%#v %#v], got %q", tt.err, tt.src, tt.dest, err) + } + } +} + +func TestGenericArrayValueUnsupported(t *testing.T) { + _, err := GenericArray{true}.Value() + + if err == nil { + t.Fatal("Expected error for bool") + } + if !strings.Contains(err.Error(), "bool to array") { + t.Errorf("Expected type to be mentioned, got %q", err) + } +} + +type ByteArrayValuer [1]byte +type ByteSliceValuer []byte +type FuncArrayValuer struct { + delimiter func() string + value func() (driver.Value, error) +} + +func (a ByteArrayValuer) Value() (driver.Value, error) { return a[:], nil } +func (b ByteSliceValuer) Value() (driver.Value, error) { return []byte(b), nil } +func (f FuncArrayValuer) ArrayDelimiter() string { return f.delimiter() } +func (f FuncArrayValuer) Value() (driver.Value, error) { return f.value() } + +func TestGenericArrayValue(t *testing.T) { + result, err := GenericArray{nil}.Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + Tilde := func(v driver.Value) FuncArrayValuer { + return FuncArrayValuer{ + func() string { return "~" }, + func() (driver.Value, error) { return v, nil }} + } + + for _, tt := range []struct { + result string + input interface{} + }{ + {`{}`, []bool{}}, + {`{true}`, []bool{true}}, + {`{true,false}`, []bool{true, false}}, + {`{true,false}`, [2]bool{true, false}}, + + {`{}`, [][]int{{}}}, + {`{}`, [][]int{{}, {}}}, + {`{{1}}`, [][]int{{1}}}, + {`{{1},{2}}`, [][]int{{1}, {2}}}, + {`{{1,2},{3,4}}`, [][]int{{1, 2}, {3, 4}}}, + {`{{1,2},{3,4}}`, [2][2]int{{1, 2}, {3, 4}}}, + + {`{"a","\\b","c\"","d,e"}`, []string{`a`, `\b`, `c"`, `d,e`}}, + {`{"a","\\b","c\"","d,e"}`, [][]byte{{'a'}, {'\\', 'b'}, {'c', '"'}, {'d', ',', 'e'}}}, + + {`{NULL}`, []*int{nil}}, + {`{0,NULL}`, []*int{new(int), nil}}, + + {`{NULL}`, []sql.NullString{{}}}, + {`{"\"",NULL}`, []sql.NullString{{String: `"`, Valid: true}, {}}}, + + {`{"a","b"}`, []ByteArrayValuer{{'a'}, {'b'}}}, + {`{{"a","b"},{"c","d"}}`, [][]ByteArrayValuer{{{'a'}, {'b'}}, {{'c'}, {'d'}}}}, + + {`{"e","f"}`, []ByteSliceValuer{{'e'}, {'f'}}}, + {`{{"e","f"},{"g","h"}}`, [][]ByteSliceValuer{{{'e'}, {'f'}}, {{'g'}, {'h'}}}}, + + {`{1~2}`, []FuncArrayValuer{Tilde(int64(1)), Tilde(int64(2))}}, + {`{{1~2}~{3~4}}`, [][]FuncArrayValuer{{Tilde(int64(1)), Tilde(int64(2))}, {Tilde(int64(3)), Tilde(int64(4))}}}, + } { + result, err := GenericArray{tt.input}.Value() + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.input, err) + } + if !reflect.DeepEqual(result, tt.result) { + t.Errorf("Expected %q for %q, got %q", tt.result, tt.input, result) + } + } +} + +func TestGenericArrayValueErrors(t *testing.T) { + var v []interface{} + + v = []interface{}{func() {}} + if _, err := (GenericArray{v}).Value(); err == nil { + t.Errorf("Expected error for %q, got nil", v) + } + + v = []interface{}{nil, func() {}} + if _, err := (GenericArray{v}).Value(); err == nil { + t.Errorf("Expected error for %q, got nil", v) + } +} + +func BenchmarkGenericArrayValueBools(b *testing.B) { + rand.Seed(1) + x := make([]bool, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Intn(2) == 0 + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func BenchmarkGenericArrayValueFloat64s(b *testing.B) { + rand.Seed(1) + x := make([]float64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.NormFloat64() + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func BenchmarkGenericArrayValueInt64s(b *testing.B) { + rand.Seed(1) + x := make([]int64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Int63() + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func BenchmarkGenericArrayValueByteSlices(b *testing.B) { + x := make([][]byte, 10) + for i := 0; i < len(x); i++ { + x[i] = bytes.Repeat([]byte(`abc"def\ghi`), 5) + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func BenchmarkGenericArrayValueStrings(b *testing.B) { + x := make([]string, 10) + for i := 0; i < len(x); i++ { + x[i] = strings.Repeat(`abc"def\ghi`, 5) + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} diff --git a/boil/types/hstore.go b/boil/types/hstore.go new file mode 100644 index 000000000..6d642c8be --- /dev/null +++ b/boil/types/hstore.go @@ -0,0 +1,135 @@ +// Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included +// in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package types + +import ( + "database/sql" + "database/sql/driver" + "strings" +) + +// Hstore is a wrapper for transferring Hstore values back and forth easily. +type Hstore map[string]sql.NullString + +// escapes and quotes hstore keys/values +// s should be a sql.NullString or string +func hQuote(s interface{}) string { + var str string + switch v := s.(type) { + case sql.NullString: + if !v.Valid { + return "NULL" + } + str = v.String + case string: + str = v + default: + panic("not a string or sql.NullString") + } + + str = strings.Replace(str, "\\", "\\\\", -1) + return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"` +} + +// Scan implements the Scanner interface. +// +// Note h is reallocated before the scan to clear existing values. If the +// hstore column's database value is NULL, then h is set to nil instead. +func (h *Hstore) Scan(value interface{}) error { + if value == nil { + h = nil + return nil + } + *h = make(map[string]sql.NullString) + var b byte + pair := [][]byte{{}, {}} + pi := 0 + inQuote := false + didQuote := false + sawSlash := false + bindex := 0 + for bindex, b = range value.([]byte) { + if sawSlash { + pair[pi] = append(pair[pi], b) + sawSlash = false + continue + } + + switch b { + case '\\': + sawSlash = true + continue + case '"': + inQuote = !inQuote + if !didQuote { + didQuote = true + } + continue + default: + if !inQuote { + switch b { + case ' ', '\t', '\n', '\r': + continue + case '=': + continue + case '>': + pi = 1 + didQuote = false + continue + case ',': + s := string(pair[1]) + if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { + (*h)[string(pair[0])] = sql.NullString{String: "", Valid: false} + } else { + (*h)[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} + } + pair[0] = []byte{} + pair[1] = []byte{} + pi = 0 + continue + } + } + } + pair[pi] = append(pair[pi], b) + } + if bindex > 0 { + s := string(pair[1]) + if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { + (*h)[string(pair[0])] = sql.NullString{String: "", Valid: false} + } else { + (*h)[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} + } + } + return nil +} + +// Value implements the driver Valuer interface. Note if h is nil, the +// database column value will be set to NULL. +func (h Hstore) Value() (driver.Value, error) { + if h == nil { + return nil, nil + } + parts := []string{} + for key, val := range h { + thispart := hQuote(key) + "=>" + hQuote(val) + parts = append(parts, thispart) + } + return []byte(strings.Join(parts, ",")), nil +} diff --git a/imports.go b/imports.go index c6510c23c..ced6daa7b 100644 --- a/imports.go +++ b/imports.go @@ -302,4 +302,22 @@ var importsBasedOnType = map[string]imports{ "types.JSON": { thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, }, + "types.BytesArray": { + thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, + }, + "types.GenericArray": { + thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, + }, + "types.Int64Array": { + thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, + }, + "types.Float64Array": { + thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, + }, + "types.BoolArray": { + thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, + }, + "types.Hstore": { + thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, + }, } From 9bcaf51493aa11a311aaf9ba29ca3136b41d941c Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Mon, 12 Sep 2016 07:22:17 +1000 Subject: [PATCH 28/64] Fix randomize for all array types, remove generic * GenericArray can't work with generated code. * Multi-dimensional arrays can't work because PSQL does not have a method to discover array depth. --- bdb/drivers/postgres.go | 4 +- boil/randomize/randomize.go | 88 +++++----- boil/types/array.go | 184 +++----------------- boil/types/array_test.go | 325 ------------------------------------ imports.go | 3 - testdata/test_schema.sql | 19 +++ 6 files changed, 89 insertions(+), 534 deletions(-) diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index 77dbc85fd..e992299b9 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -367,12 +367,12 @@ func getArrayType(c bdb.Column) string { return "types.BytesArray" case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": return "types.StringArray" - case "bool": + case "boolean": return "types.BoolArray" case "decimal", "numeric", "double precision", "real": return "types.Float64Array" default: - return "types.GenericArray" + return "types.StringArray" } } diff --git a/boil/randomize/randomize.go b/boil/randomize/randomize.go index 255043855..2277ef256 100644 --- a/boil/randomize/randomize.go +++ b/boil/randomize/randomize.go @@ -47,7 +47,6 @@ var ( typeBoolArray = reflect.TypeOf(types.BoolArray{}) typeFloat64Array = reflect.TypeOf(types.Float64Array{}) typeStringArray = reflect.TypeOf(types.StringArray{}) - typeGenericArray = reflect.TypeOf(types.GenericArray{}) typeHstore = reflect.TypeOf(types.Hstore{}) rgxValidTime = regexp.MustCompile(`[2-9]+`) @@ -318,11 +317,7 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo // If it's a Postgres array, treat it like one if strings.HasPrefix(fieldType, "ARRAY") { - if isNull { - value = getArrayNullValue(typ) - } else { - value = getArrayRandValue(s, typ) - } + value = getArrayRandValue(s, typ, fieldType) // Retrieve the value to be returned } else if kind == reflect.Struct { if isNull { @@ -347,27 +342,8 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo return nil } -func getArrayNullValue(typ reflect.Type) interface{} { - fmt.Println(typ) - switch typ { - case typeInt64Array: - return types.Int64Array{} - case typeFloat64Array: - return types.Float64Array{} - case typeBoolArray: - return types.BoolArray{} - case typeStringArray: - return types.StringArray{} - case typeBytesArray: - return types.BytesArray{} - case typeGenericArray: - return types.GenericArray{} - } - - return nil -} - -func getArrayRandValue(s *Seed, typ reflect.Type) interface{} { +func getArrayRandValue(s *Seed, typ reflect.Type, fieldType string) interface{} { + fieldType = strings.TrimLeft(fieldType, "ARRAY") switch typ { case typeInt64Array: return types.Int64Array{int64(s.nextInt()), int64(s.nextInt())} @@ -376,11 +352,54 @@ func getArrayRandValue(s *Seed, typ reflect.Type) interface{} { case typeBoolArray: return types.BoolArray{s.nextInt()%2 == 0, s.nextInt()%2 == 0, s.nextInt()%2 == 0} case typeStringArray: + if fieldType == "interval" { + value := strconv.Itoa((s.nextInt()%26)+2) + " days" + return types.StringArray{value, value} + } + if fieldType == "uuid" { + value := uuid.NewV4().String() + return types.StringArray{value, value} + } + if fieldType == "box" || fieldType == "line" || fieldType == "lseg" || + fieldType == "path" || fieldType == "polygon" { + value := randBox() + return types.StringArray{value, value} + } + if fieldType == "cidr" || fieldType == "inet" { + value := randNetAddr() + return types.StringArray{value, value} + } + if fieldType == "macaddr" { + value := randMacAddr() + return types.StringArray{value, value} + } + if fieldType == "circle" { + value := randCircle() + return types.StringArray{value, value} + } + if fieldType == "pg_lsn" { + value := randLsn() + return types.StringArray{value, value} + } + if fieldType == "point" { + value := randPoint() + return types.StringArray{value, value} + } + if fieldType == "txid_snapshot" { + value := randTxID() + return types.StringArray{value, value} + } + if fieldType == "money" { + value := randMoney(s) + return types.StringArray{value, value} + } + if fieldType == "json" || fieldType == "jsonb" { + value := []byte(fmt.Sprintf(`"%s"`, randStr(s, 1))) + return types.StringArray{string(value)} + } return types.StringArray{randStr(s, 4), randStr(s, 4), randStr(s, 4)} case typeBytesArray: return types.BytesArray{randByteSlice(s, 4), randByteSlice(s, 4), randByteSlice(s, 4)} - case typeGenericArray: - return types.GenericArray{A: []types.JSON{randJSON(s, 4), randJSON(s, 4), randJSON(s, 4)}} } return nil @@ -574,17 +593,6 @@ func randByteSlice(s *Seed, ln int) []byte { return str } -func randJSON(s *Seed, ln int) types.JSON { - str := make(types.JSON, ln) - str[0] = '"' - for i := 1; i < ln-1; i++ { - str[i] = byte(s.nextInt() % 256) - } - str[ln-1] = '"' - - return str -} - func randPoint() string { a := rand.Intn(100) b := a + 1 diff --git a/boil/types/array.go b/boil/types/array.go index e8ddfabc1..2924fa20e 100644 --- a/boil/types/array.go +++ b/boil/types/array.go @@ -184,9 +184,10 @@ func Array(a interface{}) interface { return (*Int64Array)(a) case *[]string: return (*StringArray)(a) - } - return GenericArray{a} + default: + panic(fmt.Sprintf("boil: invalid type received %T", a)) + } } // ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner @@ -208,7 +209,7 @@ func (a *BoolArray) Scan(src interface{}) error { return a.scanBytes([]byte(src)) } - return fmt.Errorf("pq: cannot convert %T to BoolArray", src) + return fmt.Errorf("boil: cannot convert %T to BoolArray", src) } func (a *BoolArray) scanBytes(src []byte) error { @@ -222,7 +223,7 @@ func (a *BoolArray) scanBytes(src []byte) error { b := make(BoolArray, len(elems)) for i, v := range elems { if len(v) != 1 { - return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) + return fmt.Errorf("boil: could not parse boolean array index %d: invalid boolean %q", i, v) } switch v[0] { case 't': @@ -230,7 +231,7 @@ func (a *BoolArray) scanBytes(src []byte) error { case 'f': b[i] = false default: - return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) + return fmt.Errorf("boil: could not parse boolean array index %d: invalid boolean %q", i, v) } } *a = b @@ -279,7 +280,7 @@ func (a *BytesArray) Scan(src interface{}) error { return a.scanBytes([]byte(src)) } - return fmt.Errorf("pq: cannot convert %T to BytesArray", src) + return fmt.Errorf("boil: cannot convert %T to BytesArray", src) } func (a *BytesArray) scanBytes(src []byte) error { @@ -348,7 +349,7 @@ func (a *Float64Array) Scan(src interface{}) error { return a.scanBytes([]byte(src)) } - return fmt.Errorf("pq: cannot convert %T to Float64Array", src) + return fmt.Errorf("boil: cannot convert %T to Float64Array", src) } func (a *Float64Array) scanBytes(src []byte) error { @@ -362,7 +363,7 @@ func (a *Float64Array) scanBytes(src []byte) error { b := make(Float64Array, len(elems)) for i, v := range elems { if b[i], err = strconv.ParseFloat(string(v), 64); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + return fmt.Errorf("boil: parsing array element index %d: %v", i, err) } } *a = b @@ -394,151 +395,6 @@ func (a Float64Array) Value() (driver.Value, error) { return "{}", nil } -// GenericArray implements the driver.Valuer and sql.Scanner interfaces for -// an array or slice of any dimension. -type GenericArray struct{ A interface{} } - -func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { - var assign func([]byte, reflect.Value) error - var del = "," - - // TODO calculate the assign function for other types - // TODO repeat this section on the element type of arrays or slices (multidimensional) - { - if reflect.PtrTo(rt).Implements(typeSQLScanner) { - // dest is always addressable because it is an element of a slice. - assign = func(src []byte, dest reflect.Value) (err error) { - ss := dest.Addr().Interface().(sql.Scanner) - if src == nil { - err = ss.Scan(nil) - } else { - err = ss.Scan(src) - } - return - } - goto FoundType - } - - assign = func([]byte, reflect.Value) error { - return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt) - } - } - -FoundType: - - if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok { - del = ad.ArrayDelimiter() - } - - return rt, assign, del -} - -// Scan implements the sql.Scanner interface. -func (a GenericArray) Scan(src interface{}) error { - dpv := reflect.ValueOf(a.A) - switch { - case dpv.Kind() != reflect.Ptr: - return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) - case dpv.IsNil(): - return fmt.Errorf("pq: destination %T is nil", a.A) - } - - dv := dpv.Elem() - switch dv.Kind() { - case reflect.Slice: - case reflect.Array: - default: - return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) - } - - switch src := src.(type) { - case []byte: - return a.scanBytes(src, dv) - case string: - return a.scanBytes([]byte(src), dv) - } - - return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) -} - -func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { - dtype, assign, del := a.evaluateDestination(dv.Type().Elem()) - dims, elems, err := parseArray(src, []byte(del)) - if err != nil { - return err - } - - // TODO allow multidimensional - - if len(dims) > 1 { - return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented", - strings.Replace(fmt.Sprint(dims), " ", "][", -1)) - } - - // Treat a zero-dimensional array like an array with a single dimension of zero. - if len(dims) == 0 { - dims = append(dims, 0) - } - - for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() { - switch rt.Kind() { - case reflect.Slice: - case reflect.Array: - if rt.Len() != dims[i] { - return fmt.Errorf("pq: cannot convert ARRAY%s to %s", - strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type()) - } - default: - // TODO handle multidimensional - } - } - - values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) - for i, e := range elems { - if err := assign(e, values.Index(i)); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) - } - } - - // TODO handle multidimensional - - switch dv.Kind() { - case reflect.Slice: - dv.Set(values.Slice(0, dims[0])) - case reflect.Array: - for i := 0; i < dims[0]; i++ { - dv.Index(i).Set(values.Index(i)) - } - } - - return nil -} - -// Value implements the driver.Valuer interface. -func (a GenericArray) Value() (driver.Value, error) { - if a.A == nil { - return nil, nil - } - - rv := reflect.ValueOf(a.A) - - if k := rv.Kind(); k != reflect.Array && k != reflect.Slice { - return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A) - } - - if n := rv.Len(); n > 0 { - // There will be at least two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 0, 1+2*n) - - b, _, err := appendArray(b, rv, n) - return string(b), err - } - - return "{}", nil -} - -// Int64Array represents a one-dimensional array of the PostgreSQL integer types. type Int64Array []int64 // Scan implements the sql.Scanner interface. @@ -550,7 +406,7 @@ func (a *Int64Array) Scan(src interface{}) error { return a.scanBytes([]byte(src)) } - return fmt.Errorf("pq: cannot convert %T to Int64Array", src) + return fmt.Errorf("boil: cannot convert %T to Int64Array", src) } func (a *Int64Array) scanBytes(src []byte) error { @@ -564,7 +420,7 @@ func (a *Int64Array) scanBytes(src []byte) error { b := make(Int64Array, len(elems)) for i, v := range elems { if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + return fmt.Errorf("boil: parsing array element index %d: %v", i, err) } } *a = b @@ -608,7 +464,7 @@ func (a *StringArray) Scan(src interface{}) error { return a.scanBytes([]byte(src)) } - return fmt.Errorf("pq: cannot convert %T to StringArray", src) + return fmt.Errorf("boil: cannot convert %T to StringArray", src) } func (a *StringArray) scanBytes(src []byte) error { @@ -622,7 +478,7 @@ func (a *StringArray) scanBytes(src []byte) error { b := make(StringArray, len(elems)) for i, v := range elems { if b[i] = string(v); v == nil { - return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i) + return fmt.Errorf("boil: parsing array element index %d: cannot convert nil to string", i) } } *a = b @@ -753,7 +609,7 @@ func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { var depth, i int if len(src) < 1 || src[0] != '{' { - return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0) + return nil, nil, fmt.Errorf("boil: unable to parse array; expected %q at offset %d", '{', 0) } Open: @@ -803,7 +659,7 @@ Element: if bytes.HasPrefix(src[i:], del) || src[i] == '}' { elem := src[start:i] if len(elem) == 0 { - return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i) } if bytes.Equal(elem, []byte("NULL")) { elem = nil @@ -825,7 +681,7 @@ Element: depth-- i++ } else { - return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i) } } @@ -835,16 +691,16 @@ Close: depth-- i++ } else { - return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i) } } if depth > 0 { - err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i) + err = fmt.Errorf("boil: unable to parse array; expected %q at offset %d", '}', i) } if err == nil { for _, d := range dims { if (len(elems) % d) != 0 { - err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions") + err = fmt.Errorf("boil: multidimensional arrays must have elements with matching dimensions") } } } @@ -857,7 +713,7 @@ func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { return nil, err } if len(dims) > 1 { - return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) + return nil, fmt.Errorf("boil: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) } return elems, err } diff --git a/boil/types/array_test.go b/boil/types/array_test.go index 6e24e447f..27e68cf1a 100644 --- a/boil/types/array_test.go +++ b/boil/types/array_test.go @@ -20,7 +20,6 @@ package types import ( - "bytes" "database/sql" "database/sql/driver" "math/rand" @@ -125,19 +124,6 @@ func TestArrayScanner(t *testing.T) { if _, ok := s.(*StringArray); !ok { t.Errorf("Expected *StringArray, got %T", s) } - - for _, tt := range []interface{}{ - &[]sql.Scanner{}, - &[][]bool{}, - &[][]float64{}, - &[][]int64{}, - &[][]string{}, - } { - s = Array(tt) - if _, ok := s.(GenericArray); !ok { - t.Errorf("Expected GenericArray for %T, got %T", tt, s) - } - } } func TestArrayValuer(t *testing.T) { @@ -162,20 +148,6 @@ func TestArrayValuer(t *testing.T) { if _, ok := v.(*StringArray); !ok { t.Errorf("Expected *StringArray, got %T", v) } - - for _, tt := range []interface{}{ - nil, - []driver.Value{}, - [][]bool{}, - [][]float64{}, - [][]int64{}, - [][]string{}, - } { - v = Array(tt) - if _, ok := v.(GenericArray); !ok { - t.Errorf("Expected GenericArray for %T, got %T", tt, v) - } - } } func TestBoolArrayScanUnsupported(t *testing.T) { @@ -826,300 +798,3 @@ func BenchmarkStringArrayValue(b *testing.B) { a.Value() } } - -func TestGenericArrayScanUnsupported(t *testing.T) { - var s string - var ss []string - - for _, tt := range []struct { - src, dest interface{} - err string - }{ - {nil, nil, "destination is not a pointer to array or slice"}, - {nil, true, "destination bool is not a pointer to array or slice"}, - {nil, &s, "destination *string is not a pointer to array or slice"}, - {nil, ss, "destination []string is not a pointer to array or slice"}, - {true, &ss, "bool to []string"}, - {`{{x}}`, &ss, "multidimensional ARRAY[1][1] is not implemented"}, - {`{{x},{x}}`, &ss, "multidimensional ARRAY[2][1] is not implemented"}, - {`{x}`, &ss, "scanning to string is not implemented"}, - } { - err := GenericArray{tt.dest}.Scan(tt.src) - - if err == nil { - t.Fatalf("Expected error for [%#v %#v]", tt.src, tt.dest) - } - if !strings.Contains(err.Error(), tt.err) { - t.Errorf("Expected error to contain %q for [%#v %#v], got %q", tt.err, tt.src, tt.dest, err) - } - } -} - -func TestGenericArrayScanScannerArrayBytes(t *testing.T) { - src, expected, nsa := []byte(`{NULL,abc,"\""}`), - [3]sql.NullString{{}, {String: `abc`, Valid: true}, {String: `"`, Valid: true}}, - [3]sql.NullString{{String: ``, Valid: true}, {}, {}} - - if err := (GenericArray{&nsa}).Scan(src); err != nil { - t.Fatalf("Expected no error, got %v", err) - } - if !reflect.DeepEqual(nsa, expected) { - t.Errorf("Expected %v, got %v", expected, nsa) - } -} - -func TestGenericArrayScanScannerArrayString(t *testing.T) { - src, expected, nsa := `{NULL,"\"",xyz}`, - [3]sql.NullString{{}, {String: `"`, Valid: true}, {String: `xyz`, Valid: true}}, - [3]sql.NullString{{String: ``, Valid: true}, {}, {}} - - if err := (GenericArray{&nsa}).Scan(src); err != nil { - t.Fatalf("Expected no error, got %v", err) - } - if !reflect.DeepEqual(nsa, expected) { - t.Errorf("Expected %v, got %v", expected, nsa) - } -} - -func TestGenericArrayScanScannerSliceBytes(t *testing.T) { - src, expected, nss := []byte(`{NULL,abc,"\""}`), - []sql.NullString{{}, {String: `abc`, Valid: true}, {String: `"`, Valid: true}}, - []sql.NullString{{String: ``, Valid: true}, {}, {}, {}, {}} - - if err := (GenericArray{&nss}).Scan(src); err != nil { - t.Fatalf("Expected no error, got %v", err) - } - if !reflect.DeepEqual(nss, expected) { - t.Errorf("Expected %v, got %v", expected, nss) - } -} - -func BenchmarkGenericArrayScanScannerSliceBytes(b *testing.B) { - var a GenericArray - var x interface{} = []byte(`{a,b,c,d,e,f,g,h,i,j}`) - var y interface{} = []byte(`{"\a","\b","\c","\d","\e","\f","\g","\h","\i","\j"}`) - - for i := 0; i < b.N; i++ { - a = GenericArray{new([]sql.NullString)} - a.Scan(x) - a = GenericArray{new([]sql.NullString)} - a.Scan(y) - } -} - -func TestGenericArrayScanScannerSliceString(t *testing.T) { - src, expected, nss := `{NULL,"\"",xyz}`, - []sql.NullString{{}, {String: `"`, Valid: true}, {String: `xyz`, Valid: true}}, - []sql.NullString{{String: ``, Valid: true}, {}, {}} - - if err := (GenericArray{&nss}).Scan(src); err != nil { - t.Fatalf("Expected no error, got %v", err) - } - if !reflect.DeepEqual(nss, expected) { - t.Errorf("Expected %v, got %v", expected, nss) - } -} - -type TildeNullInt64 struct{ sql.NullInt64 } - -func (TildeNullInt64) ArrayDelimiter() string { return "~" } - -func TestGenericArrayScanDelimiter(t *testing.T) { - src, expected, tnis := `{12~NULL~76}`, - []TildeNullInt64{{sql.NullInt64{Int64: 12, Valid: true}}, {}, {sql.NullInt64{Int64: 76, Valid: true}}}, - []TildeNullInt64{{sql.NullInt64{Int64: 0, Valid: true}}, {}} - - if err := (GenericArray{&tnis}).Scan(src); err != nil { - t.Fatalf("Expected no error for %#v, got %v", src, err) - } - if !reflect.DeepEqual(tnis, expected) { - t.Errorf("Expected %v for %#v, got %v", expected, src, tnis) - } -} - -func TestGenericArrayScanErrors(t *testing.T) { - var sa [1]string - var nis []sql.NullInt64 - var pss *[]string - - for _, tt := range []struct { - src, dest interface{} - err string - }{ - {nil, pss, "destination *[]string is nil"}, - {`{`, &sa, "unable to parse"}, - {`{}`, &sa, "cannot convert ARRAY[0] to [1]string"}, - {`{x,x}`, &sa, "cannot convert ARRAY[2] to [1]string"}, - {`{x}`, &nis, `parsing array element index 0: converting`}, - } { - err := GenericArray{tt.dest}.Scan(tt.src) - - if err == nil { - t.Fatalf("Expected error for [%#v %#v]", tt.src, tt.dest) - } - if !strings.Contains(err.Error(), tt.err) { - t.Errorf("Expected error to contain %q for [%#v %#v], got %q", tt.err, tt.src, tt.dest, err) - } - } -} - -func TestGenericArrayValueUnsupported(t *testing.T) { - _, err := GenericArray{true}.Value() - - if err == nil { - t.Fatal("Expected error for bool") - } - if !strings.Contains(err.Error(), "bool to array") { - t.Errorf("Expected type to be mentioned, got %q", err) - } -} - -type ByteArrayValuer [1]byte -type ByteSliceValuer []byte -type FuncArrayValuer struct { - delimiter func() string - value func() (driver.Value, error) -} - -func (a ByteArrayValuer) Value() (driver.Value, error) { return a[:], nil } -func (b ByteSliceValuer) Value() (driver.Value, error) { return []byte(b), nil } -func (f FuncArrayValuer) ArrayDelimiter() string { return f.delimiter() } -func (f FuncArrayValuer) Value() (driver.Value, error) { return f.value() } - -func TestGenericArrayValue(t *testing.T) { - result, err := GenericArray{nil}.Value() - - if err != nil { - t.Fatalf("Expected no error for nil, got %v", err) - } - if result != nil { - t.Errorf("Expected nil, got %q", result) - } - - Tilde := func(v driver.Value) FuncArrayValuer { - return FuncArrayValuer{ - func() string { return "~" }, - func() (driver.Value, error) { return v, nil }} - } - - for _, tt := range []struct { - result string - input interface{} - }{ - {`{}`, []bool{}}, - {`{true}`, []bool{true}}, - {`{true,false}`, []bool{true, false}}, - {`{true,false}`, [2]bool{true, false}}, - - {`{}`, [][]int{{}}}, - {`{}`, [][]int{{}, {}}}, - {`{{1}}`, [][]int{{1}}}, - {`{{1},{2}}`, [][]int{{1}, {2}}}, - {`{{1,2},{3,4}}`, [][]int{{1, 2}, {3, 4}}}, - {`{{1,2},{3,4}}`, [2][2]int{{1, 2}, {3, 4}}}, - - {`{"a","\\b","c\"","d,e"}`, []string{`a`, `\b`, `c"`, `d,e`}}, - {`{"a","\\b","c\"","d,e"}`, [][]byte{{'a'}, {'\\', 'b'}, {'c', '"'}, {'d', ',', 'e'}}}, - - {`{NULL}`, []*int{nil}}, - {`{0,NULL}`, []*int{new(int), nil}}, - - {`{NULL}`, []sql.NullString{{}}}, - {`{"\"",NULL}`, []sql.NullString{{String: `"`, Valid: true}, {}}}, - - {`{"a","b"}`, []ByteArrayValuer{{'a'}, {'b'}}}, - {`{{"a","b"},{"c","d"}}`, [][]ByteArrayValuer{{{'a'}, {'b'}}, {{'c'}, {'d'}}}}, - - {`{"e","f"}`, []ByteSliceValuer{{'e'}, {'f'}}}, - {`{{"e","f"},{"g","h"}}`, [][]ByteSliceValuer{{{'e'}, {'f'}}, {{'g'}, {'h'}}}}, - - {`{1~2}`, []FuncArrayValuer{Tilde(int64(1)), Tilde(int64(2))}}, - {`{{1~2}~{3~4}}`, [][]FuncArrayValuer{{Tilde(int64(1)), Tilde(int64(2))}, {Tilde(int64(3)), Tilde(int64(4))}}}, - } { - result, err := GenericArray{tt.input}.Value() - - if err != nil { - t.Fatalf("Expected no error for %q, got %v", tt.input, err) - } - if !reflect.DeepEqual(result, tt.result) { - t.Errorf("Expected %q for %q, got %q", tt.result, tt.input, result) - } - } -} - -func TestGenericArrayValueErrors(t *testing.T) { - var v []interface{} - - v = []interface{}{func() {}} - if _, err := (GenericArray{v}).Value(); err == nil { - t.Errorf("Expected error for %q, got nil", v) - } - - v = []interface{}{nil, func() {}} - if _, err := (GenericArray{v}).Value(); err == nil { - t.Errorf("Expected error for %q, got nil", v) - } -} - -func BenchmarkGenericArrayValueBools(b *testing.B) { - rand.Seed(1) - x := make([]bool, 10) - for i := 0; i < len(x); i++ { - x[i] = rand.Intn(2) == 0 - } - a := GenericArray{x} - - for i := 0; i < b.N; i++ { - a.Value() - } -} - -func BenchmarkGenericArrayValueFloat64s(b *testing.B) { - rand.Seed(1) - x := make([]float64, 10) - for i := 0; i < len(x); i++ { - x[i] = rand.NormFloat64() - } - a := GenericArray{x} - - for i := 0; i < b.N; i++ { - a.Value() - } -} - -func BenchmarkGenericArrayValueInt64s(b *testing.B) { - rand.Seed(1) - x := make([]int64, 10) - for i := 0; i < len(x); i++ { - x[i] = rand.Int63() - } - a := GenericArray{x} - - for i := 0; i < b.N; i++ { - a.Value() - } -} - -func BenchmarkGenericArrayValueByteSlices(b *testing.B) { - x := make([][]byte, 10) - for i := 0; i < len(x); i++ { - x[i] = bytes.Repeat([]byte(`abc"def\ghi`), 5) - } - a := GenericArray{x} - - for i := 0; i < b.N; i++ { - a.Value() - } -} - -func BenchmarkGenericArrayValueStrings(b *testing.B) { - x := make([]string, 10) - for i := 0; i < len(x); i++ { - x[i] = strings.Repeat(`abc"def\ghi`, 5) - } - a := GenericArray{x} - - for i := 0; i < b.N; i++ { - a.Value() - } -} diff --git a/imports.go b/imports.go index ced6daa7b..5fc089671 100644 --- a/imports.go +++ b/imports.go @@ -305,9 +305,6 @@ var importsBasedOnType = map[string]imports{ "types.BytesArray": { thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, }, - "types.GenericArray": { - thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, - }, "types.Int64Array": { thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, }, diff --git a/testdata/test_schema.sql b/testdata/test_schema.sql index 9dd4953c7..14beeca1d 100644 --- a/testdata/test_schema.sql +++ b/testdata/test_schema.sql @@ -194,3 +194,22 @@ create table enemies ( enemies character varying, primary key (enemies) ); + +create table fun_arrays ( + id serial, + fun_one integer[] null, + fun_two integer[] not null, + fun_three boolean[] null, + fun_four boolean[] not null, + fun_five varchar[] null, + fun_six varchar[] not null, + fun_seven decimal[] null, + fun_eight decimal[] not null, + fun_nine bytea[] null, + fun_ten bytea[] not null, + fun_eleven jsonb[] null, + fun_twelve jsonb[] not null, + fun_thirteen json[] null, + fun_fourteen json[] not null, + primary key (id) +) From 9d29d2b9464ba5501e8c182290c38b88207950cb Mon Sep 17 00:00:00 2001 From: Aaron L Date: Sun, 11 Sep 2016 09:17:08 -0700 Subject: [PATCH 29/64] Refactor all the bits. - Make TestMain be driver-based - Move config to TestMain file - Make config a little more sane in pgmain --- imports.go | 23 +++ templates_test/main_test/mysql_main.tpl | 17 +- templates_test/main_test/postgres_main.tpl | 176 ++++++------------- templates_test/singleton/boil_main_test.tpl | 135 ++++++++++++++ templates_test/singleton/boil_viper_test.tpl | 37 ---- 5 files changed, 228 insertions(+), 160 deletions(-) create mode 100644 templates_test/singleton/boil_main_test.tpl delete mode 100644 templates_test/singleton/boil_viper_test.tpl diff --git a/imports.go b/imports.go index 5fc089671..bf08208d1 100644 --- a/imports.go +++ b/imports.go @@ -239,6 +239,29 @@ var defaultTestMainImports = map[string]imports{ `_ "github.com/lib/pq"`, }, }, + "mysql": { + standard: importList{ + `"testing"`, + `"os"`, + `"os/exec"`, + `"flag"`, + `"fmt"`, + `"io/ioutil"`, + `"bytes"`, + `"database/sql"`, + `"path/filepath"`, + `"time"`, + `"math/rand"`, + }, + thirdParty: importList{ + `"github.com/kat-co/vala"`, + `"github.com/pkg/errors"`, + `"github.com/spf13/viper"`, + `"github.com/vattle/sqlboiler/boil"`, + `"github.com/vattle/sqlboiler/bdb/drivers"`, + `_ "github.com/go-mysql-driver/mysql"`, + }, + }, } // importsBasedOnType imports are only included in the template output if the diff --git a/templates_test/main_test/mysql_main.tpl b/templates_test/main_test/mysql_main.tpl index 96dcbfc05..5643e0ebd 100644 --- a/templates_test/main_test/mysql_main.tpl +++ b/templates_test/main_test/mysql_main.tpl @@ -1,2 +1,17 @@ -func TestMain(m *testing.M) { +type mysqlTester struct { + dbConn *sql.DB +} + +dbMain = mysqlTester{} + +func (m mysqlTester) setup() error { + return nil +} + +func (m mysqlTester) teardown() error { + return nil +} + +func (m mysqlTester) conn() *sql.DB { + return m.dbConn } diff --git a/templates_test/main_test/postgres_main.tpl b/templates_test/main_test/postgres_main.tpl index 2750f60cc..2bc819eab 100644 --- a/templates_test/main_test/postgres_main.tpl +++ b/templates_test/main_test/postgres_main.tpl @@ -1,50 +1,21 @@ -type PostgresCfg struct { - User string `toml:"user"` - Pass string `toml:"pass"` - Host string `toml:"host"` - Port int `toml:"port"` - DBName string `toml:"dbname"` - SSLMode string `toml:"sslmode"` -} - -type Config struct { - Postgres PostgresCfg `toml:"postgres"` -} - -var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements") +type pgTester struct { + dbConn *sql.DB -func TestMain(m *testing.M) { - rand.Seed(time.Now().UnixNano()) + dbName string + host string + user string + pass string + sslmode string + port int - // Set DebugMode so we can see generated sql statements - flag.Parse() - boil.DebugMode = *flagDebugMode - - var err error - if err = setup(); err != nil { - fmt.Println("Unable to execute setup:", err) - os.Exit(-2) - } - - var code int - if err = disableTriggers(); err != nil { - fmt.Println("Unable to disable triggers:", err) - } else { - boil.SetDB(dbConn) - code = m.Run() - } - - if err = teardown(); err != nil { - fmt.Println("Unable to execute teardown:", err) - os.Exit(-3) - } - - os.Exit(code) + testDBName string } +dbMain = pgTester{} + // disableTriggers is used to disable foreign key constraints for every table. // If this is not used we cannot test inserts due to foreign key constraint errors. -func disableTriggers() error { +func (p pgTester) disableTriggers() error { var stmts []string {{range .Tables}} @@ -57,7 +28,7 @@ func disableTriggers() error { var err error for _, s := range stmts { - _, err = dbConn.Exec(s) + _, err = p.dbConn.Exec(s) if err != nil { return err } @@ -67,33 +38,37 @@ func disableTriggers() error { } // teardown executes cleanup tasks when the tests finish running -func teardown() error { +func (p pgTester) teardown() error { err := dropTestDB() return err } +func (p pgTester) conn() *sql.DB { + return p.dbConn +} + // dropTestDB switches its connection to the template1 database temporarily // so that it can drop the test database without causing "in use" conflicts. // The template1 database should be present on all default postgres installations. -func dropTestDB() error { +func (p pgTester) dropTestDB() error { var err error - if dbConn != nil { - if err = dbConn.Close(); err != nil { + if p.dbConn != nil { + if err = p.dbConn.Close(); err != nil { return err } } - dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) + p.dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) if err != nil { return err } - _, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName)) + _, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName)) if err != nil { return err } - return dbConn.Close() + return p.dbConn.Close() } // DBConnect connects to a database and returns the handle. @@ -106,43 +81,17 @@ func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql. // setup dumps the database schema and imports it into a temporary randomly // generated test database so that tests can be run against it using the // generated sqlboiler ORM package. -func setup() error { +func (p pgTester) setup() error { var err error - // Initialize Viper and load the config file - err = InitViper() - if err != nil { - return errors.Wrap(err, "Unable to load config file") - } - - viper.SetDefault("postgres.sslmode", "require") - viper.SetDefault("postgres.port", "5432") - - // Create a randomized test configuration object. - testCfg.Postgres.Host = viper.GetString("postgres.host") - testCfg.Postgres.Port = viper.GetInt("postgres.port") - testCfg.Postgres.User = viper.GetString("postgres.user") - testCfg.Postgres.Pass = viper.GetString("postgres.pass") - testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname")) - testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode") - - // Set the default SSLMode value - if testCfg.Postgres.SSLMode == "" { - viper.Set("postgres.sslmode", "require") - testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode") - } - - err = vala.BeginValidation().Validate( - vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"), - vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"), - vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")), - vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"), - vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"), - ).Check() - - if err != nil { - return errors.Wrap(err, "Unable to load testCfg") - } + p.dbName = viper.GetString("postgres.dbname") + p.host = viper.GetString("postgres.host") + p.user = viper.GetString("postgres.user") + p.pass = viper.GetString("postgres.pass") + p.port = viper.GetInt("postgres.port") + p.sslmode = viper.GetString("postgres.dbname") + // Create a randomized db name. + p.testDBName = getDBNameHash(p.dbname) err = dropTestDB() if err != nil { @@ -163,15 +112,10 @@ func setup() error { defer os.RemoveAll(passDir) // Write the postgres user password to a tmp file for pg_dump - pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", - viper.GetString("postgres.host"), - viper.GetInt("postgres.port"), - viper.GetString("postgres.dbname"), - viper.GetString("postgres.user"), - )) - - if pw := viper.GetString("postgres.pass"); len(pw) > 0 { - pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, pw)) + pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbname, p.user)) + + if len(p.pass) > 0 { + pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass)) } passFilePath := filepath.Join(passDir, "pwfile") @@ -183,11 +127,11 @@ func setup() error { // The params for the pg_dump command to dump the database schema params := []string{ - fmt.Sprintf(`--host=%s`, viper.GetString("postgres.host")), - fmt.Sprintf(`--port=%d`, viper.GetInt("postgres.port")), - fmt.Sprintf(`--username=%s`, viper.GetString("postgres.user")), + fmt.Sprintf(`--host=%s`, p.host), + fmt.Sprintf(`--port=%d`, p.port), + fmt.Sprintf(`--username=%s`, p.user), "--schema-only", - viper.GetString("postgres.dbname"), + p.dbName, } // Dump the database schema into the sqlboilerschema tmp file @@ -202,45 +146,33 @@ func setup() error { return err } - dbConn, err = DBConnect( - viper.GetString("postgres.user"), - viper.GetString("postgres.pass"), - viper.GetString("postgres.dbname"), - viper.GetString("postgres.host"), - viper.GetInt("postgres.port"), - viper.GetString("postgres.sslmode"), - ) + p.dbConn, err = DBConnect(p.user, p.pass, p.dbName, p.host, p.port, p.sslmode) if err != nil { return err } // Create the randomly generated database - _, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName)) + _, err = p.dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, p.testDBName)) if err != nil { return err } // Close the old connection so we can reconnect to the test database - if err = dbConn.Close(); err != nil { + if err = p.dbConn.Close(); err != nil { return err } // Connect to the generated test db - dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) + p.dbConn, err = DBConnect(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode) if err != nil { return err } // Write the test config credentials to a tmp file for pg_dump - testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", - testCfg.Postgres.Host, - testCfg.Postgres.Port, - testCfg.Postgres.DBName, - testCfg.Postgres.User, - )) - - if len(testCfg.Postgres.Pass) > 0 { - testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, testCfg.Postgres.Pass)) + testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user)) + + if len(p.pass) > 0 { + testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, p.pass)) } testPassFilePath := passDir + "/testpwfile" @@ -252,10 +184,10 @@ func setup() error { // The params for the psql schema import command params = []string{ - fmt.Sprintf(`--dbname=%s`, testCfg.Postgres.DBName), - fmt.Sprintf(`--host=%s`, testCfg.Postgres.Host), - fmt.Sprintf(`--port=%d`, testCfg.Postgres.Port), - fmt.Sprintf(`--username=%s`, testCfg.Postgres.User), + fmt.Sprintf(`--dbname=%s`, p.testDBName), + fmt.Sprintf(`--host=%s`, p.host), + fmt.Sprintf(`--port=%d`, p.port), + fmt.Sprintf(`--username=%s`, p.user), fmt.Sprintf(`--file=%s`, fhSchema.Name()), } @@ -271,5 +203,5 @@ func setup() error { fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String()) } - return nil + return p.disableTriggers() } diff --git a/templates_test/singleton/boil_main_test.tpl b/templates_test/singleton/boil_main_test.tpl new file mode 100644 index 000000000..77fe62dd2 --- /dev/null +++ b/templates_test/singleton/boil_main_test.tpl @@ -0,0 +1,135 @@ +var flagDebugMode = flag.Bool("test.sqldebug", false, "Turns on debug mode for SQL statements") + +var ( + dbMain tester +) + +type tester interface { + setup() error + conn() *sql.DB + teardown() error +} + +func TestMain(m *testing.M) { + if dbMain == nil { + fmt.Println("no dbMain tester interface was ready") + os.Exit(-1) + } + + rand.Seed(time.Now().UnixNano()) + + // Load configuration + err = initViper() + if err != nil { + return errors.Wrap(err, "Unable to load config file") + } + + setConfigDefaults() + if err := validateConfig({{.DriverName}}); err != nil { + fmt.Println("failed to validate config", err) + os.Exit(-2) + } + + // Set DebugMode so we can see generated sql statements + flag.Parse() + boil.DebugMode = *flagDebugMode + + var err error + if err = dbMain.setup(); err != nil { + fmt.Println("Unable to execute setup:", err) + os.Exit(-3) + } + + var code int + boil.SetDB(dbMain.conn()) + code = m.Run() + + if err = dbMain.teardown(); err != nil { + fmt.Println("Unable to execute teardown:", err) + os.Exit(-4) + } + + os.Exit(code) +} + +func initViper() error { + var err error + + viper.SetConfigName("sqlboiler") + + configHome := os.Getenv("XDG_CONFIG_HOME") + homePath := os.Getenv("HOME") + wd, err := os.Getwd() + if err != nil { + wd = "../" + } else { + wd = wd + "/.." + } + + configPaths := []string{wd} + if len(configHome) > 0 { + configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler")) + } else { + configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler")) + } + + for _, p := range configPaths { + viper.AddConfigPath(p) + } + + // Ignore errors here, fall back to defaults and validation to provide errs + _ = viper.ReadInConfig() + viper.AutomaticEnv() + + return nil +} + +// setDefaults is only necessary because of bugs in viper, noted in main +func setDefaults() { + if viper.GetString("postgres.sslmode") == "" { + viper.Set("postgres.sslmode", "require") + } + if viper.GetInt("postgres.port") == 0 { + viper.Set("postgres.port", 5432) + } + if viper.GetString("mysql.sslmode") == "" { + viper.Set("mysql.sslmode", "true") + } + if viper.GetInt("mysql.port") == 0 { + viper.Set("mysql.port", 3306) + } +} + +func validateConfig(driverName string) error { + if viper.IsSet("postgres.dbname") { + err = vala.BeginValidation().Validate( + vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"), + vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"), + vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")), + vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"), + vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"), + ).Check() + + if err != nil { + return err + } + } else if driverName == "postgres" { + return errors.New("postgres driver requires a postgres section in your config file") + } + + if viper.IsSet("mysql.dbname") { + err = vala.BeginValidation().Validate( + vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"), + vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"), + vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")), + vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"), + vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"), + ).Check() + + if err != nil { + return err + } + } else if driverName == "mysql" { + return errors.New("mysql driver requires a mysql section in your config file") + } +} diff --git a/templates_test/singleton/boil_viper_test.tpl b/templates_test/singleton/boil_viper_test.tpl deleted file mode 100644 index d05a20a7b..000000000 --- a/templates_test/singleton/boil_viper_test.tpl +++ /dev/null @@ -1,37 +0,0 @@ -var ( - testCfg *Config - dbConn *sql.DB -) - -func InitViper() error { - var err error - testCfg = &Config{} - - viper.SetConfigName("sqlboiler") - - configHome := os.Getenv("XDG_CONFIG_HOME") - homePath := os.Getenv("HOME") - wd, err := os.Getwd() - if err != nil { - wd = "../" - } else { - wd = wd + "/.." - } - - configPaths := []string{wd} - if len(configHome) > 0 { - configPaths = append(configPaths, filepath.Join(configHome, "sqlboiler")) - } else { - configPaths = append(configPaths, filepath.Join(homePath, ".config/sqlboiler")) - } - - for _, p := range configPaths { - viper.AddConfigPath(p) - } - - // Ignore errors here, fall back to defaults and validation to provide errs - _ = viper.ReadInConfig() - viper.AutomaticEnv() - - return nil -} From d1ea9255238db602dc7879bd197c598ec23dc1c2 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Sun, 11 Sep 2016 09:22:20 -0700 Subject: [PATCH 30/64] Fix bug in debug output --- sqlboiler.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sqlboiler.go b/sqlboiler.go index d6e1bb262..3a4483c98 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -70,8 +70,7 @@ func New(config *Config) (*State, error) { if err != nil { return nil, errors.Wrap(err, "unable to json marshal tables") } - boil.DebugWriter.Write(b) - fmt.Fprintln(boil.DebugWriter) + fmt.Printf("%s\n", b) } err = s.initOutFolder() From d183ec4bb5e43bd5d6e47f5d9ee9bcc4eeb9a9c6 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Sun, 11 Sep 2016 12:07:39 -0700 Subject: [PATCH 31/64] Postgres works again after refactor --- imports.go | 17 ++++----- main.go | 9 ++--- templates_test/main_test/mysql_main.tpl | 10 +++--- templates_test/main_test/postgres_main.tpl | 33 ++++++++--------- templates_test/singleton/boil_main_test.tpl | 39 ++++++++------------- 5 files changed, 49 insertions(+), 59 deletions(-) diff --git a/imports.go b/imports.go index bf08208d1..938e29480 100644 --- a/imports.go +++ b/imports.go @@ -186,14 +186,22 @@ var defaultTestTemplateImports = imports{ } var defaultSingletonTestTemplateImports = map[string]imports{ - "boil_viper_test": { + "boil_main_test": { standard: importList{ `"database/sql"`, + `"flag"`, + `"fmt"`, + `"math/rand"`, `"os"`, `"path/filepath"`, + `"testing"`, + `"time"`, }, thirdParty: importList{ + `"github.com/kat-co/vala"`, + `"github.com/pkg/errors"`, `"github.com/spf13/viper"`, + `"github.com/vattle/sqlboiler/boil"`, }, }, "boil_queries_test": { @@ -218,23 +226,17 @@ var defaultSingletonTestTemplateImports = map[string]imports{ var defaultTestMainImports = map[string]imports{ "postgres": { standard: importList{ - `"testing"`, `"os"`, `"os/exec"`, - `"flag"`, `"fmt"`, `"io/ioutil"`, `"bytes"`, `"database/sql"`, `"path/filepath"`, - `"time"`, - `"math/rand"`, }, thirdParty: importList{ - `"github.com/kat-co/vala"`, `"github.com/pkg/errors"`, `"github.com/spf13/viper"`, - `"github.com/vattle/sqlboiler/boil"`, `"github.com/vattle/sqlboiler/bdb/drivers"`, `_ "github.com/lib/pq"`, }, @@ -254,7 +256,6 @@ var defaultTestMainImports = map[string]imports{ `"math/rand"`, }, thirdParty: importList{ - `"github.com/kat-co/vala"`, `"github.com/pkg/errors"`, `"github.com/spf13/viper"`, `"github.com/vattle/sqlboiler/boil"`, diff --git a/main.go b/main.go index 9008892b7..2fd63ffba 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package main import ( - "errors" "fmt" "os" "path/filepath" @@ -148,7 +147,7 @@ func preRun(cmd *cobra.Command, args []string) error { } } - if viper.IsSet("postgres.dbname") { + if driverName == "postgres" { cmdConfig.Postgres = PostgresConfig{ User: viper.GetString("postgres.user"), Pass: viper.GetString("postgres.pass"), @@ -182,11 +181,9 @@ func preRun(cmd *cobra.Command, args []string) error { if err != nil { return commandFailure(err.Error()) } - } else if driverName == "postgres" { - return errors.New("postgres driver requires a postgres section in your config file") } - if viper.IsSet("mysql.dbname") { + if driverName == "mysql" { cmdConfig.MySQL = MySQLConfig{ User: viper.GetString("mysql.user"), Pass: viper.GetString("mysql.pass"), @@ -223,8 +220,6 @@ func preRun(cmd *cobra.Command, args []string) error { if err != nil { return commandFailure(err.Error()) } - } else if driverName == "mysql" { - return errors.New("mysql driver requires a mysql section in your config file") } cmdState, err = New(cmdConfig) diff --git a/templates_test/main_test/mysql_main.tpl b/templates_test/main_test/mysql_main.tpl index 5643e0ebd..a519297ad 100644 --- a/templates_test/main_test/mysql_main.tpl +++ b/templates_test/main_test/mysql_main.tpl @@ -2,16 +2,18 @@ type mysqlTester struct { dbConn *sql.DB } -dbMain = mysqlTester{} +func init() { + dbMain = &mysqlTester{} +} -func (m mysqlTester) setup() error { +func (m *mysqlTester) setup() error { return nil } -func (m mysqlTester) teardown() error { +func (m *mysqlTester) teardown() error { return nil } -func (m mysqlTester) conn() *sql.DB { +func (m *mysqlTester) conn() *sql.DB { return m.dbConn } diff --git a/templates_test/main_test/postgres_main.tpl b/templates_test/main_test/postgres_main.tpl index 2bc819eab..800e76839 100644 --- a/templates_test/main_test/postgres_main.tpl +++ b/templates_test/main_test/postgres_main.tpl @@ -11,16 +11,18 @@ type pgTester struct { testDBName string } -dbMain = pgTester{} +func init() { + dbMain = &pgTester{} +} // disableTriggers is used to disable foreign key constraints for every table. // If this is not used we cannot test inserts due to foreign key constraint errors. -func (p pgTester) disableTriggers() error { +func (p *pgTester) disableTriggers() error { var stmts []string - {{range .Tables}} + {{range .Tables -}} stmts = append(stmts, `ALTER TABLE {{.Name}} DISABLE TRIGGER ALL;`) - {{- end}} + {{end -}} if len(stmts) == 0 { return nil @@ -38,19 +40,18 @@ func (p pgTester) disableTriggers() error { } // teardown executes cleanup tasks when the tests finish running -func (p pgTester) teardown() error { - err := dropTestDB() - return err +func (p *pgTester) teardown() error { + return p.dropTestDB() } -func (p pgTester) conn() *sql.DB { +func (p *pgTester) conn() *sql.DB { return p.dbConn } // dropTestDB switches its connection to the template1 database temporarily // so that it can drop the test database without causing "in use" conflicts. // The template1 database should be present on all default postgres installations. -func (p pgTester) dropTestDB() error { +func (p *pgTester) dropTestDB() error { var err error if p.dbConn != nil { if err = p.dbConn.Close(); err != nil { @@ -58,12 +59,12 @@ func (p pgTester) dropTestDB() error { } } - p.dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) + p.dbConn, err = DBConnect(p.user, p.pass, "template1", p.host, p.port, p.sslmode) if err != nil { return err } - _, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName)) + _, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, p.testDBName)) if err != nil { return err } @@ -81,7 +82,7 @@ func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql. // setup dumps the database schema and imports it into a temporary randomly // generated test database so that tests can be run against it using the // generated sqlboiler ORM package. -func (p pgTester) setup() error { +func (p *pgTester) setup() error { var err error p.dbName = viper.GetString("postgres.dbname") @@ -89,11 +90,11 @@ func (p pgTester) setup() error { p.user = viper.GetString("postgres.user") p.pass = viper.GetString("postgres.pass") p.port = viper.GetInt("postgres.port") - p.sslmode = viper.GetString("postgres.dbname") + p.sslmode = viper.GetString("postgres.sslmode") // Create a randomized db name. - p.testDBName = getDBNameHash(p.dbname) + p.testDBName = getDBNameHash(p.dbName) - err = dropTestDB() + err = p.dropTestDB() if err != nil { fmt.Printf("%#v\n", err) return err @@ -112,7 +113,7 @@ func (p pgTester) setup() error { defer os.RemoveAll(passDir) // Write the postgres user password to a tmp file for pg_dump - pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbname, p.user)) + pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbName, p.user)) if len(p.pass) > 0 { pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass)) diff --git a/templates_test/singleton/boil_main_test.tpl b/templates_test/singleton/boil_main_test.tpl index 77fe62dd2..5cd4046a0 100644 --- a/templates_test/singleton/boil_main_test.tpl +++ b/templates_test/singleton/boil_main_test.tpl @@ -17,27 +17,28 @@ func TestMain(m *testing.M) { } rand.Seed(time.Now().UnixNano()) + var err error // Load configuration err = initViper() if err != nil { - return errors.Wrap(err, "Unable to load config file") + fmt.Println("unable to load config file") + os.Exit(-2) } setConfigDefaults() - if err := validateConfig({{.DriverName}}); err != nil { + if err := validateConfig("{{.DriverName}}"); err != nil { fmt.Println("failed to validate config", err) - os.Exit(-2) + os.Exit(-3) } // Set DebugMode so we can see generated sql statements flag.Parse() boil.DebugMode = *flagDebugMode - var err error if err = dbMain.setup(); err != nil { fmt.Println("Unable to execute setup:", err) - os.Exit(-3) + os.Exit(-4) } var code int @@ -46,7 +47,7 @@ func TestMain(m *testing.M) { if err = dbMain.teardown(); err != nil { fmt.Println("Unable to execute teardown:", err) - os.Exit(-4) + os.Exit(-5) } os.Exit(code) @@ -84,8 +85,8 @@ func initViper() error { return nil } -// setDefaults is only necessary because of bugs in viper, noted in main -func setDefaults() { +// setConfigDefaults is only necessary because of bugs in viper, noted in main +func setConfigDefaults() { if viper.GetString("postgres.sslmode") == "" { viper.Set("postgres.sslmode", "require") } @@ -101,35 +102,25 @@ func setDefaults() { } func validateConfig(driverName string) error { - if viper.IsSet("postgres.dbname") { - err = vala.BeginValidation().Validate( + if driverName == "postgres" { + return vala.BeginValidation().Validate( vala.StringNotEmpty(viper.GetString("postgres.user"), "postgres.user"), vala.StringNotEmpty(viper.GetString("postgres.host"), "postgres.host"), vala.Not(vala.Equals(viper.GetInt("postgres.port"), 0, "postgres.port")), vala.StringNotEmpty(viper.GetString("postgres.dbname"), "postgres.dbname"), vala.StringNotEmpty(viper.GetString("postgres.sslmode"), "postgres.sslmode"), ).Check() - - if err != nil { - return err - } - } else if driverName == "postgres" { - return errors.New("postgres driver requires a postgres section in your config file") } - if viper.IsSet("mysql.dbname") { - err = vala.BeginValidation().Validate( + if driverName == "mysql" { + return vala.BeginValidation().Validate( vala.StringNotEmpty(viper.GetString("mysql.user"), "mysql.user"), vala.StringNotEmpty(viper.GetString("mysql.host"), "mysql.host"), vala.Not(vala.Equals(viper.GetInt("mysql.port"), 0, "mysql.port")), vala.StringNotEmpty(viper.GetString("mysql.dbname"), "mysql.dbname"), vala.StringNotEmpty(viper.GetString("mysql.sslmode"), "mysql.sslmode"), ).Check() - - if err != nil { - return err - } - } else if driverName == "mysql" { - return errors.New("mysql driver requires a mysql section in your config file") } + + return errors.New("not a valid driver name") } From 8392a4ba2ab0be8fb3188c1edc838255174609c9 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Sun, 11 Sep 2016 22:30:25 -0700 Subject: [PATCH 32/64] Fix quoting throughout templates - Use "" style strings for all templates - Attach a Quote and SchemaTable that understand escaped quotes so we can use "" style strings without repercussion. - Make SchemaTable use escaped quotes - Remove schemaTable from the templates in favor of .SchemaTable --- sqlboiler.go | 5 +++ strmangle/strmangle.go | 20 ++++++--- strmangle/strmangle_test.go | 16 +++++++- templates.go | 43 +++++++++++++++----- templates/04_relationship_to_one.tpl | 4 +- templates/05_relationship_to_many.tpl | 11 ++--- templates/06_relationship_to_one_eager.tpl | 8 ++-- templates/07_relationship_to_many_eager.tpl | 6 ++- templates/09_relationship_to_many_setops.tpl | 8 ++-- templates/10_all.tpl | 2 +- templates/11_find.tpl | 2 +- templates/12_insert.tpl | 7 ++-- templates/13_update.tpl | 7 ++-- templates/15_delete.tpl | 7 ++-- templates/16_reload.tpl | 3 +- templates/17_exists.tpl | 3 +- templates_test/relationship_to_many.tpl | 4 +- 17 files changed, 107 insertions(+), 49 deletions(-) diff --git a/sqlboiler.go b/sqlboiler.go index 3a4483c98..ae6b6d8fa 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -16,6 +16,7 @@ import ( "github.com/vattle/sqlboiler/bdb" "github.com/vattle/sqlboiler/bdb/drivers" "github.com/vattle/sqlboiler/boil" + "github.com/vattle/sqlboiler/strmangle" ) const ( @@ -103,6 +104,8 @@ func (s *State) Run(includeTests bool) error { NoHooks: s.Config.NoHooks, NoAutoTimestamps: s.Config.NoAutoTimestamps, Dialect: s.Dialect, + LQ: strmangle.QuoteCharacter(s.Dialect.LQ), + RQ: strmangle.QuoteCharacter(s.Dialect.RQ), StringFuncs: templateStringMappers, } @@ -137,6 +140,8 @@ func (s *State) Run(includeTests bool) error { NoAutoTimestamps: s.Config.NoAutoTimestamps, Tags: s.Config.Tags, Dialect: s.Dialect, + LQ: strmangle.QuoteCharacter(s.Dialect.LQ), + RQ: strmangle.QuoteCharacter(s.Dialect.RQ), StringFuncs: templateStringMappers, } diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index 82ff09eea..ab0147f14 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -38,12 +38,12 @@ func init() { // using a database that supports real schemas, for example, // for Postgres: "schema_name"."table_name", versus // simply "table_name" for MySQL (because it does not support real schemas) -func SchemaTable(lq byte, rq byte, driver string, schema string, table string) string { +func SchemaTable(lq, rq string, driver string, schema string, table string) string { if driver == "postgres" && schema != "public" { - return fmt.Sprintf(`%c%s%c.%c%s%c`, lq, schema, rq, lq, table, rq) + return fmt.Sprintf(`%s%s%s.%s%s%s`, lq, schema, rq, lq, table, rq) } - return fmt.Sprintf(`%c%s%c`, lq, table, rq) + return fmt.Sprintf(`%s%s%s`, lq, table, rq) } // IdentQuote attempts to quote simple identifiers in SQL tatements @@ -118,6 +118,16 @@ func Identifier(in int) string { return cols.String() } +// QuoteCharacter returns a string that allows the quote character +// to be embedded into a Go string that uses double quotes: +func QuoteCharacter(q byte) string { + if q == '"' { + return `\"` + } + + return string(q) +} + // Plural converts singular words to plural words (eg: person to people) func Plural(name string) string { buf := GetBuffer() @@ -433,7 +443,7 @@ func SetParamNames(columns []string) string { // WhereClause returns the where clause using start as the $ flag index // For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3" -func WhereClause(start int, cols []string) string { +func WhereClause(lq, rq string, start int, cols []string) string { if start == 0 { panic("0 is not a valid start number for whereClause") } @@ -442,7 +452,7 @@ func WhereClause(start int, cols []string) string { defer PutBuffer(buf) for i, c := range cols { - buf.WriteString(fmt.Sprintf(`"%s"=$%d`, c, start+i)) + buf.WriteString(fmt.Sprintf(`%s%s%s=$%d`, lq, c, rq, start+i)) if i < len(cols)-1 { buf.WriteString(" AND ") } diff --git a/strmangle/strmangle_test.go b/strmangle/strmangle_test.go index 5311ac906..f44ebfa2e 100644 --- a/strmangle/strmangle_test.go +++ b/strmangle/strmangle_test.go @@ -69,6 +69,20 @@ func TestIdentifier(t *testing.T) { } } +func TestQuoteCharacter(t *testing.T) { + t.Parallel() + + if QuoteCharacter('[') != "[" { + t.Error("want just the normal quote character") + } + if QuoteCharacter('`') != "`" { + t.Error("want just the normal quote character") + } + if QuoteCharacter('"') != `\"` { + t.Error("want an escaped character") + } +} + func TestPlaceholders(t *testing.T) { t.Parallel() @@ -317,7 +331,7 @@ func TestWhereClause(t *testing.T) { } for i, test := range tests { - r := WhereClause(test.Start, test.Cols) + r := WhereClause(`"`, `"`, test.Start, test.Cols) if r != test.Should { t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test) } diff --git a/templates.go b/templates.go index 2ff4dc845..80fdfb232 100644 --- a/templates.go +++ b/templates.go @@ -14,17 +14,39 @@ import ( // templateData for sqlboiler templates type templateData struct { - Tables []bdb.Table - Table bdb.Table - Schema string - DriverName string - UseLastInsertID bool - PkgName string + Tables []bdb.Table + Table bdb.Table + + // Controls what names are output + PkgName string + Schema string + + // Controls which code is output (mysql vs postgres ...) + DriverName string + UseLastInsertID bool + + // Turn off auto timestamps or hook generation NoHooks bool NoAutoTimestamps bool - Tags []string - StringFuncs map[string]func(string) string - Dialect boil.Dialect + + // Tags control which + Tags []string + + // StringFuncs are usable in templates with stringMap + StringFuncs map[string]func(string) string + + // Dialect controls quoting + Dialect boil.Dialect + LQ string + RQ string +} + +func (t templateData) Quotes(s string) string { + return fmt.Sprintf("%s%s%s", t.LQ, s, t.RQ) +} + +func (t templateData) SchemaTable(table string) string { + return strmangle.SchemaTable(t.LQ, t.RQ, t.DriverName, t.Schema, table) } type templateList struct { @@ -115,7 +137,7 @@ var templateStringMappers = map[string]func(string) string{ // add a function pointer here. var templateFunctions = template.FuncMap{ // String ops - "quoteWrap": func(a string) string { return fmt.Sprintf(`"%s"`, a) }, + "quoteWrap": func(s string) string { return fmt.Sprintf(`"%s"`, s) }, "id": strmangle.Identifier, // Pluralization @@ -143,7 +165,6 @@ var templateFunctions = template.FuncMap{ // Database related mangling "whereClause": strmangle.WhereClause, - "schemaTable": strmangle.SchemaTable, // Text helpers "textsFromForeignKey": textsFromForeignKey, diff --git a/templates/04_relationship_to_one.tpl b/templates/04_relationship_to_one.tpl index 8d9f10c74..06dac02cd 100644 --- a/templates/04_relationship_to_one.tpl +++ b/templates/04_relationship_to_one.tpl @@ -1,5 +1,5 @@ {{- define "relationship_to_one_helper" -}} - {{- $tmplData := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} + {{- $dot := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} {{- with .Rel -}}{{/* Rel holds the text helper data, passed in through preserveDot */}} {{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} // {{.Function.Name}}G pointed to by the foreign key. @@ -16,7 +16,7 @@ func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}(exec bo queryMods = append(queryMods, mods...) query := {{.ForeignTable.NamePluralGo}}(exec, queryMods...) - boil.SetFrom(query.Query, `{{schemaTable $tmplData.Dialect.LQ $tmplData.Dialect.RQ $tmplData.DriverName $tmplData.Schema .ForeignTable.Name}}`) + boil.SetFrom(query.Query, "{{.ForeignTable.Name | $dot.SchemaTable}}") return query } diff --git a/templates/05_relationship_to_many.tpl b/templates/05_relationship_to_many.tpl index e52dc1088..9d8184711 100644 --- a/templates/05_relationship_to_many.tpl +++ b/templates/05_relationship_to_many.tpl @@ -12,6 +12,7 @@ {{- else -}} {{- /* Begin execution of template for many-to-many relationship. */ -}} {{- $rel := textsFromRelationship $dot.Tables $table . -}} + {{- $schemaForeignTable := .ForeignTable | $dot.SchemaTable -}} // {{$rel.Function.Name}}G retrieves all the {{$rel.LocalTable.NameSingular}}'s {{$rel.ForeignTable.NameHumanReadable}} {{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Name}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { @@ -22,7 +23,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Na {{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Name}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query { queryMods := []qm.QueryMod{ - qm.Select(`"{{id 0}}".*`), + qm.Select("{{id 0 | $dot.Quotes}}.*"), } if len(mods) != 0 { @@ -31,17 +32,17 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Na {{if .ToJoinTable -}} queryMods = append(queryMods, - qm.InnerJoin(`{{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}"`), - qm.Where(`"{{id 1}}"."{{.JoinLocalColumn}}"=$1`, {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), + qm.InnerJoin("{{.JoinTable | $dot.SchemaTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}}"), + qm.Where("{{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}}=$1", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), ) {{else -}} queryMods = append(queryMods, - qm.Where(`"{{id 0}}"."{{.ForeignColumn}}"=$1`, {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), + qm.Where("{{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}}=$1", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), ) {{end}} query := {{$rel.ForeignTable.NamePluralGo}}(exec, queryMods...) - boil.SetFrom(query.Query, `{{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .ForeignTable}} as "{{id 0}}"`) + boil.SetFrom(query.Query, "{{$schemaForeignTable}} as {{id 0 | $dot.Quotes}}") return query } diff --git a/templates/06_relationship_to_one_eager.tpl b/templates/06_relationship_to_one_eager.tpl index f8c55eed8..cf851a572 100644 --- a/templates/06_relationship_to_one_eager.tpl +++ b/templates/06_relationship_to_one_eager.tpl @@ -1,6 +1,6 @@ {{- define "relationship_to_one_eager_helper" -}} - {{- $tmplData := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} - {{- $varNameSingular := $tmplData.Table.Name | singular | camelCase -}} + {{- $dot := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} + {{- $varNameSingular := $dot.Table.Name | singular | camelCase -}} {{- with .Rel -}} {{- $arg := printf "maybe%s" .LocalTable.NameGo -}} {{- $slice := printf "%sSlice" .LocalTable.NameGo -}} @@ -28,7 +28,7 @@ func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bo } query := fmt.Sprintf( - `select * from {{schemaTable $tmplData.Dialect.LQ $tmplData.Dialect.RQ $tmplData.DriverName $tmplData.Schema .ForeignKey.ForeignTable}} where "{{.ForeignKey.ForeignColumn}}" in (%s)`, + "select * from {{.ForeignKey.ForeignTable | $dot.SchemaTable}} where {{.ForeignKey.ForeignColumn | $dot.Quotes}} in (%s)", strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), ) @@ -47,7 +47,7 @@ func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bo return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable.NameGo}}") } - {{if not $tmplData.NoHooks -}} + {{if not $dot.NoHooks -}} if len({{.ForeignTable.Name | singular | camelCase}}AfterSelectHooks) != 0 { for _, obj := range resultSlice { if err := obj.doAfterSelectHooks(e); err != nil { diff --git a/templates/07_relationship_to_many_eager.tpl b/templates/07_relationship_to_many_eager.tpl index a79d75f63..21a6f45d0 100644 --- a/templates/07_relationship_to_many_eager.tpl +++ b/templates/07_relationship_to_many_eager.tpl @@ -13,6 +13,7 @@ {{- $txt := textsFromRelationship $dot.Tables $dot.Table . -}} {{- $arg := printf "maybe%s" $txt.LocalTable.NameGo -}} {{- $slice := printf "%sSlice" $txt.LocalTable.NameGo -}} + {{- $schemaForeignTable := .ForeignTable | $dot.SchemaTable -}} // Load{{$txt.Function.Name}} allows an eager lookup of values, cached into the // loaded structs of the objects. func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error { @@ -37,13 +38,14 @@ func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singula } {{if .ToJoinTable -}} + {{- $schemaJoinTable := .JoinTable | $dot.SchemaTable -}} query := fmt.Sprintf( - `select "{{id 0}}".*, "{{id 1}}"."{{.JoinLocalColumn}}" from {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .ForeignTable}} as "{{id 0}}" inner join {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} as "{{id 1}}" on "{{id 0}}"."{{.ForeignColumn}}" = "{{id 1}}"."{{.JoinForeignColumn}}" where "{{id 1}}"."{{.JoinLocalColumn}}" in (%s)`, + "select {{id 0 | $dot.Quotes}}.*, {{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}} from {{$schemaForeignTable}} as {{id 0 | $dot.Quotes}} inner join {{$schemaJoinTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}} where {{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}} in (%s)", strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), ) {{else -}} query := fmt.Sprintf( - `select * from {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .ForeignTable}} where "{{.ForeignColumn}}" in (%s)`, + "select * from {{$schemaForeignTable}} where {{.ForeignColumn | $dot.Quotes}} in (%s)", strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), ) {{end -}} diff --git a/templates/09_relationship_to_many_setops.tpl b/templates/09_relationship_to_many_setops.tpl index 33e5ce468..59842f66b 100644 --- a/templates/09_relationship_to_many_setops.tpl +++ b/templates/09_relationship_to_many_setops.tpl @@ -39,7 +39,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function {{if .ToJoinTable -}} for _, rel := range related { - query := `insert into {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)` + query := "insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values ($1, $2)" values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}} if boil.DebugMode { @@ -96,10 +96,10 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function // Sets related.R.{{$rel.Function.ForeignName}}'s {{$rel.Function.Name}} accordingly. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Set{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error { {{if .ToJoinTable -}} - query := `delete from {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} where "{{.JoinLocalColumn}}" = $1` + query := "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = $1" values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} {{else -}} - query := `update {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .ForeignTable}} set "{{.ForeignColumn}}" = null where "{{.ForeignColumn}}" = $1` + query := "update {{.ForeignTable | $dot.SchemaTable}} set {{.ForeignColumn | $dot.Quotes}} = null where {{.ForeignColumn | $dot.Quotes}} = $1" values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} {{end -}} if boil.DebugMode { @@ -140,7 +140,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Remove{{$rel.Funct var err error {{if .ToJoinTable -}} query := fmt.Sprintf( - `delete from {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} where "{{.JoinLocalColumn}}" = $1 and "{{.JoinForeignColumn}}" in (%s)`, + "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = $1 and {{.JoinForeignColumn | $dot.Quotes}} in (%s)", strmangle.Placeholders(dialect.IndexPlaceholders, len(related), 1, 1), ) values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} diff --git a/templates/10_all.tpl b/templates/10_all.tpl index 62ddbeb7b..5a2df1459 100644 --- a/templates/10_all.tpl +++ b/templates/10_all.tpl @@ -8,6 +8,6 @@ func {{$tableNamePlural}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { // {{$tableNamePlural}} retrieves all the records using an executor. func {{$tableNamePlural}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query { - mods = append(mods, qm.From(`{{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}}`)) + mods = append(mods, qm.From("{{.Table.Name | .SchemaTable}}")) return {{$varNameSingular}}Query{NewQuery(exec, mods...)} } diff --git a/templates/11_find.tpl b/templates/11_find.tpl index dfdc5b2a7..f7ad81806 100644 --- a/templates/11_find.tpl +++ b/templates/11_find.tpl @@ -29,7 +29,7 @@ func Find{{$tableNameSingular}}(exec boil.Executor, {{$pkArgs}}, selectCols ...s sel = strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, selectCols), ",") } query := fmt.Sprintf( - `select %s from {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} where {{whereClause 1 .Table.PKey.Columns}}`, sel, + "select %s from {{.Table.Name | .SchemaTable}} where {{whereClause .LQ .RQ 1 .Table.PKey.Columns}}", sel, ) q := boil.SQL(exec, query, {{$pkNames | join ", "}}) diff --git a/templates/12_insert.tpl b/templates/12_insert.tpl index dba01c4d3..31f7d636b 100644 --- a/templates/12_insert.tpl +++ b/templates/12_insert.tpl @@ -1,5 +1,6 @@ {{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} +{{- $schemaTable := .Table.Name | .SchemaTable -}} // InsertG a single record. See Insert for whitelist behavior description. func (o *{{$tableNameSingular}}) InsertG(whitelist ... string) error { return o.Insert(boil.GetDB(), whitelist...) @@ -64,13 +65,13 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string if err != nil { return err } - cache.query = fmt.Sprintf(`INSERT INTO {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} ("%s") VALUES (%s)`, strings.Join(wl, `","`), strmangle.Placeholders(dialect.IndexPlaceholders, len(wl), 1, 1)) + cache.query = fmt.Sprintf("INSERT INTO {{$schemaTable}} ({{.LQ}}%s{{.RQ}}) VALUES (%s)", strings.Join(wl, "{{.LQ}},{{.RQ}}"), strmangle.Placeholders(dialect.IndexPlaceholders, len(wl), 1, 1)) if len(cache.retMapping) != 0 { {{if .UseLastInsertID -}} - cache.retQuery = fmt.Sprintf(`SELECT %s FROM {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} WHERE %s`, strings.Join(returnColumns, `","`), strmangle.WhereClause(1, {{$varNameSingular}}PrimaryKeyColumns)) + cache.retQuery = fmt.Sprintf("SELECT %s FROM {{$schemaTable}} WHERE %s", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}"), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", 1, {{$varNameSingular}}PrimaryKeyColumns)) {{else -}} - cache.query += fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ",")) + cache.query += fmt.Sprintf(" RETURNING {{.LQ}}%s{{.RQ}}", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}")) {{end -}} } } diff --git a/templates/13_update.tpl b/templates/13_update.tpl index 1aea39ad9..b3ade49a6 100644 --- a/templates/13_update.tpl +++ b/templates/13_update.tpl @@ -3,6 +3,7 @@ {{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}} {{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}} {{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}} +{{- $schemaTable := .Table.Name | .SchemaTable -}} // UpdateG a single {{$tableNameSingular}} record. See Update for // whitelist behavior description. func (o *{{$tableNameSingular}}) UpdateG(whitelist ...string) error { @@ -52,7 +53,7 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string if !cached { wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist) - cache.query = fmt.Sprintf(`UPDATE {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} SET %s WHERE %s`, strmangle.SetParamNames(wl), strmangle.WhereClause(len(wl)+1, {{$varNameSingular}}PrimaryKeyColumns)) + cache.query = fmt.Sprintf("UPDATE {{$schemaTable}} SET %s WHERE %s", strmangle.SetParamNames(wl), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", len(wl)+1, {{$varNameSingular}}PrimaryKeyColumns)) cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...)) if err != nil { return err @@ -155,8 +156,8 @@ func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error args = append(args, o.inPrimaryKeyArgs()...) sql := fmt.Sprintf( - `UPDATE {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} SET (%s) = (%s) WHERE (%s) IN (%s)`, - strings.Join(colNames, ", "), + "UPDATE {{$schemaTable}} SET (%s) = (%s) WHERE (%s) IN (%s)", + strings.Join(colNames, ","), strmangle.Placeholders(dialect.IndexPlaceholders, len(colNames), 1, 1), strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(colNames)+1, len({{$varNameSingular}}PrimaryKeyColumns)), diff --git a/templates/15_delete.tpl b/templates/15_delete.tpl index 367e74ac2..97343692f 100644 --- a/templates/15_delete.tpl +++ b/templates/15_delete.tpl @@ -1,5 +1,6 @@ {{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} +{{- $schemaTable := .Table.Name | .SchemaTable -}} // DeleteP deletes a single {{$tableNameSingular}} record with an executor. // DeleteP will match against the primary key column to find the record to delete. // Panics on error. @@ -43,7 +44,7 @@ func (o *{{$tableNameSingular}}) Delete(exec boil.Executor) error { args := o.inPrimaryKeyArgs() - sql := `DELETE FROM {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} WHERE {{whereClause 1 .Table.PKey.Columns}}` + sql := "DELETE FROM {{$schemaTable}} WHERE {{whereClause .LQ .RQ 1 .Table.PKey.Columns}}" if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, sql) @@ -132,8 +133,8 @@ func (o {{$tableNameSingular}}Slice) DeleteAll(exec boil.Executor) error { args := o.inPrimaryKeyArgs() sql := fmt.Sprintf( - `DELETE FROM {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} WHERE (%s) IN (%s)`, - strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), + "DELETE FROM {{$schemaTable}} WHERE (%s) IN (%s)", + strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), "{{.LQ}},{{.RQ}}"), strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), ) diff --git a/templates/16_reload.tpl b/templates/16_reload.tpl index ac00577cc..5277c9fcb 100644 --- a/templates/16_reload.tpl +++ b/templates/16_reload.tpl @@ -1,6 +1,7 @@ {{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}} +{{- $schemaTable := .Table.Name | .SchemaTable -}} // ReloadGP refetches the object from the database and panics on error. func (o *{{$tableNameSingular}}) ReloadGP() { if err := o.ReloadG(); err != nil { @@ -67,7 +68,7 @@ func (o *{{$tableNameSingular}}Slice) ReloadAll(exec boil.Executor) error { args := o.inPrimaryKeyArgs() sql := fmt.Sprintf( - `SELECT {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}}.* FROM {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} WHERE (%s) IN (%s)`, + "SELECT {{$schemaTable}}.* FROM {{$schemaTable}} WHERE (%s) IN (%s)", strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), strmangle.Placeholders(dialect.IndexPlaceholders, len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), ) diff --git a/templates/17_exists.tpl b/templates/17_exists.tpl index 713a8bbaf..94f829b32 100644 --- a/templates/17_exists.tpl +++ b/templates/17_exists.tpl @@ -2,11 +2,12 @@ {{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}} {{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}} {{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}} +{{- $schemaTable := .Table.Name | .SchemaTable -}} // {{$tableNameSingular}}Exists checks if the {{$tableNameSingular}} row exists. func {{$tableNameSingular}}Exists(exec boil.Executor, {{$pkArgs}}) (bool, error) { var exists bool - sql := `select exists(select 1 from {{schemaTable .Dialect.LQ .Dialect.RQ .DriverName .Schema .Table.Name}} where {{whereClause 1 .Table.PKey.Columns}} limit 1)` + sql := "select exists(select 1 from {{$schemaTable}} where {{whereClause .LQ .RQ 1 .Table.PKey.Columns}} limit 1)" if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, sql) diff --git a/templates_test/relationship_to_many.tpl b/templates_test/relationship_to_many.tpl index e1c4ee0cb..a157f48d0 100644 --- a/templates_test/relationship_to_many.tpl +++ b/templates_test/relationship_to_many.tpl @@ -41,11 +41,11 @@ func test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) { } {{if .ToJoinTable -}} - _, err = tx.Exec(`insert into {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) + _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values ($1, $2)", a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) if err != nil { t.Fatal(err) } - _, err = tx.Exec(`insert into {{schemaTable $dot.Dialect.LQ $dot.Dialect.RQ $dot.DriverName $dot.Schema .JoinTable}} ({{.JoinLocalColumn}}, {{.JoinForeignColumn}}) values ($1, $2)`, a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) + _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values ($1, $2)", a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) if err != nil { t.Fatal(err) } From 37a05de380b80f2a12c9c1d822e359efd4b0bc98 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Sun, 11 Sep 2016 22:36:12 -0700 Subject: [PATCH 33/64] Fix a mistake in the insert query --- templates/12_insert.tpl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/templates/12_insert.tpl b/templates/12_insert.tpl index 31f7d636b..a32087161 100644 --- a/templates/12_insert.tpl +++ b/templates/12_insert.tpl @@ -84,7 +84,7 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string fmt.Fprintln(boil.DebugWriter, vals) } - result, err := exec.Exec(ins, vals...) + result, err := exec.Exec(cache.query, vals...) if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}") } From 1b5cea823f94fa072ea1a1eec7245aa3ce956796 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Sun, 11 Sep 2016 23:49:47 -0700 Subject: [PATCH 34/64] Get mysql tests running. - Needs disabling of Foreign Key constraints + upsert deletion to have a chance of working. --- imports.go | 18 +-- templates_test/main_test/mysql_main.tpl | 144 +++++++++++++++++++- templates_test/main_test/postgres_main.tpl | 4 +- templates_test/singleton/boil_main_test.tpl | 9 +- 4 files changed, 157 insertions(+), 18 deletions(-) diff --git a/imports.go b/imports.go index 938e29480..ec3be4c37 100644 --- a/imports.go +++ b/imports.go @@ -243,24 +243,20 @@ var defaultTestMainImports = map[string]imports{ }, "mysql": { standard: importList{ - `"testing"`, - `"os"`, - `"os/exec"`, - `"flag"`, - `"fmt"`, - `"io/ioutil"`, `"bytes"`, `"database/sql"`, - `"path/filepath"`, - `"time"`, - `"math/rand"`, + `"fmt"`, + `"io"`, + `"io/ioutil"`, + `"os"`, + `"os/exec"`, + `"strings"`, }, thirdParty: importList{ `"github.com/pkg/errors"`, `"github.com/spf13/viper"`, - `"github.com/vattle/sqlboiler/boil"`, `"github.com/vattle/sqlboiler/bdb/drivers"`, - `_ "github.com/go-mysql-driver/mysql"`, + `_ "github.com/go-sql-driver/mysql"`, }, }, } diff --git a/templates_test/main_test/mysql_main.tpl b/templates_test/main_test/mysql_main.tpl index a519297ad..c3c7c4d60 100644 --- a/templates_test/main_test/mysql_main.tpl +++ b/templates_test/main_test/mysql_main.tpl @@ -1,19 +1,157 @@ type mysqlTester struct { dbConn *sql.DB + + dbName string + host string + user string + pass string + sslmode string + port int + + optionFile string + + testDBName string } func init() { - dbMain = &mysqlTester{} + dbMain = &mysqlTester{} } func (m *mysqlTester) setup() error { + var err error + + m.dbName = viper.GetString("mysql.dbname") + m.host = viper.GetString("mysql.host") + m.user = viper.GetString("mysql.user") + m.pass = viper.GetString("mysql.pass") + m.port = viper.GetInt("mysql.port") + m.sslmode = viper.GetString("mysql.sslmode") + // Create a randomized db name. + m.testDBName = getDBNameHash(m.dbName) + + if err = m.makeOptionFile(); err != nil { + return errors.Wrap(err, "couldn't make option file") + } + + if err = m.dropTestDB(); err != nil { + return err + } + if err = m.createTestDB(); err != nil { + return err + } + + dumpCmd := exec.Command("mysqldump", m.defaultsFile(), m.dbName) + createCmd := exec.Command("mysql", m.defaultsFile(), "--database", m.testDBName) + + r, w := io.Pipe() + dumpCmd.Stdout = w + createCmd.Stdin = io.TeeReader(r, os.Stdout) + + if err = dumpCmd.Start(); err != nil { + return errors.Wrap(err, "failed to start mysqldump command") + } + if err = createCmd.Start(); err != nil { + return errors.Wrap(err, "failed to start mysql command") + } + + if err = dumpCmd.Wait(); err != nil { + fmt.Println(err) + return errors.Wrap(err, "failed to wait for mysqldump command") + } + + w.Close() // After dumpCmd is done, close the write end of the pipe + + if err = createCmd.Wait(); err != nil { + fmt.Println(err) + return errors.Wrap(err, "failed to wait for mysql command") + } + return nil } +func (m *mysqlTester) defaultsFile() string { + return fmt.Sprintf("--defaults-file=%s", m.optionFile) +} + +func (m *mysqlTester) makeOptionFile() error { + tmp, err := ioutil.TempFile("", "optionfile") + if err != nil { + return errors.Wrap(err, "failed to create option file") + } + + fmt.Fprintln(tmp, "[client]") + fmt.Fprintf(tmp, "host=%s\n", m.host) + fmt.Fprintf(tmp, "port=%d\n", m.port) + fmt.Fprintf(tmp, "user=%s\n", m.user) + fmt.Fprintf(tmp, "password=%s\n", m.pass) + // BUG: SSL Mode for whatever reason is backwards in the mysql driver + // taking options like true or false, but here taking options like + // required/disabled. Until this gets sorted, ignore this. + //fmt.Fprintf("ssl-mode=%s\n", m.password) + + fmt.Fprintln(tmp, "[mysqldump]") + fmt.Fprintf(tmp, "host=%s\n", m.host) + fmt.Fprintf(tmp, "port=%d\n", m.port) + fmt.Fprintf(tmp, "user=%s\n", m.user) + fmt.Fprintf(tmp, "password=%s\n", m.pass) + + m.optionFile = tmp.Name() + + return tmp.Close() +} + +func (m *mysqlTester) createTestDB() error { + sql := fmt.Sprintf("create database %s;", m.testDBName) + return m.runCmd(sql, "mysql") +} + +func (m *mysqlTester) dropTestDB() error { + sql := fmt.Sprintf("drop database if exists %s;", m.testDBName) + return m.runCmd(sql, "mysql") +} + func (m *mysqlTester) teardown() error { + if err := m.dropTestDB(); err != nil { + return err + } + + if m.dbConn != nil { + m.dbConn.Close() + } + + return os.Remove(m.optionFile) +} + +func (m *mysqlTester) runCmd(stdin, command string, args ...string) error { + args = append([]string{m.defaultsFile()}, args...) + + cmd := exec.Command(command, args...) + cmd.Stdin = strings.NewReader(stdin) + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + cmd.Stdout = stdout + cmd.Stderr = stderr + if err := cmd.Run(); err != nil { + fmt.Println("failed running:", command, args) + fmt.Println(stdout.String()) + fmt.Println(stderr.String()) + return err + } + return nil } -func (m *mysqlTester) conn() *sql.DB { - return m.dbConn +func (m *mysqlTester) conn() (*sql.DB, error) { + if m.dbConn != nil { + return m.dbConn, nil + } + + var err error + m.dbConn, err = sql.Open("mysql", drivers.MySQLBuildQueryString(m.user, m.pass, m.testDBName, m.host, m.port, m.sslmode)) + if err != nil { + return nil, err + } + + return m.dbConn, nil } diff --git a/templates_test/main_test/postgres_main.tpl b/templates_test/main_test/postgres_main.tpl index 800e76839..7d19f1ba6 100644 --- a/templates_test/main_test/postgres_main.tpl +++ b/templates_test/main_test/postgres_main.tpl @@ -44,8 +44,8 @@ func (p *pgTester) teardown() error { return p.dropTestDB() } -func (p *pgTester) conn() *sql.DB { - return p.dbConn +func (p *pgTester) conn() (*sql.DB, error) { + return p.dbConn, nil } // dropTestDB switches its connection to the template1 database temporarily diff --git a/templates_test/singleton/boil_main_test.tpl b/templates_test/singleton/boil_main_test.tpl index 5cd4046a0..0014a1e4b 100644 --- a/templates_test/singleton/boil_main_test.tpl +++ b/templates_test/singleton/boil_main_test.tpl @@ -6,7 +6,7 @@ var ( type tester interface { setup() error - conn() *sql.DB + conn() (*sql.DB, error) teardown() error } @@ -41,8 +41,13 @@ func TestMain(m *testing.M) { os.Exit(-4) } + conn, err := dbMain.conn() + if err != nil { + fmt.Println("failed to get connection:", err) + } + var code int - boil.SetDB(dbMain.conn()) + boil.SetDB(conn) code = m.Run() if err = dbMain.teardown(); err != nil { From f1f311b70f8720abe511962b690a636ebf470467 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Mon, 12 Sep 2016 22:43:09 -0700 Subject: [PATCH 35/64] Add DB name stuff to randomize package --- boil/randomize/random.go | 121 ++++++++++++++++++++++++++++++++++ boil/randomize/random_test.go | 19 ++++++ boil/randomize/randomize.go | 84 ----------------------- 3 files changed, 140 insertions(+), 84 deletions(-) create mode 100644 boil/randomize/random.go create mode 100644 boil/randomize/random_test.go diff --git a/boil/randomize/random.go b/boil/randomize/random.go new file mode 100644 index 000000000..27395c6e1 --- /dev/null +++ b/boil/randomize/random.go @@ -0,0 +1,121 @@ +package randomize + +import ( + "crypto/md5" + "fmt" + "math/rand" +) + +const alphabetAll = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +const alphabetLowerAlpha = "abcdefghijklmnopqrstuvwxyz" + +func randStr(s *Seed, ln int) string { + str := make([]byte, ln) + for i := 0; i < ln; i++ { + str[i] = byte(alphabetAll[s.nextInt()%len(alphabetAll)]) + } + + return string(str) +} + +func randByteSlice(s *Seed, ln int) []byte { + str := make([]byte, ln) + for i := 0; i < ln; i++ { + str[i] = byte(s.nextInt() % 256) + } + + return str +} + +func randPoint() string { + a := rand.Intn(100) + b := a + 1 + return fmt.Sprintf("(%d,%d)", a, b) +} + +func randBox() string { + a := rand.Intn(100) + b := a + 1 + c := a + 2 + d := a + 3 + return fmt.Sprintf("(%d,%d),(%d,%d)", a, b, c, d) +} + +func randCircle() string { + a, b, c := rand.Intn(100), rand.Intn(100), rand.Intn(100) + return fmt.Sprintf("((%d,%d),%d)", a, b, c) +} + +func randNetAddr() string { + return fmt.Sprintf( + "%d.%d.%d.%d", + rand.Intn(254)+1, + rand.Intn(254)+1, + rand.Intn(254)+1, + rand.Intn(254)+1, + ) +} + +func randMacAddr() string { + buf := make([]byte, 6) + _, err := rand.Read(buf) + if err != nil { + panic(err) + } + + // Set the local bit + buf[0] |= 2 + return fmt.Sprintf( + "%02x:%02x:%02x:%02x:%02x:%02x", + buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], + ) +} + +func randLsn() string { + a := rand.Int63n(9000000) + b := rand.Int63n(9000000) + return fmt.Sprintf("%d/%d", a, b) +} + +func randTxID() string { + // Order of integers is relevant + a := rand.Intn(200) + 100 + b := a + 100 + c := a + d := a + 50 + return fmt.Sprintf("%d:%d:%d,%d", a, b, c, d) +} + +func randMoney(s *Seed) string { + return fmt.Sprintf("%d.00", s.nextInt()) +} + +// StableDBName takes a database name in, and generates +// a random string using the database name as the rand Seed. +// getDBNameHash is used to generate unique test database names. +func StableDBName(input string) string { + return randStrFromSource(stableSource(input), 40) +} + +// stableSource takes an input value, and produces a random +// seed from it that will produce very few collisions in +// a 40 character random string made from a different alphabet. +func stableSource(input string) *rand.Rand { + sum := md5.Sum([]byte(input)) + var seed int64 + for i, byt := range sum { + seed ^= int64(byt) << uint((i*4)%64) + } + return rand.New(rand.NewSource(seed)) +} + +func randStrFromSource(r *rand.Rand, length int) string { + ln := len(alphabetLowerAlpha) + + output := make([]rune, length) + for i := 0; i < length; i++ { + output[i] = rune(alphabetLowerAlpha[r.Intn(ln)]) + } + + return string(output) +} diff --git a/boil/randomize/random_test.go b/boil/randomize/random_test.go new file mode 100644 index 000000000..ee910ce61 --- /dev/null +++ b/boil/randomize/random_test.go @@ -0,0 +1,19 @@ +package randomize + +import "testing" + +func TestStableDBName(t *testing.T) { + t.Parallel() + + db := "awesomedb" + + one, two := StableDBName(db), StableDBName(db) + + if len(one) != 40 { + t.Error("want 40 characters:", len(one), one) + } + + if one != two { + t.Error("it should always produce the same value") + } +} diff --git a/boil/randomize/randomize.go b/boil/randomize/randomize.go index 2277ef256..e12193e10 100644 --- a/boil/randomize/randomize.go +++ b/boil/randomize/randomize.go @@ -4,7 +4,6 @@ package randomize import ( "database/sql" "fmt" - "math/rand" "reflect" "regexp" "sort" @@ -572,86 +571,3 @@ func getVariableRandValue(s *Seed, kind reflect.Kind, typ reflect.Type) interfac return nil } - -const alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - -func randStr(s *Seed, ln int) string { - str := make([]byte, ln) - for i := 0; i < ln; i++ { - str[i] = byte(alphabet[s.nextInt()%len(alphabet)]) - } - - return string(str) -} - -func randByteSlice(s *Seed, ln int) []byte { - str := make([]byte, ln) - for i := 0; i < ln; i++ { - str[i] = byte(s.nextInt() % 256) - } - - return str -} - -func randPoint() string { - a := rand.Intn(100) - b := a + 1 - return fmt.Sprintf("(%d,%d)", a, b) -} - -func randBox() string { - a := rand.Intn(100) - b := a + 1 - c := a + 2 - d := a + 3 - return fmt.Sprintf("(%d,%d),(%d,%d)", a, b, c, d) -} - -func randCircle() string { - a, b, c := rand.Intn(100), rand.Intn(100), rand.Intn(100) - return fmt.Sprintf("((%d,%d),%d)", a, b, c) -} - -func randNetAddr() string { - return fmt.Sprintf( - "%d.%d.%d.%d", - rand.Intn(254)+1, - rand.Intn(254)+1, - rand.Intn(254)+1, - rand.Intn(254)+1, - ) -} - -func randMacAddr() string { - buf := make([]byte, 6) - _, err := rand.Read(buf) - if err != nil { - panic(err) - } - - // Set the local bit - buf[0] |= 2 - return fmt.Sprintf( - "%02x:%02x:%02x:%02x:%02x:%02x", - buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], - ) -} - -func randLsn() string { - a := rand.Int63n(9000000) - b := rand.Int63n(9000000) - return fmt.Sprintf("%d/%d", a, b) -} - -func randTxID() string { - // Order of integers is relevant - a := rand.Intn(200) + 100 - b := a + 100 - c := a - d := a + 50 - return fmt.Sprintf("%d:%d:%d,%d", a, b, c, d) -} - -func randMoney(s *Seed) string { - return fmt.Sprintf("%d.00", s.nextInt()) -} From 76b75dfaaa5875045e43006db46c428cd32600a4 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Mon, 12 Sep 2016 22:43:29 -0700 Subject: [PATCH 36/64] Remove foreign keys from mysql dump --- imports.go | 9 ++- templates_test/main_test/mysql_main.tpl | 4 +- templates_test/main_test/postgres_main.tpl | 2 +- .../singleton/boil_queries_test.tpl | 58 ++++++------------- 4 files changed, 26 insertions(+), 47 deletions(-) diff --git a/imports.go b/imports.go index ec3be4c37..66917a593 100644 --- a/imports.go +++ b/imports.go @@ -206,11 +206,12 @@ var defaultSingletonTestTemplateImports = map[string]imports{ }, "boil_queries_test": { standard: importList{ - `"crypto/md5"`, + `"bytes"`, `"fmt"`, - `"os"`, - `"strconv"`, + `"io"`, + `"io/ioutil"`, `"math/rand"`, + `"regexp"`, }, thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, @@ -238,6 +239,7 @@ var defaultTestMainImports = map[string]imports{ `"github.com/pkg/errors"`, `"github.com/spf13/viper"`, `"github.com/vattle/sqlboiler/bdb/drivers"`, + `"github.com/vattle/sqlboiler/boil/randomize"`, `_ "github.com/lib/pq"`, }, }, @@ -256,6 +258,7 @@ var defaultTestMainImports = map[string]imports{ `"github.com/pkg/errors"`, `"github.com/spf13/viper"`, `"github.com/vattle/sqlboiler/bdb/drivers"`, + `"github.com/vattle/sqlboiler/boil/randomize"`, `_ "github.com/go-sql-driver/mysql"`, }, }, diff --git a/templates_test/main_test/mysql_main.tpl b/templates_test/main_test/mysql_main.tpl index c3c7c4d60..a9daf05d1 100644 --- a/templates_test/main_test/mysql_main.tpl +++ b/templates_test/main_test/mysql_main.tpl @@ -27,7 +27,7 @@ func (m *mysqlTester) setup() error { m.port = viper.GetInt("mysql.port") m.sslmode = viper.GetString("mysql.sslmode") // Create a randomized db name. - m.testDBName = getDBNameHash(m.dbName) + m.testDBName = randomize.StableDBName(m.dbName) if err = m.makeOptionFile(); err != nil { return errors.Wrap(err, "couldn't make option file") @@ -45,7 +45,7 @@ func (m *mysqlTester) setup() error { r, w := io.Pipe() dumpCmd.Stdout = w - createCmd.Stdin = io.TeeReader(r, os.Stdout) + createCmd.Stdin = newFKeyDestroyer(r) if err = dumpCmd.Start(); err != nil { return errors.Wrap(err, "failed to start mysqldump command") diff --git a/templates_test/main_test/postgres_main.tpl b/templates_test/main_test/postgres_main.tpl index 7d19f1ba6..585beafea 100644 --- a/templates_test/main_test/postgres_main.tpl +++ b/templates_test/main_test/postgres_main.tpl @@ -92,7 +92,7 @@ func (p *pgTester) setup() error { p.port = viper.GetInt("postgres.port") p.sslmode = viper.GetString("postgres.sslmode") // Create a randomized db name. - p.testDBName = getDBNameHash(p.dbName) + p.testDBName = randomize.StableDBName(p.dbName) err = p.dropTestDB() if err != nil { diff --git a/templates_test/singleton/boil_queries_test.tpl b/templates_test/singleton/boil_queries_test.tpl index 26f09d3c5..d7a6dd3d9 100644 --- a/templates_test/singleton/boil_queries_test.tpl +++ b/templates_test/singleton/boil_queries_test.tpl @@ -7,53 +7,29 @@ func MustTx(transactor boil.Transactor, err error) boil.Transactor { return transactor } -func initDBNameRand(input string) { - sum := md5.Sum([]byte(input)) +var rgxPGFkey = regexp.MustCompile(`(?m)(?s)^ALTER TABLE ONLY.*?ADD CONSTRAINT.*?FOREIGN KEY.*?;\n`) +var rgxMySQLkey = regexp.MustCompile(`(?m)((,\n)?\s+CONSTRAINT.*?FOREIGN KEY.*?\n)+`) - var sumInt string - for _, v := range sum { - sumInt = sumInt + strconv.Itoa(int(v)) - } - - // Cut integer to 18 digits to ensure no int64 overflow. - sumInt = sumInt[:18] - - sumTmp := sumInt - for i, v := range sumInt { - if v == '0' { - sumTmp = sumInt[i+1:] - continue - } - break - } - - sumInt = sumTmp - - randSeed, err := strconv.ParseInt(sumInt, 0, 64) - if err != nil { - fmt.Printf("Unable to parse sumInt: %s", err) - os.Exit(-1) +func newFKeyDestroyer(reader io.Reader) io.Reader { + return &fKeyDestroyer{ + reader: reader, } +} - dbNameRand = rand.New(rand.NewSource(randSeed)) +type fKeyDestroyer struct { + reader io.Reader + buf *bytes.Buffer } -var alphabetChars = "abcdefghijklmnopqrstuvwxyz" -func randStr(length int) string { - c := len(alphabetChars) +func (f *fKeyDestroyer) Read(b []byte) (int, error) { + if f.buf == nil { + all, err := ioutil.ReadAll(f.reader) + if err != nil { + return 0, err + } - output := make([]rune, length) - for i := 0; i < length; i++ { - output[i] = rune(alphabetChars[dbNameRand.Intn(c)]) + f.buf = bytes.NewBuffer(rgxMySQLkey.ReplaceAll(rgxPGFkey.ReplaceAll(all, []byte{}), []byte{})) } - return string(output) -} - -// getDBNameHash takes a database name in, and generates -// a random string using the database name as the rand Seed. -// getDBNameHash is used to generate unique test database names. -func getDBNameHash(input string) string { - initDBNameRand(input) - return randStr(40) + return f.buf.Read(b) } From 912693a12433241bf746ae8b7bc5b25f3406d2e7 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Mon, 12 Sep 2016 23:28:23 -0700 Subject: [PATCH 37/64] Update parameter generation for mysql --- strmangle/strmangle.go | 22 ++++++++++++------- strmangle/strmangle_test.go | 23 ++++++++++++++++++++ templates/04_relationship_to_one.tpl | 2 +- templates/05_relationship_to_many.tpl | 4 ++-- templates/09_relationship_to_many_setops.tpl | 8 +++---- templates/11_find.tpl | 2 +- templates/12_insert.tpl | 2 +- templates/13_update.tpl | 5 ++++- templates/15_delete.tpl | 2 +- templates/17_exists.tpl | 2 +- templates_test/relationship_to_many.tpl | 4 ++-- templates_test/upsert.tpl | 3 +++ 12 files changed, 57 insertions(+), 22 deletions(-) diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index ab0147f14..5f62aca47 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -427,14 +427,19 @@ func Placeholders(indexPlaceholders bool, count int, start int, group int) strin // SetParamNames takes a slice of columns and returns a comma separated // list of parameter names for a template statement SET clause. // eg: "col1"=$1, "col2"=$2, "col3"=$3 -func SetParamNames(columns []string) string { +func SetParamNames(lq, rq string, start int, columns []string) string { buf := GetBuffer() defer PutBuffer(buf) for i, c := range columns { - buf.WriteString(fmt.Sprintf(`"%s"=$%d`, c, i+1)) + if start != 0 { + buf.WriteString(fmt.Sprintf(`%s%s%s=$%d`, lq, c, rq, i+start)) + } else { + buf.WriteString(fmt.Sprintf(`%s%s%s=?`, lq, c, rq)) + } + if i < len(columns)-1 { - buf.WriteString(", ") + buf.WriteByte(',') } } @@ -444,15 +449,16 @@ func SetParamNames(columns []string) string { // WhereClause returns the where clause using start as the $ flag index // For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3" func WhereClause(lq, rq string, start int, cols []string) string { - if start == 0 { - panic("0 is not a valid start number for whereClause") - } - buf := GetBuffer() defer PutBuffer(buf) for i, c := range cols { - buf.WriteString(fmt.Sprintf(`%s%s%s=$%d`, lq, c, rq, start+i)) + if start != 0 { + buf.WriteString(fmt.Sprintf(`%s%s%s=$%d`, lq, c, rq, start+i)) + } else { + buf.WriteString(fmt.Sprintf(`%s%s%s=?`, lq, c, rq)) + } + if i < len(cols)-1 { buf.WriteString(" AND ") } diff --git a/strmangle/strmangle_test.go b/strmangle/strmangle_test.go index f44ebfa2e..6d802d4c0 100644 --- a/strmangle/strmangle_test.go +++ b/strmangle/strmangle_test.go @@ -317,6 +317,28 @@ func TestPrefixStringSlice(t *testing.T) { } } +func TestSetParamNames(t *testing.T) { + t.Parallel() + + tests := []struct { + Cols []string + Start int + Should string + }{ + {Cols: []string{"col1", "col2"}, Start: 0, Should: `"col1"=?,"col2"=?`}, + {Cols: []string{"col1"}, Start: 2, Should: `"col1"=$2`}, + {Cols: []string{"col1", "col2"}, Start: 4, Should: `"col1"=$4,"col2"=$5`}, + {Cols: []string{"col1", "col2", "col3"}, Start: 4, Should: `"col1"=$4,"col2"=$5,"col3"=$6`}, + } + + for i, test := range tests { + r := SetParamNames(`"`, `"`, test.Start, test.Cols) + if r != test.Should { + t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test) + } + } +} + func TestWhereClause(t *testing.T) { t.Parallel() @@ -325,6 +347,7 @@ func TestWhereClause(t *testing.T) { Start int Should string }{ + {Cols: []string{"col1", "col2"}, Start: 0, Should: `"col1"=? AND "col2"=?`}, {Cols: []string{"col1"}, Start: 2, Should: `"col1"=$2`}, {Cols: []string{"col1", "col2"}, Start: 4, Should: `"col1"=$4 AND "col2"=$5`}, {Cols: []string{"col1", "col2", "col3"}, Start: 4, Should: `"col1"=$4 AND "col2"=$5 AND "col3"=$6`}, diff --git a/templates/04_relationship_to_one.tpl b/templates/04_relationship_to_one.tpl index 06dac02cd..0fdfd8761 100644 --- a/templates/04_relationship_to_one.tpl +++ b/templates/04_relationship_to_one.tpl @@ -10,7 +10,7 @@ func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}G(mods . // {{.Function.Name}} pointed to by the foreign key. func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}(exec boil.Executor, mods ...qm.QueryMod) ({{$varNameSingular}}Query) { queryMods := []qm.QueryMod{ - qm.Where("{{.ForeignTable.ColumnName}}=$1", {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}), + qm.Where("{{.ForeignTable.ColumnName}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}), } queryMods = append(queryMods, mods...) diff --git a/templates/05_relationship_to_many.tpl b/templates/05_relationship_to_many.tpl index 9d8184711..b96537738 100644 --- a/templates/05_relationship_to_many.tpl +++ b/templates/05_relationship_to_many.tpl @@ -33,11 +33,11 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Na {{if .ToJoinTable -}} queryMods = append(queryMods, qm.InnerJoin("{{.JoinTable | $dot.SchemaTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}}"), - qm.Where("{{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}}=$1", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), + qm.Where("{{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), ) {{else -}} queryMods = append(queryMods, - qm.Where("{{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}}=$1", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), + qm.Where("{{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), ) {{end}} diff --git a/templates/09_relationship_to_many_setops.tpl b/templates/09_relationship_to_many_setops.tpl index 59842f66b..0280c31aa 100644 --- a/templates/09_relationship_to_many_setops.tpl +++ b/templates/09_relationship_to_many_setops.tpl @@ -39,7 +39,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function {{if .ToJoinTable -}} for _, rel := range related { - query := "insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values ($1, $2)" + query := "insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}" values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}} if boil.DebugMode { @@ -96,10 +96,10 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function // Sets related.R.{{$rel.Function.ForeignName}}'s {{$rel.Function.Name}} accordingly. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Set{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error { {{if .ToJoinTable -}} - query := "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = $1" + query := "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}" values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} {{else -}} - query := "update {{.ForeignTable | $dot.SchemaTable}} set {{.ForeignColumn | $dot.Quotes}} = null where {{.ForeignColumn | $dot.Quotes}} = $1" + query := "update {{.ForeignTable | $dot.SchemaTable}} set {{.ForeignColumn | $dot.Quotes}} = null where {{.ForeignColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}" values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} {{end -}} if boil.DebugMode { @@ -140,7 +140,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Remove{{$rel.Funct var err error {{if .ToJoinTable -}} query := fmt.Sprintf( - "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = $1 and {{.JoinForeignColumn | $dot.Quotes}} in (%s)", + "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}} and {{.JoinForeignColumn | $dot.Quotes}} in (%s)", strmangle.Placeholders(dialect.IndexPlaceholders, len(related), 1, 1), ) values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} diff --git a/templates/11_find.tpl b/templates/11_find.tpl index f7ad81806..aede96596 100644 --- a/templates/11_find.tpl +++ b/templates/11_find.tpl @@ -29,7 +29,7 @@ func Find{{$tableNameSingular}}(exec boil.Executor, {{$pkArgs}}, selectCols ...s sel = strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, selectCols), ",") } query := fmt.Sprintf( - "select %s from {{.Table.Name | .SchemaTable}} where {{whereClause .LQ .RQ 1 .Table.PKey.Columns}}", sel, + "select %s from {{.Table.Name | .SchemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}", sel, ) q := boil.SQL(exec, query, {{$pkNames | join ", "}}) diff --git a/templates/12_insert.tpl b/templates/12_insert.tpl index a32087161..d05412474 100644 --- a/templates/12_insert.tpl +++ b/templates/12_insert.tpl @@ -69,7 +69,7 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string if len(cache.retMapping) != 0 { {{if .UseLastInsertID -}} - cache.retQuery = fmt.Sprintf("SELECT %s FROM {{$schemaTable}} WHERE %s", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}"), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", 1, {{$varNameSingular}}PrimaryKeyColumns)) + cache.retQuery = fmt.Sprintf("SELECT %s FROM {{$schemaTable}} WHERE %s", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}"), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns)) {{else -}} cache.query += fmt.Sprintf(" RETURNING {{.LQ}}%s{{.RQ}}", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}")) {{end -}} diff --git a/templates/13_update.tpl b/templates/13_update.tpl index b3ade49a6..5b858c135 100644 --- a/templates/13_update.tpl +++ b/templates/13_update.tpl @@ -53,7 +53,10 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string if !cached { wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist) - cache.query = fmt.Sprintf("UPDATE {{$schemaTable}} SET %s WHERE %s", strmangle.SetParamNames(wl), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", len(wl)+1, {{$varNameSingular}}PrimaryKeyColumns)) + cache.query = fmt.Sprintf("UPDATE {{$schemaTable}} SET %s WHERE %s", + strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, wl), + strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}len(wl)+1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns), + ) cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...)) if err != nil { return err diff --git a/templates/15_delete.tpl b/templates/15_delete.tpl index 97343692f..025c5ba05 100644 --- a/templates/15_delete.tpl +++ b/templates/15_delete.tpl @@ -44,7 +44,7 @@ func (o *{{$tableNameSingular}}) Delete(exec boil.Executor) error { args := o.inPrimaryKeyArgs() - sql := "DELETE FROM {{$schemaTable}} WHERE {{whereClause .LQ .RQ 1 .Table.PKey.Columns}}" + sql := "DELETE FROM {{$schemaTable}} WHERE {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}" if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, sql) diff --git a/templates/17_exists.tpl b/templates/17_exists.tpl index 94f829b32..b64da2f9d 100644 --- a/templates/17_exists.tpl +++ b/templates/17_exists.tpl @@ -7,7 +7,7 @@ func {{$tableNameSingular}}Exists(exec boil.Executor, {{$pkArgs}}) (bool, error) { var exists bool - sql := "select exists(select 1 from {{$schemaTable}} where {{whereClause .LQ .RQ 1 .Table.PKey.Columns}} limit 1)" + sql := "select exists(select 1 from {{$schemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}} limit 1)" if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, sql) diff --git a/templates_test/relationship_to_many.tpl b/templates_test/relationship_to_many.tpl index a157f48d0..b147890bd 100644 --- a/templates_test/relationship_to_many.tpl +++ b/templates_test/relationship_to_many.tpl @@ -41,11 +41,11 @@ func test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) { } {{if .ToJoinTable -}} - _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values ($1, $2)", a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) + _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) if err != nil { t.Fatal(err) } - _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values ($1, $2)", a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) + _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) if err != nil { t.Fatal(err) } diff --git a/templates_test/upsert.tpl b/templates_test/upsert.tpl index cb48b87a9..00c667b79 100644 --- a/templates_test/upsert.tpl +++ b/templates_test/upsert.tpl @@ -3,6 +3,9 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Upsert(t *testing.T) { + {{if not (eq .DriverName "postgres") -}} + t.Skip("not implemented for {{.DriverName}}") + {{end -}} t.Parallel() seed := randomize.NewSeed() From b1e8816d4234f1bb354494c786154eaa70c6e84d Mon Sep 17 00:00:00 2001 From: Aaron L Date: Mon, 12 Sep 2016 23:49:18 -0700 Subject: [PATCH 38/64] Fix update all --- templates/13_update.tpl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/templates/13_update.tpl b/templates/13_update.tpl index 5b858c135..a7cd0346a 100644 --- a/templates/13_update.tpl +++ b/templates/13_update.tpl @@ -150,7 +150,7 @@ func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error i := 0 for name, value := range cols { - colNames[i] = strmangle.IdentQuote(dialect.LQ, dialect.RQ, name) + colNames[i] = name args[i] = value i++ } @@ -159,10 +159,8 @@ func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error args = append(args, o.inPrimaryKeyArgs()...) sql := fmt.Sprintf( - "UPDATE {{$schemaTable}} SET (%s) = (%s) WHERE (%s) IN (%s)", - strings.Join(colNames, ","), - strmangle.Placeholders(dialect.IndexPlaceholders, len(colNames), 1, 1), - strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), + "UPDATE {{$schemaTable}} SET %s WHERE ({{.LQ}}{{.Table.PKey.Columns | join (printf "%s,%s" .LQ .RQ)}}{{.RQ}}) IN (%s)", + strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, colNames), strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(colNames)+1, len({{$varNameSingular}}PrimaryKeyColumns)), ) From 468e2f9ad3b3aad290b03e11af220f04405644a9 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Tue, 13 Sep 2016 00:28:00 -0700 Subject: [PATCH 39/64] Fix default values in mysql driver --- bdb/drivers/mysql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bdb/drivers/mysql.go b/bdb/drivers/mysql.go index bee699d12..af94db599 100644 --- a/bdb/drivers/mysql.go +++ b/bdb/drivers/mysql.go @@ -117,7 +117,7 @@ func (m *MySQLDriver) Columns(schema, tableName string) ([]bdb.Column, error) { var columns []bdb.Column rows, err := m.dbConn.Query(` - select column_name, data_type, column_default, is_nullable, + select column_name, data_type, if(extra = 'auto_increment','auto_increment', column_default), is_nullable, exists ( select c.column_name from information_schema.table_constraints tc From 1facccacc1cb0e9db39f80f67964fd1994457ec4 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Tue, 13 Sep 2016 00:48:14 -0700 Subject: [PATCH 40/64] Fix an edge case for MySQL - This patch removes auto-generation of queries that have the pattern COUNT(tablename.*) which is a syntax error in mysql. --- boil/query.go | 5 +++++ boil/query_builders.go | 2 +- boil/query_test.go | 11 +++++++++++ templates/03_finishers.tpl | 1 + 4 files changed, 18 insertions(+), 1 deletion(-) diff --git a/boil/query.go b/boil/query.go index ae6a08cca..ae10ce26c 100644 --- a/boil/query.go +++ b/boil/query.go @@ -151,6 +151,11 @@ func SetLoad(q *Query, relationships ...string) { q.load = append([]string(nil), relationships...) } +// SetSelect on the query. +func SetSelect(q *Query, sel []string) { + q.selectCols = sel +} + // SetCount on the query. func SetCount(q *Query) { q.count = true diff --git a/boil/query_builders.go b/boil/query_builders.go index 3b1f8c655..8c3fd91c9 100644 --- a/boil/query_builders.go +++ b/boil/query_builders.go @@ -58,7 +58,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { buf.WriteString(strings.Join(selectColsWithAs, ", ")) } else if hasSelectCols { buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.selectCols), ", ")) - } else if hasJoins { + } else if hasJoins && !q.count { selectColsWithStars := writeStars(q) buf.WriteString(strings.Join(selectColsWithStars, ", ")) } else { diff --git a/boil/query_test.go b/boil/query_test.go index f61818929..005e834f5 100644 --- a/boil/query_test.go +++ b/boil/query_test.go @@ -290,6 +290,17 @@ func TestFrom(t *testing.T) { } } +func TestSetSelect(t *testing.T) { + t.Parallel() + + q := &Query{selectCols: []string{"hello"}} + SetSelect(q, nil) + + if q.selectCols != nil { + t.Errorf("want nil") + } +} + func TestSetCount(t *testing.T) { t.Parallel() diff --git a/templates/03_finishers.tpl b/templates/03_finishers.tpl index 6d259550e..970aa377d 100644 --- a/templates/03_finishers.tpl +++ b/templates/03_finishers.tpl @@ -79,6 +79,7 @@ func (q {{$varNameSingular}}Query) CountP() int64 { func (q {{$varNameSingular}}Query) Count() (int64, error) { var count int64 + boil.SetSelect(q.Query, nil) boil.SetCount(q.Query) err := boil.ExecQueryOne(q.Query).Scan(&count) From 91bb5ee940ca9ef6a89363a8df1efc211dcbe4b0 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Tue, 13 Sep 2016 19:46:32 +1000 Subject: [PATCH 41/64] Change Exec funcs to methods with Query receiver --- boil/query.go | 6 +++--- boil/reflect.go | 2 +- templates/03_finishers.tpl | 4 ++-- templates/13_update.tpl | 2 +- templates/15_delete.tpl | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/boil/query.go b/boil/query.go index ae10ce26c..5ae5fabc7 100644 --- a/boil/query.go +++ b/boil/query.go @@ -97,7 +97,7 @@ func SQLG(query string, args ...interface{}) *Query { } // ExecQuery executes a query that does not need a row returned -func ExecQuery(q *Query) (sql.Result, error) { +func (q *Query) ExecQuery() (sql.Result, error) { qs, args := buildQuery(q) if DebugMode { fmt.Fprintln(DebugWriter, qs) @@ -107,7 +107,7 @@ func ExecQuery(q *Query) (sql.Result, error) { } // ExecQueryOne executes the query for the One finisher and returns a row -func ExecQueryOne(q *Query) *sql.Row { +func (q *Query) ExecQueryOne() *sql.Row { qs, args := buildQuery(q) if DebugMode { fmt.Fprintln(DebugWriter, qs) @@ -117,7 +117,7 @@ func ExecQueryOne(q *Query) *sql.Row { } // ExecQueryAll executes the query for the All finisher and returns multiple rows -func ExecQueryAll(q *Query) (*sql.Rows, error) { +func (q *Query) ExecQueryAll() (*sql.Rows, error) { qs, args := buildQuery(q) if DebugMode { fmt.Fprintln(DebugWriter, qs) diff --git a/boil/reflect.go b/boil/reflect.go index d4c4dee99..dd961bbc3 100644 --- a/boil/reflect.go +++ b/boil/reflect.go @@ -100,7 +100,7 @@ func (q *Query) Bind(obj interface{}) error { return err } - rows, err := ExecQueryAll(q) + rows, err := q.ExecQueryAll() if err != nil { return errors.Wrap(err, "bind failed to execute query") } diff --git a/templates/03_finishers.tpl b/templates/03_finishers.tpl index 970aa377d..a0de3fc1e 100644 --- a/templates/03_finishers.tpl +++ b/templates/03_finishers.tpl @@ -82,7 +82,7 @@ func (q {{$varNameSingular}}Query) Count() (int64, error) { boil.SetSelect(q.Query, nil) boil.SetCount(q.Query) - err := boil.ExecQueryOne(q.Query).Scan(&count) + err := q.Query.ExecQueryOne().Scan(&count) if err != nil { return 0, errors.Wrap(err, "{{.PkgName}}: failed to count {{.Table.Name}} rows") } @@ -107,7 +107,7 @@ func (q {{$varNameSingular}}Query) Exists() (bool, error) { boil.SetCount(q.Query) boil.SetLimit(q.Query, 1) - err := boil.ExecQueryOne(q.Query).Scan(&count) + err := q.Query.ExecQueryOne().Scan(&count) if err != nil { return false, errors.Wrap(err, "{{.PkgName}}: failed to check if {{.Table.Name}} exists") } diff --git a/templates/13_update.tpl b/templates/13_update.tpl index a7cd0346a..28c66311e 100644 --- a/templates/13_update.tpl +++ b/templates/13_update.tpl @@ -107,7 +107,7 @@ func (q {{$varNameSingular}}Query) UpdateAllP(cols M) { func (q {{$varNameSingular}}Query) UpdateAll(cols M) error { boil.SetUpdate(q.Query, cols) - _, err := boil.ExecQuery(q.Query) + _, err := q.Query.ExecQuery() if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to update all for {{.Table.Name}}") } diff --git a/templates/15_delete.tpl b/templates/15_delete.tpl index 025c5ba05..2f6d6d9da 100644 --- a/templates/15_delete.tpl +++ b/templates/15_delete.tpl @@ -80,7 +80,7 @@ func (q {{$varNameSingular}}Query) DeleteAll() error { boil.SetDelete(q.Query) - _, err := boil.ExecQuery(q.Query) + _, err := q.Query.ExecQuery() if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{.Table.Name}}") } From a86e794b617b1e5566ae9394b62c404107b378a0 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Tue, 13 Sep 2016 17:20:13 -0700 Subject: [PATCH 42/64] Fix composite primary keys for DeleteAll --- templates/15_delete.tpl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/templates/15_delete.tpl b/templates/15_delete.tpl index 2f6d6d9da..34193556a 100644 --- a/templates/15_delete.tpl +++ b/templates/15_delete.tpl @@ -134,7 +134,7 @@ func (o {{$tableNameSingular}}Slice) DeleteAll(exec boil.Executor) error { sql := fmt.Sprintf( "DELETE FROM {{$schemaTable}} WHERE (%s) IN (%s)", - strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), "{{.LQ}},{{.RQ}}"), + strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), ) From e5edef144b02a1178fb79e4038d5496deaf99b37 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Tue, 13 Sep 2016 21:46:58 -0700 Subject: [PATCH 43/64] Fix up postgres main just like mysql --- imports.go | 11 +- templates_test/main_test/mysql_main.tpl | 10 +- templates_test/main_test/postgres_main.tpl | 306 ++++++++---------- .../singleton/boil_queries_test.tpl | 8 +- 4 files changed, 148 insertions(+), 187 deletions(-) diff --git a/imports.go b/imports.go index 66917a593..9ffffa5d8 100644 --- a/imports.go +++ b/imports.go @@ -227,13 +227,14 @@ var defaultSingletonTestTemplateImports = map[string]imports{ var defaultTestMainImports = map[string]imports{ "postgres": { standard: importList{ - `"os"`, - `"os/exec"`, - `"fmt"`, - `"io/ioutil"`, `"bytes"`, `"database/sql"`, - `"path/filepath"`, + `"fmt"`, + `"io"`, + `"io/ioutil"`, + `"os"`, + `"os/exec"`, + `"strings"`, }, thirdParty: importList{ `"github.com/pkg/errors"`, diff --git a/templates_test/main_test/mysql_main.tpl b/templates_test/main_test/mysql_main.tpl index a9daf05d1..6ba0f415a 100644 --- a/templates_test/main_test/mysql_main.tpl +++ b/templates_test/main_test/mysql_main.tpl @@ -45,7 +45,7 @@ func (m *mysqlTester) setup() error { r, w := io.Pipe() dumpCmd.Stdout = w - createCmd.Stdin = newFKeyDestroyer(r) + createCmd.Stdin = newFKeyDestroyer(rgxMySQLkey, r) if err = dumpCmd.Start(); err != nil { return errors.Wrap(err, "failed to start mysqldump command") @@ -111,14 +111,14 @@ func (m *mysqlTester) dropTestDB() error { } func (m *mysqlTester) teardown() error { - if err := m.dropTestDB(); err != nil { - return err - } - if m.dbConn != nil { m.dbConn.Close() } + if err := m.dropTestDB(); err != nil { + return err + } + return os.Remove(m.optionFile) } diff --git a/templates_test/main_test/postgres_main.tpl b/templates_test/main_test/postgres_main.tpl index 585beafea..8f99790c5 100644 --- a/templates_test/main_test/postgres_main.tpl +++ b/templates_test/main_test/postgres_main.tpl @@ -6,7 +6,9 @@ type pgTester struct { user string pass string sslmode string - port int + port int + + pgPassFile string testDBName string } @@ -15,194 +17,150 @@ func init() { dbMain = &pgTester{} } -// disableTriggers is used to disable foreign key constraints for every table. -// If this is not used we cannot test inserts due to foreign key constraint errors. -func (p *pgTester) disableTriggers() error { - var stmts []string +// setup dumps the database schema and imports it into a temporary randomly +// generated test database so that tests can be run against it using the +// generated sqlboiler ORM package. +func (p *pgTester) setup() error { + var err error - {{range .Tables -}} - stmts = append(stmts, `ALTER TABLE {{.Name}} DISABLE TRIGGER ALL;`) - {{end -}} + p.dbName = viper.GetString("postgres.dbname") + p.host = viper.GetString("postgres.host") + p.user = viper.GetString("postgres.user") + p.pass = viper.GetString("postgres.pass") + p.port = viper.GetInt("postgres.port") + p.sslmode = viper.GetString("postgres.sslmode") + // Create a randomized db name. + p.testDBName = randomize.StableDBName(p.dbName) - if len(stmts) == 0 { - return nil - } + if err = p.makePGPassFile(); err != nil { + return err + } + + if err = p.dropTestDB(); err != nil { + return err + } + if err = p.createTestDB(); err != nil { + return err + } + + dumpCmd := exec.Command("pg_dump", "--schema-only", p.dbName) + dumpCmd.Env = append(os.Environ(), p.pgEnv()...) + createCmd := exec.Command("psql", p.testDBName) + createCmd.Env = append(os.Environ(), p.pgEnv()...) + + r, w := io.Pipe() + dumpCmd.Stdout = w + createCmd.Stdin = io.TeeReader(newFKeyDestroyer(rgxPGFkey, r), os.Stdout) + + if err = dumpCmd.Start(); err != nil { + return errors.Wrap(err, "failed to start pg_dump command") + } + if err = createCmd.Start(); err != nil { + return errors.Wrap(err, "failed to start psql command") + } + + if err = dumpCmd.Wait(); err != nil { + fmt.Println(err) + return errors.Wrap(err, "failed to wait for pg_dump command") + } + + w.Close() // After dumpCmd is done, close the write end of the pipe + + if err = createCmd.Wait(); err != nil { + fmt.Println(err) + return errors.Wrap(err, "failed to wait for psql command") + } + + return nil +} - var err error - for _, s := range stmts { - _, err = p.dbConn.Exec(s) - if err != nil { - return err - } - } - - return nil +func (p *pgTester) runCmd(stdin, command string, args ...string) error { + cmd := exec.Command(command, args...) + cmd.Env = append(os.Environ(), p.pgEnv()...) + + if len(stdin) != 0 { + cmd.Stdin = strings.NewReader(stdin) + } + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + cmd.Stdout = stdout + cmd.Stderr = stderr + if err := cmd.Run(); err != nil { + fmt.Println("failed running:", command, args) + fmt.Println(stdout.String()) + fmt.Println(stderr.String()) + return err + } + + return nil } -// teardown executes cleanup tasks when the tests finish running -func (p *pgTester) teardown() error { - return p.dropTestDB() +func (p *pgTester) pgEnv() []string { + return []string{ + fmt.Sprintf("PGHOST=%s", p.host), + fmt.Sprintf("PGPORT=%d", p.port), + fmt.Sprintf("PGUSER=%s", p.user), + fmt.Sprintf("PGPASS=%s", p.pgPassFile), + } } -func (p *pgTester) conn() (*sql.DB, error) { - return p.dbConn, nil +func (p *pgTester) makePGPassFile() error { + tmp, err := ioutil.TempFile("", "pgpass") + if err != nil { + return errors.Wrap(err, "failed to create option file") + } + + fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.dbName, p.user) + if len(p.pass) != 0 { + fmt.Fprintf(tmp, ":%s", p.pass) + } + fmt.Fprintln(tmp) + + fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user) + if len(p.pass) != 0 { + fmt.Fprintf(tmp, ":%s", p.pass) + } + fmt.Fprintln(tmp) + + p.pgPassFile = tmp.Name() + return tmp.Close() +} + +func (p *pgTester) createTestDB() error { + return p.runCmd("", "createdb", p.testDBName) } -// dropTestDB switches its connection to the template1 database temporarily -// so that it can drop the test database without causing "in use" conflicts. -// The template1 database should be present on all default postgres installations. func (p *pgTester) dropTestDB() error { - var err error - if p.dbConn != nil { - if err = p.dbConn.Close(); err != nil { - return err - } - } - - p.dbConn, err = DBConnect(p.user, p.pass, "template1", p.host, p.port, p.sslmode) - if err != nil { - return err - } - - _, err = p.dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, p.testDBName)) - if err != nil { - return err - } - - return p.dbConn.Close() + return p.runCmd("", "dropdb", "--if-exists", p.testDBName) } -// DBConnect connects to a database and returns the handle. -func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.DB, error) { - connStr := drivers.PostgresBuildQueryString(user, pass, dbname, host, port, sslmode) +// teardown executes cleanup tasks when the tests finish running +func (p *pgTester) teardown() error { + var err error + if err = p.dbConn.Close(); err != nil { + return err + } + p.dbConn = nil + + if err = p.dropTestDB(); err != nil { + return err + } - return sql.Open("postgres", connStr) + return os.Remove(p.pgPassFile) } -// setup dumps the database schema and imports it into a temporary randomly -// generated test database so that tests can be run against it using the -// generated sqlboiler ORM package. -func (p *pgTester) setup() error { - var err error +func (p *pgTester) conn() (*sql.DB, error) { + if p.dbConn != nil { + return p.dbConn, nil + } - p.dbName = viper.GetString("postgres.dbname") - p.host = viper.GetString("postgres.host") - p.user = viper.GetString("postgres.user") - p.pass = viper.GetString("postgres.pass") - p.port = viper.GetInt("postgres.port") - p.sslmode = viper.GetString("postgres.sslmode") - // Create a randomized db name. - p.testDBName = randomize.StableDBName(p.dbName) + var err error + p.dbConn, err = sql.Open("postgres", drivers.PostgresBuildQueryString(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode)) + if err != nil { + return nil, err + } - err = p.dropTestDB() - if err != nil { - fmt.Printf("%#v\n", err) - return err - } - - fhSchema, err := ioutil.TempFile(os.TempDir(), "sqlboilerschema") - if err != nil { - return errors.Wrap(err, "Unable to create sqlboiler schema tmp file") - } - defer os.Remove(fhSchema.Name()) - - passDir, err := ioutil.TempDir(os.TempDir(), "sqlboiler") - if err != nil { - return errors.Wrap(err, "Unable to create sqlboiler tmp dir for postgres pw file") - } - defer os.RemoveAll(passDir) - - // Write the postgres user password to a tmp file for pg_dump - pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.dbName, p.user)) - - if len(p.pass) > 0 { - pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, p.pass)) - } - - passFilePath := filepath.Join(passDir, "pwfile") - - err = ioutil.WriteFile(passFilePath, pwBytes, 0600) - if err != nil { - return errors.Wrap(err, "Unable to create pwfile in passDir") - } - - // The params for the pg_dump command to dump the database schema - params := []string{ - fmt.Sprintf(`--host=%s`, p.host), - fmt.Sprintf(`--port=%d`, p.port), - fmt.Sprintf(`--username=%s`, p.user), - "--schema-only", - p.dbName, - } - - // Dump the database schema into the sqlboilerschema tmp file - errBuf := bytes.Buffer{} - cmd := exec.Command("pg_dump", params...) - cmd.Stderr = &errBuf - cmd.Stdout = fhSchema - cmd.Env = append(os.Environ(), fmt.Sprintf(`PGPASSFILE=%s`, passFilePath)) - - if err := cmd.Run(); err != nil { - fmt.Printf("pg_dump exec failed: %s\n\n%s\n", err, errBuf.String()) - return err - } - - p.dbConn, err = DBConnect(p.user, p.pass, p.dbName, p.host, p.port, p.sslmode) - if err != nil { - return err - } - - // Create the randomly generated database - _, err = p.dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, p.testDBName)) - if err != nil { - return err - } - - // Close the old connection so we can reconnect to the test database - if err = p.dbConn.Close(); err != nil { - return err - } - - // Connect to the generated test db - p.dbConn, err = DBConnect(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode) - if err != nil { - return err - } - - // Write the test config credentials to a tmp file for pg_dump - testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user)) - - if len(p.pass) > 0 { - testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, p.pass)) - } - - testPassFilePath := passDir + "/testpwfile" - - err = ioutil.WriteFile(testPassFilePath, testPwBytes, 0600) - if err != nil { - return errors.Wrapf(err, "Unable to create testpwfile in passDir") - } - - // The params for the psql schema import command - params = []string{ - fmt.Sprintf(`--dbname=%s`, p.testDBName), - fmt.Sprintf(`--host=%s`, p.host), - fmt.Sprintf(`--port=%d`, p.port), - fmt.Sprintf(`--username=%s`, p.user), - fmt.Sprintf(`--file=%s`, fhSchema.Name()), - } - - // Import the database schema into the generated database. - // It is now ready to be used by the generated ORM package for testing. - outBuf := bytes.Buffer{} - cmd = exec.Command("psql", params...) - cmd.Stderr = &errBuf - cmd.Stdout = &outBuf - cmd.Env = append(os.Environ(), fmt.Sprintf(`PGPASSFILE=%s`, testPassFilePath)) - - if err = cmd.Run(); err != nil { - fmt.Printf("psql schema import exec failed: %s\n\n%s\n", err, errBuf.String()) - } - - return p.disableTriggers() + return p.dbConn, nil } + diff --git a/templates_test/singleton/boil_queries_test.tpl b/templates_test/singleton/boil_queries_test.tpl index d7a6dd3d9..90419be6c 100644 --- a/templates_test/singleton/boil_queries_test.tpl +++ b/templates_test/singleton/boil_queries_test.tpl @@ -7,18 +7,20 @@ func MustTx(transactor boil.Transactor, err error) boil.Transactor { return transactor } -var rgxPGFkey = regexp.MustCompile(`(?m)(?s)^ALTER TABLE ONLY.*?ADD CONSTRAINT.*?FOREIGN KEY.*?;\n`) +var rgxPGFkey = regexp.MustCompile(`(?m)^ALTER TABLE ONLY .*\n\s+ADD CONSTRAINT .*? FOREIGN KEY .*?;\n`) var rgxMySQLkey = regexp.MustCompile(`(?m)((,\n)?\s+CONSTRAINT.*?FOREIGN KEY.*?\n)+`) -func newFKeyDestroyer(reader io.Reader) io.Reader { +func newFKeyDestroyer(regex *regexp.Regexp, reader io.Reader) io.Reader { return &fKeyDestroyer{ reader: reader, + rgx: regex, } } type fKeyDestroyer struct { reader io.Reader buf *bytes.Buffer + rgx *regexp.Regexp } func (f *fKeyDestroyer) Read(b []byte) (int, error) { @@ -28,7 +30,7 @@ func (f *fKeyDestroyer) Read(b []byte) (int, error) { return 0, err } - f.buf = bytes.NewBuffer(rgxMySQLkey.ReplaceAll(rgxPGFkey.ReplaceAll(all, []byte{}), []byte{})) + f.buf = bytes.NewBuffer(f.rgx.ReplaceAll(all, []byte{})) } return f.buf.Read(b) From 4f1565147a698d5fd99e8a4df443b260eb30786c Mon Sep 17 00:00:00 2001 From: Aaron L Date: Tue, 13 Sep 2016 21:57:34 -0700 Subject: [PATCH 44/64] Fix indentation --- templates_test/main_test/postgres_main.tpl | 38 +++++++++++----------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/templates_test/main_test/postgres_main.tpl b/templates_test/main_test/postgres_main.tpl index 8f99790c5..cdd029e31 100644 --- a/templates_test/main_test/postgres_main.tpl +++ b/templates_test/main_test/postgres_main.tpl @@ -1,36 +1,36 @@ type pgTester struct { - dbConn *sql.DB + dbConn *sql.DB - dbName string - host string - user string - pass string - sslmode string - port int + dbName string + host string + user string + pass string + sslmode string + port int pgPassFile string - testDBName string + testDBName string } func init() { - dbMain = &pgTester{} + dbMain = &pgTester{} } // setup dumps the database schema and imports it into a temporary randomly // generated test database so that tests can be run against it using the // generated sqlboiler ORM package. func (p *pgTester) setup() error { - var err error - - p.dbName = viper.GetString("postgres.dbname") - p.host = viper.GetString("postgres.host") - p.user = viper.GetString("postgres.user") - p.pass = viper.GetString("postgres.pass") - p.port = viper.GetInt("postgres.port") - p.sslmode = viper.GetString("postgres.sslmode") - // Create a randomized db name. - p.testDBName = randomize.StableDBName(p.dbName) + var err error + + p.dbName = viper.GetString("postgres.dbname") + p.host = viper.GetString("postgres.host") + p.user = viper.GetString("postgres.user") + p.pass = viper.GetString("postgres.pass") + p.port = viper.GetInt("postgres.port") + p.sslmode = viper.GetString("postgres.sslmode") + // Create a randomized db name. + p.testDBName = randomize.StableDBName(p.dbName) if err = p.makePGPassFile(); err != nil { return err From 83f7092dc638dff0a3b575409123e8602a9b47f7 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Wed, 14 Sep 2016 18:08:30 +1000 Subject: [PATCH 45/64] Add MySQL Upsert, fix identation in all tpls --- README.md | 6 + bdb/keys.go | 28 + boil/query_builders.go | 33 +- templates.go | 1 + templates/00_struct.tpl | 38 +- templates/01_types.tpl | 34 +- templates/02_hooks.tpl | 148 ++--- templates/03_finishers.tpl | 138 ++--- templates/04_relationship_to_one.tpl | 30 +- templates/05_relationship_to_many.tpl | 64 +- templates/06_relationship_to_one_eager.tpl | 144 ++--- templates/07_relationship_to_many_eager.tpl | 234 +++---- templates/08_relationship_to_one_setops.tpl | 156 ++--- templates/09_relationship_to_many_setops.tpl | 384 ++++++------ templates/10_all.tpl | 6 +- templates/11_find.tpl | 64 +- templates/12_insert.tpl | 236 +++---- templates/13_update.tpl | 250 ++++---- templates/14_upsert.tpl | 150 +++-- templates/15_delete.tpl | 206 +++---- templates/16_reload.tpl | 86 +-- templates/17_exists.tpl | 46 +- templates/18_helpers.tpl | 24 +- templates/19_auto_timestamps.tpl | 152 ++--- templates/singleton/boil_queries.tpl | 18 +- templates/singleton/boil_types.tpl | 38 +- templates_test/all.tpl | 10 +- templates_test/delete.tpl | 166 ++--- templates_test/exists.tpl | 40 +- templates_test/find.tpl | 38 +- templates_test/finishers.tpl | 196 +++--- templates_test/helpers.tpl | 90 +-- templates_test/hooks.tpl | 220 +++---- templates_test/insert.tpl | 92 +-- templates_test/main_test/mysql_main.tpl | 251 ++++---- templates_test/relationship_to_many.tpl | 162 ++--- .../relationship_to_many_setops.tpl | 582 +++++++++--------- templates_test/relationship_to_one.tpl | 102 +-- templates_test/relationship_to_one_setops.tpl | 232 +++---- templates_test/reload.tpl | 76 +-- templates_test/select.tpl | 38 +- templates_test/update.tpl | 150 ++--- templates_test/upsert.tpl | 74 +-- 43 files changed, 2665 insertions(+), 2568 deletions(-) diff --git a/README.md b/README.md index 3479428a4..fa52f3eb4 100644 --- a/README.md +++ b/README.md @@ -1003,6 +1003,12 @@ p1.Name = "Hogan" err := p1.Upsert(db, true, []string{"id"}, []string{"name"}, "id", "name") ``` +The `updateOnConflict` argument allows you to specify whether you would like Postgres +to perform a `DO NOTHING` on conflict, opposed to a `DO UPDATE`. For MySQL, this param will not be generated. + +The `conflictColumns` argument allows you to specify the `ON CONFLICT` columns for Postgres. +For MySQL, this param will not be generated. + Note: Passing a different set of column values to the update component is not currently supported. If this feature is important to you let us know and we can consider adding something for this. diff --git a/bdb/keys.go b/bdb/keys.go index 35a8a84df..1ba1a2655 100644 --- a/bdb/keys.go +++ b/bdb/keys.go @@ -3,6 +3,7 @@ package bdb import ( "fmt" "regexp" + "strings" ) var rgxAutoIncColumn = regexp.MustCompile(`^nextval\(.*\)`) @@ -79,3 +80,30 @@ func SQLColDefinitions(cols []Column, names []string) SQLColumnDefs { return ret } + +// AutoIncPrimaryKey returns the auto-increment primary key column name or an +// empty string. Primary key columns with default values are presumed +// to be auto-increment, because pkeys need to be unique and a static +// default value would cause collisions. +func AutoIncPrimaryKey(cols []Column, pkey *PrimaryKey) *Column { + if pkey == nil { + return nil + } + + for _, pkeyColumn := range pkey.Columns { + for _, c := range cols { + if c.Name != pkeyColumn { + continue + } + + if c.Default != "auto_increment" || c.Nullable || + !(strings.HasPrefix(c.Type, "int") || strings.HasPrefix(c.Type, "uint")) { + continue + } + + return &c + } + } + + return nil +} diff --git a/boil/query_builders.go b/boil/query_builders.go index 8c3fd91c9..bbfe288a2 100644 --- a/boil/query_builders.go +++ b/boil/query_builders.go @@ -183,8 +183,37 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { return buf, args } -// BuildUpsertQuery builds a SQL statement string using the upsertData provided. -func BuildUpsertQuery(dia Dialect, tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string { +// BuildUpsertQueryMySQL builds a SQL statement string using the upsertData provided. +func BuildUpsertQueryMySQL(dia Dialect, tableName string, update, whitelist []string) string { + whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist) + + buf := strmangle.GetBuffer() + defer strmangle.PutBuffer(buf) + + fmt.Fprintf( + buf, + "INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE ", + tableName, + strings.Join(whitelist, ", "), + strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1), + ) + + for i, v := range update { + if i != 0 { + buf.WriteByte(',') + } + quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v) + buf.WriteString(quoted) + buf.WriteString(" = VALUES(") + buf.WriteString(quoted) + buf.WriteByte(')') + } + + return buf.String() +} + +// BuildUpsertQueryPostgres builds a SQL statement string using the upsertData provided. +func BuildUpsertQueryPostgres(dia Dialect, tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string { conflict = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, conflict) whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist) ret = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, ret) diff --git a/templates.go b/templates.go index 80fdfb232..a1da91768 100644 --- a/templates.go +++ b/templates.go @@ -174,6 +174,7 @@ var templateFunctions = template.FuncMap{ // dbdrivers ops "filterColumnsByDefault": bdb.FilterColumnsByDefault, + "autoIncPrimaryKey": bdb.AutoIncPrimaryKey, "sqlColDefinitions": bdb.SQLColDefinitions, "columnNames": bdb.ColumnNames, "columnDBTypes": bdb.ColumnDBTypes, diff --git a/templates/00_struct.tpl b/templates/00_struct.tpl index 70d1e2961..60b64b2e6 100644 --- a/templates/00_struct.tpl +++ b/templates/00_struct.tpl @@ -1,5 +1,5 @@ {{- define "relationship_to_one_struct_helper" -}} - {{.Function.Name}} *{{.ForeignTable.NameGo}} + {{.Function.Name}} *{{.ForeignTable.NameGo}} {{- end -}} {{- $dot := . -}} @@ -8,30 +8,30 @@ {{- $modelNameCamel := $tableNameSingular | camelCase -}} // {{$modelName}} is an object representing the database table. type {{$modelName}} struct { - {{range $column := .Table.Columns -}} - {{titleCase $column.Name}} {{$column.Type}} `{{generateTags $dot.Tags $column.Name}}boil:"{{$column.Name}}" json:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}" toml:"{{$column.Name}}" yaml:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}"` - {{end -}} - {{- if .Table.IsJoinTable -}} - {{- else}} - R *{{$modelNameCamel}}R `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"` - L {{$modelNameCamel}}L `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"` - {{end -}} + {{range $column := .Table.Columns -}} + {{titleCase $column.Name}} {{$column.Type}} `{{generateTags $dot.Tags $column.Name}}boil:"{{$column.Name}}" json:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}" toml:"{{$column.Name}}" yaml:"{{$column.Name}}{{if $column.Nullable}},omitempty{{end}}"` + {{end -}} + {{- if .Table.IsJoinTable -}} + {{- else}} + R *{{$modelNameCamel}}R `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"` + L {{$modelNameCamel}}L `{{generateIgnoreTags $dot.Tags}}boil:"-" json:"-" toml:"-" yaml:"-"` + {{end -}} } {{- if .Table.IsJoinTable -}} {{- else}} // {{$modelNameCamel}}R is where relationships are stored. type {{$modelNameCamel}}R struct { - {{range .Table.FKeys -}} - {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} - {{- template "relationship_to_one_struct_helper" $rel}} - {{end -}} - {{- range .Table.ToManyRelationships -}} - {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} - {{- template "relationship_to_one_struct_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table .)}} - {{else -}} - {{- $rel := textsFromRelationship $dot.Tables $dot.Table . -}} - {{$rel.Function.Name}} {{$rel.ForeignTable.Slice}} + {{range .Table.FKeys -}} + {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} + {{- template "relationship_to_one_struct_helper" $rel}} + {{end -}} + {{- range .Table.ToManyRelationships -}} + {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- template "relationship_to_one_struct_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table .)}} + {{else -}} + {{- $rel := textsFromRelationship $dot.Tables $dot.Table . -}} + {{$rel.Function.Name}} {{$rel.ForeignTable.Slice}} {{end -}}{{/* if ForeignColumnUnique */}} {{- end -}}{{/* range tomany */}} } diff --git a/templates/01_types.tpl b/templates/01_types.tpl index 790b56026..ecbeafb0f 100644 --- a/templates/01_types.tpl +++ b/templates/01_types.tpl @@ -3,31 +3,31 @@ {{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}} var ( - {{$varNameSingular}}Columns = []string{{"{"}}{{.Table.Columns | columnNames | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} - {{$varNameSingular}}ColumnsWithoutDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault false | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} - {{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault true | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} - {{$varNameSingular}}PrimaryKeyColumns = []string{{"{"}}{{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} + {{$varNameSingular}}Columns = []string{{"{"}}{{.Table.Columns | columnNames | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} + {{$varNameSingular}}ColumnsWithoutDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault false | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} + {{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault true | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} + {{$varNameSingular}}PrimaryKeyColumns = []string{{"{"}}{{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} ) type ( - {{$tableNameSingular}}Slice []*{{$tableNameSingular}} - {{if eq .NoHooks false -}} - {{$tableNameSingular}}Hook func(boil.Executor, *{{$tableNameSingular}}) error - {{- end}} + {{$tableNameSingular}}Slice []*{{$tableNameSingular}} + {{if eq .NoHooks false -}} + {{$tableNameSingular}}Hook func(boil.Executor, *{{$tableNameSingular}}) error + {{- end}} - {{$varNameSingular}}Query struct { - *boil.Query - } + {{$varNameSingular}}Query struct { + *boil.Query + } ) // Cache for insert and update var ( - {{$varNameSingular}}Type = reflect.TypeOf(&{{$tableNameSingular}}{}) - {{$varNameSingular}}Mapping = boil.MakeStructMapping({{$varNameSingular}}Type) - {{$varNameSingular}}InsertCacheMut sync.RWMutex - {{$varNameSingular}}InsertCache = make(map[string]insertCache) - {{$varNameSingular}}UpdateCacheMut sync.RWMutex - {{$varNameSingular}}UpdateCache = make(map[string]updateCache) + {{$varNameSingular}}Type = reflect.TypeOf(&{{$tableNameSingular}}{}) + {{$varNameSingular}}Mapping = boil.MakeStructMapping({{$varNameSingular}}Type) + {{$varNameSingular}}InsertCacheMut sync.RWMutex + {{$varNameSingular}}InsertCache = make(map[string]insertCache) + {{$varNameSingular}}UpdateCacheMut sync.RWMutex + {{$varNameSingular}}UpdateCache = make(map[string]updateCache) ) // Force time package dependency for automated UpdatedAt/CreatedAt. diff --git a/templates/02_hooks.tpl b/templates/02_hooks.tpl index 9e152d1cc..9fa123653 100644 --- a/templates/02_hooks.tpl +++ b/templates/02_hooks.tpl @@ -14,123 +14,123 @@ var {{$varNameSingular}}AfterUpsertHooks []{{$tableNameSingular}}Hook // doBeforeInsertHooks executes all "before insert" hooks. func (o *{{$tableNameSingular}}) doBeforeInsertHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}BeforeInsertHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}BeforeInsertHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doBeforeUpdateHooks executes all "before Update" hooks. func (o *{{$tableNameSingular}}) doBeforeUpdateHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}BeforeUpdateHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}BeforeUpdateHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doBeforeDeleteHooks executes all "before Delete" hooks. func (o *{{$tableNameSingular}}) doBeforeDeleteHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}BeforeDeleteHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}BeforeDeleteHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doBeforeUpsertHooks executes all "before Upsert" hooks. func (o *{{$tableNameSingular}}) doBeforeUpsertHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}BeforeUpsertHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}BeforeUpsertHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doAfterInsertHooks executes all "after Insert" hooks. func (o *{{$tableNameSingular}}) doAfterInsertHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}AfterInsertHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}AfterInsertHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doAfterSelectHooks executes all "after Select" hooks. func (o *{{$tableNameSingular}}) doAfterSelectHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}AfterSelectHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}AfterSelectHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doAfterUpdateHooks executes all "after Update" hooks. func (o *{{$tableNameSingular}}) doAfterUpdateHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}AfterUpdateHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}AfterUpdateHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doAfterDeleteHooks executes all "after Delete" hooks. func (o *{{$tableNameSingular}}) doAfterDeleteHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}AfterDeleteHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}AfterDeleteHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } // doAfterUpsertHooks executes all "after Upsert" hooks. func (o *{{$tableNameSingular}}) doAfterUpsertHooks(exec boil.Executor) (err error) { - for _, hook := range {{$varNameSingular}}AfterUpsertHooks { - if err := hook(exec, o); err != nil { - return err - } - } + for _, hook := range {{$varNameSingular}}AfterUpsertHooks { + if err := hook(exec, o); err != nil { + return err + } + } - return nil + return nil } func Add{{$tableNameSingular}}Hook(hookPoint boil.HookPoint, {{$varNameSingular}}Hook {{$tableNameSingular}}Hook) { - switch hookPoint { - case boil.BeforeInsertHook: - {{$varNameSingular}}BeforeInsertHooks = append({{$varNameSingular}}BeforeInsertHooks, {{$varNameSingular}}Hook) - case boil.BeforeUpdateHook: - {{$varNameSingular}}BeforeUpdateHooks = append({{$varNameSingular}}BeforeUpdateHooks, {{$varNameSingular}}Hook) - case boil.BeforeDeleteHook: - {{$varNameSingular}}BeforeDeleteHooks = append({{$varNameSingular}}BeforeDeleteHooks, {{$varNameSingular}}Hook) - case boil.BeforeUpsertHook: - {{$varNameSingular}}BeforeUpsertHooks = append({{$varNameSingular}}BeforeUpsertHooks, {{$varNameSingular}}Hook) - case boil.AfterInsertHook: - {{$varNameSingular}}AfterInsertHooks = append({{$varNameSingular}}AfterInsertHooks, {{$varNameSingular}}Hook) - case boil.AfterSelectHook: - {{$varNameSingular}}AfterSelectHooks = append({{$varNameSingular}}AfterSelectHooks, {{$varNameSingular}}Hook) - case boil.AfterUpdateHook: - {{$varNameSingular}}AfterUpdateHooks = append({{$varNameSingular}}AfterUpdateHooks, {{$varNameSingular}}Hook) - case boil.AfterDeleteHook: - {{$varNameSingular}}AfterDeleteHooks = append({{$varNameSingular}}AfterDeleteHooks, {{$varNameSingular}}Hook) - case boil.AfterUpsertHook: - {{$varNameSingular}}AfterUpsertHooks = append({{$varNameSingular}}AfterUpsertHooks, {{$varNameSingular}}Hook) - } + switch hookPoint { + case boil.BeforeInsertHook: + {{$varNameSingular}}BeforeInsertHooks = append({{$varNameSingular}}BeforeInsertHooks, {{$varNameSingular}}Hook) + case boil.BeforeUpdateHook: + {{$varNameSingular}}BeforeUpdateHooks = append({{$varNameSingular}}BeforeUpdateHooks, {{$varNameSingular}}Hook) + case boil.BeforeDeleteHook: + {{$varNameSingular}}BeforeDeleteHooks = append({{$varNameSingular}}BeforeDeleteHooks, {{$varNameSingular}}Hook) + case boil.BeforeUpsertHook: + {{$varNameSingular}}BeforeUpsertHooks = append({{$varNameSingular}}BeforeUpsertHooks, {{$varNameSingular}}Hook) + case boil.AfterInsertHook: + {{$varNameSingular}}AfterInsertHooks = append({{$varNameSingular}}AfterInsertHooks, {{$varNameSingular}}Hook) + case boil.AfterSelectHook: + {{$varNameSingular}}AfterSelectHooks = append({{$varNameSingular}}AfterSelectHooks, {{$varNameSingular}}Hook) + case boil.AfterUpdateHook: + {{$varNameSingular}}AfterUpdateHooks = append({{$varNameSingular}}AfterUpdateHooks, {{$varNameSingular}}Hook) + case boil.AfterDeleteHook: + {{$varNameSingular}}AfterDeleteHooks = append({{$varNameSingular}}AfterDeleteHooks, {{$varNameSingular}}Hook) + case boil.AfterUpsertHook: + {{$varNameSingular}}AfterUpsertHooks = append({{$varNameSingular}}AfterUpsertHooks, {{$varNameSingular}}Hook) + } } {{- end}} diff --git a/templates/03_finishers.tpl b/templates/03_finishers.tpl index a0de3fc1e..1d748e50b 100644 --- a/templates/03_finishers.tpl +++ b/templates/03_finishers.tpl @@ -2,115 +2,115 @@ {{- $varNameSingular := .Table.Name | singular | camelCase -}} // OneP returns a single {{$varNameSingular}} record from the query, and panics on error. func (q {{$varNameSingular}}Query) OneP() (*{{$tableNameSingular}}) { - o, err := q.One() - if err != nil { - panic(boil.WrapErr(err)) - } + o, err := q.One() + if err != nil { + panic(boil.WrapErr(err)) + } - return o + return o } // One returns a single {{$varNameSingular}} record from the query. func (q {{$varNameSingular}}Query) One() (*{{$tableNameSingular}}, error) { - o := &{{$tableNameSingular}}{} + o := &{{$tableNameSingular}}{} - boil.SetLimit(q.Query, 1) + boil.SetLimit(q.Query, 1) - err := q.Bind(o) - if err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return nil, sql.ErrNoRows - } - return nil, errors.Wrap(err, "{{.PkgName}}: failed to execute a one query for {{.Table.Name}}") - } + err := q.Bind(o) + if err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, sql.ErrNoRows + } + return nil, errors.Wrap(err, "{{.PkgName}}: failed to execute a one query for {{.Table.Name}}") + } - {{if not .NoHooks -}} - if err := o.doAfterSelectHooks(boil.GetExecutor(q.Query)); err != nil { - return o, err - } - {{- end}} + {{if not .NoHooks -}} + if err := o.doAfterSelectHooks(boil.GetExecutor(q.Query)); err != nil { + return o, err + } + {{- end}} - return o, nil + return o, nil } // AllP returns all {{$tableNameSingular}} records from the query, and panics on error. func (q {{$varNameSingular}}Query) AllP() {{$tableNameSingular}}Slice { - o, err := q.All() - if err != nil { - panic(boil.WrapErr(err)) - } + o, err := q.All() + if err != nil { + panic(boil.WrapErr(err)) + } - return o + return o } // All returns all {{$tableNameSingular}} records from the query. func (q {{$varNameSingular}}Query) All() ({{$tableNameSingular}}Slice, error) { - var o {{$tableNameSingular}}Slice - - err := q.Bind(&o) - if err != nil { - return nil, errors.Wrap(err, "{{.PkgName}}: failed to assign all query results to {{$tableNameSingular}} slice") - } - - {{if not .NoHooks -}} - if len({{$varNameSingular}}AfterSelectHooks) != 0 { - for _, obj := range o { - if err := obj.doAfterSelectHooks(boil.GetExecutor(q.Query)); err != nil { - return o, err - } - } - } - {{- end}} - - return o, nil + var o {{$tableNameSingular}}Slice + + err := q.Bind(&o) + if err != nil { + return nil, errors.Wrap(err, "{{.PkgName}}: failed to assign all query results to {{$tableNameSingular}} slice") + } + + {{if not .NoHooks -}} + if len({{$varNameSingular}}AfterSelectHooks) != 0 { + for _, obj := range o { + if err := obj.doAfterSelectHooks(boil.GetExecutor(q.Query)); err != nil { + return o, err + } + } + } + {{- end}} + + return o, nil } // CountP returns the count of all {{$tableNameSingular}} records in the query, and panics on error. func (q {{$varNameSingular}}Query) CountP() int64 { - c, err := q.Count() - if err != nil { - panic(boil.WrapErr(err)) - } + c, err := q.Count() + if err != nil { + panic(boil.WrapErr(err)) + } - return c + return c } // Count returns the count of all {{$tableNameSingular}} records in the query. func (q {{$varNameSingular}}Query) Count() (int64, error) { - var count int64 + var count int64 - boil.SetSelect(q.Query, nil) - boil.SetCount(q.Query) + boil.SetSelect(q.Query, nil) + boil.SetCount(q.Query) - err := q.Query.ExecQueryOne().Scan(&count) - if err != nil { - return 0, errors.Wrap(err, "{{.PkgName}}: failed to count {{.Table.Name}} rows") - } + err := q.Query.ExecQueryOne().Scan(&count) + if err != nil { + return 0, errors.Wrap(err, "{{.PkgName}}: failed to count {{.Table.Name}} rows") + } - return count, nil + return count, nil } // Exists checks if the row exists in the table, and panics on error. func (q {{$varNameSingular}}Query) ExistsP() bool { - e, err := q.Exists() - if err != nil { - panic(boil.WrapErr(err)) - } + e, err := q.Exists() + if err != nil { + panic(boil.WrapErr(err)) + } - return e + return e } // Exists checks if the row exists in the table. func (q {{$varNameSingular}}Query) Exists() (bool, error) { - var count int64 + var count int64 - boil.SetCount(q.Query) - boil.SetLimit(q.Query, 1) + boil.SetCount(q.Query) + boil.SetLimit(q.Query, 1) - err := q.Query.ExecQueryOne().Scan(&count) - if err != nil { - return false, errors.Wrap(err, "{{.PkgName}}: failed to check if {{.Table.Name}} exists") - } + err := q.Query.ExecQueryOne().Scan(&count) + if err != nil { + return false, errors.Wrap(err, "{{.PkgName}}: failed to check if {{.Table.Name}} exists") + } - return count > 0, nil + return count > 0, nil } diff --git a/templates/04_relationship_to_one.tpl b/templates/04_relationship_to_one.tpl index 0fdfd8761..02e594cf0 100644 --- a/templates/04_relationship_to_one.tpl +++ b/templates/04_relationship_to_one.tpl @@ -1,34 +1,34 @@ {{- define "relationship_to_one_helper" -}} - {{- $dot := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} - {{- with .Rel -}}{{/* Rel holds the text helper data, passed in through preserveDot */}} - {{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} + {{- $dot := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} + {{- with .Rel -}}{{/* Rel holds the text helper data, passed in through preserveDot */}} + {{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} // {{.Function.Name}}G pointed to by the foreign key. func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { - return {{.Function.Receiver}}.{{.Function.Name}}(boil.GetDB(), mods...) + return {{.Function.Receiver}}.{{.Function.Name}}(boil.GetDB(), mods...) } // {{.Function.Name}} pointed to by the foreign key. func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}(exec boil.Executor, mods ...qm.QueryMod) ({{$varNameSingular}}Query) { - queryMods := []qm.QueryMod{ - qm.Where("{{.ForeignTable.ColumnName}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}), - } + queryMods := []qm.QueryMod{ + qm.Where("{{.ForeignTable.ColumnName}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}), + } - queryMods = append(queryMods, mods...) + queryMods = append(queryMods, mods...) - query := {{.ForeignTable.NamePluralGo}}(exec, queryMods...) - boil.SetFrom(query.Query, "{{.ForeignTable.Name | $dot.SchemaTable}}") + query := {{.ForeignTable.NamePluralGo}}(exec, queryMods...) + boil.SetFrom(query.Query, "{{.ForeignTable.Name | $dot.SchemaTable}}") - return query + return query } - {{- end -}}{{/* end with */}} + {{- end -}}{{/* end with */}} {{end -}}{{/* end define */}} {{- /* Begin execution of template for one-to-one relationship */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- range .Table.FKeys -}} - {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} + {{- $dot := . -}} + {{- range .Table.FKeys -}} + {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} {{- template "relationship_to_one_helper" (preserveDot $dot $txt) -}} {{- end -}} {{- end -}} diff --git a/templates/05_relationship_to_many.tpl b/templates/05_relationship_to_many.tpl index b96537738..92b998faf 100644 --- a/templates/05_relationship_to_many.tpl +++ b/templates/05_relationship_to_many.tpl @@ -1,49 +1,49 @@ {{- /* Begin execution of template for many-to-one or many-to-many relationship helper */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- $table := .Table -}} - {{- range .Table.ToManyRelationships -}} - {{- $varNameSingular := .ForeignTable | singular | camelCase -}} - {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} - {{- /* Begin execution of template for many-to-one relationship. */ -}} - {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table . -}} - {{- template "relationship_to_one_helper" (preserveDot $dot $txt) -}} - {{- else -}} - {{- /* Begin execution of template for many-to-many relationship. */ -}} - {{- $rel := textsFromRelationship $dot.Tables $table . -}} - {{- $schemaForeignTable := .ForeignTable | $dot.SchemaTable -}} + {{- $dot := . -}} + {{- $table := .Table -}} + {{- range .Table.ToManyRelationships -}} + {{- $varNameSingular := .ForeignTable | singular | camelCase -}} + {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- /* Begin execution of template for many-to-one relationship. */ -}} + {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table . -}} + {{- template "relationship_to_one_helper" (preserveDot $dot $txt) -}} + {{- else -}} + {{- /* Begin execution of template for many-to-many relationship. */ -}} + {{- $rel := textsFromRelationship $dot.Tables $table . -}} + {{- $schemaForeignTable := .ForeignTable | $dot.SchemaTable -}} // {{$rel.Function.Name}}G retrieves all the {{$rel.LocalTable.NameSingular}}'s {{$rel.ForeignTable.NameHumanReadable}} {{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Name}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { - return {{$rel.Function.Receiver}}.{{$rel.Function.Name}}(boil.GetDB(), mods...) + return {{$rel.Function.Receiver}}.{{$rel.Function.Name}}(boil.GetDB(), mods...) } // {{$rel.Function.Name}} retrieves all the {{$rel.LocalTable.NameSingular}}'s {{$rel.ForeignTable.NameHumanReadable}} with an executor {{- if not (eq $rel.Function.Name $rel.ForeignTable.NamePluralGo)}} via {{.ForeignColumn}} column{{- end}}. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Name}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query { - queryMods := []qm.QueryMod{ - qm.Select("{{id 0 | $dot.Quotes}}.*"), - } + queryMods := []qm.QueryMod{ + qm.Select("{{id 0 | $dot.Quotes}}.*"), + } - if len(mods) != 0 { - queryMods = append(queryMods, mods...) - } + if len(mods) != 0 { + queryMods = append(queryMods, mods...) + } - {{if .ToJoinTable -}} - queryMods = append(queryMods, - qm.InnerJoin("{{.JoinTable | $dot.SchemaTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}}"), - qm.Where("{{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), - ) - {{else -}} - queryMods = append(queryMods, - qm.Where("{{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), - ) - {{end}} + {{if .ToJoinTable -}} + queryMods = append(queryMods, + qm.InnerJoin("{{.JoinTable | $dot.SchemaTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}}"), + qm.Where("{{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), + ) + {{else -}} + queryMods = append(queryMods, + qm.Where("{{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), + ) + {{end}} - query := {{$rel.ForeignTable.NamePluralGo}}(exec, queryMods...) - boil.SetFrom(query.Query, "{{$schemaForeignTable}} as {{id 0 | $dot.Quotes}}") - return query + query := {{$rel.ForeignTable.NamePluralGo}}(exec, queryMods...) + boil.SetFrom(query.Query, "{{$schemaForeignTable}} as {{id 0 | $dot.Quotes}}") + return query } {{end -}}{{- /* if unique foreign key */ -}} diff --git a/templates/06_relationship_to_one_eager.tpl b/templates/06_relationship_to_one_eager.tpl index cf851a572..279792fa7 100644 --- a/templates/06_relationship_to_one_eager.tpl +++ b/templates/06_relationship_to_one_eager.tpl @@ -1,93 +1,93 @@ {{- define "relationship_to_one_eager_helper" -}} - {{- $dot := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} - {{- $varNameSingular := $dot.Table.Name | singular | camelCase -}} - {{- with .Rel -}} - {{- $arg := printf "maybe%s" .LocalTable.NameGo -}} - {{- $slice := printf "%sSlice" .LocalTable.NameGo -}} + {{- $dot := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} + {{- $varNameSingular := $dot.Table.Name | singular | camelCase -}} + {{- with .Rel -}} + {{- $arg := printf "maybe%s" .LocalTable.NameGo -}} + {{- $slice := printf "%sSlice" .LocalTable.NameGo -}} // Load{{.Function.Name}} allows an eager lookup of values, cached into the // loaded structs of the objects. func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error { - var slice []*{{.LocalTable.NameGo}} - var object *{{.LocalTable.NameGo}} + var slice []*{{.LocalTable.NameGo}} + var object *{{.LocalTable.NameGo}} - count := 1 - if singular { - object = {{$arg}}.(*{{.LocalTable.NameGo}}) - } else { - slice = *{{$arg}}.(*{{$slice}}) - count = len(slice) - } + count := 1 + if singular { + object = {{$arg}}.(*{{.LocalTable.NameGo}}) + } else { + slice = *{{$arg}}.(*{{$slice}}) + count = len(slice) + } - args := make([]interface{}, count) - if singular { - args[0] = object.{{.LocalTable.ColumnNameGo}} - } else { - for i, obj := range slice { - args[i] = obj.{{.LocalTable.ColumnNameGo}} - } - } + args := make([]interface{}, count) + if singular { + args[0] = object.{{.LocalTable.ColumnNameGo}} + } else { + for i, obj := range slice { + args[i] = obj.{{.LocalTable.ColumnNameGo}} + } + } - query := fmt.Sprintf( - "select * from {{.ForeignKey.ForeignTable | $dot.SchemaTable}} where {{.ForeignKey.ForeignColumn | $dot.Quotes}} in (%s)", - strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), - ) + query := fmt.Sprintf( + "select * from {{.ForeignKey.ForeignTable | $dot.SchemaTable}} where {{.ForeignKey.ForeignColumn | $dot.Quotes}} in (%s)", + strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), + ) - if boil.DebugMode { - fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args) - } + if boil.DebugMode { + fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args) + } - results, err := e.Query(query, args...) - if err != nil { - return errors.Wrap(err, "failed to eager load {{.ForeignTable.NameGo}}") - } - defer results.Close() + results, err := e.Query(query, args...) + if err != nil { + return errors.Wrap(err, "failed to eager load {{.ForeignTable.NameGo}}") + } + defer results.Close() - var resultSlice []*{{.ForeignTable.NameGo}} - if err = boil.Bind(results, &resultSlice); err != nil { - return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable.NameGo}}") - } + var resultSlice []*{{.ForeignTable.NameGo}} + if err = boil.Bind(results, &resultSlice); err != nil { + return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable.NameGo}}") + } - {{if not $dot.NoHooks -}} - if len({{.ForeignTable.Name | singular | camelCase}}AfterSelectHooks) != 0 { - for _, obj := range resultSlice { - if err := obj.doAfterSelectHooks(e); err != nil { - return err - } - } - } - {{- end}} + {{if not $dot.NoHooks -}} + if len({{.ForeignTable.Name | singular | camelCase}}AfterSelectHooks) != 0 { + for _, obj := range resultSlice { + if err := obj.doAfterSelectHooks(e); err != nil { + return err + } + } + } + {{- end}} - if singular && len(resultSlice) != 0 { - if object.R == nil { - object.R = &{{$varNameSingular}}R{} - } - object.R.{{.Function.Name}} = resultSlice[0] - return nil - } + if singular && len(resultSlice) != 0 { + if object.R == nil { + object.R = &{{$varNameSingular}}R{} + } + object.R.{{.Function.Name}} = resultSlice[0] + return nil + } - for _, foreign := range resultSlice { - for _, local := range slice { - if local.{{.Function.LocalAssignment}} == foreign.{{.Function.ForeignAssignment}} { - if local.R == nil { - local.R = &{{$varNameSingular}}R{} - } - local.R.{{.Function.Name}} = foreign - break - } - } - } + for _, foreign := range resultSlice { + for _, local := range slice { + if local.{{.Function.LocalAssignment}} == foreign.{{.Function.ForeignAssignment}} { + if local.R == nil { + local.R = &{{$varNameSingular}}R{} + } + local.R.{{.Function.Name}} = foreign + break + } + } + } - return nil + return nil } - {{- end -}}{{- /* end with */ -}} + {{- end -}}{{- /* end with */ -}} {{end -}}{{- /* end define */ -}} {{- /* Begin execution of template for one-to-one eager load */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- range .Table.FKeys -}} - {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} - {{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}} - {{- end -}} + {{- $dot := . -}} + {{- range .Table.FKeys -}} + {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} + {{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}} + {{- end -}} {{end}} diff --git a/templates/07_relationship_to_many_eager.tpl b/templates/07_relationship_to_many_eager.tpl index 21a6f45d0..82b7a99ba 100644 --- a/templates/07_relationship_to_many_eager.tpl +++ b/templates/07_relationship_to_many_eager.tpl @@ -1,139 +1,139 @@ {{- /* Begin execution of template for many-to-one or many-to-many eager load */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- range .Table.ToManyRelationships -}} - {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} - {{- /* Begin execution of template for many-to-one eager load */ -}} - {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table . -}} - {{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}} - {{- else -}} - {{- /* Begin execution of template for many-to-many eager load */ -}} - {{- $varNameSingular := $dot.Table.Name | singular | camelCase -}} - {{- $txt := textsFromRelationship $dot.Tables $dot.Table . -}} - {{- $arg := printf "maybe%s" $txt.LocalTable.NameGo -}} - {{- $slice := printf "%sSlice" $txt.LocalTable.NameGo -}} - {{- $schemaForeignTable := .ForeignTable | $dot.SchemaTable -}} + {{- $dot := . -}} + {{- range .Table.ToManyRelationships -}} + {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- /* Begin execution of template for many-to-one eager load */ -}} + {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $dot.Table . -}} + {{- template "relationship_to_one_eager_helper" (preserveDot $dot $txt) -}} + {{- else -}} + {{- /* Begin execution of template for many-to-many eager load */ -}} + {{- $varNameSingular := $dot.Table.Name | singular | camelCase -}} + {{- $txt := textsFromRelationship $dot.Tables $dot.Table . -}} + {{- $arg := printf "maybe%s" $txt.LocalTable.NameGo -}} + {{- $slice := printf "%sSlice" $txt.LocalTable.NameGo -}} + {{- $schemaForeignTable := .ForeignTable | $dot.SchemaTable -}} // Load{{$txt.Function.Name}} allows an eager lookup of values, cached into the // loaded structs of the objects. func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singular bool, {{$arg}} interface{}) error { - var slice []*{{$txt.LocalTable.NameGo}} - var object *{{$txt.LocalTable.NameGo}} + var slice []*{{$txt.LocalTable.NameGo}} + var object *{{$txt.LocalTable.NameGo}} - count := 1 - if singular { - object = {{$arg}}.(*{{$txt.LocalTable.NameGo}}) - } else { - slice = *{{$arg}}.(*{{$slice}}) - count = len(slice) - } + count := 1 + if singular { + object = {{$arg}}.(*{{$txt.LocalTable.NameGo}}) + } else { + slice = *{{$arg}}.(*{{$slice}}) + count = len(slice) + } - args := make([]interface{}, count) - if singular { - args[0] = object.{{.Column | titleCase}} - } else { - for i, obj := range slice { - args[i] = obj.{{.Column | titleCase}} - } - } + args := make([]interface{}, count) + if singular { + args[0] = object.{{.Column | titleCase}} + } else { + for i, obj := range slice { + args[i] = obj.{{.Column | titleCase}} + } + } - {{if .ToJoinTable -}} - {{- $schemaJoinTable := .JoinTable | $dot.SchemaTable -}} - query := fmt.Sprintf( - "select {{id 0 | $dot.Quotes}}.*, {{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}} from {{$schemaForeignTable}} as {{id 0 | $dot.Quotes}} inner join {{$schemaJoinTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}} where {{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}} in (%s)", - strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), - ) - {{else -}} - query := fmt.Sprintf( - "select * from {{$schemaForeignTable}} where {{.ForeignColumn | $dot.Quotes}} in (%s)", - strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), - ) - {{end -}} + {{if .ToJoinTable -}} + {{- $schemaJoinTable := .JoinTable | $dot.SchemaTable -}} + query := fmt.Sprintf( + "select {{id 0 | $dot.Quotes}}.*, {{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}} from {{$schemaForeignTable}} as {{id 0 | $dot.Quotes}} inner join {{$schemaJoinTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}} where {{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}} in (%s)", + strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), + ) + {{else -}} + query := fmt.Sprintf( + "select * from {{$schemaForeignTable}} where {{.ForeignColumn | $dot.Quotes}} in (%s)", + strmangle.Placeholders(dialect.IndexPlaceholders, count, 1, 1), + ) + {{end -}} - if boil.DebugMode { - fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args) - } + if boil.DebugMode { + fmt.Fprintf(boil.DebugWriter, "%s\n%v\n", query, args) + } - results, err := e.Query(query, args...) - if err != nil { - return errors.Wrap(err, "failed to eager load {{.ForeignTable}}") - } - defer results.Close() + results, err := e.Query(query, args...) + if err != nil { + return errors.Wrap(err, "failed to eager load {{.ForeignTable}}") + } + defer results.Close() - var resultSlice []*{{$txt.ForeignTable.NameGo}} - {{if .ToJoinTable -}} - {{- $foreignTable := getTable $dot.Tables .ForeignTable -}} - {{- $joinTable := getTable $dot.Tables .JoinTable -}} - {{- $localCol := $joinTable.GetColumn .JoinLocalColumn}} - var localJoinCols []{{$localCol.Type}} - for results.Next() { - one := new({{$txt.ForeignTable.NameGo}}) - var localJoinCol {{$localCol.Type}} + var resultSlice []*{{$txt.ForeignTable.NameGo}} + {{if .ToJoinTable -}} + {{- $foreignTable := getTable $dot.Tables .ForeignTable -}} + {{- $joinTable := getTable $dot.Tables .JoinTable -}} + {{- $localCol := $joinTable.GetColumn .JoinLocalColumn}} + var localJoinCols []{{$localCol.Type}} + for results.Next() { + one := new({{$txt.ForeignTable.NameGo}}) + var localJoinCol {{$localCol.Type}} - err = results.Scan({{$foreignTable.Columns | columnNames | stringMap $dot.StringFuncs.titleCase | prefixStringSlice "&one." | join ", "}}, &localJoinCol) - if err = results.Err(); err != nil { - return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}") - } + err = results.Scan({{$foreignTable.Columns | columnNames | stringMap $dot.StringFuncs.titleCase | prefixStringSlice "&one." | join ", "}}, &localJoinCol) + if err = results.Err(); err != nil { + return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}") + } - resultSlice = append(resultSlice, one) - localJoinCols = append(localJoinCols, localJoinCol) - } + resultSlice = append(resultSlice, one) + localJoinCols = append(localJoinCols, localJoinCol) + } - if err = results.Err(); err != nil { - return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}") - } - {{else -}} - if err = boil.Bind(results, &resultSlice); err != nil { - return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable}}") - } - {{end}} + if err = results.Err(); err != nil { + return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}") + } + {{else -}} + if err = boil.Bind(results, &resultSlice); err != nil { + return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable}}") + } + {{end}} - {{if not $dot.NoHooks -}} - if len({{.ForeignTable | singular | camelCase}}AfterSelectHooks) != 0 { - for _, obj := range resultSlice { - if err := obj.doAfterSelectHooks(e); err != nil { - return err - } - } - } + {{if not $dot.NoHooks -}} + if len({{.ForeignTable | singular | camelCase}}AfterSelectHooks) != 0 { + for _, obj := range resultSlice { + if err := obj.doAfterSelectHooks(e); err != nil { + return err + } + } + } - {{- end}} - if singular { - if object.R == nil { - object.R = &{{$varNameSingular}}R{} - } - object.R.{{$txt.Function.Name}} = resultSlice - return nil - } + {{- end}} + if singular { + if object.R == nil { + object.R = &{{$varNameSingular}}R{} + } + object.R.{{$txt.Function.Name}} = resultSlice + return nil + } - {{if .ToJoinTable -}} - for i, foreign := range resultSlice { - localJoinCol := localJoinCols[i] - for _, local := range slice { - if local.{{$txt.Function.LocalAssignment}} == localJoinCol { - if local.R == nil { - local.R = &{{$varNameSingular}}R{} - } - local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign) - break - } - } - } - {{else -}} - for _, foreign := range resultSlice { - for _, local := range slice { - if local.{{$txt.Function.LocalAssignment}} == foreign.{{$txt.Function.ForeignAssignment}} { - if local.R == nil { - local.R = &{{$varNameSingular}}R{} - } - local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign) - break - } - } - } - {{end}} + {{if .ToJoinTable -}} + for i, foreign := range resultSlice { + localJoinCol := localJoinCols[i] + for _, local := range slice { + if local.{{$txt.Function.LocalAssignment}} == localJoinCol { + if local.R == nil { + local.R = &{{$varNameSingular}}R{} + } + local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign) + break + } + } + } + {{else -}} + for _, foreign := range resultSlice { + for _, local := range slice { + if local.{{$txt.Function.LocalAssignment}} == foreign.{{$txt.Function.ForeignAssignment}} { + if local.R == nil { + local.R = &{{$varNameSingular}}R{} + } + local.R.{{$txt.Function.Name}} = append(local.R.{{$txt.Function.Name}}, foreign) + break + } + } + } + {{end}} - return nil + return nil } {{end -}}{{/* if ForeignColumnUnique */}} diff --git a/templates/08_relationship_to_one_setops.tpl b/templates/08_relationship_to_one_setops.tpl index 66a6d8959..e4e93d6fa 100644 --- a/templates/08_relationship_to_one_setops.tpl +++ b/templates/08_relationship_to_one_setops.tpl @@ -1,105 +1,105 @@ {{- define "relationship_to_one_setops_helper" -}} - {{- $tmplData := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} - {{- with .Rel -}} - {{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} - {{- $localNameSingular := .ForeignKey.Table | singular | camelCase}} + {{- $tmplData := .Dot -}}{{/* .Dot holds the root templateData struct, passed in through preserveDot */}} + {{- with .Rel -}} + {{- $varNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} + {{- $localNameSingular := .ForeignKey.Table | singular | camelCase}} // Set{{.Function.Name}} of the {{.ForeignKey.Table | singular}} to the related item. // Sets {{.Function.Receiver}}.R.{{.Function.Name}} to related. // Adds {{.Function.Receiver}} to related.R.{{.Function.ForeignName}}. func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) Set{{.Function.Name}}(exec boil.Executor, insert bool, related *{{.ForeignTable.NameGo}}) error { - var err error - if insert { - if err = related.Insert(exec); err != nil { - return errors.Wrap(err, "failed to insert into foreign table") - } - } + var err error + if insert { + if err = related.Insert(exec); err != nil { + return errors.Wrap(err, "failed to insert into foreign table") + } + } - oldVal := {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}} - {{.Function.Receiver}}.{{.Function.LocalAssignment}} = related.{{.Function.ForeignAssignment}} - if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil { - {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}} = oldVal - return errors.Wrap(err, "failed to update local table") - } + oldVal := {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}} + {{.Function.Receiver}}.{{.Function.LocalAssignment}} = related.{{.Function.ForeignAssignment}} + if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil { + {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}} = oldVal + return errors.Wrap(err, "failed to update local table") + } - if {{.Function.Receiver}}.R == nil { - {{.Function.Receiver}}.R = &{{$localNameSingular}}R{ - {{.Function.Name}}: related, - } - } else { - {{.Function.Receiver}}.R.{{.Function.Name}} = related - } + if {{.Function.Receiver}}.R == nil { + {{.Function.Receiver}}.R = &{{$localNameSingular}}R{ + {{.Function.Name}}: related, + } + } else { + {{.Function.Receiver}}.R.{{.Function.Name}} = related + } - {{if (or .ForeignKey.Unique .Function.OneToOne) -}} - if related.R == nil { - related.R = &{{$varNameSingular}}R{ - {{.Function.ForeignName}}: {{.Function.Receiver}}, - } - } else { - related.R.{{.Function.ForeignName}} = {{.Function.Receiver}} - } - {{else -}} - if related.R == nil { - related.R = &{{$varNameSingular}}R{ - {{.Function.ForeignName}}: {{.LocalTable.NameGo}}Slice{{"{"}}{{.Function.Receiver}}{{"}"}}, - } - } else { - related.R.{{.Function.ForeignName}} = append(related.R.{{.Function.ForeignName}}, {{.Function.Receiver}}) - } - {{end -}} + {{if (or .ForeignKey.Unique .Function.OneToOne) -}} + if related.R == nil { + related.R = &{{$varNameSingular}}R{ + {{.Function.ForeignName}}: {{.Function.Receiver}}, + } + } else { + related.R.{{.Function.ForeignName}} = {{.Function.Receiver}} + } + {{else -}} + if related.R == nil { + related.R = &{{$varNameSingular}}R{ + {{.Function.ForeignName}}: {{.LocalTable.NameGo}}Slice{{"{"}}{{.Function.Receiver}}{{"}"}}, + } + } else { + related.R.{{.Function.ForeignName}} = append(related.R.{{.Function.ForeignName}}, {{.Function.Receiver}}) + } + {{end -}} - {{if .ForeignKey.Nullable}} - {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true - {{end -}} - return nil + {{if .ForeignKey.Nullable}} + {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true + {{end -}} + return nil } - {{- if .ForeignKey.Nullable}} + {{- if .ForeignKey.Nullable}} // Remove{{.Function.Name}} relationship. // Sets {{.Function.Receiver}}.R.{{.Function.Name}} to nil. // Removes {{.Function.Receiver}} from all passed in related items' relationships struct (Optional). func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) Remove{{.Function.Name}}(exec boil.Executor, related *{{.ForeignTable.NameGo}}) error { - var err error + var err error - {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = false - if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil { - {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true - return errors.Wrap(err, "failed to update local table") - } + {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = false + if err = {{.Function.Receiver}}.Update(exec, "{{.ForeignKey.Column}}"); err != nil { + {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}.Valid = true + return errors.Wrap(err, "failed to update local table") + } - {{.Function.Receiver}}.R.{{.Function.Name}} = nil - if related == nil || related.R == nil { - return nil - } + {{.Function.Receiver}}.R.{{.Function.Name}} = nil + if related == nil || related.R == nil { + return nil + } - {{if .ForeignKey.Unique -}} - related.R.{{.Function.ForeignName}} = nil - {{else -}} - for i, ri := range related.R.{{.Function.ForeignName}} { - if {{.Function.Receiver}}.{{.Function.LocalAssignment}} != ri.{{.Function.LocalAssignment}} { - continue - } + {{if .ForeignKey.Unique -}} + related.R.{{.Function.ForeignName}} = nil + {{else -}} + for i, ri := range related.R.{{.Function.ForeignName}} { + if {{.Function.Receiver}}.{{.Function.LocalAssignment}} != ri.{{.Function.LocalAssignment}} { + continue + } - ln := len(related.R.{{.Function.ForeignName}}) - if ln > 1 && i < ln-1 { - related.R.{{.Function.ForeignName}}[i] = related.R.{{.Function.ForeignName}}[ln-1] - } - related.R.{{.Function.ForeignName}} = related.R.{{.Function.ForeignName}}[:ln-1] - break - } - {{end -}} + ln := len(related.R.{{.Function.ForeignName}}) + if ln > 1 && i < ln-1 { + related.R.{{.Function.ForeignName}}[i] = related.R.{{.Function.ForeignName}}[ln-1] + } + related.R.{{.Function.ForeignName}} = related.R.{{.Function.ForeignName}}[:ln-1] + break + } + {{end -}} - return nil + return nil } - {{- end -}}{{/* if foreignkey nullable */}} - {{end -}}{{/* end with */}} + {{- end -}}{{/* if foreignkey nullable */}} + {{end -}}{{/* end with */}} {{- end -}}{{/* end define */}} {{- /* Begin execution of template for one-to-one setops */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- range .Table.FKeys -}} - {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} - {{- template "relationship_to_one_setops_helper" (preserveDot $dot $txt) -}} - {{- end -}} + {{- $dot := . -}} + {{- range .Table.FKeys -}} + {{- $txt := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} + {{- template "relationship_to_one_setops_helper" (preserveDot $dot $txt) -}} + {{- end -}} {{- end -}} diff --git a/templates/09_relationship_to_many_setops.tpl b/templates/09_relationship_to_many_setops.tpl index 0280c31aa..b42842874 100644 --- a/templates/09_relationship_to_many_setops.tpl +++ b/templates/09_relationship_to_many_setops.tpl @@ -1,93 +1,93 @@ {{- /* Begin execution of template for many-to-one or many-to-many setops */ -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- $table := .Table -}} - {{- range .Table.ToManyRelationships -}} - {{- $varNameSingular := .ForeignTable | singular | camelCase -}} - {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} - {{- /* Begin execution of template for many-to-one setops */ -}} - {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table . -}} - {{- template "relationship_to_one_setops_helper" (preserveDot $dot $txt) -}} - {{- else -}} - {{- $rel := textsFromRelationship $dot.Tables $table . -}} - {{- $localNameSingular := .Table | singular | camelCase -}} - {{- $foreignNameSingular := .ForeignTable | singular | camelCase}} + {{- $dot := . -}} + {{- $table := .Table -}} + {{- range .Table.ToManyRelationships -}} + {{- $varNameSingular := .ForeignTable | singular | camelCase -}} + {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- /* Begin execution of template for many-to-one setops */ -}} + {{- $txt := textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table . -}} + {{- template "relationship_to_one_setops_helper" (preserveDot $dot $txt) -}} + {{- else -}} + {{- $rel := textsFromRelationship $dot.Tables $table . -}} + {{- $localNameSingular := .Table | singular | camelCase -}} + {{- $foreignNameSingular := .ForeignTable | singular | camelCase}} // Add{{$rel.Function.Name}} adds the given related objects to the existing relationships // of the {{$table.Name | singular}}, optionally inserting them as new records. // Appends related to {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}. // Sets related.R.{{$rel.Function.ForeignName}} appropriately. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error { - var err error - for _, rel := range related { - {{if not .ToJoinTable -}} - rel.{{$rel.Function.ForeignAssignment}} = {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}} - {{if .ForeignColumnNullable -}} - rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = true - {{end -}} - {{end -}} - if insert { - if err = rel.Insert(exec); err != nil { - return errors.Wrap(err, "failed to insert into foreign table") - } - }{{if not .ToJoinTable}} else { - if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil { - return errors.Wrap(err, "failed to update foreign table") - } - }{{end -}} - } - - {{if .ToJoinTable -}} - for _, rel := range related { - query := "insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}" - values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}} - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, query) - fmt.Fprintln(boil.DebugWriter, values) - } - - _, err = exec.Exec(query, values...) - if err != nil { - return errors.Wrap(err, "failed to insert into join table") - } - } - {{end -}} - - if {{$rel.Function.Receiver}}.R == nil { - {{$rel.Function.Receiver}}.R = &{{$localNameSingular}}R{ - {{$rel.Function.Name}}: related, - } - } else { - {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = append({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}, related...) - } - - {{if .ToJoinTable -}} - for _, rel := range related { - if rel.R == nil { - rel.R = &{{$foreignNameSingular}}R{ - {{$rel.Function.ForeignName}}: {{$rel.LocalTable.NameGo}}Slice{{"{"}}{{$rel.Function.Receiver}}{{"}"}}, - } - } else { - rel.R.{{$rel.Function.ForeignName}} = append(rel.R.{{$rel.Function.ForeignName}}, {{$rel.Function.Receiver}}) - } - } - {{else -}} - for _, rel := range related { - if rel.R == nil { - rel.R = &{{$foreignNameSingular}}R{ - {{$rel.Function.ForeignName}}: {{$rel.Function.Receiver}}, - } - } else { - rel.R.{{$rel.Function.ForeignName}} = {{$rel.Function.Receiver}} - } - } - {{end -}} - - return nil + var err error + for _, rel := range related { + {{if not .ToJoinTable -}} + rel.{{$rel.Function.ForeignAssignment}} = {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}} + {{if .ForeignColumnNullable -}} + rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = true + {{end -}} + {{end -}} + if insert { + if err = rel.Insert(exec); err != nil { + return errors.Wrap(err, "failed to insert into foreign table") + } + }{{if not .ToJoinTable}} else { + if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil { + return errors.Wrap(err, "failed to update foreign table") + } + }{{end -}} + } + + {{if .ToJoinTable -}} + for _, rel := range related { + query := "insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}" + values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}} + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, query) + fmt.Fprintln(boil.DebugWriter, values) + } + + _, err = exec.Exec(query, values...) + if err != nil { + return errors.Wrap(err, "failed to insert into join table") + } + } + {{end -}} + + if {{$rel.Function.Receiver}}.R == nil { + {{$rel.Function.Receiver}}.R = &{{$localNameSingular}}R{ + {{$rel.Function.Name}}: related, + } + } else { + {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = append({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}, related...) + } + + {{if .ToJoinTable -}} + for _, rel := range related { + if rel.R == nil { + rel.R = &{{$foreignNameSingular}}R{ + {{$rel.Function.ForeignName}}: {{$rel.LocalTable.NameGo}}Slice{{"{"}}{{$rel.Function.Receiver}}{{"}"}}, + } + } else { + rel.R.{{$rel.Function.ForeignName}} = append(rel.R.{{$rel.Function.ForeignName}}, {{$rel.Function.Receiver}}) + } + } + {{else -}} + for _, rel := range related { + if rel.R == nil { + rel.R = &{{$foreignNameSingular}}R{ + {{$rel.Function.ForeignName}}: {{$rel.Function.Receiver}}, + } + } else { + rel.R.{{$rel.Function.ForeignName}} = {{$rel.Function.Receiver}} + } + } + {{end -}} + + return nil } - {{- if (or .ForeignColumnNullable .ToJoinTable)}} + {{- if (or .ForeignColumnNullable .ToJoinTable)}} // Set{{$rel.Function.Name}} removes all previously related items of the // {{$table.Name | singular}} replacing them completely with the passed // in related items, optionally inserting them as new records. @@ -95,126 +95,126 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function // Replaces {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} with related. // Sets related.R.{{$rel.Function.ForeignName}}'s {{$rel.Function.Name}} accordingly. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Set{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error { - {{if .ToJoinTable -}} - query := "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}" - values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} - {{else -}} - query := "update {{.ForeignTable | $dot.SchemaTable}} set {{.ForeignColumn | $dot.Quotes}} = null where {{.ForeignColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}" - values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} - {{end -}} - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, query) - fmt.Fprintln(boil.DebugWriter, values) - } - - _, err := exec.Exec(query, values...) - if err != nil { - return errors.Wrap(err, "failed to remove relationships before set") - } - - {{if .ToJoinTable -}} - remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related) - {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil - {{else -}} - if {{$rel.Function.Receiver}}.R != nil { - for _, rel := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} { - rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false - if rel.R == nil { - continue - } - - rel.R.{{$rel.Function.ForeignName}} = nil - } - - {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil - } - {{end -}} - - return {{$rel.Function.Receiver}}.Add{{$rel.Function.Name}}(exec, insert, related...) + {{if .ToJoinTable -}} + query := "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}" + values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} + {{else -}} + query := "update {{.ForeignTable | $dot.SchemaTable}} set {{.ForeignColumn | $dot.Quotes}} = null where {{.ForeignColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}" + values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} + {{end -}} + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, query) + fmt.Fprintln(boil.DebugWriter, values) + } + + _, err := exec.Exec(query, values...) + if err != nil { + return errors.Wrap(err, "failed to remove relationships before set") + } + + {{if .ToJoinTable -}} + remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related) + {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil + {{else -}} + if {{$rel.Function.Receiver}}.R != nil { + for _, rel := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} { + rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false + if rel.R == nil { + continue + } + + rel.R.{{$rel.Function.ForeignName}} = nil + } + + {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = nil + } + {{end -}} + + return {{$rel.Function.Receiver}}.Add{{$rel.Function.Name}}(exec, insert, related...) } // Remove{{$rel.Function.Name}} relationships from objects passed in. // Removes related items from R.{{$rel.Function.Name}} (uses pointer comparison, removal does not keep order) // Sets related.R.{{$rel.Function.ForeignName}}. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Remove{{$rel.Function.Name}}(exec boil.Executor, related ...*{{$rel.ForeignTable.NameGo}}) error { - var err error - {{if .ToJoinTable -}} - query := fmt.Sprintf( - "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}} and {{.JoinForeignColumn | $dot.Quotes}} in (%s)", - strmangle.Placeholders(dialect.IndexPlaceholders, len(related), 1, 1), - ) - values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, query) - fmt.Fprintln(boil.DebugWriter, values) - } - - _, err = exec.Exec(query, values...) - if err != nil { - return errors.Wrap(err, "failed to remove relationships before set") - } - {{else -}} - for _, rel := range related { - rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false - {{if not .ToJoinTable -}} - if rel.R != nil { - rel.R.{{$rel.Function.ForeignName}} = nil - } - {{end -}} - if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil { - return err - } - } - {{end -}} - - {{if .ToJoinTable -}} - remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related) - {{end -}} - if {{$rel.Function.Receiver}}.R == nil { - return nil - } - - for _, rel := range related { - for i, ri := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} { - if rel != ri { - continue - } - - ln := len({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}) - if ln > 1 && i < ln-1 { - {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[i] = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[ln-1] - } - {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[:ln-1] - break - } - } - - return nil + var err error + {{if .ToJoinTable -}} + query := fmt.Sprintf( + "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}} and {{.JoinForeignColumn | $dot.Quotes}} in (%s)", + strmangle.Placeholders(dialect.IndexPlaceholders, len(related), 1, 1), + ) + values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, query) + fmt.Fprintln(boil.DebugWriter, values) + } + + _, err = exec.Exec(query, values...) + if err != nil { + return errors.Wrap(err, "failed to remove relationships before set") + } + {{else -}} + for _, rel := range related { + rel.{{$rel.ForeignTable.ColumnNameGo}}.Valid = false + {{if not .ToJoinTable -}} + if rel.R != nil { + rel.R.{{$rel.Function.ForeignName}} = nil + } + {{end -}} + if err = rel.Update(exec, "{{.ForeignColumn}}"); err != nil { + return err + } + } + {{end -}} + + {{if .ToJoinTable -}} + remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}}, related) + {{end -}} + if {{$rel.Function.Receiver}}.R == nil { + return nil + } + + for _, rel := range related { + for i, ri := range {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} { + if rel != ri { + continue + } + + ln := len({{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}) + if ln > 1 && i < ln-1 { + {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[i] = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[ln-1] + } + {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}} = {{$rel.Function.Receiver}}.R.{{$rel.Function.Name}}[:ln-1] + break + } + } + + return nil } - {{if .ToJoinTable -}} + {{if .ToJoinTable -}} func remove{{$rel.LocalTable.NameGo}}From{{$rel.ForeignTable.NameGo}}Slice({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}, related []*{{$rel.ForeignTable.NameGo}}) { - for _, rel := range related { - if rel.R == nil { - continue - } - for i, ri := range rel.R.{{$rel.Function.ForeignName}} { - if {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}} != ri.{{$rel.Function.LocalAssignment}} { - continue - } - - ln := len(rel.R.{{$rel.Function.ForeignName}}) - if ln > 1 && i < ln-1 { - rel.R.{{$rel.Function.ForeignName}}[i] = rel.R.{{$rel.Function.ForeignName}}[ln-1] - } - rel.R.{{$rel.Function.ForeignName}} = rel.R.{{$rel.Function.ForeignName}}[:ln-1] - break - } - } + for _, rel := range related { + if rel.R == nil { + continue + } + for i, ri := range rel.R.{{$rel.Function.ForeignName}} { + if {{$rel.Function.Receiver}}.{{$rel.Function.LocalAssignment}} != ri.{{$rel.Function.LocalAssignment}} { + continue + } + + ln := len(rel.R.{{$rel.Function.ForeignName}}) + if ln > 1 && i < ln-1 { + rel.R.{{$rel.Function.ForeignName}}[i] = rel.R.{{$rel.Function.ForeignName}}[ln-1] + } + rel.R.{{$rel.Function.ForeignName}} = rel.R.{{$rel.Function.ForeignName}}[:ln-1] + break + } + } } - {{end -}}{{- /* if ToJoinTable */ -}} - {{- end -}}{{- /* if nullable foreign key */ -}} - {{- end -}}{{- /* if unique foreign key */ -}} - {{- end -}}{{- /* range relationships */ -}} + {{end -}}{{- /* if ToJoinTable */ -}} + {{- end -}}{{- /* if nullable foreign key */ -}} + {{- end -}}{{- /* if unique foreign key */ -}} + {{- end -}}{{- /* range relationships */ -}} {{- end -}}{{- /* if IsJoinTable */ -}} diff --git a/templates/10_all.tpl b/templates/10_all.tpl index 5a2df1459..e1ed9ddd5 100644 --- a/templates/10_all.tpl +++ b/templates/10_all.tpl @@ -3,11 +3,11 @@ {{- $varNameSingular := .Table.Name | singular | camelCase -}} // {{$tableNamePlural}}G retrieves all records. func {{$tableNamePlural}}G(mods ...qm.QueryMod) {{$varNameSingular}}Query { - return {{$tableNamePlural}}(boil.GetDB(), mods...) + return {{$tableNamePlural}}(boil.GetDB(), mods...) } // {{$tableNamePlural}} retrieves all the records using an executor. func {{$tableNamePlural}}(exec boil.Executor, mods ...qm.QueryMod) {{$varNameSingular}}Query { - mods = append(mods, qm.From("{{.Table.Name | .SchemaTable}}")) - return {{$varNameSingular}}Query{NewQuery(exec, mods...)} + mods = append(mods, qm.From("{{.Table.Name | .SchemaTable}}")) + return {{$varNameSingular}}Query{NewQuery(exec, mods...)} } diff --git a/templates/11_find.tpl b/templates/11_find.tpl index aede96596..83eae619a 100644 --- a/templates/11_find.tpl +++ b/templates/11_find.tpl @@ -6,51 +6,51 @@ {{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}} // {{$tableNameSingular}}FindG retrieves a single record by ID. func Find{{$tableNameSingular}}G({{$pkArgs}}, selectCols ...string) (*{{$tableNameSingular}}, error) { - return Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...) + return Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...) } // {{$tableNameSingular}}FindGP retrieves a single record by ID, and panics on error. func Find{{$tableNameSingular}}GP({{$pkArgs}}, selectCols ...string) *{{$tableNameSingular}} { - retobj, err := Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...) - if err != nil { - panic(boil.WrapErr(err)) - } + retobj, err := Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...) + if err != nil { + panic(boil.WrapErr(err)) + } - return retobj + return retobj } // {{$tableNameSingular}}Find retrieves a single record by ID with an executor. // If selectCols is empty Find will return all columns. func Find{{$tableNameSingular}}(exec boil.Executor, {{$pkArgs}}, selectCols ...string) (*{{$tableNameSingular}}, error) { - {{$varNameSingular}}Obj := &{{$tableNameSingular}}{} - - sel := "*" - if len(selectCols) > 0 { - sel = strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, selectCols), ",") - } - query := fmt.Sprintf( - "select %s from {{.Table.Name | .SchemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}", sel, - ) - - q := boil.SQL(exec, query, {{$pkNames | join ", "}}) - - err := q.Bind({{$varNameSingular}}Obj) - if err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return nil, sql.ErrNoRows - } - return nil, errors.Wrap(err, "{{.PkgName}}: unable to select from {{.Table.Name}}") - } - - return {{$varNameSingular}}Obj, nil + {{$varNameSingular}}Obj := &{{$tableNameSingular}}{} + + sel := "*" + if len(selectCols) > 0 { + sel = strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, selectCols), ",") + } + query := fmt.Sprintf( + "select %s from {{.Table.Name | .SchemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}", sel, + ) + + q := boil.SQL(exec, query, {{$pkNames | join ", "}}) + + err := q.Bind({{$varNameSingular}}Obj) + if err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, sql.ErrNoRows + } + return nil, errors.Wrap(err, "{{.PkgName}}: unable to select from {{.Table.Name}}") + } + + return {{$varNameSingular}}Obj, nil } // {{$tableNameSingular}}FindP retrieves a single record by ID with an executor, and panics on error. func Find{{$tableNameSingular}}P(exec boil.Executor, {{$pkArgs}}, selectCols ...string) *{{$tableNameSingular}} { - retobj, err := Find{{$tableNameSingular}}(exec, {{$pkNames | join ", "}}, selectCols...) - if err != nil { - panic(boil.WrapErr(err)) - } + retobj, err := Find{{$tableNameSingular}}(exec, {{$pkNames | join ", "}}, selectCols...) + if err != nil { + panic(boil.WrapErr(err)) + } - return retobj + return retobj } diff --git a/templates/12_insert.tpl b/templates/12_insert.tpl index d05412474..03568a13a 100644 --- a/templates/12_insert.tpl +++ b/templates/12_insert.tpl @@ -3,23 +3,23 @@ {{- $schemaTable := .Table.Name | .SchemaTable -}} // InsertG a single record. See Insert for whitelist behavior description. func (o *{{$tableNameSingular}}) InsertG(whitelist ... string) error { - return o.Insert(boil.GetDB(), whitelist...) + return o.Insert(boil.GetDB(), whitelist...) } // InsertGP a single record, and panics on error. See Insert for whitelist // behavior description. func (o *{{$tableNameSingular}}) InsertGP(whitelist ... string) { - if err := o.Insert(boil.GetDB(), whitelist...); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.Insert(boil.GetDB(), whitelist...); err != nil { + panic(boil.WrapErr(err)) + } } // InsertP a single record using an executor, and panics on error. See Insert // for whitelist behavior description. func (o *{{$tableNameSingular}}) InsertP(exec boil.Executor, whitelist ... string) { - if err := o.Insert(exec, whitelist...); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.Insert(exec, whitelist...); err != nil { + panic(boil.WrapErr(err)) + } } // Insert a single record using an executor. @@ -28,115 +28,115 @@ func (o *{{$tableNameSingular}}) InsertP(exec boil.Executor, whitelist ... strin // - All columns without a default value are included (i.e. name, age) // - All columns with a default, but non-zero are included (i.e. health = 75) func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string) error { - if o == nil { - return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion") - } - - var err error - {{- template "timestamp_insert_helper" . }} - - {{if not .NoHooks -}} - if err := o.doBeforeInsertHooks(exec); err != nil { - return err - } - {{- end}} - - nzDefaults := boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o) - - key := makeCacheKey(whitelist, nzDefaults) - {{$varNameSingular}}InsertCacheMut.RLock() - cache, cached := {{$varNameSingular}}InsertCache[key] - {{$varNameSingular}}InsertCacheMut.RUnlock() - - if !cached { - wl, returnColumns := strmangle.InsertColumnSet( - {{$varNameSingular}}Columns, - {{$varNameSingular}}ColumnsWithDefault, - {{$varNameSingular}}ColumnsWithoutDefault, - nzDefaults, - whitelist, - ) - - cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, wl) - if err != nil { - return err - } - cache.retMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, returnColumns) - if err != nil { - return err - } - cache.query = fmt.Sprintf("INSERT INTO {{$schemaTable}} ({{.LQ}}%s{{.RQ}}) VALUES (%s)", strings.Join(wl, "{{.LQ}},{{.RQ}}"), strmangle.Placeholders(dialect.IndexPlaceholders, len(wl), 1, 1)) - - if len(cache.retMapping) != 0 { - {{if .UseLastInsertID -}} - cache.retQuery = fmt.Sprintf("SELECT %s FROM {{$schemaTable}} WHERE %s", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}"), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns)) - {{else -}} - cache.query += fmt.Sprintf(" RETURNING {{.LQ}}%s{{.RQ}}", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}")) - {{end -}} - } - } - - value := reflect.Indirect(reflect.ValueOf(o)) - vals := boil.ValuesFromMapping(value, cache.valueMapping) - {{if .UseLastInsertID}} - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, cache.query) - fmt.Fprintln(boil.DebugWriter, vals) - } - - result, err := exec.Exec(cache.query, vals...) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}") - } - - if len(cache.retMapping) == 0 { - {{if not .NoHooks -}} - return o.doAfterInsertHooks(exec) - {{else -}} - return nil - {{end -}} - } - - lastID, err := result.LastInsertId() - if err != nil || lastID == 0 || len({{$varNameSingular}}PrimaryKeyColumns) != 1 { - return ErrSyncFail - } - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, cache.retQuery) - fmt.Fprintln(boil.DebugWriter, lastID) - } - - err = exec.QueryRow(cache.retQuery, lastID).Scan(boil.PtrsFromMapping(value, cache.retMapping)...) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}") - } - {{else}} - if len(cache.retMapping) != 0 { - err = exec.QueryRow(cache.query, vals...).Scan(boil.PtrsFromMapping(value, cache.retMapping)...) - } else { - _, err = exec.Exec(cache.query, vals...) - } - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, cache.query) - fmt.Fprintln(boil.DebugWriter, vals) - } - - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}") - } - {{end}} - - if !cached { - {{$varNameSingular}}InsertCacheMut.Lock() - {{$varNameSingular}}InsertCache[key] = cache - {{$varNameSingular}}InsertCacheMut.Unlock() - } - - {{if not .NoHooks -}} - return o.doAfterInsertHooks(exec) - {{- else -}} - return nil - {{- end}} + if o == nil { + return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion") + } + + var err error + {{- template "timestamp_insert_helper" . }} + + {{if not .NoHooks -}} + if err := o.doBeforeInsertHooks(exec); err != nil { + return err + } + {{- end}} + + nzDefaults := boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o) + + key := makeCacheKey(whitelist, nzDefaults) + {{$varNameSingular}}InsertCacheMut.RLock() + cache, cached := {{$varNameSingular}}InsertCache[key] + {{$varNameSingular}}InsertCacheMut.RUnlock() + + if !cached { + wl, returnColumns := strmangle.InsertColumnSet( + {{$varNameSingular}}Columns, + {{$varNameSingular}}ColumnsWithDefault, + {{$varNameSingular}}ColumnsWithoutDefault, + nzDefaults, + whitelist, + ) + + cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, wl) + if err != nil { + return err + } + cache.retMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, returnColumns) + if err != nil { + return err + } + cache.query = fmt.Sprintf("INSERT INTO {{$schemaTable}} ({{.LQ}}%s{{.RQ}}) VALUES (%s)", strings.Join(wl, "{{.LQ}},{{.RQ}}"), strmangle.Placeholders(dialect.IndexPlaceholders, len(wl), 1, 1)) + + if len(cache.retMapping) != 0 { + {{if .UseLastInsertID -}} + cache.retQuery = fmt.Sprintf("SELECT %s FROM {{$schemaTable}} WHERE %s", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}"), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns)) + {{else -}} + cache.query += fmt.Sprintf(" RETURNING {{.LQ}}%s{{.RQ}}", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}")) + {{end -}} + } + } + + value := reflect.Indirect(reflect.ValueOf(o)) + vals := boil.ValuesFromMapping(value, cache.valueMapping) + {{if .UseLastInsertID}} + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, cache.query) + fmt.Fprintln(boil.DebugWriter, vals) + } + + result, err := exec.Exec(cache.query, vals...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}") + } + + if len(cache.retMapping) == 0 { + {{if not .NoHooks -}} + return o.doAfterInsertHooks(exec) + {{else -}} + return nil + {{end -}} + } + + lastID, err := result.LastInsertId() + if err != nil || lastID == 0 || len({{$varNameSingular}}PrimaryKeyColumns) != 1 { + return ErrSyncFail + } + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, cache.retQuery) + fmt.Fprintln(boil.DebugWriter, lastID) + } + + err = exec.QueryRow(cache.retQuery, lastID).Scan(boil.PtrsFromMapping(value, cache.retMapping)...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}") + } + {{else}} + if len(cache.retMapping) != 0 { + err = exec.QueryRow(cache.query, vals...).Scan(boil.PtrsFromMapping(value, cache.retMapping)...) + } else { + _, err = exec.Exec(cache.query, vals...) + } + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, cache.query) + fmt.Fprintln(boil.DebugWriter, vals) + } + + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}") + } + {{end}} + + if !cached { + {{$varNameSingular}}InsertCacheMut.Lock() + {{$varNameSingular}}InsertCache[key] = cache + {{$varNameSingular}}InsertCacheMut.Unlock() + } + + {{if not .NoHooks -}} + return o.doAfterInsertHooks(exec) + {{- else -}} + return nil + {{- end}} } diff --git a/templates/13_update.tpl b/templates/13_update.tpl index 28c66311e..219448fe2 100644 --- a/templates/13_update.tpl +++ b/templates/13_update.tpl @@ -7,25 +7,25 @@ // UpdateG a single {{$tableNameSingular}} record. See Update for // whitelist behavior description. func (o *{{$tableNameSingular}}) UpdateG(whitelist ...string) error { - return o.Update(boil.GetDB(), whitelist...) + return o.Update(boil.GetDB(), whitelist...) } // UpdateGP a single {{$tableNameSingular}} record. // UpdateGP takes a whitelist of column names that should be updated. // Panics on error. See Update for whitelist behavior description. func (o *{{$tableNameSingular}}) UpdateGP(whitelist ...string) { - if err := o.Update(boil.GetDB(), whitelist...); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.Update(boil.GetDB(), whitelist...); err != nil { + panic(boil.WrapErr(err)) + } } // UpdateP uses an executor to update the {{$tableNameSingular}}, and panics on error. // See Update for whitelist behavior description. func (o *{{$tableNameSingular}}) UpdateP(exec boil.Executor, whitelist ... string) { - err := o.Update(exec, whitelist...) - if err != nil { - panic(boil.WrapErr(err)) - } + err := o.Update(exec, whitelist...) + if err != nil { + panic(boil.WrapErr(err)) + } } // Update uses an executor to update the {{$tableNameSingular}}. @@ -36,147 +36,147 @@ func (o *{{$tableNameSingular}}) UpdateP(exec boil.Executor, whitelist ... strin // Update does not automatically update the record in case of default values. Use .Reload() // to refresh the records. func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string) error { - {{- template "timestamp_update_helper" . -}} - - var err error - {{if not .NoHooks -}} - if err = o.doBeforeUpdateHooks(exec); err != nil { - return err - } - {{end -}} - - key := makeCacheKey(whitelist, nil) - {{$varNameSingular}}UpdateCacheMut.RLock() - cache, cached := {{$varNameSingular}}UpdateCache[key] - {{$varNameSingular}}UpdateCacheMut.RUnlock() - - if !cached { - wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist) - - cache.query = fmt.Sprintf("UPDATE {{$schemaTable}} SET %s WHERE %s", - strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, wl), - strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}len(wl)+1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns), - ) - cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...)) - if err != nil { - return err - } - } - - if len(cache.valueMapping) == 0 { - return errors.New("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist") - } - - values := boil.ValuesFromMapping(reflect.Indirect(reflect.ValueOf(o)), cache.valueMapping) - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, cache.query) - fmt.Fprintln(boil.DebugWriter, values) - } - - result, err := exec.Exec(cache.query, values...) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to update {{.Table.Name}} row") - } - - if r, err := result.RowsAffected(); err == nil && r != 1 { - return errors.Errorf("failed to update single row, updated %d rows", r) - } - - if !cached { - {{$varNameSingular}}UpdateCacheMut.Lock() - {{$varNameSingular}}UpdateCache[key] = cache - {{$varNameSingular}}UpdateCacheMut.Unlock() - } - - {{if not .NoHooks -}} - return o.doAfterUpdateHooks(exec) - {{- else -}} - return nil - {{- end}} + {{- template "timestamp_update_helper" . -}} + + var err error + {{if not .NoHooks -}} + if err = o.doBeforeUpdateHooks(exec); err != nil { + return err + } + {{end -}} + + key := makeCacheKey(whitelist, nil) + {{$varNameSingular}}UpdateCacheMut.RLock() + cache, cached := {{$varNameSingular}}UpdateCache[key] + {{$varNameSingular}}UpdateCacheMut.RUnlock() + + if !cached { + wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist) + + cache.query = fmt.Sprintf("UPDATE {{$schemaTable}} SET %s WHERE %s", + strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, wl), + strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}len(wl)+1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns), + ) + cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...)) + if err != nil { + return err + } + } + + if len(cache.valueMapping) == 0 { + return errors.New("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist") + } + + values := boil.ValuesFromMapping(reflect.Indirect(reflect.ValueOf(o)), cache.valueMapping) + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, cache.query) + fmt.Fprintln(boil.DebugWriter, values) + } + + result, err := exec.Exec(cache.query, values...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to update {{.Table.Name}} row") + } + + if r, err := result.RowsAffected(); err == nil && r != 1 { + return errors.Errorf("failed to update single row, updated %d rows", r) + } + + if !cached { + {{$varNameSingular}}UpdateCacheMut.Lock() + {{$varNameSingular}}UpdateCache[key] = cache + {{$varNameSingular}}UpdateCacheMut.Unlock() + } + + {{if not .NoHooks -}} + return o.doAfterUpdateHooks(exec) + {{- else -}} + return nil + {{- end}} } // UpdateAllP updates all rows with matching column names, and panics on error. func (q {{$varNameSingular}}Query) UpdateAllP(cols M) { - if err := q.UpdateAll(cols); err != nil { - panic(boil.WrapErr(err)) - } + if err := q.UpdateAll(cols); err != nil { + panic(boil.WrapErr(err)) + } } // UpdateAll updates all rows with the specified column values. func (q {{$varNameSingular}}Query) UpdateAll(cols M) error { - boil.SetUpdate(q.Query, cols) + boil.SetUpdate(q.Query, cols) - _, err := q.Query.ExecQuery() - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to update all for {{.Table.Name}}") - } + _, err := q.Query.ExecQuery() + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to update all for {{.Table.Name}}") + } - return nil + return nil } // UpdateAllG updates all rows with the specified column values. func (o {{$tableNameSingular}}Slice) UpdateAllG(cols M) error { - return o.UpdateAll(boil.GetDB(), cols) + return o.UpdateAll(boil.GetDB(), cols) } // UpdateAllGP updates all rows with the specified column values, and panics on error. func (o {{$tableNameSingular}}Slice) UpdateAllGP(cols M) { - if err := o.UpdateAll(boil.GetDB(), cols); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.UpdateAll(boil.GetDB(), cols); err != nil { + panic(boil.WrapErr(err)) + } } // UpdateAllP updates all rows with the specified column values, and panics on error. func (o {{$tableNameSingular}}Slice) UpdateAllP(exec boil.Executor, cols M) { - if err := o.UpdateAll(exec, cols); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.UpdateAll(exec, cols); err != nil { + panic(boil.WrapErr(err)) + } } // UpdateAll updates all rows with the specified column values, using an executor. func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error { - ln := int64(len(o)) - if ln == 0 { - return nil - } - - if len(cols) == 0 { - return errors.New("{{.PkgName}}: update all requires at least one column argument") - } - - colNames := make([]string, len(cols)) - args := make([]interface{}, len(cols)) - - i := 0 - for name, value := range cols { - colNames[i] = name - args[i] = value - i++ - } - - // Append all of the primary key values for each column - args = append(args, o.inPrimaryKeyArgs()...) - - sql := fmt.Sprintf( - "UPDATE {{$schemaTable}} SET %s WHERE ({{.LQ}}{{.Table.PKey.Columns | join (printf "%s,%s" .LQ .RQ)}}{{.RQ}}) IN (%s)", - strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, colNames), - strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(colNames)+1, len({{$varNameSingular}}PrimaryKeyColumns)), - ) - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, sql) - fmt.Fprintln(boil.DebugWriter, args...) - } - - result, err := exec.Exec(sql, args...) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to update all in {{$varNameSingular}} slice") - } - - if r, err := result.RowsAffected(); err == nil && r != ln { - return errors.Errorf("failed to update %d rows, only affected %d", ln, r) - } - - return nil + ln := int64(len(o)) + if ln == 0 { + return nil + } + + if len(cols) == 0 { + return errors.New("{{.PkgName}}: update all requires at least one column argument") + } + + colNames := make([]string, len(cols)) + args := make([]interface{}, len(cols)) + + i := 0 + for name, value := range cols { + colNames[i] = name + args[i] = value + i++ + } + + // Append all of the primary key values for each column + args = append(args, o.inPrimaryKeyArgs()...) + + sql := fmt.Sprintf( + "UPDATE {{$schemaTable}} SET %s WHERE ({{.LQ}}{{.Table.PKey.Columns | join (printf "%s,%s" .LQ .RQ)}}{{.RQ}}) IN (%s)", + strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, colNames), + strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(colNames)+1, len({{$varNameSingular}}PrimaryKeyColumns)), + ) + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, sql) + fmt.Fprintln(boil.DebugWriter, args...) + } + + result, err := exec.Exec(sql, args...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to update all in {{$varNameSingular}} slice") + } + + if r, err := result.RowsAffected(); err == nil && r != ln { + return errors.Errorf("failed to update %d rows, only affected %d", ln, r) + } + + return nil } diff --git a/templates/14_upsert.tpl b/templates/14_upsert.tpl index 2f41298a8..1f43c71d0 100644 --- a/templates/14_upsert.tpl +++ b/templates/14_upsert.tpl @@ -1,85 +1,109 @@ {{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} // UpsertG attempts an insert, and does an update or ignore on conflict. -func (o *{{$tableNameSingular}}) UpsertG(updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) error { - return o.Upsert(boil.GetDB(), updateOnConflict, conflictColumns, updateColumns, whitelist...) +func (o *{{$tableNameSingular}}) UpsertG({{if eq .DriverName "postgres"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { + return o.Upsert(boil.GetDB(), {{if eq .DriverName "postgres"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...) } // UpsertGP attempts an insert, and does an update or ignore on conflict. Panics on error. -func (o *{{$tableNameSingular}}) UpsertGP(updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) { - if err := o.Upsert(boil.GetDB(), updateOnConflict, conflictColumns, updateColumns, whitelist...); err != nil { - panic(boil.WrapErr(err)) - } +func (o *{{$tableNameSingular}}) UpsertGP({{if eq .DriverName "postgres"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { + if err := o.Upsert(boil.GetDB(), {{if eq .DriverName "postgres"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { + panic(boil.WrapErr(err)) + } } // UpsertP attempts an insert using an executor, and does an update or ignore on conflict. // UpsertP panics on error. -func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) { - if err := o.Upsert(exec, updateOnConflict, conflictColumns, updateColumns, whitelist...); err != nil { - panic(boil.WrapErr(err)) - } +func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, {{if eq .DriverName "postgres"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { + if err := o.Upsert(exec, {{if eq .DriverName "postgres"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { + panic(boil.WrapErr(err)) + } } - + // Upsert attempts an insert using an executor, and does an update or ignore on conflict. -func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, updateOnConflict bool, conflictColumns []string, updateColumns []string, whitelist ...string) error { - if o == nil { - return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for upsert") - } +func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if eq .DriverName "postgres"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { + if o == nil { + return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for upsert") + } + + {{- template "timestamp_upsert_helper" . }} + + {{if not .NoHooks -}} + if err := o.doBeforeUpsertHooks(exec); err != nil { + return err + } + {{- end}} - {{- template "timestamp_upsert_helper" . }} + var err error + var ret []string + whitelist, ret = strmangle.InsertColumnSet( + {{$varNameSingular}}Columns, + {{$varNameSingular}}ColumnsWithDefault, + {{$varNameSingular}}ColumnsWithoutDefault, + boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), + whitelist, + ) + update := strmangle.UpdateColumnSet( + {{$varNameSingular}}Columns, + {{$varNameSingular}}PrimaryKeyColumns, + updateColumns, + ) - {{if not .NoHooks -}} - if err := o.doBeforeUpsertHooks(exec); err != nil { - return err - } - {{- end}} + {{if eq .DriverName "postgres" -}} + conflict := conflictColumns + if len(conflict) == 0 { + conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) + copy(conflict, {{$varNameSingular}}PrimaryKeyColumns) + } + {{- end}} - var err error - var ret []string - whitelist, ret = strmangle.InsertColumnSet( - {{$varNameSingular}}Columns, - {{$varNameSingular}}ColumnsWithDefault, - {{$varNameSingular}}ColumnsWithoutDefault, - boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), - whitelist, - ) - update := strmangle.UpdateColumnSet( - {{$varNameSingular}}Columns, - {{$varNameSingular}}PrimaryKeyColumns, - updateColumns, - ) - conflict := conflictColumns - if len(conflict) == 0 { - conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) - copy(conflict, {{$varNameSingular}}PrimaryKeyColumns) - } + {{if eq .DriverName "postgres" -}} + query := boil.BuildUpsertQueryPostgres(dialect, "{{.Table.Name}}", updateOnConflict, ret, update, conflict, whitelist) + {{- else if eq .DriverName "mysql" -}} + query := boil.BuildUpsertQueryMySQL(dialect, "{{.Table.Name}}", update, whitelist) + {{- end}} - query := boil.BuildUpsertQuery(dialect, "{{.Table.Name}}", updateOnConflict, ret, update, conflict, whitelist) + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, query) + fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, whitelist...)) + } - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, query) - fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, whitelist...)) - } + {{- if .UseLastInsertID}} + res, err := exec.Exec(query, boil.GetStructValues(o, whitelist...)...) + {{- else}} + if len(ret) != 0 { + err = exec.QueryRow(query, boil.GetStructValues(o, whitelist...)...).Scan(boil.GetStructPointers(o, ret...)...) + } else { + _, err = exec.Exec(query, boil.GetStructValues(o, whitelist...)...) + } + {{- end}} - {{- if .UseLastInsertID}} - return errors.New("don't know how to do this yet") - {{- else}} - if len(ret) != 0 { - err = exec.QueryRow(query, boil.GetStructValues(o, whitelist...)...).Scan(boil.GetStructPointers(o, ret...)...) - } else { - _, err = exec.Exec(query, boil.GetStructValues(o, whitelist...)...) - } - {{- end}} + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") + } - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") - } + {{if .UseLastInsertID -}} + if len(ret) != 0 { + lid, err := res.LastInsertId() + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to get last insert id for {{.Table.Name}}") + } + {{$aipk := autoIncPrimaryKey .Table.Columns .Table.PKey}} + aipk := "{{$aipk.Name}}" + // if the update did not change anything, lid will be 0 + if lid == 0 && aipk == "" { + // do a select using all pkeys + } else if lid != 0 { + // do a select using all pkeys + lid + } + } + {{- end}} - {{if not .NoHooks -}} - if err := o.doAfterUpsertHooks(exec); err != nil { - return err - } - {{- end}} + {{if not .NoHooks -}} + if err := o.doAfterUpsertHooks(exec); err != nil { + return err + } + {{- end}} - return nil + return nil } diff --git a/templates/15_delete.tpl b/templates/15_delete.tpl index 34193556a..b154705c4 100644 --- a/templates/15_delete.tpl +++ b/templates/15_delete.tpl @@ -5,158 +5,158 @@ // DeleteP will match against the primary key column to find the record to delete. // Panics on error. func (o *{{$tableNameSingular}}) DeleteP(exec boil.Executor) { - if err := o.Delete(exec); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.Delete(exec); err != nil { + panic(boil.WrapErr(err)) + } } // DeleteG deletes a single {{$tableNameSingular}} record. // DeleteG will match against the primary key column to find the record to delete. func (o *{{$tableNameSingular}}) DeleteG() error { - if o == nil { - return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for deletion") - } + if o == nil { + return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for deletion") + } - return o.Delete(boil.GetDB()) + return o.Delete(boil.GetDB()) } // DeleteGP deletes a single {{$tableNameSingular}} record. // DeleteGP will match against the primary key column to find the record to delete. // Panics on error. func (o *{{$tableNameSingular}}) DeleteGP() { - if err := o.DeleteG(); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.DeleteG(); err != nil { + panic(boil.WrapErr(err)) + } } // Delete deletes a single {{$tableNameSingular}} record with an executor. // Delete will match against the primary key column to find the record to delete. func (o *{{$tableNameSingular}}) Delete(exec boil.Executor) error { - if o == nil { - return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for delete") - } + if o == nil { + return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for delete") + } - {{if not .NoHooks -}} - if err := o.doBeforeDeleteHooks(exec); err != nil { - return err - } - {{- end}} + {{if not .NoHooks -}} + if err := o.doBeforeDeleteHooks(exec); err != nil { + return err + } + {{- end}} - args := o.inPrimaryKeyArgs() + args := o.inPrimaryKeyArgs() - sql := "DELETE FROM {{$schemaTable}} WHERE {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}" + sql := "DELETE FROM {{$schemaTable}} WHERE {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}" - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, sql) - fmt.Fprintln(boil.DebugWriter, args...) - } + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, sql) + fmt.Fprintln(boil.DebugWriter, args...) + } - _, err := exec.Exec(sql, args...) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to delete from {{.Table.Name}}") - } + _, err := exec.Exec(sql, args...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to delete from {{.Table.Name}}") + } - {{if not .NoHooks -}} - if err := o.doAfterDeleteHooks(exec); err != nil { - return err - } - {{- end}} + {{if not .NoHooks -}} + if err := o.doAfterDeleteHooks(exec); err != nil { + return err + } + {{- end}} - return nil + return nil } // DeleteAllP deletes all rows, and panics on error. func (q {{$varNameSingular}}Query) DeleteAllP() { - if err := q.DeleteAll(); err != nil { - panic(boil.WrapErr(err)) - } + if err := q.DeleteAll(); err != nil { + panic(boil.WrapErr(err)) + } } // DeleteAll deletes all matching rows. func (q {{$varNameSingular}}Query) DeleteAll() error { - if q.Query == nil { - return errors.New("{{.PkgName}}: no {{$varNameSingular}}Query provided for delete all") - } + if q.Query == nil { + return errors.New("{{.PkgName}}: no {{$varNameSingular}}Query provided for delete all") + } - boil.SetDelete(q.Query) + boil.SetDelete(q.Query) - _, err := q.Query.ExecQuery() - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{.Table.Name}}") - } + _, err := q.Query.ExecQuery() + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{.Table.Name}}") + } - return nil + return nil } // DeleteAll deletes all rows in the slice, and panics on error. func (o {{$tableNameSingular}}Slice) DeleteAllGP() { - if err := o.DeleteAllG(); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.DeleteAllG(); err != nil { + panic(boil.WrapErr(err)) + } } // DeleteAllG deletes all rows in the slice. func (o {{$tableNameSingular}}Slice) DeleteAllG() error { - if o == nil { - return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all") - } - return o.DeleteAll(boil.GetDB()) + if o == nil { + return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all") + } + return o.DeleteAll(boil.GetDB()) } // DeleteAllP deletes all rows in the slice, using an executor, and panics on error. func (o {{$tableNameSingular}}Slice) DeleteAllP(exec boil.Executor) { - if err := o.DeleteAll(exec); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.DeleteAll(exec); err != nil { + panic(boil.WrapErr(err)) + } } // DeleteAll deletes all rows in the slice, using an executor. func (o {{$tableNameSingular}}Slice) DeleteAll(exec boil.Executor) error { - if o == nil { - return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all") - } - - if len(o) == 0 { - return nil - } - - {{if not .NoHooks -}} - if len({{$varNameSingular}}BeforeDeleteHooks) != 0 { - for _, obj := range o { - if err := obj.doBeforeDeleteHooks(exec); err != nil { - return err - } - } - } - {{- end}} - - args := o.inPrimaryKeyArgs() - - sql := fmt.Sprintf( - "DELETE FROM {{$schemaTable}} WHERE (%s) IN (%s)", - strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), - strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), - ) - - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, sql) - fmt.Fprintln(boil.DebugWriter, args) - } - - _, err := exec.Exec(sql, args...) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{$varNameSingular}} slice") - } - - {{if not .NoHooks -}} - if len({{$varNameSingular}}AfterDeleteHooks) != 0 { - for _, obj := range o { - if err := obj.doAfterDeleteHooks(exec); err != nil { - return err - } - } - } - {{- end}} - - return nil + if o == nil { + return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all") + } + + if len(o) == 0 { + return nil + } + + {{if not .NoHooks -}} + if len({{$varNameSingular}}BeforeDeleteHooks) != 0 { + for _, obj := range o { + if err := obj.doBeforeDeleteHooks(exec); err != nil { + return err + } + } + } + {{- end}} + + args := o.inPrimaryKeyArgs() + + sql := fmt.Sprintf( + "DELETE FROM {{$schemaTable}} WHERE (%s) IN (%s)", + strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), + strmangle.Placeholders(dialect.IndexPlaceholders, len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), + ) + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, sql) + fmt.Fprintln(boil.DebugWriter, args) + } + + _, err := exec.Exec(sql, args...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{$varNameSingular}} slice") + } + + {{if not .NoHooks -}} + if len({{$varNameSingular}}AfterDeleteHooks) != 0 { + for _, obj := range o { + if err := obj.doAfterDeleteHooks(exec); err != nil { + return err + } + } + } + {{- end}} + + return nil } diff --git a/templates/16_reload.tpl b/templates/16_reload.tpl index 5277c9fcb..8a3bc2cd8 100644 --- a/templates/16_reload.tpl +++ b/templates/16_reload.tpl @@ -4,83 +4,83 @@ {{- $schemaTable := .Table.Name | .SchemaTable -}} // ReloadGP refetches the object from the database and panics on error. func (o *{{$tableNameSingular}}) ReloadGP() { - if err := o.ReloadG(); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.ReloadG(); err != nil { + panic(boil.WrapErr(err)) + } } // ReloadP refetches the object from the database with an executor. Panics on error. func (o *{{$tableNameSingular}}) ReloadP(exec boil.Executor) { - if err := o.Reload(exec); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.Reload(exec); err != nil { + panic(boil.WrapErr(err)) + } } // ReloadG refetches the object from the database using the primary keys. func (o *{{$tableNameSingular}}) ReloadG() error { - if o == nil { - return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for reload") - } + if o == nil { + return errors.New("{{.PkgName}}: no {{$tableNameSingular}} provided for reload") + } - return o.Reload(boil.GetDB()) + return o.Reload(boil.GetDB()) } // Reload refetches the object from the database // using the primary keys with an executor. func (o *{{$tableNameSingular}}) Reload(exec boil.Executor) error { - ret, err := Find{{$tableNameSingular}}(exec, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}}) - if err != nil { - return err - } + ret, err := Find{{$tableNameSingular}}(exec, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}}) + if err != nil { + return err + } - *o = *ret - return nil + *o = *ret + return nil } func (o *{{$tableNameSingular}}Slice) ReloadAllGP() { - if err := o.ReloadAllG(); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.ReloadAllG(); err != nil { + panic(boil.WrapErr(err)) + } } func (o *{{$tableNameSingular}}Slice) ReloadAllP(exec boil.Executor) { - if err := o.ReloadAll(exec); err != nil { - panic(boil.WrapErr(err)) - } + if err := o.ReloadAll(exec); err != nil { + panic(boil.WrapErr(err)) + } } func (o *{{$tableNameSingular}}Slice) ReloadAllG() error { - if o == nil { - return errors.New("{{.PkgName}}: empty {{$tableNameSingular}}Slice provided for reload all") - } + if o == nil { + return errors.New("{{.PkgName}}: empty {{$tableNameSingular}}Slice provided for reload all") + } - return o.ReloadAll(boil.GetDB()) + return o.ReloadAll(boil.GetDB()) } // ReloadAll refetches every row with matching primary key column values // and overwrites the original object slice with the newly updated slice. func (o *{{$tableNameSingular}}Slice) ReloadAll(exec boil.Executor) error { - if o == nil || len(*o) == 0 { - return nil - } + if o == nil || len(*o) == 0 { + return nil + } - {{$varNamePlural}} := {{$tableNameSingular}}Slice{} - args := o.inPrimaryKeyArgs() + {{$varNamePlural}} := {{$tableNameSingular}}Slice{} + args := o.inPrimaryKeyArgs() - sql := fmt.Sprintf( - "SELECT {{$schemaTable}}.* FROM {{$schemaTable}} WHERE (%s) IN (%s)", - strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), - strmangle.Placeholders(dialect.IndexPlaceholders, len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), - ) + sql := fmt.Sprintf( + "SELECT {{$schemaTable}}.* FROM {{$schemaTable}} WHERE (%s) IN (%s)", + strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, {{$varNameSingular}}PrimaryKeyColumns), ","), + strmangle.Placeholders(dialect.IndexPlaceholders, len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), + ) - q := boil.SQL(exec, sql, args...) + q := boil.SQL(exec, sql, args...) - err := q.Bind(&{{$varNamePlural}}) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to reload all in {{$tableNameSingular}}Slice") - } + err := q.Bind(&{{$varNamePlural}}) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to reload all in {{$tableNameSingular}}Slice") + } - *o = {{$varNamePlural}} + *o = {{$varNamePlural}} - return nil + return nil } diff --git a/templates/17_exists.tpl b/templates/17_exists.tpl index b64da2f9d..48709aec8 100644 --- a/templates/17_exists.tpl +++ b/templates/17_exists.tpl @@ -5,46 +5,46 @@ {{- $schemaTable := .Table.Name | .SchemaTable -}} // {{$tableNameSingular}}Exists checks if the {{$tableNameSingular}} row exists. func {{$tableNameSingular}}Exists(exec boil.Executor, {{$pkArgs}}) (bool, error) { - var exists bool + var exists bool - sql := "select exists(select 1 from {{$schemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}} limit 1)" + sql := "select exists(select 1 from {{$schemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}} limit 1)" - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, sql) - fmt.Fprintln(boil.DebugWriter, {{$pkNames | join ", "}}) - } + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, sql) + fmt.Fprintln(boil.DebugWriter, {{$pkNames | join ", "}}) + } - row := exec.QueryRow(sql, {{$pkNames | join ", "}}) + row := exec.QueryRow(sql, {{$pkNames | join ", "}}) - err := row.Scan(&exists) - if err != nil { - return false, errors.Wrap(err, "{{.PkgName}}: unable to check if {{.Table.Name}} exists") - } + err := row.Scan(&exists) + if err != nil { + return false, errors.Wrap(err, "{{.PkgName}}: unable to check if {{.Table.Name}} exists") + } - return exists, nil + return exists, nil } // {{$tableNameSingular}}ExistsG checks if the {{$tableNameSingular}} row exists. func {{$tableNameSingular}}ExistsG({{$pkArgs}}) (bool, error) { - return {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}}) + return {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}}) } // {{$tableNameSingular}}ExistsGP checks if the {{$tableNameSingular}} row exists. Panics on error. func {{$tableNameSingular}}ExistsGP({{$pkArgs}}) bool { - e, err := {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}}) - if err != nil { - panic(boil.WrapErr(err)) - } + e, err := {{$tableNameSingular}}Exists(boil.GetDB(), {{$pkNames | join ", "}}) + if err != nil { + panic(boil.WrapErr(err)) + } - return e + return e } // {{$tableNameSingular}}ExistsP checks if the {{$tableNameSingular}} row exists. Panics on error. func {{$tableNameSingular}}ExistsP(exec boil.Executor, {{$pkArgs}}) bool { - e, err := {{$tableNameSingular}}Exists(exec, {{$pkNames | join ", "}}) - if err != nil { - panic(boil.WrapErr(err)) - } + e, err := {{$tableNameSingular}}Exists(exec, {{$pkNames | join ", "}}) + if err != nil { + panic(boil.WrapErr(err)) + } - return e + return e } diff --git a/templates/18_helpers.tpl b/templates/18_helpers.tpl index a9dd023d0..36618441a 100644 --- a/templates/18_helpers.tpl +++ b/templates/18_helpers.tpl @@ -1,23 +1,23 @@ {{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}} func (o {{$tableNameSingular}}) inPrimaryKeyArgs() []interface{} { - var args []interface{} + var args []interface{} - {{- range $key, $value := .Table.PKey.Columns }} - args = append(args, o.{{titleCase $value}}) - {{ end -}} + {{- range $key, $value := .Table.PKey.Columns }} + args = append(args, o.{{titleCase $value}}) + {{ end -}} - return args + return args } func (o {{$tableNameSingular}}Slice) inPrimaryKeyArgs() []interface{} { - var args []interface{} + var args []interface{} - for i := 0; i < len(o); i++ { - {{- range $key, $value := .Table.PKey.Columns }} - args = append(args, o[i].{{titleCase $value}}) - {{ end -}} - } + for i := 0; i < len(o); i++ { + {{- range $key, $value := .Table.PKey.Columns }} + args = append(args, o[i].{{titleCase $value}}) + {{ end -}} + } - return args + return args } diff --git a/templates/19_auto_timestamps.tpl b/templates/19_auto_timestamps.tpl index d600ccb77..fcb00d1cd 100644 --- a/templates/19_auto_timestamps.tpl +++ b/templates/19_auto_timestamps.tpl @@ -1,82 +1,82 @@ {{- define "timestamp_insert_helper" -}} - {{- if not .NoAutoTimestamps -}} - {{- $colNames := .Table.Columns | columnNames -}} - {{if containsAny $colNames "created_at" "updated_at"}} - currTime := time.Now().In(boil.GetLocation()) - {{range $ind, $col := .Table.Columns}} - {{- if eq $col.Name "created_at" -}} - {{- if $col.Nullable}} - if o.CreatedAt.Time.IsZero() { - o.CreatedAt.Time = currTime - o.CreatedAt.Valid = true - } - {{- else}} - if o.CreatedAt.IsZero() { - o.CreatedAt = currTime - } - {{- end -}} - {{- end -}} - {{- if eq $col.Name "updated_at" -}} - {{- if $col.Nullable}} - if o.UpdatedAt.Time.IsZero() { - o.UpdatedAt.Time = currTime - o.UpdatedAt.Valid = true - } - {{- else}} - if o.UpdatedAt.IsZero() { - o.UpdatedAt = currTime - } - {{- end -}} - {{- end -}} - {{end}} - {{end}} - {{- end}} + {{- if not .NoAutoTimestamps -}} + {{- $colNames := .Table.Columns | columnNames -}} + {{if containsAny $colNames "created_at" "updated_at"}} + currTime := time.Now().In(boil.GetLocation()) + {{range $ind, $col := .Table.Columns}} + {{- if eq $col.Name "created_at" -}} + {{- if $col.Nullable}} + if o.CreatedAt.Time.IsZero() { + o.CreatedAt.Time = currTime + o.CreatedAt.Valid = true + } + {{- else}} + if o.CreatedAt.IsZero() { + o.CreatedAt = currTime + } + {{- end -}} + {{- end -}} + {{- if eq $col.Name "updated_at" -}} + {{- if $col.Nullable}} + if o.UpdatedAt.Time.IsZero() { + o.UpdatedAt.Time = currTime + o.UpdatedAt.Valid = true + } + {{- else}} + if o.UpdatedAt.IsZero() { + o.UpdatedAt = currTime + } + {{- end -}} + {{- end -}} + {{end}} + {{end}} + {{- end}} {{- end -}} {{- define "timestamp_update_helper" -}} - {{- if not .NoAutoTimestamps -}} - {{- $colNames := .Table.Columns | columnNames -}} - {{if containsAny $colNames "updated_at"}} - currTime := time.Now().In(boil.GetLocation()) - {{range $ind, $col := .Table.Columns}} - {{- if eq $col.Name "updated_at" -}} - {{- if $col.Nullable}} - o.UpdatedAt.Time = currTime - o.UpdatedAt.Valid = true - {{- else}} - o.UpdatedAt = currTime - {{- end -}} - {{- end -}} - {{end}} - {{end}} - {{- end}} + {{- if not .NoAutoTimestamps -}} + {{- $colNames := .Table.Columns | columnNames -}} + {{if containsAny $colNames "updated_at"}} + currTime := time.Now().In(boil.GetLocation()) + {{range $ind, $col := .Table.Columns}} + {{- if eq $col.Name "updated_at" -}} + {{- if $col.Nullable}} + o.UpdatedAt.Time = currTime + o.UpdatedAt.Valid = true + {{- else}} + o.UpdatedAt = currTime + {{- end -}} + {{- end -}} + {{end}} + {{end}} + {{- end}} {{end -}} {{- define "timestamp_upsert_helper" -}} - {{- if not .NoAutoTimestamps -}} - {{- $colNames := .Table.Columns | columnNames -}} - {{if containsAny $colNames "created_at" "updated_at"}} - currTime := time.Now().In(boil.GetLocation()) - {{range $ind, $col := .Table.Columns}} - {{- if eq $col.Name "created_at" -}} - {{- if $col.Nullable}} - if o.CreatedAt.Time.IsZero() { - o.CreatedAt.Time = currTime - o.CreatedAt.Valid = true - } - {{- else}} - if o.CreatedAt.IsZero() { - o.CreatedAt = currTime - } - {{- end -}} - {{- end -}} - {{- if eq $col.Name "updated_at" -}} - {{- if $col.Nullable}} - o.UpdatedAt.Time = currTime - o.UpdatedAt.Valid = true - {{- else}} - o.UpdatedAt = currTime - {{- end -}} - {{- end -}} - {{end}} - {{end}} - {{- end}} + {{- if not .NoAutoTimestamps -}} + {{- $colNames := .Table.Columns | columnNames -}} + {{if containsAny $colNames "created_at" "updated_at"}} + currTime := time.Now().In(boil.GetLocation()) + {{range $ind, $col := .Table.Columns}} + {{- if eq $col.Name "created_at" -}} + {{- if $col.Nullable}} + if o.CreatedAt.Time.IsZero() { + o.CreatedAt.Time = currTime + o.CreatedAt.Valid = true + } + {{- else}} + if o.CreatedAt.IsZero() { + o.CreatedAt = currTime + } + {{- end -}} + {{- end -}} + {{- if eq $col.Name "updated_at" -}} + {{- if $col.Nullable}} + o.UpdatedAt.Time = currTime + o.UpdatedAt.Valid = true + {{- else}} + o.UpdatedAt = currTime + {{- end -}} + {{- end -}} + {{end}} + {{end}} + {{- end}} {{end -}} diff --git a/templates/singleton/boil_queries.tpl b/templates/singleton/boil_queries.tpl index 4db7187fa..8326f24d2 100644 --- a/templates/singleton/boil_queries.tpl +++ b/templates/singleton/boil_queries.tpl @@ -1,20 +1,20 @@ var dialect = boil.Dialect{ - LQ: 0x{{printf "%x" .Dialect.LQ}}, - RQ: 0x{{printf "%x" .Dialect.RQ}}, - IndexPlaceholders: {{.Dialect.IndexPlaceholders}}, + LQ: 0x{{printf "%x" .Dialect.LQ}}, + RQ: 0x{{printf "%x" .Dialect.RQ}}, + IndexPlaceholders: {{.Dialect.IndexPlaceholders}}, } // NewQueryG initializes a new Query using the passed in QueryMods func NewQueryG(mods ...qm.QueryMod) *boil.Query { - return NewQuery(boil.GetDB(), mods...) + return NewQuery(boil.GetDB(), mods...) } // NewQuery initializes a new Query using the passed in QueryMods func NewQuery(exec boil.Executor, mods ...qm.QueryMod) *boil.Query { - q := &boil.Query{} - boil.SetExecutor(q, exec) - boil.SetDialect(q, &dialect) - qm.Apply(q, mods...) + q := &boil.Query{} + boil.SetExecutor(q, exec) + boil.SetDialect(q, &dialect) + qm.Apply(q, mods...) - return q + return q } diff --git a/templates/singleton/boil_types.tpl b/templates/singleton/boil_types.tpl index 143a18c39..5e3282826 100644 --- a/templates/singleton/boil_types.tpl +++ b/templates/singleton/boil_types.tpl @@ -7,32 +7,32 @@ type M map[string]interface{} var ErrSyncFail = errors.New("{{.PkgName}}: failed to synchronize data after insert") type insertCache struct{ - query string - retQuery string - valueMapping []uint64 - retMapping []uint64 + query string + retQuery string + valueMapping []uint64 + retMapping []uint64 } type updateCache struct{ - query string - valueMapping []uint64 + query string + valueMapping []uint64 } func makeCacheKey(wl, nzDefaults []string) string { - buf := strmangle.GetBuffer() + buf := strmangle.GetBuffer() - for _, w := range wl { - buf.WriteString(w) - } - if len(nzDefaults) != 0 { - buf.WriteByte('.') - } - for _, nz := range nzDefaults { - buf.WriteString(nz) - } + for _, w := range wl { + buf.WriteString(w) + } + if len(nzDefaults) != 0 { + buf.WriteByte('.') + } + for _, nz := range nzDefaults { + buf.WriteString(nz) + } - str := buf.String() - strmangle.PutBuffer(buf) - return str + str := buf.String() + strmangle.PutBuffer(buf) + return str } diff --git a/templates_test/all.tpl b/templates_test/all.tpl index 001c32299..532801395 100644 --- a/templates_test/all.tpl +++ b/templates_test/all.tpl @@ -3,11 +3,11 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}(t *testing.T) { - t.Parallel() + t.Parallel() - query := {{$tableNamePlural}}(nil) + query := {{$tableNamePlural}}(nil) - if query.Query == nil { - t.Error("expected a query, got nothing") - } + if query.Query == nil { + t.Error("expected a query, got nothing") + } } diff --git a/templates_test/delete.tpl b/templates_test/delete.tpl index f6dfaca4d..f745ea48d 100644 --- a/templates_test/delete.tpl +++ b/templates_test/delete.tpl @@ -3,93 +3,93 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Delete(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - if err = {{$varNameSingular}}.Delete(tx); err != nil { - t.Error(err) - } - - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - - if count != 0 { - t.Error("want zero records, got:", count) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + if err = {{$varNameSingular}}.Delete(tx); err != nil { + t.Error(err) + } + + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + + if count != 0 { + t.Error("want zero records, got:", count) + } } func test{{$tableNamePlural}}QueryDeleteAll(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - if err = {{$tableNamePlural}}(tx).DeleteAll(); err != nil { - t.Error(err) - } - - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - - if count != 0 { - t.Error("want zero records, got:", count) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + if err = {{$tableNamePlural}}(tx).DeleteAll(); err != nil { + t.Error(err) + } + + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + + if count != 0 { + t.Error("want zero records, got:", count) + } } func test{{$tableNamePlural}}SliceDeleteAll(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} - - if err = slice.DeleteAll(tx); err != nil { - t.Error(err) - } - - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - - if count != 0 { - t.Error("want zero records, got:", count) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} + + if err = slice.DeleteAll(tx); err != nil { + t.Error(err) + } + + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + + if count != 0 { + t.Error("want zero records, got:", count) + } } diff --git a/templates_test/exists.tpl b/templates_test/exists.tpl index 36089214d..30bbcfbf3 100644 --- a/templates_test/exists.tpl +++ b/templates_test/exists.tpl @@ -3,27 +3,27 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Exists(t *testing.T) { - t.Parallel() + t.Parallel() - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } - {{$pkeyArgs := .Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", " -}} - e, err := {{$tableNameSingular}}Exists(tx, {{$pkeyArgs}}) - if err != nil { - t.Errorf("Unable to check if {{$tableNameSingular}} exists: %s", err) - } - if e != true { - t.Errorf("Expected {{$tableNameSingular}}ExistsG to return true, but got false.") - } + {{$pkeyArgs := .Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", " -}} + e, err := {{$tableNameSingular}}Exists(tx, {{$pkeyArgs}}) + if err != nil { + t.Errorf("Unable to check if {{$tableNameSingular}} exists: %s", err) + } + if e != true { + t.Errorf("Expected {{$tableNameSingular}}ExistsG to return true, but got false.") + } } diff --git a/templates_test/find.tpl b/templates_test/find.tpl index 2da3fdadc..cd3ea9677 100644 --- a/templates_test/find.tpl +++ b/templates_test/find.tpl @@ -3,27 +3,27 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Find(t *testing.T) { - t.Parallel() + t.Parallel() - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } - {{$varNameSingular}}Found, err := Find{{$tableNameSingular}}(tx, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", "}}) - if err != nil { - t.Error(err) - } + {{$varNameSingular}}Found, err := Find{{$tableNameSingular}}(tx, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice (printf "%s." $varNameSingular) | join ", "}}) + if err != nil { + t.Error(err) + } - if {{$varNameSingular}}Found == nil { - t.Error("want a record, got nil") - } + if {{$varNameSingular}}Found == nil { + t.Error("want a record, got nil") + } } diff --git a/templates_test/finishers.tpl b/templates_test/finishers.tpl index f7cff0d3a..fa8b129d2 100644 --- a/templates_test/finishers.tpl +++ b/templates_test/finishers.tpl @@ -3,111 +3,111 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Bind(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - if err = {{$tableNamePlural}}(tx).Bind({{$varNameSingular}}); err != nil { - t.Error(err) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + if err = {{$tableNamePlural}}(tx).Bind({{$varNameSingular}}); err != nil { + t.Error(err) + } } func test{{$tableNamePlural}}One(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - if x, err := {{$tableNamePlural}}(tx).One(); err != nil { - t.Error(err) - } else if x == nil { - t.Error("expected to get a non nil record") - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + if x, err := {{$tableNamePlural}}(tx).One(); err != nil { + t.Error(err) + } else if x == nil { + t.Error("expected to get a non nil record") + } } func test{{$tableNamePlural}}All(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}}One := &{{$tableNameSingular}}{} - {{$varNameSingular}}Two := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}One.Insert(tx); err != nil { - t.Error(err) - } - if err = {{$varNameSingular}}Two.Insert(tx); err != nil { - t.Error(err) - } - - slice, err := {{$tableNamePlural}}(tx).All() - if err != nil { - t.Error(err) - } - - if len(slice) != 2 { - t.Error("want 2 records, got:", len(slice)) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}}One := &{{$tableNameSingular}}{} + {{$varNameSingular}}Two := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}One.Insert(tx); err != nil { + t.Error(err) + } + if err = {{$varNameSingular}}Two.Insert(tx); err != nil { + t.Error(err) + } + + slice, err := {{$tableNamePlural}}(tx).All() + if err != nil { + t.Error(err) + } + + if len(slice) != 2 { + t.Error("want 2 records, got:", len(slice)) + } } func test{{$tableNamePlural}}Count(t *testing.T) { - t.Parallel() - - var err error - seed := randomize.NewSeed() - {{$varNameSingular}}One := &{{$tableNameSingular}}{} - {{$varNameSingular}}Two := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}One.Insert(tx); err != nil { - t.Error(err) - } - if err = {{$varNameSingular}}Two.Insert(tx); err != nil { - t.Error(err) - } - - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - - if count != 2 { - t.Error("want 2 records, got:", count) - } + t.Parallel() + + var err error + seed := randomize.NewSeed() + {{$varNameSingular}}One := &{{$tableNameSingular}}{} + {{$varNameSingular}}Two := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}One, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + if err = randomize.Struct(seed, {{$varNameSingular}}Two, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}One.Insert(tx); err != nil { + t.Error(err) + } + if err = {{$varNameSingular}}Two.Insert(tx); err != nil { + t.Error(err) + } + + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + + if count != 2 { + t.Error("want 2 records, got:", count) + } } diff --git a/templates_test/helpers.tpl b/templates_test/helpers.tpl index c20dbd0d9..09f9bff86 100644 --- a/templates_test/helpers.tpl +++ b/templates_test/helpers.tpl @@ -5,57 +5,57 @@ var {{$varNameSingular}}DBTypes = map[string]string{{"{"}}{{.Table.Columns | columnDBTypes | makeStringMap}}{{"}"}} func test{{$tableNamePlural}}InPrimaryKeyArgs(t *testing.T) { - t.Parallel() + t.Parallel() - var err error - var o {{$tableNameSingular}} - o = {{$tableNameSingular}}{} + var err error + var o {{$tableNameSingular}} + o = {{$tableNameSingular}}{} - seed := randomize.NewSeed() - if err = randomize.Struct(seed, &o, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Could not randomize struct: %s", err) - } + seed := randomize.NewSeed() + if err = randomize.Struct(seed, &o, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Could not randomize struct: %s", err) + } - args := o.inPrimaryKeyArgs() + args := o.inPrimaryKeyArgs() - if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) { - t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns), len(args)) - } + if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) { + t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns), len(args)) + } - {{range $key, $value := .Table.PKey.Columns}} - if o.{{titleCase $value}} != args[{{$key}}] { - t.Errorf("Expected args[{{$key}}] to be value of o.{{titleCase $value}}, but got %#v", args[{{$key}}]) - } - {{- end}} + {{range $key, $value := .Table.PKey.Columns}} + if o.{{titleCase $value}} != args[{{$key}}] { + t.Errorf("Expected args[{{$key}}] to be value of o.{{titleCase $value}}, but got %#v", args[{{$key}}]) + } + {{- end}} } func test{{$tableNamePlural}}SliceInPrimaryKeyArgs(t *testing.T) { - t.Parallel() - - var err error - o := make({{$tableNameSingular}}Slice, 3) - - seed := randomize.NewSeed() - for i := range o { - o[i] = &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, o[i], {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Could not randomize struct: %s", err) - } - } - - args := o.inPrimaryKeyArgs() - - if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) * 3 { - t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns) * 3, len(args)) - } - - argC := 0 - for i := 0; i < 3; i++ { - {{range $key, $value := .Table.PKey.Columns}} - if o[i].{{titleCase $value}} != args[argC] { - t.Errorf("Expected args[%d] to be value of o.{{titleCase $value}}, but got %#v", i, args[i]) - } - argC++ - {{- end}} - } + t.Parallel() + + var err error + o := make({{$tableNameSingular}}Slice, 3) + + seed := randomize.NewSeed() + for i := range o { + o[i] = &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, o[i], {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Could not randomize struct: %s", err) + } + } + + args := o.inPrimaryKeyArgs() + + if len(args) != len({{$varNameSingular}}PrimaryKeyColumns) * 3 { + t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns) * 3, len(args)) + } + + argC := 0 + for i := 0; i < 3; i++ { + {{range $key, $value := .Table.PKey.Columns}} + if o[i].{{titleCase $value}} != args[argC] { + t.Errorf("Expected args[%d] to be value of o.{{titleCase $value}}, but got %#v", i, args[i]) + } + argC++ + {{- end}} + } } diff --git a/templates_test/hooks.tpl b/templates_test/hooks.tpl index dd6a4bb64..22dc84c7e 100644 --- a/templates_test/hooks.tpl +++ b/templates_test/hooks.tpl @@ -4,142 +4,142 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func {{$varNameSingular}}BeforeInsertHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}AfterInsertHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}AfterSelectHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}BeforeUpdateHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}AfterUpdateHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}BeforeDeleteHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}AfterDeleteHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}BeforeUpsertHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func {{$varNameSingular}}AfterUpsertHook(e boil.Executor, o *{{$tableNameSingular}}) error { - *o = {{$tableNameSingular}}{} - return nil + *o = {{$tableNameSingular}}{} + return nil } func test{{$tableNamePlural}}Hooks(t *testing.T) { - t.Parallel() - - var err error - - empty := &{{$tableNameSingular}}{} - o := &{{$tableNameSingular}}{} - - seed := randomize.NewSeed() - if err = randomize.Struct(seed, o, {{$varNameSingular}}DBTypes, false); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} object: %s", err) - } - - Add{{$tableNameSingular}}Hook(boil.BeforeInsertHook, {{$varNameSingular}}BeforeInsertHook) - if err = o.doBeforeInsertHooks(nil); err != nil { - t.Errorf("Unable to execute doBeforeInsertHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected BeforeInsertHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}BeforeInsertHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.AfterInsertHook, {{$varNameSingular}}AfterInsertHook) - if err = o.doAfterInsertHooks(nil); err != nil { - t.Errorf("Unable to execute doAfterInsertHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected AfterInsertHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}AfterInsertHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.AfterSelectHook, {{$varNameSingular}}AfterSelectHook) - if err = o.doAfterSelectHooks(nil); err != nil { - t.Errorf("Unable to execute doAfterSelectHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected AfterSelectHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}AfterSelectHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.BeforeUpdateHook, {{$varNameSingular}}BeforeUpdateHook) - if err = o.doBeforeUpdateHooks(nil); err != nil { - t.Errorf("Unable to execute doBeforeUpdateHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected BeforeUpdateHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}BeforeUpdateHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.AfterUpdateHook, {{$varNameSingular}}AfterUpdateHook) - if err = o.doAfterUpdateHooks(nil); err != nil { - t.Errorf("Unable to execute doAfterUpdateHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected AfterUpdateHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}AfterUpdateHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.BeforeDeleteHook, {{$varNameSingular}}BeforeDeleteHook) - if err = o.doBeforeDeleteHooks(nil); err != nil { - t.Errorf("Unable to execute doBeforeDeleteHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected BeforeDeleteHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}BeforeDeleteHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.AfterDeleteHook, {{$varNameSingular}}AfterDeleteHook) - if err = o.doAfterDeleteHooks(nil); err != nil { - t.Errorf("Unable to execute doAfterDeleteHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected AfterDeleteHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}AfterDeleteHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.BeforeUpsertHook, {{$varNameSingular}}BeforeUpsertHook) - if err = o.doBeforeUpsertHooks(nil); err != nil { - t.Errorf("Unable to execute doBeforeUpsertHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected BeforeUpsertHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}BeforeUpsertHooks = []{{$tableNameSingular}}Hook{} - - Add{{$tableNameSingular}}Hook(boil.AfterUpsertHook, {{$varNameSingular}}AfterUpsertHook) - if err = o.doAfterUpsertHooks(nil); err != nil { - t.Errorf("Unable to execute doAfterUpsertHooks: %s", err) - } - if !reflect.DeepEqual(o, empty) { - t.Errorf("Expected AfterUpsertHook function to empty object, but got: %#v", o) - } - {{$varNameSingular}}AfterUpsertHooks = []{{$tableNameSingular}}Hook{} + t.Parallel() + + var err error + + empty := &{{$tableNameSingular}}{} + o := &{{$tableNameSingular}}{} + + seed := randomize.NewSeed() + if err = randomize.Struct(seed, o, {{$varNameSingular}}DBTypes, false); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} object: %s", err) + } + + Add{{$tableNameSingular}}Hook(boil.BeforeInsertHook, {{$varNameSingular}}BeforeInsertHook) + if err = o.doBeforeInsertHooks(nil); err != nil { + t.Errorf("Unable to execute doBeforeInsertHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected BeforeInsertHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}BeforeInsertHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.AfterInsertHook, {{$varNameSingular}}AfterInsertHook) + if err = o.doAfterInsertHooks(nil); err != nil { + t.Errorf("Unable to execute doAfterInsertHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected AfterInsertHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}AfterInsertHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.AfterSelectHook, {{$varNameSingular}}AfterSelectHook) + if err = o.doAfterSelectHooks(nil); err != nil { + t.Errorf("Unable to execute doAfterSelectHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected AfterSelectHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}AfterSelectHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.BeforeUpdateHook, {{$varNameSingular}}BeforeUpdateHook) + if err = o.doBeforeUpdateHooks(nil); err != nil { + t.Errorf("Unable to execute doBeforeUpdateHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected BeforeUpdateHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}BeforeUpdateHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.AfterUpdateHook, {{$varNameSingular}}AfterUpdateHook) + if err = o.doAfterUpdateHooks(nil); err != nil { + t.Errorf("Unable to execute doAfterUpdateHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected AfterUpdateHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}AfterUpdateHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.BeforeDeleteHook, {{$varNameSingular}}BeforeDeleteHook) + if err = o.doBeforeDeleteHooks(nil); err != nil { + t.Errorf("Unable to execute doBeforeDeleteHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected BeforeDeleteHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}BeforeDeleteHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.AfterDeleteHook, {{$varNameSingular}}AfterDeleteHook) + if err = o.doAfterDeleteHooks(nil); err != nil { + t.Errorf("Unable to execute doAfterDeleteHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected AfterDeleteHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}AfterDeleteHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.BeforeUpsertHook, {{$varNameSingular}}BeforeUpsertHook) + if err = o.doBeforeUpsertHooks(nil); err != nil { + t.Errorf("Unable to execute doBeforeUpsertHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected BeforeUpsertHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}BeforeUpsertHooks = []{{$tableNameSingular}}Hook{} + + Add{{$tableNameSingular}}Hook(boil.AfterUpsertHook, {{$varNameSingular}}AfterUpsertHook) + if err = o.doAfterUpsertHooks(nil); err != nil { + t.Errorf("Unable to execute doAfterUpsertHooks: %s", err) + } + if !reflect.DeepEqual(o, empty) { + t.Errorf("Expected AfterUpsertHook function to empty object, but got: %#v", o) + } + {{$varNameSingular}}AfterUpsertHooks = []{{$tableNameSingular}}Hook{} } {{- end}} diff --git a/templates_test/insert.tpl b/templates_test/insert.tpl index 63898dac1..d14a0c827 100644 --- a/templates_test/insert.tpl +++ b/templates_test/insert.tpl @@ -4,53 +4,53 @@ {{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $parent := . -}} func test{{$tableNamePlural}}Insert(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - - if count != 1 { - t.Error("want one record, got:", count) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + + if count != 1 { + t.Error("want one record, got:", count) + } } func test{{$tableNamePlural}}InsertWhitelist(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx, {{$varNameSingular}}Columns...); err != nil { - t.Error(err) - } - - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - - if count != 1 { - t.Error("want one record, got:", count) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx, {{$varNameSingular}}Columns...); err != nil { + t.Error(err) + } + + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + + if count != 1 { + t.Error("want one record, got:", count) + } } diff --git a/templates_test/main_test/mysql_main.tpl b/templates_test/main_test/mysql_main.tpl index 6ba0f415a..849aa9834 100644 --- a/templates_test/main_test/mysql_main.tpl +++ b/templates_test/main_test/mysql_main.tpl @@ -1,157 +1,166 @@ type mysqlTester struct { - dbConn *sql.DB + dbConn *sql.DB - dbName string - host string - user string - pass string - sslmode string - port int + dbName string + host string + user string + pass string + sslmode string + port int - optionFile string + optionFile string - testDBName string + testDBName string } func init() { - dbMain = &mysqlTester{} + dbMain = &mysqlTester{} } func (m *mysqlTester) setup() error { - var err error - - m.dbName = viper.GetString("mysql.dbname") - m.host = viper.GetString("mysql.host") - m.user = viper.GetString("mysql.user") - m.pass = viper.GetString("mysql.pass") - m.port = viper.GetInt("mysql.port") - m.sslmode = viper.GetString("mysql.sslmode") - // Create a randomized db name. - m.testDBName = randomize.StableDBName(m.dbName) - - if err = m.makeOptionFile(); err != nil { - return errors.Wrap(err, "couldn't make option file") - } - - if err = m.dropTestDB(); err != nil { - return err - } - if err = m.createTestDB(); err != nil { - return err - } - - dumpCmd := exec.Command("mysqldump", m.defaultsFile(), m.dbName) - createCmd := exec.Command("mysql", m.defaultsFile(), "--database", m.testDBName) - - r, w := io.Pipe() - dumpCmd.Stdout = w - createCmd.Stdin = newFKeyDestroyer(rgxMySQLkey, r) - - if err = dumpCmd.Start(); err != nil { - return errors.Wrap(err, "failed to start mysqldump command") - } - if err = createCmd.Start(); err != nil { - return errors.Wrap(err, "failed to start mysql command") - } - - if err = dumpCmd.Wait(); err != nil { - fmt.Println(err) - return errors.Wrap(err, "failed to wait for mysqldump command") - } - - w.Close() // After dumpCmd is done, close the write end of the pipe - - if err = createCmd.Wait(); err != nil { - fmt.Println(err) - return errors.Wrap(err, "failed to wait for mysql command") - } - - return nil + var err error + + m.dbName = viper.GetString("mysql.dbname") + m.host = viper.GetString("mysql.host") + m.user = viper.GetString("mysql.user") + m.pass = viper.GetString("mysql.pass") + m.port = viper.GetInt("mysql.port") + m.sslmode = viper.GetString("mysql.sslmode") + // Create a randomized db name. + m.testDBName = randomize.StableDBName(m.dbName) + + if err = m.makeOptionFile(); err != nil { + return errors.Wrap(err, "couldn't make option file") + } + + if err = m.dropTestDB(); err != nil { + return err + } + if err = m.createTestDB(); err != nil { + return err + } + + dumpCmd := exec.Command("mysqldump", m.defaultsFile(), m.dbName) + createCmd := exec.Command("mysql", m.defaultsFile(), "--database", m.testDBName) + + r, w := io.Pipe() + dumpCmd.Stdout = w + createCmd.Stdin = newFKeyDestroyer(rgxMySQLkey, r) + + if err = dumpCmd.Start(); err != nil { + return errors.Wrap(err, "failed to start mysqldump command") + } + if err = createCmd.Start(); err != nil { + return errors.Wrap(err, "failed to start mysql command") + } + + if err = dumpCmd.Wait(); err != nil { + fmt.Println(err) + return errors.Wrap(err, "failed to wait for mysqldump command") + } + + w.Close() // After dumpCmd is done, close the write end of the pipe + + if err = createCmd.Wait(); err != nil { + fmt.Println(err) + return errors.Wrap(err, "failed to wait for mysql command") + } + + return nil +} + +func (m *mysqlTester) sslMode(mode string) string { + switch mode { + case "true": + return "REQUIRED" + case "false": + return "DISABLED" + default: + return "PREFERRED" + } } func (m *mysqlTester) defaultsFile() string { - return fmt.Sprintf("--defaults-file=%s", m.optionFile) + return fmt.Sprintf("--defaults-file=%s", m.optionFile) } func (m *mysqlTester) makeOptionFile() error { - tmp, err := ioutil.TempFile("", "optionfile") - if err != nil { - return errors.Wrap(err, "failed to create option file") - } - - fmt.Fprintln(tmp, "[client]") - fmt.Fprintf(tmp, "host=%s\n", m.host) - fmt.Fprintf(tmp, "port=%d\n", m.port) - fmt.Fprintf(tmp, "user=%s\n", m.user) - fmt.Fprintf(tmp, "password=%s\n", m.pass) - // BUG: SSL Mode for whatever reason is backwards in the mysql driver - // taking options like true or false, but here taking options like - // required/disabled. Until this gets sorted, ignore this. - //fmt.Fprintf("ssl-mode=%s\n", m.password) - - fmt.Fprintln(tmp, "[mysqldump]") - fmt.Fprintf(tmp, "host=%s\n", m.host) - fmt.Fprintf(tmp, "port=%d\n", m.port) - fmt.Fprintf(tmp, "user=%s\n", m.user) - fmt.Fprintf(tmp, "password=%s\n", m.pass) - - m.optionFile = tmp.Name() - - return tmp.Close() + tmp, err := ioutil.TempFile("", "optionfile") + if err != nil { + return errors.Wrap(err, "failed to create option file") + } + + fmt.Fprintln(tmp, "[client]") + fmt.Fprintf(tmp, "host=%s\n", m.host) + fmt.Fprintf(tmp, "port=%d\n", m.port) + fmt.Fprintf(tmp, "user=%s\n", m.user) + fmt.Fprintf(tmp, "password=%s\n", m.pass) + fmt.Fprintf(tmp, "ssl-mode=%s\n", m.sslMode(m.sslmode)) + + fmt.Fprintln(tmp, "[mysqldump]") + fmt.Fprintf(tmp, "host=%s\n", m.host) + fmt.Fprintf(tmp, "port=%d\n", m.port) + fmt.Fprintf(tmp, "user=%s\n", m.user) + fmt.Fprintf(tmp, "password=%s\n", m.pass) + fmt.Fprintf(tmp, "ssl-mode=%s\n", m.sslMode(m.sslmode)) + + m.optionFile = tmp.Name() + + return tmp.Close() } func (m *mysqlTester) createTestDB() error { - sql := fmt.Sprintf("create database %s;", m.testDBName) - return m.runCmd(sql, "mysql") + sql := fmt.Sprintf("create database %s;", m.testDBName) + return m.runCmd(sql, "mysql") } func (m *mysqlTester) dropTestDB() error { - sql := fmt.Sprintf("drop database if exists %s;", m.testDBName) - return m.runCmd(sql, "mysql") + sql := fmt.Sprintf("drop database if exists %s;", m.testDBName) + return m.runCmd(sql, "mysql") } func (m *mysqlTester) teardown() error { - if m.dbConn != nil { - m.dbConn.Close() - } + if m.dbConn != nil { + m.dbConn.Close() + } - if err := m.dropTestDB(); err != nil { - return err - } + if err := m.dropTestDB(); err != nil { + return err + } - return os.Remove(m.optionFile) + return os.Remove(m.optionFile) } func (m *mysqlTester) runCmd(stdin, command string, args ...string) error { - args = append([]string{m.defaultsFile()}, args...) - - cmd := exec.Command(command, args...) - cmd.Stdin = strings.NewReader(stdin) - - stdout := &bytes.Buffer{} - stderr := &bytes.Buffer{} - cmd.Stdout = stdout - cmd.Stderr = stderr - if err := cmd.Run(); err != nil { - fmt.Println("failed running:", command, args) - fmt.Println(stdout.String()) - fmt.Println(stderr.String()) - return err - } - - return nil + args = append([]string{m.defaultsFile()}, args...) + + cmd := exec.Command(command, args...) + cmd.Stdin = strings.NewReader(stdin) + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + cmd.Stdout = stdout + cmd.Stderr = stderr + if err := cmd.Run(); err != nil { + fmt.Println("failed running:", command, args) + fmt.Println(stdout.String()) + fmt.Println(stderr.String()) + return err + } + + return nil } func (m *mysqlTester) conn() (*sql.DB, error) { - if m.dbConn != nil { - return m.dbConn, nil - } + if m.dbConn != nil { + return m.dbConn, nil + } - var err error - m.dbConn, err = sql.Open("mysql", drivers.MySQLBuildQueryString(m.user, m.pass, m.testDBName, m.host, m.port, m.sslmode)) - if err != nil { - return nil, err - } + var err error + m.dbConn, err = sql.Open("mysql", drivers.MySQLBuildQueryString(m.user, m.pass, m.testDBName, m.host, m.port, m.sslmode)) + if err != nil { + return nil, err + } - return m.dbConn, nil + return m.dbConn, nil } diff --git a/templates_test/relationship_to_many.tpl b/templates_test/relationship_to_many.tpl index b147890bd..cade6cd0a 100644 --- a/templates_test/relationship_to_many.tpl +++ b/templates_test/relationship_to_many.tpl @@ -1,98 +1,98 @@ {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . }} - {{- $table := .Table }} - {{- range .Table.ToManyRelationships -}} - {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- $dot := . }} + {{- $table := .Table }} + {{- range .Table.ToManyRelationships -}} + {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} {{- template "relationship_to_one_test_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}} - {{- else -}} - {{- $rel := textsFromRelationship $dot.Tables $table . -}} + {{- else -}} + {{- $rel := textsFromRelationship $dot.Tables $table . -}} func test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) { - var err error - tx := MustTx(boil.Begin()) - defer tx.Rollback() + var err error + tx := MustTx(boil.Begin()) + defer tx.Rollback() - var a {{$rel.LocalTable.NameGo}} - var b, c {{$rel.ForeignTable.NameGo}} + var a {{$rel.LocalTable.NameGo}} + var b, c {{$rel.ForeignTable.NameGo}} - if err := a.Insert(tx); err != nil { - t.Fatal(err) - } + if err := a.Insert(tx); err != nil { + t.Fatal(err) + } - seed := randomize.NewSeed() - randomize.Struct(seed, &b, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}") - randomize.Struct(seed, &c, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}") - {{if .Nullable -}} - a.{{.Column | titleCase}}.Valid = true - {{- end}} - {{- if .ForeignColumnNullable -}} - b.{{.ForeignColumn | titleCase}}.Valid = true - c.{{.ForeignColumn | titleCase}}.Valid = true - {{- end}} - {{if not .ToJoinTable -}} - b.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}} - c.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}} - {{- end}} - if err = b.Insert(tx); err != nil { - t.Fatal(err) - } - if err = c.Insert(tx); err != nil { - t.Fatal(err) - } + seed := randomize.NewSeed() + randomize.Struct(seed, &b, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}") + randomize.Struct(seed, &c, {{$rel.ForeignTable.NameSingular | camelCase}}DBTypes, false, "{{.ForeignColumn}}") + {{if .Nullable -}} + a.{{.Column | titleCase}}.Valid = true + {{- end}} + {{- if .ForeignColumnNullable -}} + b.{{.ForeignColumn | titleCase}}.Valid = true + c.{{.ForeignColumn | titleCase}}.Valid = true + {{- end}} + {{if not .ToJoinTable -}} + b.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}} + c.{{$rel.Function.ForeignAssignment}} = a.{{$rel.Function.LocalAssignment}} + {{- end}} + if err = b.Insert(tx); err != nil { + t.Fatal(err) + } + if err = c.Insert(tx); err != nil { + t.Fatal(err) + } - {{if .ToJoinTable -}} - _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) - if err != nil { - t.Fatal(err) - } - _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) - if err != nil { - t.Fatal(err) - } - {{end}} + {{if .ToJoinTable -}} + _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) + if err != nil { + t.Fatal(err) + } + _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) + if err != nil { + t.Fatal(err) + } + {{end}} - {{$varname := .ForeignTable | singular | camelCase -}} - {{$varname}}, err := a.{{$rel.Function.Name}}(tx).All() - if err != nil { - t.Fatal(err) - } + {{$varname := .ForeignTable | singular | camelCase -}} + {{$varname}}, err := a.{{$rel.Function.Name}}(tx).All() + if err != nil { + t.Fatal(err) + } - bFound, cFound := false, false - for _, v := range {{$varname}} { - if v.{{$rel.Function.ForeignAssignment}} == b.{{$rel.Function.ForeignAssignment}} { - bFound = true - } - if v.{{$rel.Function.ForeignAssignment}} == c.{{$rel.Function.ForeignAssignment}} { - cFound = true - } - } + bFound, cFound := false, false + for _, v := range {{$varname}} { + if v.{{$rel.Function.ForeignAssignment}} == b.{{$rel.Function.ForeignAssignment}} { + bFound = true + } + if v.{{$rel.Function.ForeignAssignment}} == c.{{$rel.Function.ForeignAssignment}} { + cFound = true + } + } - if !bFound { - t.Error("expected to find b") - } - if !cFound { - t.Error("expected to find c") - } + if !bFound { + t.Error("expected to find b") + } + if !cFound { + t.Error("expected to find c") + } - slice := {{$rel.LocalTable.NameGo}}Slice{&a} - if err = a.L.Load{{$rel.Function.Name}}(tx, false, &slice); err != nil { - t.Fatal(err) - } - if got := len(a.R.{{$rel.Function.Name}}); got != 2 { - t.Error("number of eager loaded records wrong, got:", got) - } + slice := {{$rel.LocalTable.NameGo}}Slice{&a} + if err = a.L.Load{{$rel.Function.Name}}(tx, false, &slice); err != nil { + t.Fatal(err) + } + if got := len(a.R.{{$rel.Function.Name}}); got != 2 { + t.Error("number of eager loaded records wrong, got:", got) + } - a.R.{{$rel.Function.Name}} = nil - if err = a.L.Load{{$rel.Function.Name}}(tx, true, &a); err != nil { - t.Fatal(err) - } - if got := len(a.R.{{$rel.Function.Name}}); got != 2 { - t.Error("number of eager loaded records wrong, got:", got) - } + a.R.{{$rel.Function.Name}} = nil + if err = a.L.Load{{$rel.Function.Name}}(tx, true, &a); err != nil { + t.Fatal(err) + } + if got := len(a.R.{{$rel.Function.Name}}); got != 2 { + t.Error("number of eager loaded records wrong, got:", got) + } - if t.Failed() { - t.Logf("%#v", {{$varname}}) - } + if t.Failed() { + t.Logf("%#v", {{$varname}}) + } } {{end -}}{{- /* if unique */ -}} diff --git a/templates_test/relationship_to_many_setops.tpl b/templates_test/relationship_to_many_setops.tpl index 22e1ea792..e653d3672 100644 --- a/templates_test/relationship_to_many_setops.tpl +++ b/templates_test/relationship_to_many_setops.tpl @@ -1,306 +1,306 @@ {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- $table := .Table -}} - {{- range .Table.ToManyRelationships -}} - {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} + {{- $dot := . -}} + {{- $table := .Table -}} + {{- range .Table.ToManyRelationships -}} + {{- if (and .ForeignColumnUnique (not .ToJoinTable)) -}} {{- template "relationship_to_one_setops_test_helper" (textsFromOneToOneRelationship $dot.PkgName $dot.Tables $table .) -}} - {{- else -}} - {{- $varNameSingular := .Table | singular | camelCase -}} - {{- $foreignVarNameSingular := .ForeignTable | singular | camelCase -}} - {{- $rel := textsFromRelationship $dot.Tables $table .}} + {{- else -}} + {{- $varNameSingular := .Table | singular | camelCase -}} + {{- $foreignVarNameSingular := .ForeignTable | singular | camelCase -}} + {{- $rel := textsFromRelationship $dot.Tables $table .}} func test{{$rel.LocalTable.NameGo}}ToManyAddOp{{$rel.Function.Name}}(t *testing.T) { - var err error - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - - var a {{$rel.LocalTable.NameGo}} - var b, c, d, e {{$rel.ForeignTable.NameGo}} - - seed := randomize.NewSeed() - if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} - for _, x := range foreigners { - if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - } - - if err := a.Insert(tx); err != nil { - t.Fatal(err) - } - if err = b.Insert(tx); err != nil { - t.Fatal(err) - } - if err = c.Insert(tx); err != nil { - t.Fatal(err) - } - - foreignersSplitByInsertion := [][]*{{$rel.ForeignTable.NameGo}}{ - {&b, &c}, - {&d, &e}, - } - - for i, x := range foreignersSplitByInsertion { - err = a.Add{{$rel.Function.Name}}(tx, i != 0, x...) - if err != nil { - t.Fatal(err) - } - - first := x[0] - second := x[1] - {{- if .ToJoinTable}} - - if first.R.{{$rel.Function.ForeignName}}[0] != &a { - t.Error("relationship was not added properly to the slice") - } - if second.R.{{$rel.Function.ForeignName}}[0] != &a { - t.Error("relationship was not added properly to the slice") - } - {{- else}} - - if a.{{$rel.Function.LocalAssignment}} != first.{{$rel.Function.ForeignAssignment}} { - t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, first.{{$rel.Function.ForeignAssignment}}) - } - if a.{{$rel.Function.LocalAssignment}} != second.{{$rel.Function.ForeignAssignment}} { - t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, second.{{$rel.Function.ForeignAssignment}}) - } - - if first.R.{{$rel.Function.ForeignName}} != &a { - t.Error("relationship was not added properly to the foreign slice") - } - if second.R.{{$rel.Function.ForeignName}} != &a { - t.Error("relationship was not added properly to the foreign slice") - } - {{- end}} - - if a.R.{{$rel.Function.Name}}[i*2] != first { - t.Error("relationship struct slice not set to correct value") - } - if a.R.{{$rel.Function.Name}}[i*2+1] != second { - t.Error("relationship struct slice not set to correct value") - } - - count, err := a.{{$rel.Function.Name}}(tx).Count() - if err != nil { - t.Fatal(err) - } - if want := int64((i+1)*2); count != want { - t.Error("want", want, "got", count) - } - } + var err error + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + + var a {{$rel.LocalTable.NameGo}} + var b, c, d, e {{$rel.ForeignTable.NameGo}} + + seed := randomize.NewSeed() + if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} + for _, x := range foreigners { + if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + } + + if err := a.Insert(tx); err != nil { + t.Fatal(err) + } + if err = b.Insert(tx); err != nil { + t.Fatal(err) + } + if err = c.Insert(tx); err != nil { + t.Fatal(err) + } + + foreignersSplitByInsertion := [][]*{{$rel.ForeignTable.NameGo}}{ + {&b, &c}, + {&d, &e}, + } + + for i, x := range foreignersSplitByInsertion { + err = a.Add{{$rel.Function.Name}}(tx, i != 0, x...) + if err != nil { + t.Fatal(err) + } + + first := x[0] + second := x[1] + {{- if .ToJoinTable}} + + if first.R.{{$rel.Function.ForeignName}}[0] != &a { + t.Error("relationship was not added properly to the slice") + } + if second.R.{{$rel.Function.ForeignName}}[0] != &a { + t.Error("relationship was not added properly to the slice") + } + {{- else}} + + if a.{{$rel.Function.LocalAssignment}} != first.{{$rel.Function.ForeignAssignment}} { + t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, first.{{$rel.Function.ForeignAssignment}}) + } + if a.{{$rel.Function.LocalAssignment}} != second.{{$rel.Function.ForeignAssignment}} { + t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, second.{{$rel.Function.ForeignAssignment}}) + } + + if first.R.{{$rel.Function.ForeignName}} != &a { + t.Error("relationship was not added properly to the foreign slice") + } + if second.R.{{$rel.Function.ForeignName}} != &a { + t.Error("relationship was not added properly to the foreign slice") + } + {{- end}} + + if a.R.{{$rel.Function.Name}}[i*2] != first { + t.Error("relationship struct slice not set to correct value") + } + if a.R.{{$rel.Function.Name}}[i*2+1] != second { + t.Error("relationship struct slice not set to correct value") + } + + count, err := a.{{$rel.Function.Name}}(tx).Count() + if err != nil { + t.Fatal(err) + } + if want := int64((i+1)*2); count != want { + t.Error("want", want, "got", count) + } + } } {{- if (or .ForeignColumnNullable .ToJoinTable)}} func test{{$rel.LocalTable.NameGo}}ToManySetOp{{$rel.Function.Name}}(t *testing.T) { - var err error - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - - var a {{$rel.LocalTable.NameGo}} - var b, c, d, e {{$rel.ForeignTable.NameGo}} - - seed := randomize.NewSeed() - if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} - for _, x := range foreigners { - if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - } - - if err = a.Insert(tx); err != nil { - t.Fatal(err) - } - if err = b.Insert(tx); err != nil { - t.Fatal(err) - } - if err = c.Insert(tx); err != nil { - t.Fatal(err) - } - - err = a.Set{{$rel.Function.Name}}(tx, false, &b, &c) - if err != nil { - t.Fatal(err) - } - - count, err := a.{{$rel.Function.Name}}(tx).Count() - if err != nil { - t.Fatal(err) - } - if count != 2 { - t.Error("count was wrong:", count) - } - - err = a.Set{{$rel.Function.Name}}(tx, true, &d, &e) - if err != nil { - t.Fatal(err) - } - - count, err = a.{{$rel.Function.Name}}(tx).Count() - if err != nil { - t.Fatal(err) - } - if count != 2 { - t.Error("count was wrong:", count) - } - - {{- if .ToJoinTable}} - - if len(b.R.{{$rel.Function.ForeignName}}) != 0 { - t.Error("relationship was not removed properly from the slice") - } - if len(c.R.{{$rel.Function.ForeignName}}) != 0 { - t.Error("relationship was not removed properly from the slice") - } - if d.R.{{$rel.Function.ForeignName}}[0] != &a { - t.Error("relationship was not added properly to the slice") - } - if e.R.{{$rel.Function.ForeignName}}[0] != &a { - t.Error("relationship was not added properly to the slice") - } - {{- else}} - - if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid { - t.Error("want b's foreign key value to be nil") - } - if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid { - t.Error("want c's foreign key value to be nil") - } - if a.{{$rel.Function.LocalAssignment}} != d.{{$rel.Function.ForeignAssignment}} { - t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, d.{{$rel.Function.ForeignAssignment}}) - } - if a.{{$rel.Function.LocalAssignment}} != e.{{$rel.Function.ForeignAssignment}} { - t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, e.{{$rel.Function.ForeignAssignment}}) - } - - if b.R.{{$rel.Function.ForeignName}} != nil { - t.Error("relationship was not removed properly from the foreign struct") - } - if c.R.{{$rel.Function.ForeignName}} != nil { - t.Error("relationship was not removed properly from the foreign struct") - } - if d.R.{{$rel.Function.ForeignName}} != &a { - t.Error("relationship was not added properly to the foreign struct") - } - if e.R.{{$rel.Function.ForeignName}} != &a { - t.Error("relationship was not added properly to the foreign struct") - } - {{- end}} - - if a.R.{{$rel.Function.Name}}[0] != &d { - t.Error("relationship struct slice not set to correct value") - } - if a.R.{{$rel.Function.Name}}[1] != &e { - t.Error("relationship struct slice not set to correct value") - } + var err error + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + + var a {{$rel.LocalTable.NameGo}} + var b, c, d, e {{$rel.ForeignTable.NameGo}} + + seed := randomize.NewSeed() + if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} + for _, x := range foreigners { + if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + } + + if err = a.Insert(tx); err != nil { + t.Fatal(err) + } + if err = b.Insert(tx); err != nil { + t.Fatal(err) + } + if err = c.Insert(tx); err != nil { + t.Fatal(err) + } + + err = a.Set{{$rel.Function.Name}}(tx, false, &b, &c) + if err != nil { + t.Fatal(err) + } + + count, err := a.{{$rel.Function.Name}}(tx).Count() + if err != nil { + t.Fatal(err) + } + if count != 2 { + t.Error("count was wrong:", count) + } + + err = a.Set{{$rel.Function.Name}}(tx, true, &d, &e) + if err != nil { + t.Fatal(err) + } + + count, err = a.{{$rel.Function.Name}}(tx).Count() + if err != nil { + t.Fatal(err) + } + if count != 2 { + t.Error("count was wrong:", count) + } + + {{- if .ToJoinTable}} + + if len(b.R.{{$rel.Function.ForeignName}}) != 0 { + t.Error("relationship was not removed properly from the slice") + } + if len(c.R.{{$rel.Function.ForeignName}}) != 0 { + t.Error("relationship was not removed properly from the slice") + } + if d.R.{{$rel.Function.ForeignName}}[0] != &a { + t.Error("relationship was not added properly to the slice") + } + if e.R.{{$rel.Function.ForeignName}}[0] != &a { + t.Error("relationship was not added properly to the slice") + } + {{- else}} + + if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid { + t.Error("want b's foreign key value to be nil") + } + if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid { + t.Error("want c's foreign key value to be nil") + } + if a.{{$rel.Function.LocalAssignment}} != d.{{$rel.Function.ForeignAssignment}} { + t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, d.{{$rel.Function.ForeignAssignment}}) + } + if a.{{$rel.Function.LocalAssignment}} != e.{{$rel.Function.ForeignAssignment}} { + t.Error("foreign key was wrong value", a.{{$rel.Function.LocalAssignment}}, e.{{$rel.Function.ForeignAssignment}}) + } + + if b.R.{{$rel.Function.ForeignName}} != nil { + t.Error("relationship was not removed properly from the foreign struct") + } + if c.R.{{$rel.Function.ForeignName}} != nil { + t.Error("relationship was not removed properly from the foreign struct") + } + if d.R.{{$rel.Function.ForeignName}} != &a { + t.Error("relationship was not added properly to the foreign struct") + } + if e.R.{{$rel.Function.ForeignName}} != &a { + t.Error("relationship was not added properly to the foreign struct") + } + {{- end}} + + if a.R.{{$rel.Function.Name}}[0] != &d { + t.Error("relationship struct slice not set to correct value") + } + if a.R.{{$rel.Function.Name}}[1] != &e { + t.Error("relationship struct slice not set to correct value") + } } func test{{$rel.LocalTable.NameGo}}ToManyRemoveOp{{$rel.Function.Name}}(t *testing.T) { - var err error - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - - var a {{$rel.LocalTable.NameGo}} - var b, c, d, e {{$rel.ForeignTable.NameGo}} - - seed := randomize.NewSeed() - if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} - for _, x := range foreigners { - if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - } - - if err := a.Insert(tx); err != nil { - t.Fatal(err) - } - - err = a.Add{{$rel.Function.Name}}(tx, true, foreigners...) - if err != nil { - t.Fatal(err) - } - - count, err := a.{{$rel.Function.Name}}(tx).Count() - if err != nil { - t.Fatal(err) - } - if count != 4 { - t.Error("count was wrong:", count) - } - - err = a.Remove{{$rel.Function.Name}}(tx, foreigners[:2]...) - if err != nil { - t.Fatal(err) - } - - count, err = a.{{$rel.Function.Name}}(tx).Count() - if err != nil { - t.Fatal(err) - } - if count != 2 { - t.Error("count was wrong:", count) - } - - {{- if .ToJoinTable}} - - if len(b.R.{{$rel.Function.ForeignName}}) != 0 { - t.Error("relationship was not removed properly from the slice") - } - if len(c.R.{{$rel.Function.ForeignName}}) != 0 { - t.Error("relationship was not removed properly from the slice") - } - if d.R.{{$rel.Function.ForeignName}}[0] != &a { - t.Error("relationship was not added properly to the foreign struct") - } - if e.R.{{$rel.Function.ForeignName}}[0] != &a { - t.Error("relationship was not added properly to the foreign struct") - } - {{- else}} - - if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid { - t.Error("want b's foreign key value to be nil") - } - if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid { - t.Error("want c's foreign key value to be nil") - } - - if b.R.{{$rel.Function.ForeignName}} != nil { - t.Error("relationship was not removed properly from the foreign struct") - } - if c.R.{{$rel.Function.ForeignName}} != nil { - t.Error("relationship was not removed properly from the foreign struct") - } - if d.R.{{$rel.Function.ForeignName}} != &a { - t.Error("relationship to a should have been preserved") - } - if e.R.{{$rel.Function.ForeignName}} != &a { - t.Error("relationship to a should have been preserved") - } - {{- end}} - - if len(a.R.{{$rel.Function.Name}}) != 2 { - t.Error("should have preserved two relationships") - } - - // Removal doesn't do a stable deletion for performance so we have to flip the order - if a.R.{{$rel.Function.Name}}[1] != &d { - t.Error("relationship to d should have been preserved") - } - if a.R.{{$rel.Function.Name}}[0] != &e { - t.Error("relationship to e should have been preserved") - } + var err error + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + + var a {{$rel.LocalTable.NameGo}} + var b, c, d, e {{$rel.ForeignTable.NameGo}} + + seed := randomize.NewSeed() + if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + foreigners := []*{{$rel.ForeignTable.NameGo}}{&b, &c, &d, &e} + for _, x := range foreigners { + if err = randomize.Struct(seed, x, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + } + + if err := a.Insert(tx); err != nil { + t.Fatal(err) + } + + err = a.Add{{$rel.Function.Name}}(tx, true, foreigners...) + if err != nil { + t.Fatal(err) + } + + count, err := a.{{$rel.Function.Name}}(tx).Count() + if err != nil { + t.Fatal(err) + } + if count != 4 { + t.Error("count was wrong:", count) + } + + err = a.Remove{{$rel.Function.Name}}(tx, foreigners[:2]...) + if err != nil { + t.Fatal(err) + } + + count, err = a.{{$rel.Function.Name}}(tx).Count() + if err != nil { + t.Fatal(err) + } + if count != 2 { + t.Error("count was wrong:", count) + } + + {{- if .ToJoinTable}} + + if len(b.R.{{$rel.Function.ForeignName}}) != 0 { + t.Error("relationship was not removed properly from the slice") + } + if len(c.R.{{$rel.Function.ForeignName}}) != 0 { + t.Error("relationship was not removed properly from the slice") + } + if d.R.{{$rel.Function.ForeignName}}[0] != &a { + t.Error("relationship was not added properly to the foreign struct") + } + if e.R.{{$rel.Function.ForeignName}}[0] != &a { + t.Error("relationship was not added properly to the foreign struct") + } + {{- else}} + + if b.{{$rel.ForeignTable.ColumnNameGo}}.Valid { + t.Error("want b's foreign key value to be nil") + } + if c.{{$rel.ForeignTable.ColumnNameGo}}.Valid { + t.Error("want c's foreign key value to be nil") + } + + if b.R.{{$rel.Function.ForeignName}} != nil { + t.Error("relationship was not removed properly from the foreign struct") + } + if c.R.{{$rel.Function.ForeignName}} != nil { + t.Error("relationship was not removed properly from the foreign struct") + } + if d.R.{{$rel.Function.ForeignName}} != &a { + t.Error("relationship to a should have been preserved") + } + if e.R.{{$rel.Function.ForeignName}} != &a { + t.Error("relationship to a should have been preserved") + } + {{- end}} + + if len(a.R.{{$rel.Function.Name}}) != 2 { + t.Error("should have preserved two relationships") + } + + // Removal doesn't do a stable deletion for performance so we have to flip the order + if a.R.{{$rel.Function.Name}}[1] != &d { + t.Error("relationship to d should have been preserved") + } + if a.R.{{$rel.Function.Name}}[0] != &e { + t.Error("relationship to e should have been preserved") + } } {{end -}} {{- end -}}{{- /* if unique foreign key */ -}} diff --git a/templates_test/relationship_to_one.tpl b/templates_test/relationship_to_one.tpl index 602f9b953..f3a60b4c9 100644 --- a/templates_test/relationship_to_one.tpl +++ b/templates_test/relationship_to_one.tpl @@ -1,69 +1,69 @@ {{- define "relationship_to_one_test_helper"}} func test{{.LocalTable.NameGo}}ToOne{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) { - tx := MustTx(boil.Begin()) - defer tx.Rollback() + tx := MustTx(boil.Begin()) + defer tx.Rollback() - var foreign {{.ForeignTable.NameGo}} - var local {{.LocalTable.NameGo}} - {{if .ForeignKey.Nullable -}} - local.{{.ForeignKey.Column | titleCase}}.Valid = true - {{end}} - {{- if .ForeignKey.ForeignColumnNullable -}} - foreign.{{.ForeignKey.ForeignColumn | titleCase}}.Valid = true - {{end}} + var foreign {{.ForeignTable.NameGo}} + var local {{.LocalTable.NameGo}} + {{if .ForeignKey.Nullable -}} + local.{{.ForeignKey.Column | titleCase}}.Valid = true + {{end}} + {{- if .ForeignKey.ForeignColumnNullable -}} + foreign.{{.ForeignKey.ForeignColumn | titleCase}}.Valid = true + {{end}} - {{if not .Function.OneToOne -}} - if err := foreign.Insert(tx); err != nil { - t.Fatal(err) - } + {{if not .Function.OneToOne -}} + if err := foreign.Insert(tx); err != nil { + t.Fatal(err) + } - local.{{.Function.LocalAssignment}} = foreign.{{.Function.ForeignAssignment}} - if err := local.Insert(tx); err != nil { - t.Fatal(err) - } - {{else -}} - if err := local.Insert(tx); err != nil { - t.Fatal(err) - } + local.{{.Function.LocalAssignment}} = foreign.{{.Function.ForeignAssignment}} + if err := local.Insert(tx); err != nil { + t.Fatal(err) + } + {{else -}} + if err := local.Insert(tx); err != nil { + t.Fatal(err) + } - foreign.{{.Function.ForeignAssignment}} = local.{{.Function.LocalAssignment}} - if err := foreign.Insert(tx); err != nil { - t.Fatal(err) - } - {{end -}} + foreign.{{.Function.ForeignAssignment}} = local.{{.Function.LocalAssignment}} + if err := foreign.Insert(tx); err != nil { + t.Fatal(err) + } + {{end -}} - check, err := local.{{.Function.Name}}(tx).One() - if err != nil { - t.Fatal(err) - } + check, err := local.{{.Function.Name}}(tx).One() + if err != nil { + t.Fatal(err) + } - if check.{{.Function.ForeignAssignment}} != foreign.{{.Function.ForeignAssignment}} { - t.Errorf("want: %v, got %v", foreign.{{.Function.ForeignAssignment}}, check.{{.Function.ForeignAssignment}}) - } + if check.{{.Function.ForeignAssignment}} != foreign.{{.Function.ForeignAssignment}} { + t.Errorf("want: %v, got %v", foreign.{{.Function.ForeignAssignment}}, check.{{.Function.ForeignAssignment}}) + } - slice := {{.LocalTable.NameGo}}Slice{&local} - if err = local.L.Load{{.Function.Name}}(tx, false, &slice); err != nil { - t.Fatal(err) - } - if local.R.{{.Function.Name}} == nil { - t.Error("struct should have been eager loaded") - } + slice := {{.LocalTable.NameGo}}Slice{&local} + if err = local.L.Load{{.Function.Name}}(tx, false, &slice); err != nil { + t.Fatal(err) + } + if local.R.{{.Function.Name}} == nil { + t.Error("struct should have been eager loaded") + } - local.R.{{.Function.Name}} = nil - if err = local.L.Load{{.Function.Name}}(tx, true, &local); err != nil { - t.Fatal(err) - } - if local.R.{{.Function.Name}} == nil { - t.Error("struct should have been eager loaded") - } + local.R.{{.Function.Name}} = nil + if err = local.L.Load{{.Function.Name}}(tx, true, &local); err != nil { + t.Fatal(err) + } + if local.R.{{.Function.Name}} == nil { + t.Error("struct should have been eager loaded") + } } {{end -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- range .Table.FKeys -}} - {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} + {{- $dot := . -}} + {{- range .Table.FKeys -}} + {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table . -}} {{- template "relationship_to_one_test_helper" $rel -}} {{end -}} {{- end -}} diff --git a/templates_test/relationship_to_one_setops.tpl b/templates_test/relationship_to_one_setops.tpl index ab27db590..2ff2ca76f 100644 --- a/templates_test/relationship_to_one_setops.tpl +++ b/templates_test/relationship_to_one_setops.tpl @@ -2,131 +2,131 @@ {{- $varNameSingular := .ForeignKey.Table | singular | camelCase -}} {{- $foreignVarNameSingular := .ForeignKey.ForeignTable | singular | camelCase -}} func test{{.LocalTable.NameGo}}ToOneSetOp{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) { - var err error - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - - var a {{.LocalTable.NameGo}} - var b, c {{.ForeignTable.NameGo}} - - seed := randomize.NewSeed() - if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - if err = randomize.Struct(seed, &c, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - - if err := a.Insert(tx); err != nil { - t.Fatal(err) - } - if err = b.Insert(tx); err != nil { - t.Fatal(err) - } - - for i, x := range []*{{.ForeignTable.NameGo}}{&b, &c} { - err = a.Set{{.Function.Name}}(tx, i != 0, x) - if err != nil { - t.Fatal(err) - } - - if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} { - t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}}) - } - if a.R.{{.Function.Name}} != x { - t.Error("relationship struct not set to correct value") - } - - zero := reflect.Zero(reflect.TypeOf(a.{{.Function.LocalAssignment}})) - reflect.Indirect(reflect.ValueOf(&a.{{.Function.LocalAssignment}})).Set(zero) - - if err = a.Reload(tx); err != nil { - t.Fatal("failed to reload", err) - } - - if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} { - t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}}, x.{{.Function.ForeignAssignment}}) - } - - {{if .ForeignKey.Unique -}} - if x.R.{{.Function.ForeignName}} != &a { - t.Error("failed to append to foreign relationship struct") - } - {{else -}} - if x.R.{{.Function.ForeignName}}[0] != &a { - t.Error("failed to append to foreign relationship struct") - } - {{end -}} - } + var err error + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + + var a {{.LocalTable.NameGo}} + var b, c {{.ForeignTable.NameGo}} + + seed := randomize.NewSeed() + if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + if err = randomize.Struct(seed, &c, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + + if err := a.Insert(tx); err != nil { + t.Fatal(err) + } + if err = b.Insert(tx); err != nil { + t.Fatal(err) + } + + for i, x := range []*{{.ForeignTable.NameGo}}{&b, &c} { + err = a.Set{{.Function.Name}}(tx, i != 0, x) + if err != nil { + t.Fatal(err) + } + + if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} { + t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}}) + } + if a.R.{{.Function.Name}} != x { + t.Error("relationship struct not set to correct value") + } + + zero := reflect.Zero(reflect.TypeOf(a.{{.Function.LocalAssignment}})) + reflect.Indirect(reflect.ValueOf(&a.{{.Function.LocalAssignment}})).Set(zero) + + if err = a.Reload(tx); err != nil { + t.Fatal("failed to reload", err) + } + + if a.{{.Function.LocalAssignment}} != x.{{.Function.ForeignAssignment}} { + t.Error("foreign key was wrong value", a.{{.Function.LocalAssignment}}, x.{{.Function.ForeignAssignment}}) + } + + {{if .ForeignKey.Unique -}} + if x.R.{{.Function.ForeignName}} != &a { + t.Error("failed to append to foreign relationship struct") + } + {{else -}} + if x.R.{{.Function.ForeignName}}[0] != &a { + t.Error("failed to append to foreign relationship struct") + } + {{end -}} + } } {{- if .ForeignKey.Nullable}} func test{{.LocalTable.NameGo}}ToOneRemoveOp{{.ForeignTable.NameGo}}_{{.Function.Name}}(t *testing.T) { - var err error - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - - var a {{.LocalTable.NameGo}} - var b {{.ForeignTable.NameGo}} - - seed := randomize.NewSeed() - if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { - t.Fatal(err) - } - - if err = a.Insert(tx); err != nil { - t.Fatal(err) - } - - if err = a.Set{{.Function.Name}}(tx, true, &b); err != nil { - t.Fatal(err) - } - - if err = a.Remove{{.Function.Name}}(tx, &b); err != nil { - t.Error("failed to remove relationship") - } - - count, err := a.{{.Function.Name}}(tx).Count() - if err != nil { - t.Error(err) - } - if count != 0 { - t.Error("want no relationships remaining") - } - - if a.R.{{.Function.Name}} != nil { - t.Error("R struct entry should be nil") - } - - if a.{{.LocalTable.ColumnNameGo}}.Valid { - t.Error("R struct entry should be nil") - } - - {{if .ForeignKey.Unique -}} - if b.R.{{.Function.ForeignName}} != nil { - t.Error("failed to remove a from b's relationships") - } - {{else -}} - if len(b.R.{{.Function.ForeignName}}) != 0 { - t.Error("failed to remove a from b's relationships") - } - {{end -}} + var err error + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + + var a {{.LocalTable.NameGo}} + var b {{.ForeignTable.NameGo}} + + seed := randomize.NewSeed() + if err = randomize.Struct(seed, &a, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + if err = randomize.Struct(seed, &b, {{$foreignVarNameSingular}}DBTypes, false, {{$foreignVarNameSingular}}PrimaryKeyColumns...); err != nil { + t.Fatal(err) + } + + if err = a.Insert(tx); err != nil { + t.Fatal(err) + } + + if err = a.Set{{.Function.Name}}(tx, true, &b); err != nil { + t.Fatal(err) + } + + if err = a.Remove{{.Function.Name}}(tx, &b); err != nil { + t.Error("failed to remove relationship") + } + + count, err := a.{{.Function.Name}}(tx).Count() + if err != nil { + t.Error(err) + } + if count != 0 { + t.Error("want no relationships remaining") + } + + if a.R.{{.Function.Name}} != nil { + t.Error("R struct entry should be nil") + } + + if a.{{.LocalTable.ColumnNameGo}}.Valid { + t.Error("R struct entry should be nil") + } + + {{if .ForeignKey.Unique -}} + if b.R.{{.Function.ForeignName}} != nil { + t.Error("failed to remove a from b's relationships") + } + {{else -}} + if len(b.R.{{.Function.ForeignName}}) != 0 { + t.Error("failed to remove a from b's relationships") + } + {{end -}} } {{end -}} {{- end -}} {{- if .Table.IsJoinTable -}} {{- else -}} - {{- $dot := . -}} - {{- range .Table.FKeys -}} - {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table .}} + {{- $dot := . -}} + {{- range .Table.FKeys -}} + {{- $rel := textsFromForeignKey $dot.PkgName $dot.Tables $dot.Table .}} {{template "relationship_to_one_setops_test_helper" $rel -}} {{- end -}} diff --git a/templates_test/reload.tpl b/templates_test/reload.tpl index da236bdb8..7879b40b9 100644 --- a/templates_test/reload.tpl +++ b/templates_test/reload.tpl @@ -3,45 +3,45 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Reload(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - if err = {{$varNameSingular}}.Reload(tx); err != nil { - t.Error(err) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + if err = {{$varNameSingular}}.Reload(tx); err != nil { + t.Error(err) + } } func test{{$tableNamePlural}}ReloadAll(t *testing.T) { - t.Parallel() - - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } - - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } - - slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} - - if err = slice.ReloadAll(tx); err != nil { - t.Error(err) - } + t.Parallel() + + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } + + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } + + slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} + + if err = slice.ReloadAll(tx); err != nil { + t.Error(err) + } } diff --git a/templates_test/select.tpl b/templates_test/select.tpl index 1b3aefbdf..0e5692fc0 100644 --- a/templates_test/select.tpl +++ b/templates_test/select.tpl @@ -3,27 +3,27 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Select(t *testing.T) { - t.Parallel() + t.Parallel() - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}ColumnsWithDefault...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } - slice, err := {{$tableNamePlural}}(tx).All() - if err != nil { - t.Error(err) - } + slice, err := {{$tableNamePlural}}(tx).All() + if err != nil { + t.Error(err) + } - if len(slice) != 1 { - t.Error("want one record, got:", len(slice)) - } + if len(slice) != 1 { + t.Error("want one record, got:", len(slice)) + } } diff --git a/templates_test/update.tpl b/templates_test/update.tpl index 5cd15bedb..98421c9d7 100644 --- a/templates_test/update.tpl +++ b/templates_test/update.tpl @@ -3,97 +3,97 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Update(t *testing.T) { - t.Parallel() + t.Parallel() - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } - if count != 1 { - t.Error("want one record, got:", count) - } + if count != 1 { + t.Error("want one record, got:", count) + } - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - // If table only contains primary key columns, we need to pass - // them into a whitelist to get a valid test result, - // otherwise the Update method will error because it will not be able to - // generate a whitelist (due to it excluding primary key columns). - if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) { - if err = {{$varNameSingular}}.Update(tx, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Error(err) - } - } else { - if err = {{$varNameSingular}}.Update(tx); err != nil { - t.Error(err) - } - } + // If table only contains primary key columns, we need to pass + // them into a whitelist to get a valid test result, + // otherwise the Update method will error because it will not be able to + // generate a whitelist (due to it excluding primary key columns). + if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) { + if err = {{$varNameSingular}}.Update(tx, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Error(err) + } + } else { + if err = {{$varNameSingular}}.Update(tx); err != nil { + t.Error(err) + } + } } func test{{$tableNamePlural}}SliceUpdateAll(t *testing.T) { - t.Parallel() + t.Parallel() - seed := randomize.NewSeed() - var err error - {{$varNameSingular}} := &{{$tableNameSingular}}{} - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + seed := randomize.NewSeed() + var err error + {{$varNameSingular}} := &{{$tableNameSingular}}{} + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Insert(tx); err != nil { - t.Error(err) - } + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Insert(tx); err != nil { + t.Error(err) + } - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } - if count != 1 { - t.Error("want one record, got:", count) - } + if count != 1 { + t.Error("want one record, got:", count) + } - if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + if err = randomize.Struct(seed, {{$varNameSingular}}, {{$varNameSingular}}DBTypes, true, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - // Remove Primary keys and unique columns from what we plan to update - var fields []string - if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) { - fields = {{$varNameSingular}}Columns - } else { - fields = strmangle.SetComplement( - {{$varNameSingular}}Columns, - {{$varNameSingular}}PrimaryKeyColumns, - ) - } + // Remove Primary keys and unique columns from what we plan to update + var fields []string + if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) { + fields = {{$varNameSingular}}Columns + } else { + fields = strmangle.SetComplement( + {{$varNameSingular}}Columns, + {{$varNameSingular}}PrimaryKeyColumns, + ) + } value := reflect.Indirect(reflect.ValueOf({{$varNameSingular}})) - updateMap := M{} - for _, col := range fields { - updateMap[col] = value.FieldByName(strmangle.TitleCase(col)).Interface() - } + updateMap := M{} + for _, col := range fields { + updateMap[col] = value.FieldByName(strmangle.TitleCase(col)).Interface() + } - slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} - if err = slice.UpdateAll(tx, updateMap); err != nil { - t.Error(err) - } + slice := {{$tableNameSingular}}Slice{{"{"}}{{$varNameSingular}}{{"}"}} + if err = slice.UpdateAll(tx, updateMap); err != nil { + t.Error(err) + } } diff --git a/templates_test/upsert.tpl b/templates_test/upsert.tpl index 00c667b79..de5623b0c 100644 --- a/templates_test/upsert.tpl +++ b/templates_test/upsert.tpl @@ -3,47 +3,47 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Upsert(t *testing.T) { - {{if not (eq .DriverName "postgres") -}} - t.Skip("not implemented for {{.DriverName}}") - {{end -}} - t.Parallel() + {{if not (eq .DriverName "postgres") -}} + t.Skip("not implemented for {{.DriverName}}") + {{end -}} + t.Parallel() - seed := randomize.NewSeed() - var err error - // Attempt the INSERT side of an UPSERT - {{$varNameSingular}} := {{$tableNameSingular}}{} - if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + seed := randomize.NewSeed() + var err error + // Attempt the INSERT side of an UPSERT + {{$varNameSingular}} := {{$tableNameSingular}}{} + if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, true); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - tx := MustTx(boil.Begin()) - defer tx.Rollback() - if err = {{$varNameSingular}}.Upsert(tx, false, nil, nil); err != nil { - t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err) - } + tx := MustTx(boil.Begin()) + defer tx.Rollback() + if err = {{$varNameSingular}}.Upsert(tx, {{if eq .DriverName "postgres"}}false, nil, {{end}}nil); err != nil { + t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err) + } - count, err := {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - if count != 1 { - t.Error("want one record, got:", count) - } + count, err := {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + if count != 1 { + t.Error("want one record, got:", count) + } - // Attempt the UPDATE side of an UPSERT - if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { - t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) - } + // Attempt the UPDATE side of an UPSERT + if err = randomize.Struct(seed, &{{$varNameSingular}}, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) + } - if err = {{$varNameSingular}}.Upsert(tx, true, nil, nil); err != nil { - t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err) - } + if err = {{$varNameSingular}}.Upsert(tx, {{if eq .DriverName "postgres"}}true, nil, {{end}}nil); err != nil { + t.Errorf("Unable to upsert {{$tableNameSingular}}: %s", err) + } - count, err = {{$tableNamePlural}}(tx).Count() - if err != nil { - t.Error(err) - } - if count != 1 { - t.Error("want one record, got:", count) - } + count, err = {{$tableNamePlural}}(tx).Count() + if err != nil { + t.Error(err) + } + if count != 1 { + t.Error("want one record, got:", count) + } } From 931f3d2de57f4081d1b7fb9bb838305778582adb Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Wed, 14 Sep 2016 18:27:20 +1000 Subject: [PATCH 46/64] Fix mock driver compat with upsert --- templates/14_upsert.tpl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/templates/14_upsert.tpl b/templates/14_upsert.tpl index 1f43c71d0..07656e20e 100644 --- a/templates/14_upsert.tpl +++ b/templates/14_upsert.tpl @@ -1,27 +1,27 @@ {{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} // UpsertG attempts an insert, and does an update or ignore on conflict. -func (o *{{$tableNameSingular}}) UpsertG({{if eq .DriverName "postgres"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { - return o.Upsert(boil.GetDB(), {{if eq .DriverName "postgres"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...) +func (o *{{$tableNameSingular}}) UpsertG({{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { + return o.Upsert(boil.GetDB(), {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...) } // UpsertGP attempts an insert, and does an update or ignore on conflict. Panics on error. -func (o *{{$tableNameSingular}}) UpsertGP({{if eq .DriverName "postgres"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { - if err := o.Upsert(boil.GetDB(), {{if eq .DriverName "postgres"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { +func (o *{{$tableNameSingular}}) UpsertGP({{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { + if err := o.Upsert(boil.GetDB(), {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { panic(boil.WrapErr(err)) } } // UpsertP attempts an insert using an executor, and does an update or ignore on conflict. // UpsertP panics on error. -func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, {{if eq .DriverName "postgres"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { - if err := o.Upsert(exec, {{if eq .DriverName "postgres"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { +func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, {{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { + if err := o.Upsert(exec, {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { panic(boil.WrapErr(err)) } } // Upsert attempts an insert using an executor, and does an update or ignore on conflict. -func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if eq .DriverName "postgres"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { +func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { if o == nil { return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for upsert") } @@ -49,7 +49,7 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if eq .DriverName updateColumns, ) - {{if eq .DriverName "postgres" -}} + {{if ne .DriverName "mysql" -}} conflict := conflictColumns if len(conflict) == 0 { conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) @@ -57,10 +57,10 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if eq .DriverName } {{- end}} - {{if eq .DriverName "postgres" -}} - query := boil.BuildUpsertQueryPostgres(dialect, "{{.Table.Name}}", updateOnConflict, ret, update, conflict, whitelist) - {{- else if eq .DriverName "mysql" -}} + {{if eq .DriverName "mysql" -}} query := boil.BuildUpsertQueryMySQL(dialect, "{{.Table.Name}}", update, whitelist) + {{- else -}} + query := boil.BuildUpsertQueryPostgres(dialect, "{{.Table.Name}}", updateOnConflict, ret, update, conflict, whitelist) {{- end}} if boil.DebugMode { From 7ce5ac18acb0874e210065a6f4c582adfea1c50a Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Wed, 14 Sep 2016 19:42:07 +1000 Subject: [PATCH 47/64] Add P versions of query exec funcs - Update readme --- README.md | 41 +++++++++-------------------------------- boil/query.go | 22 ++++++++++++++++++++++ 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index fa52f3eb4..9165fb372 100644 --- a/README.md +++ b/README.md @@ -456,36 +456,8 @@ err := models.NewQuery(db, From("pilots")).All() As you can see, [Query Mods](#query-mods) allow you to modify your queries, and [Finishers](#finishers) allow you to execute the final action. -If you plan on executing the same query with the same values using the query builder, -you should do so like the following to utilize caching: - -```go -// Instead of this: -for i := 0; i < 10; i++ { - pilots := models.Pilots(qm.Where("id > ?", 5), qm.Limit(5)).All() -} - -// You should do this -query := models.Pilots(qm.Where("id > ?", 5), qm.Limit(5)) -for i := 0; i < 10; i++ { - pilots := query.All() -} - -// Every execution of All() after the first will use a cached version of -// the built query that short circuits the query builder all together. -// This allows you to save on performance. - -// Just something to be aware of: query mods don't store pointers, so if -// your passed in variable's value changes, your generated query will not change. -``` - -Note: You will see exported `boil.SetX` methods in the boil package. These should not be used on query -objects because they will break caching. Unfortunately these had to be exported due to some circular -dependency issues, but they're not functionality we want exposed. If you want a different -query object, generate a new one. - -Take a look at our [Relationships Query Building](#relationships) section for some additional query -building information. +We also generate query building helper methods for your relationships as well. Take a look at our +[Relationships Query Building](#relationships) section for some additional query building information. ### Query Mod System @@ -579,6 +551,9 @@ UpdateAll(models.M{"name": "John", "age": 23}) // Update all rows matching the b DeleteAll() // Delete all rows matching the built query. Exists() // Returns a bool indicating whether the row(s) for the built query exists. Bind(&myObj) // Bind the results of a query to your own struct object. +ExecQuery() // Execute an SQL query that does not require any rows returned. Equivalent to `sql.Exec()`. +ExecQueryOne() // Execute an SQL query expected to return only a single row. Equivalent to `sql.QueryRow()`. +ExecQueryAll() // Execute an SQL query expected to return multiple rows. Equivalent to `sql.Query()`. ``` ### Raw Query @@ -593,8 +568,10 @@ err := boil.SQL(db, "select * from pilots where id=$1", 5).Bind(&obj) You can use your own structs or a generated struct as a parameter to Bind. Bind supports both a single object for single row queries and a slice of objects for multiple row queries. -You also have `models.NewQuery()` at your disposal if you would still like to use [Query Build](#query-building) -but would like to build against a non-generated model. +`boil.SQL()` also has a method that can execute a query without binding to an object, if required. + +You also have `models.NewQuery()` at your disposal if you would still like to use [Query Building](#query-building) +in combination with your own custom, non-generated model. ### Binding diff --git a/boil/query.go b/boil/query.go index 5ae5fabc7..bf8221568 100644 --- a/boil/query.go +++ b/boil/query.go @@ -126,6 +126,28 @@ func (q *Query) ExecQueryAll() (*sql.Rows, error) { return q.executor.Query(qs, args...) } +// ExecQueryP executes a query that does not need a row returned +// It will panic on error +func (q *Query) ExecQueryP() sql.Result { + res, err := q.ExecQuery() + if err != nil { + panic(WrapErr(err)) + } + + return res +} + +// ExecQueryAllP executes the query for the All finisher and returns multiple rows +// It will panic on error +func (q *Query) ExecQueryAllP() *sql.Rows { + rows, err := q.ExecQueryAll() + if err != nil { + panic(WrapErr(err)) + } + + return rows +} + // SetExecutor on the query. func SetExecutor(q *Query, exec Executor) { q.executor = exec From 7d377f42aeeea2c44dd51a119f12ae3c456e93ad Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Wed, 14 Sep 2016 20:42:20 +1000 Subject: [PATCH 48/64] Update readme --- README.md | 80 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 43 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 9165fb372..027f5380f 100644 --- a/README.md +++ b/README.md @@ -80,23 +80,28 @@ Table of Contents ### Features - Full model generation -- High performance through generation - Extremely fast code generation +- High performance through generation & intelligent caching - Uses boil.Executor (simple interface, sql.DB, sqlx.DB etc. compatible) - Easy workflow (models can always be regenerated, full auto-complete) - Strongly typed querying (usually no converting or binding to pointers) - Hooks (Before/After Create/Select/Update/Delete/Upsert) - Automatic CreatedAt/UpdatedAt +- Table whitelist/blacklist - Relationships/Associations - Eager loading (recursive) +- Custom struct tags +- Schema support - Transactions - Raw SQL fallback - Compatibility tests (Run against your own DB schema) - Debug logging +- Postgres 1d arrays, json, hstore & more ### Supported Databases - PostgreSQL +- MySQL *Note: Seeking contributors for other database engines.* @@ -203,36 +208,36 @@ order: - `$XDG_CONFIG_HOME/sqlboiler/` - `$HOME/.config/sqlboiler/` -We require you pass in the `postgres` configuration via the configuration file rather than env vars. -There is no command line argument support for database configuration. Values given under the `postgres` -block are passed directly to the [pq](github.com/lib/pq) driver. Here is a rundown of all the different +We require you pass in your `postgres` and `mysql` database configuration via the configuration file rather than env vars. +There is no command line argument support for database configuration. Values given under the `postgres` and `mysql` +block are passed directly to the postgres and mysql drivers. Here is a rundown of all the different values that can go in that section: -| Name | Required | Default | -| --- | --- | --- | -| dbname | yes | none | -| host | yes | none | -| port | no | 5432 | -| user | yes | none | -| pass | no | none | -| sslmode | no | "require" | +| Name | Required | Postgres Default | MySQL Default | +| --- | --- | --- | --- | +| dbname | yes | none | none | +| host | yes | none | none | +| port | no | 5432 | 3306 | +| user | yes | none | none | +| pass | no | none | none | +| sslmode | no | "require" | "true" | You can also pass in these top level configuration values if you would prefer not to pass them through the command line or environment variables: -| Name | Default | -| --- | --- | -| basedir | none | -| schema | "public" | -| pkgname | "models" | -| output | "models" | -| whitelist | [] | -| blacklist | [] | -| tag | [] | -| debug | false | -| no-hooks | false | -| no-tests | false | -| no-auto-timestamps | false | +| Name | Postgres Default | Mysql Default +| --- | --- | --- | +| basedir | none | none | +| schema | "public" | *N/A* | +| pkgname | "models" | "models" | +| output | "models" | "models" | +| whitelist | [] | [] | +| blacklist | [] | [] | +| tag | [] | [] | +| debug | false | false | +| no-hooks | false | false | +| no-tests | false | false | +| no-auto-timestamps | false | false | Example: @@ -258,7 +263,8 @@ Usage: sqlboiler [flags] Examples: -sqlboiler postgres + sqlboiler postgres + sqlboiler mysql Flags: -b, --blacklist stringSlice Do not include these tables in your generated package @@ -274,9 +280,9 @@ Flags: --no-tests Disable generated go test files ``` -Follow the steps below to do some basic model generation. Once we've generated -our models, we can run the compatibility tests which will exercise the entirety -of the generated code. This way we can ensure that our database is compatible +Follow the steps below to do some basic model generation. Once you've generated +your models, you can run the compatibility tests which will exercise the entirety +of the generated code. This way you can ensure that your database is compatible with SQLBoiler. If you find there are some failing tests, please check the [Diagnosing Problems](#diagnosing-problems) section. @@ -285,8 +291,7 @@ with SQLBoiler. If you find there are some failing tests, please check the sqlboiler -x goose_migrations postgres # Run the generated tests -go test ./models # This requires an administrator postgres user because of some - # voodoo we do to disable triggers for the generated test db +go test ./models ``` ## Diagnosing Problems @@ -296,7 +301,7 @@ The most common causes of problems and panics are: - Forgetting to exclude tables you do not want included in your generation, like migration tables. - Tables without a primary key. All tables require one. - Forgetting to put foreign key constraints on your columns that reference other tables. -- The compatibility tests that run against your own DB schema require a superuser, ensure the user +- The compatibility tests require privileges to create a database for testing purposes, ensure the user supplied in your `sqlboiler.toml` config has adequate privileges. - A nil or closed database handle. Ensure your passed in `boil.Executor` is not nil. - If you decide to use the `G` variant of functions instead, make sure you've initialized your @@ -349,9 +354,8 @@ ALTER TABLE pilot_languages ADD CONSTRAINT pilots_fkey FOREIGN KEY (pilot_id) RE ALTER TABLE pilot_languages ADD CONSTRAINT languages_fkey FOREIGN KEY (language_id) REFERENCES languages(id); ``` -The generated model structs for this schema look like the following. Note that I've included the relationship -structs as well so you can see how it all pieces together, but these are unexported and not something you should -ever need to touch directly: +The generated model structs for this schema look like the following. Note that we've included the relationship +structs as well so you can see how it all pieces together: ```go type Pilot struct { @@ -359,6 +363,7 @@ type Pilot struct { Name string `boil:"name" json:"name" toml:"name" yaml:"name"` R *pilotR `boil:"-" json:"-" toml:"-" yaml:"-"` + L pilotR `boil:"-" json:"-" toml:"-" yaml:"-"` } type pilotR struct { @@ -375,6 +380,7 @@ type Jet struct { Color string `boil:"color" json:"color" toml:"color" yaml:"color"` R *jetR `boil:"-" json:"-" toml:"-" yaml:"-"` + L jetR `boil:"-" json:"-" toml:"-" yaml:"-"` } type jetR struct { @@ -386,6 +392,7 @@ type Language struct { Language string `boil:"language" json:"language" toml:"language" yaml:"language"` R *languageR `boil:"-" json:"-" toml:"-" yaml:"-"` + L languageR `boil:"-" json:"-" toml:"-" yaml:"-"` } type languageR struct { @@ -987,7 +994,6 @@ The `conflictColumns` argument allows you to specify the `ON CONFLICT` columns f For MySQL, this param will not be generated. Note: Passing a different set of column values to the update component is not currently supported. -If this feature is important to you let us know and we can consider adding something for this. ### Reload In the event that your objects get out of sync with the database for whatever reason, @@ -997,7 +1003,7 @@ attached to the objects. ```go pilot, _ := models.FindPilot(db, 1) -// > Object becomes out of sync for some reason +// > Object becomes out of sync for some reason, perhaps async processing // Refresh the object with the latest data from the db err := pilot.Reload(db) From 1c28f761f1dda91fc580c786e1baa81dc582ea0c Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 15 Sep 2016 01:36:36 +1000 Subject: [PATCH 49/64] Upsert fixed --- templates/12_insert.tpl | 30 +++++-- templates/14_upsert.tpl | 111 +++++++++++++++--------- templates_test/main_test/mysql_main.tpl | 2 +- 3 files changed, 93 insertions(+), 50 deletions(-) diff --git a/templates/12_insert.tpl b/templates/12_insert.tpl index 03568a13a..38c4dbfd7 100644 --- a/templates/12_insert.tpl +++ b/templates/12_insert.tpl @@ -98,18 +98,34 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string } lastID, err := result.LastInsertId() - if err != nil || lastID == 0 || len({{$varNameSingular}}PrimaryKeyColumns) != 1 { + if err != nil { return ErrSyncFail } - if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, cache.retQuery) - fmt.Fprintln(boil.DebugWriter, lastID) + var identifierCols []interface{} + if lastID != 0 { + {{- $colName := index .Table.PKey.Columns 0 -}} + {{- $col := .Table.GetColumn $colName -}} + o.{{$colName | singular | titleCase}} = {{$col.Type}}(lastID) + identifierCols = []interface{}{lastID} + } else { + identifierCols = []interface{}{ + {{range .Table.PKey.Columns -}} + o.{{. | singular | titleCase}}, + {{end -}} + } } - err = exec.QueryRow(cache.retQuery, lastID).Scan(boil.PtrsFromMapping(value, cache.retMapping)...) - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}") + if lastID != 0 && len(cache.retMapping) == 1 { + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, cache.retQuery) + fmt.Fprintln(boil.DebugWriter, identifierCols...) + } + + err = exec.QueryRow(cache.retQuery, identifierCols...).Scan(boil.PtrsFromMapping(value, cache.retMapping)...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}") + } } {{else}} if len(cache.retMapping) != 0 { diff --git a/templates/14_upsert.tpl b/templates/14_upsert.tpl index 07656e20e..b993495f6 100644 --- a/templates/14_upsert.tpl +++ b/templates/14_upsert.tpl @@ -1,5 +1,6 @@ {{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} +{{- $schemaTable := .Table.Name | .SchemaTable -}} // UpsertG attempts an insert, and does an update or ignore on conflict. func (o *{{$tableNameSingular}}) UpsertG({{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { return o.Upsert(boil.GetDB(), {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...) @@ -8,7 +9,7 @@ func (o *{{$tableNameSingular}}) UpsertG({{if ne .DriverName "mysql"}}updateOnCo // UpsertGP attempts an insert, and does an update or ignore on conflict. Panics on error. func (o *{{$tableNameSingular}}) UpsertGP({{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { if err := o.Upsert(boil.GetDB(), {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { - panic(boil.WrapErr(err)) + panic(boil.WrapErr(err)) } } @@ -16,92 +17,118 @@ func (o *{{$tableNameSingular}}) UpsertGP({{if ne .DriverName "mysql"}}updateOnC // UpsertP panics on error. func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, {{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { if err := o.Upsert(exec, {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { - panic(boil.WrapErr(err)) + panic(boil.WrapErr(err)) } } // Upsert attempts an insert using an executor, and does an update or ignore on conflict. func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { if o == nil { - return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for upsert") + return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for upsert") } {{- template "timestamp_upsert_helper" . }} {{if not .NoHooks -}} if err := o.doBeforeUpsertHooks(exec); err != nil { - return err + return err } {{- end}} var err error var ret []string whitelist, ret = strmangle.InsertColumnSet( - {{$varNameSingular}}Columns, - {{$varNameSingular}}ColumnsWithDefault, - {{$varNameSingular}}ColumnsWithoutDefault, - boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), - whitelist, + {{$varNameSingular}}ColumnsWithDefault, + {{$varNameSingular}}Columns, + {{$varNameSingular}}ColumnsWithoutDefault, + boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), + whitelist, ) update := strmangle.UpdateColumnSet( - {{$varNameSingular}}Columns, - {{$varNameSingular}}PrimaryKeyColumns, - updateColumns, + {{$varNameSingular}}Columns, + {{$varNameSingular}}PrimaryKeyColumns, + updateColumns, ) {{if ne .DriverName "mysql" -}} conflict := conflictColumns if len(conflict) == 0 { - conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) - copy(conflict, {{$varNameSingular}}PrimaryKeyColumns) + conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) + copy(conflict, {{$varNameSingular}}PrimaryKeyColumns) } - {{- end}} - - {{if eq .DriverName "mysql" -}} - query := boil.BuildUpsertQueryMySQL(dialect, "{{.Table.Name}}", update, whitelist) + query := boil.BuildUpsertQueryPostgres(dialect, "{{$schemaTable}}", updateOnConflict, ret, update, conflict, whitelist) {{- else -}} - query := boil.BuildUpsertQueryPostgres(dialect, "{{.Table.Name}}", updateOnConflict, ret, update, conflict, whitelist) + query := boil.BuildUpsertQueryMySQL(dialect, "{{.Table.Name}}", update, whitelist) {{- end}} if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, query) - fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, whitelist...)) + fmt.Fprintln(boil.DebugWriter, query) + fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, whitelist...)) } {{- if .UseLastInsertID}} - res, err := exec.Exec(query, boil.GetStructValues(o, whitelist...)...) - {{- else}} - if len(ret) != 0 { - err = exec.QueryRow(query, boil.GetStructValues(o, whitelist...)...).Scan(boil.GetStructPointers(o, ret...)...) - } else { - _, err = exec.Exec(query, boil.GetStructValues(o, whitelist...)...) + result, err := exec.Exec(query, boil.GetStructValues(o, whitelist...)...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") } - {{- end}} - if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") + if len(ret) == 0 { + {{if not .NoHooks -}} + return o.doAfterUpsertHooks(exec) + {{else -}} + return nil + {{end -}} } - {{if .UseLastInsertID -}} - if len(ret) != 0 { - lid, err := res.LastInsertId() + lastID, err := result.LastInsertId() if err != nil { - return errors.Wrap(err, "{{.PkgName}}: unable to get last insert id for {{.Table.Name}}") + return ErrSyncFail } - {{$aipk := autoIncPrimaryKey .Table.Columns .Table.PKey}} - aipk := "{{$aipk.Name}}" - // if the update did not change anything, lid will be 0 - if lid == 0 && aipk == "" { - // do a select using all pkeys - } else if lid != 0 { - // do a select using all pkeys + lid + + var identifierCols []interface{} + if lastID != 0 { + {{- $colName := index .Table.PKey.Columns 0 -}} + {{- $col := .Table.GetColumn $colName -}} + o.{{$colName | singular | titleCase}} = {{$col.Type}}(lastID) + identifierCols = []interface{}{lastID} + } else { + identifierCols = []interface{}{ + {{range .Table.PKey.Columns -}} + o.{{. | singular | titleCase}}, + {{end -}} + } + } + + if lastID != 0 && len(ret) == 1 { + retQuery := fmt.Sprintf( + "SELECT %s FROM {{.LQ}}{{.Table.Name}}{{.RQ}} WHERE {{whereClause .LQ .RQ 0 .Table.PKey.Columns}}", + strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, ret), ","), + ) + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, ret) + fmt.Fprintln(boil.DebugWriter, identifierCols...) + } + + err = exec.QueryRow(retQuery, identifierCols...).Scan(boil.GetStructPointers(o, ret...)...) + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}") + } } + {{- else}} + if len(ret) != 0 { + err = exec.QueryRow(query, boil.GetStructValues(o, whitelist...)...).Scan(boil.GetStructPointers(o, ret...)...) + } else { + _, err = exec.Exec(query, boil.GetStructValues(o, whitelist...)...) + } + if err != nil { + return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") } {{- end}} {{if not .NoHooks -}} if err := o.doAfterUpsertHooks(exec); err != nil { - return err + return err } {{- end}} diff --git a/templates_test/main_test/mysql_main.tpl b/templates_test/main_test/mysql_main.tpl index 849aa9834..fc43d3d8f 100644 --- a/templates_test/main_test/mysql_main.tpl +++ b/templates_test/main_test/mysql_main.tpl @@ -40,7 +40,7 @@ func (m *mysqlTester) setup() error { return err } - dumpCmd := exec.Command("mysqldump", m.defaultsFile(), m.dbName) + dumpCmd := exec.Command("mysqldump", m.defaultsFile(), "--no-data", m.dbName) createCmd := exec.Command("mysql", m.defaultsFile(), "--database", m.testDBName) r, w := io.Pipe() From 702bb2095ef478e435272b7e2408754a4c931aa5 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 15 Sep 2016 01:45:28 +1000 Subject: [PATCH 50/64] Don't output the schema --- templates_test/main_test/postgres_main.tpl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/templates_test/main_test/postgres_main.tpl b/templates_test/main_test/postgres_main.tpl index cdd029e31..d1951438a 100644 --- a/templates_test/main_test/postgres_main.tpl +++ b/templates_test/main_test/postgres_main.tpl @@ -50,7 +50,7 @@ func (p *pgTester) setup() error { r, w := io.Pipe() dumpCmd.Stdout = w - createCmd.Stdin = io.TeeReader(newFKeyDestroyer(rgxPGFkey, r), os.Stdout) + createCmd.Stdin = newFKeyDestroyer(rgxPGFkey, r) if err = dumpCmd.Start(); err != nil { return errors.Wrap(err, "failed to start pg_dump command") From f6b4d3c6fd901577ab28d83f3574c3f9a1d143ae Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 15 Sep 2016 02:14:30 +1000 Subject: [PATCH 51/64] Rename exec funcs to conform to sql stdlib --- README.md | 6 +++--- boil/query.go | 24 ++++++++++++------------ boil/reflect.go | 2 +- templates/03_finishers.tpl | 4 ++-- templates/13_update.tpl | 2 +- templates/14_upsert.tpl | 2 +- templates/15_delete.tpl | 2 +- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 027f5380f..9b486df17 100644 --- a/README.md +++ b/README.md @@ -558,9 +558,9 @@ UpdateAll(models.M{"name": "John", "age": 23}) // Update all rows matching the b DeleteAll() // Delete all rows matching the built query. Exists() // Returns a bool indicating whether the row(s) for the built query exists. Bind(&myObj) // Bind the results of a query to your own struct object. -ExecQuery() // Execute an SQL query that does not require any rows returned. Equivalent to `sql.Exec()`. -ExecQueryOne() // Execute an SQL query expected to return only a single row. Equivalent to `sql.QueryRow()`. -ExecQueryAll() // Execute an SQL query expected to return multiple rows. Equivalent to `sql.Query()`. +Exec() // Execute an SQL query that does not require any rows returned. +QueryRow() // Execute an SQL query expected to return only a single row. +Query() // Execute an SQL query expected to return multiple rows. ``` ### Raw Query diff --git a/boil/query.go b/boil/query.go index bf8221568..e3366bfb7 100644 --- a/boil/query.go +++ b/boil/query.go @@ -96,8 +96,8 @@ func SQLG(query string, args ...interface{}) *Query { return SQL(GetDB(), query, args...) } -// ExecQuery executes a query that does not need a row returned -func (q *Query) ExecQuery() (sql.Result, error) { +// Exec executes a query that does not need a row returned +func (q *Query) Exec() (sql.Result, error) { qs, args := buildQuery(q) if DebugMode { fmt.Fprintln(DebugWriter, qs) @@ -106,8 +106,8 @@ func (q *Query) ExecQuery() (sql.Result, error) { return q.executor.Exec(qs, args...) } -// ExecQueryOne executes the query for the One finisher and returns a row -func (q *Query) ExecQueryOne() *sql.Row { +// QueryRow executes the query for the One finisher and returns a row +func (q *Query) QueryRow() *sql.Row { qs, args := buildQuery(q) if DebugMode { fmt.Fprintln(DebugWriter, qs) @@ -116,8 +116,8 @@ func (q *Query) ExecQueryOne() *sql.Row { return q.executor.QueryRow(qs, args...) } -// ExecQueryAll executes the query for the All finisher and returns multiple rows -func (q *Query) ExecQueryAll() (*sql.Rows, error) { +// Query executes the query for the All finisher and returns multiple rows +func (q *Query) Query() (*sql.Rows, error) { qs, args := buildQuery(q) if DebugMode { fmt.Fprintln(DebugWriter, qs) @@ -126,10 +126,10 @@ func (q *Query) ExecQueryAll() (*sql.Rows, error) { return q.executor.Query(qs, args...) } -// ExecQueryP executes a query that does not need a row returned +// ExecP executes a query that does not need a row returned // It will panic on error -func (q *Query) ExecQueryP() sql.Result { - res, err := q.ExecQuery() +func (q *Query) ExecP() sql.Result { + res, err := q.Exec() if err != nil { panic(WrapErr(err)) } @@ -137,10 +137,10 @@ func (q *Query) ExecQueryP() sql.Result { return res } -// ExecQueryAllP executes the query for the All finisher and returns multiple rows +// QueryP executes the query for the All finisher and returns multiple rows // It will panic on error -func (q *Query) ExecQueryAllP() *sql.Rows { - rows, err := q.ExecQueryAll() +func (q *Query) QueryP() *sql.Rows { + rows, err := q.Query() if err != nil { panic(WrapErr(err)) } diff --git a/boil/reflect.go b/boil/reflect.go index dd961bbc3..483ecd91c 100644 --- a/boil/reflect.go +++ b/boil/reflect.go @@ -100,7 +100,7 @@ func (q *Query) Bind(obj interface{}) error { return err } - rows, err := q.ExecQueryAll() + rows, err := q.Query() if err != nil { return errors.Wrap(err, "bind failed to execute query") } diff --git a/templates/03_finishers.tpl b/templates/03_finishers.tpl index 1d748e50b..b27f7742d 100644 --- a/templates/03_finishers.tpl +++ b/templates/03_finishers.tpl @@ -82,7 +82,7 @@ func (q {{$varNameSingular}}Query) Count() (int64, error) { boil.SetSelect(q.Query, nil) boil.SetCount(q.Query) - err := q.Query.ExecQueryOne().Scan(&count) + err := q.Query.QueryRow().Scan(&count) if err != nil { return 0, errors.Wrap(err, "{{.PkgName}}: failed to count {{.Table.Name}} rows") } @@ -107,7 +107,7 @@ func (q {{$varNameSingular}}Query) Exists() (bool, error) { boil.SetCount(q.Query) boil.SetLimit(q.Query, 1) - err := q.Query.ExecQueryOne().Scan(&count) + err := q.Query.QueryRow().Scan(&count) if err != nil { return false, errors.Wrap(err, "{{.PkgName}}: failed to check if {{.Table.Name}} exists") } diff --git a/templates/13_update.tpl b/templates/13_update.tpl index 219448fe2..5b2674dd8 100644 --- a/templates/13_update.tpl +++ b/templates/13_update.tpl @@ -107,7 +107,7 @@ func (q {{$varNameSingular}}Query) UpdateAllP(cols M) { func (q {{$varNameSingular}}Query) UpdateAll(cols M) error { boil.SetUpdate(q.Query, cols) - _, err := q.Query.ExecQuery() + _, err := q.Query.Exec() if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to update all for {{.Table.Name}}") } diff --git a/templates/14_upsert.tpl b/templates/14_upsert.tpl index b993495f6..b9900a59e 100644 --- a/templates/14_upsert.tpl +++ b/templates/14_upsert.tpl @@ -38,8 +38,8 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName var err error var ret []string whitelist, ret = strmangle.InsertColumnSet( - {{$varNameSingular}}ColumnsWithDefault, {{$varNameSingular}}Columns, + {{$varNameSingular}}ColumnsWithDefault, {{$varNameSingular}}ColumnsWithoutDefault, boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), whitelist, diff --git a/templates/15_delete.tpl b/templates/15_delete.tpl index b154705c4..7c80bc17d 100644 --- a/templates/15_delete.tpl +++ b/templates/15_delete.tpl @@ -80,7 +80,7 @@ func (q {{$varNameSingular}}Query) DeleteAll() error { boil.SetDelete(q.Query) - _, err := q.Query.ExecQuery() + _, err := q.Query.Exec() if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to delete all from {{.Table.Name}}") } From 5149df835996f6e7321dfd9a097785655c9c6c12 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Wed, 14 Sep 2016 20:45:09 -0700 Subject: [PATCH 52/64] Move everything to better package structure --- boil/{ => queries}/eager_load.go | 0 imports.go | 22 +-- {boil => queries}/_fixtures/00.sql | 0 {boil => queries}/_fixtures/01.sql | 0 {boil => queries}/_fixtures/02.sql | 0 {boil => queries}/_fixtures/03.sql | 0 {boil => queries}/_fixtures/04.sql | 0 {boil => queries}/_fixtures/05.sql | 0 {boil => queries}/_fixtures/06.sql | 0 {boil => queries}/_fixtures/07.sql | 0 {boil => queries}/_fixtures/08.sql | 0 {boil => queries}/_fixtures/09.sql | 0 {boil => queries}/_fixtures/10.sql | 0 {boil => queries}/_fixtures/11.sql | 0 {boil => queries}/_fixtures/12.sql | 0 {boil => queries}/_fixtures/13.sql | 0 {boil => queries}/_fixtures/14.sql | 0 {boil => queries}/_fixtures/15.sql | 0 queries/eager_load.go | 148 ++++++++++++++++++ {boil => queries}/eager_load_test.go | 14 +- {boil => queries}/helpers.go | 2 +- {boil => queries}/helpers_test.go | 2 +- {boil => queries}/qm/query_mods.go | 78 ++++----- {boil => queries}/query.go | 38 ++--- {boil => queries}/query_builders.go | 2 +- {boil => queries}/query_builders_test.go | 2 +- {boil => queries}/query_test.go | 2 +- {boil => queries}/reflect.go | 5 +- {boil => queries}/reflect_test.go | 2 +- {boil/randomize => randomize}/random.go | 0 {boil/randomize => randomize}/random_test.go | 0 {boil/randomize => randomize}/randomize.go | 2 +- .../randomize => randomize}/randomize_test.go | 0 sqlboiler.go | 4 +- templates.go | 4 +- {boil/types => types}/array.go | 0 {boil/types => types}/array_test.go | 0 {boil/types => types}/hstore.go | 0 {boil/types => types}/json.go | 0 {boil/types => types}/json_test.go | 0 40 files changed, 241 insertions(+), 86 deletions(-) rename boil/{ => queries}/eager_load.go (100%) rename {boil => queries}/_fixtures/00.sql (100%) rename {boil => queries}/_fixtures/01.sql (100%) rename {boil => queries}/_fixtures/02.sql (100%) rename {boil => queries}/_fixtures/03.sql (100%) rename {boil => queries}/_fixtures/04.sql (100%) rename {boil => queries}/_fixtures/05.sql (100%) rename {boil => queries}/_fixtures/06.sql (100%) rename {boil => queries}/_fixtures/07.sql (100%) rename {boil => queries}/_fixtures/08.sql (100%) rename {boil => queries}/_fixtures/09.sql (100%) rename {boil => queries}/_fixtures/10.sql (100%) rename {boil => queries}/_fixtures/11.sql (100%) rename {boil => queries}/_fixtures/12.sql (100%) rename {boil => queries}/_fixtures/13.sql (100%) rename {boil => queries}/_fixtures/14.sql (100%) rename {boil => queries}/_fixtures/15.sql (100%) create mode 100644 queries/eager_load.go rename {boil => queries}/eager_load_test.go (92%) rename {boil => queries}/helpers.go (97%) rename {boil => queries}/helpers_test.go (98%) rename {boil => queries}/qm/query_mods.go (66%) rename {boil => queries}/query.go (87%) rename {boil => queries}/query_builders.go (99%) rename {boil => queries}/query_builders_test.go (99%) rename {boil => queries}/query_test.go (99%) rename {boil => queries}/reflect.go (99%) rename {boil => queries}/reflect_test.go (99%) rename {boil/randomize => randomize}/random.go (100%) rename {boil/randomize => randomize}/random_test.go (100%) rename {boil/randomize => randomize}/randomize.go (99%) rename {boil/randomize => randomize}/randomize_test.go (100%) rename {boil/types => types}/array.go (100%) rename {boil/types => types}/array_test.go (100%) rename {boil/types => types}/hstore.go (100%) rename {boil/types => types}/json.go (100%) rename {boil/types => types}/json_test.go (100%) diff --git a/boil/eager_load.go b/boil/queries/eager_load.go similarity index 100% rename from boil/eager_load.go rename to boil/queries/eager_load.go diff --git a/imports.go b/imports.go index 9ffffa5d8..61ce351f1 100644 --- a/imports.go +++ b/imports.go @@ -153,7 +153,7 @@ var defaultTemplateImports = imports{ thirdParty: importList{ `"github.com/pkg/errors"`, `"github.com/vattle/sqlboiler/boil"`, - `"github.com/vattle/sqlboiler/boil/qm"`, + `"github.com/vattle/sqlboiler/queries/qm"`, `"github.com/vattle/sqlboiler/strmangle"`, }, } @@ -162,7 +162,7 @@ var defaultSingletonTemplateImports = map[string]imports{ "boil_queries": { thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, - `"github.com/vattle/sqlboiler/boil/qm"`, + `"github.com/vattle/sqlboiler/queries/qm"`, }, }, "boil_types": { @@ -180,7 +180,7 @@ var defaultTestTemplateImports = imports{ }, thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, - `"github.com/vattle/sqlboiler/boil/randomize"`, + `"github.com/vattle/sqlboiler/randomize"`, `"github.com/vattle/sqlboiler/strmangle"`, }, } @@ -240,7 +240,7 @@ var defaultTestMainImports = map[string]imports{ `"github.com/pkg/errors"`, `"github.com/spf13/viper"`, `"github.com/vattle/sqlboiler/bdb/drivers"`, - `"github.com/vattle/sqlboiler/boil/randomize"`, + `"github.com/vattle/sqlboiler/randomize"`, `_ "github.com/lib/pq"`, }, }, @@ -259,7 +259,7 @@ var defaultTestMainImports = map[string]imports{ `"github.com/pkg/errors"`, `"github.com/spf13/viper"`, `"github.com/vattle/sqlboiler/bdb/drivers"`, - `"github.com/vattle/sqlboiler/boil/randomize"`, + `"github.com/vattle/sqlboiler/randomize"`, `_ "github.com/go-sql-driver/mysql"`, }, }, @@ -324,21 +324,21 @@ var importsBasedOnType = map[string]imports{ standard: importList{`"time"`}, }, "types.JSON": { - thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, + thirdParty: importList{`"github.com/vattle/sqlboiler/types"`}, }, "types.BytesArray": { - thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, + thirdParty: importList{`"github.com/vattle/sqlboiler/types"`}, }, "types.Int64Array": { - thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, + thirdParty: importList{`"github.com/vattle/sqlboiler/types"`}, }, "types.Float64Array": { - thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, + thirdParty: importList{`"github.com/vattle/sqlboiler/types"`}, }, "types.BoolArray": { - thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, + thirdParty: importList{`"github.com/vattle/sqlboiler/types"`}, }, "types.Hstore": { - thirdParty: importList{`"github.com/vattle/sqlboiler/boil/types"`}, + thirdParty: importList{`"github.com/vattle/sqlboiler/types"`}, }, } diff --git a/boil/_fixtures/00.sql b/queries/_fixtures/00.sql similarity index 100% rename from boil/_fixtures/00.sql rename to queries/_fixtures/00.sql diff --git a/boil/_fixtures/01.sql b/queries/_fixtures/01.sql similarity index 100% rename from boil/_fixtures/01.sql rename to queries/_fixtures/01.sql diff --git a/boil/_fixtures/02.sql b/queries/_fixtures/02.sql similarity index 100% rename from boil/_fixtures/02.sql rename to queries/_fixtures/02.sql diff --git a/boil/_fixtures/03.sql b/queries/_fixtures/03.sql similarity index 100% rename from boil/_fixtures/03.sql rename to queries/_fixtures/03.sql diff --git a/boil/_fixtures/04.sql b/queries/_fixtures/04.sql similarity index 100% rename from boil/_fixtures/04.sql rename to queries/_fixtures/04.sql diff --git a/boil/_fixtures/05.sql b/queries/_fixtures/05.sql similarity index 100% rename from boil/_fixtures/05.sql rename to queries/_fixtures/05.sql diff --git a/boil/_fixtures/06.sql b/queries/_fixtures/06.sql similarity index 100% rename from boil/_fixtures/06.sql rename to queries/_fixtures/06.sql diff --git a/boil/_fixtures/07.sql b/queries/_fixtures/07.sql similarity index 100% rename from boil/_fixtures/07.sql rename to queries/_fixtures/07.sql diff --git a/boil/_fixtures/08.sql b/queries/_fixtures/08.sql similarity index 100% rename from boil/_fixtures/08.sql rename to queries/_fixtures/08.sql diff --git a/boil/_fixtures/09.sql b/queries/_fixtures/09.sql similarity index 100% rename from boil/_fixtures/09.sql rename to queries/_fixtures/09.sql diff --git a/boil/_fixtures/10.sql b/queries/_fixtures/10.sql similarity index 100% rename from boil/_fixtures/10.sql rename to queries/_fixtures/10.sql diff --git a/boil/_fixtures/11.sql b/queries/_fixtures/11.sql similarity index 100% rename from boil/_fixtures/11.sql rename to queries/_fixtures/11.sql diff --git a/boil/_fixtures/12.sql b/queries/_fixtures/12.sql similarity index 100% rename from boil/_fixtures/12.sql rename to queries/_fixtures/12.sql diff --git a/boil/_fixtures/13.sql b/queries/_fixtures/13.sql similarity index 100% rename from boil/_fixtures/13.sql rename to queries/_fixtures/13.sql diff --git a/boil/_fixtures/14.sql b/queries/_fixtures/14.sql similarity index 100% rename from boil/_fixtures/14.sql rename to queries/_fixtures/14.sql diff --git a/boil/_fixtures/15.sql b/queries/_fixtures/15.sql similarity index 100% rename from boil/_fixtures/15.sql rename to queries/_fixtures/15.sql diff --git a/queries/eager_load.go b/queries/eager_load.go new file mode 100644 index 000000000..5d4c9b1d4 --- /dev/null +++ b/queries/eager_load.go @@ -0,0 +1,148 @@ +package queries + +import ( + "database/sql" + "reflect" + + "github.com/pkg/errors" + "github.com/vattle/sqlboiler/boil" + "github.com/vattle/sqlboiler/strmangle" +) + +type loadRelationshipState struct { + exec boil.Executor + loaded map[string]struct{} + toLoad []string +} + +func (l loadRelationshipState) hasLoaded(depth int) bool { + _, ok := l.loaded[l.buildKey(depth)] + return ok +} + +func (l loadRelationshipState) setLoaded(depth int) { + l.loaded[l.buildKey(depth)] = struct{}{} +} + +func (l loadRelationshipState) buildKey(depth int) string { + buf := strmangle.GetBuffer() + + for i, piece := range l.toLoad[:depth+1] { + if i != 0 { + buf.WriteByte('.') + } + buf.WriteString(piece) + } + + str := buf.String() + strmangle.PutBuffer(buf) + return str +} + +// loadRelationships dynamically calls the template generated eager load +// functions of the form: +// +// func (t *TableR) LoadRelationshipName(exec Executor, singular bool, obj interface{}) +// +// The arguments to this function are: +// - t is not considered here, and is always passed nil. The function exists on a loaded +// struct to avoid a circular dependency with boil, and the receiver is ignored. +// - exec is used to perform additional queries that might be required for loading the relationships. +// - singular is passed in to identify whether or not this was a single object +// or a slice that must be loaded into. +// - obj is the object or slice of objects, always of the type *obj or *[]*obj as per bind. +// +// It takes list of nested relationships to load. +func (l loadRelationshipState) loadRelationships(depth int, obj interface{}, bkind bindKind) error { + typ := reflect.TypeOf(obj).Elem() + if bkind == kindPtrSliceStruct { + typ = typ.Elem().Elem() + } + + if !l.hasLoaded(depth) { + current := l.toLoad[depth] + ln, found := typ.FieldByName(loaderStructName) + // It's possible a Loaders struct doesn't exist on the struct. + if !found { + return errors.Errorf("attempted to load %s but no L struct was found", current) + } + + // Attempt to find the LoadRelationshipName function + loadMethod, found := ln.Type.MethodByName(loadMethodPrefix + current) + if !found { + return errors.Errorf("could not find %s%s method for eager loading", loadMethodPrefix, current) + } + + // Hack to allow nil executors + execArg := reflect.ValueOf(l.exec) + if !execArg.IsValid() { + execArg = reflect.ValueOf((*sql.DB)(nil)) + } + + val := reflect.ValueOf(obj).Elem() + if bkind == kindPtrSliceStruct { + val = val.Index(0).Elem() + } + + methodArgs := []reflect.Value{ + val.FieldByName(loaderStructName), + execArg, + reflect.ValueOf(bkind == kindStruct), + reflect.ValueOf(obj), + } + resp := loadMethod.Func.Call(methodArgs) + if intf := resp[0].Interface(); intf != nil { + return errors.Wrapf(intf.(error), "failed to eager load %s", current) + } + + l.setLoaded(depth) + } + + // Pull one off the queue, continue if there's still some to go + depth++ + if depth >= len(l.toLoad) { + return nil + } + + loadedObject := reflect.ValueOf(obj) + // If we eagerly loaded nothing + if loadedObject.IsNil() { + return nil + } + loadedObject = reflect.Indirect(loadedObject) + + // If it's singular we can just immediately call without looping + if bkind == kindStruct { + return l.loadRelationshipsRecurse(depth, loadedObject) + } + + // Loop over all eager loaded objects + ln := loadedObject.Len() + if ln == 0 { + return nil + } + for i := 0; i < ln; i++ { + iter := loadedObject.Index(i).Elem() + if err := l.loadRelationshipsRecurse(depth, iter); err != nil { + return err + } + } + + return nil +} + +// loadRelationshipsRecurse is a helper function for taking a reflect.Value and +// Basically calls loadRelationships with: obj.R.EagerLoadedObj, and whether it's a string or slice +func (l loadRelationshipState) loadRelationshipsRecurse(depth int, obj reflect.Value) error { + r := obj.FieldByName(relationshipStructName) + if !r.IsValid() || r.IsNil() { + return errors.Errorf("could not traverse into loaded %s relationship to load more things", l.toLoad[depth]) + } + newObj := reflect.Indirect(r).FieldByName(l.toLoad[depth]) + bkind := kindStruct + if reflect.Indirect(newObj).Kind() != reflect.Struct { + bkind = kindPtrSliceStruct + newObj = newObj.Addr() + } + return l.loadRelationships(depth, newObj.Interface(), bkind) +} diff --git a/boil/eager_load_test.go b/queries/eager_load_test.go similarity index 92% rename from boil/eager_load_test.go rename to queries/eager_load_test.go index 86ada1aa2..282bff019 100644 --- a/boil/eager_load_test.go +++ b/queries/eager_load_test.go @@ -1,6 +1,10 @@ -package boil +package queries -import "testing" +import ( + "testing" + + "github.com/vattle/sqlboiler/boil" +) var loadFunctionCalled bool var loadFunctionNestedCalled int @@ -32,12 +36,12 @@ type testNestedRSlice struct { type testNestedLSlice struct { } -func (testLStruct) LoadTestOne(exec Executor, singular bool, obj interface{}) error { +func (testLStruct) LoadTestOne(exec boil.Executor, singular bool, obj interface{}) error { loadFunctionCalled = true return nil } -func (testNestedLStruct) LoadToEagerLoad(exec Executor, singular bool, obj interface{}) error { +func (testNestedLStruct) LoadToEagerLoad(exec boil.Executor, singular bool, obj interface{}) error { switch x := obj.(type) { case *testNestedStruct: x.R = &testNestedRStruct{ @@ -54,7 +58,7 @@ func (testNestedLStruct) LoadToEagerLoad(exec Executor, singular bool, obj inter return nil } -func (testNestedLSlice) LoadToEagerLoad(exec Executor, singular bool, obj interface{}) error { +func (testNestedLSlice) LoadToEagerLoad(exec boil.Executor, singular bool, obj interface{}) error { switch x := obj.(type) { case *testNestedSlice: diff --git a/boil/helpers.go b/queries/helpers.go similarity index 97% rename from boil/helpers.go rename to queries/helpers.go index 43e3ff2ce..59ad8a3ff 100644 --- a/boil/helpers.go +++ b/queries/helpers.go @@ -1,4 +1,4 @@ -package boil +package queries import ( "fmt" diff --git a/boil/helpers_test.go b/queries/helpers_test.go similarity index 98% rename from boil/helpers_test.go rename to queries/helpers_test.go index 73d284e68..c87bc7602 100644 --- a/boil/helpers_test.go +++ b/queries/helpers_test.go @@ -1,4 +1,4 @@ -package boil +package queries import ( "reflect" diff --git a/boil/qm/query_mods.go b/queries/qm/query_mods.go similarity index 66% rename from boil/qm/query_mods.go rename to queries/qm/query_mods.go index 5cbf3c63d..b2e7e14f6 100644 --- a/boil/qm/query_mods.go +++ b/queries/qm/query_mods.go @@ -1,12 +1,12 @@ package qm -import "github.com/vattle/sqlboiler/boil" +import "github.com/vattle/sqlboiler/queries" // QueryMod to modify the query object -type QueryMod func(q *boil.Query) +type QueryMod func(q *queries.Query) // Apply the query mods to the Query object -func Apply(q *boil.Query, mods ...QueryMod) { +func Apply(q *queries.Query, mods ...QueryMod) { for _, mod := range mods { mod(q) } @@ -14,8 +14,8 @@ func Apply(q *boil.Query, mods ...QueryMod) { // SQL allows you to execute a plain SQL statement func SQL(sql string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.SetSQL(q, sql, args...) + return func(q *queries.Query) { + queries.SetSQL(q, sql, args...) } } @@ -25,29 +25,29 @@ func SQL(sql string, args ...interface{}) QueryMod { // Relationship name plurality is important, if your relationship is // singular, you need to specify the singular form and vice versa. func Load(relationships ...string) QueryMod { - return func(q *boil.Query) { - boil.SetLoad(q, relationships...) + return func(q *queries.Query) { + queries.SetLoad(q, relationships...) } } // InnerJoin on another table func InnerJoin(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendInnerJoin(q, clause, args...) + return func(q *queries.Query) { + queries.AppendInnerJoin(q, clause, args...) } } // Select specific columns opposed to all columns func Select(columns ...string) QueryMod { - return func(q *boil.Query) { - boil.AppendSelect(q, columns...) + return func(q *queries.Query) { + queries.AppendSelect(q, columns...) } } // Where allows you to specify a where clause for your statement func Where(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendWhere(q, clause, args...) + return func(q *queries.Query) { + queries.AppendWhere(q, clause, args...) } } @@ -55,24 +55,24 @@ func Where(clause string, args ...interface{}) QueryMod { // And is a duplicate of the Where function, but allows for more natural looking // query mod chains, for example: (Where("a=?"), And("b=?"), Or("c=?"))) func And(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendWhere(q, clause, args...) + return func(q *queries.Query) { + queries.AppendWhere(q, clause, args...) } } // Or allows you to specify a where clause separated by an OR for your statement func Or(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendWhere(q, clause, args...) - boil.SetLastWhereAsOr(q) + return func(q *queries.Query) { + queries.AppendWhere(q, clause, args...) + queries.SetLastWhereAsOr(q) } } // WhereIn allows you to specify a "x IN (set)" clause for your where statement // Example clauses: "column in ?", "(column1,column2) in ?" func WhereIn(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendIn(q, clause, args...) + return func(q *queries.Query) { + queries.AppendIn(q, clause, args...) } } @@ -81,65 +81,65 @@ func WhereIn(clause string, args ...interface{}) QueryMod { // allows for more natural looking query mod chains, for example: // (WhereIn("column1 in ?"), AndIn("column2 in ?"), OrIn("column3 in ?")) func AndIn(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendIn(q, clause, args...) + return func(q *queries.Query) { + queries.AppendIn(q, clause, args...) } } // OrIn allows you to specify an IN clause separated by // an OR for your where statement func OrIn(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendIn(q, clause, args...) - boil.SetLastInAsOr(q) + return func(q *queries.Query) { + queries.AppendIn(q, clause, args...) + queries.SetLastInAsOr(q) } } // GroupBy allows you to specify a group by clause for your statement func GroupBy(clause string) QueryMod { - return func(q *boil.Query) { - boil.AppendGroupBy(q, clause) + return func(q *queries.Query) { + queries.AppendGroupBy(q, clause) } } // OrderBy allows you to specify a order by clause for your statement func OrderBy(clause string) QueryMod { - return func(q *boil.Query) { - boil.AppendOrderBy(q, clause) + return func(q *queries.Query) { + queries.AppendOrderBy(q, clause) } } // Having allows you to specify a having clause for your statement func Having(clause string, args ...interface{}) QueryMod { - return func(q *boil.Query) { - boil.AppendHaving(q, clause, args...) + return func(q *queries.Query) { + queries.AppendHaving(q, clause, args...) } } // From allows to specify the table for your statement func From(from string) QueryMod { - return func(q *boil.Query) { - boil.AppendFrom(q, from) + return func(q *queries.Query) { + queries.AppendFrom(q, from) } } // Limit the number of returned rows func Limit(limit int) QueryMod { - return func(q *boil.Query) { - boil.SetLimit(q, limit) + return func(q *queries.Query) { + queries.SetLimit(q, limit) } } // Offset into the results func Offset(offset int) QueryMod { - return func(q *boil.Query) { - boil.SetOffset(q, offset) + return func(q *queries.Query) { + queries.SetOffset(q, offset) } } // For inserts a concurrency locking clause at the end of your statement func For(clause string) QueryMod { - return func(q *boil.Query) { - boil.SetFor(q, clause) + return func(q *queries.Query) { + queries.SetFor(q, clause) } } diff --git a/boil/query.go b/queries/query.go similarity index 87% rename from boil/query.go rename to queries/query.go index e3366bfb7..2a22be18f 100644 --- a/boil/query.go +++ b/queries/query.go @@ -1,8 +1,10 @@ -package boil +package queries import ( "database/sql" "fmt" + + "github.com/vattle/sqlboiler/boil" ) // joinKind is the type of join @@ -18,7 +20,7 @@ const ( // Query holds the state for the built up query type Query struct { - executor Executor + executor boil.Executor dialect *Dialect plainSQL plainSQL load []string @@ -81,7 +83,7 @@ type join struct { } // SQL makes a plainSQL query, usually for use with bind -func SQL(exec Executor, query string, args ...interface{}) *Query { +func SQL(exec boil.Executor, query string, args ...interface{}) *Query { return &Query{ executor: exec, plainSQL: plainSQL{ @@ -91,17 +93,17 @@ func SQL(exec Executor, query string, args ...interface{}) *Query { } } -// SQLG makes a plainSQL query using the global Executor, usually for use with bind +// SQLG makes a plainSQL query using the global boil.Executor, usually for use with bind func SQLG(query string, args ...interface{}) *Query { - return SQL(GetDB(), query, args...) + return SQL(boil.GetDB(), query, args...) } // Exec executes a query that does not need a row returned func (q *Query) Exec() (sql.Result, error) { qs, args := buildQuery(q) - if DebugMode { - fmt.Fprintln(DebugWriter, qs) - fmt.Fprintln(DebugWriter, args) + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, qs) + fmt.Fprintln(boil.DebugWriter, args) } return q.executor.Exec(qs, args...) } @@ -109,9 +111,9 @@ func (q *Query) Exec() (sql.Result, error) { // QueryRow executes the query for the One finisher and returns a row func (q *Query) QueryRow() *sql.Row { qs, args := buildQuery(q) - if DebugMode { - fmt.Fprintln(DebugWriter, qs) - fmt.Fprintln(DebugWriter, args) + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, qs) + fmt.Fprintln(boil.DebugWriter, args) } return q.executor.QueryRow(qs, args...) } @@ -119,9 +121,9 @@ func (q *Query) QueryRow() *sql.Row { // Query executes the query for the All finisher and returns multiple rows func (q *Query) Query() (*sql.Rows, error) { qs, args := buildQuery(q) - if DebugMode { - fmt.Fprintln(DebugWriter, qs) - fmt.Fprintln(DebugWriter, args) + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, qs) + fmt.Fprintln(boil.DebugWriter, args) } return q.executor.Query(qs, args...) } @@ -131,7 +133,7 @@ func (q *Query) Query() (*sql.Rows, error) { func (q *Query) ExecP() sql.Result { res, err := q.Exec() if err != nil { - panic(WrapErr(err)) + panic(boil.WrapErr(err)) } return res @@ -142,19 +144,19 @@ func (q *Query) ExecP() sql.Result { func (q *Query) QueryP() *sql.Rows { rows, err := q.Query() if err != nil { - panic(WrapErr(err)) + panic(boil.WrapErr(err)) } return rows } // SetExecutor on the query. -func SetExecutor(q *Query, exec Executor) { +func SetExecutor(q *Query, exec boil.Executor) { q.executor = exec } // GetExecutor on the query. -func GetExecutor(q *Query) Executor { +func GetExecutor(q *Query) boil.Executor { return q.executor } diff --git a/boil/query_builders.go b/queries/query_builders.go similarity index 99% rename from boil/query_builders.go rename to queries/query_builders.go index bbfe288a2..0937bdc97 100644 --- a/boil/query_builders.go +++ b/queries/query_builders.go @@ -1,4 +1,4 @@ -package boil +package queries import ( "bytes" diff --git a/boil/query_builders_test.go b/queries/query_builders_test.go similarity index 99% rename from boil/query_builders_test.go rename to queries/query_builders_test.go index db642b092..9af45da09 100644 --- a/boil/query_builders_test.go +++ b/queries/query_builders_test.go @@ -1,4 +1,4 @@ -package boil +package queries import ( "bytes" diff --git a/boil/query_test.go b/queries/query_test.go similarity index 99% rename from boil/query_test.go rename to queries/query_test.go index 005e834f5..08778f6b0 100644 --- a/boil/query_test.go +++ b/queries/query_test.go @@ -1,4 +1,4 @@ -package boil +package queries import ( "database/sql" diff --git a/boil/reflect.go b/queries/reflect.go similarity index 99% rename from boil/reflect.go rename to queries/reflect.go index 483ecd91c..bee0a123e 100644 --- a/boil/reflect.go +++ b/queries/reflect.go @@ -1,4 +1,4 @@ -package boil +package queries import ( "database/sql" @@ -8,6 +8,7 @@ import ( "sync" "github.com/pkg/errors" + "github.com/vattle/sqlboiler/boil" "github.com/vattle/sqlboiler/strmangle" ) @@ -40,7 +41,7 @@ const ( // It panics on error. See boil.Bind() documentation. func (q *Query) BindP(obj interface{}) { if err := q.Bind(obj); err != nil { - panic(WrapErr(err)) + panic(boil.WrapErr(err)) } } diff --git a/boil/reflect_test.go b/queries/reflect_test.go similarity index 99% rename from boil/reflect_test.go rename to queries/reflect_test.go index 279e98b9f..7641557f9 100644 --- a/boil/reflect_test.go +++ b/queries/reflect_test.go @@ -1,4 +1,4 @@ -package boil +package queries import ( "database/sql/driver" diff --git a/boil/randomize/random.go b/randomize/random.go similarity index 100% rename from boil/randomize/random.go rename to randomize/random.go diff --git a/boil/randomize/random_test.go b/randomize/random_test.go similarity index 100% rename from boil/randomize/random_test.go rename to randomize/random_test.go diff --git a/boil/randomize/randomize.go b/randomize/randomize.go similarity index 99% rename from boil/randomize/randomize.go rename to randomize/randomize.go index e12193e10..3f5ae295c 100644 --- a/boil/randomize/randomize.go +++ b/randomize/randomize.go @@ -17,8 +17,8 @@ import ( "github.com/lib/pq/hstore" "github.com/pkg/errors" "github.com/satori/go.uuid" - "github.com/vattle/sqlboiler/boil/types" "github.com/vattle/sqlboiler/strmangle" + "github.com/vattle/sqlboiler/types" ) var ( diff --git a/boil/randomize/randomize_test.go b/randomize/randomize_test.go similarity index 100% rename from boil/randomize/randomize_test.go rename to randomize/randomize_test.go diff --git a/sqlboiler.go b/sqlboiler.go index ae6b6d8fa..0c52dcea1 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -15,7 +15,7 @@ import ( "github.com/pkg/errors" "github.com/vattle/sqlboiler/bdb" "github.com/vattle/sqlboiler/bdb/drivers" - "github.com/vattle/sqlboiler/boil" + "github.com/vattle/sqlboiler/queries" "github.com/vattle/sqlboiler/strmangle" ) @@ -35,7 +35,7 @@ type State struct { Driver bdb.Interface Tables []bdb.Table - Dialect boil.Dialect + Dialect queries.Dialect Templates *templateList TestTemplates *templateList diff --git a/templates.go b/templates.go index a1da91768..12108bbd2 100644 --- a/templates.go +++ b/templates.go @@ -8,7 +8,7 @@ import ( "text/template" "github.com/vattle/sqlboiler/bdb" - "github.com/vattle/sqlboiler/boil" + "github.com/vattle/sqlboiler/queries" "github.com/vattle/sqlboiler/strmangle" ) @@ -36,7 +36,7 @@ type templateData struct { StringFuncs map[string]func(string) string // Dialect controls quoting - Dialect boil.Dialect + Dialect queries.Dialect LQ string RQ string } diff --git a/boil/types/array.go b/types/array.go similarity index 100% rename from boil/types/array.go rename to types/array.go diff --git a/boil/types/array_test.go b/types/array_test.go similarity index 100% rename from boil/types/array_test.go rename to types/array_test.go diff --git a/boil/types/hstore.go b/types/hstore.go similarity index 100% rename from boil/types/hstore.go rename to types/hstore.go diff --git a/boil/types/json.go b/types/json.go similarity index 100% rename from boil/types/json.go rename to types/json.go diff --git a/boil/types/json_test.go b/types/json_test.go similarity index 100% rename from boil/types/json_test.go rename to types/json_test.go From 12967f7b663d6c2cf24f526884a2142e7f39f5f9 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Wed, 14 Sep 2016 20:57:07 -0700 Subject: [PATCH 53/64] Fix up the interface to raw queries. --- bdb/interface.go | 2 +- queries/query.go | 18 +++++++++--------- queries/query_builders.go | 8 ++++---- queries/query_test.go | 24 ++++++++++++------------ 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/bdb/interface.go b/bdb/interface.go index ab32c8c1a..0cfee7369 100644 --- a/bdb/interface.go +++ b/bdb/interface.go @@ -24,7 +24,7 @@ type Interface interface { Close() // Dialect helpers, these provide the values that will go into - // a boil.Dialect, so the query builder knows how to support + // a queries.Dialect, so the query builder knows how to support // your database driver properly. LeftQuote() byte RightQuote() byte diff --git a/queries/query.go b/queries/query.go index 2a22be18f..81237a85a 100644 --- a/queries/query.go +++ b/queries/query.go @@ -22,7 +22,7 @@ const ( type Query struct { executor boil.Executor dialect *Dialect - plainSQL plainSQL + rawSQL rawSQL load []string delete bool update map[string]interface{} @@ -71,7 +71,7 @@ type having struct { args []interface{} } -type plainSQL struct { +type rawSQL struct { sql string args []interface{} } @@ -82,20 +82,20 @@ type join struct { args []interface{} } -// SQL makes a plainSQL query, usually for use with bind -func SQL(exec boil.Executor, query string, args ...interface{}) *Query { +// Raw makes a raw query, usually for use with bind +func Raw(exec boil.Executor, query string, args ...interface{}) *Query { return &Query{ executor: exec, - plainSQL: plainSQL{ + rawSQL: rawSQL{ sql: query, args: args, }, } } -// SQLG makes a plainSQL query using the global boil.Executor, usually for use with bind -func SQLG(query string, args ...interface{}) *Query { - return SQL(boil.GetDB(), query, args...) +// RawG makes a raw query using the global boil.Executor, usually for use with bind +func RawG(query string, args ...interface{}) *Query { + return Raw(boil.GetDB(), query, args...) } // Exec executes a query that does not need a row returned @@ -167,7 +167,7 @@ func SetDialect(q *Query, dialect *Dialect) { // SetSQL on the query. func SetSQL(q *Query, sql string, args ...interface{}) { - q.plainSQL = plainSQL{sql: sql, args: args} + q.rawSQL = rawSQL{sql: sql, args: args} } // SetLoad on the query. diff --git a/queries/query_builders.go b/queries/query_builders.go index 0937bdc97..8997eaa92 100644 --- a/queries/query_builders.go +++ b/queries/query_builders.go @@ -20,8 +20,8 @@ func buildQuery(q *Query) (string, []interface{}) { var args []interface{} switch { - case len(q.plainSQL.sql) != 0: - return q.plainSQL.sql, q.plainSQL.args + case len(q.rawSQL.sql) != 0: + return q.rawSQL.sql, q.rawSQL.args case q.delete: buf, args = buildDeleteQuery(q) case len(q.update) > 0: @@ -34,8 +34,8 @@ func buildQuery(q *Query) (string, []interface{}) { // Cache the generated query for query object re-use bufStr := buf.String() - q.plainSQL.sql = bufStr - q.plainSQL.args = args + q.rawSQL.sql = bufStr + q.rawSQL.args = args return bufStr, args } diff --git a/queries/query_test.go b/queries/query_test.go index 08778f6b0..97f8a9d7c 100644 --- a/queries/query_test.go +++ b/queries/query_test.go @@ -36,12 +36,12 @@ func TestSetSQL(t *testing.T) { q := &Query{} SetSQL(q, "select * from thing", 5, 3) - if len(q.plainSQL.args) != 2 { - t.Errorf("Expected len 2, got %d", len(q.plainSQL.args)) + if len(q.rawSQL.args) != 2 { + t.Errorf("Expected len 2, got %d", len(q.rawSQL.args)) } - if q.plainSQL.sql != "select * from thing" { - t.Errorf("Was not expected string, got %s", q.plainSQL.sql) + if q.rawSQL.sql != "select * from thing" { + t.Errorf("Was not expected string, got %s", q.rawSQL.sql) } } @@ -374,11 +374,11 @@ func TestSQL(t *testing.T) { t.Parallel() q := SQL(&sql.DB{}, "thing", 5) - if q.plainSQL.sql != "thing" { - t.Errorf("Expected %q, got %s", "thing", q.plainSQL.sql) + if q.rawSQL.sql != "thing" { + t.Errorf("Expected %q, got %s", "thing", q.rawSQL.sql) } - if q.plainSQL.args[0].(int) != 5 { - t.Errorf("Expected 5, got %v", q.plainSQL.args[0]) + if q.rawSQL.args[0].(int) != 5 { + t.Errorf("Expected 5, got %v", q.rawSQL.args[0]) } } @@ -386,11 +386,11 @@ func TestSQLG(t *testing.T) { t.Parallel() q := SQLG("thing", 5) - if q.plainSQL.sql != "thing" { - t.Errorf("Expected %q, got %s", "thing", q.plainSQL.sql) + if q.rawSQL.sql != "thing" { + t.Errorf("Expected %q, got %s", "thing", q.rawSQL.sql) } - if q.plainSQL.args[0].(int) != 5 { - t.Errorf("Expected 5, got %v", q.plainSQL.args[0]) + if q.rawSQL.args[0].(int) != 5 { + t.Errorf("Expected 5, got %v", q.rawSQL.args[0]) } } From f803cdd6bd32e51ce22f1220ec1e1bd38950365d Mon Sep 17 00:00:00 2001 From: Aaron L Date: Wed, 14 Sep 2016 20:59:55 -0700 Subject: [PATCH 54/64] Fix all references to moved elements. --- README.md | 20 ++++++++++---------- imports.go | 2 ++ templates/01_types.tpl | 4 ++-- templates/03_finishers.tpl | 14 +++++++------- templates/04_relationship_to_one.tpl | 2 +- templates/05_relationship_to_many.tpl | 2 +- templates/06_relationship_to_one_eager.tpl | 2 +- templates/07_relationship_to_many_eager.tpl | 2 +- templates/11_find.tpl | 2 +- templates/12_insert.tpl | 12 ++++++------ templates/13_update.tpl | 6 +++--- templates/14_upsert.tpl | 16 ++++++++-------- templates/15_delete.tpl | 2 +- templates/16_reload.tpl | 2 +- templates/singleton/boil_queries.tpl | 12 ++++++------ 15 files changed, 51 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 9b486df17..6d39fbb04 100644 --- a/README.md +++ b/README.md @@ -291,7 +291,7 @@ with SQLBoiler. If you find there are some failing tests, please check the sqlboiler -x goose_migrations postgres # Run the generated tests -go test ./models +go test ./models ``` ## Diagnosing Problems @@ -425,7 +425,7 @@ Note: You can set the timezone for this feature by calling `boil.SetLocation()` This is somewhat of a work around until we can devise a better solution in a later version. * **Update** * The `updated_at` column will always be set to `time.Now()`. If you need to override - this value you will need to fall back to another method in the meantime: `boil.SQL()`, + this value you will need to fall back to another method in the meantime: `queries.Raw()`, overriding `updated_at` in all of your objects using a hook, or create your own wrapper. * **Upsert** * `created_at` will be set automatically if it is a zero value, otherwise your supplied value @@ -463,7 +463,7 @@ err := models.NewQuery(db, From("pilots")).All() As you can see, [Query Mods](#query-mods) allow you to modify your queries, and [Finishers](#finishers) allow you to execute the final action. -We also generate query building helper methods for your relationships as well. Take a look at our +We also generate query building helper methods for your relationships as well. Take a look at our [Relationships Query Building](#relationships) section for some additional query building information. @@ -565,17 +565,17 @@ Query() // Execute an SQL query expected to return multiple rows. ### Raw Query -We provide `boil.SQL()` for executing raw queries. Generally you will want to use `Bind()` with +We provide `queries.Raw()` for executing raw queries. Generally you will want to use `Bind()` with this, like the following: ```go -err := boil.SQL(db, "select * from pilots where id=$1", 5).Bind(&obj) +err := queries.Raw(db, "select * from pilots where id=$1", 5).Bind(&obj) ``` You can use your own structs or a generated struct as a parameter to Bind. Bind supports both a single object for single row queries and a slice of objects for multiple row queries. -`boil.SQL()` also has a method that can execute a query without binding to an object, if required. +`queries.Raw()` also has a method that can execute a query without binding to an object, if required. You also have `models.NewQuery()` at your disposal if you would still like to use [Query Building](#query-building) in combination with your own custom, non-generated model. @@ -601,7 +601,7 @@ type PilotAndJet struct { var paj PilotAndJet // Use a raw query -err := boil.SQL(` +err := queries.Raw(` select pilots.id as "pilots.id", pilots.name as "pilots.name", jets.id as "jets.id", jets.pilot_id as "jets.pilot_id", jets.age as "jets.age", jets.name as "jets.name", jets.color as "jets.color" @@ -629,7 +629,7 @@ var info JetInfo err := models.NewQuery(db, Select("sum(age) as age_sum", "count(*) as juicy_count", From("jets"))).Bind(&info) // Use a raw query -err := boil.SQL(`select sum(age) as "age_sum", count(*) as "juicy_count" from jets`).Bind(&info) +err := queries.Raw(`select sum(age) as "age_sum", count(*) as "juicy_count" from jets`).Bind(&info) ``` We support the following struct tag modes for `Bind()` control: @@ -990,7 +990,7 @@ err := p1.Upsert(db, true, []string{"id"}, []string{"name"}, "id", "name") The `updateOnConflict` argument allows you to specify whether you would like Postgres to perform a `DO NOTHING` on conflict, opposed to a `DO UPDATE`. For MySQL, this param will not be generated. -The `conflictColumns` argument allows you to specify the `ON CONFLICT` columns for Postgres. +The `conflictColumns` argument allows you to specify the `ON CONFLICT` columns for Postgres. For MySQL, this param will not be generated. Note: Passing a different set of column values to the update component is not currently supported. @@ -1062,7 +1062,7 @@ Please note that multi-dimensional Postgres ARRAY types are not supported at thi #### Where is the homepage? -The homepage for the [SQLBoiler](https://github.com/vattle/sqlboiler) [Golang ORM](https://github.com/vattle/sqlboiler) generator is located at: https://github.com/vattle/sqlboiler +The homepage for the [SQLBoiler](https://github.com/vattle/sqlboiler) [Golang ORM](https://github.com/vattle/sqlboiler) generator is located at: https://github.com/vattle/sqlboiler ## Benchmarks diff --git a/imports.go b/imports.go index 61ce351f1..bf9693055 100644 --- a/imports.go +++ b/imports.go @@ -153,6 +153,7 @@ var defaultTemplateImports = imports{ thirdParty: importList{ `"github.com/pkg/errors"`, `"github.com/vattle/sqlboiler/boil"`, + `"github.com/vattle/sqlboiler/queries"`, `"github.com/vattle/sqlboiler/queries/qm"`, `"github.com/vattle/sqlboiler/strmangle"`, }, @@ -162,6 +163,7 @@ var defaultSingletonTemplateImports = map[string]imports{ "boil_queries": { thirdParty: importList{ `"github.com/vattle/sqlboiler/boil"`, + `"github.com/vattle/sqlboiler/queries"`, `"github.com/vattle/sqlboiler/queries/qm"`, }, }, diff --git a/templates/01_types.tpl b/templates/01_types.tpl index ecbeafb0f..cf4e0c7d5 100644 --- a/templates/01_types.tpl +++ b/templates/01_types.tpl @@ -16,14 +16,14 @@ type ( {{- end}} {{$varNameSingular}}Query struct { - *boil.Query + *queries.Query } ) // Cache for insert and update var ( {{$varNameSingular}}Type = reflect.TypeOf(&{{$tableNameSingular}}{}) - {{$varNameSingular}}Mapping = boil.MakeStructMapping({{$varNameSingular}}Type) + {{$varNameSingular}}Mapping = queries.MakeStructMapping({{$varNameSingular}}Type) {{$varNameSingular}}InsertCacheMut sync.RWMutex {{$varNameSingular}}InsertCache = make(map[string]insertCache) {{$varNameSingular}}UpdateCacheMut sync.RWMutex diff --git a/templates/03_finishers.tpl b/templates/03_finishers.tpl index b27f7742d..429a27625 100644 --- a/templates/03_finishers.tpl +++ b/templates/03_finishers.tpl @@ -14,7 +14,7 @@ func (q {{$varNameSingular}}Query) OneP() (*{{$tableNameSingular}}) { func (q {{$varNameSingular}}Query) One() (*{{$tableNameSingular}}, error) { o := &{{$tableNameSingular}}{} - boil.SetLimit(q.Query, 1) + queries.SetLimit(q.Query, 1) err := q.Bind(o) if err != nil { @@ -25,7 +25,7 @@ func (q {{$varNameSingular}}Query) One() (*{{$tableNameSingular}}, error) { } {{if not .NoHooks -}} - if err := o.doAfterSelectHooks(boil.GetExecutor(q.Query)); err != nil { + if err := o.doAfterSelectHooks(queries.GetExecutor(q.Query)); err != nil { return o, err } {{- end}} @@ -55,7 +55,7 @@ func (q {{$varNameSingular}}Query) All() ({{$tableNameSingular}}Slice, error) { {{if not .NoHooks -}} if len({{$varNameSingular}}AfterSelectHooks) != 0 { for _, obj := range o { - if err := obj.doAfterSelectHooks(boil.GetExecutor(q.Query)); err != nil { + if err := obj.doAfterSelectHooks(queries.GetExecutor(q.Query)); err != nil { return o, err } } @@ -79,8 +79,8 @@ func (q {{$varNameSingular}}Query) CountP() int64 { func (q {{$varNameSingular}}Query) Count() (int64, error) { var count int64 - boil.SetSelect(q.Query, nil) - boil.SetCount(q.Query) + queries.SetSelect(q.Query, nil) + queries.SetCount(q.Query) err := q.Query.QueryRow().Scan(&count) if err != nil { @@ -104,8 +104,8 @@ func (q {{$varNameSingular}}Query) ExistsP() bool { func (q {{$varNameSingular}}Query) Exists() (bool, error) { var count int64 - boil.SetCount(q.Query) - boil.SetLimit(q.Query, 1) + queries.SetCount(q.Query) + queries.SetLimit(q.Query, 1) err := q.Query.QueryRow().Scan(&count) if err != nil { diff --git a/templates/04_relationship_to_one.tpl b/templates/04_relationship_to_one.tpl index 02e594cf0..b95662651 100644 --- a/templates/04_relationship_to_one.tpl +++ b/templates/04_relationship_to_one.tpl @@ -16,7 +16,7 @@ func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}(exec bo queryMods = append(queryMods, mods...) query := {{.ForeignTable.NamePluralGo}}(exec, queryMods...) - boil.SetFrom(query.Query, "{{.ForeignTable.Name | $dot.SchemaTable}}") + queries.SetFrom(query.Query, "{{.ForeignTable.Name | $dot.SchemaTable}}") return query } diff --git a/templates/05_relationship_to_many.tpl b/templates/05_relationship_to_many.tpl index 92b998faf..0cf62d8f6 100644 --- a/templates/05_relationship_to_many.tpl +++ b/templates/05_relationship_to_many.tpl @@ -42,7 +42,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Na {{end}} query := {{$rel.ForeignTable.NamePluralGo}}(exec, queryMods...) - boil.SetFrom(query.Query, "{{$schemaForeignTable}} as {{id 0 | $dot.Quotes}}") + queries.SetFrom(query.Query, "{{$schemaForeignTable}} as {{id 0 | $dot.Quotes}}") return query } diff --git a/templates/06_relationship_to_one_eager.tpl b/templates/06_relationship_to_one_eager.tpl index 279792fa7..0911029a6 100644 --- a/templates/06_relationship_to_one_eager.tpl +++ b/templates/06_relationship_to_one_eager.tpl @@ -43,7 +43,7 @@ func ({{$varNameSingular}}L) Load{{.Function.Name}}(e boil.Executor, singular bo defer results.Close() var resultSlice []*{{.ForeignTable.NameGo}} - if err = boil.Bind(results, &resultSlice); err != nil { + if err = queries.Bind(results, &resultSlice); err != nil { return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable.NameGo}}") } diff --git a/templates/07_relationship_to_many_eager.tpl b/templates/07_relationship_to_many_eager.tpl index 82b7a99ba..aae8c4e3c 100644 --- a/templates/07_relationship_to_many_eager.tpl +++ b/templates/07_relationship_to_many_eager.tpl @@ -83,7 +83,7 @@ func ({{$varNameSingular}}L) Load{{$txt.Function.Name}}(e boil.Executor, singula return errors.Wrap(err, "failed to plebian-bind eager loaded slice {{.ForeignTable}}") } {{else -}} - if err = boil.Bind(results, &resultSlice); err != nil { + if err = queries.Bind(results, &resultSlice); err != nil { return errors.Wrap(err, "failed to bind eager loaded slice {{.ForeignTable}}") } {{end}} diff --git a/templates/11_find.tpl b/templates/11_find.tpl index 83eae619a..4c0e53c05 100644 --- a/templates/11_find.tpl +++ b/templates/11_find.tpl @@ -32,7 +32,7 @@ func Find{{$tableNameSingular}}(exec boil.Executor, {{$pkArgs}}, selectCols ...s "select %s from {{.Table.Name | .SchemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}", sel, ) - q := boil.SQL(exec, query, {{$pkNames | join ", "}}) + q := queries.Raw(exec, query, {{$pkNames | join ", "}}) err := q.Bind({{$varNameSingular}}Obj) if err != nil { diff --git a/templates/12_insert.tpl b/templates/12_insert.tpl index 38c4dbfd7..f18d2d57c 100644 --- a/templates/12_insert.tpl +++ b/templates/12_insert.tpl @@ -41,7 +41,7 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string } {{- end}} - nzDefaults := boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o) + nzDefaults := queries.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o) key := makeCacheKey(whitelist, nzDefaults) {{$varNameSingular}}InsertCacheMut.RLock() @@ -57,11 +57,11 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string whitelist, ) - cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, wl) + cache.valueMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, wl) if err != nil { return err } - cache.retMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, returnColumns) + cache.retMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, returnColumns) if err != nil { return err } @@ -77,7 +77,7 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string } value := reflect.Indirect(reflect.ValueOf(o)) - vals := boil.ValuesFromMapping(value, cache.valueMapping) + vals := queries.ValuesFromMapping(value, cache.valueMapping) {{if .UseLastInsertID}} if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, cache.query) @@ -122,14 +122,14 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string fmt.Fprintln(boil.DebugWriter, identifierCols...) } - err = exec.QueryRow(cache.retQuery, identifierCols...).Scan(boil.PtrsFromMapping(value, cache.retMapping)...) + err = exec.QueryRow(cache.retQuery, identifierCols...).Scan(queries.PtrsFromMapping(value, cache.retMapping)...) if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}") } } {{else}} if len(cache.retMapping) != 0 { - err = exec.QueryRow(cache.query, vals...).Scan(boil.PtrsFromMapping(value, cache.retMapping)...) + err = exec.QueryRow(cache.query, vals...).Scan(queries.PtrsFromMapping(value, cache.retMapping)...) } else { _, err = exec.Exec(cache.query, vals...) } diff --git a/templates/13_update.tpl b/templates/13_update.tpl index 5b2674dd8..7581b4566 100644 --- a/templates/13_update.tpl +++ b/templates/13_update.tpl @@ -57,7 +57,7 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, wl), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}len(wl)+1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns), ) - cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...)) + cache.valueMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...)) if err != nil { return err } @@ -67,7 +67,7 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string return errors.New("{{.PkgName}}: unable to update {{.Table.Name}}, could not build whitelist") } - values := boil.ValuesFromMapping(reflect.Indirect(reflect.ValueOf(o)), cache.valueMapping) + values := queries.ValuesFromMapping(reflect.Indirect(reflect.ValueOf(o)), cache.valueMapping) if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, cache.query) @@ -105,7 +105,7 @@ func (q {{$varNameSingular}}Query) UpdateAllP(cols M) { // UpdateAll updates all rows with the specified column values. func (q {{$varNameSingular}}Query) UpdateAll(cols M) error { - boil.SetUpdate(q.Query, cols) + queries.SetUpdate(q.Query, cols) _, err := q.Query.Exec() if err != nil { diff --git a/templates/14_upsert.tpl b/templates/14_upsert.tpl index b9900a59e..e3fc8163d 100644 --- a/templates/14_upsert.tpl +++ b/templates/14_upsert.tpl @@ -41,7 +41,7 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName {{$varNameSingular}}Columns, {{$varNameSingular}}ColumnsWithDefault, {{$varNameSingular}}ColumnsWithoutDefault, - boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), + queries.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), whitelist, ) update := strmangle.UpdateColumnSet( @@ -56,18 +56,18 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) copy(conflict, {{$varNameSingular}}PrimaryKeyColumns) } - query := boil.BuildUpsertQueryPostgres(dialect, "{{$schemaTable}}", updateOnConflict, ret, update, conflict, whitelist) + query := queries.BuildUpsertQueryPostgres(dialect, "{{$schemaTable}}", updateOnConflict, ret, update, conflict, whitelist) {{- else -}} - query := boil.BuildUpsertQueryMySQL(dialect, "{{.Table.Name}}", update, whitelist) + query := queries.BuildUpsertQueryMySQL(dialect, "{{.Table.Name}}", update, whitelist) {{- end}} if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, query) - fmt.Fprintln(boil.DebugWriter, boil.GetStructValues(o, whitelist...)) + fmt.Fprintln(boil.DebugWriter, queries.GetStructValues(o, whitelist...)) } {{- if .UseLastInsertID}} - result, err := exec.Exec(query, boil.GetStructValues(o, whitelist...)...) + result, err := exec.Exec(query, queries.GetStructValues(o, whitelist...)...) if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") } @@ -110,16 +110,16 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName fmt.Fprintln(boil.DebugWriter, identifierCols...) } - err = exec.QueryRow(retQuery, identifierCols...).Scan(boil.GetStructPointers(o, ret...)...) + err = exec.QueryRow(retQuery, identifierCols...).Scan(queries.GetStructPointers(o, ret...)...) if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}") } } {{- else}} if len(ret) != 0 { - err = exec.QueryRow(query, boil.GetStructValues(o, whitelist...)...).Scan(boil.GetStructPointers(o, ret...)...) + err = exec.QueryRow(query, queries.GetStructValues(o, whitelist...)...).Scan(queries.GetStructPointers(o, ret...)...) } else { - _, err = exec.Exec(query, boil.GetStructValues(o, whitelist...)...) + _, err = exec.Exec(query, queries.GetStructValues(o, whitelist...)...) } if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") diff --git a/templates/15_delete.tpl b/templates/15_delete.tpl index 7c80bc17d..839a695c3 100644 --- a/templates/15_delete.tpl +++ b/templates/15_delete.tpl @@ -78,7 +78,7 @@ func (q {{$varNameSingular}}Query) DeleteAll() error { return errors.New("{{.PkgName}}: no {{$varNameSingular}}Query provided for delete all") } - boil.SetDelete(q.Query) + queries.SetDelete(q.Query) _, err := q.Query.Exec() if err != nil { diff --git a/templates/16_reload.tpl b/templates/16_reload.tpl index 8a3bc2cd8..eb91ba20e 100644 --- a/templates/16_reload.tpl +++ b/templates/16_reload.tpl @@ -73,7 +73,7 @@ func (o *{{$tableNameSingular}}Slice) ReloadAll(exec boil.Executor) error { strmangle.Placeholders(dialect.IndexPlaceholders, len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)), ) - q := boil.SQL(exec, sql, args...) + q := queries.Raw(exec, sql, args...) err := q.Bind(&{{$varNamePlural}}) if err != nil { diff --git a/templates/singleton/boil_queries.tpl b/templates/singleton/boil_queries.tpl index 8326f24d2..d0879cb96 100644 --- a/templates/singleton/boil_queries.tpl +++ b/templates/singleton/boil_queries.tpl @@ -1,19 +1,19 @@ -var dialect = boil.Dialect{ +var dialect = queries.Dialect{ LQ: 0x{{printf "%x" .Dialect.LQ}}, RQ: 0x{{printf "%x" .Dialect.RQ}}, IndexPlaceholders: {{.Dialect.IndexPlaceholders}}, } // NewQueryG initializes a new Query using the passed in QueryMods -func NewQueryG(mods ...qm.QueryMod) *boil.Query { +func NewQueryG(mods ...qm.QueryMod) *queries.Query { return NewQuery(boil.GetDB(), mods...) } // NewQuery initializes a new Query using the passed in QueryMods -func NewQuery(exec boil.Executor, mods ...qm.QueryMod) *boil.Query { - q := &boil.Query{} - boil.SetExecutor(q, exec) - boil.SetDialect(q, &dialect) +func NewQuery(exec boil.Executor, mods ...qm.QueryMod) *queries.Query { + q := &queries.Query{} + queries.SetExecutor(q, exec) + queries.SetDialect(q, &dialect) qm.Apply(q, mods...) return q From 01f08efe8ae3a11270101bf6e8d4c1230f3b9884 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 15 Sep 2016 14:20:35 +1000 Subject: [PATCH 55/64] Fix lint errors for generated package --- templates/01_types.tpl | 3 +++ templates/02_hooks.tpl | 1 + templates/11_find.tpl | 8 ++++---- templates/15_delete.tpl | 2 +- templates/16_reload.tpl | 8 ++++++++ 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/templates/01_types.tpl b/templates/01_types.tpl index cf4e0c7d5..11d3f0860 100644 --- a/templates/01_types.tpl +++ b/templates/01_types.tpl @@ -10,8 +10,11 @@ var ( ) type ( + // {{$tableNameSingular}}Slice is an alias for a slice of pointers to {{$tableNameSingular}}. + // This should generally be used opposed to []{{$tableNameSingular}}. {{$tableNameSingular}}Slice []*{{$tableNameSingular}} {{if eq .NoHooks false -}} + // {{$tableNameSingular}}Hook is the signature for custom {{$tableNameSingular}} hook methods {{$tableNameSingular}}Hook func(boil.Executor, *{{$tableNameSingular}}) error {{- end}} diff --git a/templates/02_hooks.tpl b/templates/02_hooks.tpl index 9fa123653..e87283f4d 100644 --- a/templates/02_hooks.tpl +++ b/templates/02_hooks.tpl @@ -111,6 +111,7 @@ func (o *{{$tableNameSingular}}) doAfterUpsertHooks(exec boil.Executor) (err err return nil } +// Add{{$tableNameSingular}}Hook registers your hook function for all future operations. func Add{{$tableNameSingular}}Hook(hookPoint boil.HookPoint, {{$varNameSingular}}Hook {{$tableNameSingular}}Hook) { switch hookPoint { case boil.BeforeInsertHook: diff --git a/templates/11_find.tpl b/templates/11_find.tpl index 4c0e53c05..5fb5ffba9 100644 --- a/templates/11_find.tpl +++ b/templates/11_find.tpl @@ -4,12 +4,12 @@ {{- $colDefs := sqlColDefinitions .Table.Columns .Table.PKey.Columns -}} {{- $pkNames := $colDefs.Names | stringMap .StringFuncs.camelCase -}} {{- $pkArgs := joinSlices " " $pkNames $colDefs.Types | join ", " -}} -// {{$tableNameSingular}}FindG retrieves a single record by ID. +// Find{{$tableNameSingular}}G retrieves a single record by ID. func Find{{$tableNameSingular}}G({{$pkArgs}}, selectCols ...string) (*{{$tableNameSingular}}, error) { return Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...) } -// {{$tableNameSingular}}FindGP retrieves a single record by ID, and panics on error. +// Find{{$tableNameSingular}}GP retrieves a single record by ID, and panics on error. func Find{{$tableNameSingular}}GP({{$pkArgs}}, selectCols ...string) *{{$tableNameSingular}} { retobj, err := Find{{$tableNameSingular}}(boil.GetDB(), {{$pkNames | join ", "}}, selectCols...) if err != nil { @@ -19,7 +19,7 @@ func Find{{$tableNameSingular}}GP({{$pkArgs}}, selectCols ...string) *{{$tableNa return retobj } -// {{$tableNameSingular}}Find retrieves a single record by ID with an executor. +// Find{{$tableNameSingular}} retrieves a single record by ID with an executor. // If selectCols is empty Find will return all columns. func Find{{$tableNameSingular}}(exec boil.Executor, {{$pkArgs}}, selectCols ...string) (*{{$tableNameSingular}}, error) { {{$varNameSingular}}Obj := &{{$tableNameSingular}}{} @@ -45,7 +45,7 @@ func Find{{$tableNameSingular}}(exec boil.Executor, {{$pkArgs}}, selectCols ...s return {{$varNameSingular}}Obj, nil } -// {{$tableNameSingular}}FindP retrieves a single record by ID with an executor, and panics on error. +// Find{{$tableNameSingular}}P retrieves a single record by ID with an executor, and panics on error. func Find{{$tableNameSingular}}P(exec boil.Executor, {{$pkArgs}}, selectCols ...string) *{{$tableNameSingular}} { retobj, err := Find{{$tableNameSingular}}(exec, {{$pkNames | join ", "}}, selectCols...) if err != nil { diff --git a/templates/15_delete.tpl b/templates/15_delete.tpl index 839a695c3..18205bdd1 100644 --- a/templates/15_delete.tpl +++ b/templates/15_delete.tpl @@ -88,7 +88,7 @@ func (q {{$varNameSingular}}Query) DeleteAll() error { return nil } -// DeleteAll deletes all rows in the slice, and panics on error. +// DeleteAllGP deletes all rows in the slice, and panics on error. func (o {{$tableNameSingular}}Slice) DeleteAllGP() { if err := o.DeleteAllG(); err != nil { panic(boil.WrapErr(err)) diff --git a/templates/16_reload.tpl b/templates/16_reload.tpl index eb91ba20e..60a1f369c 100644 --- a/templates/16_reload.tpl +++ b/templates/16_reload.tpl @@ -37,18 +37,26 @@ func (o *{{$tableNameSingular}}) Reload(exec boil.Executor) error { return nil } +// ReloadAllGP refetches every row with matching primary key column values +// and overwrites the original object slice with the newly updated slice. +// Panics on error. func (o *{{$tableNameSingular}}Slice) ReloadAllGP() { if err := o.ReloadAllG(); err != nil { panic(boil.WrapErr(err)) } } +// ReloadAllP refetches every row with matching primary key column values +// and overwrites the original object slice with the newly updated slice. +// Panics on error. func (o *{{$tableNameSingular}}Slice) ReloadAllP(exec boil.Executor) { if err := o.ReloadAll(exec); err != nil { panic(boil.WrapErr(err)) } } +// ReloadAllG refetches every row with matching primary key column values +// and overwrites the original object slice with the newly updated slice. func (o *{{$tableNameSingular}}Slice) ReloadAllG() error { if o == nil { return errors.New("{{.PkgName}}: empty {{$tableNameSingular}}Slice provided for reload all") From e52fac9c5e9f0f00f7294cfbeb7ee51ec3bc3aa8 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 15 Sep 2016 14:27:06 +1000 Subject: [PATCH 56/64] Fix indentation --- templates/01_types.tpl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/templates/01_types.tpl b/templates/01_types.tpl index 11d3f0860..36cfd5f9f 100644 --- a/templates/01_types.tpl +++ b/templates/01_types.tpl @@ -3,10 +3,10 @@ {{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}} var ( - {{$varNameSingular}}Columns = []string{{"{"}}{{.Table.Columns | columnNames | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} - {{$varNameSingular}}ColumnsWithoutDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault false | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} - {{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault true | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} - {{$varNameSingular}}PrimaryKeyColumns = []string{{"{"}}{{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} + {{$varNameSingular}}Columns = []string{{"{"}}{{.Table.Columns | columnNames | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} + {{$varNameSingular}}ColumnsWithoutDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault false | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} + {{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{.Table.Columns | filterColumnsByDefault true | columnNames | stringMap .StringFuncs.quoteWrap | join ","}}{{"}"}} + {{$varNameSingular}}PrimaryKeyColumns = []string{{"{"}}{{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}{{"}"}} ) type ( From b3230c27576ab82f65e8c6b1678116d4fe9a5846 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 15 Sep 2016 14:46:03 +1000 Subject: [PATCH 57/64] Fix insert with goto --- templates/12_insert.tpl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/templates/12_insert.tpl b/templates/12_insert.tpl index f18d2d57c..c22c18245 100644 --- a/templates/12_insert.tpl +++ b/templates/12_insert.tpl @@ -88,21 +88,19 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}") } + + var lastID int64 + var identifierCols []interface{} if len(cache.retMapping) == 0 { - {{if not .NoHooks -}} - return o.doAfterInsertHooks(exec) - {{else -}} - return nil - {{end -}} + goto CacheNoHooks } - lastID, err := result.LastInsertId() + lastID, err = result.LastInsertId() if err != nil { return ErrSyncFail } - var identifierCols []interface{} if lastID != 0 { {{- $colName := index .Table.PKey.Columns 0 -}} {{- $col := .Table.GetColumn $colName -}} @@ -143,7 +141,9 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string return errors.Wrap(err, "{{.PkgName}}: unable to insert into {{.Table.Name}}") } {{end}} - +{{if .UseLastInsertID -}} +CacheNoHooks: +{{- end}} if !cached { {{$varNameSingular}}InsertCacheMut.Lock() {{$varNameSingular}}InsertCache[key] = cache From 0abfe1cba618084d038fb8f4e93c7a173041760e Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 15 Sep 2016 14:59:48 +1000 Subject: [PATCH 58/64] Fix upsert if else --- templates/14_upsert.tpl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/templates/14_upsert.tpl b/templates/14_upsert.tpl index e3fc8163d..ebce7c52c 100644 --- a/templates/14_upsert.tpl +++ b/templates/14_upsert.tpl @@ -127,10 +127,8 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName {{- end}} {{if not .NoHooks -}} - if err := o.doAfterUpsertHooks(exec); err != nil { - return err - } - {{- end}} - + return o.doAfterUpsertHooks(exec) + {{- else -}} return nil + {{- end}} } From 78de983d7d2d58152e4371843ea6edc65ec1b7e8 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 15 Sep 2016 15:13:09 +1000 Subject: [PATCH 59/64] Fix calls to Raw and RawG --- boil/queries/eager_load.go | 147 ------------------------------------- queries/query_test.go | 4 +- 2 files changed, 2 insertions(+), 149 deletions(-) delete mode 100644 boil/queries/eager_load.go diff --git a/boil/queries/eager_load.go b/boil/queries/eager_load.go deleted file mode 100644 index e6eafd80a..000000000 --- a/boil/queries/eager_load.go +++ /dev/null @@ -1,147 +0,0 @@ -package boil - -import ( - "database/sql" - "reflect" - - "github.com/pkg/errors" - "github.com/vattle/sqlboiler/strmangle" -) - -type loadRelationshipState struct { - exec Executor - loaded map[string]struct{} - toLoad []string -} - -func (l loadRelationshipState) hasLoaded(depth int) bool { - _, ok := l.loaded[l.buildKey(depth)] - return ok -} - -func (l loadRelationshipState) setLoaded(depth int) { - l.loaded[l.buildKey(depth)] = struct{}{} -} - -func (l loadRelationshipState) buildKey(depth int) string { - buf := strmangle.GetBuffer() - - for i, piece := range l.toLoad[:depth+1] { - if i != 0 { - buf.WriteByte('.') - } - buf.WriteString(piece) - } - - str := buf.String() - strmangle.PutBuffer(buf) - return str -} - -// loadRelationships dynamically calls the template generated eager load -// functions of the form: -// -// func (t *TableR) LoadRelationshipName(exec Executor, singular bool, obj interface{}) -// -// The arguments to this function are: -// - t is not considered here, and is always passed nil. The function exists on a loaded -// struct to avoid a circular dependency with boil, and the receiver is ignored. -// - exec is used to perform additional queries that might be required for loading the relationships. -// - singular is passed in to identify whether or not this was a single object -// or a slice that must be loaded into. -// - obj is the object or slice of objects, always of the type *obj or *[]*obj as per bind. -// -// It takes list of nested relationships to load. -func (l loadRelationshipState) loadRelationships(depth int, obj interface{}, bkind bindKind) error { - typ := reflect.TypeOf(obj).Elem() - if bkind == kindPtrSliceStruct { - typ = typ.Elem().Elem() - } - - if !l.hasLoaded(depth) { - current := l.toLoad[depth] - ln, found := typ.FieldByName(loaderStructName) - // It's possible a Loaders struct doesn't exist on the struct. - if !found { - return errors.Errorf("attempted to load %s but no L struct was found", current) - } - - // Attempt to find the LoadRelationshipName function - loadMethod, found := ln.Type.MethodByName(loadMethodPrefix + current) - if !found { - return errors.Errorf("could not find %s%s method for eager loading", loadMethodPrefix, current) - } - - // Hack to allow nil executors - execArg := reflect.ValueOf(l.exec) - if !execArg.IsValid() { - execArg = reflect.ValueOf((*sql.DB)(nil)) - } - - val := reflect.ValueOf(obj).Elem() - if bkind == kindPtrSliceStruct { - val = val.Index(0).Elem() - } - - methodArgs := []reflect.Value{ - val.FieldByName(loaderStructName), - execArg, - reflect.ValueOf(bkind == kindStruct), - reflect.ValueOf(obj), - } - resp := loadMethod.Func.Call(methodArgs) - if intf := resp[0].Interface(); intf != nil { - return errors.Wrapf(intf.(error), "failed to eager load %s", current) - } - - l.setLoaded(depth) - } - - // Pull one off the queue, continue if there's still some to go - depth++ - if depth >= len(l.toLoad) { - return nil - } - - loadedObject := reflect.ValueOf(obj) - // If we eagerly loaded nothing - if loadedObject.IsNil() { - return nil - } - loadedObject = reflect.Indirect(loadedObject) - - // If it's singular we can just immediately call without looping - if bkind == kindStruct { - return l.loadRelationshipsRecurse(depth, loadedObject) - } - - // Loop over all eager loaded objects - ln := loadedObject.Len() - if ln == 0 { - return nil - } - for i := 0; i < ln; i++ { - iter := loadedObject.Index(i).Elem() - if err := l.loadRelationshipsRecurse(depth, iter); err != nil { - return err - } - } - - return nil -} - -// loadRelationshipsRecurse is a helper function for taking a reflect.Value and -// Basically calls loadRelationships with: obj.R.EagerLoadedObj, and whether it's a string or slice -func (l loadRelationshipState) loadRelationshipsRecurse(depth int, obj reflect.Value) error { - r := obj.FieldByName(relationshipStructName) - if !r.IsValid() || r.IsNil() { - return errors.Errorf("could not traverse into loaded %s relationship to load more things", l.toLoad[depth]) - } - newObj := reflect.Indirect(r).FieldByName(l.toLoad[depth]) - bkind := kindStruct - if reflect.Indirect(newObj).Kind() != reflect.Struct { - bkind = kindPtrSliceStruct - newObj = newObj.Addr() - } - return l.loadRelationships(depth, newObj.Interface(), bkind) -} diff --git a/queries/query_test.go b/queries/query_test.go index 97f8a9d7c..06bccf45e 100644 --- a/queries/query_test.go +++ b/queries/query_test.go @@ -373,7 +373,7 @@ func TestAppendSelect(t *testing.T) { func TestSQL(t *testing.T) { t.Parallel() - q := SQL(&sql.DB{}, "thing", 5) + q := Raw(&sql.DB{}, "thing", 5) if q.rawSQL.sql != "thing" { t.Errorf("Expected %q, got %s", "thing", q.rawSQL.sql) } @@ -385,7 +385,7 @@ func TestSQL(t *testing.T) { func TestSQLG(t *testing.T) { t.Parallel() - q := SQLG("thing", 5) + q := RawG("thing", 5) if q.rawSQL.sql != "thing" { t.Errorf("Expected %q, got %s", "thing", q.rawSQL.sql) } From 09eeef63af65e5bd26671c2707b65e5fba7e8db6 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Wed, 14 Sep 2016 23:03:05 -0700 Subject: [PATCH 60/64] Optimize upsert. --- templates/01_types.tpl | 4 +- templates/14_upsert.tpl | 130 ++++++++++++++++++++--------- templates/singleton/boil_types.tpl | 13 ++- 3 files changed, 101 insertions(+), 46 deletions(-) diff --git a/templates/01_types.tpl b/templates/01_types.tpl index 36cfd5f9f..b3459d8bd 100644 --- a/templates/01_types.tpl +++ b/templates/01_types.tpl @@ -23,7 +23,7 @@ type ( } ) -// Cache for insert and update +// Cache for insert, update and upsert var ( {{$varNameSingular}}Type = reflect.TypeOf(&{{$tableNameSingular}}{}) {{$varNameSingular}}Mapping = queries.MakeStructMapping({{$varNameSingular}}Type) @@ -31,6 +31,8 @@ var ( {{$varNameSingular}}InsertCache = make(map[string]insertCache) {{$varNameSingular}}UpdateCacheMut sync.RWMutex {{$varNameSingular}}UpdateCache = make(map[string]updateCache) + {{$varNameSingular}}UpsertCacheMut sync.RWMutex + {{$varNameSingular}}UpsertCache = make(map[string]insertCache) ) // Force time package dependency for automated UpdatedAt/CreatedAt. diff --git a/templates/14_upsert.tpl b/templates/14_upsert.tpl index ebce7c52c..6505f9d5b 100644 --- a/templates/14_upsert.tpl +++ b/templates/14_upsert.tpl @@ -20,7 +20,7 @@ func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, {{if ne .DriverName panic(boil.WrapErr(err)) } } - + // Upsert attempts an insert using an executor, and does an update or ignore on conflict. func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { if o == nil { @@ -35,44 +35,97 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName } {{- end}} + // Build cache key in-line uglily - mysql vs postgres problems + buf := strmangle.GetBuffer() + {{if ne .DriverName "mysql" -}} + if updateOnConflict { + buf.WriteByte('t') + } else { + buf.WriteByte('f') + } + buf.WriteByte('.') + for _, c := range conflictColumns { + buf.WriteString(c) + } + buf.WriteByte('.') + {{end -}} + for _, c := range updateColumns { + buf.WriteString(c) + } + buf.WriteByte('.') + for _, c := range whitelist { + buf.WriteString(c) + } + key := buf.String() + strmangle.PutBuffer(buf) + + {{$varNameSingular}}UpsertCacheMut.RLock() + cache, cached := {{$varNameSingular}}UpsertCache[key] + {{$varNameSingular}}UpsertCacheMut.RUnlock() + var err error - var ret []string - whitelist, ret = strmangle.InsertColumnSet( - {{$varNameSingular}}Columns, - {{$varNameSingular}}ColumnsWithDefault, - {{$varNameSingular}}ColumnsWithoutDefault, - queries.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), - whitelist, - ) - update := strmangle.UpdateColumnSet( - {{$varNameSingular}}Columns, - {{$varNameSingular}}PrimaryKeyColumns, - updateColumns, - ) - {{if ne .DriverName "mysql" -}} - conflict := conflictColumns - if len(conflict) == 0 { - conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) - copy(conflict, {{$varNameSingular}}PrimaryKeyColumns) + if !cached { + var ret []string + whitelist, ret = strmangle.InsertColumnSet( + {{$varNameSingular}}Columns, + {{$varNameSingular}}ColumnsWithDefault, + {{$varNameSingular}}ColumnsWithoutDefault, + queries.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o), + whitelist, + ) + update := strmangle.UpdateColumnSet( + {{$varNameSingular}}Columns, + {{$varNameSingular}}PrimaryKeyColumns, + updateColumns, + ) + + {{if ne .DriverName "mysql" -}} + var conflict []string + if len(conflictColumns) == 0 { + conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) + copy(conflict, {{$varNameSingular}}PrimaryKeyColumns) + } + cache.query = queries.BuildUpsertQueryPostgres(dialect, "{{$schemaTable}}", updateOnConflict, ret, update, conflict, whitelist) + {{- else -}} + cache.query = queries.BuildUpsertQueryMySQL(dialect, "{{.Table.Name}}", update, whitelist) + cache.retQuery = fmt.Sprintf( + "SELECT %s FROM {{.LQ}}{{.Table.Name}}{{.RQ}} WHERE {{whereClause .LQ .RQ 0 .Table.PKey.Columns}}", + strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, ret), ","), + ) + {{- end}} + + cache.valueMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, whitelist) + if err != nil { + return err + } + if len(ret) != 0 { + cache.retMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, ret) + if err != nil { + return err + } + } + } + + value := reflect.Indirect(reflect.ValueOf(o)) + values := queries.ValuesFromMapping(value, cache.valueMapping) + var returns []interface{} + if len(cache.retMapping) != 0 { + returns = queries.PtrsFromMapping(value, cache.retMapping) } - query := queries.BuildUpsertQueryPostgres(dialect, "{{$schemaTable}}", updateOnConflict, ret, update, conflict, whitelist) - {{- else -}} - query := queries.BuildUpsertQueryMySQL(dialect, "{{.Table.Name}}", update, whitelist) - {{- end}} if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, query) - fmt.Fprintln(boil.DebugWriter, queries.GetStructValues(o, whitelist...)) + fmt.Fprintln(boil.DebugWriter, cache.query) + fmt.Fprintln(boil.DebugWriter, values) } {{- if .UseLastInsertID}} - result, err := exec.Exec(query, queries.GetStructValues(o, whitelist...)...) + result, err := exec.Exec(cache.query, values...) if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") } - if len(ret) == 0 { + if len(cache.retMapping) == 0 { {{if not .NoHooks -}} return o.doAfterUpsertHooks(exec) {{else -}} @@ -99,33 +152,34 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName } } - if lastID != 0 && len(ret) == 1 { - retQuery := fmt.Sprintf( - "SELECT %s FROM {{.LQ}}{{.Table.Name}}{{.RQ}} WHERE {{whereClause .LQ .RQ 0 .Table.PKey.Columns}}", - strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, ret), ","), - ) - + if lastID != 0 && len(cache.retMapping) == 1 { if boil.DebugMode { - fmt.Fprintln(boil.DebugWriter, ret) + fmt.Fprintln(boil.DebugWriter, cache.retQuery) fmt.Fprintln(boil.DebugWriter, identifierCols...) } - err = exec.QueryRow(retQuery, identifierCols...).Scan(queries.GetStructPointers(o, ret...)...) + err = exec.QueryRow(cache.retQuery, identifierCols...).Scan(returns...) if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to populate default values for {{.Table.Name}}") } } {{- else}} - if len(ret) != 0 { - err = exec.QueryRow(query, queries.GetStructValues(o, whitelist...)...).Scan(queries.GetStructPointers(o, ret...)...) + if len(cache.retMapping) != 0 { + err = exec.QueryRow(cache.query, values...).Scan(returns...) } else { - _, err = exec.Exec(query, queries.GetStructValues(o, whitelist...)...) + _, err = exec.Exec(cache.query, values...) } if err != nil { return errors.Wrap(err, "{{.PkgName}}: unable to upsert for {{.Table.Name}}") } {{- end}} + if !cached { + {{$varNameSingular}}UpsertCacheMut.Lock() + {{$varNameSingular}}UpsertCache[key] = cache + {{$varNameSingular}}UpsertCacheMut.Unlock() + } + {{if not .NoHooks -}} return o.doAfterUpsertHooks(exec) {{- else -}} diff --git a/templates/singleton/boil_types.tpl b/templates/singleton/boil_types.tpl index 5e3282826..ebb88918f 100644 --- a/templates/singleton/boil_types.tpl +++ b/templates/singleton/boil_types.tpl @@ -6,15 +6,15 @@ type M map[string]interface{} // fails or there was a primary key configuration that was not resolvable. var ErrSyncFail = errors.New("{{.PkgName}}: failed to synchronize data after insert") -type insertCache struct{ - query string - retQuery string +type insertCache struct { + query string + retQuery string valueMapping []uint64 - retMapping []uint64 + retMapping []uint64 } -type updateCache struct{ - query string +type updateCache struct { + query string valueMapping []uint64 } @@ -35,4 +35,3 @@ func makeCacheKey(wl, nzDefaults []string) string { strmangle.PutBuffer(buf) return str } - From f6323d5ebe4c99f302dc428235f2eb37a03e9783 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Wed, 14 Sep 2016 23:04:42 -0700 Subject: [PATCH 61/64] Delete unused code --- queries/reflect.go | 71 ----------------------------- queries/reflect_test.go | 98 ----------------------------------------- 2 files changed, 169 deletions(-) diff --git a/queries/reflect.go b/queries/reflect.go index bee0a123e..1a1e6cbc5 100644 --- a/queries/reflect.go +++ b/queries/reflect.go @@ -407,74 +407,3 @@ func makeCacheKey(typ string, cols []string) string { return mapKey } - -// GetStructValues returns the values (as interface) of the matching columns in obj -func GetStructValues(obj interface{}, columns ...string) []interface{} { - ret := make([]interface{}, len(columns)) - val := reflect.Indirect(reflect.ValueOf(obj)) - - for i, c := range columns { - fieldName := strmangle.TitleCase(c) - field := val.FieldByName(fieldName) - if !field.IsValid() { - panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj)) - } - ret[i] = field.Interface() - } - - return ret -} - -// GetSliceValues returns the values (as interface) of the matching columns in obj. -func GetSliceValues(slice []interface{}, columns ...string) []interface{} { - ret := make([]interface{}, len(slice)*len(columns)) - - for i, obj := range slice { - val := reflect.Indirect(reflect.ValueOf(obj)) - for j, c := range columns { - fieldName := strmangle.TitleCase(c) - field := val.FieldByName(fieldName) - if !field.IsValid() { - panic(fmt.Sprintf("unable to find field with name: %s\n%#v", fieldName, obj)) - } - ret[i*len(columns)+j] = field.Interface() - } - } - - return ret -} - -// GetStructPointers returns a slice of pointers to the matching columns in obj -func GetStructPointers(obj interface{}, columns ...string) []interface{} { - val := reflect.ValueOf(obj).Elem() - - var ln int - var getField func(reflect.Value, int) reflect.Value - - if len(columns) == 0 { - ln = val.NumField() - getField = func(v reflect.Value, i int) reflect.Value { - return v.Field(i) - } - } else { - ln = len(columns) - getField = func(v reflect.Value, i int) reflect.Value { - return v.FieldByName(strmangle.TitleCase(columns[i])) - } - } - - ret := make([]interface{}, ln) - for i := 0; i < ln; i++ { - field := getField(val, i) - - if !field.IsValid() { - // Although this breaks the abstraction of getField above - we know that v.Field(i) can't actually - // produce an Invalid value, so we make a hopefully safe assumption here. - panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(columns[i]))) - } - - ret[i] = field.Addr().Interface() - } - - return ret -} diff --git a/queries/reflect_test.go b/queries/reflect_test.go index 7641557f9..a1484dc43 100644 --- a/queries/reflect_test.go +++ b/queries/reflect_test.go @@ -6,10 +6,8 @@ import ( "strconv" "strings" "testing" - "time" "gopkg.in/DATA-DOG/go-sqlmock.v1" - "gopkg.in/nullbio/null.v5" ) func bin64(i uint64) string { @@ -609,99 +607,3 @@ func TestBind_InnerJoin(t *testing.T) { // t.Error("id is the wrong pointer") // } // } - -func TestGetStructValues(t *testing.T) { - t.Parallel() - - timeThing := time.Now() - o := struct { - TitleThing string - Name string - ID int - Stuff int - Things int - Time time.Time - NullBool null.Bool - }{ - TitleThing: "patrick", - Stuff: 10, - Things: 0, - Time: timeThing, - NullBool: null.NewBool(true, false), - } - - vals := GetStructValues(&o, "title_thing", "name", "id", "stuff", "things", "time", "null_bool") - if vals[0].(string) != "patrick" { - t.Errorf("Want test, got %s", vals[0]) - } - if vals[1].(string) != "" { - t.Errorf("Want empty string, got %s", vals[1]) - } - if vals[2].(int) != 0 { - t.Errorf("Want 0, got %d", vals[2]) - } - if vals[3].(int) != 10 { - t.Errorf("Want 10, got %d", vals[3]) - } - if vals[4].(int) != 0 { - t.Errorf("Want 0, got %d", vals[4]) - } - if !vals[5].(time.Time).Equal(timeThing) { - t.Errorf("Want %s, got %s", o.Time, vals[5]) - } - if !vals[6].(null.Bool).IsZero() { - t.Errorf("Want %v, got %v", o.NullBool, vals[6]) - } -} - -func TestGetSliceValues(t *testing.T) { - t.Parallel() - - o := []struct { - ID int - Name string - }{ - {5, "a"}, - {6, "b"}, - } - - in := make([]interface{}, len(o)) - in[0] = o[0] - in[1] = o[1] - - vals := GetSliceValues(in, "id", "name") - if got := vals[0].(int); got != 5 { - t.Error(got) - } - if got := vals[1].(string); got != "a" { - t.Error(got) - } - if got := vals[2].(int); got != 6 { - t.Error(got) - } - if got := vals[3].(string); got != "b" { - t.Error(got) - } -} - -func TestGetStructPointers(t *testing.T) { - t.Parallel() - - o := struct { - Title string - ID *int - }{ - Title: "patrick", - } - - ptrs := GetStructPointers(&o, "title", "id") - *ptrs[0].(*string) = "test" - if o.Title != "test" { - t.Errorf("Expected test, got %s", o.Title) - } - x := 5 - *ptrs[1].(**int) = &x - if *o.ID != 5 { - t.Errorf("Expected 5, got %d", *o.ID) - } -} From de7ba2fa8efb0cc3c13a7a6561dd8884ecf127f6 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Wed, 14 Sep 2016 23:20:42 -0700 Subject: [PATCH 62/64] Clean up the reflect tests. --- queries/reflect_test.go | 276 ++++++++++++++++++---------------------- 1 file changed, 126 insertions(+), 150 deletions(-) diff --git a/queries/reflect_test.go b/queries/reflect_test.go index a1484dc43..8d83ed35b 100644 --- a/queries/reflect_test.go +++ b/queries/reflect_test.go @@ -266,6 +266,76 @@ func TestPtrFromMapping(t *testing.T) { } } +func TestValuesFromMapping(t *testing.T) { + t.Parallel() + + type NestedPtrs struct { + Int int + IntP *int + NestedPtrsP *NestedPtrs + } + + val := &NestedPtrs{ + Int: 5, + IntP: new(int), + NestedPtrsP: &NestedPtrs{ + Int: 6, + IntP: new(int), + }, + } + + mapping := []uint64{testMakeMapping(0), testMakeMapping(1), testMakeMapping(2, 0), testMakeMapping(2, 1)} + v := ValuesFromMapping(reflect.Indirect(reflect.ValueOf(val)), mapping) + + if got := v[0].(int); got != 5 { + t.Error("flat int was wrong:", got) + } + if got := v[1].(int); got != 0 { + t.Error("flat pointer was wrong:", got) + } + if got := v[2].(int); got != 6 { + t.Error("nested int was wrong:", got) + } + if got := v[3].(int); got != 0 { + t.Error("nested pointer was wrong:", got) + } +} + +func TestPtrsFromMapping(t *testing.T) { + t.Parallel() + + type NestedPtrs struct { + Int int + IntP *int + NestedPtrsP *NestedPtrs + } + + val := &NestedPtrs{ + Int: 5, + IntP: new(int), + NestedPtrsP: &NestedPtrs{ + Int: 6, + IntP: new(int), + }, + } + + mapping := []uint64{testMakeMapping(0), testMakeMapping(1), testMakeMapping(2, 0), testMakeMapping(2, 1)} + v := PtrsFromMapping(reflect.Indirect(reflect.ValueOf(val)), mapping) + + if got := *v[0].(*int); got != 5 { + t.Error("flat int was wrong:", got) + } + if got := *v[1].(*int); got != 0 { + t.Error("flat pointer was wrong:", got) + } + if got := *v[2].(*int); got != 6 { + t.Error("nested int was wrong:", got) + } + if got := *v[3].(*int); got != 0 { + t.Error("nested pointer was wrong:", got) + } +} + func TestGetBoilTag(t *testing.T) { t.Parallel() @@ -457,153 +527,59 @@ func TestBind_InnerJoin(t *testing.T) { } } -// func TestBind_InnerJoinSelect(t *testing.T) { -// t.Parallel() -// -// testResults := []*struct { -// Happy struct { -// ID int -// } `boil:"h,bind"` -// Fun struct { -// ID int -// } `boil:",bind"` -// }{} -// -// query := &Query{ -// selectCols: []string{"fun.id", "h.id"}, -// from: []string{"fun"}, -// joins: []join{{kind: JoinInner, clause: "happy as h on fun.happy_id = h.id"}}, -// } -// -// db, mock, err := sqlmock.New() -// if err != nil { -// t.Error(err) -// } -// -// ret := sqlmock.NewRows([]string{"fun.id", "h.id"}) -// ret.AddRow(driver.Value(int64(10)), driver.Value(int64(11))) -// ret.AddRow(driver.Value(int64(12)), driver.Value(int64(13))) -// mock.ExpectQuery(`SELECT "fun"."id" as "fun.id", "h"."id" as "h.id" FROM "fun" INNER JOIN happy as h on fun.happy_id = h.id;`).WillReturnRows(ret) -// -// SetExecutor(query, db) -// err = query.Bind(&testResults) -// if err != nil { -// t.Error(err) -// } -// -// if len(testResults) != 2 { -// t.Fatal("wrong number of results:", len(testResults)) -// } -// if id := testResults[0].Happy.ID; id != 11 { -// t.Error("wrong ID:", id) -// } -// if id := testResults[0].Fun.ID; id != 10 { -// t.Error("wrong ID:", id) -// } -// -// if id := testResults[1].Happy.ID; id != 13 { -// t.Error("wrong ID:", id) -// } -// if id := testResults[1].Fun.ID; id != 12 { -// t.Error("wrong ID:", id) -// } -// -// if err := mock.ExpectationsWereMet(); err != nil { -// t.Error(err) -// } -// } - -// func TestBindPtrs_Easy(t *testing.T) { -// t.Parallel() -// -// testStruct := struct { -// ID int `boil:"identifier"` -// Date time.Time -// }{} -// -// cols := []string{"identifier", "date"} -// ptrs, err := bindPtrs(&testStruct, nil, cols...) -// if err != nil { -// t.Error(err) -// } -// -// if ptrs[0].(*int) != &testStruct.ID { -// t.Error("id is the wrong pointer") -// } -// if ptrs[1].(*time.Time) != &testStruct.Date { -// t.Error("id is the wrong pointer") -// } -// } -// -// func TestBindPtrs_Recursive(t *testing.T) { -// t.Parallel() -// -// testStruct := struct { -// Happy struct { -// ID int `boil:"identifier"` -// } -// Fun struct { -// ID int -// } `boil:",bind"` -// }{} -// -// cols := []string{"id", "fun.id"} -// ptrs, err := bindPtrs(&testStruct, nil, cols...) -// if err != nil { -// t.Error(err) -// } -// -// if ptrs[0].(*int) != &testStruct.Fun.ID { -// t.Error("id is the wrong pointer") -// } -// if ptrs[1].(*int) != &testStruct.Fun.ID { -// t.Error("id is the wrong pointer") -// } -// } -// -// func TestBindPtrs_RecursiveTags(t *testing.T) { -// t.Parallel() -// -// testStruct := struct { -// Happy struct { -// ID int `boil:"identifier"` -// } `boil:",bind"` -// Fun struct { -// ID int `boil:"identification"` -// } `boil:",bind"` -// }{} -// -// cols := []string{"happy.identifier", "fun.identification"} -// ptrs, err := bindPtrs(&testStruct, nil, cols...) -// if err != nil { -// t.Error(err) -// } -// -// if ptrs[0].(*int) != &testStruct.Happy.ID { -// t.Error("id is the wrong pointer") -// } -// if ptrs[1].(*int) != &testStruct.Fun.ID { -// t.Error("id is the wrong pointer") -// } -// } -// -// func TestBindPtrs_Ignore(t *testing.T) { -// t.Parallel() -// -// testStruct := struct { -// ID int `boil:"-"` -// Happy struct { -// ID int -// } `boil:",bind"` -// }{} -// -// cols := []string{"id"} -// ptrs, err := bindPtrs(&testStruct, nil, cols...) -// if err != nil { -// t.Error(err) -// } -// -// if ptrs[0].(*int) != &testStruct.Happy.ID { -// t.Error("id is the wrong pointer") -// } -// } +func TestBind_InnerJoinSelect(t *testing.T) { + t.Parallel() + + testResults := []*struct { + Happy struct { + ID int + } `boil:"h,bind"` + Fun struct { + ID int + } `boil:",bind"` + }{} + + query := &Query{ + dialect: &Dialect{LQ: '"', RQ: '"', IndexPlaceholders: true}, + selectCols: []string{"fun.id", "h.id"}, + from: []string{"fun"}, + joins: []join{{kind: JoinInner, clause: "happy as h on fun.happy_id = h.id"}}, + } + + db, mock, err := sqlmock.New() + if err != nil { + t.Error(err) + } + + ret := sqlmock.NewRows([]string{"fun.id", "h.id"}) + ret.AddRow(driver.Value(int64(10)), driver.Value(int64(11))) + ret.AddRow(driver.Value(int64(12)), driver.Value(int64(13))) + mock.ExpectQuery(`SELECT "fun"."id" as "fun.id", "h"."id" as "h.id" FROM "fun" INNER JOIN happy as h on fun.happy_id = h.id;`).WillReturnRows(ret) + + SetExecutor(query, db) + err = query.Bind(&testResults) + if err != nil { + t.Error(err) + } + + if len(testResults) != 2 { + t.Fatal("wrong number of results:", len(testResults)) + } + if id := testResults[0].Happy.ID; id != 11 { + t.Error("wrong ID:", id) + } + if id := testResults[0].Fun.ID; id != 10 { + t.Error("wrong ID:", id) + } + + if id := testResults[1].Happy.ID; id != 13 { + t.Error("wrong ID:", id) + } + if id := testResults[1].Fun.ID; id != 12 { + t.Error("wrong ID:", id) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Error(err) + } +} From c249cf49d06931913290646671494fb7a8207a06 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Wed, 14 Sep 2016 23:33:18 -0700 Subject: [PATCH 63/64] Fix Boris' name in the README. - Fix table formatting - Remove section for MySQL vs Postgres config for global options - Fix path to Bind after refactor. --- README.md | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 6d39fbb04..367b5d44c 100644 --- a/README.md +++ b/README.md @@ -225,19 +225,19 @@ values that can go in that section: You can also pass in these top level configuration values if you would prefer not to pass them through the command line or environment variables: -| Name | Postgres Default | Mysql Default -| --- | --- | --- | -| basedir | none | none | -| schema | "public" | *N/A* | -| pkgname | "models" | "models" | -| output | "models" | "models" | -| whitelist | [] | [] | -| blacklist | [] | [] | -| tag | [] | [] | -| debug | false | false | -| no-hooks | false | false | -| no-tests | false | false | -| no-auto-timestamps | false | false | +| Name | Defaults | +| ------------------ | --------- | +| basedir | none | +| schema | "public" *(or dbname for mysql)* | +| pkgname | "models" | +| output | "models" | +| whitelist | [] | +| blacklist | [] | +| tag | [] | +| debug | false | +| no-hooks | false | +| no-tests | false | +| no-auto-timestamps | false | Example: @@ -582,7 +582,7 @@ in combination with your own custom, non-generated model. ### Binding -For a comprehensive ruleset for `Bind()` you can refer to our [godoc](https://godoc.org/github.com/vattle/sqlboiler/boil#Bind). +For a comprehensive ruleset for `Bind()` you can refer to our [godoc](https://godoc.org/github.com/vattle/sqlboiler/queries#Bind). The `Bind()` [Finisher](#finisher) allows the results of a query built with the [Raw SQL](#raw-query) method or the [Query Builder](#query-building) methods to be bound @@ -893,8 +893,8 @@ err := p1.Insert(db) // Insert the first pilot with name "Larry" // p1 now has an ID field set to 1 var p2 models.Pilot -p2.Name "Borris" -err := p2.Insert(db) // Insert the second pilot with name "Borris" +p2.Name "Boris" +err := p2.Insert(db) // Insert the second pilot with name "Boris" // p2 now has an ID field set to 2 var p3 models.Pilot From 40ce5838f3132b3b965f7a02872b63b356e998a3 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 15 Sep 2016 16:58:24 +1000 Subject: [PATCH 64/64] Fix hstore naming --- bdb/drivers/postgres.go | 4 ++-- randomize/randomize.go | 15 +++++++-------- types/hstore.go | 8 ++++---- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index e992299b9..0422939bb 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -307,7 +307,7 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.DBType = c.DBType + *c.ArrType case "USER-DEFINED": if c.UDTName == "hstore" { - c.Type = "types.Hstore" + c.Type = "types.HStore" c.DBType = "hstore" } else { c.Type = "string" @@ -344,7 +344,7 @@ func (p *PostgresDriver) TranslateColumnType(c bdb.Column) bdb.Column { c.DBType = c.DBType + *c.ArrType case "USER-DEFINED": if c.UDTName == "hstore" { - c.Type = "types.Hstore" + c.Type = "types.HStore" c.DBType = "hstore" } else { c.Type = "string" diff --git a/randomize/randomize.go b/randomize/randomize.go index 3f5ae295c..739198067 100644 --- a/randomize/randomize.go +++ b/randomize/randomize.go @@ -14,7 +14,6 @@ import ( "gopkg.in/nullbio/null.v5" - "github.com/lib/pq/hstore" "github.com/pkg/errors" "github.com/satori/go.uuid" "github.com/vattle/sqlboiler/strmangle" @@ -46,7 +45,7 @@ var ( typeBoolArray = reflect.TypeOf(types.BoolArray{}) typeFloat64Array = reflect.TypeOf(types.Float64Array{}) typeStringArray = reflect.TypeOf(types.StringArray{}) - typeHstore = reflect.TypeOf(types.Hstore{}) + typeHStore = reflect.TypeOf(types.HStore{}) rgxValidTime = regexp.MustCompile(`[2-9]+`) validatedTypes = []string{ @@ -226,10 +225,10 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo value = null.NewJSON([]byte(fmt.Sprintf(`"%s"`, randStr(s, 1))), true) field.Set(reflect.ValueOf(value)) return nil - case typeHstore: - value := hstore.Hstore{Map: map[string]sql.NullString{}} - value.Map[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} - value.Map[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} + case typeHStore: + value := types.HStore{} + value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} + value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} field.Set(reflect.ValueOf(value)) return nil } @@ -294,8 +293,8 @@ func randomizeField(s *Seed, field reflect.Value, fieldType string, canBeNull bo value = []byte(fmt.Sprintf(`"%s"`, randStr(s, 1))) field.Set(reflect.ValueOf(value)) return nil - case typeHstore: - value := types.Hstore{} + case typeHStore: + value := types.HStore{} value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} value[randStr(s, 3)] = sql.NullString{String: randStr(s, 3), Valid: s.nextInt()%3 == 0} field.Set(reflect.ValueOf(value)) diff --git a/types/hstore.go b/types/hstore.go index 6d642c8be..101a4f111 100644 --- a/types/hstore.go +++ b/types/hstore.go @@ -25,8 +25,8 @@ import ( "strings" ) -// Hstore is a wrapper for transferring Hstore values back and forth easily. -type Hstore map[string]sql.NullString +// HStore is a wrapper for transferring HStore values back and forth easily. +type HStore map[string]sql.NullString // escapes and quotes hstore keys/values // s should be a sql.NullString or string @@ -52,7 +52,7 @@ func hQuote(s interface{}) string { // // Note h is reallocated before the scan to clear existing values. If the // hstore column's database value is NULL, then h is set to nil instead. -func (h *Hstore) Scan(value interface{}) error { +func (h *HStore) Scan(value interface{}) error { if value == nil { h = nil return nil @@ -122,7 +122,7 @@ func (h *Hstore) Scan(value interface{}) error { // Value implements the driver Valuer interface. Note if h is nil, the // database column value will be set to NULL. -func (h Hstore) Value() (driver.Value, error) { +func (h HStore) Value() (driver.Value, error) { if h == nil { return nil, nil }