Skip to content
This repository has been archived by the owner on Dec 22, 2023. It is now read-only.

Fix various issues with creating user on sign up #224

Merged
merged 6 commits into from
Nov 9, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions glide.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions glide.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import:
- package: github.com/gorilla/websocket
version: b6ab76f1fe9803ee1d59e7e5b2a797c1fe897ce5
- package: github.com/jmoiron/sqlx
version: 344b1e96d6e410a093b7bce40287731a49e253f6
subpackages:
- reflectx
- package: github.com/joho/godotenv
Expand All @@ -40,7 +39,6 @@ import:
- package: github.com/lann/squirrel
version: e13dbacee404686afd0acf2d44b8d34869605e03
- package: github.com/lib/pq
version: b269bd035a727d6c1081f76e7a239a1b00674c40
subpackages:
- oid
- package: github.com/mattn/go-xmpp
Expand Down
29 changes: 16 additions & 13 deletions pkg/server/handler/recordutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,21 +135,24 @@ func (f recordFetcher) getCreationAccess(recordType string) skydb.RecordACL {

func (f recordFetcher) fetchOrCreateRecord(recordID skydb.RecordID, userInfo *skydb.UserInfo) (record *skydb.Record, err skyerr.Error) {
dbRecord := skydb.Record{}
if f.db.Get(recordID, &dbRecord) == skydb.ErrRecordNotFound {
// new record
if f.withMasterKey {
return
}
if dbErr := f.db.Get(recordID, &dbRecord); dbErr != nil {
if dbErr == skydb.ErrRecordNotFound {
// new record
if f.withMasterKey {
return
}

creationAccess := f.getCreationAccess(recordID.Type)
if !creationAccess.Accessible(userInfo, skydb.CreateLevel) {
err = skyerr.NewError(
skyerr.PermissionDenied,
"no permission to create",
)
}
creationAccess := f.getCreationAccess(recordID.Type)
if !creationAccess.Accessible(userInfo, skydb.CreateLevel) {
err = skyerr.NewError(
skyerr.PermissionDenied,
"no permission to create",
)
}

return
return
}
return nil, skyerr.NewError(skyerr.UnexpectedError, dbErr.Error())
}

record = &dbRecord
Expand Down
37 changes: 22 additions & 15 deletions pkg/server/skydb/pq/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,43 +94,50 @@ func (c *conn) Db() Ext {
}

// Begin begins a transaction.
func (c *conn) Begin() (err error) {
func (c *conn) Begin() error {
log.Debugf("%p: Beginning transaction", c)
if c.tx != nil {
return skydb.ErrDatabaseTxDidBegin
}

c.tx, err = c.db.Beginx()
tx, err := c.db.Beginx()
if err != nil {
log.Debugf("%p: Unable to begin transaction %p: %v", c, err)
return err
}
c.tx = tx
log.Debugf("%p: Done beginning transaction %p", c, c.tx)
return
return nil
}

// Commit commits a transaction.
func (c *conn) Commit() (err error) {
func (c *conn) Commit() error {
if c.tx == nil {
return skydb.ErrDatabaseTxDidNotBegin
}

err = c.tx.Commit()
if err == nil {
c.tx = nil
if err := c.tx.Commit(); err != nil {
log.Errorf("%p: Unable to commit transaction %p: %v", c, c.tx, err)
return err
}
log.Debugf("%p: Committed transaction %p", c, c.tx)
return
c.tx = nil
log.Debugf("%p: Committed transaction", c)
return nil
}

// Rollback rollbacks a transaction.
func (c *conn) Rollback() (err error) {
func (c *conn) Rollback() error {
if c.tx == nil {
return skydb.ErrDatabaseTxDidNotBegin
}

err = c.tx.Rollback()
if err == nil {
c.tx = nil
if err := c.tx.Rollback(); err != nil {
log.Errorf("%p: Unable to rollback transaction %p: %v", c, c.tx, err)
return err
}
log.Debugf("%p: Rolled back transaction %p", c, c.tx)
return
c.tx = nil
log.Debugf("%p: Rolled back transaction", c)
return nil
}

func (c *conn) PublicDB() skydb.Database {
Expand Down
22 changes: 16 additions & 6 deletions pkg/server/skydb/pq/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,18 +283,16 @@ func (db *database) getSequences(recordType string) ([]string, error) {
// AND tc.table_name = 'note';
func (db *database) remoteColumnTypes(recordType string) (skydb.RecordSchema, error) {
typemap := skydb.RecordSchema{}
var err error
// STEP 0: Return the cached ColumnType
if schema, ok := db.c.RecordSchema[recordType]; ok {
log.Debugf("Using cached remoteColumnTypes %s", recordType)
return schema, nil
}
defer func() {
db.c.RecordSchema[recordType] = typemap
log.Debugf("Cache remoteColumnTypes %s", recordType)
}()
log.Debugf("Querying remoteColumnTypes %s", recordType)
// STEP 1: Get the oid of the current table
var oid int
err := db.c.QueryRowx(`
err = db.c.QueryRowx(`
SELECT c.oid
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
Expand All @@ -303,6 +301,8 @@ WHERE c.relname = $1
recordType, db.schemaName()).Scan(&oid)

if err == sql.ErrNoRows {
db.c.RecordSchema[recordType] = nil
log.Debugf("Cache remoteColumnTypes %s (no table)", recordType)
return nil, nil
}
if err != nil {
Expand Down Expand Up @@ -334,6 +334,7 @@ WHERE a.attrelid = $1 AND a.attnum > 0 AND NOT a.attisdropped`,

var columnName, pqType string
var integerColumns = []string{}
var columnErrors []error
for rows.Next() {
if err := rows.Scan(&columnName, &pqType); err != nil {
return nil, err
Expand All @@ -357,15 +358,21 @@ WHERE a.attrelid = $1 AND a.attnum > 0 AND NOT a.attisdropped`,
}
case TypeLocation:
schema.Type = skydb.TypeLocation
case TypeBigInteger:
fallthrough
case TypeInteger:
schema.Type = skydb.TypeInteger
integerColumns = append(integerColumns, columnName)
default:
return nil, fmt.Errorf("received unknown data type = %s for column = %s", pqType, columnName)
// We need to enumerate all rows, so do not simply return with the error here
columnErrors = append(columnErrors, fmt.Errorf("received unknown data type = %s for column = %s", pqType, columnName))
}

typemap[columnName] = schema
}
if len(columnErrors) > 0 {
return nil, columnErrors[0]
}

// STEP 2.1: Convert integer column to sequence column if applicable
if len(integerColumns) > 0 {
Expand Down Expand Up @@ -423,6 +430,9 @@ WHERE a.attrelid = $1 AND a.attnum > 0 AND NOT a.attisdropped`,
}
typemap[primaryColumn] = s
}

db.c.RecordSchema[recordType] = typemap
log.Debugf("Cache remoteColumnTypes %s", recordType)
return typemap, nil
}

Expand Down
21 changes: 11 additions & 10 deletions pkg/server/skydb/pq/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ import (
// NOTE(limouren): varchar is missing because text can replace them,
// see the docs here: http://www.postgresql.org/docs/9.5/static/datatype-character.html
const (
TypeString = "text"
TypeNumber = "double precision"
TypeBoolean = "boolean"
TypeJSON = "jsonb"
TypeTimestamp = "timestamp without time zone"
TypeLocation = "geometry(Point)"
TypeInteger = "integer"
TypeSerial = "serial UNIQUE"
TypeString = "text"
TypeNumber = "double precision"
TypeBoolean = "boolean"
TypeJSON = "jsonb"
TypeTimestamp = "timestamp without time zone"
TypeLocation = "geometry(Point)"
TypeInteger = "integer"
TypeSerial = "serial UNIQUE"
TypeBigInteger = "bigint"
)

type nullJSON struct {
Expand Down Expand Up @@ -106,13 +107,13 @@ func (na *nullAsset) Scan(value interface{}) error {
return nil
}

assetName, ok := value.([]byte)
assetName, ok := value.(string)
if !ok {
return fmt.Errorf("failed to scan Asset: got type(value) = %T, expect []byte", value)
}

na.Asset = &skydb.Asset{
Name: string(assetName),
Name: assetName,
}
na.Valid = true

Expand Down