diff --git a/README.md b/README.md index e7bb424b..7aa537d5 100644 --- a/README.md +++ b/README.md @@ -226,9 +226,13 @@ dbmap := &gorp.DbMap{Db: db, Dialect: gorp.MySQLDialect{"InnoDB", "UTF8"}} // SetKeys(true) means we have a auto increment primary key, which // will get automatically bound to your struct post-insert // -t1 := dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id") -t2 := dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id") +t1 := dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id") +t2 := dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id") t3 := dbmap.AddTableWithName(Product{}, "product_test").SetKeys(true, "Id") + +// SetForeignKey will declare that Invoice.PersonId is a foreign key for +// Person.Id, and delete/update actions are to be treated as specified. +t2.ColMap("PersonId").SetForeignKey(gorp.NewForeignKey("Person", "Id").OnDelete(gorp.Restrict).OnUpdate(gorp.Cascade)) ``` ### Struct Embedding ### @@ -450,6 +454,15 @@ func InsertInv(dbmap *DbMap, inv *Invoice, per *Person) error { } ``` +### Foreign Keys ### + +You can define the foreign-key relationships when you create the table schema. +The `ColumnMap` has a `SetForeignKey` method to do this, shown in Examples above. + +Gorp only uses this when `CreateTables` or `CreateTablesIfNotExists` are invoked. +Thereafter, Gorp plays no further part because the database itself enforces the +consistency of keys, returning an error any time that a constraint would be violated. + ### Hooks ### Use hooks to update data before/after saving to the db. Good for timestamps: @@ -644,3 +657,4 @@ Thanks! * matthias-margush - column aliasing via tags * Rob Figueiredo - @robfig * Quinn Slack - @sqs +* Rick Beton - @rickb777 - foreign keys diff --git a/dbmap.go b/dbmap.go new file mode 100644 index 00000000..8af9a53b --- /dev/null +++ b/dbmap.go @@ -0,0 +1,280 @@ +// Copyright 2012 James Cooper. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +// Package gorp provides a simple way to marshal Go structs to and from +// SQL databases. It uses the database/sql package, and should work with any +// compliant database/sql driver. +// +// Source code and project home: +// https://github.com/coopernurse/gorp +// +package gorp + +import ( + "database/sql" + "errors" + "fmt" + "reflect" +) + +// TraceOn turns on SQL statement logging for this DbMap. After this is +// called, all SQL statements will be sent to the logger. If prefix is +// a non-empty string, it will be written to the front of all logged +// strings, which can aid in filtering log lines. +// +// Use TraceOn if you want to spy on the SQL statements that gorp +// generates. +// +// Note that the base log.Logger type satisfies GorpLogger, but adapters can +// easily be written for other logging packages (e.g., the golang-sanctioned +// glog framework). +func (m *DbMap) TraceOn(prefix string, logger GorpLogger) { + m.logger = logger + if prefix == "" { + m.logPrefix = prefix + } else { + m.logPrefix = fmt.Sprintf("%s ", prefix) + } +} + +// TraceOff turns off tracing. It is idempotent. +func (m *DbMap) TraceOff() { + m.logger = nil + m.logPrefix = "" +} + +// TruncateTables iterates through TableMaps registered to this DbMap and +// executes "truncate table" statements against the database for each, or in the case of +// sqlite, a "delete from" with no "where" clause, which uses the truncate optimization +// (http://www.sqlite.org/lang_delete.html) +func (m *DbMap) TruncateTables() error { + var err error + for i := range m.tables { + table := m.tables[i] + _, e := m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(), m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) + if e != nil { + err = e + } + } + return err +} + +// Insert runs a SQL INSERT statement for each element in list. List +// items must be pointers. +// +// Any interface whose TableMap has an auto-increment primary key will +// have its last insert id bound to the PK field on the struct. +// +// The hook functions PreInsert() and/or PostInsert() will be executed +// before/after the INSERT statement if the interface defines them. +// +// Panics if any interface in the list has not been registered with AddTable +func (m *DbMap) Insert(list ...interface{}) error { + return insert(m, m, list...) +} + +// Update runs a SQL UPDATE statement for each element in list. List +// items must be pointers. +// +// The hook functions PreUpdate() and/or PostUpdate() will be executed +// before/after the UPDATE statement if the interface defines them. +// +// Returns the number of rows updated. +// +// Returns an error if SetKeys has not been called on the TableMap +// Panics if any interface in the list has not been registered with AddTable +func (m *DbMap) Update(list ...interface{}) (int64, error) { + return update(m, m, list...) +} + +// Delete runs a SQL DELETE statement for each element in list. List +// items must be pointers. +// +// The hook functions PreDelete() and/or PostDelete() will be executed +// before/after the DELETE statement if the interface defines them. +// +// Returns the number of rows deleted. +// +// Returns an error if SetKeys has not been called on the TableMap +// Panics if any interface in the list has not been registered with AddTable +func (m *DbMap) Delete(list ...interface{}) (int64, error) { + return delete(m, m, list...) +} + +// Get runs a SQL SELECT to fetch a single row from the table based on the +// primary key(s) +// +// i should be an empty value for the struct to load. keys should be +// the primary key value(s) for the row to load. If multiple keys +// exist on the table, the order should match the column order +// specified in SetKeys() when the table mapping was defined. +// +// The hook function PostGet() will be executed after the SELECT +// statement if the interface defines them. +// +// Returns a pointer to a struct that matches or nil if no row is found. +// +// Returns an error if SetKeys has not been called on the TableMap +// Panics if any interface in the list has not been registered with AddTable +func (m *DbMap) Get(i interface{}, keys ...interface{}) (interface{}, error) { + return get(m, m, i, keys...) +} + +// Select runs an arbitrary SQL query, binding the columns in the result +// to fields on the struct specified by i. args represent the bind +// parameters for the SQL statement. +// +// Column names on the SELECT statement should be aliased to the field names +// on the struct i. Returns an error if one or more columns in the result +// do not match. It is OK if fields on i are not part of the SQL +// statement. +// +// The hook function PostGet() will be executed after the SELECT +// statement if the interface defines them. +// +// Values are returned in one of two ways: +// 1. If i is a struct or a pointer to a struct, returns a slice of pointers to +// matching rows of type i. +// 2. If i is a pointer to a slice, the results will be appended to that slice +// and nil returned. +// +// i does NOT need to be registered with AddTable() +func (m *DbMap) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { + return hookedselect(m, m, i, query, args...) +} + +// Exec runs an arbitrary SQL statement. args represent the bind parameters. +// This is equivalent to running: Exec() using database/sql +func (m *DbMap) Exec(query string, args ...interface{}) (sql.Result, error) { + err := m.initialise() + if err != nil { + return nil, err + } + m.trace(query, args...) + return m.Db.Exec(query, args...) +} + +// SelectInt is a convenience wrapper around the gorp.SelectInt function +func (m *DbMap) SelectInt(query string, args ...interface{}) (int64, error) { + return SelectInt(m, query, args...) +} + +// SelectNullInt is a convenience wrapper around the gorp.SelectNullInt function +func (m *DbMap) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) { + return SelectNullInt(m, query, args...) +} + +// SelectFloat is a convenience wrapper around the gorp.SelectFlot function +func (m *DbMap) SelectFloat(query string, args ...interface{}) (float64, error) { + return SelectFloat(m, query, args...) +} + +// SelectNullFloat is a convenience wrapper around the gorp.SelectNullFloat function +func (m *DbMap) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) { + return SelectNullFloat(m, query, args...) +} + +// SelectStr is a convenience wrapper around the gorp.SelectStr function +func (m *DbMap) SelectStr(query string, args ...interface{}) (string, error) { + return SelectStr(m, query, args...) +} + +// SelectNullStr is a convenience wrapper around the gorp.SelectNullStr function +func (m *DbMap) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) { + return SelectNullStr(m, query, args...) +} + +// SelectOne is a convenience wrapper around the gorp.SelectOne function +func (m *DbMap) SelectOne(holder interface{}, query string, args ...interface{}) error { + return SelectOne(m, m, holder, query, args...) +} + +// Begin starts a gorp Transaction +func (m *DbMap) Begin() (*Transaction, error) { + m.trace("begin;") + tx, err := m.Db.Begin() + if err != nil { + return nil, err + } + return &Transaction{m, tx, false}, nil +} + +func (m *DbMap) tableFor(t reflect.Type, checkPK bool) (*TableMap, error) { + table := tableOrNil(m, t) + if table == nil { + return nil, errors.New(fmt.Sprintf("No table found for type: %v", t.Name())) + } + + if checkPK && len(table.keys) < 1 { + e := fmt.Sprintf("gorp: No keys defined for table: %s", + table.TableName) + return nil, errors.New(e) + } + + return table, nil +} + +func tableOrNil(m *DbMap, t reflect.Type) *TableMap { + for i := range m.tables { + table := m.tables[i] + if table.gotype == t { + return table + } + } + return nil +} + +func (m *DbMap) tableForPointer(ptr interface{}, checkPK bool) (*TableMap, reflect.Value, error) { + ptrv := reflect.ValueOf(ptr) + if ptrv.Kind() != reflect.Ptr { + e := fmt.Sprintf("gorp: passed non-pointer: %v (kind=%v)", ptr, + ptrv.Kind()) + return nil, reflect.Value{}, errors.New(e) + } + elem := ptrv.Elem() + etype := reflect.TypeOf(elem.Interface()) + t, err := m.tableFor(etype, checkPK) + if err != nil { + return nil, reflect.Value{}, err + } + + return t, elem, nil +} + +func (m *DbMap) QueryRow(query string, args ...interface{}) *sql.Row { + err := m.initialise() + if err != nil { + panic(err) + } + m.initialise() + m.trace(query, args...) + return m.Db.QueryRow(query, args...) +} + +func (m *DbMap) Query(query string, args ...interface{}) (*sql.Rows, error) { + err := m.initialise() + if err != nil { + return nil, err + } + m.trace(query, args...) + return m.Db.Query(query, args...) +} + +func (m *DbMap) trace(query string, args ...interface{}) { + if m.logger != nil { + m.logger.Printf("%s%s %v", m.logPrefix, query, args) + } +} + +func (m *DbMap) initialise() (err error) { + if !m.initialised { + m.initialised = true + if m.Dialect.InitString() != "" { + m.trace(m.Dialect.InitString()) + _, err = m.Db.Exec(m.Dialect.InitString()) + } + } + return +} + diff --git a/dialect.go b/dialect.go index 6b6ef0e8..ccccf58b 100644 --- a/dialect.go +++ b/dialect.go @@ -25,6 +25,12 @@ type Dialect interface { AutoIncrInsertSuffix(col *ColumnMap) string + // Creates the trailing foreign key reference in a column specification. + CreateForeignKeySuffix(references *ForeignKey) string + + // Creates the separate foreign key reference for a column. + CreateForeignKeyBlock(col *ColumnMap) string + // string to append to "create table" statement for vendor specific // table attributes CreateTableSuffix() string @@ -51,6 +57,11 @@ type Dialect interface { // schema - The schema that lives in // table - The table name QuotedTableForQuery(schema string, table string) string + + // Sends an initialisation instruction when connecting to the database. + // Primarily, this exists for Sqlite3 because foreign keys are disable + // by default, unlike Postgresql and Mysql InnoDB. + InitString() string } func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { @@ -61,6 +72,19 @@ func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interf return res.LastInsertId() } +func standardOnChangeStr(change string, action FKOnChangeAction) string { + prefix := "\n " + switch action { + case Unspecified: return "" + case NoAction: return prefix + "on " + change + " no action" + case Restrict: return prefix + "on " + change + " restrict" + case Cascade: return prefix + "on " + change + " cascade" + case SetNull: return prefix + "on " + change + " set null" + case Delete: return prefix + "on " + change + " delete" + } + return "" +} + /////////////////////////////////////////////////////// // sqlite3 // ///////////// @@ -115,6 +139,19 @@ func (d SqliteDialect) AutoIncrInsertSuffix(col *ColumnMap) string { return "" } +func (d SqliteDialect) CreateForeignKeySuffix(references *ForeignKey) string { + return "" +} + +func (d SqliteDialect) CreateForeignKeyBlock(col *ColumnMap) string { + return fmt.Sprintf("foreign key (%s) references %s (%s)", + d.QuoteField(col.ColumnName), + d.QuoteField(col.References.ReferencedTable), + d.QuoteField(col.References.ReferencedColumn)) + + standardOnChangeStr("update", col.References.ActionOnUpdate) + + standardOnChangeStr("delete", col.References.ActionOnDelete) +} + // Returns suffix func (d SqliteDialect) CreateTableSuffix() string { return d.suffix @@ -145,6 +182,11 @@ func (d SqliteDialect) QuotedTableForQuery(schema string, table string) string { return d.QuoteField(table) } +// sqlite3 has foreign keys disabled by default (will be enabled in sqlite4). +func (d SqliteDialect) InitString() string { + return "pragma foreign_keys = ON;" +} + /////////////////////////////////////////////////////// // PostgreSQL // //////////////// @@ -211,6 +253,18 @@ func (d PostgresDialect) AutoIncrInsertSuffix(col *ColumnMap) string { return " returning " + col.ColumnName } +func (d PostgresDialect) CreateForeignKeySuffix(references *ForeignKey) string { + refTable := d.QuotedTableForQuery("", references.ReferencedTable) + refField := d.QuoteField(references.ReferencedColumn) + return fmt.Sprintf(" references %s (%s)%s%s", refTable, refField, + standardOnChangeStr("delete", references.ActionOnDelete), + standardOnChangeStr("update", references.ActionOnUpdate)) +} + +func (d PostgresDialect) CreateForeignKeyBlock(col *ColumnMap) string { + return "" +} + // Returns suffix func (d PostgresDialect) CreateTableSuffix() string { return d.suffix @@ -226,7 +280,7 @@ func (d PostgresDialect) BindVar(i int) string { } func (d PostgresDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { - rows, err := exec.query(insertSql, params...) + rows, err := exec.Query(insertSql, params...) if err != nil { return 0, err } @@ -238,7 +292,7 @@ func (d PostgresDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, para return id, err } - return 0, errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error()) + return 0, errors.New("No serial value returned for insert: "+insertSql+" Encountered error: "+rows.Err().Error()) } func (d PostgresDialect) QuoteField(f string) string { @@ -253,6 +307,10 @@ func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string return schema + "." + d.QuoteField(table) } +func (d PostgresDialect) InitString() string { + return "" +} + /////////////////////////////////////////////////////// // MySQL // /////////// @@ -267,10 +325,10 @@ type MySQLDialect struct { Encoding string } -func (m MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { +func (d MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { switch val.Kind() { case reflect.Ptr: - return m.ToSqlType(val.Elem(), maxsize, isAutoIncr) + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) case reflect.Bool: return "boolean" case reflect.Int8: @@ -315,49 +373,62 @@ func (m MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) } // Returns auto_increment -func (m MySQLDialect) AutoIncrStr() string { +func (d MySQLDialect) AutoIncrStr() string { return "auto_increment" } -func (m MySQLDialect) AutoIncrBindValue() string { +func (d MySQLDialect) AutoIncrBindValue() string { return "null" } -func (m MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string { +func (d MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string { return "" } +func (d MySQLDialect) CreateForeignKeySuffix(references *ForeignKey) string { + return "" +} + +func (d MySQLDialect) CreateForeignKeyBlock(col *ColumnMap) string { + return fmt.Sprintf("foreign key (%s) references %s (%s)", + d.QuoteField(col.ColumnName), + d.QuoteField(col.References.ReferencedTable), + d.QuoteField(col.References.ReferencedColumn)) + + standardOnChangeStr("update", col.References.ActionOnUpdate) + + standardOnChangeStr("delete", col.References.ActionOnDelete) +} + // Returns engine=%s charset=%s based on values stored on struct -func (m MySQLDialect) CreateTableSuffix() string { - if m.Engine == "" || m.Encoding == "" { +func (d MySQLDialect) CreateTableSuffix() string { + if d.Engine == "" || d.Encoding == "" { msg := "gorp - undefined" - if m.Engine == "" { + if d.Engine == "" { msg += " MySQLDialect.Engine" } - if m.Engine == "" && m.Encoding == "" { + if d.Engine == "" && d.Encoding == "" { msg += "," } - if m.Encoding == "" { + if d.Encoding == "" { msg += " MySQLDialect.Encoding" } msg += ". Check that your MySQLDialect was correctly initialized when declared." panic(msg) } - return fmt.Sprintf(" engine=%s charset=%s", m.Engine, m.Encoding) + return fmt.Sprintf(" engine=%s charset=%s", d.Engine, d.Encoding) } -func (m MySQLDialect) TruncateClause() string { +func (d MySQLDialect) TruncateClause() string { return "truncate" } // Returns "?" -func (m MySQLDialect) BindVar(i int) string { +func (d MySQLDialect) BindVar(i int) string { return "?" } -func (m MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { +func (d MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { return standardInsertAutoIncr(exec, insertSql, params...) } @@ -369,3 +440,7 @@ func (d MySQLDialect) QuoteField(f string) string { func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string { return d.QuoteField(table) } + +func (d MySQLDialect) InitString() string { + return "" +} diff --git a/gorp.go b/gorp.go index 5ba0ba37..a249d03f 100644 --- a/gorp.go +++ b/gorp.go @@ -14,7 +14,6 @@ package gorp import ( "bytes" "database/sql" - "errors" "fmt" "reflect" "regexp" @@ -93,142 +92,6 @@ func (me CustomScanner) Bind() error { return me.Binder(me.Holder, me.Target) } -// DbMap is the root gorp mapping object. Create one of these for each -// database schema you wish to map. Each DbMap contains a list of -// mapped tables. -// -// Example: -// -// dialect := gorp.MySQLDialect{"InnoDB", "UTF8"} -// dbmap := &gorp.DbMap{Db: db, Dialect: dialect} -// -type DbMap struct { - // Db handle to use with this map - Db *sql.DB - - // Dialect implementation to use with this map - Dialect Dialect - - TypeConverter TypeConverter - - tables []*TableMap - logger GorpLogger - logPrefix string -} - -// TableMap represents a mapping between a Go struct and a database table -// Use dbmap.AddTable() or dbmap.AddTableWithName() to create these -type TableMap struct { - // Name of database table. - TableName string - SchemaName string - gotype reflect.Type - columns []*ColumnMap - keys []*ColumnMap - uniqueTogether [][]string - version *ColumnMap - insertPlan bindPlan - updatePlan bindPlan - deletePlan bindPlan - getPlan bindPlan - dbmap *DbMap -} - -// ResetSql removes cached insert/update/select/delete SQL strings -// associated with this TableMap. Call this if you've modified -// any column names or the table name itself. -func (t *TableMap) ResetSql() { - t.insertPlan = bindPlan{} - t.updatePlan = bindPlan{} - t.deletePlan = bindPlan{} - t.getPlan = bindPlan{} -} - -// SetKeys lets you specify the fields on a struct that map to primary -// key columns on the table. If isAutoIncr is set, result.LastInsertId() -// will be used after INSERT to bind the generated id to the Go struct. -// -// Automatically calls ResetSql() to ensure SQL statements are regenerated. -// -// Panics if isAutoIncr is true, and fieldNames length != 1 -// -func (t *TableMap) SetKeys(isAutoIncr bool, fieldNames ...string) *TableMap { - if isAutoIncr && len(fieldNames) != 1 { - panic(fmt.Sprintf( - "gorp: SetKeys: fieldNames length must be 1 if key is auto-increment. (Saw %v fieldNames)", - len(fieldNames))) - } - t.keys = make([]*ColumnMap, 0) - for _, name := range fieldNames { - colmap := t.ColMap(name) - colmap.isPK = true - colmap.isAutoIncr = isAutoIncr - t.keys = append(t.keys, colmap) - } - t.ResetSql() - - return t -} - -// SetUniqueTogether lets you specify uniqueness constraints across multiple -// columns on the table. Each call adds an additional constraint for the -// specified columns. -// -// Automatically calls ResetSql() to ensure SQL statements are regenerated. -// -// Panics if fieldNames length < 2. -// -func (t *TableMap) SetUniqueTogether(fieldNames ...string) *TableMap { - if len(fieldNames) < 2 { - panic(fmt.Sprintf( - "gorp: SetUniqueTogether: must provide at least two fieldNames to set uniqueness constraint.")) - } - - columns := make([]string, 0) - for _, name := range fieldNames { - columns = append(columns, name) - } - t.uniqueTogether = append(t.uniqueTogether, columns) - t.ResetSql() - - return t -} - -// ColMap returns the ColumnMap pointer matching the given struct field -// name. It panics if the struct does not contain a field matching this -// name. -func (t *TableMap) ColMap(field string) *ColumnMap { - col := colMapOrNil(t, field) - if col == nil { - e := fmt.Sprintf("No ColumnMap in table %s type %s with field %s", - t.TableName, t.gotype.Name(), field) - - panic(e) - } - return col -} - -func colMapOrNil(t *TableMap, field string) *ColumnMap { - for _, col := range t.columns { - if col.fieldName == field || col.ColumnName == field { - return col - } - } - return nil -} - -// SetVersionCol sets the column to use as the Version field. By default -// the "Version" field is used. Returns the column found, or panics -// if the struct does not contain a field matching this name. -// -// Automatically calls ResetSql() to ensure SQL statements are regenerated. -func (t *TableMap) SetVersionCol(field string) *ColumnMap { - c := t.ColMap(field) - t.version = c - t.ResetSql() - return c -} - type bindPlan struct { query string argFields []string @@ -491,71 +354,6 @@ func (t *TableMap) bindGet() bindPlan { return plan } -// ColumnMap represents a mapping between a Go struct field and a single -// column in a table. -// Unique and MaxSize only inform the -// CreateTables() function and are not used by Insert/Update/Delete/Get. -type ColumnMap struct { - // Column name in db table - ColumnName string - - // If true, this column is skipped in generated SQL statements - Transient bool - - // If true, " unique" is added to create table statements. - // Not used elsewhere - Unique bool - - // Passed to Dialect.ToSqlType() to assist in informing the - // correct column type to map to in CreateTables() - // Not used elsewhere - MaxSize int - - fieldName string - gotype reflect.Type - isPK bool - isAutoIncr bool - isNotNull bool -} - -// Rename allows you to specify the column name in the table -// -// Example: table.ColMap("Updated").Rename("date_updated") -// -func (c *ColumnMap) Rename(colname string) *ColumnMap { - c.ColumnName = colname - return c -} - -// SetTransient allows you to mark the column as transient. If true -// this column will be skipped when SQL statements are generated -func (c *ColumnMap) SetTransient(b bool) *ColumnMap { - c.Transient = b - return c -} - -// SetUnique adds "unique" to the create table statements for this -// column, if b is true. -func (c *ColumnMap) SetUnique(b bool) *ColumnMap { - c.Unique = b - return c -} - -// SetNotNull adds "not null" to the create table statements for this -// column, if nn is true. -func (c *ColumnMap) SetNotNull(nn bool) *ColumnMap { - c.isNotNull = nn - return c -} - -// SetMaxSize specifies the max length of values of this column. This is -// passed to the dialect.ToSqlType() function, which can use the value -// to alter the generated type for "create table" statements -func (c *ColumnMap) SetMaxSize(size int) *ColumnMap { - c.MaxSize = size - return c -} - // Transaction represents a database transaction. // Insert/Update/Delete/Get/Exec operations will be run in the context // of that transaction. Transactions should be terminated with @@ -578,8 +376,7 @@ type SqlExecutor interface { Update(list ...interface{}) (int64, error) Delete(list ...interface{}) (int64, error) Exec(query string, args ...interface{}) (sql.Result, error) - Select(i interface{}, query string, - args ...interface{}) ([]interface{}, error) + Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) SelectInt(query string, args ...interface{}) (int64, error) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) SelectFloat(query string, args ...interface{}) (float64, error) @@ -587,8 +384,8 @@ type SqlExecutor interface { SelectStr(query string, args ...interface{}) (string, error) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) SelectOne(holder interface{}, query string, args ...interface{}) error - query(query string, args ...interface{}) (*sql.Rows, error) - queryRow(query string, args ...interface{}) *sql.Row + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row } // Compile-time check that DbMap and Transaction implement the SqlExecutor @@ -599,494 +396,6 @@ type GorpLogger interface { Printf(format string, v ...interface{}) } -// TraceOn turns on SQL statement logging for this DbMap. After this is -// called, all SQL statements will be sent to the logger. If prefix is -// a non-empty string, it will be written to the front of all logged -// strings, which can aid in filtering log lines. -// -// Use TraceOn if you want to spy on the SQL statements that gorp -// generates. -// -// Note that the base log.Logger type satisfies GorpLogger, but adapters can -// easily be written for other logging packages (e.g., the golang-sanctioned -// glog framework). -func (m *DbMap) TraceOn(prefix string, logger GorpLogger) { - m.logger = logger - if prefix == "" { - m.logPrefix = prefix - } else { - m.logPrefix = fmt.Sprintf("%s ", prefix) - } -} - -// TraceOff turns off tracing. It is idempotent. -func (m *DbMap) TraceOff() { - m.logger = nil - m.logPrefix = "" -} - -// AddTable registers the given interface type with gorp. The table name -// will be given the name of the TypeOf(i). You must call this function, -// or AddTableWithName, for any struct type you wish to persist with -// the given DbMap. -// -// This operation is idempotent. If i's type is already mapped, the -// existing *TableMap is returned -func (m *DbMap) AddTable(i interface{}) *TableMap { - return m.AddTableWithName(i, "") -} - -// AddTableWithName has the same behavior as AddTable, but sets -// table.TableName to name. -func (m *DbMap) AddTableWithName(i interface{}, name string) *TableMap { - return m.AddTableWithNameAndSchema(i, "", name) -} - -// AddTableWithNameAndSchema has the same behavior as AddTable, but sets -// table.TableName to name. -func (m *DbMap) AddTableWithNameAndSchema(i interface{}, schema string, name string) *TableMap { - t := reflect.TypeOf(i) - if name == "" { - name = t.Name() - } - - // check if we have a table for this type already - // if so, update the name and return the existing pointer - for i := range m.tables { - table := m.tables[i] - if table.gotype == t { - table.TableName = name - return table - } - } - - tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m} - tmap.columns, tmap.version = readStructColumns(t) - m.tables = append(m.tables, tmap) - - return tmap -} - -func readStructColumns(t reflect.Type) (cols []*ColumnMap, version *ColumnMap) { - n := t.NumField() - for i := 0; i < n; i++ { - f := t.Field(i) - if f.Anonymous && f.Type.Kind() == reflect.Struct { - // Recursively add nested fields in embedded structs. - subcols, subversion := readStructColumns(f.Type) - // Don't append nested fields that have the same field - // name as an already-mapped field. - for _, subcol := range subcols { - shouldAppend := true - for _, col := range cols { - if !subcol.Transient && subcol.fieldName == col.fieldName { - shouldAppend = false - break - } - } - if shouldAppend { - cols = append(cols, subcol) - } - } - if subversion != nil { - version = subversion - } - } else { - columnName := f.Tag.Get("db") - if columnName == "" { - columnName = f.Name - } - cm := &ColumnMap{ - ColumnName: columnName, - Transient: columnName == "-", - fieldName: f.Name, - gotype: f.Type, - } - // Check for nested fields of the same field name and - // override them. - shouldAppend := true - for index, col := range cols { - if !col.Transient && col.fieldName == cm.fieldName { - cols[index] = cm - shouldAppend = false - break - } - } - if shouldAppend { - cols = append(cols, cm) - } - if cm.fieldName == "Version" { - version = cm - } - } - } - return -} - -// CreateTables iterates through TableMaps registered to this DbMap and -// executes "create table" statements against the database for each. -// -// This is particularly useful in unit tests where you want to create -// and destroy the schema automatically. -func (m *DbMap) CreateTables() error { - return m.createTables(false) -} - -// CreateTablesIfNotExists is similar to CreateTables, but starts -// each statement with "create table if not exists" so that existing -// tables do not raise errors -func (m *DbMap) CreateTablesIfNotExists() error { - return m.createTables(true) -} - -func (m *DbMap) createTables(ifNotExists bool) error { - var err error - for i := range m.tables { - table := m.tables[i] - - s := bytes.Buffer{} - - if strings.TrimSpace(table.SchemaName) != "" { - schemaCreate := "create schema" - if ifNotExists { - schemaCreate += " if not exists" - } - - s.WriteString(fmt.Sprintf("%s %s;", schemaCreate, table.SchemaName)) - } - - create := "create table" - if ifNotExists { - create += " if not exists" - } - - s.WriteString(fmt.Sprintf("%s %s (", create, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) - x := 0 - for _, col := range table.columns { - if !col.Transient { - if x > 0 { - s.WriteString(", ") - } - stype := m.Dialect.ToSqlType(col.gotype, col.MaxSize, col.isAutoIncr) - s.WriteString(fmt.Sprintf("%s %s", m.Dialect.QuoteField(col.ColumnName), stype)) - - if col.isPK || col.isNotNull { - s.WriteString(" not null") - } - if col.isPK && len(table.keys) == 1 { - s.WriteString(" primary key") - } - if col.Unique { - s.WriteString(" unique") - } - if col.isAutoIncr { - s.WriteString(fmt.Sprintf(" %s", m.Dialect.AutoIncrStr())) - } - - x++ - } - } - if len(table.keys) > 1 { - s.WriteString(", primary key (") - for x := range table.keys { - if x > 0 { - s.WriteString(", ") - } - s.WriteString(m.Dialect.QuoteField(table.keys[x].ColumnName)) - } - s.WriteString(")") - } - if len(table.uniqueTogether) > 0 { - for _, columns := range table.uniqueTogether { - s.WriteString(", unique (") - for i, column := range columns { - if i > 0 { - s.WriteString(", ") - } - s.WriteString(m.Dialect.QuoteField(column)) - } - s.WriteString(")") - } - } - s.WriteString(") ") - s.WriteString(m.Dialect.CreateTableSuffix()) - s.WriteString(";") - _, err = m.Exec(s.String()) - if err != nil { - break - } - } - return err -} - -// DropTable drops an individual table. Will throw an error -// if the table does not exist. -func (m *DbMap) DropTable(table interface{}) error { - t := reflect.TypeOf(table) - return m.dropTable(t, false) -} - -// DropTable drops an individual table. Will NOT throw an error -// if the table does not exist. -func (m *DbMap) DropTableIfExists(table interface{}) error { - t := reflect.TypeOf(table) - return m.dropTable(t, true) -} - -// DropTables iterates through TableMaps registered to this DbMap and -// executes "drop table" statements against the database for each. -func (m *DbMap) DropTables() error { - return m.dropTables(false) -} - -// DropTablesIfExists is the same as DropTables, but uses the "if exists" clause to -// avoid errors for tables that do not exist. -func (m *DbMap) DropTablesIfExists() error { - return m.dropTables(true) -} - -// Goes through all the registered tables, dropping them one by one. -// If an error is encountered, then it is returned and the rest of -// the tables are not dropped. -func (m *DbMap) dropTables(addIfExists bool) (err error) { - for _, table := range m.tables { - err = m.dropTableImpl(table, addIfExists) - if err != nil { - return - } - } - return err -} - -// Implementation of dropping a single table. -func (m *DbMap) dropTable(t reflect.Type, addIfExists bool) error { - table := tableOrNil(m, t) - if table == nil { - return errors.New(fmt.Sprintf("table %s was not registered!", table.TableName)) - } - - return m.dropTableImpl(table, addIfExists) -} - -func (m *DbMap) dropTableImpl(table *TableMap, addIfExists bool) (err error) { - ifExists := "" - if addIfExists { - ifExists = " if exists" - } - _, err = m.Exec(fmt.Sprintf("drop table%s %s;", ifExists, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) - return err -} - -// TruncateTables iterates through TableMaps registered to this DbMap and -// executes "truncate table" statements against the database for each, or in the case of -// sqlite, a "delete from" with no "where" clause, which uses the truncate optimization -// (http://www.sqlite.org/lang_delete.html) -func (m *DbMap) TruncateTables() error { - var err error - for i := range m.tables { - table := m.tables[i] - _, e := m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(), m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) - if e != nil { - err = e - } - } - return err -} - -// Insert runs a SQL INSERT statement for each element in list. List -// items must be pointers. -// -// Any interface whose TableMap has an auto-increment primary key will -// have its last insert id bound to the PK field on the struct. -// -// The hook functions PreInsert() and/or PostInsert() will be executed -// before/after the INSERT statement if the interface defines them. -// -// Panics if any interface in the list has not been registered with AddTable -func (m *DbMap) Insert(list ...interface{}) error { - return insert(m, m, list...) -} - -// Update runs a SQL UPDATE statement for each element in list. List -// items must be pointers. -// -// The hook functions PreUpdate() and/or PostUpdate() will be executed -// before/after the UPDATE statement if the interface defines them. -// -// Returns the number of rows updated. -// -// Returns an error if SetKeys has not been called on the TableMap -// Panics if any interface in the list has not been registered with AddTable -func (m *DbMap) Update(list ...interface{}) (int64, error) { - return update(m, m, list...) -} - -// Delete runs a SQL DELETE statement for each element in list. List -// items must be pointers. -// -// The hook functions PreDelete() and/or PostDelete() will be executed -// before/after the DELETE statement if the interface defines them. -// -// Returns the number of rows deleted. -// -// Returns an error if SetKeys has not been called on the TableMap -// Panics if any interface in the list has not been registered with AddTable -func (m *DbMap) Delete(list ...interface{}) (int64, error) { - return delete(m, m, list...) -} - -// Get runs a SQL SELECT to fetch a single row from the table based on the -// primary key(s) -// -// i should be an empty value for the struct to load. keys should be -// the primary key value(s) for the row to load. If multiple keys -// exist on the table, the order should match the column order -// specified in SetKeys() when the table mapping was defined. -// -// The hook function PostGet() will be executed after the SELECT -// statement if the interface defines them. -// -// Returns a pointer to a struct that matches or nil if no row is found. -// -// Returns an error if SetKeys has not been called on the TableMap -// Panics if any interface in the list has not been registered with AddTable -func (m *DbMap) Get(i interface{}, keys ...interface{}) (interface{}, error) { - return get(m, m, i, keys...) -} - -// Select runs an arbitrary SQL query, binding the columns in the result -// to fields on the struct specified by i. args represent the bind -// parameters for the SQL statement. -// -// Column names on the SELECT statement should be aliased to the field names -// on the struct i. Returns an error if one or more columns in the result -// do not match. It is OK if fields on i are not part of the SQL -// statement. -// -// The hook function PostGet() will be executed after the SELECT -// statement if the interface defines them. -// -// Values are returned in one of two ways: -// 1. If i is a struct or a pointer to a struct, returns a slice of pointers to -// matching rows of type i. -// 2. If i is a pointer to a slice, the results will be appended to that slice -// and nil returned. -// -// i does NOT need to be registered with AddTable() -func (m *DbMap) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { - return hookedselect(m, m, i, query, args...) -} - -// Exec runs an arbitrary SQL statement. args represent the bind parameters. -// This is equivalent to running: Exec() using database/sql -func (m *DbMap) Exec(query string, args ...interface{}) (sql.Result, error) { - m.trace(query, args...) - return m.Db.Exec(query, args...) -} - -// SelectInt is a convenience wrapper around the gorp.SelectInt function -func (m *DbMap) SelectInt(query string, args ...interface{}) (int64, error) { - return SelectInt(m, query, args...) -} - -// SelectNullInt is a convenience wrapper around the gorp.SelectNullInt function -func (m *DbMap) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) { - return SelectNullInt(m, query, args...) -} - -// SelectFloat is a convenience wrapper around the gorp.SelectFlot function -func (m *DbMap) SelectFloat(query string, args ...interface{}) (float64, error) { - return SelectFloat(m, query, args...) -} - -// SelectNullFloat is a convenience wrapper around the gorp.SelectNullFloat function -func (m *DbMap) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) { - return SelectNullFloat(m, query, args...) -} - -// SelectStr is a convenience wrapper around the gorp.SelectStr function -func (m *DbMap) SelectStr(query string, args ...interface{}) (string, error) { - return SelectStr(m, query, args...) -} - -// SelectNullStr is a convenience wrapper around the gorp.SelectNullStr function -func (m *DbMap) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) { - return SelectNullStr(m, query, args...) -} - -// SelectOne is a convenience wrapper around the gorp.SelectOne function -func (m *DbMap) SelectOne(holder interface{}, query string, args ...interface{}) error { - return SelectOne(m, m, holder, query, args...) -} - -// Begin starts a gorp Transaction -func (m *DbMap) Begin() (*Transaction, error) { - m.trace("begin;") - tx, err := m.Db.Begin() - if err != nil { - return nil, err - } - return &Transaction{m, tx, false}, nil -} - -func (m *DbMap) tableFor(t reflect.Type, checkPK bool) (*TableMap, error) { - table := tableOrNil(m, t) - if table == nil { - return nil, errors.New(fmt.Sprintf("No table found for type: %v", t.Name())) - } - - if checkPK && len(table.keys) < 1 { - e := fmt.Sprintf("gorp: No keys defined for table: %s", - table.TableName) - return nil, errors.New(e) - } - - return table, nil -} - -func tableOrNil(m *DbMap, t reflect.Type) *TableMap { - for i := range m.tables { - table := m.tables[i] - if table.gotype == t { - return table - } - } - return nil -} - -func (m *DbMap) tableForPointer(ptr interface{}, checkPK bool) (*TableMap, reflect.Value, error) { - ptrv := reflect.ValueOf(ptr) - if ptrv.Kind() != reflect.Ptr { - e := fmt.Sprintf("gorp: passed non-pointer: %v (kind=%v)", ptr, - ptrv.Kind()) - return nil, reflect.Value{}, errors.New(e) - } - elem := ptrv.Elem() - etype := reflect.TypeOf(elem.Interface()) - t, err := m.tableFor(etype, checkPK) - if err != nil { - return nil, reflect.Value{}, err - } - - return t, elem, nil -} - -func (m *DbMap) queryRow(query string, args ...interface{}) *sql.Row { - m.trace(query, args...) - return m.Db.QueryRow(query, args...) -} - -func (m *DbMap) query(query string, args ...interface{}) (*sql.Rows, error) { - m.trace(query, args...) - return m.Db.Query(query, args...) -} - -func (m *DbMap) trace(query string, args ...interface{}) { - if m.logger != nil { - m.logger.Printf("%s%s %v", m.logPrefix, query, args) - } -} - /////////////// // Insert has the same behavior as DbMap.Insert(), but runs in a transaction. @@ -1207,12 +516,12 @@ func (t *Transaction) ReleaseSavepoint(savepoint string) error { return err } -func (t *Transaction) queryRow(query string, args ...interface{}) *sql.Row { +func (t *Transaction) QueryRow(query string, args ...interface{}) *sql.Row { t.dbmap.trace(query, args...) return t.tx.QueryRow(query, args...) } -func (t *Transaction) query(query string, args ...interface{}) (*sql.Rows, error) { +func (t *Transaction) Query(query string, args ...interface{}) (*sql.Rows, error) { t.dbmap.trace(query, args...) return t.tx.Query(query, args...) } @@ -1225,7 +534,7 @@ func (t *Transaction) query(query string, args ...interface{}) (*sql.Rows, error func SelectInt(e SqlExecutor, query string, args ...interface{}) (int64, error) { var h int64 err := selectVal(e, &h, query, args...) - if err != nil { + if err != nil && err != sql.ErrNoRows { return 0, err } return h, nil @@ -1237,7 +546,7 @@ func SelectInt(e SqlExecutor, query string, args ...interface{}) (int64, error) func SelectNullInt(e SqlExecutor, query string, args ...interface{}) (sql.NullInt64, error) { var h sql.NullInt64 err := selectVal(e, &h, query, args...) - if err != nil { + if err != nil && err != sql.ErrNoRows { return h, err } return h, nil @@ -1249,7 +558,7 @@ func SelectNullInt(e SqlExecutor, query string, args ...interface{}) (sql.NullIn func SelectFloat(e SqlExecutor, query string, args ...interface{}) (float64, error) { var h float64 err := selectVal(e, &h, query, args...) - if err != nil { + if err != nil && err != sql.ErrNoRows { return 0, err } return h, nil @@ -1261,7 +570,7 @@ func SelectFloat(e SqlExecutor, query string, args ...interface{}) (float64, err func SelectNullFloat(e SqlExecutor, query string, args ...interface{}) (sql.NullFloat64, error) { var h sql.NullFloat64 err := selectVal(e, &h, query, args...) - if err != nil { + if err != nil && err != sql.ErrNoRows { return h, err } return h, nil @@ -1273,7 +582,7 @@ func SelectNullFloat(e SqlExecutor, query string, args ...interface{}) (sql.Null func SelectStr(e SqlExecutor, query string, args ...interface{}) (string, error) { var h string err := selectVal(e, &h, query, args...) - if err != nil { + if err != nil && err != sql.ErrNoRows { return "", err } return h, nil @@ -1286,7 +595,7 @@ func SelectStr(e SqlExecutor, query string, args ...interface{}) (string, error) func SelectNullStr(e SqlExecutor, query string, args ...interface{}) (sql.NullString, error) { var h sql.NullString err := selectVal(e, &h, query, args...) - if err != nil { + if err != nil && err != sql.ErrNoRows { return h, err } return h, nil @@ -1344,20 +653,17 @@ func selectVal(e SqlExecutor, holder interface{}, query string, args ...interfac query, args = maybeExpandNamedQuery(m.dbmap, query, args) } } - rows, err := e.query(query, args...) + rows, err := e.Query(query, args...) if err != nil { return err } defer rows.Close() - if rows.Next() { - err = rows.Scan(holder) - if err != nil { - return err - } + if !rows.Next() { + return sql.ErrNoRows } - return nil + return rows.Scan(holder) } /////////////// @@ -1424,7 +730,7 @@ func rawselect(m *DbMap, exec SqlExecutor, i interface{}, query string, } // Run the query - rows, err := exec.query(query, args...) + rows, err := exec.Query(query, args...) if err != nil { return nil, err } @@ -1527,8 +833,8 @@ func maybeExpandNamedQuery(m *DbMap, query string, args []interface{}) (string, switch { case arg.Kind() == reflect.Map && arg.Type().Key().Kind() == reflect.String: return expandNamedQuery(m, query, func(key string) reflect.Value { - return arg.MapIndex(reflect.ValueOf(key)) - }) + return arg.MapIndex(reflect.ValueOf(key)) + }) // #84 - ignore time.Time structs here - there may be a cleaner way to do this case arg.Kind() == reflect.Struct && !(arg.Type().PkgPath() == "time" && arg.Type().Name() == "Time"): return expandNamedQuery(m, query, arg.FieldByName) @@ -1548,15 +854,15 @@ func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect args []interface{} ) return keyRegexp.ReplaceAllStringFunc(query, func(key string) string { - val := keyGetter(key[1:]) - if !val.IsValid() { - return key - } - args = append(args, val.Interface()) - newVar := m.Dialect.BindVar(n) - n++ - return newVar - }), args + val := keyGetter(key[1:]) + if !val.IsValid() { + return key + } + args = append(args, val.Interface()) + newVar := m.Dialect.BindVar(n) + n++ + return newVar + }), args } func columnToFieldIndex(m *DbMap, t reflect.Type, cols []string) ([][]int, error) { @@ -1694,7 +1000,7 @@ func get(m *DbMap, exec SqlExecutor, i interface{}, dest[x] = target } - row := exec.queryRow(plan.query, keys...) + row := exec.QueryRow(plan.query, keys...) err = row.Scan(dest...) if err != nil { if err == sql.ErrNoRows { diff --git a/gorp_test.go b/gorp_test.go index 6876be08..226e8caa 100644 --- a/gorp_test.go +++ b/gorp_test.go @@ -306,7 +306,7 @@ func TestSetUniqueTogether(t *testing.T) { t.Error(err) } // "unique" for Postgres/SQLite, "Duplicate entry" for MySQL - if !strings.Contains(err.Error(), "unique") && !strings.Contains(err.Error(), "Duplicate entry") { + if !stringContainsIgnoreCase(err.Error(), "unique") && !stringContainsIgnoreCase(err.Error(), "Duplicate entry") { t.Error(err) } @@ -317,7 +317,7 @@ func TestSetUniqueTogether(t *testing.T) { t.Error(err) } // "unique" for Postgres/SQLite, "Duplicate entry" for MySQL - if !strings.Contains(err.Error(), "unique") && !strings.Contains(err.Error(), "Duplicate entry") { + if !stringContainsIgnoreCase(err.Error(), "unique") && !stringContainsIgnoreCase(err.Error(), "Duplicate entry") { t.Error(err) } @@ -736,6 +736,113 @@ func TestColumnProps(t *testing.T) { } } +func checkFkProperty(t *testing.T, createTableSql, expected string, required bool) { + if strings.Contains(createTableSql, expected) != required { + t.Errorf("Expected '%s' in:\n%s", expected, createTableSql) + } +} + +func TestUnitFkColumnPropsMysql(t *testing.T) { + dialect := MySQLDialect{"InnoDB", "UTF8"} + dbmap := &DbMap{Dialect: dialect} + dbmap.AddTable(Person{}).SetKeys(true, "Id") + t1 := dbmap.AddTable(Invoice{}).SetKeys(true, "Id") + + t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Restrict).OnUpdate(Cascade)) + table1 := dbmap.createOneTableSql(true, dbmap.tables[1]) + if !strings.Contains(table1, "foreign key (`PersonId`) references `Person` (`Id`)") { + t.Errorf("Expected foreign key reference in:\n%s", table1) + } + checkFkProperty(t, table1, "on update cascade", true) + checkFkProperty(t, table1, "on delete restrict", true) + + t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Delete).OnUpdate(SetNull)) + table1 = dbmap.createOneTableSql(true, dbmap.tables[1]) + checkFkProperty(t, table1, "on update set null", true) + checkFkProperty(t, table1, "on delete delete", true) + + t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(NoAction).OnUpdate(Unspecified)) + table1 = dbmap.createOneTableSql(true, dbmap.tables[1]) + checkFkProperty(t, table1, "on update", false) + checkFkProperty(t, table1, "on delete no action", true) +} + +func TestUnitFkColumnPropsPsql(t *testing.T) { + dialect := PostgresDialect{} + dbmap := &DbMap{Dialect: dialect} + dbmap.AddTable(Person{}).SetKeys(true, "Id") + t1 := dbmap.AddTable(Invoice{}).SetKeys(true, "Id") + + t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Restrict).OnUpdate(Cascade)) + table1 := dbmap.createOneTableSql(true, dbmap.tables[1]) + if !strings.Contains(table1, `"personid" bigint references "person" ("id")`) { + t.Errorf("Expected foreign key reference in:\n%s", table1) + } + checkFkProperty(t, table1, "on update cascade", true) + checkFkProperty(t, table1, "on delete restrict", true) + + t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Delete).OnUpdate(SetNull)) + table1 = dbmap.createOneTableSql(true, dbmap.tables[1]) + checkFkProperty(t, table1, "on update set null", true) + checkFkProperty(t, table1, "on delete delete", true) + + t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(NoAction).OnUpdate(Unspecified)) + table1 = dbmap.createOneTableSql(true, dbmap.tables[1]) + checkFkProperty(t, table1, "on update", false) + checkFkProperty(t, table1, "on delete no action", true) +} + +func TestUnitFkColumnPropsSqlite(t *testing.T) { + dialect := SqliteDialect{} + dbmap := &DbMap{Dialect: dialect} + dbmap.AddTable(Person{}).SetKeys(true, "Id") + t1 := dbmap.AddTable(Invoice{}).SetKeys(true, "Id") + + t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Restrict).OnUpdate(Cascade)) + table1 := dbmap.createOneTableSql(true, dbmap.tables[1]) + if !strings.Contains(table1, `foreign key ("PersonId") references "Person" ("Id")`) { + t.Errorf("Expected foreign key reference in:\n%s", table1) + } + checkFkProperty(t, table1, "on update cascade", true) + checkFkProperty(t, table1, "on delete restrict", true) + + t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Delete).OnUpdate(SetNull)) + table1 = dbmap.createOneTableSql(true, dbmap.tables[1]) + checkFkProperty(t, table1, "on update set null", true) + checkFkProperty(t, table1, "on delete delete", true) + + t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(NoAction).OnUpdate(Unspecified)) + table1 = dbmap.createOneTableSql(true, dbmap.tables[1]) + checkFkProperty(t, table1, "on update", false) + checkFkProperty(t, table1, "on delete no action", true) +} + +func TestFkColumnProps(t *testing.T) { + dbmap := newDbMap() + dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds)) + dbmap.AddTable(Person{}).SetKeys(true, "Id") + t1 := dbmap.AddTable(Invoice{}).SetKeys(true, "Id") + t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Restrict).OnUpdate(Cascade)) + // note: "on update" is not yet tested + + err := dbmap.CreateTables() + if err != nil { + log.Fatalln(err) + } + defer dropAndClose(dbmap) + + person := &Person{0, 0, 0, "John", "Cooper", 0} + _insert(dbmap, person) + + inv := &Invoice{0, 0, 1, "my invoice", person.Id, true} + _insert(dbmap, inv) + + n, err := dbmap.Delete(person) + if err == nil { + t.Errorf("Restricted delete failed; deleted %d. %d\n", n, person.Id) + } +} + func TestRawSelect(t *testing.T) { dbmap := initDbMap() defer dropAndClose(dbmap) @@ -1408,7 +1515,20 @@ func TestSelectSingleVal(t *testing.T) { _insert(dbmap, &Person{0, 0, 0, "bob", "smith", 0}) err = dbmap.SelectOne(&p2, "select * from person_test where Fname='bob'") if err == nil { - t.Error("Expected nil when two rows found") + t.Error("Expected error when two rows found") + } + + // tests for #150 + var tInt int64 + var tStr string + var tBool bool + var tFloat float64 + primVals := []interface{}{tInt, tStr, tBool, tFloat} + for _, prim := range primVals { + err = dbmap.SelectOne(&prim, "select * from person_test where Id=-123") + if err == nil || err != sql.ErrNoRows { + t.Error("primVals: SelectOne should have returned sql.ErrNoRows") + } } } @@ -1542,9 +1662,9 @@ func initDbMapBench() *DbMap { func initDbMap() *DbMap { dbmap := newDbMap() dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds)) + dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id") dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id") dbmap.AddTableWithName(OverriddenInvoice{}, "invoice_override_test").SetKeys(false, "Id") - dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id") dbmap.AddTableWithName(WithIgnoredColumn{}, "ignored_column_test").SetKeys(true, "Id") dbmap.AddTableWithName(TypeConversionExample{}, "type_conv_test").SetKeys(true, "Id") dbmap.AddTableWithName(WithEmbeddedStruct{}, "embedded_struct_test").SetKeys(true, "Id") @@ -1577,7 +1697,10 @@ func newDbMap() *DbMap { } func dropAndClose(dbmap *DbMap) { - dbmap.DropTablesIfExists() + err := dbmap.DropTablesIfExists() + if err != nil { + log.Println(err) + } dbmap.Db.Close() } @@ -1710,3 +1833,9 @@ func _rawselect(dbmap *DbMap, i interface{}, query string, args ...interface{}) } return list } + +func stringContainsIgnoreCase(value, lookingFor string) bool { + valueLC := strings.ToLower(value) + lookingForLC := strings.ToLower(lookingFor) + return strings.Contains(valueLC, lookingForLC) +} diff --git a/schema.go b/schema.go new file mode 100644 index 00000000..7aef30a1 --- /dev/null +++ b/schema.go @@ -0,0 +1,572 @@ +// Copyright 2012 James Cooper. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +// Package gorp provides a simple way to marshal Go structs to and from +// SQL databases. It uses the database/sql package, and should work with any +// compliant database/sql driver. +// +// Source code and project home: +// https://github.com/coopernurse/gorp +// +package gorp + +import ( + "bytes" + "database/sql" + "errors" + "fmt" + "reflect" + "strings" +) + +// DbMap is the root gorp mapping object. Create one of these for each +// database schema you wish to map. Each DbMap contains a list of +// mapped tables. +// +// Example: +// +// dialect := gorp.MySQLDialect{"InnoDB", "UTF8"} +// dbmap := &gorp.DbMap{Db: db, Dialect: dialect} +// +type DbMap struct { + // Db handle to use with this map + Db *sql.DB + + // Dialect implementation to use with this map + Dialect Dialect + + TypeConverter TypeConverter + + tables []*TableMap + logger GorpLogger + logPrefix string + initialised bool +} + +// AddTable registers the given interface type with gorp. The table name +// will be given the name of the TypeOf(i). You must call this function, +// or AddTableWithName, for any struct type you wish to persist with +// the given DbMap. +// +// This operation is idempotent. If i's type is already mapped, the +// existing *TableMap is returned +func (m *DbMap) AddTable(i interface{}) *TableMap { + return m.AddTableWithName(i, "") +} + +// AddTableWithName has the same behavior as AddTable, but sets +// table.TableName to name. +func (m *DbMap) AddTableWithName(i interface{}, name string) *TableMap { + return m.AddTableWithNameAndSchema(i, "", name) +} + +// AddTableWithNameAndSchema has the same behavior as AddTable, but sets +// table.TableName to name. +func (m *DbMap) AddTableWithNameAndSchema(i interface{}, schema string, name string) *TableMap { + t := reflect.TypeOf(i) + if name == "" { + name = t.Name() + } + + // check if we have a table for this type already + // if so, update the name and return the existing pointer + for i := range m.tables { + table := m.tables[i] + if table.gotype == t { + table.TableName = name + return table + } + } + + tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m} + tmap.columns, tmap.version = readStructColumns(t) + m.tables = append(m.tables, tmap) + + return tmap +} + +func readStructColumns(t reflect.Type) (cols []*ColumnMap, version *ColumnMap) { + n := t.NumField() + for i := 0; i < n; i++ { + f := t.Field(i) + if f.Anonymous && f.Type.Kind() == reflect.Struct { + // Recursively add nested fields in embedded structs. + subcols, subversion := readStructColumns(f.Type) + // Don't append nested fields that have the same field + // name as an already-mapped field. + for _, subcol := range subcols { + shouldAppend := true + for _, col := range cols { + if !subcol.Transient && subcol.fieldName == col.fieldName { + shouldAppend = false + break + } + } + if shouldAppend { + cols = append(cols, subcol) + } + } + if subversion != nil { + version = subversion + } + } else { + columnName := f.Tag.Get("db") + if columnName == "" { + columnName = f.Name + } + cm := &ColumnMap{ + ColumnName: columnName, + Transient: columnName == "-", + fieldName: f.Name, + gotype: f.Type, + } + // Check for nested fields of the same field name and + // override them. + shouldAppend := true + for index, col := range cols { + if !col.Transient && col.fieldName == cm.fieldName { + cols[index] = cm + shouldAppend = false + break + } + } + if shouldAppend { + cols = append(cols, cm) + } + if cm.fieldName == "Version" { + version = cm + } + } + } + return +} + +// CreateTables iterates through TableMaps registered to this DbMap and +// executes "create table" statements against the database for each. +// +// This is particularly useful in unit tests where you want to create +// and destroy the schema automatically. +func (m *DbMap) CreateTables() error { + return m.createTables(false) +} + +// CreateTablesIfNotExists is similar to CreateTables, but starts +// each statement with "create table if not exists" so that existing +// tables do not raise errors +func (m *DbMap) CreateTablesIfNotExists() error { + return m.createTables(true) +} + +func (m *DbMap) createTables(ifNotExists bool) error { + for _, t := range m.tables { + ddl := m.createOneTableSql(ifNotExists, t) + _, err := m.Exec(ddl) + if err != nil { + return err + } + } + return nil +} + +func (m *DbMap) createOneTableSql(ifNotExists bool, table *TableMap) string { + s := bytes.Buffer{} + + if strings.TrimSpace(table.SchemaName) != "" { + s.WriteString("create schema ") + if ifNotExists { + s.WriteString("if not exists ") + } + + s.WriteString(table.SchemaName) + s.WriteString(";") + } + + s.WriteString("create table ") + if ifNotExists { + s.WriteString("if not exists ") + } + tableName := m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName) + s.WriteString(fmt.Sprintf("%s (\n ", tableName)) + + s = m.createOneTableColumns(table, s) + s = m.createOneTablePrimaryKeys(table, s) + s = m.createOneTableIndexes(table, s) + s = m.createOneTableForeignKeys(table, s) + + s.WriteString("\n) ") + s.WriteString(m.Dialect.CreateTableSuffix()) + s.WriteString(";") + return s.String() +} + +func (m *DbMap) createOneTableColumns(table *TableMap, s bytes.Buffer) bytes.Buffer { + x := 0 + for _, col := range table.columns { + if !col.Transient { + if x > 0 { + s.WriteString(",\n ") + } + field := m.Dialect.QuoteField(col.ColumnName) + stype := m.Dialect.ToSqlType(col.gotype, col.MaxSize, col.isAutoIncr) + s.WriteString(fmt.Sprintf("%s %s", field, stype)) + + if col.isPK || col.isNotNull { + s.WriteString(" not null") + } + if col.isPK && len(table.keys) == 1 { + s.WriteString(" primary key") + } + if col.Unique { + s.WriteString(" unique") + } + if col.isAutoIncr { + s.WriteString(" " + m.Dialect.AutoIncrStr()) + } + if col.References != nil { + s.WriteString(m.Dialect.CreateForeignKeySuffix(col.References)) + } + + x++ + } + } + return s +} + +func (m *DbMap) createOneTablePrimaryKeys(table *TableMap, s bytes.Buffer) bytes.Buffer { + if len(table.keys) > 1 { + s.WriteString(",\n primary key (") + for x := range table.keys { + if x > 0 { + s.WriteString(", ") + } + s.WriteString(m.Dialect.QuoteField(table.keys[x].ColumnName)) + } + s.WriteString(")") + } + return s +} + +func (m *DbMap) createOneTableIndexes(table *TableMap, s bytes.Buffer) bytes.Buffer { + if len(table.uniqueTogether) > 0 { + for _, columns := range table.uniqueTogether { + s.WriteString(",\n unique (") + for i, column := range columns { + if i > 0 { + s.WriteString(", ") + } + s.WriteString(m.Dialect.QuoteField(column)) + } + s.WriteString(")") + } + } + return s +} + +func (m *DbMap) createOneTableForeignKeys(table *TableMap, s bytes.Buffer) bytes.Buffer { + for _, col := range table.columns { + if !col.Transient && col.References != nil { + fkBlock := m.Dialect.CreateForeignKeyBlock(col) + if fkBlock != "" { + s.WriteString(",\n ") + s.WriteString(fkBlock) + } + } + } + return s +} + +// DropTable drops an individual table. Will throw an error +// if the table does not exist. +func (m *DbMap) DropTable(table interface{}) error { + t := reflect.TypeOf(table) + return m.dropTable(t, false) +} + +// DropTable drops an individual table. Will NOT throw an error +// if the table does not exist. +func (m *DbMap) DropTableIfExists(table interface{}) error { + t := reflect.TypeOf(table) + return m.dropTable(t, true) +} + +// DropTables iterates through TableMaps registered to this DbMap and +// executes "drop table" statements against the database for each. +func (m *DbMap) DropTables() error { + return m.dropTables(false) +} + +// DropTablesIfExists is the same as DropTables, but uses the "if exists" clause to +// avoid errors for tables that do not exist. +func (m *DbMap) DropTablesIfExists() error { + return m.dropTables(true) +} + +// Goes through all the registered tables, dropping them one by one. +// If an error is encountered, then it is returned and the rest of +// the tables are not dropped. +func (m *DbMap) dropTables(addIfExists bool) (err error) { + // drop in reverse order, assuming that foreign keys were created in + // the order that the tables were created + n := len(m.tables) - 1 + for i, _ := range m.tables { + table := m.tables[n - i] + err = m.dropTableImpl(table, addIfExists) + if err != nil { + return + } + } + return +} + +// Implementation of dropping a single table. +func (m *DbMap) dropTable(t reflect.Type, addIfExists bool) error { + table := tableOrNil(m, t) + if table == nil { + return errors.New(fmt.Sprintf("table %s was not registered!", table.TableName)) + } + + return m.dropTableImpl(table, addIfExists) +} + +func (m *DbMap) dropTableImpl(table *TableMap, addIfExists bool) (err error) { + ifExists := "" + if addIfExists { + ifExists = " if exists" + } + _, err = m.Exec(fmt.Sprintf("drop table%s %s;", ifExists, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) + return err +} + +// TableMap represents a mapping between a Go struct and a database table +// Use dbmap.AddTable() or dbmap.AddTableWithName() to create these +type TableMap struct { + // Name of database table. + TableName string + SchemaName string + gotype reflect.Type + columns []*ColumnMap + keys []*ColumnMap + uniqueTogether [][]string + version *ColumnMap + insertPlan bindPlan + updatePlan bindPlan + deletePlan bindPlan + getPlan bindPlan + dbmap *DbMap +} + +// ResetSql removes cached insert/update/select/delete SQL strings +// associated with this TableMap. Call this if you've modified +// any column names or the table name itself. +func (t *TableMap) ResetSql() { + t.insertPlan = bindPlan{} + t.updatePlan = bindPlan{} + t.deletePlan = bindPlan{} + t.getPlan = bindPlan{} +} + +// SetKeys lets you specify the fields on a struct that map to primary +// key columns on the table. If isAutoIncr is set, result.LastInsertId() +// will be used after INSERT to bind the generated id to the Go struct. +// +// Automatically calls ResetSql() to ensure SQL statements are regenerated. +// +// Panics if isAutoIncr is true, and fieldNames length != 1 +// +func (t *TableMap) SetKeys(isAutoIncr bool, fieldNames ...string) *TableMap { + if isAutoIncr && len(fieldNames) != 1 { + panic(fmt.Sprintf( + "gorp: SetKeys: fieldNames length must be 1 if key is auto-increment. (Saw %v fieldNames)", + len(fieldNames))) + } + t.keys = make([]*ColumnMap, 0) + for _, name := range fieldNames { + colmap := t.ColMap(name) + colmap.isPK = true + colmap.isAutoIncr = isAutoIncr + t.keys = append(t.keys, colmap) + } + t.ResetSql() + + return t +} + +// SetUniqueTogether lets you specify uniqueness constraints across multiple +// columns on the table. Each call adds an additional constraint for the +// specified columns. +// +// Automatically calls ResetSql() to ensure SQL statements are regenerated. +// +// Panics if fieldNames length < 2. +// +func (t *TableMap) SetUniqueTogether(fieldNames ...string) *TableMap { + if len(fieldNames) < 2 { + panic(fmt.Sprintf( + "gorp: SetUniqueTogether: must provide at least two fieldNames to set uniqueness constraint.")) + } + + columns := make([]string, 0) + for _, name := range fieldNames { + columns = append(columns, name) + } + t.uniqueTogether = append(t.uniqueTogether, columns) + t.ResetSql() + + return t +} + +// ColMap returns the ColumnMap pointer matching the given struct field +// name. It panics if the struct does not contain a field matching this +// name. +func (t *TableMap) ColMap(field string) *ColumnMap { + col := colMapOrNil(t, field) + if col == nil { + e := fmt.Sprintf("No ColumnMap in table %s type %s with field %s", + t.TableName, t.gotype.Name(), field) + + panic(e) + } + return col +} + +func colMapOrNil(t *TableMap, field string) *ColumnMap { + for _, col := range t.columns { + if col.fieldName == field || col.ColumnName == field { + return col + } + } + return nil +} + +// SetVersionCol sets the column to use as the Version field. By default +// the "Version" field is used. Returns the column found, or panics +// if the struct does not contain a field matching this name. +// +// Automatically calls ResetSql() to ensure SQL statements are regenerated. +func (t *TableMap) SetVersionCol(field string) *ColumnMap { + c := t.ColMap(field) + t.version = c + t.ResetSql() + return c +} + +// ColumnMap represents a mapping between a Go struct field and a single +// column in a table. +// Unique and MaxSize only inform the +// CreateTables() function and are not used by Insert/Update/Delete/Get. +type ColumnMap struct { + // Column name in db table + ColumnName string + + // If true, this column is skipped in generated SQL statements + Transient bool + + // If true, " unique" is added to create table statements. + // Not used elsewhere + Unique bool + + // Passed to Dialect.ToSqlType() to assist in informing the + // correct column type to map to in CreateTables() + // Not used elsewhere + MaxSize int + + // If present, specifies that this column is a foreign key that + // references another column of another table. + References *ForeignKey + + fieldName string + gotype reflect.Type + isPK bool + isAutoIncr bool + isNotNull bool +} + +// Rename allows you to specify the column name in the table +// +// Example: table.ColMap("Updated").Rename("date_updated") +// +func (c *ColumnMap) Rename(colname string) *ColumnMap { + c.ColumnName = colname + return c +} + +// SetTransient allows you to mark the column as transient. If true +// this column will be skipped when SQL statements are generated +func (c *ColumnMap) SetTransient(b bool) *ColumnMap { + c.Transient = b + return c +} + +// SetUnique adds "unique" to the create table statements for this +// column, if b is true. +func (c *ColumnMap) SetUnique(b bool) *ColumnMap { + c.Unique = b + return c +} + +// SetNotNull adds "not null" to the create table statements for this +// column, if nn is true. +func (c *ColumnMap) SetNotNull(nn bool) *ColumnMap { + c.isNotNull = nn + return c +} + +// SetMaxSize specifies the max length of values of this column. This is +// passed to the dialect.ToSqlType() function, which can use the value +// to alter the generated type for "create table" statements +func (c *ColumnMap) SetMaxSize(size int) *ColumnMap { + c.MaxSize = size + return c +} + +// SetForeignKey specifies the foreign-key relationship between this column +// and a column in another table. +func (c *ColumnMap) SetForeignKey(fk *ForeignKey) *ColumnMap { + c.References = fk + return c +} + +// Specifies what foreign-key constraints will be enforced by the database. +type FKOnChangeAction int + +const ( + Unspecified FKOnChangeAction = iota + NoAction + Restrict + Cascade + SetNull + //SetDefault // may not be supported by MySql + Delete +) + +// ForeignKey specifies the relationship formed when one column refers to the +// primary key of another table. +type ForeignKey struct { + ReferencedTable string + ReferencedColumn string + ActionOnDelete FKOnChangeAction + ActionOnUpdate FKOnChangeAction +} + +// NewForeignKey creates a new ForeignKey for a specified table/column reference. If +// the table is part of a named schema, include the schema prefix in the referencedTable +// value. +func NewForeignKey(referencedTable, referencedColumn string) *ForeignKey { + return &ForeignKey{referencedTable, referencedColumn, Unspecified, Unspecified} +} + +// Sets the action that the database is to perform when the parent record +// is updated. The default is usually Restrict. +func (fk *ForeignKey) OnUpdate(action FKOnChangeAction) *ForeignKey { + fk.ActionOnUpdate = action + return fk +} + +// Sets the action that the database is to perform when the parent record +// is deleted. The default is usually Restrict. +func (fk *ForeignKey) OnDelete(action FKOnChangeAction) *ForeignKey { + fk.ActionOnDelete = action + return fk +} + diff --git a/test_all.sh b/test_all.sh index f870b39a..edabf7a2 100755 --- a/test_all.sh +++ b/test_all.sh @@ -3,20 +3,40 @@ # on macs, you may need to: # export GOBUILDFLAG=-ldflags -linkmode=external -set -e +# Using "-n", the environment variables will be set but tests will not run. +# You can "dot" this script, then use go test directly, which will give verbose progress info. +if [ "$1" = "-n" ]; then + NOOP=":" + shift +fi -export GORP_TEST_DSN=gorptest/gorptest/gorptest -export GORP_TEST_DIALECT=mysql -go test $GOBUILDFLAG . +set -e -export GORP_TEST_DSN=gorptest:gorptest@/gorptest -export GORP_TEST_DIALECT=gomysql -go test $GOBUILDFLAG . +testMysql() { + export GORP_TEST_DSN=gorptest/gorptest/gorptest + export GORP_TEST_DIALECT=mysql + $NOOP go test $GOBUILDFLAG . -export GORP_TEST_DSN="user=gorptest password=gorptest dbname=gorptest sslmode=disable" -export GORP_TEST_DIALECT=postgres -go test $GOBUILDFLAG . + export GORP_TEST_DSN=gorptest:gorptest@/gorptest + export GORP_TEST_DIALECT=gomysql + $NOOP go test $GOBUILDFLAG . +} -export GORP_TEST_DSN=/tmp/gorptest.bin -export GORP_TEST_DIALECT=sqlite -go test $GOBUILDFLAG . +testPostgresql() { + export GORP_TEST_DSN="user=gorptest password=gorptest dbname=gorptest sslmode=disable" + export GORP_TEST_DIALECT=postgres + $NOOP go test $GOBUILDFLAG . +} + +testSqlite() { + export GORP_TEST_DSN=/tmp/gorptest.bin + export GORP_TEST_DIALECT=sqlite + $NOOP go test $GOBUILDFLAG . +} + +case "$1" in + mysql) testMysql ;; + psql | postgresql) testPostgresql ;; + sqlite) testSqlite ;; + *) testMysql ; testPostgresql ; testSqlite ;; +esac