diff --git a/changes/2979.fix.md b/changes/2979.fix.md new file mode 100644 index 00000000000..f5bad08f80b --- /dev/null +++ b/changes/2979.fix.md @@ -0,0 +1 @@ +Refactor container registries' projects traversal logic of the image rescanning. diff --git a/src/ai/backend/common/docker.py b/src/ai/backend/common/docker.py index ff816964153..1ab9d6cd48b 100644 --- a/src/ai/backend/common/docker.py +++ b/src/ai/backend/common/docker.py @@ -13,6 +13,7 @@ from typing import ( Final, Iterable, + Literal, Mapping, NamedTuple, Optional, @@ -437,7 +438,9 @@ def parse_image_tag( return image_str, tag @classmethod - def parse_image_str(cls, image_str: str, registry: str | None = None) -> ParsedImageStr: + def parse_image_str( + cls, image_str: str, registry: str | Literal["*"] | None = None + ) -> ParsedImageStr: """ Parses a string representing an image. @@ -458,7 +461,7 @@ def parse_image_str(cls, image_str: str, registry: str | None = None) -> ParsedI if "://" in image_str or image_str.startswith("//"): raise InvalidImageName(image_str) - def divide_parts(image_str: str, registry: str | None) -> tuple[str, str]: + def divide_parts(image_str: str, registry: str | Literal["*"] | None) -> tuple[str, str]: if "/" not in image_str: return (default_registry, image_str) diff --git a/src/ai/backend/manager/container_registry/aws_ecr.py b/src/ai/backend/manager/container_registry/aws_ecr.py index 25fee9a50df..0865eecc406 100644 --- a/src/ai/backend/manager/container_registry/aws_ecr.py +++ b/src/ai/backend/manager/container_registry/aws_ecr.py @@ -1,10 +1,11 @@ import logging -from typing import AsyncIterator +from typing import AsyncIterator, override import aiohttp import boto3 from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.exceptions import ContainerRegistryProjectEmpty from .base import ( BaseContainerRegistry, @@ -14,10 +15,14 @@ class AWSElasticContainerRegistry(BaseContainerRegistry): + @override async def fetch_repositories( self, sess: aiohttp.ClientSession, ) -> AsyncIterator[str]: + if not self.registry_info.project: + raise ContainerRegistryProjectEmpty(self.registry_info.type, self.registry_info.project) + access_key, secret_access_key, region, type_ = ( self.registry_info.extra.get("access_key"), self.registry_info.extra.get("secret_access_key"), @@ -43,10 +48,12 @@ async def fetch_repositories( for repo in response["repositories"]: match type_: case "ecr": - yield repo["repositoryName"] + if repo["repositoryName"].startswith(self.registry_info.project): + yield repo["repositoryName"] case "ecr-public": registry_alias = (repo["repositoryUri"].split("/"))[1] - yield f"{registry_alias}/{repo['repositoryName']}" + if self.registry_info.project == registry_alias: + yield f"{registry_alias}/{repo['repositoryName']}" case _: raise ValueError(f"Unknown registry type: {type_}") diff --git a/src/ai/backend/manager/container_registry/base.py b/src/ai/backend/manager/container_registry/base.py index 023c7b667c1..9b2574f3dca 100644 --- a/src/ai/backend/manager/container_registry/base.py +++ b/src/ai/backend/manager/container_registry/base.py @@ -13,7 +13,6 @@ import sqlalchemy as sa import trafaret as t import yarl -from sqlalchemy.orm import load_only from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.docker import ( @@ -147,61 +146,45 @@ async def commit_rescan_result(self) -> None: image_row.resources = update["resources"] image_row.is_local = is_local - registries = cast( - list[ContainerRegistryRow], - ( - await session.scalars( - sa.select(ContainerRegistryRow).options( - load_only( - ContainerRegistryRow.project, - ContainerRegistryRow.registry_name, - ContainerRegistryRow.url, - ) - ) - ) - ).all(), - ) - for image_identifier, update in _all_updates.items(): - for registry in registries: - try: - parsed_img = ImageRef.from_image_str( - image_identifier.canonical, registry.project, registry.registry_name - ) - except ProjectMismatchWithCanonical: - continue - except ValueError as e: - skip_reason = str(e) - progress_msg = f"Skipped image - {image_identifier.canonical}/{image_identifier.architecture} ({skip_reason})" - log.warning(progress_msg) - break - - session.add( - ImageRow( - name=parsed_img.canonical, - project=registry.project, - registry=parsed_img.registry, - registry_id=registry.id, - image=join_non_empty(parsed_img.project, parsed_img.name, sep="/"), - tag=parsed_img.tag, - architecture=image_identifier.architecture, - is_local=is_local, - config_digest=update["config_digest"], - size_bytes=update["size_bytes"], - type=ImageType.COMPUTE, - accelerators=update.get("accels"), - labels=update["labels"], - resources=update["resources"], - ) + try: + parsed_img = ImageRef.from_image_str( + image_identifier.canonical, + self.registry_info.project, + self.registry_info.registry_name, + is_local=is_local, ) - progress_msg = f"Updated image - {parsed_img.canonical}/{image_identifier.architecture} ({update['config_digest']})" - log.info(progress_msg) - break - - else: - skip_reason = "No container registry found matching the image." + except ProjectMismatchWithCanonical: + continue + except ValueError as e: + skip_reason = str(e) progress_msg = f"Skipped image - {image_identifier.canonical}/{image_identifier.architecture} ({skip_reason})" log.warning(progress_msg) + if (reporter := progress_reporter.get()) is not None: + await reporter.update(1, message=progress_msg) + + continue + + session.add( + ImageRow( + name=parsed_img.canonical, + project=self.registry_info.project, + registry=parsed_img.registry, + registry_id=self.registry_info.id, + image=join_non_empty(parsed_img.project, parsed_img.name, sep="/"), + tag=parsed_img.tag, + architecture=image_identifier.architecture, + is_local=is_local, + config_digest=update["config_digest"], + size_bytes=update["size_bytes"], + type=ImageType.COMPUTE, + accelerators=update.get("accels"), + labels=update["labels"], + resources=update["resources"], + ) + ) + progress_msg = f"Updated image - {parsed_img.canonical}/{image_identifier.architecture} ({update['config_digest']})" + log.info(progress_msg) if (reporter := progress_reporter.get()) is not None: await reporter.update(1, message=progress_msg) diff --git a/src/ai/backend/manager/container_registry/docker.py b/src/ai/backend/manager/container_registry/docker.py index b391178c075..a290ddad506 100644 --- a/src/ai/backend/manager/container_registry/docker.py +++ b/src/ai/backend/manager/container_registry/docker.py @@ -1,6 +1,6 @@ import json import logging -from typing import AsyncIterator, Optional, cast +from typing import AsyncIterator, Optional, cast, override import aiohttp import typing_extensions @@ -8,6 +8,7 @@ from ai.backend.common.docker import login as registry_login from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.exceptions import ContainerRegistryProjectEmpty from .base import BaseContainerRegistry @@ -18,6 +19,7 @@ class DockerHubRegistry(BaseContainerRegistry): @typing_extensions.deprecated( "Rescanning a whole Docker Hub account is disabled due to the API rate limit." ) + @override async def fetch_repositories( self, sess: aiohttp.ClientSession, @@ -63,10 +65,14 @@ async def fetch_repositories_legacy( class DockerRegistry_v2(BaseContainerRegistry): + @override async def fetch_repositories( self, sess: aiohttp.ClientSession, ) -> AsyncIterator[str]: + if not self.registry_info.project: + raise ContainerRegistryProjectEmpty(self.registry_info.type, self.registry_info.project) + # The credential should have the catalog search privilege. rqst_args = await registry_login( sess, @@ -83,7 +89,8 @@ async def fetch_repositories( if resp.status == 200: data = json.loads(await resp.read()) for item in data["repositories"]: - yield item + if item.startswith(self.registry_info.project): + yield item log.debug("found {} repositories", len(data["repositories"])) else: log.warning( diff --git a/src/ai/backend/manager/container_registry/github.py b/src/ai/backend/manager/container_registry/github.py index 0fcebf1365c..1ba9020129e 100644 --- a/src/ai/backend/manager/container_registry/github.py +++ b/src/ai/backend/manager/container_registry/github.py @@ -1,9 +1,10 @@ import logging -from typing import AsyncIterator +from typing import AsyncIterator, override import aiohttp from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.exceptions import ContainerRegistryProjectEmpty from .base import ( BaseContainerRegistry, @@ -13,18 +14,19 @@ class GitHubRegistry(BaseContainerRegistry): - async def fetch_repositories( - self, - sess: aiohttp.ClientSession, - ) -> AsyncIterator[str]: - username = self.registry_info.username + @override + async def fetch_repositories(self, sess: aiohttp.ClientSession) -> AsyncIterator[str]: + if not self.registry_info.project: + raise ContainerRegistryProjectEmpty(self.registry_info.type, self.registry_info.project) + + project = self.registry_info.project access_token = self.registry_info.password entity_type = self.registry_info.extra.get("entity_type", None) if entity_type is None: raise RuntimeError("Entity type is not provided for GitHub registry!") - base_url = f"https://api.github.com/{entity_type}/{username}/packages" + base_url = f"https://api.github.com/{entity_type}/{project}/packages" headers = { "Authorization": f"Bearer {access_token}", @@ -41,7 +43,7 @@ async def fetch_repositories( if response.status == 200: data = await response.json() for repo in data: - yield f"{username}/{repo['name']}" + yield f"{project}/{repo['name']}" if "next" in response.links: page += 1 else: diff --git a/src/ai/backend/manager/container_registry/gitlab.py b/src/ai/backend/manager/container_registry/gitlab.py index 02875a608b9..98515b1e367 100644 --- a/src/ai/backend/manager/container_registry/gitlab.py +++ b/src/ai/backend/manager/container_registry/gitlab.py @@ -1,12 +1,11 @@ import logging import urllib.parse -from typing import AsyncIterator, cast +from typing import AsyncIterator, override import aiohttp -import sqlalchemy as sa from ai.backend.logging import BraceStyleAdapter -from ai.backend.manager.models.container_registry import ContainerRegistryRow +from ai.backend.manager.exceptions import ContainerRegistryProjectEmpty from .base import ( BaseContainerRegistry, @@ -16,49 +15,42 @@ class GitLabRegistry(BaseContainerRegistry): + @override async def fetch_repositories(self, sess: aiohttp.ClientSession) -> AsyncIterator[str]: + if not self.registry_info.project: + raise ContainerRegistryProjectEmpty(self.registry_info.type, self.registry_info.project) + access_token = self.registry_info.password api_endpoint = self.registry_info.extra.get("api_endpoint", None) if api_endpoint is None: raise RuntimeError('"api_endpoint" is not provided for GitLab registry!') - async with self.db.begin_readonly_session() as db_sess: - result = await db_sess.execute( - sa.select(ContainerRegistryRow.project).where( - ContainerRegistryRow.registry_name == self.registry_info.registry_name - ) - ) - projects = cast(list[str], result.scalars().all()) - - for project in projects: - encoded_project_id = urllib.parse.quote(project, safe="") - repo_list_url = ( - f"{api_endpoint}/api/v4/projects/{encoded_project_id}/registry/repositories" - ) - - headers = { - "Accept": "application/json", - "PRIVATE-TOKEN": access_token, - } - page = 1 - - while True: - async with sess.get( - repo_list_url, - headers=headers, - params={"per_page": 30, "page": page}, - ) as response: - if response.status == 200: - data = await response.json() - - for repo in data: - yield repo["path"] - if "next" in response.headers.get("Link", ""): - page += 1 - else: - break + encoded_project_id = urllib.parse.quote(self.registry_info.project, safe="") + repo_list_url = f"{api_endpoint}/api/v4/projects/{encoded_project_id}/registry/repositories" + + headers = { + "Accept": "application/json", + "PRIVATE-TOKEN": access_token, + } + page = 1 + + while True: + async with sess.get( + repo_list_url, + headers=headers, + params={"per_page": 30, "page": page}, + ) as response: + if response.status == 200: + data = await response.json() + + for repo in data: + yield repo["path"] + if "next" in response.headers.get("Link", ""): + page += 1 else: - raise RuntimeError( - f"Failed to fetch repositories for project {project}! {response.status} error occurred." - ) + break + else: + raise RuntimeError( + f"Failed to fetch repositories for project {self.registry_info.project}! {response.status} error occurred." + ) diff --git a/src/ai/backend/manager/container_registry/harbor.py b/src/ai/backend/manager/container_registry/harbor.py index 094dec67716..a999ffccf1e 100644 --- a/src/ai/backend/manager/container_registry/harbor.py +++ b/src/ai/backend/manager/container_registry/harbor.py @@ -3,18 +3,17 @@ import json import logging import urllib.parse -from typing import Any, AsyncIterator, Mapping, Optional, cast +from typing import Any, AsyncIterator, Mapping, Optional, cast, override import aiohttp import aiohttp.client_exceptions import aiotools -import sqlalchemy as sa import yarl from ai.backend.common.docker import ImageRef, arch_name_aliases from ai.backend.common.docker import login as registry_login from ai.backend.logging import BraceStyleAdapter -from ai.backend.manager.models.container_registry import ContainerRegistryRow +from ai.backend.manager.exceptions import ContainerRegistryProjectEmpty from .base import ( BaseContainerRegistry, @@ -26,19 +25,15 @@ class HarborRegistry_v1(BaseContainerRegistry): + @override async def fetch_repositories( self, sess: aiohttp.ClientSession, ) -> AsyncIterator[str]: - api_url = self.registry_url / "api" + if not self.registry_info.project: + raise ContainerRegistryProjectEmpty(self.registry_info.type, self.registry_info.project) - async with self.db.begin_readonly_session() as db_sess: - result = await db_sess.execute( - sa.select(ContainerRegistryRow.project).where( - ContainerRegistryRow.registry_name == self.registry_info.registry_name - ) - ) - registry_projects = cast(list[str | None], result.scalars().all()) + api_url = self.registry_url / "api" rqst_args: dict[str, Any] = {} if self.credentials: @@ -55,7 +50,7 @@ async def fetch_repositories( async with sess.get(project_list_url, allow_redirects=False, **rqst_args) as resp: projects = await resp.json() for item in projects: - if item["name"] in registry_projects: + if item["name"] == self.registry_info.project: project_ids.append(item["project_id"]) project_list_url = None next_page_link = resp.links.get("next") @@ -86,6 +81,7 @@ async def fetch_repositories( next_page_url.query ) + @override async def _scan_tag( self, sess: aiohttp.ClientSession, @@ -141,11 +137,10 @@ async def untag( image: ImageRef, ) -> None: project = image.project - repository = image.name - - if project is None: - raise ValueError("project is required for Harbor registry") + if not project: + raise ContainerRegistryProjectEmpty(self.registry_info.type, project) + repository = image.name base_url = ( self.registry_url / "api" @@ -181,19 +176,15 @@ async def untag( ): # 404 means image is already removed from harbor so we can just safely ignore the exception raise RuntimeError(f"Failed to untag {image}: {e.message}") from e + @override async def fetch_repositories( self, sess: aiohttp.ClientSession, ) -> AsyncIterator[str]: - api_url = self.registry_url / "api" / "v2.0" + if not self.registry_info.project: + raise ContainerRegistryProjectEmpty(self.registry_info.type, self.registry_info.project) - async with self.db.begin_readonly_session() as db_sess: - result = await db_sess.execute( - sa.select(ContainerRegistryRow.project).where( - ContainerRegistryRow.registry_name == self.registry_info.registry_name - ) - ) - registry_projects = cast(list[str | None], result.scalars().all()) + api_url = self.registry_url / "api" / "v2.0" rqst_args: dict[str, Any] = {} if self.credentials: @@ -201,33 +192,34 @@ async def fetch_repositories( self.credentials["username"], self.credentials["password"], ) - repo_list_url: Optional[yarl.URL] - for project_name in registry_projects: - assert project_name is not None - repo_list_url = (api_url / "projects" / project_name / "repositories").with_query( - {"page_size": "30"}, - ) - while repo_list_url is not None: - async with sess.get(repo_list_url, allow_redirects=False, **rqst_args) as resp: - items = await resp.json() - if isinstance(items, dict) and (errors := items.get("errors", [])): - raise RuntimeError( - f"failed to fetch repositories in project {project_name}", - errors[0]["code"], - errors[0]["message"], - ) - repos = [item["name"] for item in items] - for item in repos: - yield item - repo_list_url = None - next_page_link = resp.links.get("next") - if next_page_link: - next_page_url = cast(yarl.URL, next_page_link["url"]) - repo_list_url = self.registry_url.with_path(next_page_url.path).with_query( - next_page_url.query - ) + repo_list_url: Optional[yarl.URL] + repo_list_url = ( + api_url / "projects" / self.registry_info.project / "repositories" + ).with_query( + {"page_size": "30"}, + ) + while repo_list_url is not None: + async with sess.get(repo_list_url, allow_redirects=False, **rqst_args) as resp: + items = await resp.json() + if isinstance(items, dict) and (errors := items.get("errors", [])): + raise RuntimeError( + f"failed to fetch repositories in project {self.registry_info.project}", + errors[0]["code"], + errors[0]["message"], + ) + repos = [item["name"] for item in items] + for item in repos: + yield item + repo_list_url = None + next_page_link = resp.links.get("next") + if next_page_link: + next_page_url = cast(yarl.URL, next_page_link["url"]) + repo_list_url = self.registry_url.with_path(next_page_url.path).with_query( + next_page_url.query + ) + @override async def _scan_image( self, sess: aiohttp.ClientSession, @@ -293,6 +285,7 @@ async def _scan_image( next_page_url.query ) + @override async def _scan_tag( self, sess: aiohttp.ClientSession, @@ -333,6 +326,7 @@ async def _scan_tag( case _ as media_type: raise RuntimeError(f"Unsupported artifact media-type: {media_type}") + @override async def _process_oci_index( self, tg: aiotools.TaskGroup, @@ -369,6 +363,7 @@ async def _process_oci_index( ) ) + @override async def _process_docker_v2_multiplatform_image( self, tg: aiotools.TaskGroup, @@ -407,6 +402,7 @@ async def _process_docker_v2_multiplatform_image( ) ) + @override async def _process_docker_v2_image( self, tg: aiotools.TaskGroup, diff --git a/src/ai/backend/manager/container_registry/local.py b/src/ai/backend/manager/container_registry/local.py index 1177e4a211d..d5a46c13c09 100644 --- a/src/ai/backend/manager/container_registry/local.py +++ b/src/ai/backend/manager/container_registry/local.py @@ -3,7 +3,7 @@ import json import logging from contextlib import asynccontextmanager as actxmgr -from typing import AsyncIterator, Optional +from typing import AsyncIterator, Optional, override import aiohttp import sqlalchemy as sa @@ -29,6 +29,7 @@ async def prepare_client_session(self) -> AsyncIterator[tuple[yarl.URL, aiohttp. async with aiohttp.ClientSession(connector=connector.connector) as sess: yield connector.docker_host, sess + @override async def fetch_repositories( self, sess: aiohttp.ClientSession, @@ -48,6 +49,7 @@ async def fetch_repositories( continue yield image_ref_str # this includes the tag part + @override async def _scan_image( self, sess: aiohttp.ClientSession, diff --git a/src/ai/backend/manager/exceptions.py b/src/ai/backend/manager/exceptions.py index dfa5a99b806..389a10b7f2b 100644 --- a/src/ai/backend/manager/exceptions.py +++ b/src/ai/backend/manager/exceptions.py @@ -5,6 +5,7 @@ TYPE_CHECKING, Any, List, + Literal, NotRequired, Optional, Tuple, @@ -143,3 +144,10 @@ def convert_to_status_data( if is_debug: data["error"]["traceback"] = "\n".join(traceback.format_tb(e.__traceback__)) return data + + +class ContainerRegistryProjectEmpty(RuntimeError): + def __init__(self, type: str, project: Literal[""] | None): + super().__init__( + f"{type} container registry requires project value, but {project} is provided" + ) diff --git a/src/ai/backend/manager/models/image.py b/src/ai/backend/manager/models/image.py index 387d32e201e..ca1a816d80d 100644 --- a/src/ai/backend/manager/models/image.py +++ b/src/ai/backend/manager/models/image.py @@ -8,6 +8,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, List, NamedTuple, Optional, @@ -100,51 +101,136 @@ class ImageLoadFilter(enum.StrEnum): """Include every customized images filed at the system. Effective only for superadmin. CUSTOMIZED and CUSTOMIZED_GLOBAL are mutually exclusive.""" -async def rescan_images( +class RelationLoadingOption(enum.StrEnum): + ALIASES = enum.auto() + ENDPOINTS = enum.auto() + REGISTRY = enum.auto() + + +async def load_all_registries( db: ExtendedAsyncSAEngine, - registry_or_image: str | None = None, - *, - reporter: ProgressReporter | None = None, -) -> None: +) -> dict[str, ContainerRegistryRow]: + join = functools.partial(join_non_empty, sep="/") + async with db.begin_readonly_session() as session: result = await session.execute(sa.select(ContainerRegistryRow)) - latest_registry_config = cast( - dict[str, ContainerRegistryRow], - {row.registry_name: row for row in result.scalars().all()}, - ) + all_registry_config = { + join(row.registry_name, row.project): row for row in result.scalars().all() + } + return cast(dict[str, ContainerRegistryRow], all_registry_config) - # TODO: delete images from registries removed from the previous config? - if registry_or_image is None: - # scan all configured registries - registries = latest_registry_config - else: - # find if it's a full image ref of one of configured registries - for registry_name, registry_info in latest_registry_config.items(): - if registry_or_image.startswith(registry_name + "/"): - repo_with_tag = registry_or_image.removeprefix(registry_name + "/") - log.debug( - "running a per-image metadata scan: {}, {}", - registry_name, - repo_with_tag, - ) - scanner_cls = get_container_registry_cls(registry_info) - scanner = scanner_cls(db, registry_name, registry_info) - await scanner.scan_single_ref(repo_with_tag) - return - else: - # treat it as a normal registry name - registry = registry_or_image - try: - registries = {registry: latest_registry_config[registry]} - log.debug("running a per-registry metadata scan") - except KeyError: - raise RuntimeError("It is an unknown registry.", registry) + +async def scan_registries( + db: ExtendedAsyncSAEngine, + registries: dict[str, ContainerRegistryRow], + reporter: Optional[ProgressReporter] = None, +) -> None: + """ + Performs an image rescan for all images in the registries. + """ async with aiotools.TaskGroup() as tg: - for registry_name, registry_info in registries.items(): + for registry_key, registry_row in registries.items(): + registry_name = ImageRef.parse_image_str(registry_key, "*").registry log.info('Scanning kernel images from the registry "{0}"', registry_name) - scanner_cls = get_container_registry_cls(registry_info) - scanner = scanner_cls(db, registry_name, registry_info) + + scanner_cls = get_container_registry_cls(registry_row) + scanner = scanner_cls(db, registry_name, registry_row) tg.create_task(scanner.rescan_single_registry(reporter)) + + +async def scan_single_image( + db: ExtendedAsyncSAEngine, + registry_key: str, + registry_row: ContainerRegistryRow, + image_canonical: str, +) -> None: + """ + Performs a scan for a single image. + """ + registry_name = ImageRef.parse_image_str(registry_key, "*").registry + image_name = image_canonical.removeprefix(registry_name + "/") + + log.debug("running a per-image metadata scan: {}, {}", registry_name, image_name) + + scanner_cls = get_container_registry_cls(registry_row) + scanner = scanner_cls(db, registry_name, registry_row) + await scanner.scan_single_ref(image_name) + + +def filter_registry_dict( + registries: dict[str, ContainerRegistryRow], + condition: Callable[[str, ContainerRegistryRow], bool], +) -> dict[str, ContainerRegistryRow]: + return { + registry_key: registry_row + for registry_key, registry_row in registries.items() + if condition(registry_key, registry_row) + } + + +def filter_registries_by_img_canonical( + registries: dict[str, ContainerRegistryRow], registry_or_image: str +) -> dict[str, ContainerRegistryRow]: + """ + Filters the matching registry assuming `registry_or_image` is an image canonical name. + """ + return filter_registry_dict( + registries, + lambda registry_key, _row: registry_or_image.startswith(registry_key + "/"), + ) + + +def filter_registries_by_registry_name( + registries: dict[str, ContainerRegistryRow], registry_or_image: str +) -> dict[str, ContainerRegistryRow]: + """ + Filters the matching registry assuming `registry_or_image` is a registry name. + """ + return filter_registry_dict( + registries, + lambda registry_key, _row: registry_key.startswith(registry_or_image), + ) + + +async def rescan_images( + db: ExtendedAsyncSAEngine, + registry_or_image: Optional[str] = None, + *, + reporter: Optional[ProgressReporter] = None, +) -> None: + """ + Performs an image rescan and updates the database. + Refer to the comments below for details on the function's behavior. + + If registry name is provided for `registry_or_image`, scans all images in the specified registry. + If image canonical name is provided for `registry_or_image`, only scan the image. + If the `registry_or_image` is not provided, scan all configured registries. + """ + all_registry_config = await load_all_registries(db) + + if registry_or_image is None: + await scan_registries(db, all_registry_config, reporter=reporter) + return + + matching_registries = filter_registries_by_img_canonical(all_registry_config, registry_or_image) + + if matching_registries: + if len(matching_registries) > 1: + raise RuntimeError( + "ContainerRegistryRows exist with the same registry_name and project!", + ) + + registry_key, registry_row = next(iter(matching_registries.items())) + await scan_single_image(db, registry_key, registry_row, registry_or_image) + return + + matching_registries = filter_registries_by_registry_name(all_registry_config, registry_or_image) + + if not matching_registries: + raise RuntimeError("It is an unknown registry.", registry_or_image) + + log.debug("running a per-registry metadata scan") + await scan_registries(db, matching_registries, reporter=reporter) # TODO: delete images removed from registry? @@ -256,7 +342,9 @@ async def from_alias( cls, session: AsyncSession, alias: str, - load_aliases=False, + load_aliases: bool = False, + *, + loading_options: Iterable[RelationLoadingOption] = tuple(), ) -> ImageRow: query = ( sa.select(ImageRow) diff --git a/tests/client/integration/test_image.py b/tests/client/integration/test_image.py index 978cd8fa736..08c64be7a5d 100644 --- a/tests/client/integration/test_image.py +++ b/tests/client/integration/test_image.py @@ -29,11 +29,6 @@ async def test_list_images_by_user(userconfig): assert "hash" in image -# This is invasive... -# async def test_rescan_images(): -# pass - - @pytest.mark.asyncio async def test_alias_dealias_image_by_admin(): with Session() as sess: diff --git a/tests/manager/models/test_image.py b/tests/manager/models/test_image.py new file mode 100644 index 00000000000..2e928b15c50 --- /dev/null +++ b/tests/manager/models/test_image.py @@ -0,0 +1,219 @@ +import asyncio +import json + +import attr +import pytest +import sqlalchemy as sa +from aiohttp import web +from aioresponses import aioresponses +from graphene import Schema +from graphene.test import Client + +from ai.backend.common.events import BgtaskDoneEvent, EventDispatcher +from ai.backend.common.types import AgentId +from ai.backend.manager.api.context import RootContext +from ai.backend.manager.models.gql import GraphQueryContext, Mutations, Queries +from ai.backend.manager.models.image import ImageRow +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.server import ( + background_task_ctx, + database_ctx, + event_dispatcher_ctx, + hook_plugin_ctx, + monitoring_ctx, + redis_ctx, + shared_config_ctx, +) + + +@pytest.fixture(scope="module") +def client() -> Client: + return Client(Schema(query=Queries, mutation=Mutations, auto_camelcase=False)) + + +def get_graphquery_context( + background_task_manager, database_engine: ExtendedAsyncSAEngine +) -> GraphQueryContext: + return GraphQueryContext( + schema=None, # type: ignore + dataloader_manager=None, # type: ignore + local_config=None, # type: ignore + shared_config=None, # type: ignore + etcd=None, # type: ignore + user={"domain": "default", "role": "superadmin"}, + access_key="AKIAIOSFODNN7EXAMPLE", + db=database_engine, # type: ignore + redis_stat=None, # type: ignore + redis_image=None, # type: ignore + redis_live=None, # type: ignore + manager_status=None, # type: ignore + known_slot_types=None, # type: ignore + background_task_manager=background_task_manager, # type: ignore + storage_manager=None, # type: ignore + registry=None, # type: ignore + idle_checker_host=None, # type: ignore + ) + + +FIXTURES_REGISTRIES = [ + { + "container_registries": [ + { + "id": "00000000-0000-0000-0000-000000000000", + "url": "http://mock_registry", + "type": "docker", + "project": "lablup", + "registry_name": "mock_registry", + } + ] + } +] + + +@pytest.mark.asyncio +@pytest.mark.timeout(60) +@pytest.mark.parametrize("extra_fixtures", FIXTURES_REGISTRIES) +@pytest.mark.parametrize( + "test_case", + [ + { + "mock_dockerhub_responses": { + "get_token": {"token": "fake-token"}, + "get_catalog": {"repositories": ["lablup/python"]}, + "get_tags": {"tags": ["latest"]}, + "get_manifest": { + "schemaVersion": 2, + "mediaType": "application/vnd.docker.distribution.manifest.v2+json", + "config": { + "mediaType": "application/vnd.docker.container.image.v1+json", + "size": 100, + "digest": "sha256:1111111111111111111111111111111111111111111111111111111111111111", + }, + "layers": [], + }, + "get_config": { + "architecture": "amd64", + "os": "linux", + }, + } + } + ], +) +async def test_image_rescan( + client: Client, + test_case, + etcd_fixture, + extra_fixtures, + database_fixture, + create_app_and_client, +): + app, _ = await create_app_and_client( + [ + shared_config_ctx, + database_ctx, + monitoring_ctx, + hook_plugin_ctx, + redis_ctx, + event_dispatcher_ctx, + background_task_ctx, + ], + [".events", ".auth"], + ) + root_ctx: RootContext = app["_root.context"] + dispatcher: EventDispatcher = root_ctx.event_dispatcher + done_handler_ctx = {} + done_event = asyncio.Event() + + async def done_sub( + context: web.Application, + source: AgentId, + event: BgtaskDoneEvent, + ) -> None: + done_handler_ctx["event_name"] = event.name + update_body = attr.asdict(event) # type: ignore + done_handler_ctx.update(**update_body) + done_event.set() + + dispatcher.subscribe(BgtaskDoneEvent, app, done_sub) + + mock_dockerhub_responses = test_case["mock_dockerhub_responses"] + + def setup_dockerhub_mocking(mocked): + registry_url = extra_fixtures["container_registries"][0]["url"] + + # /v2/ endpoint + mocked.get( + f"{registry_url}/v2/", + status=200, + payload=mock_dockerhub_responses["get_tags"], + repeat=True, + ) + + # catalog + mocked.get( + f"{registry_url}/v2/_catalog?n=30", + status=200, + payload=mock_dockerhub_responses["get_catalog"], + ) + + # tags + mocked.get( + f"{registry_url}/v2/lablup/python/tags/list?n=10", + status=200, + payload=mock_dockerhub_responses["get_tags"], + ) + + # manifest + mocked.get( + f"{registry_url}/v2/lablup/python/manifests/latest", + status=200, + payload=mock_dockerhub_responses["get_manifest"], + headers={ + "Content-Type": "application/vnd.docker.distribution.manifest.v2+json", + }, + ) + + config_data = mock_dockerhub_responses["get_manifest"]["config"] + image_digest = config_data["digest"] + + # config blob (JSON) + mocked.get( + f"{registry_url}/v2/lablup/python/blobs/{image_digest}", + status=200, + body=json.dumps(mock_dockerhub_responses["get_config"]).encode("utf-8"), + payload=mock_dockerhub_responses["get_config"], + repeat=True, + ) + + with aioresponses() as mocked: + setup_dockerhub_mocking(mocked) + + context = get_graphquery_context(root_ctx.background_task_manager, root_ctx.db) + image_rescan_query = """ + mutation ($registry: String!) { + rescan_images(registry: $registry) { + ok + msg + task_id + } + } + """ + variables = { + "registry": "mock_registry", + } + + res = await client.execute_async(image_rescan_query, context=context, variables=variables) + assert res["data"]["rescan_images"]["ok"] + + await done_event.wait() + # Even if the response value is ok: true, the rescan background task might have failed. + # So we need to separately verify whether the actual task was successful. + assert str(done_handler_ctx["task_id"]) == res["data"]["rescan_images"]["task_id"] + + async with root_ctx.db.begin_readonly_session() as db_session: + target_registry_id = extra_fixtures["container_registries"][0]["id"] + res = await db_session.execute( + sa.select(sa.exists().where(ImageRow.registry_id == target_registry_id)) + ) + image_row_populated = res.scalar() + assert image_row_populated diff --git a/tests/manager/test_image.py b/tests/manager/test_image.py index 23e5a5cad12..678efdd38f8 100644 --- a/tests/manager/test_image.py +++ b/tests/manager/test_image.py @@ -1,70 +1,10 @@ -import uuid from pathlib import Path import pytest -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker - -from ai.backend.manager.models import ImageAliasRow, ImageRow -from ai.backend.manager.models.base import metadata as old_metadata -from ai.backend.manager.models.utils import regenerate_table column_keys = ["nullable", "index", "unique", "primary_key"] -@pytest.fixture -async def virtual_image_db(): - engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=True) - base = declarative_base() - metadata = base.metadata - - regenerate_table(old_metadata.tables["images"], metadata) - regenerate_table(old_metadata.tables["image_aliases"], metadata) - ImageAliasRow.metadata = metadata - ImageRow.metadata = metadata - async_session = sessionmaker(engine, class_=AsyncSession, autoflush=False) - async with engine.begin() as conn: - await conn.run_sync(metadata.create_all) - await conn.commit() - async with async_session() as session: - image_1 = ImageRow( - name="index.docker.io/lablup/test-python:latest", - architecture="x86_64", - registry_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), - registry="index.docker.io", - image="lablup/test-python", - tag="latest", - config_digest="sha256:2d577a600afe2d1b38d78bc2ee5abe3bd350890d0652e48096249694e074f9c3", - size_bytes=123123123, - type="COMPUTE", - accelerators="", - labels={}, - resources={}, - ) - image_1.id = uuid.uuid4() - image_2 = ImageRow( - name="index.docker.io/lablup/test-python:3.6-debian", - architecture="aarch64", - registry_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), - registry="index.docker.io", - image="lablup/test-python", - tag="3.6-debian", - config_digest="sha256:2d577a600afe2d1b38d78bc2ee5abe3bd350890d0652e48096249694e074f9c3", - size_bytes=123123123, - type="COMPUTE", - accelerators="", - labels={}, - resources={}, - ) - image_2.id = uuid.uuid4() - session.add(image_1) - session.add(image_2) - await session.commit() - yield async_session - await engine.dispose() - - @pytest.fixture async def image_aliases(tmpdir): content = """