Skip to content

Commit

Permalink
feat: Implement `ProjectPermissionContextBuilder.build_ctx_in_contain…
Browse files Browse the repository at this point in the history
…er_registry_scope()`
  • Loading branch information
jopemachine committed Jan 22, 2025
1 parent 1a2003a commit f5d681b
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 51 deletions.
2 changes: 1 addition & 1 deletion src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,7 @@ type ContainerRegistryNode implements Node {
extra: JSONString

"""Added in 25.2.0."""
allowed_groups(limit: Int, offset: Int): [GroupNode]
allowed_groups(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): GroupConnection

Check notice on line 1530 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'allowed_groups' was added to object type 'ContainerRegistryNode'

Field 'allowed_groups' was added to object type 'ContainerRegistryNode'
}

"""Added in 24.09.0."""
Expand Down
104 changes: 75 additions & 29 deletions src/ai/backend/manager/models/container_registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import enum
import logging
import uuid
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, cast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, TypeAlias, cast

import graphene
import graphql
Expand All @@ -18,27 +20,28 @@
from ai.backend.common.exception import UnknownImageRegistry
from ai.backend.common.logging_utils import BraceStyleAdapter
from ai.backend.manager.api.exceptions import ContainerRegistryNotFound
from ai.backend.manager.models.association_container_registries_groups import (
AssociationContainerRegistriesGroupsRow,
)
from ai.backend.manager.models.gql_models.group import GroupNode
from ai.backend.manager.models.group import GroupRow
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
from ai.backend.manager.models.rbac import SystemScope

from ..defs import PASSWORD_PLACEHOLDER
from .association_container_registries_groups import (
AssociationContainerRegistriesGroupsRow,
)
from .base import (
Base,
FilterExprArg,
IDColumn,
OrderExprArg,
PaginatedConnectionField,
StrEnumType,
generate_sql_info_for_gql_connection,
set_if_set,
)
from .gql_models.group import GroupConnection, GroupNode
from .gql_relay import AsyncNode, Connection, ConnectionResolverResult
from .minilang.ordering import OrderSpecItem, QueryOrderParser
from .minilang.queryfilter import FieldSpecItem, QueryFilterParser
from .user import UserRole
from .utils import ExtendedAsyncSAEngine

if TYPE_CHECKING:
from .gql import GraphQueryContext
Expand All @@ -56,9 +59,45 @@
"CreateContainerRegistryNode",
"ModifyContainerRegistryNode",
"DeleteContainerRegistryNode",
"ContainerRegistryScope",
)


WhereClauseType: TypeAlias = (
sa.sql.expression.BinaryExpression | sa.sql.expression.BooleanClauseList
)


class ContainerRegistryScopeType(enum.StrEnum):
USER = "user"
PROJECT = "project"


@dataclass
class ContainerRegistryScope:
scope_type: ContainerRegistryScopeType
registry_id: uuid.UUID

def __str__(self) -> str:
match self.registry_id:
case uuid.UUID():
return f"{self.scope_type}:{str(self.registry_id)}"
case _:
raise ValueError(f"Invalid container registry scope ID: {str(self.registry_id)!r}")

def __repr__(self) -> str:
return self.__str__()

@classmethod
def parse(cls, raw: str) -> ContainerRegistryScope:
scope_type, _, registry_id = raw.partition(":")
match scope_type.lower():
case ContainerRegistryScopeType.PROJECT | ContainerRegistryScopeType.USER as t:
return cls(t, uuid.UUID(registry_id))
case _:
raise ValueError(f"Invalid container registry scope type: {scope_type!r}")


class ContainerRegistryRow(Base):
__tablename__ = "container_registries"
id = IDColumn()
Expand Down Expand Up @@ -332,9 +371,7 @@ class Meta:
password = graphene.String(description="Added in 24.09.0.")
ssl_verify = graphene.Boolean(description="Added in 24.09.0.")
extra = graphene.JSONString(description="Added in 24.09.3.")
allowed_groups = graphene.List(
GroupNode, description="Added in 25.2.0.", limit=graphene.Int(), offset=graphene.Int()
)
allowed_groups = PaginatedConnectionField(GroupConnection, description="Added in 25.2.0.")

_queryfilter_fieldspec: dict[str, FieldSpecItem] = {
"row_id": ("id", None),
Expand Down Expand Up @@ -423,26 +460,35 @@ def from_row(cls, ctx: GraphQueryContext, row: ContainerRegistryRow) -> Containe
async def resolve_allowed_groups(
self,
info: graphene.ResolveInfo,
limit: int,
offset: int,
) -> list[GroupNode]:
graph_ctx: GraphQueryContext = info.context
registry_id = self.id

async with graph_ctx.db.begin_readonly() as db_session:
query = (
sa.select(GroupRow)
.select_from(GroupRow)
.join(
AssociationContainerRegistriesGroupsRow,
GroupRow.id == AssociationContainerRegistriesGroupsRow.group_id,
)
.where(AssociationContainerRegistriesGroupsRow.registry_id == registry_id)
.limit(limit)
.offset(offset)
filter: Optional[str] = None,
order: Optional[str] = None,
offset: Optional[int] = None,
after: Optional[str] = None,
first: Optional[int] = None,
before: Optional[str] = None,
last: Optional[int] = None,
) -> ConnectionResolverResult[GroupNode]:
if self.is_global:
scope = SystemScope()
container_registry_scope = None
else:
scope = None
container_registry_scope = ContainerRegistryScope.parse(
f"{ContainerRegistryScopeType.PROJECT}:{self.id}"
)
groups = (await db_session.execute(query)).all()
return [GroupNode.from_row(graph_ctx, row) for row in groups]

return await GroupNode.get_connection(
info,
scope,
container_registry_scope,
filter_expr=filter,
order_expr=order,
offset=offset,
after=after,
first=first,
before=before,
last=last,
)


class ContainerRegistryConnection(Connection):
Expand Down
31 changes: 23 additions & 8 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from ..idle import IdleCheckerHost
from ..models.utils import ExtendedAsyncSAEngine
from ..registry import AgentRegistry
from .container_registry import ContainerRegistryScope
from .storage import StorageSessionManager

from ..api.exceptions import (
Expand Down Expand Up @@ -133,7 +134,12 @@
from .keypair import CreateKeyPair, DeleteKeyPair, KeyPair, KeyPairList, ModifyKeyPair
from .network import CreateNetwork, DeleteNetwork, ModifyNetwork, NetworkConnection, NetworkNode
from .rbac import ProjectScope, ScopeType, SystemScope
from .rbac.permission_defs import AgentPermission, ComputeSessionPermission, DomainPermission
from .rbac.permission_defs import (
AgentPermission,
ComputeSessionPermission,
DomainPermission,
ProjectPermission,
)
from .rbac.permission_defs import VFolderPermission as VFolderRBACPermission
from .resource_policy import (
CreateKeyPairResourcePolicy,
Expand Down Expand Up @@ -447,6 +453,9 @@ class Queries(graphene.ObjectType):
description="Added in 24.03.0.",
filter=graphene.String(description="Added in 24.09.0."),
order=graphene.String(description="Added in 24.09.0."),
# TODO: Add this.
# scope=ScopeType(),
# container_registry_scope=ContainerRegistryScope(),
)

group = graphene.Field(
Expand Down Expand Up @@ -1126,16 +1135,22 @@ async def resolve_group_nodes(
root: Any,
info: graphene.ResolveInfo,
*,
filter: str | None = None,
order: str | None = None,
offset: int | None = None,
after: str | None = None,
first: int | None = None,
before: str | None = None,
last: int | None = None,
scope: Optional[ScopeType] = None,
container_registry_scope: Optional[ContainerRegistryScope] = None,
permission: ProjectPermission = ProjectPermission.READ_ATTRIBUTE,
filter: Optional[str] = None,
order: Optional[str] = None,
offset: Optional[int] = None,
after: Optional[str] = None,
first: Optional[int] = None,
before: Optional[str] = None,
last: Optional[int] = None,
) -> ConnectionResolverResult[GroupNode]:
return await GroupNode.get_connection(
info,
scope,
container_registry_scope,
permission,
filter,
order,
offset,
Expand Down
49 changes: 36 additions & 13 deletions src/ai/backend/manager/models/gql_models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Mapping
from typing import (
TYPE_CHECKING,
Optional,
Self,
Sequence,
)
Expand All @@ -23,13 +24,17 @@
Connection,
ConnectionResolverResult,
)
from ..group import AssocGroupUserRow, GroupRow, ProjectType
from ..group import AssocGroupUserRow, GroupRow, ProjectType, get_permission_ctx
from ..minilang.ordering import OrderSpecItem, QueryOrderParser
from ..minilang.queryfilter import FieldSpecItem, QueryFilterParser
from ..rbac.context import ClientContext
from ..rbac.permission_defs import ProjectPermission
from .user import UserConnection, UserNode

if TYPE_CHECKING:
from ..container_registry import ContainerRegistryScope
from ..gql import GraphQueryContext
from ..rbac import ScopeType
from ..scaling_group import ScalingGroup

_queryfilter_fieldspec: Mapping[str, FieldSpecItem] = {
Expand Down Expand Up @@ -217,13 +222,16 @@ async def get_node(cls, info: graphene.ResolveInfo, id) -> Self:
async def get_connection(
cls,
info: graphene.ResolveInfo,
filter_expr: str | None = None,
order_expr: str | None = None,
offset: int | None = None,
after: str | None = None,
first: int | None = None,
before: str | None = None,
last: int | None = None,
scope: Optional[ScopeType] = None,
container_registry_scope: Optional[ContainerRegistryScope] = None,
permission: ProjectPermission = ProjectPermission.READ_ATTRIBUTE,
filter_expr: Optional[str] = None,
order_expr: Optional[str] = None,
offset: Optional[int] = None,
after: Optional[str] = None,
first: Optional[int] = None,
before: Optional[str] = None,
last: Optional[int] = None,
) -> ConnectionResolverResult[Self]:
graph_ctx: GraphQueryContext = info.context
_filter_arg = (
Expand Down Expand Up @@ -255,11 +263,26 @@ async def get_connection(
before=before,
last=last,
)
async with graph_ctx.db.begin_readonly_session() as db_session:
group_rows = (await db_session.scalars(query)).all()
result = [cls.from_row(graph_ctx, row) for row in group_rows]
total_cnt = await db_session.scalar(cnt_query)
return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt)
async with graph_ctx.db.connect() as db_conn:
user = graph_ctx.user
client_ctx = ClientContext(
graph_ctx.db, user["domain_name"], user["uuid"], user["role"]
)
permission_ctx = await get_permission_ctx(
db_conn, client_ctx, permission, scope, container_registry_scope
)
cond = permission_ctx.query_condition
if cond is None:
return ConnectionResolverResult([], cursor, pagination_order, page_size, 0)
query = query.where(cond)
cnt_query = cnt_query.where(cond)

async with graph_ctx.db.begin_readonly_session(db_conn) as db_session:
group_rows = (await db_session.scalars(query)).all()
total_cnt = await db_session.scalar(cnt_query)
result = [cls.from_row(graph_ctx, row) for row in group_rows]

return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt)


class GroupConnection(Connection):
Expand Down
Loading

0 comments on commit f5d681b

Please sign in to comment.