From 548cdfe5af1e51fe30b083007c34a2af4b3d30fc Mon Sep 17 00:00:00 2001 From: Yuqing Wei Date: Thu, 15 Sep 2022 16:10:35 +0800 Subject: [PATCH 1/3] extend rbac to support project id as input Signed-off-by: Yuqing Wei --- registry/access_control/api.py | 42 ++++++------ registry/access_control/rbac/access.py | 41 ++++++++---- registry/access_control/rbac/db_rbac.py | 16 +++-- registry/access_control/rbac/models.py | 66 +++++++++++++++++++ registry/purview-registry/main.py | 3 + .../registry/purview_registry.py | 13 +++- registry/sql-registry/main.py | 3 + registry/sql-registry/registry/db_registry.py | 8 +++ 8 files changed, 153 insertions(+), 39 deletions(-) diff --git a/registry/access_control/api.py b/registry/access_control/api.py index 3e7343e47..e49c4203c 100644 --- a/registry/access_control/api.py +++ b/registry/access_control/api.py @@ -20,30 +20,30 @@ async def get_projects() -> list[str]: @router.get('/projects/{project}', name="Get My Project [Read Access Required]") -async def get_project(project: str, requestor: User = Depends(project_read_access)): +async def get_project(project: str, access: UserAccess = Depends(project_read_access)): response = requests.get(url=f"{registry_url}/projects/{project}", - headers=get_api_header(requestor)).content.decode('utf-8') + headers=get_api_header(access.user_name)).content.decode('utf-8') return json.loads(response) @router.get("/projects/{project}/datasources", name="Get data sources of my project [Read Access Required]") -def get_project_datasources(project: str, requestor: User = Depends(project_read_access)) -> list: +def get_project_datasources(project: str, access: UserAccess = Depends(project_read_access)) -> list: response = requests.get(url=f"{registry_url}/projects/{project}/datasources", - headers=get_api_header(requestor)).content.decode('utf-8') + headers=get_api_header(access.user_name)).content.decode('utf-8') return json.loads(response) @router.get("/projects/{project}/features", name="Get features under my project [Read Access Required]") -def get_project_features(project: str, keyword: Optional[str] = None, requestor: User = Depends(project_read_access)) -> list: +def get_project_features(project: str, keyword: Optional[str] = None, access: UserAccess = Depends(project_read_access)) -> list: response = requests.get(url=f"{registry_url}/projects/{project}/features", - headers=get_api_header(requestor)).content.decode('utf-8') + headers=get_api_header(access.user_name)).content.decode('utf-8') return json.loads(response) @router.get("/features/{feature}", name="Get a single feature by feature Id [Read Access Required]") def get_feature(feature: str, requestor: User = Depends(get_user)) -> dict: response = requests.get(url=f"{registry_url}/features/{feature}", - headers=get_api_header(requestor)).content.decode('utf-8') + headers=get_api_header(requestor.username)).content.decode('utf-8') ret = json.loads(response) feature_qualifiedName = ret['attributes']['qualifiedName'] @@ -55,7 +55,7 @@ def get_feature(feature: str, requestor: User = Depends(get_user)) -> dict: @router.get("/features/{feature}/lineage", name="Get Feature Lineage [Read Access Required]") def get_feature_lineage(feature: str, requestor: User = Depends(get_user)) -> dict: response = requests.get(url=f"{registry_url}/features/{feature}/lineage", - headers=get_api_header(requestor)).content.decode('utf-8') + headers=get_api_header(requestor.username)).content.decode('utf-8') ret = json.loads(response) feature_qualifiedName = ret['guidEntityMap'][feature]['attributes']['qualifiedName'] @@ -68,35 +68,35 @@ def get_feature_lineage(feature: str, requestor: User = Depends(get_user)) -> di def new_project(definition: dict, requestor: User = Depends(get_user)) -> dict: rbac.init_userrole(requestor.username, definition["name"]) response = requests.post(url=f"{registry_url}/projects", json=definition, - headers=get_api_header(requestor)).content.decode('utf-8') + headers=get_api_header(requestor.username)).content.decode('utf-8') return json.loads(response) @router.post("/projects/{project}/datasources", name="Create new data source of my project [Write Access Required]") -def new_project_datasource(project: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict: +def new_project_datasource(project: str, definition: dict, access: UserAccess = Depends(project_write_access)) -> dict: response = requests.post(url=f"{registry_url}/projects/{project}/datasources", json=definition, headers=get_api_header( - requestor)).content.decode('utf-8') + access.user_name)).content.decode('utf-8') return json.loads(response) @router.post("/projects/{project}/anchors", name="Create new anchors of my project [Write Access Required]") -def new_project_anchor(project: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict: +def new_project_anchor(project: str, definition: dict, access: UserAccess = Depends(project_write_access)) -> dict: response = requests.post(url=f"{registry_url}/projects/{project}/anchors", json=definition, headers=get_api_header( - requestor)).content.decode('utf-8') + access.user_name)).content.decode('utf-8') return json.loads(response) @router.post("/projects/{project}/anchors/{anchor}/features", name="Create new anchor features of my project [Write Access Required]") -def new_project_anchor_feature(project: str, anchor: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict: +def new_project_anchor_feature(project: str, anchor: str, definition: dict, access: UserAccess = Depends(project_write_access)) -> dict: response = requests.post(url=f"{registry_url}/projects/{project}/anchors/{anchor}/features", json=definition, headers=get_api_header( - requestor)).content.decode('utf-8') + access.user_name)).content.decode('utf-8') return json.loads(response) @router.post("/projects/{project}/derivedfeatures", name="Create new derived features of my project [Write Access Required]") -def new_project_derived_feature(project: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict: +def new_project_derived_feature(project: str, definition: dict, access: UserAccess = Depends(project_write_access)) -> dict: response = requests.post(url=f"{registry_url}/projects/{project}/derivedfeatures", - json=definition, headers=get_api_header(requestor)).content.decode('utf-8') + json=definition, headers=get_api_header(access.user_name)).content.decode('utf-8') return json.loads(response) # Below are access control management APIs @@ -106,10 +106,10 @@ def get_userroles(requestor: User = Depends(get_user)) -> list: @router.post("/users/{user}/userroles/add", name="Add a new user role [Project Manage Access Required]") -def add_userrole(project: str, user: str, role: str, reason: str, requestor: User = Depends(project_manage_access)): - return rbac.add_userrole(project, user, role, reason, requestor.username) +def add_userrole(project: str, user: str, role: str, reason: str, access: UserAccess = Depends(project_manage_access)): + return rbac.add_userrole(access.project_name, user, role, reason, access.user_name) @router.delete("/users/{user}/userroles/delete", name="Delete a user role [Project Manage Access Required]") -def delete_userrole(project: str, user: str, role: str, reason: str, requestor: User = Depends(project_manage_access)): - return rbac.delete_userrole(project, user, role, reason, requestor.username) +def delete_userrole(user: str, role: str, reason: str, access: UserAccess= Depends(project_manage_access)): + return rbac.delete_userrole(access.project_name, user, role, reason, access.user_name) diff --git a/registry/access_control/rbac/access.py b/registry/access_control/rbac/access.py index f7949c307..ff74495a1 100644 --- a/registry/access_control/rbac/access.py +++ b/registry/access_control/rbac/access.py @@ -1,8 +1,9 @@ -from typing import Any +from typing import Any, Union +from uuid import UUID from fastapi import Depends, HTTPException, status from rbac.db_rbac import DbRBAC -from rbac.models import AccessType, User +from rbac.models import AccessType, User, UserAccess,_to_uuid from rbac.auth import authorize """ @@ -22,24 +23,25 @@ def get_user(user: User = Depends(authorize)) -> User: return user -def project_read_access(project: str, user: User = Depends(authorize)) -> User: +def project_read_access(project: str, user: User = Depends(authorize)) -> UserAccess: return _project_access(project, user, AccessType.READ) -def project_write_access(project: str, user: User = Depends(authorize)) -> User: +def project_write_access(project: str, user: User = Depends(authorize)) -> UserAccess: return _project_access(project, user, AccessType.WRITE) -def project_manage_access(project: str, user: User = Depends(authorize)) -> User: +def project_manage_access(project: str, user: User = Depends(authorize)) -> UserAccess: return _project_access(project, user, AccessType.MANAGE) -def _project_access(project: str, user: User, access: str): +def _project_access(project: str, user: User, access: str) -> UserAccess: + project = _get_project_name(project) if rbac.validate_project_access_users(project, user.username, access): - return user + return UserAccess(user.username, project) else: raise ForbiddenAccess( - f"{access} privileges for project {project} required for user {user.username}") + f"{access} access for project {project} is required for user {user.username}") def global_admin_access(user: User = Depends(authorize)): @@ -48,16 +50,29 @@ def global_admin_access(user: User = Depends(authorize)): else: raise ForbiddenAccess('Admin privileges required') -def validate_project_access_for_feature(feature:str, user:str, access:str): +def validate_project_access_for_feature(feature:str, user:User, access:str): project = _get_project_from_feature(feature) _project_access(project, user, access) - def _get_project_from_feature(feature: str): feature_delimiter = "__" return feature.split(feature_delimiter)[0] -def get_api_header(requestor: User): +def get_api_header(username: str): return { - "x-registry-requestor": requestor.username - } \ No newline at end of file + "x-registry-requestor": username + } + +def _get_project_name(id_or_name: Union[str, UUID]): + try: + _to_uuid(id_or_name) + if id_or_name not in rbac.projects_ids: + # refresh project id map if id not found + rbac.get_projects_ids() + return rbac.projects_ids[id_or_name] + except KeyError: + raise ForbiddenAccess(f"Project Id {id_or_name} not found in Registry") + except ValueError: + pass + # It is a name + return id_or_name \ No newline at end of file diff --git a/registry/access_control/rbac/db_rbac.py b/registry/access_control/rbac/db_rbac.py index e0f1f37ed..5dd2ee25e 100644 --- a/registry/access_control/rbac/db_rbac.py +++ b/registry/access_control/rbac/db_rbac.py @@ -1,8 +1,10 @@ +import json +import requests from fastapi import HTTPException, status from typing import Any from rbac import config from rbac.database import connect -from rbac.models import AccessType, UserRole, RoleType, SUPER_ADMIN_SCOPE +from rbac.models import AccessType, UserRole, RoleType, SUPER_ADMIN_SCOPE, _to_uuid from rbac.interface import RBAC import os import logging @@ -19,6 +21,7 @@ def __init__(self): os.environ["RBAC_CONNECTION_STR"] = config.RBAC_CONNECTION_STR self.conn = connect() self.get_userroles() + self.get_projects_ids() def get_userroles(self): # Cache is not supported in cluster, make sure every operation read from database. @@ -107,7 +110,7 @@ def add_userrole(self, project_name: str, user_name: str, role_name: str, create query = fr"""insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time) values ('%s','%s','%s','%s' ,'%s', getutcdate())""" self.conn.update(query % (project_name, user_name, - role_name, by, create_reason)) + role_name, by, create_reason.replace("'", "''"))) logging.info( f"Userrole added with query: {query%(project_name, user_name, role_name, by, create_reason)}") self.get_userroles() @@ -122,7 +125,7 @@ def delete_userrole(self, project_name: str, user_name: str, role_name: str, del [delete_time] = getutcdate() WHERE [user_name] = '%s' and [project_name] = '%s' and [role_name] = '%s' and [delete_time] is null""" - self.conn.update(query % (by, delete_reason, + self.conn.update(query % (by, delete_reason.replace("'", "''"), user_name, project_name, role_name)) logging.info( f"Userrole removed with query: {query%(by, delete_reason, user_name, project_name, role_name)}") @@ -148,8 +151,8 @@ def init_userrole(self, creator_name: str, project_name:str): else: # initialize project admin if project not exist: self.init_project_admin(creator_name, project_name) + - def init_project_admin(self, creator_name: str, project_name: str): """initialize the creator as project admin when a new project is created """ @@ -160,3 +163,8 @@ def init_project_admin(self, creator_name: str, project_name: str): self.conn.update(query % (project_name, creator_name, RoleType.ADMIN.value, create_by, create_reason)) logging.info(f"Userrole initialized with query: {query%(project_name, creator_name, RoleType.ADMIN.value, create_by, create_reason)}") return self.get_userroles() + + def get_projects_ids(self): + """cache all project ids from registry api""" + response = requests.get(url=f"{config.RBAC_REGISTRY_URL}/projects-ids").content.decode('utf-8') + self.projects_ids = json.loads(response) \ No newline at end of file diff --git a/registry/access_control/rbac/models.py b/registry/access_control/rbac/models.py index 602763c3e..d865a8cb4 100644 --- a/registry/access_control/rbac/models.py +++ b/registry/access_control/rbac/models.py @@ -1,7 +1,9 @@ +import re from typing import List, Optional from pydantic import BaseModel from datetime import datetime from enum import Enum +from uuid import UUID class User(BaseModel): id: str @@ -97,3 +99,67 @@ def to_dict(self) -> dict: "project_name": self.project_name, "access": self.access_name, } + +class UserAccess(): + def __init__(self, + user_name: str, + project_name: str): + self.user_name = user_name + self.project_name = project_name + +def to_snake(d, level: int = 0): + """ + Convert `string`, `list[string]`, or all keys in a `dict` into snake case + The maximum length of input string or list is 100, or it will be truncated before being processed, for dict, the exception will be thrown if it has more than 100 keys. + the maximum nested level is 10, otherwise the exception will be thrown + """ + if level >= 10: + raise ValueError("Too many nested levels") + if isinstance(d, str): + d = d[:100] + return re.sub(r'(? 100: + raise ValueError("Dict has too many keys") + return {to_snake(a, level + 1): to_snake(b, level + 1) if isinstance(b, (dict, list)) else b for a, b in d.items()} + + + +def _to_type(value, type): + """ + Convert `value` into `type`, + or `list[type]` if `value` is a list + NOTE: This is **not** a generic implementation, only for objects in this module + """ + if isinstance(value, type): + return value + if isinstance(value, list): + return list([_to_type(v, type) for v in value]) + if isinstance(value, dict): + if hasattr(type, "new"): + try: + # The convention is to use `new` method to create the object from a dict + return type.new(**to_snake(value)) + except TypeError: + pass + return type(**to_snake(value)) + if issubclass(type, Enum): + try: + n = int(value) + return type(n) + except ValueError: + pass + if hasattr(type, "new"): + try: + # As well as Enum types, some of them have alias that cannot be handled by default Enum constructor + return type.new(value) + except KeyError: + pass + return type[value] + return type(value) + + +def _to_uuid(value): + return _to_type(value, UUID) diff --git a/registry/purview-registry/main.py b/registry/purview-registry/main.py index 5818cd513..61c746410 100644 --- a/registry/purview-registry/main.py +++ b/registry/purview-registry/main.py @@ -48,6 +48,9 @@ def to_camel(s): def get_projects() -> list[str]: return registry.get_projects() +@router.get("/projects-ids") +def get_projects_ids() -> dict: + return registry.get_projects_ids() @router.get("/projects/{project}",tags=["Project"]) def get_projects(project: str) -> dict: diff --git a/registry/purview-registry/registry/purview_registry.py b/registry/purview-registry/registry/purview_registry.py index 0d43dd6e6..9f5f47560 100644 --- a/registry/purview-registry/registry/purview_registry.py +++ b/registry/purview-registry/registry/purview_registry.py @@ -47,7 +47,18 @@ def get_projects(self) -> list[str]: result = self.purview_client.discovery.query(filter=searchTerm) result_entities = result['value'] return [x['qualifiedName'] for x in result_entities] - + + def get_projects_ids(self) -> dict: + """ + Returns the names and ids of all projects""" + searchTerm = {"entityType": str(EntityType.Project)} + result = self.purview_client.discovery.query(filter=searchTerm) + result_entities = result['value'] + projects = {} + for x in result_entities: + projects[x['id']] = x['qualifiedName'] + return projects + def get_entity(self, id_or_name: Union[str, UUID],recursive = False) -> Entity: id = self.get_entity_id(id_or_name) if not id: diff --git a/registry/sql-registry/main.py b/registry/sql-registry/main.py index 00ac1d422..ae7662b6f 100644 --- a/registry/sql-registry/main.py +++ b/registry/sql-registry/main.py @@ -33,6 +33,9 @@ def get_projects() -> list[str]: return registry.get_projects() +@router.get("/projects-ids") +def get_projects_ids() -> dict: + return registry.get_projects_ids() @router.get("/projects/{project}") def get_projects(project: str) -> dict: diff --git a/registry/sql-registry/registry/db_registry.py b/registry/sql-registry/registry/db_registry.py index 58f4b98db..9686ac5a7 100644 --- a/registry/sql-registry/registry/db_registry.py +++ b/registry/sql-registry/registry/db_registry.py @@ -25,6 +25,14 @@ def get_projects(self) -> list[str]: ret = self.conn.query( f"select qualified_name from entities where entity_type=%s", str(EntityType.Project)) return list([r["qualified_name"] for r in ret]) + + def get_projects_ids(self) -> dict: + projects = {} + ret = self.conn.query( + f"select entity_id, qualified_name from entities where entity_type=%s", str(EntityType.Project)) + for r in ret: + projects[r['entity_id']] = r['qualified_name'] + return projects def get_entity(self, id_or_name: Union[str, UUID]) -> Entity: return self._fill_entity(self._get_entity(id_or_name)) From 6c1bae9c4e8e16ba31f9e5aacf9577845c021a04 Mon Sep 17 00:00:00 2001 From: Yuqing Wei Date: Thu, 15 Sep 2022 16:25:36 +0800 Subject: [PATCH 2/3] update registry docs and interface Signed-off-by: Yuqing Wei --- registry/purview-registry/api-spec.md | 5 +++++ registry/purview-registry/registry/interface.py | 7 +++++++ registry/sql-registry/api-spec.md | 5 +++++ registry/sql-registry/registry/interface.py | 7 +++++++ 4 files changed, 24 insertions(+) diff --git a/registry/purview-registry/api-spec.md b/registry/purview-registry/api-spec.md index 1b14cae8b..d2e82a878 100644 --- a/registry/purview-registry/api-spec.md +++ b/registry/purview-registry/api-spec.md @@ -277,6 +277,11 @@ List **names** of all projects. Response Type: `array` +### `GET /projects-ids` +Dictionary of **id** to **names** mapping of all projects. + +Response Type: `dict` + ### `GET /projects/{project}` Get everything defined in the project diff --git a/registry/purview-registry/registry/interface.py b/registry/purview-registry/registry/interface.py index 78e79cb88..7559a3f27 100644 --- a/registry/purview-registry/registry/interface.py +++ b/registry/purview-registry/registry/interface.py @@ -12,6 +12,13 @@ def get_projects(self) -> list[str]: """ pass + @abstractmethod + def get_projects_ids(self) -> dict: + """ + Returns the ids to names mapping of all projects + """ + pass + @abstractmethod def get_entity(self, id_or_name: Union[str, UUID],recursive = False) -> Entity: """ diff --git a/registry/sql-registry/api-spec.md b/registry/sql-registry/api-spec.md index 1b14cae8b..d2e82a878 100644 --- a/registry/sql-registry/api-spec.md +++ b/registry/sql-registry/api-spec.md @@ -277,6 +277,11 @@ List **names** of all projects. Response Type: `array` +### `GET /projects-ids` +Dictionary of **id** to **names** mapping of all projects. + +Response Type: `dict` + ### `GET /projects/{project}` Get everything defined in the project diff --git a/registry/sql-registry/registry/interface.py b/registry/sql-registry/registry/interface.py index dbaf2e8fd..7f1439079 100644 --- a/registry/sql-registry/registry/interface.py +++ b/registry/sql-registry/registry/interface.py @@ -14,6 +14,13 @@ def get_projects(self) -> list[str]: """ pass + @abstractmethod + def get_projects_ids(self) -> dict: + """ + Returns the ids to names mapping of all projects + """ + pass + @abstractmethod def get_entity(self, id_or_name: Union[str, UUID]) -> Entity: """ From 8478d147eb0b639e694e3a7d30b5468c2ca99ae4 Mon Sep 17 00:00:00 2001 From: Yuqing Wei Date: Mon, 19 Sep 2022 20:20:16 +0800 Subject: [PATCH 3/3] user name case sensitive hot fix Signed-off-by: Yuqing Wei --- registry/access_control/README.md | 2 +- registry/access_control/rbac/db_rbac.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/registry/access_control/README.md b/registry/access_control/README.md index 56f6b6cc2..ac317239d 100644 --- a/registry/access_control/README.md +++ b/registry/access_control/README.md @@ -58,7 +58,7 @@ Admin roles can add or delete roles in management UI page or through management | RBAC_API_AUDIENCE | Used as audience to decode jwt tokens | ## Notes - +Please notice that User Role records are **NOT** case sensitive. All records will be converted to lower case before saving to database. Supported scenarios status are tracked below: - General Foundations: diff --git a/registry/access_control/rbac/db_rbac.py b/registry/access_control/rbac/db_rbac.py index 5dd2ee25e..6b04c4a45 100644 --- a/registry/access_control/rbac/db_rbac.py +++ b/registry/access_control/rbac/db_rbac.py @@ -59,9 +59,9 @@ def get_userroles_by_user(self, user_name: str, role_name: str = None) -> list[U where delete_reason is null and user_name ='%s'""" if role_name: query += fr"and role_name = '%s'" - rows = self.conn.query(query % (user_name, role_name)) + rows = self.conn.query(query % (user_name.lower(), role_name.lower())) else: - rows = self.conn.query(query % (user_name)) + rows = self.conn.query(query % (user_name.lower())) ret = [] for row in rows: ret.append(UserRole(**row)) @@ -75,9 +75,9 @@ def get_userroles_by_project(self, project_name: str, role_name: str = None) -> where delete_reason is null and project_name ='%s'""" if role_name: query += fr"and role_name = '%s'" - rows = self.conn.query(query % (project_name, role_name)) + rows = self.conn.query(query % (project_name.lower(), role_name.lower())) else: - rows = self.conn.query(query % (project_name)) + rows = self.conn.query(query % (project_name.lower())) ret = [] for row in rows: ret.append(UserRole(**row)) @@ -109,8 +109,8 @@ def add_userrole(self, project_name: str, user_name: str, role_name: str, create # insert new record query = fr"""insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time) values ('%s','%s','%s','%s' ,'%s', getutcdate())""" - self.conn.update(query % (project_name, user_name, - role_name, by, create_reason.replace("'", "''"))) + self.conn.update(query % (project_name.lower(), user_name.lower(), + role_name.lower(), by, create_reason.replace("'", "''"))) logging.info( f"Userrole added with query: {query%(project_name, user_name, role_name, by, create_reason)}") self.get_userroles() @@ -126,7 +126,7 @@ def delete_userrole(self, project_name: str, user_name: str, role_name: str, del WHERE [user_name] = '%s' and [project_name] = '%s' and [role_name] = '%s' and [delete_time] is null""" self.conn.update(query % (by, delete_reason.replace("'", "''"), - user_name, project_name, role_name)) + user_name.lower(), project_name.lower(), role_name.lower())) logging.info( f"Userrole removed with query: {query%(by, delete_reason, user_name, project_name, role_name)}") self.get_userroles() @@ -144,7 +144,7 @@ def init_userrole(self, creator_name: str, project_name:str): query = fr"""select project_name, user_name, role_name, create_by, create_reason, create_time, delete_reason, delete_time from userroles where delete_reason is null and project_name ='%s'""" - rows = self.conn.query(query%(project_name)) + rows = self.conn.query(query%(project_name.lower())) if len(rows) > 0: logging.warning(f"{project_name} already exist, please pick another name.") return @@ -160,7 +160,7 @@ def init_project_admin(self, creator_name: str, project_name: str): create_reason = "creator of project, get admin by default." query = fr"""insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time) values ('%s','%s','%s','%s','%s', getutcdate())""" - self.conn.update(query % (project_name, creator_name, RoleType.ADMIN.value, create_by, create_reason)) + self.conn.update(query % (project_name.lower(), creator_name.lower(), RoleType.ADMIN.value, create_by, create_reason)) logging.info(f"Userrole initialized with query: {query%(project_name, creator_name, RoleType.ADMIN.value, create_by, create_reason)}") return self.get_userroles()