diff --git a/README.md b/README.md
index e7bb424b..7aa537d5 100644
--- a/README.md
+++ b/README.md
@@ -226,9 +226,13 @@ dbmap := &gorp.DbMap{Db: db, Dialect: gorp.MySQLDialect{"InnoDB", "UTF8"}}
// SetKeys(true) means we have a auto increment primary key, which
// will get automatically bound to your struct post-insert
//
-t1 := dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id")
-t2 := dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id")
+t1 := dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id")
+t2 := dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id")
t3 := dbmap.AddTableWithName(Product{}, "product_test").SetKeys(true, "Id")
+
+// SetForeignKey will declare that Invoice.PersonId is a foreign key for
+// Person.Id, and delete/update actions are to be treated as specified.
+t2.ColMap("PersonId").SetForeignKey(gorp.NewForeignKey("Person", "Id").OnDelete(gorp.Restrict).OnUpdate(gorp.Cascade))
```
### Struct Embedding ###
@@ -450,6 +454,15 @@ func InsertInv(dbmap *DbMap, inv *Invoice, per *Person) error {
}
```
+### Foreign Keys ###
+
+You can define the foreign-key relationships when you create the table schema.
+The `ColumnMap` has a `SetForeignKey` method to do this, shown in Examples above.
+
+Gorp only uses this when `CreateTables` or `CreateTablesIfNotExists` are invoked.
+Thereafter, Gorp plays no further part because the database itself enforces the
+consistency of keys, returning an error any time that a constraint would be violated.
+
### Hooks ###
Use hooks to update data before/after saving to the db. Good for timestamps:
@@ -644,3 +657,4 @@ Thanks!
* matthias-margush - column aliasing via tags
* Rob Figueiredo - @robfig
* Quinn Slack - @sqs
+* Rick Beton - @rickb777 - foreign keys
diff --git a/dbmap.go b/dbmap.go
new file mode 100644
index 00000000..8af9a53b
--- /dev/null
+++ b/dbmap.go
@@ -0,0 +1,280 @@
+// Copyright 2012 James Cooper. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+// Package gorp provides a simple way to marshal Go structs to and from
+// SQL databases. It uses the database/sql package, and should work with any
+// compliant database/sql driver.
+//
+// Source code and project home:
+// https://github.com/coopernurse/gorp
+//
+package gorp
+
+import (
+ "database/sql"
+ "errors"
+ "fmt"
+ "reflect"
+)
+
+// TraceOn turns on SQL statement logging for this DbMap. After this is
+// called, all SQL statements will be sent to the logger. If prefix is
+// a non-empty string, it will be written to the front of all logged
+// strings, which can aid in filtering log lines.
+//
+// Use TraceOn if you want to spy on the SQL statements that gorp
+// generates.
+//
+// Note that the base log.Logger type satisfies GorpLogger, but adapters can
+// easily be written for other logging packages (e.g., the golang-sanctioned
+// glog framework).
+func (m *DbMap) TraceOn(prefix string, logger GorpLogger) {
+ m.logger = logger
+ if prefix == "" {
+ m.logPrefix = prefix
+ } else {
+ m.logPrefix = fmt.Sprintf("%s ", prefix)
+ }
+}
+
+// TraceOff turns off tracing. It is idempotent.
+func (m *DbMap) TraceOff() {
+ m.logger = nil
+ m.logPrefix = ""
+}
+
+// TruncateTables iterates through TableMaps registered to this DbMap and
+// executes "truncate table" statements against the database for each, or in the case of
+// sqlite, a "delete from" with no "where" clause, which uses the truncate optimization
+// (http://www.sqlite.org/lang_delete.html)
+func (m *DbMap) TruncateTables() error {
+ var err error
+ for i := range m.tables {
+ table := m.tables[i]
+ _, e := m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(), m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName)))
+ if e != nil {
+ err = e
+ }
+ }
+ return err
+}
+
+// Insert runs a SQL INSERT statement for each element in list. List
+// items must be pointers.
+//
+// Any interface whose TableMap has an auto-increment primary key will
+// have its last insert id bound to the PK field on the struct.
+//
+// The hook functions PreInsert() and/or PostInsert() will be executed
+// before/after the INSERT statement if the interface defines them.
+//
+// Panics if any interface in the list has not been registered with AddTable
+func (m *DbMap) Insert(list ...interface{}) error {
+ return insert(m, m, list...)
+}
+
+// Update runs a SQL UPDATE statement for each element in list. List
+// items must be pointers.
+//
+// The hook functions PreUpdate() and/or PostUpdate() will be executed
+// before/after the UPDATE statement if the interface defines them.
+//
+// Returns the number of rows updated.
+//
+// Returns an error if SetKeys has not been called on the TableMap
+// Panics if any interface in the list has not been registered with AddTable
+func (m *DbMap) Update(list ...interface{}) (int64, error) {
+ return update(m, m, list...)
+}
+
+// Delete runs a SQL DELETE statement for each element in list. List
+// items must be pointers.
+//
+// The hook functions PreDelete() and/or PostDelete() will be executed
+// before/after the DELETE statement if the interface defines them.
+//
+// Returns the number of rows deleted.
+//
+// Returns an error if SetKeys has not been called on the TableMap
+// Panics if any interface in the list has not been registered with AddTable
+func (m *DbMap) Delete(list ...interface{}) (int64, error) {
+ return delete(m, m, list...)
+}
+
+// Get runs a SQL SELECT to fetch a single row from the table based on the
+// primary key(s)
+//
+// i should be an empty value for the struct to load. keys should be
+// the primary key value(s) for the row to load. If multiple keys
+// exist on the table, the order should match the column order
+// specified in SetKeys() when the table mapping was defined.
+//
+// The hook function PostGet() will be executed after the SELECT
+// statement if the interface defines them.
+//
+// Returns a pointer to a struct that matches or nil if no row is found.
+//
+// Returns an error if SetKeys has not been called on the TableMap
+// Panics if any interface in the list has not been registered with AddTable
+func (m *DbMap) Get(i interface{}, keys ...interface{}) (interface{}, error) {
+ return get(m, m, i, keys...)
+}
+
+// Select runs an arbitrary SQL query, binding the columns in the result
+// to fields on the struct specified by i. args represent the bind
+// parameters for the SQL statement.
+//
+// Column names on the SELECT statement should be aliased to the field names
+// on the struct i. Returns an error if one or more columns in the result
+// do not match. It is OK if fields on i are not part of the SQL
+// statement.
+//
+// The hook function PostGet() will be executed after the SELECT
+// statement if the interface defines them.
+//
+// Values are returned in one of two ways:
+// 1. If i is a struct or a pointer to a struct, returns a slice of pointers to
+// matching rows of type i.
+// 2. If i is a pointer to a slice, the results will be appended to that slice
+// and nil returned.
+//
+// i does NOT need to be registered with AddTable()
+func (m *DbMap) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
+ return hookedselect(m, m, i, query, args...)
+}
+
+// Exec runs an arbitrary SQL statement. args represent the bind parameters.
+// This is equivalent to running: Exec() using database/sql
+func (m *DbMap) Exec(query string, args ...interface{}) (sql.Result, error) {
+ err := m.initialise()
+ if err != nil {
+ return nil, err
+ }
+ m.trace(query, args...)
+ return m.Db.Exec(query, args...)
+}
+
+// SelectInt is a convenience wrapper around the gorp.SelectInt function
+func (m *DbMap) SelectInt(query string, args ...interface{}) (int64, error) {
+ return SelectInt(m, query, args...)
+}
+
+// SelectNullInt is a convenience wrapper around the gorp.SelectNullInt function
+func (m *DbMap) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
+ return SelectNullInt(m, query, args...)
+}
+
+// SelectFloat is a convenience wrapper around the gorp.SelectFlot function
+func (m *DbMap) SelectFloat(query string, args ...interface{}) (float64, error) {
+ return SelectFloat(m, query, args...)
+}
+
+// SelectNullFloat is a convenience wrapper around the gorp.SelectNullFloat function
+func (m *DbMap) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
+ return SelectNullFloat(m, query, args...)
+}
+
+// SelectStr is a convenience wrapper around the gorp.SelectStr function
+func (m *DbMap) SelectStr(query string, args ...interface{}) (string, error) {
+ return SelectStr(m, query, args...)
+}
+
+// SelectNullStr is a convenience wrapper around the gorp.SelectNullStr function
+func (m *DbMap) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
+ return SelectNullStr(m, query, args...)
+}
+
+// SelectOne is a convenience wrapper around the gorp.SelectOne function
+func (m *DbMap) SelectOne(holder interface{}, query string, args ...interface{}) error {
+ return SelectOne(m, m, holder, query, args...)
+}
+
+// Begin starts a gorp Transaction
+func (m *DbMap) Begin() (*Transaction, error) {
+ m.trace("begin;")
+ tx, err := m.Db.Begin()
+ if err != nil {
+ return nil, err
+ }
+ return &Transaction{m, tx, false}, nil
+}
+
+func (m *DbMap) tableFor(t reflect.Type, checkPK bool) (*TableMap, error) {
+ table := tableOrNil(m, t)
+ if table == nil {
+ return nil, errors.New(fmt.Sprintf("No table found for type: %v", t.Name()))
+ }
+
+ if checkPK && len(table.keys) < 1 {
+ e := fmt.Sprintf("gorp: No keys defined for table: %s",
+ table.TableName)
+ return nil, errors.New(e)
+ }
+
+ return table, nil
+}
+
+func tableOrNil(m *DbMap, t reflect.Type) *TableMap {
+ for i := range m.tables {
+ table := m.tables[i]
+ if table.gotype == t {
+ return table
+ }
+ }
+ return nil
+}
+
+func (m *DbMap) tableForPointer(ptr interface{}, checkPK bool) (*TableMap, reflect.Value, error) {
+ ptrv := reflect.ValueOf(ptr)
+ if ptrv.Kind() != reflect.Ptr {
+ e := fmt.Sprintf("gorp: passed non-pointer: %v (kind=%v)", ptr,
+ ptrv.Kind())
+ return nil, reflect.Value{}, errors.New(e)
+ }
+ elem := ptrv.Elem()
+ etype := reflect.TypeOf(elem.Interface())
+ t, err := m.tableFor(etype, checkPK)
+ if err != nil {
+ return nil, reflect.Value{}, err
+ }
+
+ return t, elem, nil
+}
+
+func (m *DbMap) QueryRow(query string, args ...interface{}) *sql.Row {
+ err := m.initialise()
+ if err != nil {
+ panic(err)
+ }
+ m.initialise()
+ m.trace(query, args...)
+ return m.Db.QueryRow(query, args...)
+}
+
+func (m *DbMap) Query(query string, args ...interface{}) (*sql.Rows, error) {
+ err := m.initialise()
+ if err != nil {
+ return nil, err
+ }
+ m.trace(query, args...)
+ return m.Db.Query(query, args...)
+}
+
+func (m *DbMap) trace(query string, args ...interface{}) {
+ if m.logger != nil {
+ m.logger.Printf("%s%s %v", m.logPrefix, query, args)
+ }
+}
+
+func (m *DbMap) initialise() (err error) {
+ if !m.initialised {
+ m.initialised = true
+ if m.Dialect.InitString() != "" {
+ m.trace(m.Dialect.InitString())
+ _, err = m.Db.Exec(m.Dialect.InitString())
+ }
+ }
+ return
+}
+
diff --git a/dialect.go b/dialect.go
index 6b6ef0e8..ccccf58b 100644
--- a/dialect.go
+++ b/dialect.go
@@ -25,6 +25,12 @@ type Dialect interface {
AutoIncrInsertSuffix(col *ColumnMap) string
+ // Creates the trailing foreign key reference in a column specification.
+ CreateForeignKeySuffix(references *ForeignKey) string
+
+ // Creates the separate foreign key reference for a column.
+ CreateForeignKeyBlock(col *ColumnMap) string
+
// string to append to "create table" statement for vendor specific
// table attributes
CreateTableSuffix() string
@@ -51,6 +57,11 @@ type Dialect interface {
// schema - The schema that
lives in
// table - The table name
QuotedTableForQuery(schema string, table string) string
+
+ // Sends an initialisation instruction when connecting to the database.
+ // Primarily, this exists for Sqlite3 because foreign keys are disable
+ // by default, unlike Postgresql and Mysql InnoDB.
+ InitString() string
}
func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
@@ -61,6 +72,19 @@ func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interf
return res.LastInsertId()
}
+func standardOnChangeStr(change string, action FKOnChangeAction) string {
+ prefix := "\n "
+ switch action {
+ case Unspecified: return ""
+ case NoAction: return prefix + "on " + change + " no action"
+ case Restrict: return prefix + "on " + change + " restrict"
+ case Cascade: return prefix + "on " + change + " cascade"
+ case SetNull: return prefix + "on " + change + " set null"
+ case Delete: return prefix + "on " + change + " delete"
+ }
+ return ""
+}
+
///////////////////////////////////////////////////////
// sqlite3 //
/////////////
@@ -115,6 +139,19 @@ func (d SqliteDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return ""
}
+func (d SqliteDialect) CreateForeignKeySuffix(references *ForeignKey) string {
+ return ""
+}
+
+func (d SqliteDialect) CreateForeignKeyBlock(col *ColumnMap) string {
+ return fmt.Sprintf("foreign key (%s) references %s (%s)",
+ d.QuoteField(col.ColumnName),
+ d.QuoteField(col.References.ReferencedTable),
+ d.QuoteField(col.References.ReferencedColumn)) +
+ standardOnChangeStr("update", col.References.ActionOnUpdate) +
+ standardOnChangeStr("delete", col.References.ActionOnDelete)
+}
+
// Returns suffix
func (d SqliteDialect) CreateTableSuffix() string {
return d.suffix
@@ -145,6 +182,11 @@ func (d SqliteDialect) QuotedTableForQuery(schema string, table string) string {
return d.QuoteField(table)
}
+// sqlite3 has foreign keys disabled by default (will be enabled in sqlite4).
+func (d SqliteDialect) InitString() string {
+ return "pragma foreign_keys = ON;"
+}
+
///////////////////////////////////////////////////////
// PostgreSQL //
////////////////
@@ -211,6 +253,18 @@ func (d PostgresDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return " returning " + col.ColumnName
}
+func (d PostgresDialect) CreateForeignKeySuffix(references *ForeignKey) string {
+ refTable := d.QuotedTableForQuery("", references.ReferencedTable)
+ refField := d.QuoteField(references.ReferencedColumn)
+ return fmt.Sprintf(" references %s (%s)%s%s", refTable, refField,
+ standardOnChangeStr("delete", references.ActionOnDelete),
+ standardOnChangeStr("update", references.ActionOnUpdate))
+}
+
+func (d PostgresDialect) CreateForeignKeyBlock(col *ColumnMap) string {
+ return ""
+}
+
// Returns suffix
func (d PostgresDialect) CreateTableSuffix() string {
return d.suffix
@@ -226,7 +280,7 @@ func (d PostgresDialect) BindVar(i int) string {
}
func (d PostgresDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
- rows, err := exec.query(insertSql, params...)
+ rows, err := exec.Query(insertSql, params...)
if err != nil {
return 0, err
}
@@ -238,7 +292,7 @@ func (d PostgresDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, para
return id, err
}
- return 0, errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error())
+ return 0, errors.New("No serial value returned for insert: "+insertSql+" Encountered error: "+rows.Err().Error())
}
func (d PostgresDialect) QuoteField(f string) string {
@@ -253,6 +307,10 @@ func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string
return schema + "." + d.QuoteField(table)
}
+func (d PostgresDialect) InitString() string {
+ return ""
+}
+
///////////////////////////////////////////////////////
// MySQL //
///////////
@@ -267,10 +325,10 @@ type MySQLDialect struct {
Encoding string
}
-func (m MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
+func (d MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
- return m.ToSqlType(val.Elem(), maxsize, isAutoIncr)
+ return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
case reflect.Bool:
return "boolean"
case reflect.Int8:
@@ -315,49 +373,62 @@ func (m MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool)
}
// Returns auto_increment
-func (m MySQLDialect) AutoIncrStr() string {
+func (d MySQLDialect) AutoIncrStr() string {
return "auto_increment"
}
-func (m MySQLDialect) AutoIncrBindValue() string {
+func (d MySQLDialect) AutoIncrBindValue() string {
return "null"
}
-func (m MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
+func (d MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return ""
}
+func (d MySQLDialect) CreateForeignKeySuffix(references *ForeignKey) string {
+ return ""
+}
+
+func (d MySQLDialect) CreateForeignKeyBlock(col *ColumnMap) string {
+ return fmt.Sprintf("foreign key (%s) references %s (%s)",
+ d.QuoteField(col.ColumnName),
+ d.QuoteField(col.References.ReferencedTable),
+ d.QuoteField(col.References.ReferencedColumn)) +
+ standardOnChangeStr("update", col.References.ActionOnUpdate) +
+ standardOnChangeStr("delete", col.References.ActionOnDelete)
+}
+
// Returns engine=%s charset=%s based on values stored on struct
-func (m MySQLDialect) CreateTableSuffix() string {
- if m.Engine == "" || m.Encoding == "" {
+func (d MySQLDialect) CreateTableSuffix() string {
+ if d.Engine == "" || d.Encoding == "" {
msg := "gorp - undefined"
- if m.Engine == "" {
+ if d.Engine == "" {
msg += " MySQLDialect.Engine"
}
- if m.Engine == "" && m.Encoding == "" {
+ if d.Engine == "" && d.Encoding == "" {
msg += ","
}
- if m.Encoding == "" {
+ if d.Encoding == "" {
msg += " MySQLDialect.Encoding"
}
msg += ". Check that your MySQLDialect was correctly initialized when declared."
panic(msg)
}
- return fmt.Sprintf(" engine=%s charset=%s", m.Engine, m.Encoding)
+ return fmt.Sprintf(" engine=%s charset=%s", d.Engine, d.Encoding)
}
-func (m MySQLDialect) TruncateClause() string {
+func (d MySQLDialect) TruncateClause() string {
return "truncate"
}
// Returns "?"
-func (m MySQLDialect) BindVar(i int) string {
+func (d MySQLDialect) BindVar(i int) string {
return "?"
}
-func (m MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
+func (d MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
return standardInsertAutoIncr(exec, insertSql, params...)
}
@@ -369,3 +440,7 @@ func (d MySQLDialect) QuoteField(f string) string {
func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string {
return d.QuoteField(table)
}
+
+func (d MySQLDialect) InitString() string {
+ return ""
+}
diff --git a/gorp.go b/gorp.go
index 5ba0ba37..a249d03f 100644
--- a/gorp.go
+++ b/gorp.go
@@ -14,7 +14,6 @@ package gorp
import (
"bytes"
"database/sql"
- "errors"
"fmt"
"reflect"
"regexp"
@@ -93,142 +92,6 @@ func (me CustomScanner) Bind() error {
return me.Binder(me.Holder, me.Target)
}
-// DbMap is the root gorp mapping object. Create one of these for each
-// database schema you wish to map. Each DbMap contains a list of
-// mapped tables.
-//
-// Example:
-//
-// dialect := gorp.MySQLDialect{"InnoDB", "UTF8"}
-// dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
-//
-type DbMap struct {
- // Db handle to use with this map
- Db *sql.DB
-
- // Dialect implementation to use with this map
- Dialect Dialect
-
- TypeConverter TypeConverter
-
- tables []*TableMap
- logger GorpLogger
- logPrefix string
-}
-
-// TableMap represents a mapping between a Go struct and a database table
-// Use dbmap.AddTable() or dbmap.AddTableWithName() to create these
-type TableMap struct {
- // Name of database table.
- TableName string
- SchemaName string
- gotype reflect.Type
- columns []*ColumnMap
- keys []*ColumnMap
- uniqueTogether [][]string
- version *ColumnMap
- insertPlan bindPlan
- updatePlan bindPlan
- deletePlan bindPlan
- getPlan bindPlan
- dbmap *DbMap
-}
-
-// ResetSql removes cached insert/update/select/delete SQL strings
-// associated with this TableMap. Call this if you've modified
-// any column names or the table name itself.
-func (t *TableMap) ResetSql() {
- t.insertPlan = bindPlan{}
- t.updatePlan = bindPlan{}
- t.deletePlan = bindPlan{}
- t.getPlan = bindPlan{}
-}
-
-// SetKeys lets you specify the fields on a struct that map to primary
-// key columns on the table. If isAutoIncr is set, result.LastInsertId()
-// will be used after INSERT to bind the generated id to the Go struct.
-//
-// Automatically calls ResetSql() to ensure SQL statements are regenerated.
-//
-// Panics if isAutoIncr is true, and fieldNames length != 1
-//
-func (t *TableMap) SetKeys(isAutoIncr bool, fieldNames ...string) *TableMap {
- if isAutoIncr && len(fieldNames) != 1 {
- panic(fmt.Sprintf(
- "gorp: SetKeys: fieldNames length must be 1 if key is auto-increment. (Saw %v fieldNames)",
- len(fieldNames)))
- }
- t.keys = make([]*ColumnMap, 0)
- for _, name := range fieldNames {
- colmap := t.ColMap(name)
- colmap.isPK = true
- colmap.isAutoIncr = isAutoIncr
- t.keys = append(t.keys, colmap)
- }
- t.ResetSql()
-
- return t
-}
-
-// SetUniqueTogether lets you specify uniqueness constraints across multiple
-// columns on the table. Each call adds an additional constraint for the
-// specified columns.
-//
-// Automatically calls ResetSql() to ensure SQL statements are regenerated.
-//
-// Panics if fieldNames length < 2.
-//
-func (t *TableMap) SetUniqueTogether(fieldNames ...string) *TableMap {
- if len(fieldNames) < 2 {
- panic(fmt.Sprintf(
- "gorp: SetUniqueTogether: must provide at least two fieldNames to set uniqueness constraint."))
- }
-
- columns := make([]string, 0)
- for _, name := range fieldNames {
- columns = append(columns, name)
- }
- t.uniqueTogether = append(t.uniqueTogether, columns)
- t.ResetSql()
-
- return t
-}
-
-// ColMap returns the ColumnMap pointer matching the given struct field
-// name. It panics if the struct does not contain a field matching this
-// name.
-func (t *TableMap) ColMap(field string) *ColumnMap {
- col := colMapOrNil(t, field)
- if col == nil {
- e := fmt.Sprintf("No ColumnMap in table %s type %s with field %s",
- t.TableName, t.gotype.Name(), field)
-
- panic(e)
- }
- return col
-}
-
-func colMapOrNil(t *TableMap, field string) *ColumnMap {
- for _, col := range t.columns {
- if col.fieldName == field || col.ColumnName == field {
- return col
- }
- }
- return nil
-}
-
-// SetVersionCol sets the column to use as the Version field. By default
-// the "Version" field is used. Returns the column found, or panics
-// if the struct does not contain a field matching this name.
-//
-// Automatically calls ResetSql() to ensure SQL statements are regenerated.
-func (t *TableMap) SetVersionCol(field string) *ColumnMap {
- c := t.ColMap(field)
- t.version = c
- t.ResetSql()
- return c
-}
-
type bindPlan struct {
query string
argFields []string
@@ -491,71 +354,6 @@ func (t *TableMap) bindGet() bindPlan {
return plan
}
-// ColumnMap represents a mapping between a Go struct field and a single
-// column in a table.
-// Unique and MaxSize only inform the
-// CreateTables() function and are not used by Insert/Update/Delete/Get.
-type ColumnMap struct {
- // Column name in db table
- ColumnName string
-
- // If true, this column is skipped in generated SQL statements
- Transient bool
-
- // If true, " unique" is added to create table statements.
- // Not used elsewhere
- Unique bool
-
- // Passed to Dialect.ToSqlType() to assist in informing the
- // correct column type to map to in CreateTables()
- // Not used elsewhere
- MaxSize int
-
- fieldName string
- gotype reflect.Type
- isPK bool
- isAutoIncr bool
- isNotNull bool
-}
-
-// Rename allows you to specify the column name in the table
-//
-// Example: table.ColMap("Updated").Rename("date_updated")
-//
-func (c *ColumnMap) Rename(colname string) *ColumnMap {
- c.ColumnName = colname
- return c
-}
-
-// SetTransient allows you to mark the column as transient. If true
-// this column will be skipped when SQL statements are generated
-func (c *ColumnMap) SetTransient(b bool) *ColumnMap {
- c.Transient = b
- return c
-}
-
-// SetUnique adds "unique" to the create table statements for this
-// column, if b is true.
-func (c *ColumnMap) SetUnique(b bool) *ColumnMap {
- c.Unique = b
- return c
-}
-
-// SetNotNull adds "not null" to the create table statements for this
-// column, if nn is true.
-func (c *ColumnMap) SetNotNull(nn bool) *ColumnMap {
- c.isNotNull = nn
- return c
-}
-
-// SetMaxSize specifies the max length of values of this column. This is
-// passed to the dialect.ToSqlType() function, which can use the value
-// to alter the generated type for "create table" statements
-func (c *ColumnMap) SetMaxSize(size int) *ColumnMap {
- c.MaxSize = size
- return c
-}
-
// Transaction represents a database transaction.
// Insert/Update/Delete/Get/Exec operations will be run in the context
// of that transaction. Transactions should be terminated with
@@ -578,8 +376,7 @@ type SqlExecutor interface {
Update(list ...interface{}) (int64, error)
Delete(list ...interface{}) (int64, error)
Exec(query string, args ...interface{}) (sql.Result, error)
- Select(i interface{}, query string,
- args ...interface{}) ([]interface{}, error)
+ Select(i interface{}, query string, args ...interface{}) ([]interface{}, error)
SelectInt(query string, args ...interface{}) (int64, error)
SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error)
SelectFloat(query string, args ...interface{}) (float64, error)
@@ -587,8 +384,8 @@ type SqlExecutor interface {
SelectStr(query string, args ...interface{}) (string, error)
SelectNullStr(query string, args ...interface{}) (sql.NullString, error)
SelectOne(holder interface{}, query string, args ...interface{}) error
- query(query string, args ...interface{}) (*sql.Rows, error)
- queryRow(query string, args ...interface{}) *sql.Row
+ Query(query string, args ...interface{}) (*sql.Rows, error)
+ QueryRow(query string, args ...interface{}) *sql.Row
}
// Compile-time check that DbMap and Transaction implement the SqlExecutor
@@ -599,494 +396,6 @@ type GorpLogger interface {
Printf(format string, v ...interface{})
}
-// TraceOn turns on SQL statement logging for this DbMap. After this is
-// called, all SQL statements will be sent to the logger. If prefix is
-// a non-empty string, it will be written to the front of all logged
-// strings, which can aid in filtering log lines.
-//
-// Use TraceOn if you want to spy on the SQL statements that gorp
-// generates.
-//
-// Note that the base log.Logger type satisfies GorpLogger, but adapters can
-// easily be written for other logging packages (e.g., the golang-sanctioned
-// glog framework).
-func (m *DbMap) TraceOn(prefix string, logger GorpLogger) {
- m.logger = logger
- if prefix == "" {
- m.logPrefix = prefix
- } else {
- m.logPrefix = fmt.Sprintf("%s ", prefix)
- }
-}
-
-// TraceOff turns off tracing. It is idempotent.
-func (m *DbMap) TraceOff() {
- m.logger = nil
- m.logPrefix = ""
-}
-
-// AddTable registers the given interface type with gorp. The table name
-// will be given the name of the TypeOf(i). You must call this function,
-// or AddTableWithName, for any struct type you wish to persist with
-// the given DbMap.
-//
-// This operation is idempotent. If i's type is already mapped, the
-// existing *TableMap is returned
-func (m *DbMap) AddTable(i interface{}) *TableMap {
- return m.AddTableWithName(i, "")
-}
-
-// AddTableWithName has the same behavior as AddTable, but sets
-// table.TableName to name.
-func (m *DbMap) AddTableWithName(i interface{}, name string) *TableMap {
- return m.AddTableWithNameAndSchema(i, "", name)
-}
-
-// AddTableWithNameAndSchema has the same behavior as AddTable, but sets
-// table.TableName to name.
-func (m *DbMap) AddTableWithNameAndSchema(i interface{}, schema string, name string) *TableMap {
- t := reflect.TypeOf(i)
- if name == "" {
- name = t.Name()
- }
-
- // check if we have a table for this type already
- // if so, update the name and return the existing pointer
- for i := range m.tables {
- table := m.tables[i]
- if table.gotype == t {
- table.TableName = name
- return table
- }
- }
-
- tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m}
- tmap.columns, tmap.version = readStructColumns(t)
- m.tables = append(m.tables, tmap)
-
- return tmap
-}
-
-func readStructColumns(t reflect.Type) (cols []*ColumnMap, version *ColumnMap) {
- n := t.NumField()
- for i := 0; i < n; i++ {
- f := t.Field(i)
- if f.Anonymous && f.Type.Kind() == reflect.Struct {
- // Recursively add nested fields in embedded structs.
- subcols, subversion := readStructColumns(f.Type)
- // Don't append nested fields that have the same field
- // name as an already-mapped field.
- for _, subcol := range subcols {
- shouldAppend := true
- for _, col := range cols {
- if !subcol.Transient && subcol.fieldName == col.fieldName {
- shouldAppend = false
- break
- }
- }
- if shouldAppend {
- cols = append(cols, subcol)
- }
- }
- if subversion != nil {
- version = subversion
- }
- } else {
- columnName := f.Tag.Get("db")
- if columnName == "" {
- columnName = f.Name
- }
- cm := &ColumnMap{
- ColumnName: columnName,
- Transient: columnName == "-",
- fieldName: f.Name,
- gotype: f.Type,
- }
- // Check for nested fields of the same field name and
- // override them.
- shouldAppend := true
- for index, col := range cols {
- if !col.Transient && col.fieldName == cm.fieldName {
- cols[index] = cm
- shouldAppend = false
- break
- }
- }
- if shouldAppend {
- cols = append(cols, cm)
- }
- if cm.fieldName == "Version" {
- version = cm
- }
- }
- }
- return
-}
-
-// CreateTables iterates through TableMaps registered to this DbMap and
-// executes "create table" statements against the database for each.
-//
-// This is particularly useful in unit tests where you want to create
-// and destroy the schema automatically.
-func (m *DbMap) CreateTables() error {
- return m.createTables(false)
-}
-
-// CreateTablesIfNotExists is similar to CreateTables, but starts
-// each statement with "create table if not exists" so that existing
-// tables do not raise errors
-func (m *DbMap) CreateTablesIfNotExists() error {
- return m.createTables(true)
-}
-
-func (m *DbMap) createTables(ifNotExists bool) error {
- var err error
- for i := range m.tables {
- table := m.tables[i]
-
- s := bytes.Buffer{}
-
- if strings.TrimSpace(table.SchemaName) != "" {
- schemaCreate := "create schema"
- if ifNotExists {
- schemaCreate += " if not exists"
- }
-
- s.WriteString(fmt.Sprintf("%s %s;", schemaCreate, table.SchemaName))
- }
-
- create := "create table"
- if ifNotExists {
- create += " if not exists"
- }
-
- s.WriteString(fmt.Sprintf("%s %s (", create, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName)))
- x := 0
- for _, col := range table.columns {
- if !col.Transient {
- if x > 0 {
- s.WriteString(", ")
- }
- stype := m.Dialect.ToSqlType(col.gotype, col.MaxSize, col.isAutoIncr)
- s.WriteString(fmt.Sprintf("%s %s", m.Dialect.QuoteField(col.ColumnName), stype))
-
- if col.isPK || col.isNotNull {
- s.WriteString(" not null")
- }
- if col.isPK && len(table.keys) == 1 {
- s.WriteString(" primary key")
- }
- if col.Unique {
- s.WriteString(" unique")
- }
- if col.isAutoIncr {
- s.WriteString(fmt.Sprintf(" %s", m.Dialect.AutoIncrStr()))
- }
-
- x++
- }
- }
- if len(table.keys) > 1 {
- s.WriteString(", primary key (")
- for x := range table.keys {
- if x > 0 {
- s.WriteString(", ")
- }
- s.WriteString(m.Dialect.QuoteField(table.keys[x].ColumnName))
- }
- s.WriteString(")")
- }
- if len(table.uniqueTogether) > 0 {
- for _, columns := range table.uniqueTogether {
- s.WriteString(", unique (")
- for i, column := range columns {
- if i > 0 {
- s.WriteString(", ")
- }
- s.WriteString(m.Dialect.QuoteField(column))
- }
- s.WriteString(")")
- }
- }
- s.WriteString(") ")
- s.WriteString(m.Dialect.CreateTableSuffix())
- s.WriteString(";")
- _, err = m.Exec(s.String())
- if err != nil {
- break
- }
- }
- return err
-}
-
-// DropTable drops an individual table. Will throw an error
-// if the table does not exist.
-func (m *DbMap) DropTable(table interface{}) error {
- t := reflect.TypeOf(table)
- return m.dropTable(t, false)
-}
-
-// DropTable drops an individual table. Will NOT throw an error
-// if the table does not exist.
-func (m *DbMap) DropTableIfExists(table interface{}) error {
- t := reflect.TypeOf(table)
- return m.dropTable(t, true)
-}
-
-// DropTables iterates through TableMaps registered to this DbMap and
-// executes "drop table" statements against the database for each.
-func (m *DbMap) DropTables() error {
- return m.dropTables(false)
-}
-
-// DropTablesIfExists is the same as DropTables, but uses the "if exists" clause to
-// avoid errors for tables that do not exist.
-func (m *DbMap) DropTablesIfExists() error {
- return m.dropTables(true)
-}
-
-// Goes through all the registered tables, dropping them one by one.
-// If an error is encountered, then it is returned and the rest of
-// the tables are not dropped.
-func (m *DbMap) dropTables(addIfExists bool) (err error) {
- for _, table := range m.tables {
- err = m.dropTableImpl(table, addIfExists)
- if err != nil {
- return
- }
- }
- return err
-}
-
-// Implementation of dropping a single table.
-func (m *DbMap) dropTable(t reflect.Type, addIfExists bool) error {
- table := tableOrNil(m, t)
- if table == nil {
- return errors.New(fmt.Sprintf("table %s was not registered!", table.TableName))
- }
-
- return m.dropTableImpl(table, addIfExists)
-}
-
-func (m *DbMap) dropTableImpl(table *TableMap, addIfExists bool) (err error) {
- ifExists := ""
- if addIfExists {
- ifExists = " if exists"
- }
- _, err = m.Exec(fmt.Sprintf("drop table%s %s;", ifExists, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName)))
- return err
-}
-
-// TruncateTables iterates through TableMaps registered to this DbMap and
-// executes "truncate table" statements against the database for each, or in the case of
-// sqlite, a "delete from" with no "where" clause, which uses the truncate optimization
-// (http://www.sqlite.org/lang_delete.html)
-func (m *DbMap) TruncateTables() error {
- var err error
- for i := range m.tables {
- table := m.tables[i]
- _, e := m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(), m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName)))
- if e != nil {
- err = e
- }
- }
- return err
-}
-
-// Insert runs a SQL INSERT statement for each element in list. List
-// items must be pointers.
-//
-// Any interface whose TableMap has an auto-increment primary key will
-// have its last insert id bound to the PK field on the struct.
-//
-// The hook functions PreInsert() and/or PostInsert() will be executed
-// before/after the INSERT statement if the interface defines them.
-//
-// Panics if any interface in the list has not been registered with AddTable
-func (m *DbMap) Insert(list ...interface{}) error {
- return insert(m, m, list...)
-}
-
-// Update runs a SQL UPDATE statement for each element in list. List
-// items must be pointers.
-//
-// The hook functions PreUpdate() and/or PostUpdate() will be executed
-// before/after the UPDATE statement if the interface defines them.
-//
-// Returns the number of rows updated.
-//
-// Returns an error if SetKeys has not been called on the TableMap
-// Panics if any interface in the list has not been registered with AddTable
-func (m *DbMap) Update(list ...interface{}) (int64, error) {
- return update(m, m, list...)
-}
-
-// Delete runs a SQL DELETE statement for each element in list. List
-// items must be pointers.
-//
-// The hook functions PreDelete() and/or PostDelete() will be executed
-// before/after the DELETE statement if the interface defines them.
-//
-// Returns the number of rows deleted.
-//
-// Returns an error if SetKeys has not been called on the TableMap
-// Panics if any interface in the list has not been registered with AddTable
-func (m *DbMap) Delete(list ...interface{}) (int64, error) {
- return delete(m, m, list...)
-}
-
-// Get runs a SQL SELECT to fetch a single row from the table based on the
-// primary key(s)
-//
-// i should be an empty value for the struct to load. keys should be
-// the primary key value(s) for the row to load. If multiple keys
-// exist on the table, the order should match the column order
-// specified in SetKeys() when the table mapping was defined.
-//
-// The hook function PostGet() will be executed after the SELECT
-// statement if the interface defines them.
-//
-// Returns a pointer to a struct that matches or nil if no row is found.
-//
-// Returns an error if SetKeys has not been called on the TableMap
-// Panics if any interface in the list has not been registered with AddTable
-func (m *DbMap) Get(i interface{}, keys ...interface{}) (interface{}, error) {
- return get(m, m, i, keys...)
-}
-
-// Select runs an arbitrary SQL query, binding the columns in the result
-// to fields on the struct specified by i. args represent the bind
-// parameters for the SQL statement.
-//
-// Column names on the SELECT statement should be aliased to the field names
-// on the struct i. Returns an error if one or more columns in the result
-// do not match. It is OK if fields on i are not part of the SQL
-// statement.
-//
-// The hook function PostGet() will be executed after the SELECT
-// statement if the interface defines them.
-//
-// Values are returned in one of two ways:
-// 1. If i is a struct or a pointer to a struct, returns a slice of pointers to
-// matching rows of type i.
-// 2. If i is a pointer to a slice, the results will be appended to that slice
-// and nil returned.
-//
-// i does NOT need to be registered with AddTable()
-func (m *DbMap) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
- return hookedselect(m, m, i, query, args...)
-}
-
-// Exec runs an arbitrary SQL statement. args represent the bind parameters.
-// This is equivalent to running: Exec() using database/sql
-func (m *DbMap) Exec(query string, args ...interface{}) (sql.Result, error) {
- m.trace(query, args...)
- return m.Db.Exec(query, args...)
-}
-
-// SelectInt is a convenience wrapper around the gorp.SelectInt function
-func (m *DbMap) SelectInt(query string, args ...interface{}) (int64, error) {
- return SelectInt(m, query, args...)
-}
-
-// SelectNullInt is a convenience wrapper around the gorp.SelectNullInt function
-func (m *DbMap) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
- return SelectNullInt(m, query, args...)
-}
-
-// SelectFloat is a convenience wrapper around the gorp.SelectFlot function
-func (m *DbMap) SelectFloat(query string, args ...interface{}) (float64, error) {
- return SelectFloat(m, query, args...)
-}
-
-// SelectNullFloat is a convenience wrapper around the gorp.SelectNullFloat function
-func (m *DbMap) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
- return SelectNullFloat(m, query, args...)
-}
-
-// SelectStr is a convenience wrapper around the gorp.SelectStr function
-func (m *DbMap) SelectStr(query string, args ...interface{}) (string, error) {
- return SelectStr(m, query, args...)
-}
-
-// SelectNullStr is a convenience wrapper around the gorp.SelectNullStr function
-func (m *DbMap) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
- return SelectNullStr(m, query, args...)
-}
-
-// SelectOne is a convenience wrapper around the gorp.SelectOne function
-func (m *DbMap) SelectOne(holder interface{}, query string, args ...interface{}) error {
- return SelectOne(m, m, holder, query, args...)
-}
-
-// Begin starts a gorp Transaction
-func (m *DbMap) Begin() (*Transaction, error) {
- m.trace("begin;")
- tx, err := m.Db.Begin()
- if err != nil {
- return nil, err
- }
- return &Transaction{m, tx, false}, nil
-}
-
-func (m *DbMap) tableFor(t reflect.Type, checkPK bool) (*TableMap, error) {
- table := tableOrNil(m, t)
- if table == nil {
- return nil, errors.New(fmt.Sprintf("No table found for type: %v", t.Name()))
- }
-
- if checkPK && len(table.keys) < 1 {
- e := fmt.Sprintf("gorp: No keys defined for table: %s",
- table.TableName)
- return nil, errors.New(e)
- }
-
- return table, nil
-}
-
-func tableOrNil(m *DbMap, t reflect.Type) *TableMap {
- for i := range m.tables {
- table := m.tables[i]
- if table.gotype == t {
- return table
- }
- }
- return nil
-}
-
-func (m *DbMap) tableForPointer(ptr interface{}, checkPK bool) (*TableMap, reflect.Value, error) {
- ptrv := reflect.ValueOf(ptr)
- if ptrv.Kind() != reflect.Ptr {
- e := fmt.Sprintf("gorp: passed non-pointer: %v (kind=%v)", ptr,
- ptrv.Kind())
- return nil, reflect.Value{}, errors.New(e)
- }
- elem := ptrv.Elem()
- etype := reflect.TypeOf(elem.Interface())
- t, err := m.tableFor(etype, checkPK)
- if err != nil {
- return nil, reflect.Value{}, err
- }
-
- return t, elem, nil
-}
-
-func (m *DbMap) queryRow(query string, args ...interface{}) *sql.Row {
- m.trace(query, args...)
- return m.Db.QueryRow(query, args...)
-}
-
-func (m *DbMap) query(query string, args ...interface{}) (*sql.Rows, error) {
- m.trace(query, args...)
- return m.Db.Query(query, args...)
-}
-
-func (m *DbMap) trace(query string, args ...interface{}) {
- if m.logger != nil {
- m.logger.Printf("%s%s %v", m.logPrefix, query, args)
- }
-}
-
///////////////
// Insert has the same behavior as DbMap.Insert(), but runs in a transaction.
@@ -1207,12 +516,12 @@ func (t *Transaction) ReleaseSavepoint(savepoint string) error {
return err
}
-func (t *Transaction) queryRow(query string, args ...interface{}) *sql.Row {
+func (t *Transaction) QueryRow(query string, args ...interface{}) *sql.Row {
t.dbmap.trace(query, args...)
return t.tx.QueryRow(query, args...)
}
-func (t *Transaction) query(query string, args ...interface{}) (*sql.Rows, error) {
+func (t *Transaction) Query(query string, args ...interface{}) (*sql.Rows, error) {
t.dbmap.trace(query, args...)
return t.tx.Query(query, args...)
}
@@ -1225,7 +534,7 @@ func (t *Transaction) query(query string, args ...interface{}) (*sql.Rows, error
func SelectInt(e SqlExecutor, query string, args ...interface{}) (int64, error) {
var h int64
err := selectVal(e, &h, query, args...)
- if err != nil {
+ if err != nil && err != sql.ErrNoRows {
return 0, err
}
return h, nil
@@ -1237,7 +546,7 @@ func SelectInt(e SqlExecutor, query string, args ...interface{}) (int64, error)
func SelectNullInt(e SqlExecutor, query string, args ...interface{}) (sql.NullInt64, error) {
var h sql.NullInt64
err := selectVal(e, &h, query, args...)
- if err != nil {
+ if err != nil && err != sql.ErrNoRows {
return h, err
}
return h, nil
@@ -1249,7 +558,7 @@ func SelectNullInt(e SqlExecutor, query string, args ...interface{}) (sql.NullIn
func SelectFloat(e SqlExecutor, query string, args ...interface{}) (float64, error) {
var h float64
err := selectVal(e, &h, query, args...)
- if err != nil {
+ if err != nil && err != sql.ErrNoRows {
return 0, err
}
return h, nil
@@ -1261,7 +570,7 @@ func SelectFloat(e SqlExecutor, query string, args ...interface{}) (float64, err
func SelectNullFloat(e SqlExecutor, query string, args ...interface{}) (sql.NullFloat64, error) {
var h sql.NullFloat64
err := selectVal(e, &h, query, args...)
- if err != nil {
+ if err != nil && err != sql.ErrNoRows {
return h, err
}
return h, nil
@@ -1273,7 +582,7 @@ func SelectNullFloat(e SqlExecutor, query string, args ...interface{}) (sql.Null
func SelectStr(e SqlExecutor, query string, args ...interface{}) (string, error) {
var h string
err := selectVal(e, &h, query, args...)
- if err != nil {
+ if err != nil && err != sql.ErrNoRows {
return "", err
}
return h, nil
@@ -1286,7 +595,7 @@ func SelectStr(e SqlExecutor, query string, args ...interface{}) (string, error)
func SelectNullStr(e SqlExecutor, query string, args ...interface{}) (sql.NullString, error) {
var h sql.NullString
err := selectVal(e, &h, query, args...)
- if err != nil {
+ if err != nil && err != sql.ErrNoRows {
return h, err
}
return h, nil
@@ -1344,20 +653,17 @@ func selectVal(e SqlExecutor, holder interface{}, query string, args ...interfac
query, args = maybeExpandNamedQuery(m.dbmap, query, args)
}
}
- rows, err := e.query(query, args...)
+ rows, err := e.Query(query, args...)
if err != nil {
return err
}
defer rows.Close()
- if rows.Next() {
- err = rows.Scan(holder)
- if err != nil {
- return err
- }
+ if !rows.Next() {
+ return sql.ErrNoRows
}
- return nil
+ return rows.Scan(holder)
}
///////////////
@@ -1424,7 +730,7 @@ func rawselect(m *DbMap, exec SqlExecutor, i interface{}, query string,
}
// Run the query
- rows, err := exec.query(query, args...)
+ rows, err := exec.Query(query, args...)
if err != nil {
return nil, err
}
@@ -1527,8 +833,8 @@ func maybeExpandNamedQuery(m *DbMap, query string, args []interface{}) (string,
switch {
case arg.Kind() == reflect.Map && arg.Type().Key().Kind() == reflect.String:
return expandNamedQuery(m, query, func(key string) reflect.Value {
- return arg.MapIndex(reflect.ValueOf(key))
- })
+ return arg.MapIndex(reflect.ValueOf(key))
+ })
// #84 - ignore time.Time structs here - there may be a cleaner way to do this
case arg.Kind() == reflect.Struct && !(arg.Type().PkgPath() == "time" && arg.Type().Name() == "Time"):
return expandNamedQuery(m, query, arg.FieldByName)
@@ -1548,15 +854,15 @@ func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect
args []interface{}
)
return keyRegexp.ReplaceAllStringFunc(query, func(key string) string {
- val := keyGetter(key[1:])
- if !val.IsValid() {
- return key
- }
- args = append(args, val.Interface())
- newVar := m.Dialect.BindVar(n)
- n++
- return newVar
- }), args
+ val := keyGetter(key[1:])
+ if !val.IsValid() {
+ return key
+ }
+ args = append(args, val.Interface())
+ newVar := m.Dialect.BindVar(n)
+ n++
+ return newVar
+ }), args
}
func columnToFieldIndex(m *DbMap, t reflect.Type, cols []string) ([][]int, error) {
@@ -1694,7 +1000,7 @@ func get(m *DbMap, exec SqlExecutor, i interface{},
dest[x] = target
}
- row := exec.queryRow(plan.query, keys...)
+ row := exec.QueryRow(plan.query, keys...)
err = row.Scan(dest...)
if err != nil {
if err == sql.ErrNoRows {
diff --git a/gorp_test.go b/gorp_test.go
index 6876be08..226e8caa 100644
--- a/gorp_test.go
+++ b/gorp_test.go
@@ -306,7 +306,7 @@ func TestSetUniqueTogether(t *testing.T) {
t.Error(err)
}
// "unique" for Postgres/SQLite, "Duplicate entry" for MySQL
- if !strings.Contains(err.Error(), "unique") && !strings.Contains(err.Error(), "Duplicate entry") {
+ if !stringContainsIgnoreCase(err.Error(), "unique") && !stringContainsIgnoreCase(err.Error(), "Duplicate entry") {
t.Error(err)
}
@@ -317,7 +317,7 @@ func TestSetUniqueTogether(t *testing.T) {
t.Error(err)
}
// "unique" for Postgres/SQLite, "Duplicate entry" for MySQL
- if !strings.Contains(err.Error(), "unique") && !strings.Contains(err.Error(), "Duplicate entry") {
+ if !stringContainsIgnoreCase(err.Error(), "unique") && !stringContainsIgnoreCase(err.Error(), "Duplicate entry") {
t.Error(err)
}
@@ -736,6 +736,113 @@ func TestColumnProps(t *testing.T) {
}
}
+func checkFkProperty(t *testing.T, createTableSql, expected string, required bool) {
+ if strings.Contains(createTableSql, expected) != required {
+ t.Errorf("Expected '%s' in:\n%s", expected, createTableSql)
+ }
+}
+
+func TestUnitFkColumnPropsMysql(t *testing.T) {
+ dialect := MySQLDialect{"InnoDB", "UTF8"}
+ dbmap := &DbMap{Dialect: dialect}
+ dbmap.AddTable(Person{}).SetKeys(true, "Id")
+ t1 := dbmap.AddTable(Invoice{}).SetKeys(true, "Id")
+
+ t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Restrict).OnUpdate(Cascade))
+ table1 := dbmap.createOneTableSql(true, dbmap.tables[1])
+ if !strings.Contains(table1, "foreign key (`PersonId`) references `Person` (`Id`)") {
+ t.Errorf("Expected foreign key reference in:\n%s", table1)
+ }
+ checkFkProperty(t, table1, "on update cascade", true)
+ checkFkProperty(t, table1, "on delete restrict", true)
+
+ t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Delete).OnUpdate(SetNull))
+ table1 = dbmap.createOneTableSql(true, dbmap.tables[1])
+ checkFkProperty(t, table1, "on update set null", true)
+ checkFkProperty(t, table1, "on delete delete", true)
+
+ t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(NoAction).OnUpdate(Unspecified))
+ table1 = dbmap.createOneTableSql(true, dbmap.tables[1])
+ checkFkProperty(t, table1, "on update", false)
+ checkFkProperty(t, table1, "on delete no action", true)
+}
+
+func TestUnitFkColumnPropsPsql(t *testing.T) {
+ dialect := PostgresDialect{}
+ dbmap := &DbMap{Dialect: dialect}
+ dbmap.AddTable(Person{}).SetKeys(true, "Id")
+ t1 := dbmap.AddTable(Invoice{}).SetKeys(true, "Id")
+
+ t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Restrict).OnUpdate(Cascade))
+ table1 := dbmap.createOneTableSql(true, dbmap.tables[1])
+ if !strings.Contains(table1, `"personid" bigint references "person" ("id")`) {
+ t.Errorf("Expected foreign key reference in:\n%s", table1)
+ }
+ checkFkProperty(t, table1, "on update cascade", true)
+ checkFkProperty(t, table1, "on delete restrict", true)
+
+ t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Delete).OnUpdate(SetNull))
+ table1 = dbmap.createOneTableSql(true, dbmap.tables[1])
+ checkFkProperty(t, table1, "on update set null", true)
+ checkFkProperty(t, table1, "on delete delete", true)
+
+ t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(NoAction).OnUpdate(Unspecified))
+ table1 = dbmap.createOneTableSql(true, dbmap.tables[1])
+ checkFkProperty(t, table1, "on update", false)
+ checkFkProperty(t, table1, "on delete no action", true)
+}
+
+func TestUnitFkColumnPropsSqlite(t *testing.T) {
+ dialect := SqliteDialect{}
+ dbmap := &DbMap{Dialect: dialect}
+ dbmap.AddTable(Person{}).SetKeys(true, "Id")
+ t1 := dbmap.AddTable(Invoice{}).SetKeys(true, "Id")
+
+ t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Restrict).OnUpdate(Cascade))
+ table1 := dbmap.createOneTableSql(true, dbmap.tables[1])
+ if !strings.Contains(table1, `foreign key ("PersonId") references "Person" ("Id")`) {
+ t.Errorf("Expected foreign key reference in:\n%s", table1)
+ }
+ checkFkProperty(t, table1, "on update cascade", true)
+ checkFkProperty(t, table1, "on delete restrict", true)
+
+ t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Delete).OnUpdate(SetNull))
+ table1 = dbmap.createOneTableSql(true, dbmap.tables[1])
+ checkFkProperty(t, table1, "on update set null", true)
+ checkFkProperty(t, table1, "on delete delete", true)
+
+ t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(NoAction).OnUpdate(Unspecified))
+ table1 = dbmap.createOneTableSql(true, dbmap.tables[1])
+ checkFkProperty(t, table1, "on update", false)
+ checkFkProperty(t, table1, "on delete no action", true)
+}
+
+func TestFkColumnProps(t *testing.T) {
+ dbmap := newDbMap()
+ dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
+ dbmap.AddTable(Person{}).SetKeys(true, "Id")
+ t1 := dbmap.AddTable(Invoice{}).SetKeys(true, "Id")
+ t1.ColMap("PersonId").SetForeignKey(NewForeignKey("Person", "Id").OnDelete(Restrict).OnUpdate(Cascade))
+ // note: "on update" is not yet tested
+
+ err := dbmap.CreateTables()
+ if err != nil {
+ log.Fatalln(err)
+ }
+ defer dropAndClose(dbmap)
+
+ person := &Person{0, 0, 0, "John", "Cooper", 0}
+ _insert(dbmap, person)
+
+ inv := &Invoice{0, 0, 1, "my invoice", person.Id, true}
+ _insert(dbmap, inv)
+
+ n, err := dbmap.Delete(person)
+ if err == nil {
+ t.Errorf("Restricted delete failed; deleted %d. %d\n", n, person.Id)
+ }
+}
+
func TestRawSelect(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
@@ -1408,7 +1515,20 @@ func TestSelectSingleVal(t *testing.T) {
_insert(dbmap, &Person{0, 0, 0, "bob", "smith", 0})
err = dbmap.SelectOne(&p2, "select * from person_test where Fname='bob'")
if err == nil {
- t.Error("Expected nil when two rows found")
+ t.Error("Expected error when two rows found")
+ }
+
+ // tests for #150
+ var tInt int64
+ var tStr string
+ var tBool bool
+ var tFloat float64
+ primVals := []interface{}{tInt, tStr, tBool, tFloat}
+ for _, prim := range primVals {
+ err = dbmap.SelectOne(&prim, "select * from person_test where Id=-123")
+ if err == nil || err != sql.ErrNoRows {
+ t.Error("primVals: SelectOne should have returned sql.ErrNoRows")
+ }
}
}
@@ -1542,9 +1662,9 @@ func initDbMapBench() *DbMap {
func initDbMap() *DbMap {
dbmap := newDbMap()
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
+ dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id")
dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id")
dbmap.AddTableWithName(OverriddenInvoice{}, "invoice_override_test").SetKeys(false, "Id")
- dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id")
dbmap.AddTableWithName(WithIgnoredColumn{}, "ignored_column_test").SetKeys(true, "Id")
dbmap.AddTableWithName(TypeConversionExample{}, "type_conv_test").SetKeys(true, "Id")
dbmap.AddTableWithName(WithEmbeddedStruct{}, "embedded_struct_test").SetKeys(true, "Id")
@@ -1577,7 +1697,10 @@ func newDbMap() *DbMap {
}
func dropAndClose(dbmap *DbMap) {
- dbmap.DropTablesIfExists()
+ err := dbmap.DropTablesIfExists()
+ if err != nil {
+ log.Println(err)
+ }
dbmap.Db.Close()
}
@@ -1710,3 +1833,9 @@ func _rawselect(dbmap *DbMap, i interface{}, query string, args ...interface{})
}
return list
}
+
+func stringContainsIgnoreCase(value, lookingFor string) bool {
+ valueLC := strings.ToLower(value)
+ lookingForLC := strings.ToLower(lookingFor)
+ return strings.Contains(valueLC, lookingForLC)
+}
diff --git a/schema.go b/schema.go
new file mode 100644
index 00000000..7aef30a1
--- /dev/null
+++ b/schema.go
@@ -0,0 +1,572 @@
+// Copyright 2012 James Cooper. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+// Package gorp provides a simple way to marshal Go structs to and from
+// SQL databases. It uses the database/sql package, and should work with any
+// compliant database/sql driver.
+//
+// Source code and project home:
+// https://github.com/coopernurse/gorp
+//
+package gorp
+
+import (
+ "bytes"
+ "database/sql"
+ "errors"
+ "fmt"
+ "reflect"
+ "strings"
+)
+
+// DbMap is the root gorp mapping object. Create one of these for each
+// database schema you wish to map. Each DbMap contains a list of
+// mapped tables.
+//
+// Example:
+//
+// dialect := gorp.MySQLDialect{"InnoDB", "UTF8"}
+// dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
+//
+type DbMap struct {
+ // Db handle to use with this map
+ Db *sql.DB
+
+ // Dialect implementation to use with this map
+ Dialect Dialect
+
+ TypeConverter TypeConverter
+
+ tables []*TableMap
+ logger GorpLogger
+ logPrefix string
+ initialised bool
+}
+
+// AddTable registers the given interface type with gorp. The table name
+// will be given the name of the TypeOf(i). You must call this function,
+// or AddTableWithName, for any struct type you wish to persist with
+// the given DbMap.
+//
+// This operation is idempotent. If i's type is already mapped, the
+// existing *TableMap is returned
+func (m *DbMap) AddTable(i interface{}) *TableMap {
+ return m.AddTableWithName(i, "")
+}
+
+// AddTableWithName has the same behavior as AddTable, but sets
+// table.TableName to name.
+func (m *DbMap) AddTableWithName(i interface{}, name string) *TableMap {
+ return m.AddTableWithNameAndSchema(i, "", name)
+}
+
+// AddTableWithNameAndSchema has the same behavior as AddTable, but sets
+// table.TableName to name.
+func (m *DbMap) AddTableWithNameAndSchema(i interface{}, schema string, name string) *TableMap {
+ t := reflect.TypeOf(i)
+ if name == "" {
+ name = t.Name()
+ }
+
+ // check if we have a table for this type already
+ // if so, update the name and return the existing pointer
+ for i := range m.tables {
+ table := m.tables[i]
+ if table.gotype == t {
+ table.TableName = name
+ return table
+ }
+ }
+
+ tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m}
+ tmap.columns, tmap.version = readStructColumns(t)
+ m.tables = append(m.tables, tmap)
+
+ return tmap
+}
+
+func readStructColumns(t reflect.Type) (cols []*ColumnMap, version *ColumnMap) {
+ n := t.NumField()
+ for i := 0; i < n; i++ {
+ f := t.Field(i)
+ if f.Anonymous && f.Type.Kind() == reflect.Struct {
+ // Recursively add nested fields in embedded structs.
+ subcols, subversion := readStructColumns(f.Type)
+ // Don't append nested fields that have the same field
+ // name as an already-mapped field.
+ for _, subcol := range subcols {
+ shouldAppend := true
+ for _, col := range cols {
+ if !subcol.Transient && subcol.fieldName == col.fieldName {
+ shouldAppend = false
+ break
+ }
+ }
+ if shouldAppend {
+ cols = append(cols, subcol)
+ }
+ }
+ if subversion != nil {
+ version = subversion
+ }
+ } else {
+ columnName := f.Tag.Get("db")
+ if columnName == "" {
+ columnName = f.Name
+ }
+ cm := &ColumnMap{
+ ColumnName: columnName,
+ Transient: columnName == "-",
+ fieldName: f.Name,
+ gotype: f.Type,
+ }
+ // Check for nested fields of the same field name and
+ // override them.
+ shouldAppend := true
+ for index, col := range cols {
+ if !col.Transient && col.fieldName == cm.fieldName {
+ cols[index] = cm
+ shouldAppend = false
+ break
+ }
+ }
+ if shouldAppend {
+ cols = append(cols, cm)
+ }
+ if cm.fieldName == "Version" {
+ version = cm
+ }
+ }
+ }
+ return
+}
+
+// CreateTables iterates through TableMaps registered to this DbMap and
+// executes "create table" statements against the database for each.
+//
+// This is particularly useful in unit tests where you want to create
+// and destroy the schema automatically.
+func (m *DbMap) CreateTables() error {
+ return m.createTables(false)
+}
+
+// CreateTablesIfNotExists is similar to CreateTables, but starts
+// each statement with "create table if not exists" so that existing
+// tables do not raise errors
+func (m *DbMap) CreateTablesIfNotExists() error {
+ return m.createTables(true)
+}
+
+func (m *DbMap) createTables(ifNotExists bool) error {
+ for _, t := range m.tables {
+ ddl := m.createOneTableSql(ifNotExists, t)
+ _, err := m.Exec(ddl)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (m *DbMap) createOneTableSql(ifNotExists bool, table *TableMap) string {
+ s := bytes.Buffer{}
+
+ if strings.TrimSpace(table.SchemaName) != "" {
+ s.WriteString("create schema ")
+ if ifNotExists {
+ s.WriteString("if not exists ")
+ }
+
+ s.WriteString(table.SchemaName)
+ s.WriteString(";")
+ }
+
+ s.WriteString("create table ")
+ if ifNotExists {
+ s.WriteString("if not exists ")
+ }
+ tableName := m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName)
+ s.WriteString(fmt.Sprintf("%s (\n ", tableName))
+
+ s = m.createOneTableColumns(table, s)
+ s = m.createOneTablePrimaryKeys(table, s)
+ s = m.createOneTableIndexes(table, s)
+ s = m.createOneTableForeignKeys(table, s)
+
+ s.WriteString("\n) ")
+ s.WriteString(m.Dialect.CreateTableSuffix())
+ s.WriteString(";")
+ return s.String()
+}
+
+func (m *DbMap) createOneTableColumns(table *TableMap, s bytes.Buffer) bytes.Buffer {
+ x := 0
+ for _, col := range table.columns {
+ if !col.Transient {
+ if x > 0 {
+ s.WriteString(",\n ")
+ }
+ field := m.Dialect.QuoteField(col.ColumnName)
+ stype := m.Dialect.ToSqlType(col.gotype, col.MaxSize, col.isAutoIncr)
+ s.WriteString(fmt.Sprintf("%s %s", field, stype))
+
+ if col.isPK || col.isNotNull {
+ s.WriteString(" not null")
+ }
+ if col.isPK && len(table.keys) == 1 {
+ s.WriteString(" primary key")
+ }
+ if col.Unique {
+ s.WriteString(" unique")
+ }
+ if col.isAutoIncr {
+ s.WriteString(" " + m.Dialect.AutoIncrStr())
+ }
+ if col.References != nil {
+ s.WriteString(m.Dialect.CreateForeignKeySuffix(col.References))
+ }
+
+ x++
+ }
+ }
+ return s
+}
+
+func (m *DbMap) createOneTablePrimaryKeys(table *TableMap, s bytes.Buffer) bytes.Buffer {
+ if len(table.keys) > 1 {
+ s.WriteString(",\n primary key (")
+ for x := range table.keys {
+ if x > 0 {
+ s.WriteString(", ")
+ }
+ s.WriteString(m.Dialect.QuoteField(table.keys[x].ColumnName))
+ }
+ s.WriteString(")")
+ }
+ return s
+}
+
+func (m *DbMap) createOneTableIndexes(table *TableMap, s bytes.Buffer) bytes.Buffer {
+ if len(table.uniqueTogether) > 0 {
+ for _, columns := range table.uniqueTogether {
+ s.WriteString(",\n unique (")
+ for i, column := range columns {
+ if i > 0 {
+ s.WriteString(", ")
+ }
+ s.WriteString(m.Dialect.QuoteField(column))
+ }
+ s.WriteString(")")
+ }
+ }
+ return s
+}
+
+func (m *DbMap) createOneTableForeignKeys(table *TableMap, s bytes.Buffer) bytes.Buffer {
+ for _, col := range table.columns {
+ if !col.Transient && col.References != nil {
+ fkBlock := m.Dialect.CreateForeignKeyBlock(col)
+ if fkBlock != "" {
+ s.WriteString(",\n ")
+ s.WriteString(fkBlock)
+ }
+ }
+ }
+ return s
+}
+
+// DropTable drops an individual table. Will throw an error
+// if the table does not exist.
+func (m *DbMap) DropTable(table interface{}) error {
+ t := reflect.TypeOf(table)
+ return m.dropTable(t, false)
+}
+
+// DropTable drops an individual table. Will NOT throw an error
+// if the table does not exist.
+func (m *DbMap) DropTableIfExists(table interface{}) error {
+ t := reflect.TypeOf(table)
+ return m.dropTable(t, true)
+}
+
+// DropTables iterates through TableMaps registered to this DbMap and
+// executes "drop table" statements against the database for each.
+func (m *DbMap) DropTables() error {
+ return m.dropTables(false)
+}
+
+// DropTablesIfExists is the same as DropTables, but uses the "if exists" clause to
+// avoid errors for tables that do not exist.
+func (m *DbMap) DropTablesIfExists() error {
+ return m.dropTables(true)
+}
+
+// Goes through all the registered tables, dropping them one by one.
+// If an error is encountered, then it is returned and the rest of
+// the tables are not dropped.
+func (m *DbMap) dropTables(addIfExists bool) (err error) {
+ // drop in reverse order, assuming that foreign keys were created in
+ // the order that the tables were created
+ n := len(m.tables) - 1
+ for i, _ := range m.tables {
+ table := m.tables[n - i]
+ err = m.dropTableImpl(table, addIfExists)
+ if err != nil {
+ return
+ }
+ }
+ return
+}
+
+// Implementation of dropping a single table.
+func (m *DbMap) dropTable(t reflect.Type, addIfExists bool) error {
+ table := tableOrNil(m, t)
+ if table == nil {
+ return errors.New(fmt.Sprintf("table %s was not registered!", table.TableName))
+ }
+
+ return m.dropTableImpl(table, addIfExists)
+}
+
+func (m *DbMap) dropTableImpl(table *TableMap, addIfExists bool) (err error) {
+ ifExists := ""
+ if addIfExists {
+ ifExists = " if exists"
+ }
+ _, err = m.Exec(fmt.Sprintf("drop table%s %s;", ifExists, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName)))
+ return err
+}
+
+// TableMap represents a mapping between a Go struct and a database table
+// Use dbmap.AddTable() or dbmap.AddTableWithName() to create these
+type TableMap struct {
+ // Name of database table.
+ TableName string
+ SchemaName string
+ gotype reflect.Type
+ columns []*ColumnMap
+ keys []*ColumnMap
+ uniqueTogether [][]string
+ version *ColumnMap
+ insertPlan bindPlan
+ updatePlan bindPlan
+ deletePlan bindPlan
+ getPlan bindPlan
+ dbmap *DbMap
+}
+
+// ResetSql removes cached insert/update/select/delete SQL strings
+// associated with this TableMap. Call this if you've modified
+// any column names or the table name itself.
+func (t *TableMap) ResetSql() {
+ t.insertPlan = bindPlan{}
+ t.updatePlan = bindPlan{}
+ t.deletePlan = bindPlan{}
+ t.getPlan = bindPlan{}
+}
+
+// SetKeys lets you specify the fields on a struct that map to primary
+// key columns on the table. If isAutoIncr is set, result.LastInsertId()
+// will be used after INSERT to bind the generated id to the Go struct.
+//
+// Automatically calls ResetSql() to ensure SQL statements are regenerated.
+//
+// Panics if isAutoIncr is true, and fieldNames length != 1
+//
+func (t *TableMap) SetKeys(isAutoIncr bool, fieldNames ...string) *TableMap {
+ if isAutoIncr && len(fieldNames) != 1 {
+ panic(fmt.Sprintf(
+ "gorp: SetKeys: fieldNames length must be 1 if key is auto-increment. (Saw %v fieldNames)",
+ len(fieldNames)))
+ }
+ t.keys = make([]*ColumnMap, 0)
+ for _, name := range fieldNames {
+ colmap := t.ColMap(name)
+ colmap.isPK = true
+ colmap.isAutoIncr = isAutoIncr
+ t.keys = append(t.keys, colmap)
+ }
+ t.ResetSql()
+
+ return t
+}
+
+// SetUniqueTogether lets you specify uniqueness constraints across multiple
+// columns on the table. Each call adds an additional constraint for the
+// specified columns.
+//
+// Automatically calls ResetSql() to ensure SQL statements are regenerated.
+//
+// Panics if fieldNames length < 2.
+//
+func (t *TableMap) SetUniqueTogether(fieldNames ...string) *TableMap {
+ if len(fieldNames) < 2 {
+ panic(fmt.Sprintf(
+ "gorp: SetUniqueTogether: must provide at least two fieldNames to set uniqueness constraint."))
+ }
+
+ columns := make([]string, 0)
+ for _, name := range fieldNames {
+ columns = append(columns, name)
+ }
+ t.uniqueTogether = append(t.uniqueTogether, columns)
+ t.ResetSql()
+
+ return t
+}
+
+// ColMap returns the ColumnMap pointer matching the given struct field
+// name. It panics if the struct does not contain a field matching this
+// name.
+func (t *TableMap) ColMap(field string) *ColumnMap {
+ col := colMapOrNil(t, field)
+ if col == nil {
+ e := fmt.Sprintf("No ColumnMap in table %s type %s with field %s",
+ t.TableName, t.gotype.Name(), field)
+
+ panic(e)
+ }
+ return col
+}
+
+func colMapOrNil(t *TableMap, field string) *ColumnMap {
+ for _, col := range t.columns {
+ if col.fieldName == field || col.ColumnName == field {
+ return col
+ }
+ }
+ return nil
+}
+
+// SetVersionCol sets the column to use as the Version field. By default
+// the "Version" field is used. Returns the column found, or panics
+// if the struct does not contain a field matching this name.
+//
+// Automatically calls ResetSql() to ensure SQL statements are regenerated.
+func (t *TableMap) SetVersionCol(field string) *ColumnMap {
+ c := t.ColMap(field)
+ t.version = c
+ t.ResetSql()
+ return c
+}
+
+// ColumnMap represents a mapping between a Go struct field and a single
+// column in a table.
+// Unique and MaxSize only inform the
+// CreateTables() function and are not used by Insert/Update/Delete/Get.
+type ColumnMap struct {
+ // Column name in db table
+ ColumnName string
+
+ // If true, this column is skipped in generated SQL statements
+ Transient bool
+
+ // If true, " unique" is added to create table statements.
+ // Not used elsewhere
+ Unique bool
+
+ // Passed to Dialect.ToSqlType() to assist in informing the
+ // correct column type to map to in CreateTables()
+ // Not used elsewhere
+ MaxSize int
+
+ // If present, specifies that this column is a foreign key that
+ // references another column of another table.
+ References *ForeignKey
+
+ fieldName string
+ gotype reflect.Type
+ isPK bool
+ isAutoIncr bool
+ isNotNull bool
+}
+
+// Rename allows you to specify the column name in the table
+//
+// Example: table.ColMap("Updated").Rename("date_updated")
+//
+func (c *ColumnMap) Rename(colname string) *ColumnMap {
+ c.ColumnName = colname
+ return c
+}
+
+// SetTransient allows you to mark the column as transient. If true
+// this column will be skipped when SQL statements are generated
+func (c *ColumnMap) SetTransient(b bool) *ColumnMap {
+ c.Transient = b
+ return c
+}
+
+// SetUnique adds "unique" to the create table statements for this
+// column, if b is true.
+func (c *ColumnMap) SetUnique(b bool) *ColumnMap {
+ c.Unique = b
+ return c
+}
+
+// SetNotNull adds "not null" to the create table statements for this
+// column, if nn is true.
+func (c *ColumnMap) SetNotNull(nn bool) *ColumnMap {
+ c.isNotNull = nn
+ return c
+}
+
+// SetMaxSize specifies the max length of values of this column. This is
+// passed to the dialect.ToSqlType() function, which can use the value
+// to alter the generated type for "create table" statements
+func (c *ColumnMap) SetMaxSize(size int) *ColumnMap {
+ c.MaxSize = size
+ return c
+}
+
+// SetForeignKey specifies the foreign-key relationship between this column
+// and a column in another table.
+func (c *ColumnMap) SetForeignKey(fk *ForeignKey) *ColumnMap {
+ c.References = fk
+ return c
+}
+
+// Specifies what foreign-key constraints will be enforced by the database.
+type FKOnChangeAction int
+
+const (
+ Unspecified FKOnChangeAction = iota
+ NoAction
+ Restrict
+ Cascade
+ SetNull
+ //SetDefault // may not be supported by MySql
+ Delete
+)
+
+// ForeignKey specifies the relationship formed when one column refers to the
+// primary key of another table.
+type ForeignKey struct {
+ ReferencedTable string
+ ReferencedColumn string
+ ActionOnDelete FKOnChangeAction
+ ActionOnUpdate FKOnChangeAction
+}
+
+// NewForeignKey creates a new ForeignKey for a specified table/column reference. If
+// the table is part of a named schema, include the schema prefix in the referencedTable
+// value.
+func NewForeignKey(referencedTable, referencedColumn string) *ForeignKey {
+ return &ForeignKey{referencedTable, referencedColumn, Unspecified, Unspecified}
+}
+
+// Sets the action that the database is to perform when the parent record
+// is updated. The default is usually Restrict.
+func (fk *ForeignKey) OnUpdate(action FKOnChangeAction) *ForeignKey {
+ fk.ActionOnUpdate = action
+ return fk
+}
+
+// Sets the action that the database is to perform when the parent record
+// is deleted. The default is usually Restrict.
+func (fk *ForeignKey) OnDelete(action FKOnChangeAction) *ForeignKey {
+ fk.ActionOnDelete = action
+ return fk
+}
+
diff --git a/test_all.sh b/test_all.sh
index f870b39a..edabf7a2 100755
--- a/test_all.sh
+++ b/test_all.sh
@@ -3,20 +3,40 @@
# on macs, you may need to:
# export GOBUILDFLAG=-ldflags -linkmode=external
-set -e
+# Using "-n", the environment variables will be set but tests will not run.
+# You can "dot" this script, then use go test directly, which will give verbose progress info.
+if [ "$1" = "-n" ]; then
+ NOOP=":"
+ shift
+fi
-export GORP_TEST_DSN=gorptest/gorptest/gorptest
-export GORP_TEST_DIALECT=mysql
-go test $GOBUILDFLAG .
+set -e
-export GORP_TEST_DSN=gorptest:gorptest@/gorptest
-export GORP_TEST_DIALECT=gomysql
-go test $GOBUILDFLAG .
+testMysql() {
+ export GORP_TEST_DSN=gorptest/gorptest/gorptest
+ export GORP_TEST_DIALECT=mysql
+ $NOOP go test $GOBUILDFLAG .
-export GORP_TEST_DSN="user=gorptest password=gorptest dbname=gorptest sslmode=disable"
-export GORP_TEST_DIALECT=postgres
-go test $GOBUILDFLAG .
+ export GORP_TEST_DSN=gorptest:gorptest@/gorptest
+ export GORP_TEST_DIALECT=gomysql
+ $NOOP go test $GOBUILDFLAG .
+}
-export GORP_TEST_DSN=/tmp/gorptest.bin
-export GORP_TEST_DIALECT=sqlite
-go test $GOBUILDFLAG .
+testPostgresql() {
+ export GORP_TEST_DSN="user=gorptest password=gorptest dbname=gorptest sslmode=disable"
+ export GORP_TEST_DIALECT=postgres
+ $NOOP go test $GOBUILDFLAG .
+}
+
+testSqlite() {
+ export GORP_TEST_DSN=/tmp/gorptest.bin
+ export GORP_TEST_DIALECT=sqlite
+ $NOOP go test $GOBUILDFLAG .
+}
+
+case "$1" in
+ mysql) testMysql ;;
+ psql | postgresql) testPostgresql ;;
+ sqlite) testSqlite ;;
+ *) testMysql ; testPostgresql ; testSqlite ;;
+esac