diff --git a/changes/3090.feature.md b/changes/3090.feature.md new file mode 100644 index 00000000000..e725ca5ff61 --- /dev/null +++ b/changes/3090.feature.md @@ -0,0 +1 @@ +Implement CRUD API for managing Harbor per-project Quota. diff --git a/docs/manager/graphql-reference/schema.graphql b/docs/manager/graphql-reference/schema.graphql index 94bdb3a2635..d67f8806675 100644 --- a/docs/manager/graphql-reference/schema.graphql +++ b/docs/manager/graphql-reference/schema.graphql @@ -710,6 +710,9 @@ type GroupNode implements Node { """Added in 24.03.7.""" container_registry: JSONString scaling_groups: [String] + + """Added in 25.2.0.""" + registry_quota: BigInt user_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): UserConnection } @@ -1933,6 +1936,15 @@ type Mutations { """Added in 25.1.0.""" delete_endpoint_auto_scaling_rule_node(id: String!): DeleteEndpointAutoScalingRuleNode + """Added in 25.2.0.""" + create_container_registry_quota(quota: BigInt!, scope_id: ScopeField!): CreateContainerRegistryQuota + + """Added in 25.2.0.""" + update_container_registry_quota(quota: BigInt!, scope_id: ScopeField!): UpdateContainerRegistryQuota + + """Added in 25.2.0.""" + delete_container_registry_quota(scope_id: ScopeField!): DeleteContainerRegistryQuota + """Deprecated since 24.09.0. use `CreateContainerRegistryNode` instead""" create_container_registry(hostname: String!, props: CreateContainerRegistryInput!): CreateContainerRegistry @@ -2787,6 +2799,24 @@ type DeleteEndpointAutoScalingRuleNode { msg: String } +"""Added in 25.2.0.""" +type CreateContainerRegistryQuota { + ok: Boolean + msg: String +} + +"""Added in 25.2.0.""" +type UpdateContainerRegistryQuota { + ok: Boolean + msg: String +} + +"""Added in 25.2.0.""" +type DeleteContainerRegistryQuota { + ok: Boolean + msg: String +} + """Deprecated since 24.09.0. use `CreateContainerRegistryNode` instead""" type CreateContainerRegistry { container_registry: ContainerRegistry diff --git a/docs/manager/rest-reference/openapi.json b/docs/manager/rest-reference/openapi.json index 336f51d482d..45b435562dd 100644 --- a/docs/manager/rest-reference/openapi.json +++ b/docs/manager/rest-reference/openapi.json @@ -8462,6 +8462,142 @@ "description": "\n**Preconditions:**\n* Admin privilege required.\n* Manager status required: one of FROZEN, RUNNING\n" } }, + "/group/registry-quota": { + "post": { + "operationId": "group.create_registry_quota", + "tags": [ + "group" + ], + "responses": { + "200": { + "description": "Successful response" + } + }, + "security": [ + { + "TokenAuth": [] + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "group_id": { + "type": "string" + }, + "quota": { + "type": "integer" + } + }, + "required": [ + "group_id", + "quota" + ] + }, + "examples": {} + } + } + }, + "parameters": [], + "description": "\n**Preconditions:**\n* Superadmin privilege required.\n* Manager status required: one of FROZEN, RUNNING\n" + }, + "get": { + "operationId": "group.read_registry_quota", + "tags": [ + "group" + ], + "responses": { + "200": { + "description": "Successful response" + } + }, + "security": [ + { + "TokenAuth": [] + } + ], + "parameters": [ + { + "name": "group_id", + "schema": { + "type": "string" + }, + "required": true, + "in": "query" + } + ], + "description": "\n**Preconditions:**\n* Superadmin privilege required.\n* Manager status required: one of FROZEN, RUNNING\n" + }, + "patch": { + "operationId": "group.update_registry_quota", + "tags": [ + "group" + ], + "responses": { + "200": { + "description": "Successful response" + } + }, + "security": [ + { + "TokenAuth": [] + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "group_id": { + "type": "string" + }, + "quota": { + "type": "integer" + } + }, + "required": [ + "group_id", + "quota" + ] + }, + "examples": {} + } + } + }, + "parameters": [], + "description": "\n**Preconditions:**\n* Superadmin privilege required.\n* Manager status required: one of FROZEN, RUNNING\n" + }, + "delete": { + "operationId": "group.delete_registry_quota", + "tags": [ + "group" + ], + "responses": { + "200": { + "description": "Successful response" + } + }, + "security": [ + { + "TokenAuth": [] + } + ], + "parameters": [ + { + "name": "group_id", + "schema": { + "type": "string" + }, + "required": true, + "in": "query" + } + ], + "description": "\n**Preconditions:**\n* Superadmin privilege required.\n* Manager status required: one of FROZEN, RUNNING\n" + } + }, "/group-config/dotfiles": { "post": { "operationId": "group-config.create", diff --git a/src/ai/backend/client/func/group.py b/src/ai/backend/client/func/group.py index 434460b9cfe..ba48a959523 100644 --- a/src/ai/backend/client/func/group.py +++ b/src/ai/backend/client/func/group.py @@ -1,7 +1,9 @@ +import textwrap from typing import Any, Iterable, Optional, Sequence from ai.backend.client.output.fields import group_fields from ai.backend.client.output.types import FieldSpec +from ai.backend.common.utils import b64encode from ...cli.types import Undefined, undefined from ..session import api_session @@ -293,3 +295,101 @@ async def remove_users( } data = await api_session.get().Admin._query(query, variables) return data["modify_group"] + + @api_function + @classmethod + async def get_container_registry_quota(cls, group_id: str) -> int: + """ + Get Quota Limit for the group's container registry. + Currently only HarborV2 registry is supported. + + You need an admin privilege for this operation. + """ + query = textwrap.dedent( + """\ + query($id: String!) { + group_node(id: $id) { + registry_quota + } + } + """ + ) + + variables = {"id": b64encode(f"group_node:{group_id}")} + data = await api_session.get().Admin._query(query, variables) + return data["group_node"]["registry_quota"] + + @api_function + @classmethod + async def create_container_registry_quota(cls, group_id: str, quota: int) -> dict: + """ + Create Quota Limit for the group's container registry. + Currently only HarborV2 registry is supported. + + You need an admin privilege for this operation. + """ + query = textwrap.dedent( + """\ + mutation($scope_id: ScopeField!, $quota: Int!) { + create_container_registry_quota( + scope_id: $scope_id, quota: $quota) { + ok msg + } + } + """ + ) + + scope_id = f"project:{group_id}" + variables = {"scope_id": scope_id, "quota": quota} + data = await api_session.get().Admin._query(query, variables) + return data["create_container_registry_quota"] + + @api_function + @classmethod + async def update_container_registry_quota(cls, group_id: str, quota: int) -> dict: + """ + Update Quota Limit for the group's container registry. + Currently only HarborV2 registry is supported. + + You need an admin privilege for this operation. + """ + query = textwrap.dedent( + """\ + mutation($scope_id: ScopeField!, $quota: Int!) { + update_container_registry_quota( + scope_id: $scope_id, quota: $quota) { + ok msg + } + } + """ + ) + + scope_id = f"project:{group_id}" + variables = {"scope_id": scope_id, "quota": quota} + data = await api_session.get().Admin._query(query, variables) + return data["update_container_registry_quota"] + + @api_function + @classmethod + async def delete_container_registry_quota(cls, group_id: str) -> dict: + """ + Delete Quota Limit for the group's container registry. + Currently only HarborV2 registry is supported. + + You need an admin privilege for this operation. + """ + query = textwrap.dedent( + """\ + mutation($scope_id: ScopeField!) { + delete_container_registry_quota( + scope_id: $scope_id) { + ok msg + } + } + """ + ) + + scope_id = f"project:{group_id}" + variables = {"scope_id": scope_id} + data = await api_session.get().Admin._query(query, variables) + return data["delete_container_registry_quota"] diff --git a/src/ai/backend/common/utils.py b/src/ai/backend/common/utils.py index 4eed2a03668..f8d31b7366f 100644 --- a/src/ai/backend/common/utils.py +++ b/src/ai/backend/common/utils.py @@ -425,3 +425,12 @@ def join_non_empty(*args: Optional[str], sep: str) -> str: """ filtered_args = [arg for arg in args if arg] return sep.join(filtered_args) + + +def b64encode(s: str) -> str: + """ + base64 encoding method of graphql_relay. + Use it in components where the graphql_relay package is unavailable. + """ + b: bytes = s.encode("utf-8") if isinstance(s, str) else s + return base64.b64encode(b).decode("ascii") diff --git a/src/ai/backend/manager/api/admin.py b/src/ai/backend/manager/api/admin.py index de97f51e0cb..5e59c349e80 100644 --- a/src/ai/backend/manager/api/admin.py +++ b/src/ai/backend/manager/api/admin.py @@ -85,6 +85,7 @@ async def _handle_gql_common(request: web.Request, params: Any) -> ExecutionResu manager_status=manager_status, known_slot_types=known_slot_types, background_task_manager=root_ctx.background_task_manager, + services_ctx=root_ctx.services_ctx, storage_manager=root_ctx.storage_manager, registry=root_ctx.registry, idle_checker_host=root_ctx.idle_checker_host, diff --git a/src/ai/backend/manager/api/container_registry.py b/src/ai/backend/manager/api/container_registry.py index 488fb69c010..9fe10cbb422 100644 --- a/src/ai/backend/manager/api/container_registry.py +++ b/src/ai/backend/manager/api/container_registry.py @@ -39,7 +39,7 @@ async def patch_container_registry( request: web.Request, params: PatchContainerRegistryRequestModel ) -> PatchContainerRegistryResponseModel: registry_id = uuid.UUID(request.match_info["registry_id"]) - log.info("PATCH_CONTAINER_REGISTRY (cr:{})", registry_id) + log.info("PATCH_CONTAINER_REGISTRY (registry:{})", registry_id) root_ctx: RootContext = request.app["_root.context"] registry_row_updates = params.model_dump(exclude={"allowed_groups"}, exclude_none=True) diff --git a/src/ai/backend/manager/api/context.py b/src/ai/backend/manager/api/context.py index 5d7cf4bb5dd..dcd2654076b 100644 --- a/src/ai/backend/manager/api/context.py +++ b/src/ai/backend/manager/api/context.py @@ -6,6 +6,7 @@ from ai.backend.common.metrics.metric import CommonMetricRegistry from ai.backend.manager.plugin.network import NetworkPluginContext +from ai.backend.manager.service.base import ServicesContext if TYPE_CHECKING: from ai.backend.common.bgtask import BackgroundTaskManager @@ -50,6 +51,7 @@ class RootContext(BaseContext): storage_manager: StorageSessionManager hook_plugin_ctx: HookPluginContext network_plugin_ctx: NetworkPluginContext + services_ctx: ServicesContext registry: AgentRegistry agent_cache: AgentRPCCache diff --git a/src/ai/backend/manager/api/group.py b/src/ai/backend/manager/api/group.py new file mode 100644 index 00000000000..2c68457933f --- /dev/null +++ b/src/ai/backend/manager/api/group.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Iterable, Tuple + +import aiohttp_cors +import trafaret as t +from aiohttp import web + +from ai.backend.common import validators as tx +from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.models.rbac import ProjectScope + +if TYPE_CHECKING: + from .context import RootContext + +from .auth import superadmin_required +from .manager import READ_ALLOWED, server_status_required +from .types import CORSOptions, WebMiddleware +from .utils import check_api_params + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + tx.AliasedKey(["group_id", "group"]): t.String, + tx.AliasedKey(["quota"]): t.Int, + }) +) +async def update_registry_quota(request: web.Request, params: Any) -> web.Response: + log.info("UPDATE_REGISTRY_QUOTA (group:{})", params["group_id"]) + root_ctx: RootContext = request.app["_root.context"] + group_id = params["group_id"] + scope_id = ProjectScope(project_id=group_id, domain_name=None) + quota = int(params["quota"]) + + await root_ctx.services_ctx.per_project_container_registries_quota.update_quota(scope_id, quota) + return web.Response(status=204) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + tx.AliasedKey(["group_id", "group"]): t.String, + }) +) +async def delete_registry_quota(request: web.Request, params: Any) -> web.Response: + log.info("DELETE_REGISTRY_QUOTA (group:{})", params["group_id"]) + root_ctx: RootContext = request.app["_root.context"] + group_id = params["group_id"] + scope_id = ProjectScope(project_id=group_id, domain_name=None) + + await root_ctx.services_ctx.per_project_container_registries_quota.delete_quota(scope_id) + return web.Response(status=204) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + tx.AliasedKey(["group_id", "group"]): t.String, + tx.AliasedKey(["quota"]): t.Int, + }) +) +async def create_registry_quota(request: web.Request, params: Any) -> web.Response: + log.info("CREATE_REGISTRY_QUOTA (group:{})", params["group_id"]) + root_ctx: RootContext = request.app["_root.context"] + group_id = params["group_id"] + scope_id = ProjectScope(project_id=group_id, domain_name=None) + quota = int(params["quota"]) + + await root_ctx.services_ctx.per_project_container_registries_quota.create_quota(scope_id, quota) + return web.Response(status=204) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + tx.AliasedKey(["group_id", "group"]): t.String, + }) +) +async def read_registry_quota(request: web.Request, params: Any) -> web.Response: + log.info("READ_REGISTRY_QUOTA (group:{})", params["group_id"]) + root_ctx: RootContext = request.app["_root.context"] + group_id = params["group_id"] + scope_id = ProjectScope(project_id=group_id, domain_name=None) + + quota = await root_ctx.services_ctx.per_project_container_registries_quota.read_quota(scope_id) + + return web.json_response({"result": quota}) + + +def create_app( + default_cors_options: CORSOptions, +) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app["api_versions"] = (1, 2, 3, 4, 5) + app["prefix"] = "group" + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + cors.add(app.router.add_route("POST", "/registry-quota", create_registry_quota)) + cors.add(app.router.add_route("GET", "/registry-quota", read_registry_quota)) + cors.add(app.router.add_route("PATCH", "/registry-quota", update_registry_quota)) + cors.add(app.router.add_route("DELETE", "/registry-quota", delete_registry_quota)) + return app, [] diff --git a/src/ai/backend/manager/client/BUILD b/src/ai/backend/manager/client/BUILD new file mode 100644 index 00000000000..73574424040 --- /dev/null +++ b/src/ai/backend/manager/client/BUILD @@ -0,0 +1 @@ +python_sources(name="src") diff --git a/src/ai/backend/manager/client/__init__.py b/src/ai/backend/manager/client/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ai/backend/manager/client/container_registry/BUILD b/src/ai/backend/manager/client/container_registry/BUILD new file mode 100644 index 00000000000..73574424040 --- /dev/null +++ b/src/ai/backend/manager/client/container_registry/BUILD @@ -0,0 +1 @@ +python_sources(name="src") diff --git a/src/ai/backend/manager/client/container_registry/__init__.py b/src/ai/backend/manager/client/container_registry/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ai/backend/manager/client/container_registry/harbor.py b/src/ai/backend/manager/client/container_registry/harbor.py new file mode 100644 index 00000000000..5fc520a9588 --- /dev/null +++ b/src/ai/backend/manager/client/container_registry/harbor.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import abc +import logging +from typing import TYPE_CHECKING, Any, override + +import aiohttp +import yarl + +from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.api.exceptions import GenericBadRequest, InternalServerError, ObjectNotFound + +if TYPE_CHECKING: + from ai.backend.manager.service.container_registry.harbor import ( + HarborAuthArgs, + HarborProjectInfo, + HarborProjectQuotaInfo, + ) + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +def _get_harbor_auth_args(auth_args: HarborAuthArgs) -> dict[str, Any]: + return {"auth": aiohttp.BasicAuth(auth_args["username"], auth_args["password"])} + + +class AbstractPerProjectRegistryQuotaClient(abc.ABC): + async def create_quota( + self, project_info: HarborProjectInfo, quota: int, auth_args: HarborAuthArgs + ) -> None: + raise NotImplementedError + + async def update_quota( + self, project_info: HarborProjectInfo, quota: int, auth_args: HarborAuthArgs + ) -> None: + raise NotImplementedError + + async def delete_quota( + self, project_info: HarborProjectInfo, auth_args: HarborAuthArgs + ) -> None: + raise NotImplementedError + + async def read_quota(self, project_info: HarborProjectInfo) -> int: + raise NotImplementedError + + +class PerProjectHarborQuotaClient(AbstractPerProjectRegistryQuotaClient): + async def _get_harbor_project_id( + self, + sess: aiohttp.ClientSession, + project_info: HarborProjectInfo, + rqst_args: dict[str, Any], + ) -> str: + get_project_id_api = ( + yarl.URL(project_info.url) / "api" / "v2.0" / "projects" / project_info.project + ) + + async with sess.get(get_project_id_api, allow_redirects=False, **rqst_args) as resp: + if resp.status != 200: + raise InternalServerError(f"Failed to get harbor project_id! response: {resp}") + + res = await resp.json() + harbor_project_id = res["project_id"] + return str(harbor_project_id) + + async def _get_quota_info( + self, + sess: aiohttp.ClientSession, + project_info: HarborProjectInfo, + rqst_args: dict[str, Any], + ) -> HarborProjectQuotaInfo: + from ...service.container_registry.harbor import HarborProjectQuotaInfo + + harbor_project_id = await self._get_harbor_project_id(sess, project_info, rqst_args) + get_quota_id_api = (yarl.URL(project_info.url) / "api" / "v2.0" / "quotas").with_query({ + "reference": "project", + "reference_id": harbor_project_id, + }) + + async with sess.get(get_quota_id_api, allow_redirects=False, **rqst_args) as resp: + if resp.status != 200: + raise InternalServerError(f"Failed to get quota info! response: {resp}") + + res = await resp.json() + if not res: + raise ObjectNotFound(object_name="quota entity") + if len(res) > 1: + raise InternalServerError( + f"Multiple quota entities found. (project_id: {harbor_project_id})" + ) + + previous_quota = res[0]["hard"]["storage"] + quota_id = res[0]["id"] + return HarborProjectQuotaInfo(previous_quota=previous_quota, quota_id=quota_id) + + @override + async def read_quota(self, project_info: HarborProjectInfo) -> int: + connector = aiohttp.TCPConnector(ssl=project_info.ssl_verify) + async with aiohttp.ClientSession(connector=connector) as sess: + rqst_args: dict[str, Any] = {} + quota_info = await self._get_quota_info(sess, project_info, rqst_args) + previous_quota = quota_info["previous_quota"] + if previous_quota == -1: + raise ObjectNotFound(object_name="quota entity") + return previous_quota + + @override + async def create_quota( + self, project_info: HarborProjectInfo, quota: int, auth_args: HarborAuthArgs + ) -> None: + connector = aiohttp.TCPConnector(ssl=project_info.ssl_verify) + async with aiohttp.ClientSession(connector=connector) as sess: + rqst_args = _get_harbor_auth_args(auth_args) + quota_info = await self._get_quota_info(sess, project_info, rqst_args) + previous_quota, quota_id = quota_info["previous_quota"], quota_info["quota_id"] + + if previous_quota > 0: + raise GenericBadRequest("Quota limit already exists!") + + put_quota_api = yarl.URL(project_info.url) / "api" / "v2.0" / "quotas" / str(quota_id) + payload = {"hard": {"storage": quota}} + + async with sess.put( + put_quota_api, json=payload, allow_redirects=False, **rqst_args + ) as resp: + if resp.status != 200: + log.error(f"Failed to create quota! response: {resp}") + raise InternalServerError(f"Failed to create quota! response: {resp}") + + @override + async def update_quota( + self, project_info: HarborProjectInfo, quota: int, auth_args: HarborAuthArgs + ) -> None: + connector = aiohttp.TCPConnector(ssl=project_info.ssl_verify) + async with aiohttp.ClientSession(connector=connector) as sess: + rqst_args = _get_harbor_auth_args(auth_args) + quota_info = await self._get_quota_info(sess, project_info, rqst_args) + previous_quota, quota_id = quota_info["previous_quota"], quota_info["quota_id"] + + if previous_quota == -1: + raise ObjectNotFound(object_name="quota entity") + + put_quota_api = yarl.URL(project_info.url) / "api" / "v2.0" / "quotas" / str(quota_id) + payload = {"hard": {"storage": quota}} + + async with sess.put( + put_quota_api, json=payload, allow_redirects=False, **rqst_args + ) as resp: + if resp.status != 200: + log.error(f"Failed to update quota! response: {resp}") + raise InternalServerError(f"Failed to update quota! response: {resp}") + + @override + async def delete_quota( + self, project_info: HarborProjectInfo, auth_args: HarborAuthArgs + ) -> None: + connector = aiohttp.TCPConnector(ssl=project_info.ssl_verify) + async with aiohttp.ClientSession(connector=connector) as sess: + rqst_args = _get_harbor_auth_args(auth_args) + quota_info = await self._get_quota_info(sess, project_info, rqst_args) + previous_quota, quota_id = quota_info["previous_quota"], quota_info["quota_id"] + + if previous_quota == -1: + raise ObjectNotFound(object_name="quota entity") + + put_quota_api = yarl.URL(project_info.url) / "api" / "v2.0" / "quotas" / str(quota_id) + payload = {"hard": {"storage": -1}} + + async with sess.put( + put_quota_api, json=payload, allow_redirects=False, **rqst_args + ) as resp: + if resp.status != 200: + log.error(f"Failed to delete quota! response: {resp}") + raise InternalServerError(f"Failed to delete quota! response: {resp}") diff --git a/src/ai/backend/manager/container_registry/__init__.py b/src/ai/backend/manager/container_registry/__init__.py index 89ba53dde0f..7ffe07c6283 100644 --- a/src/ai/backend/manager/container_registry/__init__.py +++ b/src/ai/backend/manager/container_registry/__init__.py @@ -5,9 +5,10 @@ import yarl from ai.backend.common.container_registry import ContainerRegistryType -from ai.backend.manager.models.container_registry import ContainerRegistryRow if TYPE_CHECKING: + from ai.backend.manager.models.container_registry import ContainerRegistryRow + from .base import BaseContainerRegistry diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index b4346104863..033967202bc 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -12,6 +12,7 @@ from graphql.type import GraphQLField from ai.backend.manager.plugin.network import NetworkPluginContext +from ai.backend.manager.service.base import ServicesContext set_input_object_type_default_value(Undefined) @@ -75,6 +76,11 @@ AgentSummaryList, ModifyAgent, ) +from .gql_models.container_registry import ( + CreateContainerRegistryQuota, + DeleteContainerRegistryQuota, + UpdateContainerRegistryQuota, +) from .gql_models.domain import ( CreateDomainNode, DomainConnection, @@ -227,6 +233,7 @@ class GraphQueryContext: access_key: str db: ExtendedAsyncSAEngine network_plugin_ctx: NetworkPluginContext + services_ctx: ServicesContext redis_stat: RedisConnectionInfo redis_live: RedisConnectionInfo redis_image: RedisConnectionInfo @@ -360,6 +367,15 @@ class Mutations(graphene.ObjectType): delete_endpoint_auto_scaling_rule_node = DeleteEndpointAutoScalingRuleNode.Field( description="Added in 25.1.0." ) + create_container_registry_quota = CreateContainerRegistryQuota.Field( + description="Added in 25.2.0." + ) + update_container_registry_quota = UpdateContainerRegistryQuota.Field( + description="Added in 25.2.0." + ) + delete_container_registry_quota = DeleteContainerRegistryQuota.Field( + description="Added in 25.2.0." + ) # Legacy mutations create_container_registry = CreateContainerRegistry.Field() diff --git a/src/ai/backend/manager/models/gql_models/container_registry.py b/src/ai/backend/manager/models/gql_models/container_registry.py new file mode 100644 index 00000000000..c5aa718fc0c --- /dev/null +++ b/src/ai/backend/manager/models/gql_models/container_registry.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Self + +import graphene + +from ai.backend.logging import BraceStyleAdapter + +from ..base import BigInt +from ..rbac import ProjectScope, ScopeType +from ..user import UserRole +from .fields import ScopeField + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore + +if TYPE_CHECKING: + from ai.backend.manager.models.gql import GraphQueryContext + + +class CreateContainerRegistryQuota(graphene.Mutation): + """Added in 25.2.0.""" + + allowed_roles = ( + UserRole.SUPERADMIN, + UserRole.ADMIN, + ) + + class Arguments: + scope_id = ScopeField(required=True) + quota = BigInt(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + scope_id: ScopeType, + quota: int | float, + ) -> Self: + graph_ctx: GraphQueryContext = info.context + try: + match scope_id: + case ProjectScope(_): + await ( + graph_ctx.services_ctx.per_project_container_registries_quota.create_quota( + scope_id, int(quota) + ) + ) + case _: + raise NotImplementedError("Only project scope is supported for now.") + + return cls(ok=True, msg="success") + except Exception as e: + return cls(ok=False, msg=str(e)) + + +class UpdateContainerRegistryQuota(graphene.Mutation): + """Added in 25.2.0.""" + + allowed_roles = ( + UserRole.SUPERADMIN, + UserRole.ADMIN, + ) + + class Arguments: + scope_id = ScopeField(required=True) + quota = BigInt(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + scope_id: ScopeType, + quota: int | float, + ) -> Self: + graph_ctx: GraphQueryContext = info.context + try: + match scope_id: + case ProjectScope(_): + await ( + graph_ctx.services_ctx.per_project_container_registries_quota.update_quota( + scope_id, int(quota) + ) + ) + case _: + raise NotImplementedError("Only project scope is supported for now.") + + return cls(ok=True, msg="success") + except Exception as e: + return cls(ok=False, msg=str(e)) + + +class DeleteContainerRegistryQuota(graphene.Mutation): + """Added in 25.2.0.""" + + allowed_roles = ( + UserRole.SUPERADMIN, + UserRole.ADMIN, + ) + + class Arguments: + scope_id = ScopeField(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + scope_id: ScopeType, + ) -> Self: + graph_ctx: GraphQueryContext = info.context + try: + match scope_id: + case ProjectScope(_): + await ( + graph_ctx.services_ctx.per_project_container_registries_quota.delete_quota( + scope_id + ) + ) + case _: + raise NotImplementedError("Only project scope is supported for now.") + + return cls(ok=True, msg="success") + except Exception as e: + return cls(ok=False, msg=str(e)) diff --git a/src/ai/backend/manager/models/gql_models/group.py b/src/ai/backend/manager/models/gql_models/group.py index f5bb2f6aede..780b85968cd 100644 --- a/src/ai/backend/manager/models/gql_models/group.py +++ b/src/ai/backend/manager/models/gql_models/group.py @@ -13,7 +13,10 @@ from dateutil.parser import parse as dtparse from graphene.types.datetime import DateTime as GQLDateTime +from ai.backend.manager.models.rbac import ProjectScope + from ..base import ( + BigInt, FilterExprArg, OrderExprArg, PaginatedConnectionField, @@ -117,6 +120,8 @@ class Meta: lambda: graphene.String, ) + registry_quota = BigInt(description="Added in 25.2.0.") + user_nodes = PaginatedConnectionField( UserConnection, ) @@ -209,6 +214,13 @@ async def resolve_user_nodes( total_cnt = await db_session.scalar(cnt_query) return ConnectionResolverResult(result, cursor, pagination_order, page_size, total_cnt) + async def resolve_registry_quota(self, info: graphene.ResolveInfo) -> int: + graph_ctx: GraphQueryContext = info.context + scope_id = ProjectScope(project_id=self.id, domain_name=None) + return await graph_ctx.services_ctx.per_project_container_registries_quota.read_quota( + scope_id + ) + @classmethod async def get_node(cls, info: graphene.ResolveInfo, id) -> Self: graph_ctx: GraphQueryContext = info.context diff --git a/src/ai/backend/manager/server.py b/src/ai/backend/manager/server.py index 34ede4a5da6..1a39fead1d1 100644 --- a/src/ai/backend/manager/server.py +++ b/src/ai/backend/manager/server.py @@ -63,6 +63,11 @@ from ai.backend.common.utils import env_info from ai.backend.logging import BraceStyleAdapter, Logger, LogLevel from ai.backend.manager.plugin.network import NetworkPluginContext +from ai.backend.manager.service.base import ServicesContext +from ai.backend.manager.service.container_registry.base import PerProjectRegistryQuotaRepository +from ai.backend.manager.service.container_registry.harbor import ( + PerProjectContainerRegistryQuotaService, +) from . import __version__ from .agent_cache import AgentRPCCache @@ -195,6 +200,7 @@ ".image", ".userconfig", ".domainconfig", + ".group", ".groupconfig", ".logs", ] @@ -692,6 +698,20 @@ async def _force_terminate_hanging_sessions( await task +@actxmgr +async def services_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + db = root_ctx.db + + per_project_container_registries_quota = PerProjectContainerRegistryQuotaService( + repository=PerProjectRegistryQuotaRepository(db) + ) + + root_ctx.services_ctx = ServicesContext( + per_project_container_registries_quota, + ) + yield None + + class background_task_ctx: def __init__(self, root_ctx: RootContext) -> None: self.root_ctx = root_ctx @@ -859,6 +879,7 @@ def build_root_app( manager_status_ctx, redis_ctx, database_ctx, + services_ctx, distributed_lock_ctx, event_dispatcher_ctx, idle_checker_ctx, diff --git a/src/ai/backend/manager/service/BUILD b/src/ai/backend/manager/service/BUILD new file mode 100644 index 00000000000..73574424040 --- /dev/null +++ b/src/ai/backend/manager/service/BUILD @@ -0,0 +1 @@ +python_sources(name="src") diff --git a/src/ai/backend/manager/service/__init__.py b/src/ai/backend/manager/service/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ai/backend/manager/service/base.py b/src/ai/backend/manager/service/base.py new file mode 100644 index 00000000000..bd4a22aa21a --- /dev/null +++ b/src/ai/backend/manager/service/base.py @@ -0,0 +1,18 @@ +from .container_registry.harbor import ( + PerProjectContainerRegistryQuota, +) + + +class ServicesContext: + """ + In the API layer, requests are processed through the ServicesContext and + its subordinate layers, including the DB, Client, and Repository layers. + Each layer separates the responsibilities specific to its respective level. + """ + + per_project_container_registries_quota: PerProjectContainerRegistryQuota + + def __init__( + self, per_project_container_registries_quota: PerProjectContainerRegistryQuota + ) -> None: + self.per_project_container_registries_quota = per_project_container_registries_quota diff --git a/src/ai/backend/manager/service/container_registry/BUILD b/src/ai/backend/manager/service/container_registry/BUILD new file mode 100644 index 00000000000..73574424040 --- /dev/null +++ b/src/ai/backend/manager/service/container_registry/BUILD @@ -0,0 +1 @@ +python_sources(name="src") diff --git a/src/ai/backend/manager/service/container_registry/__init__.py b/src/ai/backend/manager/service/container_registry/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ai/backend/manager/service/container_registry/base.py b/src/ai/backend/manager/service/container_registry/base.py new file mode 100644 index 00000000000..f9567e2cdf8 --- /dev/null +++ b/src/ai/backend/manager/service/container_registry/base.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import abc +import logging +import uuid +from dataclasses import dataclass +from typing import Any, override + +import sqlalchemy as sa +from sqlalchemy.orm import load_only + +from ai.backend.common.container_registry import ContainerRegistryType +from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.api.exceptions import ( + ContainerRegistryNotFound, +) +from ai.backend.manager.models.container_registry import ContainerRegistryRow +from ai.backend.manager.models.group import GroupRow +from ai.backend.manager.models.rbac import ProjectScope +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +@dataclass +class ContainerRegistryRowInfo: + id: uuid.UUID + url: str + registry_name: str + type: ContainerRegistryType + project: str + username: str + password: str + ssl_verify: bool + is_global: bool + extra: dict[str, Any] + + +class AbstractPerProjectRegistryQuotaRepository(abc.ABC): + async def fetch_container_registry_row( + self, scope_id: ProjectScope + ) -> ContainerRegistryRowInfo: + raise NotImplementedError + + +class PerProjectRegistryQuotaRepository(AbstractPerProjectRegistryQuotaRepository): + def __init__(self, db: ExtendedAsyncSAEngine): + self.db = db + + @classmethod + def _is_valid_group_row(cls, group_row: GroupRow) -> bool: + return ( + group_row + and group_row.container_registry + and "registry" in group_row.container_registry + and "project" in group_row.container_registry + ) + + @override + async def fetch_container_registry_row( + self, scope_id: ProjectScope + ) -> ContainerRegistryRowInfo: + async with self.db.begin_readonly_session() as db_sess: + project_id = scope_id.project_id + group_query = ( + sa.select(GroupRow) + .where(GroupRow.id == project_id) + .options(load_only(GroupRow.container_registry)) + ) + result = await db_sess.execute(group_query) + group_row = result.scalar_one_or_none() + + if not PerProjectRegistryQuotaRepository._is_valid_group_row(group_row): + raise ContainerRegistryNotFound( + f"Container registry info does not exist or is invalid in the group. (group: {project_id})" + ) + + registry_name, project = ( + group_row.container_registry["registry"], + group_row.container_registry["project"], + ) + + registry_query = sa.select(ContainerRegistryRow).where( + (ContainerRegistryRow.registry_name == registry_name) + & (ContainerRegistryRow.project == project) + ) + + result = await db_sess.execute(registry_query) + registry = result.scalars().one_or_none() + + if not registry: + raise ContainerRegistryNotFound( + f"Container registry row not found. (registry: {registry_name}, group: {project})" + ) + + return ContainerRegistryRowInfo( + id=registry.id, + url=registry.url, + registry_name=registry.registry_name, + type=registry.type, + project=registry.project, + username=registry.username, + password=registry.password, + ssl_verify=registry.ssl_verify, + is_global=registry.is_global, + extra=registry.extra, + ) diff --git a/src/ai/backend/manager/service/container_registry/harbor.py b/src/ai/backend/manager/service/container_registry/harbor.py new file mode 100644 index 00000000000..f74ec4d9c49 --- /dev/null +++ b/src/ai/backend/manager/service/container_registry/harbor.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import abc +import logging +from dataclasses import dataclass +from typing import TypedDict, override + +from ai.backend.common.container_registry import ContainerRegistryType +from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.api.exceptions import GenericBadRequest +from ai.backend.manager.client.container_registry.harbor import ( + AbstractPerProjectRegistryQuotaClient, + PerProjectHarborQuotaClient, +) +from ai.backend.manager.models.rbac import ProjectScope +from ai.backend.manager.service.container_registry.base import ( + ContainerRegistryRowInfo, + PerProjectRegistryQuotaRepository, +) + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +@dataclass +class HarborProjectInfo: + url: str + project: str + ssl_verify: bool + + +class HarborAuthArgs(TypedDict): + username: str + password: str + + +class HarborProjectQuotaInfo(TypedDict): + previous_quota: int + quota_id: int + + +class PerProjectContainerRegistryQuota(abc.ABC): + async def create_quota(self, scope_id: ProjectScope, quota: int) -> None: + raise NotImplementedError + + async def update_quota(self, scope_id: ProjectScope, quota: int) -> None: + raise NotImplementedError + + async def delete_quota(self, scope_id: ProjectScope) -> None: + raise NotImplementedError + + async def read_quota(self, scope_id: ProjectScope) -> int: + raise NotImplementedError + + +class PerProjectContainerRegistryQuotaService(PerProjectContainerRegistryQuota): + repository: PerProjectRegistryQuotaRepository + + def __init__(self, repository: PerProjectRegistryQuotaRepository): + self.repository = repository + + def _registry_row_to_harbor_project_info( + self, registry_info: ContainerRegistryRowInfo + ) -> HarborProjectInfo: + return HarborProjectInfo( + url=registry_info.url, + project=registry_info.project, + ssl_verify=registry_info.ssl_verify, + ) + + def _make_client(self, type_: ContainerRegistryType) -> AbstractPerProjectRegistryQuotaClient: + match type_: + case ContainerRegistryType.HARBOR2: + return PerProjectHarborQuotaClient() + case _: + raise GenericBadRequest( + f"{type_} does not support registry quota per project management." + ) + + @override + async def create_quota(self, scope_id: ProjectScope, quota: int) -> None: + registry_info = await self.repository.fetch_container_registry_row(scope_id) + project_info = self._registry_row_to_harbor_project_info(registry_info) + credential = HarborAuthArgs( + username=registry_info.username, password=registry_info.password + ) + await self._make_client(registry_info.type).create_quota(project_info, quota, credential) + + @override + async def update_quota(self, scope_id: ProjectScope, quota: int) -> None: + registry_info = await self.repository.fetch_container_registry_row(scope_id) + project_info = self._registry_row_to_harbor_project_info(registry_info) + credential = HarborAuthArgs( + username=registry_info.username, password=registry_info.password + ) + await self._make_client(registry_info.type).update_quota(project_info, quota, credential) + + @override + async def delete_quota(self, scope_id: ProjectScope) -> None: + registry_info = await self.repository.fetch_container_registry_row(scope_id) + project_info = self._registry_row_to_harbor_project_info(registry_info) + credential = HarborAuthArgs( + username=registry_info.username, password=registry_info.password + ) + await self._make_client(registry_info.type).delete_quota(project_info, credential) + + @override + async def read_quota(self, scope_id: ProjectScope) -> int: + registry_info = await self.repository.fetch_container_registry_row(scope_id) + project_info = self._registry_row_to_harbor_project_info(registry_info) + return await self._make_client(registry_info.type).read_quota(project_info) diff --git a/src/ai/backend/testutils/extra_fixtures.py b/src/ai/backend/testutils/extra_fixtures.py new file mode 100644 index 00000000000..a7e18108779 --- /dev/null +++ b/src/ai/backend/testutils/extra_fixtures.py @@ -0,0 +1,34 @@ +FIXTURES_FOR_HARBOR_CRUD_TEST = [ + { + "container_registries": [ + { + "id": "00000000-0000-0000-0000-000000000000", + "type": "harbor2", + "url": "http://mock_registry", + "registry_name": "mock_registry", + "project": "mock_project", + "username": "mock_user", + "password": "mock_password", + "ssl_verify": False, + "is_global": True, + } + ], + "groups": [ + { + "id": "00000000-0000-0000-0000-000000000000", + "name": "mock_group", + "description": "", + "is_active": True, + "domain_name": "default", + "resource_policy": "default", + "total_resource_slots": {}, + "allowed_vfolder_hosts": {}, + "container_registry": { + "registry": "mock_registry", + "project": "mock_project", + }, + "type": "general", + } + ], + }, +] diff --git a/tests/manager/api/test_group.py b/tests/manager/api/test_group.py new file mode 100644 index 00000000000..2a18c69a99a --- /dev/null +++ b/tests/manager/api/test_group.py @@ -0,0 +1,331 @@ +import json +from urllib.parse import urlencode + +import pytest +from aioresponses import aioresponses + +from ai.backend.manager.server import ( + database_ctx, + hook_plugin_ctx, + monitoring_ctx, + redis_ctx, + shared_config_ctx, +) +from ai.backend.testutils.extra_fixtures import FIXTURES_FOR_HARBOR_CRUD_TEST + + +@pytest.mark.asyncio +@pytest.mark.parametrize("extra_fixtures", FIXTURES_FOR_HARBOR_CRUD_TEST) +@pytest.mark.parametrize( + "test_case", + [ + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": -1}, + } + ], + }, + "expected_code": 200, + }, + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": 100}, + } + ], + }, + "expected_code": 400, + }, + ], + ids=["Normal case", "Project Quota already exist"], +) +async def test_harbor_create_project_quota( + test_case, + etcd_fixture, + database_fixture, + create_app_and_client, + get_headers, +): + app, client = await create_app_and_client( + [ + shared_config_ctx, + database_ctx, + monitoring_ctx, + hook_plugin_ctx, + redis_ctx, + ], + [".group", ".auth"], + ) + + mock_harbor_responses = test_case["mock_harbor_responses"] + + url = "/group/registry-quota" + params = {"group_id": "00000000-0000-0000-0000-000000000000", "quota": 100} + req_bytes = json.dumps(params).encode() + headers = get_headers("POST", url, req_bytes) + + with aioresponses(passthrough=["http://127.0.0.1"]) as mocked: + get_project_id_url = "http://mock_registry/api/v2.0/projects/mock_project" + mocked.get( + get_project_id_url, + status=200, + payload=mock_harbor_responses["get_project_id"], + ) + + harbor_project_id = mock_harbor_responses["get_project_id"]["project_id"] + get_quota_url = f"http://mock_registry/api/v2.0/quotas?reference=project&reference_id={harbor_project_id}" + mocked.get( + get_quota_url, + status=200, + payload=mock_harbor_responses["get_quotas"], + ) + + harbor_quota_id = mock_harbor_responses["get_quotas"][0]["id"] + put_quota_url = f"http://mock_registry/api/v2.0/quotas/{harbor_quota_id}" + mocked.put( + put_quota_url, + status=200, + ) + + resp = await client.post(url, data=req_bytes, headers=headers) + assert resp.status == test_case["expected_code"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("extra_fixtures", FIXTURES_FOR_HARBOR_CRUD_TEST) +@pytest.mark.parametrize( + "test_case", + [ + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": 100}, + } + ], + }, + "expected_code": 200, + }, + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": -1}, + } + ], + }, + "expected_code": 404, + }, + ], + ids=["Normal case", "Project Quota doesn't exist"], +) +async def test_harbor_read_project_quota( + test_case, + etcd_fixture, + database_fixture, + create_app_and_client, + get_headers, +): + app, client = await create_app_and_client( + [ + shared_config_ctx, + database_ctx, + monitoring_ctx, + hook_plugin_ctx, + redis_ctx, + ], + [".group", ".auth"], + ) + + mock_harbor_responses = test_case["mock_harbor_responses"] + + with aioresponses(passthrough=["http://127.0.0.1"]) as mocked: + get_project_id_url = "http://mock_registry/api/v2.0/projects/mock_project" + mocked.get(get_project_id_url, status=200, payload=mock_harbor_responses["get_project_id"]) + harbor_project_id = mock_harbor_responses["get_project_id"]["project_id"] + + get_quota_url = f"http://mock_registry/api/v2.0/quotas?reference=project&reference_id={harbor_project_id}" + mocked.get( + get_quota_url, + status=200, + payload=mock_harbor_responses["get_quotas"], + ) + + url = "/group/registry-quota" + params = {"group_id": "00000000-0000-0000-0000-000000000000"} + full_url = f"{url}?{urlencode(params)}" + headers = get_headers("GET", full_url, b"") + + resp = await client.get(url, params=params, headers=headers) + assert resp.status == test_case["expected_code"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("extra_fixtures", FIXTURES_FOR_HARBOR_CRUD_TEST) +@pytest.mark.parametrize( + "test_case", + [ + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": 100}, + } + ], + }, + "expected_code": 200, + }, + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": -1}, + } + ], + }, + "expected_code": 404, + }, + ], + ids=["Normal case", "Project Quota not found"], +) +async def test_harbor_update_project_quota( + test_case, + etcd_fixture, + database_fixture, + create_app_and_client, + get_headers, +): + app, client = await create_app_and_client( + [ + shared_config_ctx, + database_ctx, + monitoring_ctx, + hook_plugin_ctx, + redis_ctx, + ], + [".group", ".auth"], + ) + + mock_harbor_responses = test_case["mock_harbor_responses"] + + url = "/group/registry-quota" + params = {"group_id": "00000000-0000-0000-0000-000000000000", "quota": 200} + req_bytes = json.dumps(params).encode() + headers = get_headers("PATCH", url, req_bytes) + + with aioresponses(passthrough=["http://127.0.0.1"]) as mocked: + get_project_id_url = "http://mock_registry/api/v2.0/projects/mock_project" + mocked.get(get_project_id_url, status=200, payload=mock_harbor_responses["get_project_id"]) + harbor_project_id = mock_harbor_responses["get_project_id"]["project_id"] + + get_quota_url = f"http://mock_registry/api/v2.0/quotas?reference=project&reference_id={harbor_project_id}" + mocked.get( + get_quota_url, + status=200, + payload=mock_harbor_responses["get_quotas"], + ) + harbor_quota_id = mock_harbor_responses["get_quotas"][0]["id"] + + put_quota_url = f"http://mock_registry/api/v2.0/quotas/{harbor_quota_id}" + mocked.put( + put_quota_url, + status=200, + ) + + resp = await client.patch(url, data=req_bytes, headers=headers) + assert resp.status == test_case["expected_code"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("extra_fixtures", FIXTURES_FOR_HARBOR_CRUD_TEST) +@pytest.mark.parametrize( + "test_case", + [ + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": 100}, + } + ], + }, + "expected_code": 200, + }, + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": -1}, + } + ], + }, + "expected_code": 404, + }, + ], + ids=["Normal case", "Project Quota not found"], +) +async def test_harbor_delete_project_quota( + test_case, + etcd_fixture, + database_fixture, + create_app_and_client, + get_headers, +): + app, client = await create_app_and_client( + [ + shared_config_ctx, + database_ctx, + monitoring_ctx, + hook_plugin_ctx, + redis_ctx, + ], + [".group", ".auth"], + ) + + mock_harbor_responses = test_case["mock_harbor_responses"] + + url = "/group/registry-quota" + params = {"group_id": "00000000-0000-0000-0000-000000000000"} + req_bytes = json.dumps(params).encode() + headers = get_headers("DELETE", url, req_bytes) + + with aioresponses(passthrough=["http://127.0.0.1"]) as mocked: + get_project_id_url = "http://mock_registry/api/v2.0/projects/mock_project" + mocked.get(get_project_id_url, status=200, payload=mock_harbor_responses["get_project_id"]) + harbor_project_id = mock_harbor_responses["get_project_id"]["project_id"] + + get_quota_url = f"http://mock_registry/api/v2.0/quotas?reference=project&reference_id={harbor_project_id}" + mocked.get( + get_quota_url, + status=200, + payload=mock_harbor_responses["get_quotas"], + ) + harbor_quota_id = mock_harbor_responses["get_quotas"][0]["id"] + + put_quota_url = f"http://mock_registry/api/v2.0/quotas/{harbor_quota_id}" + mocked.put( + put_quota_url, + status=200, + ) + + resp = await client.delete(url, data=req_bytes, headers=headers) + assert resp.status == test_case["expected_code"] diff --git a/tests/manager/models/gql_models/test_container_registries.py b/tests/manager/models/gql_models/test_container_registries.py index fd61b6a9368..d1189f46977 100644 --- a/tests/manager/models/gql_models/test_container_registries.py +++ b/tests/manager/models/gql_models/test_container_registries.py @@ -1,9 +1,15 @@ import pytest +from aioresponses import aioresponses from graphene import Schema from graphene.test import Client +from ai.backend.manager.api.context import RootContext from ai.backend.manager.models.gql import GraphQueryContext, Mutations, Queries from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.server import ( + database_ctx, +) +from ai.backend.testutils.extra_fixtures import FIXTURES_FOR_HARBOR_CRUD_TEST @pytest.fixture(scope="module") @@ -31,4 +37,271 @@ def get_graphquery_context(database_engine: ExtendedAsyncSAEngine) -> GraphQuery registry=None, # type: ignore idle_checker_host=None, # type: ignore network_plugin_ctx=None, # type: ignore + services_ctx=None, # type: ignore ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("extra_fixtures", FIXTURES_FOR_HARBOR_CRUD_TEST) +@pytest.mark.parametrize( + "test_case", + [ + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": -1}, + } + ], + }, + "expected": True, + }, + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": 100}, + } + ], + }, + "expected": False, + }, + ], + ids=["Normal case", "Project Quota already exist"], +) +async def test_harbor_create_project_quota( + client: Client, + test_case, + database_fixture, + create_app_and_client, +): + test_app, _ = await create_app_and_client( + [ + database_ctx, + ], + [], + ) + + root_ctx: RootContext = test_app["_root.context"] + context = get_graphquery_context(root_ctx.db) + + create_query = """ + mutation ($scope_id: ScopeField!, $quota: BigInt!) { + create_container_registry_quota(scope_id: $scope_id, quota: $quota) { + ok + msg + } + } + """ + variables = { + "scope_id": "project:00000000-0000-0000-0000-000000000000", + "quota": 100, + } + + mock_harbor_responses = test_case["mock_harbor_responses"] + + with aioresponses() as mocked: + get_project_id_url = "http://mock_registry/api/v2.0/projects/mock_project" + mocked.get(get_project_id_url, status=200, payload=mock_harbor_responses["get_project_id"]) + + harbor_project_id = mock_harbor_responses["get_project_id"]["project_id"] + get_quotas_url = f"http://mock_registry/api/v2.0/quotas?reference=project&reference_id={harbor_project_id}" + mocked.get( + get_quotas_url, + status=200, + payload=mock_harbor_responses["get_quotas"], + ) + + harbor_quota_id = mock_harbor_responses["get_quotas"][0]["id"] + put_quota_url = f"http://mock_registry/api/v2.0/quotas/{harbor_quota_id}" + mocked.put( + put_quota_url, + status=200, + ) + + response = await client.execute_async( + create_query, variables=variables, context_value=context + ) + + assert response["data"]["create_container_registry_quota"]["ok"] == test_case["expected"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("extra_fixtures", FIXTURES_FOR_HARBOR_CRUD_TEST) +@pytest.mark.parametrize( + "test_case", + [ + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": 100}, + } + ], + }, + "expected": True, + }, + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": -1}, + } + ], + }, + "expected": False, + }, + ], + ids=["Normal case", "Project Quota not found"], +) +async def test_harbor_update_project_quota( + client: Client, + test_case, + database_fixture, + create_app_and_client, +): + test_app, _ = await create_app_and_client( + [ + database_ctx, + ], + [], + ) + + root_ctx: RootContext = test_app["_root.context"] + context = get_graphquery_context(root_ctx.db) + + update_query = """ + mutation ($scope_id: ScopeField!, $quota: BigInt!) { + update_container_registry_quota(scope_id: $scope_id, quota: $quota) { + ok + msg + } + } + """ + variables = { + "scope_id": "project:00000000-0000-0000-0000-000000000000", + "quota": 200, + } + + mock_harbor_responses = test_case["mock_harbor_responses"] + + with aioresponses() as mocked: + get_project_id_url = "http://mock_registry/api/v2.0/projects/mock_project" + mocked.get(get_project_id_url, status=200, payload=mock_harbor_responses["get_project_id"]) + + harbor_project_id = mock_harbor_responses["get_project_id"]["project_id"] + + get_quotas_url = f"http://mock_registry/api/v2.0/quotas?reference=project&reference_id={harbor_project_id}" + mocked.get( + get_quotas_url, + status=200, + payload=mock_harbor_responses["get_quotas"], + ) + + harbor_quota_id = mock_harbor_responses["get_quotas"][0]["id"] + put_quota_url = f"http://mock_registry/api/v2.0/quotas/{harbor_quota_id}" + mocked.put( + put_quota_url, + status=200, + ) + + response = await client.execute_async( + update_query, variables=variables, context_value=context + ) + assert response["data"]["update_container_registry_quota"]["ok"] == test_case["expected"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("extra_fixtures", FIXTURES_FOR_HARBOR_CRUD_TEST) +@pytest.mark.parametrize( + "test_case", + [ + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": 100}, + } + ], + }, + "expected": True, + }, + { + "mock_harbor_responses": { + "get_project_id": {"project_id": "1"}, + "get_quotas": [ + { + "id": 1, + "hard": {"storage": -1}, + } + ], + }, + "expected": False, + }, + ], + ids=["Normal case", "Project Quota not found"], +) +async def test_harbor_delete_project_quota( + client: Client, + test_case, + database_fixture, + create_app_and_client, +): + test_app, _ = await create_app_and_client( + [ + database_ctx, + ], + [], + ) + + root_ctx: RootContext = test_app["_root.context"] + context = get_graphquery_context(root_ctx.db) + + delete_query = """ + mutation ($scope_id: ScopeField!) { + delete_container_registry_quota(scope_id: $scope_id) { + ok + msg + } + } + """ + variables = { + "scope_id": "project:00000000-0000-0000-0000-000000000000", + } + + mock_harbor_responses = test_case["mock_harbor_responses"] + + with aioresponses() as mocked: + get_project_id_url = "http://mock_registry/api/v2.0/projects/mock_project" + mocked.get(get_project_id_url, status=200, payload=mock_harbor_responses["get_project_id"]) + + harbor_project_id = mock_harbor_responses["get_project_id"]["project_id"] + + get_quotas_url = f"http://mock_registry/api/v2.0/quotas?reference=project&reference_id={harbor_project_id}" + mocked.get( + get_quotas_url, + status=200, + payload=mock_harbor_responses["get_quotas"], + ) + + harbor_quota_id = mock_harbor_responses["get_quotas"][0]["id"] + put_quota_url = f"http://mock_registry/api/v2.0/quotas/{harbor_quota_id}" + mocked.put( + put_quota_url, + status=200, + ) + + response = await client.execute_async( + delete_query, variables=variables, context_value=context + ) + assert response["data"]["delete_container_registry_quota"]["ok"] == test_case["expected"] diff --git a/tests/manager/models/gql_models/test_group.py b/tests/manager/models/gql_models/test_group.py new file mode 100644 index 00000000000..37880b2f340 --- /dev/null +++ b/tests/manager/models/gql_models/test_group.py @@ -0,0 +1,94 @@ +import pytest +from aioresponses import aioresponses +from graphene import Schema +from graphene.test import Client + +from ai.backend.common.utils import b64encode +from ai.backend.manager.api.context import RootContext +from ai.backend.manager.models.gql import GraphQueryContext, Mutations, Queries +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.server import ( + database_ctx, +) +from ai.backend.testutils.extra_fixtures import FIXTURES_FOR_HARBOR_CRUD_TEST + + +@pytest.fixture(scope="module") +def client() -> Client: + return Client(Schema(query=Queries, mutation=Mutations, auto_camelcase=False)) + + +def get_graphquery_context(database_engine: ExtendedAsyncSAEngine) -> GraphQueryContext: + return GraphQueryContext( + schema=None, # type: ignore + dataloader_manager=None, # type: ignore + local_config=None, # type: ignore + shared_config=None, # type: ignore + etcd=None, # type: ignore + user={"domain": "default", "role": "superadmin"}, + access_key="AKIAIOSFODNN7EXAMPLE", + db=database_engine, # type: ignore + redis_stat=None, # type: ignore + redis_image=None, # type: ignore + redis_live=None, # type: ignore + manager_status=None, # type: ignore + known_slot_types=None, # type: ignore + background_task_manager=None, # type: ignore + storage_manager=None, # type: ignore + registry=None, # type: ignore + idle_checker_host=None, # type: ignore + network_plugin_ctx=None, # type: ignore + services_ctx=None, # type: ignore + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("extra_fixtures", FIXTURES_FOR_HARBOR_CRUD_TEST) +async def test_harbor_read_project_quota( + client: Client, + database_fixture, + create_app_and_client, +): + test_app, _ = await create_app_and_client( + [ + database_ctx, + ], + [], + ) + + root_ctx: RootContext = test_app["_root.context"] + context = get_graphquery_context(root_ctx.db) + + # Arbitrary values for mocking Harbor API responses + HARBOR_PROJECT_ID = "123" + HARBOR_QUOTA_ID = 456 + HARBOR_QUOTA_VALUE = 1024 + + with aioresponses() as mocked: + get_project_id_url = "http://mock_registry/api/v2.0/projects/mock_project" + mocked.get(get_project_id_url, status=200, payload={"project_id": HARBOR_PROJECT_ID}) + + get_quota_url = f"http://mock_registry/api/v2.0/quotas?reference=project&reference_id={HARBOR_PROJECT_ID}" + mocked.get( + get_quota_url, + status=200, + payload=[{"id": HARBOR_QUOTA_ID, "hard": {"storage": HARBOR_QUOTA_VALUE}}], + ) + + groupnode_query = """ + query ($id: String!) { + group_node(id: $id) { + registry_quota + } + } + """ + + group_id = "00000000-0000-0000-0000-000000000000" + variables = { + "id": b64encode(f"group_node:{group_id}"), + } + + response = await client.execute_async( + groupnode_query, variables=variables, context_value=context + ) + assert response["data"]["group_node"]["registry_quota"] == HARBOR_QUOTA_VALUE diff --git a/tests/manager/models/test_container_registries.py b/tests/manager/models/test_container_registries.py index 0da01fab80d..33e2f109903 100644 --- a/tests/manager/models/test_container_registries.py +++ b/tests/manager/models/test_container_registries.py @@ -45,6 +45,7 @@ def get_graphquery_context(database_engine: ExtendedAsyncSAEngine) -> GraphQuery storage_manager=None, # type: ignore registry=None, # type: ignore idle_checker_host=None, # type: ignore + services_ctx=None, # type: ignore ) diff --git a/tests/manager/models/test_container_registry_nodes.py b/tests/manager/models/test_container_registry_nodes.py index 623719034f5..45a8ce05b23 100644 --- a/tests/manager/models/test_container_registry_nodes.py +++ b/tests/manager/models/test_container_registry_nodes.py @@ -108,6 +108,7 @@ def mock_shared_config_api_getitem(key): registry=None, # type: ignore idle_checker_host=None, # type: ignore network_plugin_ctx=None, # type: ignore + services_ctx=None, # type: ignore )