Skip to content

Commit

Permalink
Support non-integer autoincrement fields in postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
nelsam authored and James Cooper committed May 16, 2014
1 parent 9259f03 commit 2e7bcc3
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 19 deletions.
32 changes: 24 additions & 8 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ type Dialect interface {
// string to truncate tables
TruncateClause() string

InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error)

// bind variable string to use when forming SQL statements
// in many dbs it is "?", but Postgres appears to use $1
//
Expand All @@ -53,6 +51,25 @@ type Dialect interface {
QuotedTableForQuery(schema string, table string) string
}

// IntegerAutoIncrInserter is implemented by dialects that can perform
// inserts with automatically incremented integer primary keys. If
// the dialect can handle automatic assignment of more than just
// integers, see TargetedAutoIncrInserter.
type IntegerAutoIncrInserter interface {
InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error)
}

// TargetedAutoIncrInserter is implemented by dialects that can
// perform automatic assignment of any primary key type (i.e. strings
// for uuids, integers for serials, etc).
type TargetedAutoIncrInserter interface {
// InsertAutoIncrToTarget runs an insert operation and assigns the
// resulting automatically generated primary key directly to the
// passed in target, which should be a pointer to the primary key
// element.
InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error
}

func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
res, err := exec.Exec(insertSql, params...)
if err != nil {
Expand Down Expand Up @@ -225,20 +242,19 @@ func (d PostgresDialect) BindVar(i int) string {
return fmt.Sprintf("$%d", i+1)
}

func (d PostgresDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
func (d PostgresDialect) InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error {
rows, err := exec.query(insertSql, params...)
if err != nil {
return 0, err
return err
}
defer rows.Close()

if rows.Next() {
var id int64
err := rows.Scan(&id)
return id, err
err := rows.Scan(target)
return err
}

return 0, errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error())
return errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error())
}

func (d PostgresDialect) QuoteField(f string) string {
Expand Down
32 changes: 21 additions & 11 deletions gorp.go
Original file line number Diff line number Diff line change
Expand Up @@ -1843,18 +1843,28 @@ func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error {
}

if bi.autoIncrIdx > -1 {
id, err := m.Dialect.InsertAutoIncr(exec, bi.query, bi.args...)
if err != nil {
return err
}
f := elem.FieldByName(bi.autoIncrFieldName)
k := f.Kind()
if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) {
f.SetInt(id)
} else if (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) {
f.SetUint(uint64(id))
} else {
return fmt.Errorf("gorp: Cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName)
switch inserter := m.Dialect.(type) {
case IntegerAutoIncrInserter:
id, err := inserter.InsertAutoIncr(exec, bi.query, bi.args...)
if err != nil {
return err
}
k := f.Kind()
if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) {
f.SetInt(id)
} else if (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) {
f.SetUint(uint64(id))
} else {
return fmt.Errorf("gorp: Cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName)
}
case TargetedAutoIncrInserter:
err := inserter.InsertAutoIncrToTarget(exec, bi.query, f.Addr().Interface(), bi.args...)
if err != nil {
return err
}
default:
return fmt.Errorf("gorp: Cannot use autoincrement fields on dialects that do not implement an autoincrementing interface")
}
} else {
_, err := exec.Exec(bi.query, bi.args...)
Expand Down

0 comments on commit 2e7bcc3

Please sign in to comment.