diff --git a/cli/server/auth.go b/cli/server/auth.go index 8dbb4510d..79d2c2ff0 100644 --- a/cli/server/auth.go +++ b/cli/server/auth.go @@ -98,7 +98,7 @@ func (ah *copAuthHandler) serveHTTP(w http.ResponseWriter, r *http.Request) erro log.Debugf("Basic auth is not allowed; found %s", authHdr) return errBasicAuthNotAllowed } - _, err := cfg.UserRegistery.LoginUserBasicAuth(user, pwd) + _, err := cfg.UserRegistry.LoginUserBasicAuth(user, pwd) if err != nil { log.Errorf("Failed authorizing user, [error: %s]", err) return err diff --git a/cli/server/bootstrap.go b/cli/server/bootstrap.go index f6cd69638..86191c740 100644 --- a/cli/server/bootstrap.go +++ b/cli/server/bootstrap.go @@ -27,20 +27,18 @@ import ( // Bootstrap is used for bootstrapping database type Bootstrap struct { - cfg *Config } // BootstrapDB is a constructor to bootstrap the database at server startup func BootstrapDB() *Bootstrap { b := new(Bootstrap) - b.cfg = CFG return b } // PopulateUsersTable populates the user table with the users defined in the server configuration file func (b *Bootstrap) PopulateUsersTable() error { log.Debug("populateUsersTable") - for name, info := range b.cfg.Users { + for name, info := range CFG.Users { reg := NewRegisterUser() reg.RegisterUser(name, info.Type, info.Group, info.Attributes, "", info.Pass) @@ -71,7 +69,7 @@ func (b *Bootstrap) PopulateGroupsTable() { replacer := strings.NewReplacer(".", "_") viper.SetEnvKeyReplacer(replacer) - base := filepath.Base(b.cfg.ConfigFile) + base := filepath.Base(CFG.ConfigFile) filename := strings.Split(base, ".") name := filename[0] typ := filename[1] @@ -79,7 +77,7 @@ func (b *Bootstrap) PopulateGroupsTable() { viper.SetConfigName(name) viper.SetConfigType(typ) - configPath := filepath.Dir(b.cfg.ConfigFile) + configPath := filepath.Dir(CFG.ConfigFile) viper.AddConfigPath(configPath) err := viper.ReadInConfig() if err != nil { @@ -103,13 +101,13 @@ func (b *Bootstrap) registerGroup(name string, parentName string) error { log.Debugf("Registering affiliation group (%s) with parent (%s)", name, parentName) var err error - _, err = b.cfg.UserRegistery.GetGroup(name) + _, err = CFG.UserRegistry.GetGroup(name) if err == nil { log.Error("Group already registered") return errors.New("Group already registered") } - err = b.cfg.UserRegistery.InsertGroup(name, parentName) + err = CFG.UserRegistry.InsertGroup(name, parentName) if err != nil { log.Error(err) } diff --git a/cli/server/bootstrap_test.go b/cli/server/bootstrap_test.go index 4e6e3299d..4f02b09c0 100644 --- a/cli/server/bootstrap_test.go +++ b/cli/server/bootstrap_test.go @@ -49,7 +49,7 @@ func prepBootstrap() (*Bootstrap, error) { bootCFG.Home = bootPath bootCFG.DataSource = bootCFG.Home + "/cop.db" - CFG.UserRegistery, err = NewUserRegistry(bootCFG.DBdriver, bootCFG.DataSource) + CFG.UserRegistry, err = NewUserRegistry(bootCFG.DBdriver, bootCFG.DataSource) if err != nil { return nil, err } @@ -73,7 +73,7 @@ func TestAllBootstrap(t *testing.T) { func testBootstrapGroup(b *Bootstrap, t *testing.T) { b.PopulateGroupsTable() - _, err := b.cfg.UserRegistery.GetGroup("bank_b") + _, err := CFG.UserRegistry.GetGroup("bank_b") if err != nil { t.Error("Failed bootstrapping groups table") @@ -83,7 +83,7 @@ func testBootstrapGroup(b *Bootstrap, t *testing.T) { func testBootstrapUsers(b *Bootstrap, t *testing.T) { b.PopulateUsersTable() - _, err := b.cfg.UserRegistery.GetUser("admin") + _, err := CFG.UserRegistry.GetUser("admin") if err != nil { t.Error("Failed bootstrapping users table") diff --git a/cli/server/config.go b/cli/server/config.go index da4e40d54..6b854bc44 100644 --- a/cli/server/config.go +++ b/cli/server/config.go @@ -39,6 +39,7 @@ type Config struct { Users map[string]*User `json:"users,omitempty"` DBdriver string `json:"driver"` DataSource string `json:"data_source"` + UsrReg UserReg `json:"user_registry"` Home string ConfigFile string CACert string @@ -46,7 +47,12 @@ type Config struct { DB *sqlx.DB certDBAccessor certdb.Accessor Signer signer.Signer - UserRegistery spi.UserRegistry + UserRegistry spi.UserRegistry +} + +// UserReg defines the user registry properties +type UserReg struct { + MaxEnrollments int `json:"max_enrollments"` } // User information @@ -92,6 +98,10 @@ func configInit(cfg *cli.Config) { } } + if CFG.UsrReg.MaxEnrollments == 0 { + CFG.UsrReg.MaxEnrollments = 1 + } + dbg := os.Getenv("COP_DEBUG") if dbg != "" { CFG.Debug = dbg == "true" diff --git a/cli/server/dasqlite_test.go b/cli/server/dasqlite_test.go index c4873988d..2de0e4bb6 100644 --- a/cli/server/dasqlite_test.go +++ b/cli/server/dasqlite_test.go @@ -38,7 +38,7 @@ DELETE FROM Groups; ) type TestAccessor struct { - Accessor spi.UserRegistry + Accessor *Accessor DB *sqlx.DB } @@ -55,19 +55,16 @@ func TestSQLite(t *testing.T) { os.RemoveAll(dbPath) os.MkdirAll(dbPath, 0755) } - var cfg = new(Config) cfg.DBdriver = "sqlite3" cfg.DataSource = dbPath + "/cop.db" - accessor, err := NewUserRegistry(cfg.DBdriver, cfg.DataSource) - if err != nil { - t.Error("Failed to get new user registery") - } db, err := dbutil.GetDB(cfg.DBdriver, cfg.DataSource) if err != nil { t.Error("Failed to open connection to DB") } + accessor := NewDBAccessor() + accessor.SetDB(db) ta := TestAccessor{ Accessor: accessor, @@ -102,6 +99,7 @@ func testEverything(ta TestAccessor, t *testing.T) { testUpdateUser(ta, t) testInsertAndGetGroup(ta, t) testDeleteGroup(ta, t) + testUpdateAndGetField(ta, t) } func testInsertAndGetUser(ta TestAccessor, t *testing.T) { @@ -234,3 +232,30 @@ func testDeleteGroup(ta TestAccessor, t *testing.T) { t.Error("Should have errored, and not returned any results") } } + +func testUpdateAndGetField(ta TestAccessor, t *testing.T) { + ta.Truncate() + + insert := spi.UserInfo{ + Name: "testId", + Pass: "123456", + Type: "client", + Attributes: []idp.Attribute{}, + } + + err := ta.Accessor.InsertUser(insert) + if err != nil { + t.Errorf("Error occured during insert query of ID: %s, error: %s", insert.Name, err) + } + + err = ta.Accessor.UpdateField(insert.Name, serialNumber, "1234567890") + if err != nil { + t.Errorf("Error occured during updating of field serial_number for ID: %s, error: %s", insert.Name, err) + } + + _, err = ta.Accessor.GetField(insert.Name, serialNumber) + if err != nil { + t.Errorf("Error occured during get of field serial_number for ID: %s, error: %s", insert.Name, err) + } + +} diff --git a/cli/server/dbaccessor.go b/cli/server/dbaccessor.go index 3d2e80ce5..a583ae1fa 100644 --- a/cli/server/dbaccessor.go +++ b/cli/server/dbaccessor.go @@ -17,7 +17,6 @@ limitations under the License. package server import ( - "database/sql" "encoding/json" "errors" "fmt" @@ -39,58 +38,66 @@ func init() { const ( insertUser = ` -INSERT INTO Users (id, token, type, attributes, state, serial_number, authority_key_identifier) - VALUES (:id, :token, :type, :attributes, :state, :serial_number, :authority_key_identifier);` +INSERT INTO users (id, token, type, user_group, attributes, state, max_enrollments, serial_number, authority_key_identifier) + VALUES (:id, :token, :type, :user_group, :attributes, :state, :max_enrollments, :serial_number, :authority_key_identifier);` deleteUser = ` -DELETE FROM Users +DELETE FROM users WHERE (id = ?);` updateUser = ` -UPDATE Users - SET token = :token, type = :type, attributes = :attributes +UPDATE users + SET token = :token, type = :type, user_group = :user_group, attributes = :attributes WHERE (id = :id);` getUser = ` -SELECT * FROM Users +SELECT * FROM users WHERE (id = ?)` insertGroup = ` -INSERT INTO Groups (name, parent_id) +INSERT INTO groups (name, parent_id) VALUES (?, ?)` deleteGroup = ` -DELETE FROM Groups +DELETE FROM groups WHERE (name = ?)` getGroup = ` -SELECT name, parent_id FROM Groups +SELECT name, parent_id FROM groups WHERE (name = ?)` ) const ( - password = iota - state - serialNumber + serialNumber = iota aki + prekey + maxEnrollments + state ) // UserRecord defines the properties of a user type UserRecord struct { - Name string `db:"id"` - Pass string `db:"token"` - Type string `db:"type"` - Attributes string `db:"attributes"` - State int `db:"state"` - SerialNumber string `db:"serial_number"` - AKI string `db:"authority_key_identifier"` + Name string `db:"id"` + Pass string `db:"token"` + Type string `db:"type"` + Group string `db:"user_group"` + Attributes string `db:"attributes"` + State int `db:"state"` + MaxEnrollments int `db:"max_enrollments"` + SerialNumber string `db:"serial_number"` + AKI string `db:"authority_key_identifier"` +} + +// GroupRecord defines the properties of a group +type GroupRecord struct { + Name string `db:"name"` + ParentID string `db:"parent_id"` + Prekey string `db:"prekey"` } // Accessor implements db.Accessor interface. type Accessor struct { - state int - serialNumber string - db *sqlx.DB + db *sqlx.DB } // NewDBAccessor is a constructor for the database API @@ -125,15 +132,39 @@ func (d *Accessor) LoginUserBasicAuth(user, pass string) (spi.User, error) { userInfo := convertToUserInfo(&userRec) if userRec.Pass == pass { - if userRec.State == 0 { + if userRec.State >= 0 && userRec.State < userRec.MaxEnrollments { + state := userRec.State + 1 + res, err := d.db.Exec("UPDATE users SET state = ? WHERE (id = ?)", state, user) + if err != nil { + return nil, err + } + + numRowsAffected, err := res.RowsAffected() + + if err != nil { + return nil, err + } + + if numRowsAffected == 0 { + return nil, cop.NewError(cop.UserStoreError, "Failed to update the user record") + } + + if numRowsAffected != 1 { + return nil, cop.NewError(cop.UserStoreError, "%d rows are affected, should be 1 row", numRowsAffected) + } + return userInfo, nil } + _, err := d.db.Exec("UPDATE users SET token = ? WHERE (id = ?)", "", user) + if err != nil { + return nil, err + } log.Errorf("User (%s) has already been enrolled", user) return nil, cop.NewError(cop.AuthorizationFailure, "User has already been enrolled") } - log.Errorf("Incorrect password provided for user (%s)", user) - return nil, cop.NewError(cop.AuthorizationFailure, "Incorrect password provided for user (%s)", user) + log.Errorf("Incorrect username/password provided") + return nil, cop.NewError(cop.AuthorizationFailure, "Incorrect username/password provided)") } // InsertUser inserts user into database @@ -154,6 +185,7 @@ func (d *Accessor) InsertUser(user spi.UserInfo) error { Name: user.Name, Pass: user.Pass, Type: user.Type, + Group: user.Group, Attributes: string(attrBytes), }) @@ -218,6 +250,7 @@ func (d *Accessor) UpdateUser(user spi.UserInfo) error { Name: user.Name, Pass: user.Pass, Type: user.Type, + Group: user.Group, Attributes: string(attributes), }) @@ -247,21 +280,35 @@ func (d *Accessor) UpdateField(id string, field int, value interface{}) error { return err } - var res sql.Result - switch field { - case password: - log.Debug("DB: Updating field: token") - v := value.(string) - res, err = d.db.Exec("UPDATE Users SET token = ? WHERE (id = ?)", v, id) + case serialNumber: + log.Debug("Update serial number") + val := value.(string) + _, err = d.db.Exec("UPDATE users SET serial_number = ? WHERE (id = ?)", val, id) + if err != nil { + return err + } + case aki: + log.Debug("Update authority key idenitifier") + val := value.(string) + _, err = d.db.Exec("UPDATE users SET authority_key_identifier = ? WHERE (id = ?)", val, id) + if err != nil { + return err + } + case maxEnrollments: + log.Debug("Update max enrollments") + val := value.(int) + _, err = d.db.Exec("UPDATE users SET max_enrollments = ? WHERE (id = ?)", val, id) if err != nil { + log.Error(err) return err } - case field: - log.Debug("DB: Updating field: state") - v := value.(int) - res, err = d.db.Exec("UPDATE Users SET state = ? WHERE (id = ?)", v, id) + case state: + log.Debug("Update state") + val := value.(int) + _, err = d.db.Exec("UPDATE users SET state = ? WHERE (id = ?)", val, id) if err != nil { + log.Error(err) return err } default: @@ -269,17 +316,46 @@ func (d *Accessor) UpdateField(id string, field int, value interface{}) error { return cop.NewError(cop.DatabaseError, "DB: Specified field does not exist or cannot be updated") } - numRowsAffected, err := res.RowsAffected() + return err +} - if numRowsAffected == 0 { - return cop.NewError(cop.UserStoreError, "Failed to update the user record") +// GetField updates a specific field in database +func (d *Accessor) GetField(id string, field int) (interface{}, error) { + err := d.checkDB() + if err != nil { + return nil, err } - if numRowsAffected != 1 { - return cop.NewError(cop.UserStoreError, "%d rows are affected, should be 1 row", numRowsAffected) + switch field { + case prekey: + log.Debug("Get prekey") + var groupRec GroupRecord + err = d.db.Get(&groupRec, "SELECT prekey FROM groups WHERE (name = ?)", id) + if err != nil { + return nil, err + } + return groupRec.Prekey, nil + case serialNumber: + log.Debug("Get serial number") + var userRec UserRecord + err = d.db.Get(&userRec, "SELECT serial_number FROM users WHERE (id = ?)", id) + if err != nil { + return nil, err + } + return userRec.SerialNumber, nil + case aki: + log.Debug("Get authority key idenitifier") + var userRec UserRecord + err = d.db.Get(&userRec, "SELECT authority_key_identifier FROM users WHERE (id = ?)", id) + if err != nil { + return nil, err + } + return userRec.AKI, nil + default: + log.Error("DB: Specified field does not exist or cannot be retrieved") + return nil, cop.NewError(cop.DatabaseError, "DB: Specified field does not exist or cannot be retrieved") } - return err } // GetUser gets user from database @@ -314,6 +390,15 @@ func (d *Accessor) InsertGroup(name string, parentID string) error { return err } + /* + preKeyString := crypto.CreateRootPreKey() + + _, err = d.db.Exec("UPDATE groups SET prekey = ? WHERE (name = ?)", preKeyString, name) + if err != nil { + return err + } + */ + return nil } diff --git a/cli/server/dbutil/dbutil.go b/cli/server/dbutil/dbutil.go index 2700c2922..ba816f42d 100644 --- a/cli/server/dbutil/dbutil.go +++ b/cli/server/dbutil/dbutil.go @@ -56,6 +56,7 @@ func GetDB(dbdriver string, datasource string) (*sqlx.DB, error) { // NewUserRegistrySQLLite3 returns a pointer to a sqlite database func NewUserRegistrySQLLite3(datasource string) (*sqlx.DB, bool, error) { log.Debugf("Using sqlite database, connect to database in home (%s) directory", datasource) + datasource = filepath.Join(datasource) var exists bool @@ -92,15 +93,16 @@ func createSQLiteDBTables(datasource string) error { return cop.WrapError(err, cop.DatabaseError, "Failed to connect to database") } - if _, err := db.Exec("CREATE TABLE IF NOT EXISTS users (id VARCHAR(64), token bytea, type VARCHAR(64), attributes VARCHAR(256), state INTEGER, serial_number bytea, authority_key_identifier bytea)"); err != nil { + log.Debug("Creating tables...") + if _, err := db.Exec("CREATE TABLE IF NOT EXISTS users (id VARCHAR(64), token bytea, type VARCHAR(64), user_group VARCHAR(64), attributes VARCHAR(256), state INTEGER, max_enrollments INTEGER, serial_number bytea, authority_key_identifier bytea)"); err != nil { return err } - if _, err := db.Exec("CREATE TABLE IF NOT EXISTS groups (name VARCHAR(64), parent_id VARCHAR(64), group_key VARCHAR(48))"); err != nil { + if _, err := db.Exec("CREATE TABLE IF NOT EXISTS groups (name VARCHAR(64), parent_id VARCHAR(64), prekey VARCHAR(48))"); err != nil { return err } - if _, err := db.Exec("CREATE TABLE IF NOT EXISTS certificates (serial_number bytea NOT NULL, authority_key_identifier bytea NOT NULL, ca_label bytea, status bytea NOT NULL, reason int, expiry timestamp, revoked_at timestamp, pem bytea NOT NULL, PRIMARY KEY(serial_number, authority_key_identifier))"); err != nil { + if _, err := db.Exec("CREATE TABLE IF NOT EXISTS certificates (id VARCHAR(64), serial_number bytea NOT NULL, authority_key_identifier bytea NOT NULL, ca_label bytea, status bytea NOT NULL, reason int, expiry timestamp, revoked_at timestamp, pem bytea NOT NULL, PRIMARY KEY(serial_number, authority_key_identifier))"); err != nil { return err } @@ -168,17 +170,16 @@ func createPostgresDBTables(datasource string, dbName string, db *sqlx.DB) error log.Errorf("Failed to open database (%s)", dbName) } - log.Debug("Create Tables...") - if _, err := database.Exec("CREATE TABLE users (id VARCHAR(64), token bytea, type VARCHAR(64), attributes VARCHAR(256), state INTEGER, serial_number bytea, authority_key_identifier bytea)"); err != nil { + log.Debug("Creating Tables...") + if _, err := database.Exec("CREATE TABLE users (id VARCHAR(64), token bytea, type VARCHAR(64), user_group VARCHAR(64), attributes VARCHAR(256), state INTEGER, max_enrollments INTEGER, serial_number bytea, authority_key_identifier bytea)"); err != nil { log.Errorf("Error creating users table [error: %s] ", err) - return err } - if _, err := database.Exec("CREATE TABLE groups (name VARCHAR(64), parent_id VARCHAR(64), group_key VARCHAR(48))"); err != nil { + if _, err := database.Exec("CREATE TABLE groups (name VARCHAR(64), parent_id VARCHAR(64), prekey VARCHAR(48))"); err != nil { log.Errorf("Error creating groups table [error: %s] ", err) return err } - if _, err := database.Exec("CREATE TABLE certificates (serial_number bytea NOT NULL, authority_key_identifier bytea NOT NULL, ca_label bytea, status bytea NOT NULL, reason int, expiry timestamp, revoked_at timestamp, pem bytea NOT NULL, PRIMARY KEY(serial_number, authority_key_identifier))"); err != nil { + if _, err := database.Exec("CREATE TABLE certificates (id VARCHAR(64), serial_number bytea NOT NULL, authority_key_identifier bytea NOT NULL, ca_label bytea, status bytea NOT NULL, reason int, expiry timestamp, revoked_at timestamp, pem bytea NOT NULL, PRIMARY KEY(serial_number, authority_key_identifier))"); err != nil { log.Errorf("Error creating certificates table [error: %s] ", err) return err } @@ -242,16 +243,16 @@ func createMySQLTables(datasource string, dbName string, db *sqlx.DB) error { if err != nil { log.Errorf("Failed to open database (%s), err: %s", dbName, err) } - log.Debug("Create Tables...") - if _, err := database.Exec("CREATE TABLE users (id VARCHAR(64) NOT NULL, token blob, type VARCHAR(64), attributes VARCHAR(256), state INTEGER, serial_number varbinary(20), authority_key_identifier varbinary(128), PRIMARY KEY (id))"); err != nil { + log.Debug("Creating Tables...") + if _, err := database.Exec("CREATE TABLE users (id VARCHAR(64) NOT NULL, token blob, type VARCHAR(64), user_group VARCHAR(64), attributes VARCHAR(256), state INTEGER, max_enrollments INTEGER, serial_number varbinary(20), authority_key_identifier varbinary(128), PRIMARY KEY (id))"); err != nil { log.Errorf("Error creating users table [error: %s] ", err) return err } - if _, err := database.Exec("CREATE TABLE groups (name VARCHAR(64), parent_id VARCHAR(64), group_key VARCHAR(48))"); err != nil { + if _, err := database.Exec("CREATE TABLE groups (name VARCHAR(64), parent_id VARCHAR(64), prekey VARCHAR(48))"); err != nil { log.Errorf("Error creating groups table [error: %s] ", err) return err } - if _, err := database.Exec("CREATE TABLE certificates (serial_number varbinary(20) NOT NULL, authority_key_identifier varbinary(128) NOT NULL, ca_label varbinary(128), status varbinary(128) NOT NULL, reason int, expiry timestamp DEFAULT '1970-01-01 00:00:01', revoked_at timestamp DEFAULT '1970-01-01 00:00:01', pem varbinary(4096) NOT NULL, PRIMARY KEY(serial_number, authority_key_identifier))"); err != nil { + if _, err := database.Exec("CREATE TABLE certificates (id VARCHAR(64), serial_number varbinary(20) NOT NULL, authority_key_identifier varbinary(128) NOT NULL, ca_label varbinary(128), status varbinary(128) NOT NULL, reason int, expiry timestamp DEFAULT '1970-01-01 00:00:01', revoked_at timestamp DEFAULT '1970-01-01 00:00:01', pem varbinary(4096) NOT NULL, PRIMARY KEY(serial_number, authority_key_identifier))"); err != nil { log.Errorf("Error creating certificates table [error: %s] ", err) return err } diff --git a/cli/server/enroll.go b/cli/server/enroll.go index bb7a27734..54dc7fee1 100644 --- a/cli/server/enroll.go +++ b/cli/server/enroll.go @@ -28,7 +28,6 @@ import ( "github.com/cloudflare/cfssl/log" "github.com/cloudflare/cfssl/signer" cop "github.com/hyperledger/fabric-cop/api" - "github.com/hyperledger/fabric-cop/util" "github.com/jmoiron/sqlx" ) @@ -96,20 +95,6 @@ func (e *Enroll) Enroll(id string, token []byte, csrPEM []byte) ([]byte, cop.Err return nil, signErr } - tok := util.RandomString(12) - - err := e.cfg.UserRegistery.UpdateField(id, password, tok) - if err != nil { - log.Errorf("Failed to update user token - Enroll Failed [error: %s]", err) - return nil, cop.WrapError(err, cop.EnrollingUserError, "Failed to update user token - Enroll Failed") - } - - err = e.cfg.UserRegistery.UpdateField(id, state, 1) - if err != nil { - log.Errorf("Failed to update user state - Enroll Failed [error: %s]", err) - return nil, cop.WrapError(err, cop.EnrollingUserError, "Failed to update user state - Enroll Failed") - } - return cert, nil } diff --git a/cli/server/factory.go b/cli/server/factory.go index 5bfa2bb57..8a03de86f 100644 --- a/cli/server/factory.go +++ b/cli/server/factory.go @@ -22,6 +22,7 @@ limitations under the License. package server import ( + "github.com/cloudflare/cfssl/log" cop "github.com/hyperledger/fabric-cop/api" "github.com/hyperledger/fabric-cop/cli/server/dbutil" "github.com/hyperledger/fabric-cop/cli/server/spi" @@ -30,6 +31,7 @@ import ( // NewUserRegistry abstracts out the user retreival func NewUserRegistry(typ string, config string) (spi.UserRegistry, error) { + log.Debugf("Create new user registry of type: %s", typ) var db *sqlx.DB var err error var exists bool @@ -60,7 +62,7 @@ func NewUserRegistry(typ string, config string) (spi.UserRegistry, error) { dbAccessor := new(Accessor) dbAccessor.SetDB(db) - CFG.UserRegistery = dbAccessor + CFG.UserRegistry = dbAccessor if !exists { err := bootstrapDB() diff --git a/cli/server/register.go b/cli/server/register.go index b14f76e89..bff9358b3 100644 --- a/cli/server/register.go +++ b/cli/server/register.go @@ -63,8 +63,6 @@ func (h *registerHandler) Handle(w http.ResponseWriter, r *http.Request) error { return err } - // attributes, _ := json.Marshal(req.Attributes) - // Register User tok, err := reg.RegisterUser(req.User, req.Type, req.Group, req.Attributes, req.CallerID) if err != nil { @@ -116,7 +114,8 @@ func (r *Register) RegisterUser(id string, userType string, group string, attrib return "", err } - tok, err = r.registerUserID(id, userType, attributes, opt...) + tok, err = r.registerUserID(id, userType, group, attributes, opt...) + if err != nil { return "", err } @@ -147,7 +146,7 @@ func (r *Register) validateID(id string, userType string, group string) error { } // registerUserID registers a new user and its enrollmentID, role and state -func (r *Register) registerUserID(id string, userType string, attributes []idp.Attribute, opt ...string) (string, error) { +func (r *Register) registerUserID(id string, userType string, group string, attributes []idp.Attribute, opt ...string) (string, error) { log.Debugf("Registering user id: %s\n", id) mutex.Lock() defer mutex.Unlock() @@ -163,15 +162,21 @@ func (r *Register) registerUserID(id string, userType string, attributes []idp.A Name: id, Pass: tok, Type: userType, + Group: group, Attributes: attributes, } - _, err := r.cfg.UserRegistery.GetUser(id) + _, err := r.cfg.UserRegistry.GetUser(id) if err == nil { log.Error("User is already registered") - return "", errors.New("User is already registered") + return "", cop.NewError(cop.RegisteringUserError, "User is already registered") } - err = r.cfg.UserRegistery.InsertUser(insert) + err = r.cfg.UserRegistry.InsertUser(insert) + if err != nil { + return "", err + } + + err = r.cfg.UserRegistry.UpdateField(id, maxEnrollments, CFG.UsrReg.MaxEnrollments) if err != nil { return "", err } @@ -181,9 +186,8 @@ func (r *Register) registerUserID(id string, userType string, attributes []idp.A func (r *Register) isValidGroup(group string) (bool, error) { log.Debug("Validating group: " + group) - // Check cop.yaml to see if group is valid - _, err := r.cfg.UserRegistery.GetGroup(group) + _, err := r.cfg.UserRegistry.GetGroup(group) if err != nil { log.Error("Error occured getting group: ", err) return false, err @@ -209,7 +213,7 @@ func (r *Register) canRegister(registrar string, userType string) error { user, check, err := r.isRegistrar(registrar) if err != nil { - return errors.New("Can't Register: " + err.Error()) + return cop.NewError(cop.RegisteringUserError, "Can't Register: [error: %s]"+err.Error()) } if check != true { @@ -226,7 +230,7 @@ func (r *Register) canRegister(registrar string, userType string) error { if strings.ToLower(rAttr.Name) == strings.ToLower(delegateRoles) { registrarRoles := strings.Split(rAttr.Value, ",") if !util.StrContained(userType, registrarRoles) { - return errors.New("user " + registrar + " may not register type " + userType) + return cop.NewError(cop.RegisteringUserError, "user %s may not register type %s", registrar, userType) } } } @@ -238,7 +242,7 @@ func (r *Register) canRegister(registrar string, userType string) error { func (r *Register) isRegistrar(registrar string) (spi.User, bool, error) { log.Debugf("isRegistrar - Check if specified registrar (%s) has appropriate permissions", registrar) - user, err := r.cfg.UserRegistery.GetUser(registrar) + user, err := r.cfg.UserRegistry.GetUser(registrar) if err != nil { return nil, false, errors.New("Registrar does not exist") } @@ -252,5 +256,5 @@ func (r *Register) isRegistrar(registrar string) (spi.User, bool, error) { } log.Errorf("%s is not a registrar", registrar) - return nil, false, errors.New("Is not registrar") + return nil, false, cop.NewError(cop.RegisteringUserError, "%s is not a registrar", registrar) } diff --git a/cli/server/register_test.go b/cli/server/register_test.go index 3a844f2e7..e1befb196 100644 --- a/cli/server/register_test.go +++ b/cli/server/register_test.go @@ -18,6 +18,7 @@ package server import ( "errors" + "fmt" "os" "testing" @@ -67,7 +68,7 @@ func prepRegister() error { regCFG.Home = regPath regCFG.DataSource = regCFG.Home + "/cop.db" - CFG.UserRegistery, err = NewUserRegistry(regCFG.DBdriver, regCFG.DataSource) + CFG.UserRegistry, err = NewUserRegistry(regCFG.DBdriver, regCFG.DataSource) if err != nil { return err } @@ -136,7 +137,8 @@ func testRegisterDuplicateUser(t *testing.T) { t.Fatal("Expected an error when registering the same user twice") } - if err.Error() != "User is already registered" { + expectedError := fmt.Sprintf("%d: User is already registered", cop.RegisteringUserError) + if err.Error() != expectedError { t.Fatalf("Expected error was not returned when registering user twice: [%s]", err.Error()) } } diff --git a/cli/server/revoke.go b/cli/server/revoke.go index 5d51fa078..1a6559d51 100644 --- a/cli/server/revoke.go +++ b/cli/server/revoke.go @@ -83,16 +83,12 @@ func (h *revokeHandler) Handle(w http.ResponseWriter, r *http.Request) error { return notFound(w, err) } } else if req.Name != "" { - _, err := CFG.UserRegistery.GetUser(req.Name) + _, err := CFG.UserRegistry.GetUser(req.Name) if err != nil { log.Warningf("Revoke failed: %s", err) return notFound(w, err) } - // userInfo := user.(*spi.UserInfo) - // userInfo. - // user.State = -1 - // err = CFG.UserRegistery.UpdateUser(user) - err = CFG.UserRegistery.UpdateField(req.Name, state, -1) + err = CFG.UserRegistry.UpdateField(req.Name, state, -1) if err != nil { log.Warningf("Revoke failed: %s", err) return dbErr(w, err) diff --git a/cli/server/server_test.go b/cli/server/server_test.go index 216b2a89e..91cca7de3 100644 --- a/cli/server/server_test.go +++ b/cli/server/server_test.go @@ -17,6 +17,7 @@ limitations under the License. package server import ( + "encoding/base64" "fmt" "io/ioutil" "os" @@ -82,12 +83,12 @@ func TestRegisterUser(t *testing.T) { copServer := `{"serverURL":"http://localhost:8888"}` c, _ := lib.NewClient(copServer) - regReq := &idp.EnrollmentRequest{ + enrollReq := &idp.EnrollmentRequest{ Name: "admin", Secret: "adminpw", } - ID, err := c.Enroll(regReq) + ID, err := c.Enroll(enrollReq) if err != nil { t.Error("enroll of user 'admin' with password 'adminpw' failed") return @@ -99,7 +100,7 @@ func TestRegisterUser(t *testing.T) { return } - enrollReq := &idp.RegistrationRequest{ + regReq := &idp.RegistrationRequest{ Name: "TestUser1", Type: "Client", Group: "bank_a", @@ -112,9 +113,9 @@ func TestRegisterUser(t *testing.T) { } util.Unmarshal(identity, id, "identity") - enrollReq.Registrar = id + regReq.Registrar = id - _, err = c.Register(enrollReq) + _, err = c.Register(regReq) if err != nil { t.Error(err) } @@ -187,6 +188,58 @@ func TestRevoke(t *testing.T) { if err == nil { t.Error("Revoke with with bogus serial and AKI should have failed but did not") } +} + +func TestMaxEnrollment(t *testing.T) { + CFG.UsrReg.MaxEnrollments = 2 + + copServer := `{"serverURL":"http://localhost:8888"}` + c, _ := lib.NewClient(copServer) + + regReq := &idp.RegistrationRequest{ + Name: "MaxTestUser", + Type: "Client", + Group: "bank_a", + } + + id, _ := factory.NewIdentity() + identity, err := ioutil.ReadFile("/tmp/home/client.json") + if err != nil { + t.Error(err) + } + util.Unmarshal(identity, id, "identity") + + regReq.Registrar = id + + resp, err := c.Register(regReq) + if err != nil { + t.Error(err) + } + + secretBytes, err := base64.StdEncoding.DecodeString(resp.Secret) + + enrollReq := &idp.EnrollmentRequest{ + Name: "MaxTestUser", + Secret: string(secretBytes), + } + + _, err = c.Enroll(enrollReq) + if err != nil { + t.Error("Enroll of user 'MaxTestUser' failed") + return + } + + _, err = c.Enroll(enrollReq) + if err != nil { + t.Error("Enroll of user 'MaxTestUser' failed") + return + } + + _, err = c.Enroll(enrollReq) + if err == nil { + t.Error("Enroll of user should have failed, max enrollment reached") + return + } } diff --git a/cli/server/spi/user_test.go b/cli/server/spi/user_test.go index 0f7935ffb..c95146bbf 100644 --- a/cli/server/spi/user_test.go +++ b/cli/server/spi/user_test.go @@ -28,7 +28,7 @@ import ( ) func TestGetAttributes(t *testing.T) { - userInfo := &UserInfo{"TestUser1", "User1", "Client", []idp.Attribute{idp.Attribute{Name: "testName", Value: "testValue"}}} + userInfo := &UserInfo{"TestUser1", "User1", "Client", "bank_a", []idp.Attribute{idp.Attribute{Name: "testName", Value: "testValue"}}} user := NewUser(userInfo) attributes, err := user.GetAttributes() if err != nil { diff --git a/cli/server/spi/userregistry.go b/cli/server/spi/userregistry.go index 365ae5b01..ef8523247 100644 --- a/cli/server/spi/userregistry.go +++ b/cli/server/spi/userregistry.go @@ -28,6 +28,7 @@ type UserInfo struct { Name string Pass string Type string + Group string Attributes []idp.Attribute } @@ -57,6 +58,7 @@ type UserRegistry interface { UpdateUser(user UserInfo) error DeleteUser(id string) error UpdateField(id string, field int, value interface{}) error + GetField(id string, field int) (interface{}, error) GetGroup(name string) (Group, error) GetRootGroup() (Group, error) InsertGroup(name string, parentID string) error diff --git a/cli/server/user.go b/cli/server/user.go index 9b6f19e9a..6fd24f023 100644 --- a/cli/server/user.go +++ b/cli/server/user.go @@ -55,7 +55,7 @@ func getUserAttrValue(username, attrname string) (string, error) { // getUserAttrs returns a user's attributes func getUserAttrs(username string) ([]idp.Attribute, error) { log.Debugf("getUserAttributes %s", username) - user, err := CFG.UserRegistery.GetUser(username) + user, err := CFG.UserRegistry.GetUser(username) if err != nil { return nil, fmt.Errorf("user '%s' not found", username) } diff --git a/testdata/cop.json b/testdata/cop.json index a8282b26a..2b81d0c05 100644 --- a/testdata/cop.json +++ b/testdata/cop.json @@ -1,6 +1,9 @@ { "driver":"sqlite3", "data_source":"cop.db", + "user_registry": { + "max_enrollments": 1 + }, "users": { "admin": { "pass": "adminpw", diff --git a/testdata/testconfig.json b/testdata/testconfig.json index 08149043b..d79f98847 100644 --- a/testdata/testconfig.json +++ b/testdata/testconfig.json @@ -1,6 +1,9 @@ { "driver":"sqlite3", "data_source":"cop.db", + "user_registry": { + "max_enrollments": 1 + }, "users": { "admin": { "pass": "adminpw", @@ -42,13 +45,13 @@ "testUser": { "pass": "user1", "type": "client", - "group": "bank_a", + "group": "bank_b", "attrs": [] }, "testUser2": { "pass": "user2", "type": "client", - "group": "bank_a", + "group": "bank_c", "attrs": [] }, "testUser3": {