Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom registration fields for registry #675

Merged
merged 13 commits into from
Jan 31, 2024
Merged
83 changes: 71 additions & 12 deletions registry/registry_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ func createNamespaceTable() {
prefix TEXT NOT NULL UNIQUE,
pubkey TEXT NOT NULL,
identity TEXT,
admin_metadata TEXT CHECK (length("admin_metadata") <= 4000)
admin_metadata TEXT CHECK (length("admin_metadata") <= 4000),
custom_fields TEXT CHECK (length("custom_fields") <= 4000)
);`

_, err := db.Exec(query)
Expand Down Expand Up @@ -415,8 +416,9 @@ func getNamespaceById(id int) (*Namespace, error) {
}
ns := &Namespace{}
adminMetadataStr := ""
query := `SELECT id, prefix, pubkey, identity, admin_metadata FROM namespace WHERE id = ?`
err := db.QueryRow(query, id).Scan(&ns.ID, &ns.Prefix, &ns.Pubkey, &ns.Identity, &adminMetadataStr)
customRegFieldsStr := ""
query := `SELECT id, prefix, pubkey, identity, admin_metadata, custom_fields FROM namespace WHERE id = ?`
err := db.QueryRow(query, id).Scan(&ns.ID, &ns.Prefix, &ns.Pubkey, &ns.Identity, &adminMetadataStr, &customRegFieldsStr)
if err != nil {
return nil, err
}
Expand All @@ -426,6 +428,21 @@ func getNamespaceById(id int) (*Namespace, error) {
return nil, err
}
}
if customRegFieldsStr != "" {
if err := json.Unmarshal([]byte(customRegFieldsStr), &ns.CustomFields); err != nil {
return nil, err
}
}
// By default, JSON unmarshall will convert any generic number to float
// and we only allow integer in custom fields, so we convert them back
for key, val := range ns.CustomFields {
switch v := val.(type) {
case float64:
ns.CustomFields[key] = int(v)
case float32:
ns.CustomFields[key] = int(v)
}
}
return ns, nil
}

Expand All @@ -435,8 +452,9 @@ func getNamespaceByPrefix(prefix string) (*Namespace, error) {
}
ns := &Namespace{}
adminMetadataStr := ""
query := `SELECT id, prefix, pubkey, identity, admin_metadata FROM namespace WHERE prefix = ?`
err := db.QueryRow(query, prefix).Scan(&ns.ID, &ns.Prefix, &ns.Pubkey, &ns.Identity, &adminMetadataStr)
customRegFieldsStr := ""
query := `SELECT id, prefix, pubkey, identity, admin_metadata, custom_fields FROM namespace WHERE prefix = ?`
err := db.QueryRow(query, prefix).Scan(&ns.ID, &ns.Prefix, &ns.Pubkey, &ns.Identity, &adminMetadataStr, &customRegFieldsStr)
if err != nil {
return nil, err
}
Expand All @@ -446,6 +464,21 @@ func getNamespaceByPrefix(prefix string) (*Namespace, error) {
return nil, err
}
}
if customRegFieldsStr != "" {
if err := json.Unmarshal([]byte(customRegFieldsStr), &ns.CustomFields); err != nil {
return nil, err
}
}
// By default, JSON unmarshall will convert any generic number to float
// and we only allow integer in custom fields, so we convert them back
for key, val := range ns.CustomFields {
switch v := val.(type) {
case float64:
ns.CustomFields[key] = int(v)
case float32:
ns.CustomFields[key] = int(v)
}
}
return ns, nil
}

Expand All @@ -465,7 +498,9 @@ func getNamespacesByFilter(filterNs Namespace, serverType ServerType) ([]*Namesp
} else if serverType != "" {
return nil, errors.New(fmt.Sprint("Can't get namespace: unsupported server type: ", serverType))
}

if filterNs.CustomFields != nil {
return nil, errors.New("Unsupported operation: Can't filter against Custrom Registration field.")
}
if filterNs.ID != 0 {
return nil, errors.New("Unsupported operation: Can't filter against ID field.")
}
Expand Down Expand Up @@ -553,7 +588,7 @@ functions) used by the client.
*/

func addNamespace(ns *Namespace) error {
query := `INSERT INTO namespace (prefix, pubkey, identity, admin_metadata) VALUES (?, ?, ?, ?)`
query := `INSERT INTO namespace (prefix, pubkey, identity, admin_metadata, custom_fields) VALUES (?, ?, ?, ?, ?)`
tx, err := db.Begin()
if err != nil {
return err
Expand All @@ -573,8 +608,12 @@ func addNamespace(ns *Namespace) error {
if err != nil {
return errors.Wrap(err, "Fail to marshall AdminMetadata")
}
strCustomRegFields, err := json.Marshal(ns.CustomFields)
if err != nil {
return errors.Wrap(err, "Fail to marshall custom registration fields")
}

_, err = tx.Exec(query, ns.Prefix, ns.Pubkey, ns.Identity, strAdminMetadata)
_, err = tx.Exec(query, ns.Prefix, ns.Pubkey, ns.Identity, strAdminMetadata, strCustomRegFields)
if err != nil {
if errRoll := tx.Rollback(); errRoll != nil {
log.Errorln("Failed to rollback transaction:", errRoll)
Expand Down Expand Up @@ -610,15 +649,19 @@ func updateNamespace(ns *Namespace) error {
if err != nil {
return errors.Wrap(err, "Fail to marshall AdminMetadata")
}
strCustomRegFields, err := json.Marshal(ns.CustomFields)
if err != nil {
return errors.Wrap(err, "Fail to marshall custom registration fields")
}

// We intentionally exclude updating "identity" as this should only be updated
// when user registered through Pelican client with identity
query := `UPDATE namespace SET prefix = ?, pubkey = ?, admin_metadata = ? WHERE id = ?`
query := `UPDATE namespace SET prefix = ?, pubkey = ?, admin_metadata = ?, custom_fields = ? WHERE id = ?`
tx, err := db.Begin()
if err != nil {
return err
}
_, err = tx.Exec(query, ns.Prefix, ns.Pubkey, strAdminMetadata, ns.ID)
_, err = tx.Exec(query, ns.Prefix, ns.Pubkey, strAdminMetadata, strCustomRegFields, ns.ID)
if err != nil {
if errRoll := tx.Rollback(); errRoll != nil {
log.Errorln("Failed to rollback transaction:", errRoll)
Expand Down Expand Up @@ -681,7 +724,7 @@ func deleteNamespace(prefix string) error {
}

func getAllNamespaces() ([]*Namespace, error) {
query := `SELECT id, prefix, pubkey, identity, admin_metadata FROM namespace ORDER BY id ASC`
query := `SELECT id, prefix, pubkey, identity, admin_metadata, custom_fields FROM namespace ORDER BY id ASC`
turetske marked this conversation as resolved.
Show resolved Hide resolved
rows, err := db.Query(query)
if err != nil {
return nil, err
Expand All @@ -692,7 +735,8 @@ func getAllNamespaces() ([]*Namespace, error) {
for rows.Next() {
ns := &Namespace{}
adminMetadataStr := ""
if err := rows.Scan(&ns.ID, &ns.Prefix, &ns.Pubkey, &ns.Identity, &adminMetadataStr); err != nil {
customRegFieldsStr := ""
if err := rows.Scan(&ns.ID, &ns.Prefix, &ns.Pubkey, &ns.Identity, &adminMetadataStr, &customRegFieldsStr); err != nil {
return nil, err
}
// For backward compatibility, if adminMetadata is an empty string, don't unmarshall json
Expand All @@ -701,6 +745,21 @@ func getAllNamespaces() ([]*Namespace, error) {
return nil, err
}
}
if customRegFieldsStr != "" {
if err := json.Unmarshal([]byte(customRegFieldsStr), &ns.CustomFields); err != nil {
return nil, err
}
}
// By default, JSON unmarshall will convert any generic number to float
// and we only allow integer in custom fields, so we convert them back
for key, val := range ns.CustomFields {
switch v := val.(type) {
case float64:
ns.CustomFields[key] = int(v)
haoming29 marked this conversation as resolved.
Show resolved Hide resolved
case float32:
ns.CustomFields[key] = int(v)
}
}
namespaces = append(namespaces, ns)
}

Expand Down
28 changes: 25 additions & 3 deletions registry/registry_db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func teardownMockNamespaceDB(t *testing.T) {
}

func insertMockDBData(nss []Namespace) error {
query := `INSERT INTO namespace (prefix, pubkey, identity, admin_metadata) VALUES (?, ?, ?, ?)`
query := `INSERT INTO namespace (prefix, pubkey, identity, admin_metadata, custom_fields) VALUES (?, ?, ?, ?, ?)`
tx, err := db.Begin()
if err != nil {
return err
Expand All @@ -71,8 +71,15 @@ func insertMockDBData(nss []Namespace) error {
}
return err
}
customFieldsStr, err := json.Marshal(ns.CustomFields)
if err != nil {
if errRoll := tx.Rollback(); errRoll != nil {
return errors.Wrap(errRoll, "Failed to rollback transaction")
}
return err
}

_, err = tx.Exec(query, ns.Prefix, ns.Pubkey, ns.Identity, adminMetaStr)
_, err = tx.Exec(query, ns.Prefix, ns.Pubkey, ns.Identity, adminMetaStr, customFieldsStr)
if err != nil {
if errRoll := tx.Rollback(); errRoll != nil {
return errors.Wrap(errRoll, "Failed to rollback transaction")
Expand Down Expand Up @@ -170,6 +177,12 @@ var (
mixed = append(mixed, mockNssWithCachesNotApproved...)
return
}()

mockCustomFields = map[string]interface{}{
"key1": "value1",
"key2": 2,
"key3": true,
}
)

func TestNamespaceExistsByPrefix(t *testing.T) {
Expand All @@ -192,7 +205,7 @@ func TestNamespaceExistsByPrefix(t *testing.T) {
})
}

func TestGetNamespacesById(t *testing.T) {
func TestGetNamespaceById(t *testing.T) {
setupMockRegistryDB(t)
defer teardownMockNamespaceDB(t)

Expand All @@ -212,6 +225,7 @@ func TestGetNamespacesById(t *testing.T) {
t.Run("return-namespace-with-correct-id", func(t *testing.T) {
defer resetNamespaceDB(t)
mockNs := mockNamespace("/test", "", "", AdminMetadata{UserID: "foo"})
mockNs.CustomFields = mockCustomFields
err := insertMockDBData([]Namespace{mockNs})
require.NoError(t, err)
nss, err := getAllNamespaces()
Expand Down Expand Up @@ -317,6 +331,7 @@ func TestAddNamespace(t *testing.T) {
t.Run("insert-data-integrity", func(t *testing.T) {
defer resetNamespaceDB(t)
mockNs := mockNamespace("/test", "pubkey", "identity", AdminMetadata{UserID: "someone", Description: "Some description", SiteName: "OSG", SecurityContactUserID: "security-001"})
mockNs.CustomFields = mockCustomFields
err := addNamespace(&mockNs)
require.NoError(t, err)
got, err := getAllNamespaces()
Expand All @@ -328,6 +343,7 @@ func TestAddNamespace(t *testing.T) {
assert.Equal(t, mockNs.AdminMetadata.Description, got[0].AdminMetadata.Description)
assert.Equal(t, mockNs.AdminMetadata.SiteName, got[0].AdminMetadata.SiteName)
assert.Equal(t, mockNs.AdminMetadata.SecurityContactUserID, got[0].AdminMetadata.SecurityContactUserID)
assert.Equal(t, mockCustomFields, got[0].CustomFields)
})
}

Expand Down Expand Up @@ -459,6 +475,12 @@ func TestGetNamespacesByFilter(t *testing.T) {
_, err := getNamespacesByFilter(filterNsID, "")
require.Error(t, err, "Should return error for filtering against unsupported field ID")

filterNsCF := Namespace{
CustomFields: mockCustomFields,
}
_, err = getNamespacesByFilter(filterNsCF, "")
require.Error(t, err, "Should return error for filtering against unsupported custom fields")

filterNsIdentity := Namespace{
Identity: "someIdentity",
}
Expand Down