From 0d429794814e0368358ab3e1e9e1077e080b9db2 Mon Sep 17 00:00:00 2001 From: Matt Culbreth Date: Wed, 21 May 2014 17:22:29 -0400 Subject: [PATCH] mattc58 merge of develop, to include support for Transactions for Goose migrations commit 50ff6b2d8eeac38ccfad2618078346bd797c5b3d Merge: 96ea999 def0f72 Author: Matt Culbreth Date: Wed May 21 17:12:50 2014 -0400 fixed merge issue with mattc58/gorp for transaction support commit 96ea99914b04733d38e45afbcd7b30303b08b7c8 Author: Matt Culbreth Date: Wed May 21 17:09:31 2014 -0400 Changed DbMap and createTables() to be able to support Transactions in addtion to Connections. This is for Goose migration support. commit 1e2e4143621a5e2d7632ebd70b047b5212eb078e Author: Matt Culbreth Date: Fri Apr 11 09:20:37 2014 -0400 fixed else statement commit 4fc09f6483481da40023f1796fd25dc5dd15fe7b Author: Matt Culbreth Date: Fri Apr 11 09:18:40 2014 -0400 changed dropTableImpl to use the transaction if it's there commit 25cbcabae357931971e0387fca203ee185f432ab Author: Matt Culbreth Date: Fri Apr 11 12:56:10 2014 +0000 * added Tx to DbMap * added logic in createTables() to switch between Tx and DbMap for the Exec commit def0f726bd2740814f2f988d21330e56cba89271 Author: James Cooper Date: Mon May 19 07:31:13 2014 -0700 clean up dialect receiver var name. add QuerySuffix() to SqlServerDialect commit 325f5a760e1508ae82e30c39a631e93ab1113ab7 Author: James Cooper Date: Mon May 19 07:30:45 2014 -0700 ensure that dialect structs comply with interface commit 096714fd38f16eb6fa00fc04415b5f228502add4 Author: James Cooper Date: Mon May 19 07:30:05 2014 -0700 add sql server driver to README commit ef29a36bc77802041e18d2a93e5b247aafdf8e6d Author: James Cooper Date: Fri May 16 13:22:59 2014 -0700 readme: add notes about SQL Server and Oracle support commit 214aeb9d7e79eea57ac347195331934233e895f1 Author: James Cooper Date: Fri May 16 13:10:19 2014 -0700 Added docs for DbMap.TableFor commit 4593fad38df54751aa9811a739d09f96dce2e688 Author: Samuel Nelson Date: Fri May 16 13:27:27 2014 -0600 Export TableMap.Columns and DbMap.TableFor() commit 728a08eb9169ec77726ab406641748579496b668 Author: Alex Guerrieri Date: Fri May 16 16:51:28 2014 +0200 Oracle diver develop, new method QuerySuffix() in dialect, better args string commit d69be845787ce72661a90166fce22b83e1ed8506 Author: Pierre Date: Fri May 16 10:23:29 2014 +0200 fix QuoteField for SqlServerDialect from blank to dbl-quote commit 354af1951ea9a2d206b35834bac02eb05b6dbfed Author: Pierre Prinetti Date: Thu May 15 10:56:09 2014 +0200 add support for SQL Server commit 108d32dba8d839104b061e6ab433051fc2594498 Author: Harley Laue Date: Wed May 7 16:29:35 2014 -0500 MySQL certainly does have schemas * I'd wager it was a copy/paste job from SQLite. * http://dev.mysql.com/doc/refman/5.1/en/create-database.html commit 86bc8f669d57644b06b779ce3a84ee47b2df636e Author: Samuel Nelson Date: Thu Apr 17 15:34:29 2014 -0600 More clarification commit 059256e3e256fc717010dff2bc0d4de6d4029600 Author: Samuel Nelson Date: Thu Apr 17 15:32:41 2014 -0600 Minor documentation clarification commit 2e7bcc335fa21be773cfacd8813ad77d3dacfc2d Author: Samuel Nelson Date: Thu Apr 17 15:31:01 2014 -0600 Support non-integer autoincrement fields in postgres commit 9259f03aa6d4d724149e1c4ec5ef040e6e587262 Author: Mike Thompson Date: Mon Apr 7 16:06:36 2014 -0700 Update statements to handle Non-incrementing PK's The change in this if statement allows for the driver to correctly handle non-auto-incrementing primary keys. The check has been changed to if it is auto-incrementing or transient. commit ed5dce528750a790c8d73f794ae53760a6e3e6a5 Author: Mike Thompson Date: Mon Apr 7 15:02:48 2014 -0700 Initial test to show failure This initial test, against this code base, shows that "update" against a table, with only one column, where that coumn is the primary key, and is not auto-incrementing, will fail. commit b7953545f3151ddda68fe6914614e8524f5c1e7f Author: umisama Date: Mon Mar 17 19:35:29 2014 +0900 Add Pre/Post functions to godoc commit f1c93ef3e5e62d2bfa148f751acc51fbc05d5cdd Author: umisama Date: Mon Mar 17 19:11:23 2014 +0900 Modified Pre/Post functions into using interface commit b5ce3b9a12a14e306c83e460200e003f66cf95f0 Author: zhengjia Date: Wed Mar 5 21:38:57 2014 -0600 Add support for alias column commit eeb38f76f723211b6ab83e744a763c49a99f92af Author: James Cooper Date: Wed May 14 13:11:44 2014 -0700 TestSetUniqueTogether: fix sqlite test expectation (force err string to lower) commit f296a21ab03af5752155f2a56415c8d82d269c5f Author: James Cooper Date: Sun Mar 23 07:53:33 2014 -0700 #150 - modify selectVal() to return sql.ErrNoRows, and modify Select* funcs to ignore this error and continue to return zeroVal (per docs). This fixes SelectOne() for pointers to primitive values. --- README.md | 9 ++ dialect.go | 294 +++++++++++++++++++++++++++++++++++---- gorp.go | 381 ++++++++++++++++++++++++++++++++++----------------- gorp_test.go | 104 +++++++++++++- 4 files changed, 638 insertions(+), 150 deletions(-) mode change 100644 => 100755 gorp.go diff --git a/README.md b/README.md index e7bb424b..cbef88be 100644 --- a/README.md +++ b/README.md @@ -564,6 +564,15 @@ implemented per database vendor. Dialects are provided for: Each of these three databases pass the test suite. See `gorp_test.go` for example DSNs for these three databases. +Support is also provided for: + +* Oracle (contributed by @klaidliadon) +* SQL Server (contributed by @qrawl) - use driver: github.com/denisenkom/go-mssqldb + +Note that these databases are not covered by CI and I (@coopernurse) have no good way to +test them locally. So please try them and send patches as needed, but expect a bit more +unpredicability. + ## Known Issues ## ### SQL placeholder portability ### diff --git a/dialect.go b/dialect.go index 6b6ef0e8..963a3eab 100644 --- a/dialect.go +++ b/dialect.go @@ -12,6 +12,9 @@ import ( // but this could change in the future type Dialect interface { + // adds a suffix to any query, usually ";" + QuerySuffix() string + // ToSqlType returns the SQL column type to use when creating a // table of the given Go Type. maxsize can be used to switch based on // size. For example, in MySQL []byte could map to BLOB, MEDIUMBLOB, @@ -21,6 +24,8 @@ type Dialect interface { // string to append to primary key column definitions AutoIncrStr() string + // string to bind autoincrement columns to. Empty string will + // remove reference to those columns in the INSERT statement. AutoIncrBindValue() string AutoIncrInsertSuffix(col *ColumnMap) string @@ -32,8 +37,6 @@ type Dialect interface { // string to truncate tables TruncateClause() string - InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) - // bind variable string to use when forming SQL statements // in many dbs it is "?", but Postgres appears to use $1 // @@ -53,6 +56,25 @@ type Dialect interface { QuotedTableForQuery(schema string, table string) string } +// IntegerAutoIncrInserter is implemented by dialects that can perform +// inserts with automatically incremented integer primary keys. If +// the dialect can handle automatic assignment of more than just +// integers, see TargetedAutoIncrInserter. +type IntegerAutoIncrInserter interface { + InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) +} + +// TargetedAutoIncrInserter is implemented by dialects that can +// perform automatic assignment of any primary key type (i.e. strings +// for uuids, integers for serials, etc). +type TargetedAutoIncrInserter interface { + // InsertAutoIncrToTarget runs an insert operation and assigns the + // automatically generated primary key directly to the passed in + // target. The target should be a pointer to the primary key + // field of the value being inserted. + InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error +} + func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { res, err := exec.Exec(insertSql, params...) if err != nil { @@ -69,6 +91,8 @@ type SqliteDialect struct { suffix string } +func (d SqliteDialect) QuerySuffix() string { return ";" } + func (d SqliteDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { switch val.Kind() { case reflect.Ptr: @@ -153,6 +177,8 @@ type PostgresDialect struct { suffix string } +func (d PostgresDialect) QuerySuffix() string { return ";" } + func (d PostgresDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { switch val.Kind() { case reflect.Ptr: @@ -225,20 +251,19 @@ func (d PostgresDialect) BindVar(i int) string { return fmt.Sprintf("$%d", i+1) } -func (d PostgresDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { +func (d PostgresDialect) InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error { rows, err := exec.query(insertSql, params...) if err != nil { - return 0, err + return err } defer rows.Close() if rows.Next() { - var id int64 - err := rows.Scan(&id) - return id, err + err := rows.Scan(target) + return err } - return 0, errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error()) + return errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error()) } func (d PostgresDialect) QuoteField(f string) string { @@ -267,10 +292,12 @@ type MySQLDialect struct { Encoding string } -func (m MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { +func (d MySQLDialect) QuerySuffix() string { return ";" } + +func (d MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { switch val.Kind() { case reflect.Ptr: - return m.ToSqlType(val.Elem(), maxsize, isAutoIncr) + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) case reflect.Bool: return "boolean" case reflect.Int8: @@ -315,49 +342,49 @@ func (m MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) } // Returns auto_increment -func (m MySQLDialect) AutoIncrStr() string { +func (d MySQLDialect) AutoIncrStr() string { return "auto_increment" } -func (m MySQLDialect) AutoIncrBindValue() string { +func (d MySQLDialect) AutoIncrBindValue() string { return "null" } -func (m MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string { +func (d MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string { return "" } // Returns engine=%s charset=%s based on values stored on struct -func (m MySQLDialect) CreateTableSuffix() string { - if m.Engine == "" || m.Encoding == "" { +func (d MySQLDialect) CreateTableSuffix() string { + if d.Engine == "" || d.Encoding == "" { msg := "gorp - undefined" - if m.Engine == "" { + if d.Engine == "" { msg += " MySQLDialect.Engine" } - if m.Engine == "" && m.Encoding == "" { + if d.Engine == "" && d.Encoding == "" { msg += "," } - if m.Encoding == "" { + if d.Encoding == "" { msg += " MySQLDialect.Encoding" } msg += ". Check that your MySQLDialect was correctly initialized when declared." panic(msg) } - return fmt.Sprintf(" engine=%s charset=%s", m.Engine, m.Encoding) + return fmt.Sprintf(" engine=%s charset=%s", d.Engine, d.Encoding) } -func (m MySQLDialect) TruncateClause() string { +func (d MySQLDialect) TruncateClause() string { return "truncate" } // Returns "?" -func (m MySQLDialect) BindVar(i int) string { +func (d MySQLDialect) BindVar(i int) string { return "?" } -func (m MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { +func (d MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { return standardInsertAutoIncr(exec, insertSql, params...) } @@ -365,7 +392,226 @@ func (d MySQLDialect) QuoteField(f string) string { return "`" + f + "`" } -// MySQL does not have schemas like PostgreSQL does, so just escape it like normal func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string { - return d.QuoteField(table) + if strings.TrimSpace(schema) == "" { + return d.QuoteField(table) + } + + return schema + "." + d.QuoteField(table) +} + +/////////////////////////////////////////////////////// +// Sql Server // +//////////////// + +// Implementation of Dialect for Microsoft SQL Server databases. +// Tested on SQL Server 2008 with driver: github.com/denisenkom/go-mssqldb +// Presently, it doesn't work with CreateTablesIfNotExists(). + +type SqlServerDialect struct { + suffix string +} + +func (d SqlServerDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "bit" + case reflect.Int8: + return "tinyint" + case reflect.Uint8: + return "smallint" + case reflect.Int16: + return "smallint" + case reflect.Uint16: + return "int" + case reflect.Int, reflect.Int32: + return "int" + case reflect.Uint, reflect.Uint32: + return "bigint" + case reflect.Int64: + return "bigint" + case reflect.Uint64: + return "bigint" + case reflect.Float32: + return "real" + case reflect.Float64: + return "float(53)" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "varbinary" + } + } + + switch val.Name() { + case "NullInt64": + return "bigint" + case "NullFloat64": + return "float(53)" + case "NullBool": + return "tinyint" + case "Time": + return "datetime" + } + + if maxsize < 1 { + maxsize = 255 + } + return fmt.Sprintf("varchar(%d)", maxsize) +} + +// Returns auto_increment +func (d SqlServerDialect) AutoIncrStr() string { + return "identity(0,1)" +} + +// Empty string removes autoincrement columns from the INSERT statements. +func (d SqlServerDialect) AutoIncrBindValue() string { + return "" +} + +func (d SqlServerDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return "" +} + +// Returns suffix +func (d SqlServerDialect) CreateTableSuffix() string { + + return d.suffix +} + +func (d SqlServerDialect) TruncateClause() string { + return "delete from" +} + +// Returns "?" +func (d SqlServerDialect) BindVar(i int) string { + return "?" +} + +func (d SqlServerDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { + return standardInsertAutoIncr(exec, insertSql, params...) +} + +func (d SqlServerDialect) QuoteField(f string) string { + return `"` + f + `"` +} + +func (d SqlServerDialect) QuotedTableForQuery(schema string, table string) string { + if strings.TrimSpace(schema) == "" { + return table + } + return schema + "." + table +} + +func (d SqlServerDialect) QuerySuffix() string { return ";" } + +/////////////////////////////////////////////////////// +// Oracle // +/////////// + +// Implementation of Dialect for Oracle databases. +type OracleDialect struct{} + +func (d OracleDialect) QuerySuffix() string { return "" } + +func (d OracleDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32: + if isAutoIncr { + return "serial" + } + return "integer" + case reflect.Int64, reflect.Uint64: + if isAutoIncr { + return "bigserial" + } + return "bigint" + case reflect.Float64: + return "double precision" + case reflect.Float32: + return "real" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "bytea" + } + } + + switch val.Name() { + case "NullInt64": + return "bigint" + case "NullFloat64": + return "double precision" + case "NullBool": + return "boolean" + case "NullTime", "Time": + return "timestamp with time zone" + } + + if maxsize > 0 { + return fmt.Sprintf("varchar(%d)", maxsize) + } else { + return "text" + } + +} + +// Returns empty string +func (d OracleDialect) AutoIncrStr() string { + return "" +} + +func (d OracleDialect) AutoIncrBindValue() string { + return "default" +} + +func (d OracleDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return " returning " + col.ColumnName +} + +// Returns suffix +func (d OracleDialect) CreateTableSuffix() string { + return "" +} + +func (d OracleDialect) TruncateClause() string { + return "truncate" +} + +// Returns "$(i+1)" +func (d OracleDialect) BindVar(i int) string { + return fmt.Sprintf(":%d", i+1) +} + +func (d OracleDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { + rows, err := exec.query(insertSql, params...) + if err != nil { + return 0, err + } + defer rows.Close() + + if rows.Next() { + var id int64 + err := rows.Scan(&id) + return id, err + } + + return 0, errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error()) +} + +func (d OracleDialect) QuoteField(f string) string { + return `"` + strings.ToUpper(f) + `"` +} + +func (d OracleDialect) QuotedTableForQuery(schema string, table string) string { + if strings.TrimSpace(schema) == "" { + return d.QuoteField(table) + } + + return schema + "." + d.QuoteField(table) } diff --git a/gorp.go b/gorp.go old mode 100644 new mode 100755 index 5ba0ba37..53be19fa --- a/gorp.go +++ b/gorp.go @@ -14,13 +14,58 @@ package gorp import ( "bytes" "database/sql" + "database/sql/driver" "errors" "fmt" "reflect" "regexp" "strings" + "time" ) +// Oracle String (empty string is null) +type OracleString struct { + sql.NullString +} + +// Scan implements the Scanner interface. +func (os *OracleString) Scan(value interface{}) error { + if value == nil { + os.String, os.Valid = "", false + return nil + } + os.Valid = true + return os.NullString.Scan(value) +} + +// Value implements the driver Valuer interface. +func (os OracleString) Value() (driver.Value, error) { + if !os.Valid || os.String == "" { + return nil, nil + } + return os.String, nil +} + +// A nullable Time value +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +func (nt *NullTime) Scan(value interface{}) error { + nt.Time, nt.Valid = value.(time.Time) + return nil +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} + var zeroVal reflect.Value var versFieldConst = "[gorp_ver_field]" @@ -106,6 +151,9 @@ type DbMap struct { // Db handle to use with this map Db *sql.DB + // Transaction handle to use with this map + Tx *sql.Tx + // Dialect implementation to use with this map Dialect Dialect @@ -123,7 +171,7 @@ type TableMap struct { TableName string SchemaName string gotype reflect.Type - columns []*ColumnMap + Columns []*ColumnMap keys []*ColumnMap uniqueTogether [][]string version *ColumnMap @@ -209,7 +257,7 @@ func (t *TableMap) ColMap(field string) *ColumnMap { } func colMapOrNil(t *TableMap, field string) *ColumnMap { - for _, col := range t.columns { + for _, col := range t.Columns { if col.fieldName == field || col.ColumnName == field { return col } @@ -302,42 +350,45 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { x := 0 first := true - for y := range t.columns { - col := t.columns[y] - - if !col.Transient { - if !first { - s.WriteString(",") - s2.WriteString(",") - } - s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) + for y := range t.Columns { + col := t.Columns[y] + if !(col.isAutoIncr && t.dbmap.Dialect.AutoIncrBindValue() == "") { + if !col.Transient { + if !first { + s.WriteString(",") + s2.WriteString(",") + } + s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) - if col.isAutoIncr { - s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue()) - plan.autoIncrIdx = y - plan.autoIncrFieldName = col.fieldName - } else { - s2.WriteString(t.dbmap.Dialect.BindVar(x)) - if col == t.version { - plan.versField = col.fieldName - plan.argFields = append(plan.argFields, versFieldConst) + if col.isAutoIncr { + s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue()) + plan.autoIncrIdx = y + plan.autoIncrFieldName = col.fieldName } else { - plan.argFields = append(plan.argFields, col.fieldName) + s2.WriteString(t.dbmap.Dialect.BindVar(x)) + if col == t.version { + plan.versField = col.fieldName + plan.argFields = append(plan.argFields, versFieldConst) + } else { + plan.argFields = append(plan.argFields, col.fieldName) + } + + x++ } - - x++ + first = false } - - first = false + } else { + plan.autoIncrIdx = y + plan.autoIncrFieldName = col.fieldName } } s.WriteString(") values (") s.WriteString(s2.String()) s.WriteString(")") if plan.autoIncrIdx > -1 { - s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.columns[plan.autoIncrIdx])) + s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.Columns[plan.autoIncrIdx])) } - s.WriteString(";") + s.WriteString(t.dbmap.Dialect.QuerySuffix()) plan.query = s.String() t.insertPlan = plan @@ -354,9 +405,9 @@ func (t *TableMap) bindUpdate(elem reflect.Value) (bindInstance, error) { s.WriteString(fmt.Sprintf("update %s set ", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) x := 0 - for y := range t.columns { - col := t.columns[y] - if !col.isPK && !col.Transient { + for y := range t.Columns { + col := t.Columns[y] + if !col.isAutoIncr && !col.Transient { if x > 0 { s.WriteString(", ") } @@ -395,7 +446,7 @@ func (t *TableMap) bindUpdate(elem reflect.Value) (bindInstance, error) { s.WriteString(t.dbmap.Dialect.BindVar(x)) plan.argFields = append(plan.argFields, plan.versField) } - s.WriteString(";") + s.WriteString(t.dbmap.Dialect.QuerySuffix()) plan.query = s.String() t.updatePlan = plan @@ -411,8 +462,8 @@ func (t *TableMap) bindDelete(elem reflect.Value) (bindInstance, error) { s := bytes.Buffer{} s.WriteString(fmt.Sprintf("delete from %s", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) - for y := range t.columns { - col := t.columns[y] + for y := range t.Columns { + col := t.Columns[y] if !col.Transient { if col == t.version { plan.versField = col.fieldName @@ -441,7 +492,7 @@ func (t *TableMap) bindDelete(elem reflect.Value) (bindInstance, error) { plan.argFields = append(plan.argFields, plan.versField) } - s.WriteString(";") + s.WriteString(t.dbmap.Dialect.QuerySuffix()) plan.query = s.String() t.deletePlan = plan @@ -458,7 +509,7 @@ func (t *TableMap) bindGet() bindPlan { s.WriteString("select ") x := 0 - for _, col := range t.columns { + for _, col := range t.Columns { if !col.Transient { if x > 0 { s.WriteString(",") @@ -482,7 +533,7 @@ func (t *TableMap) bindGet() bindPlan { plan.keyFields = append(plan.keyFields, col.fieldName) } - s.WriteString(";") + s.WriteString(t.dbmap.Dialect.QuerySuffix()) plan.query = s.String() t.getPlan = plan @@ -661,7 +712,7 @@ func (m *DbMap) AddTableWithNameAndSchema(i interface{}, schema string, name str } tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m} - tmap.columns, tmap.version = readStructColumns(t) + tmap.Columns, tmap.version = readStructColumns(t) m.tables = append(m.tables, tmap) return tmap @@ -762,7 +813,7 @@ func (m *DbMap) createTables(ifNotExists bool) error { s.WriteString(fmt.Sprintf("%s %s (", create, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) x := 0 - for _, col := range table.columns { + for _, col := range table.Columns { if !col.Transient { if x > 0 { s.WriteString(", ") @@ -810,8 +861,15 @@ func (m *DbMap) createTables(ifNotExists bool) error { } s.WriteString(") ") s.WriteString(m.Dialect.CreateTableSuffix()) - s.WriteString(";") - _, err = m.Exec(s.String()) + s.WriteString(m.Dialect.QuerySuffix()) + + // use the transaction if it's there. otherwise, use the db connection. + if m.Tx != nil { + _, err = m.Tx.Exec(s.String()) + } else { + _, err = m.Exec(s.String()) + } + if err != nil { break } @@ -873,7 +931,14 @@ func (m *DbMap) dropTableImpl(table *TableMap, addIfExists bool) (err error) { if addIfExists { ifExists = " if exists" } - _, err = m.Exec(fmt.Sprintf("drop table%s %s;", ifExists, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) + + // use the transaction if it's there. otherwise, use the db connection. + if m.Tx != nil { + _, err = m.Tx.Exec(fmt.Sprintf("drop table%s %s;", ifExists, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) + } else { + _, err = m.Exec(fmt.Sprintf("drop table%s %s;", ifExists, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) + } + return err } @@ -1029,7 +1094,10 @@ func (m *DbMap) Begin() (*Transaction, error) { return &Transaction{m, tx, false}, nil } -func (m *DbMap) tableFor(t reflect.Type, checkPK bool) (*TableMap, error) { +// TableFor returns the *TableMap corresponding to the given Go Type +// If no table is mapped to that type an error is returned. +// If checkPK is true and the mapped table has no registered PKs, an error is returned. +func (m *DbMap) TableFor(t reflect.Type, checkPK bool) (*TableMap, error) { table := tableOrNil(m, t) if table == nil { return nil, errors.New(fmt.Sprintf("No table found for type: %v", t.Name())) @@ -1063,7 +1131,7 @@ func (m *DbMap) tableForPointer(ptr interface{}, checkPK bool) (*TableMap, refle } elem := ptrv.Elem() etype := reflect.TypeOf(elem.Interface()) - t, err := m.tableFor(etype, checkPK) + t, err := m.TableFor(etype, checkPK) if err != nil { return nil, reflect.Value{}, err } @@ -1083,8 +1151,33 @@ func (m *DbMap) query(query string, args ...interface{}) (*sql.Rows, error) { func (m *DbMap) trace(query string, args ...interface{}) { if m.logger != nil { - m.logger.Printf("%s%s %v", m.logPrefix, query, args) + var margs = argsString(args...) + m.logger.Printf("%s%s [%s]", m.logPrefix, query, margs) + } +} + +func argsString(args ...interface{}) string { + var margs string + for i, a := range args { + var v interface{} = a + if x, ok := v.(driver.Valuer); ok { + y, err := x.Value() + if err == nil { + v = y + } + } + switch v.(type) { + case string: + v = fmt.Sprintf("%q", v) + default: + v = fmt.Sprintf("%v", v) + } + margs += fmt.Sprintf("%d:%s", i+1, v) + if i+1 < len(args) { + margs += " " + } } + return margs } /////////////// @@ -1225,7 +1318,7 @@ func (t *Transaction) query(query string, args ...interface{}) (*sql.Rows, error func SelectInt(e SqlExecutor, query string, args ...interface{}) (int64, error) { var h int64 err := selectVal(e, &h, query, args...) - if err != nil { + if err != nil && err != sql.ErrNoRows { return 0, err } return h, nil @@ -1237,7 +1330,7 @@ func SelectInt(e SqlExecutor, query string, args ...interface{}) (int64, error) func SelectNullInt(e SqlExecutor, query string, args ...interface{}) (sql.NullInt64, error) { var h sql.NullInt64 err := selectVal(e, &h, query, args...) - if err != nil { + if err != nil && err != sql.ErrNoRows { return h, err } return h, nil @@ -1249,7 +1342,7 @@ func SelectNullInt(e SqlExecutor, query string, args ...interface{}) (sql.NullIn func SelectFloat(e SqlExecutor, query string, args ...interface{}) (float64, error) { var h float64 err := selectVal(e, &h, query, args...) - if err != nil { + if err != nil && err != sql.ErrNoRows { return 0, err } return h, nil @@ -1261,7 +1354,7 @@ func SelectFloat(e SqlExecutor, query string, args ...interface{}) (float64, err func SelectNullFloat(e SqlExecutor, query string, args ...interface{}) (sql.NullFloat64, error) { var h sql.NullFloat64 err := selectVal(e, &h, query, args...) - if err != nil { + if err != nil && err != sql.ErrNoRows { return h, err } return h, nil @@ -1273,7 +1366,7 @@ func SelectNullFloat(e SqlExecutor, query string, args ...interface{}) (sql.Null func SelectStr(e SqlExecutor, query string, args ...interface{}) (string, error) { var h string err := selectVal(e, &h, query, args...) - if err != nil { + if err != nil && err != sql.ErrNoRows { return "", err } return h, nil @@ -1286,7 +1379,7 @@ func SelectStr(e SqlExecutor, query string, args ...interface{}) (string, error) func SelectNullStr(e SqlExecutor, query string, args ...interface{}) (sql.NullString, error) { var h sql.NullString err := selectVal(e, &h, query, args...) - if err != nil { + if err != nil && err != sql.ErrNoRows { return h, err } return h, nil @@ -1350,14 +1443,11 @@ func selectVal(e SqlExecutor, holder interface{}, query string, args ...interfac } defer rows.Close() - if rows.Next() { - err = rows.Scan(holder) - if err != nil { - return err - } + if !rows.Next() { + return sql.ErrNoRows } - return nil + return rows.Scan(holder) } /////////////// @@ -1373,17 +1463,21 @@ func hookedselect(m *DbMap, exec SqlExecutor, i interface{}, query string, // Determine where the results are: written to i, or returned in list if t, _ := toSliceType(i); t == nil { for _, v := range list { - err = runHook("PostGet", reflect.ValueOf(v), hookArg(exec)) - if err != nil { - return nil, err + if v, ok := v.(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } } } } else { resultsValue := reflect.Indirect(reflect.ValueOf(i)) for i := 0; i < resultsValue.Len(); i++ { - err = runHook("PostGet", resultsValue.Index(i), hookArg(exec)) - if err != nil { - return nil, err + if v, ok := resultsValue.Index(i).Interface().(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } } } } @@ -1575,14 +1669,16 @@ func columnToFieldIndex(m *DbMap, t reflect.Type, cols []string) ([][]int, error // a field in the i struct for x := range cols { colName := strings.ToLower(cols[x]) - field, found := t.FieldByNameFunc(func(fieldName string) bool { + var mappedFieldName string field, _ := t.FieldByName(fieldName) - fieldName = field.Tag.Get("db") - - if fieldName == "-" { + lowerFieldName := strings.ToLower(field.Name) + mappedFieldName = field.Tag.Get("db") + if mappedFieldName == "-" && colName != lowerFieldName { return false - } else if fieldName == "" { + } else if mappedFieldName == "-" && colName == lowerFieldName { + return true + } else if mappedFieldName == "" { fieldName = field.Name } if tableMapped { @@ -1591,7 +1687,6 @@ func columnToFieldIndex(m *DbMap, t reflect.Type, cols []string) ([][]int, error fieldName = colMap.ColumnName } } - return colName == strings.ToLower(fieldName) }) if found { @@ -1668,7 +1763,7 @@ func get(m *DbMap, exec SqlExecutor, i interface{}, return nil, err } - table, err := m.tableFor(t, true) + table, err := m.TableFor(t, true) if err != nil { return nil, err } @@ -1710,16 +1805,17 @@ func get(m *DbMap, exec SqlExecutor, i interface{}, } } - err = runHook("PostGet", v, hookArg(exec)) - if err != nil { - return nil, err + if v, ok := v.Interface().(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } } return v.Interface(), nil } func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { - hookarg := hookArg(exec) count := int64(0) for _, ptr := range list { table, elem, err := m.tableForPointer(ptr, true) @@ -1727,10 +1823,12 @@ func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { return -1, err } - eptr := elem.Addr() - err = runHook("PreDelete", eptr, hookarg) - if err != nil { - return -1, err + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreDelete); ok { + err = v.PreDelete(exec) + if err != nil { + return -1, err + } } bi, err := table.bindDelete(elem) @@ -1754,9 +1852,11 @@ func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { count += rows - err = runHook("PostDelete", eptr, hookarg) - if err != nil { - return -1, err + if v, ok := eval.(HasPostDelete); ok { + err := v.PostDelete(exec) + if err != nil { + return -1, err + } } } @@ -1764,7 +1864,6 @@ func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { } func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { - hookarg := hookArg(exec) count := int64(0) for _, ptr := range list { table, elem, err := m.tableForPointer(ptr, true) @@ -1772,10 +1871,12 @@ func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { return -1, err } - eptr := elem.Addr() - err = runHook("PreUpdate", eptr, hookarg) - if err != nil { - return -1, err + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreUpdate); ok { + err = v.PreUpdate(exec) + if err != nil { + return -1, err + } } bi, err := table.bindUpdate(elem) @@ -1804,26 +1905,29 @@ func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { count += rows - err = runHook("PostUpdate", eptr, hookarg) - if err != nil { - return -1, err + if v, ok := eval.(HasPostUpdate); ok { + err = v.PostUpdate(exec) + if err != nil { + return -1, err + } } } return count, nil } func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error { - hookarg := hookArg(exec) for _, ptr := range list { table, elem, err := m.tableForPointer(ptr, false) if err != nil { return err } - eptr := elem.Addr() - err = runHook("PreInsert", eptr, hookarg) - if err != nil { - return err + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreInsert); ok { + err := v.PreInsert(exec) + if err != nil { + return err + } } bi, err := table.bindInsert(elem) @@ -1832,18 +1936,28 @@ func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error { } if bi.autoIncrIdx > -1 { - id, err := m.Dialect.InsertAutoIncr(exec, bi.query, bi.args...) - if err != nil { - return err - } f := elem.FieldByName(bi.autoIncrFieldName) - k := f.Kind() - if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) { - f.SetInt(id) - } else if (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) { - f.SetUint(uint64(id)) - } else { - return fmt.Errorf("gorp: Cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName) + switch inserter := m.Dialect.(type) { + case IntegerAutoIncrInserter: + id, err := inserter.InsertAutoIncr(exec, bi.query, bi.args...) + if err != nil { + return err + } + k := f.Kind() + if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) { + f.SetInt(id) + } else if (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) { + f.SetUint(uint64(id)) + } else { + return fmt.Errorf("gorp: Cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName) + } + case TargetedAutoIncrInserter: + err := inserter.InsertAutoIncrToTarget(exec, bi.query, f.Addr().Interface(), bi.args...) + if err != nil { + return err + } + default: + return fmt.Errorf("gorp: Cannot use autoincrement fields on dialects that do not implement an autoincrementing interface") } } else { _, err := exec.Exec(bi.query, bi.args...) @@ -1852,25 +1966,11 @@ func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error { } } - err = runHook("PostInsert", eptr, hookarg) - if err != nil { - return err - } - } - return nil -} - -func hookArg(exec SqlExecutor) []reflect.Value { - execval := reflect.ValueOf(exec) - return []reflect.Value{execval} -} - -func runHook(name string, eptr reflect.Value, arg []reflect.Value) error { - hook := eptr.MethodByName(name) - if hook != zeroVal { - ret := hook.Call(arg) - if len(ret) > 0 && !ret[0].IsNil() { - return ret[0].Interface().(error) + if v, ok := eval.(HasPostInsert); ok { + err := v.PostInsert(exec) + if err != nil { + return err + } } } return nil @@ -1891,3 +1991,38 @@ func lockError(m *DbMap, exec SqlExecutor, tableName string, } return -1, ole } + +// PostUpdate() will be executed after the GET statement. +type HasPostGet interface { + PostGet(SqlExecutor) error +} + +// PostUpdate() will be executed after the DELETE statement +type HasPostDelete interface { + PostDelete(SqlExecutor) error +} + +// PostUpdate() will be executed after the UPDATE statement +type HasPostUpdate interface { + PostUpdate(SqlExecutor) error +} + +// PostInsert() will be executed after the INSERT statement +type HasPostInsert interface { + PostInsert(SqlExecutor) error +} + +// PreDelete() will be executed before the DELETE statement. +type HasPreDelete interface { + PreDelete(SqlExecutor) error +} + +// PreUpdate() will be executed before UPDATE statement. +type HasPreUpdate interface { + PreUpdate(SqlExecutor) error +} + +// PreInsert() will be executed before INSERT statement. +type HasPreInsert interface { + PreInsert(SqlExecutor) error +} diff --git a/gorp_test.go b/gorp_test.go index 6876be08..45587853 100644 --- a/gorp_test.go +++ b/gorp_test.go @@ -18,6 +18,13 @@ import ( "time" ) +// verify interface compliance +var _ Dialect = SqliteDialect{} +var _ Dialect = PostgresDialect{} +var _ Dialect = MySQLDialect{} +var _ Dialect = SqlServerDialect{} +var _ Dialect = OracleDialect{} + type Invoice struct { Id int64 Created int64 @@ -64,6 +71,12 @@ type WithIgnoredColumn struct { Created int64 } +type IgnoredColumnExported struct { + Id int64 + External int64 `db:"-"` + Created int64 +} + type WithStringPk struct { Id string Name string @@ -119,6 +132,10 @@ type UniqueColumns struct { ZipCode int64 } +type SingleColumnTable struct { + SomeId string +} + type testTypeConverter struct{} func (me testTypeConverter) ToDb(val interface{}) (interface{}, error) { @@ -306,7 +323,8 @@ func TestSetUniqueTogether(t *testing.T) { t.Error(err) } // "unique" for Postgres/SQLite, "Duplicate entry" for MySQL - if !strings.Contains(err.Error(), "unique") && !strings.Contains(err.Error(), "Duplicate entry") { + errLower := strings.ToLower(err.Error()) + if !strings.Contains(errLower, "unique") && !strings.Contains(errLower, "duplicate entry") { t.Error(err) } @@ -317,7 +335,8 @@ func TestSetUniqueTogether(t *testing.T) { t.Error(err) } // "unique" for Postgres/SQLite, "Duplicate entry" for MySQL - if !strings.Contains(err.Error(), "unique") && !strings.Contains(err.Error(), "Duplicate entry") { + errLower = strings.ToLower(err.Error()) + if !strings.Contains(errLower, "unique") && !strings.Contains(errLower, "duplicate entry") { t.Error(err) } @@ -1408,7 +1427,54 @@ func TestSelectSingleVal(t *testing.T) { _insert(dbmap, &Person{0, 0, 0, "bob", "smith", 0}) err = dbmap.SelectOne(&p2, "select * from person_test where Fname='bob'") if err == nil { - t.Error("Expected nil when two rows found") + t.Error("Expected error when two rows found") + } + + // tests for #150 + var tInt int64 + var tStr string + var tBool bool + var tFloat float64 + primVals := []interface{}{tInt, tStr, tBool, tFloat} + for _, prim := range primVals { + err = dbmap.SelectOne(&prim, "select * from person_test where Id=-123") + if err == nil || err != sql.ErrNoRows { + t.Error("primVals: SelectOne should have returned sql.ErrNoRows") + } + } +} + +func TestSelectAlias(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + p1 := &IgnoredColumnExported{Id: 1, External: 2, Created: 3} + _insert(dbmap, p1) + + var p2 IgnoredColumnExported + + err := dbmap.SelectOne(&p2, "select * from ignored_column_exported_test where Id=1") + if err != nil { + t.Error(err) + } + if p2.Id != 1 || p2.Created != 3 || p2.External != 0 { + t.Error("Expected ignorred field defaults to not set") + } + + err = dbmap.SelectOne(&p2, "SELECT *, 1 AS external FROM ignored_column_exported_test") + if err != nil { + t.Error(err) + } + if p2.External != 1 { + t.Error("Expected select as can map to exported field.") + } + + var rows *sql.Rows + var cols []string + rows, err = dbmap.Db.Query("SELECT * FROM ignored_column_exported_test") + cols, err = rows.Columns() + if err != nil || len(cols) != 2 { + t.Error("Expected ignored column not created") } } @@ -1436,6 +1502,37 @@ func TestMysqlPanicIfDialectNotInitialized(t *testing.T) { db.CreateTables() } +func TestSingleColumnKeyDbReturnsZeroRowsUpdatedOnPKChange(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + dbmap.AddTableWithName(SingleColumnTable{}, "single_column_table").SetKeys(false, "SomeId") + err := dbmap.DropTablesIfExists() + if err != nil { + t.Error("Drop tables failed") + } + err = dbmap.CreateTablesIfNotExists() + if err != nil { + t.Error("Create tables failed") + } + err = dbmap.TruncateTables() + if err != nil { + t.Error("Truncate tables failed") + } + + sct := SingleColumnTable{ + SomeId: "A Unique Id String", + } + + count, err := dbmap.Update(&sct) + if err != nil { + t.Error(err) + } + if count != 0 { + t.Errorf("Expected 0 updated rows, got %d", count) + } + +} + func BenchmarkNativeCrud(b *testing.B) { b.StopTimer() dbmap := initDbMapBench() @@ -1546,6 +1643,7 @@ func initDbMap() *DbMap { dbmap.AddTableWithName(OverriddenInvoice{}, "invoice_override_test").SetKeys(false, "Id") dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id") dbmap.AddTableWithName(WithIgnoredColumn{}, "ignored_column_test").SetKeys(true, "Id") + dbmap.AddTableWithName(IgnoredColumnExported{}, "ignored_column_exported_test").SetKeys(true, "Id") dbmap.AddTableWithName(TypeConversionExample{}, "type_conv_test").SetKeys(true, "Id") dbmap.AddTableWithName(WithEmbeddedStruct{}, "embedded_struct_test").SetKeys(true, "Id") dbmap.AddTableWithName(WithEmbeddedStructBeforeAutoincrField{}, "embedded_struct_before_autoincr_test").SetKeys(true, "Id")