Skip to content

Commit

Permalink
feat: Implement required_roles
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Jan 31, 2025
1 parent cd61cc4 commit f7ccbee
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 60 deletions.
103 changes: 56 additions & 47 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from aiodataloader import DataLoader
from aiotools import apartial
from graphene.types import Scalar
from graphene.types.objecttype import ObjectTypeMeta
from graphene.types.scalars import MAX_INT, MIN_INT
from graphql import Undefined
from graphql.language.ast import IntValueNode
Expand Down Expand Up @@ -1014,62 +1013,72 @@ async def wrapped(
return wrap


def restricted_field_resolver(field_name: str):
from .user import UserRole

async def resolver(root, info, *args, **kwargs):
cls = type(root)
required_roles = getattr(cls, "required_roles_for_fields", {}).get(field_name)

if required_roles is None:
return getattr(root, field_name, None)

if isinstance(required_roles, UserRole):
required_roles = [required_roles]

ctx: GraphQueryContext = info.context
user_role: UserRole = ctx.user["role"]

if user_role not in required_roles:
raise GenericForbidden(f"Access denied for the '{field_name}' field")

return getattr(root, field_name, None)

return resolver

def required_roles(
roles: UserRole | list[UserRole],
field_name: str | None = None,
):
"""
A flexible function that can act as either:
1) A decorator for custom resolvers
2) A resolver argument for simple fields (using the 'field_name' parameter)
Usage:
------
1) Decorator form:
@require_roles([UserRole.SUPERADMIN, UserRole.ADMIN])
async def resolve_something(root, info, *args, **kwargs):
# original resolver logic
return ...
2) Resolver argument form (for simple fields):
myfield = graphene.String(
resolver=require_roles([UserRole.SUPERADMIN], "myfield")
)
class FieldRestrictedMeta(ObjectTypeMeta):
def __new__(mcs, name, bases, attrs, **options):
cls = super().__new__(mcs, name, bases, attrs, **options)
Parameters:
-----------
roles: UserRole | list[UserRole]
A single role or a list of roles required to access the field or resolver.
field_name: str | None
If provided, returns a resolver function that fetches `field_name` from `root`.
If None, returns a decorator for custom resolver functions.
for field_name, field_obj in cls._meta.fields.items():
# Skip if the field has a custom resolver
if hasattr(cls, f"resolve_{field_name}") or field_obj.resolver:
continue
Returns:
--------
- An async resolver function if `field_name` is set.
- A decorator function if `field_name` is None.
"""
from .user import UserRole

field_obj.resolver = restricted_field_resolver(field_name)
if isinstance(roles, UserRole):
roles = [roles]

return cls
def decorator(func):
"""Decorator that checks user role before running 'func'."""

@functools.wraps(func)
async def wrapper(root, info, *args, **kwargs):
ctx = info.context
user_role: UserRole = ctx.user["role"]
if user_role not in roles:
raise GenericForbidden(
f"One of {roles} permission is required. Current role: {user_role}"
)
return await func(root, info, *args, **kwargs)

class FieldRestrictedObjectType(graphene.ObjectType, metaclass=FieldRestrictedMeta):
"""
This base class automatically assigns a resolver to each field
that checks if the user's role is in the list (or single value)
defined in `required_roles_for_fields`.
return wrapper

Usage example in a subclass:
# In case of "field_name" is provided, it returns a dynamic resolver function.
if field_name is not None:

required_roles_for_fields = {
"some_field": UserRole.SUPERADMIN,
"another_field": [UserRole.SUPERADMIN, UserRole.ADMIN],
}
@decorator
async def dynamic_resolver(root, info, *args, **kwargs):
return getattr(root, field_name, None)

If the field is not listed in `required_roles_for_fields`, no special role is required.
Note that if there's already a custom resolver for a field, that field is skipped.
"""
return dynamic_resolver

pass
# Otherwise, it returns the decorator function.
return decorator


def scoped_query(
Expand Down
51 changes: 38 additions & 13 deletions src/ai/backend/manager/models/container_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@
)
from .base import (
Base,
FieldRestrictedObjectType,
FilterExprArg,
IDColumn,
OrderExprArg,
PaginatedConnectionField,
StrEnumType,
generate_sql_info_for_gql_connection,
required_roles,
set_if_set,
)
from .gql_models.group import GroupConnection, GroupNode
Expand Down Expand Up @@ -370,24 +370,49 @@ async def handle_allowed_groups_update(
raise ContainerRegistryNotFound()


class ContainerRegistryNode(FieldRestrictedObjectType):
class ContainerRegistryNode(graphene.ObjectType):
class Meta:
interfaces = (AsyncNode,)
description = "Added in 24.09.0."

row_id = graphene.UUID(
description="Added in 24.09.0. The UUID type id of DB container_registries row."
description="Added in 24.09.0. The UUID type id of DB container_registries row.",
resolver=required_roles(UserRole.SUPERADMIN, "row_id"),
)
name = graphene.String(resolver=required_roles(UserRole.SUPERADMIN, "name"))
url = graphene.String(
required=True,
description="Added in 24.09.0.",
resolver=required_roles(UserRole.SUPERADMIN, "url"),
)
type = ContainerRegistryTypeField(
required=True,
description="Added in 24.09.0.",
resolver=required_roles(UserRole.SUPERADMIN, "type"),
)
registry_name = graphene.String(
required=True,
description="Added in 24.09.0.",
resolver=required_roles(UserRole.SUPERADMIN, "registry_name"),
)
is_global = graphene.Boolean(
description="Added in 24.09.0.", resolver=required_roles(UserRole.SUPERADMIN, "is_global")
)
project = graphene.String(
description="Added in 24.09.0.", resolver=required_roles(UserRole.SUPERADMIN, "project")
)
username = graphene.String(
description="Added in 24.09.0.", resolver=required_roles(UserRole.SUPERADMIN, "username")
)
password = graphene.String(
description="Added in 24.09.0.", resolver=required_roles(UserRole.SUPERADMIN, "password")
)
ssl_verify = graphene.Boolean(
description="Added in 24.09.0.", resolver=required_roles(UserRole.SUPERADMIN, "ssl_verify")
)
extra = graphene.JSONString(
description="Added in 24.09.3.", resolver=required_roles(UserRole.SUPERADMIN, "extra")
)
name = graphene.String()
url = graphene.String(required=True, description="Added in 24.09.0.")
type = ContainerRegistryTypeField(required=True, description="Added in 24.09.0.")
registry_name = graphene.String(required=True, description="Added in 24.09.0.")
is_global = graphene.Boolean(description="Added in 24.09.0.")
project = graphene.String(description="Added in 24.09.0.")
username = graphene.String(description="Added in 24.09.0.")
password = graphene.String(description="Added in 24.09.0.")
ssl_verify = graphene.Boolean(description="Added in 24.09.0.")
extra = graphene.JSONString(description="Added in 24.09.3.")
allowed_groups = PaginatedConnectionField(GroupConnection, description="Added in 25.2.0.")

_queryfilter_fieldspec: dict[str, FieldSpecItem] = {
Expand Down

0 comments on commit f7ccbee

Please sign in to comment.