Skip to content

Commit

Permalink
refactor(BA-462): Container registry project traversal logic of image…
Browse files Browse the repository at this point in the history
… rescanning (#2979) (#3472)

Co-authored-by: Gyubong Lee <jopemachine@naver.com>
  • Loading branch information
lablup-octodog and jopemachine authored Jan 16, 2025
1 parent f57f392 commit 15c7530
Show file tree
Hide file tree
Showing 14 changed files with 505 additions and 262 deletions.
1 change: 1 addition & 0 deletions changes/2979.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor container registries' projects traversal logic of the image rescanning.
7 changes: 5 additions & 2 deletions src/ai/backend/common/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import (
Final,
Iterable,
Literal,
Mapping,
NamedTuple,
Optional,
Expand Down Expand Up @@ -437,7 +438,9 @@ def parse_image_tag(
return image_str, tag

@classmethod
def parse_image_str(cls, image_str: str, registry: str | None = None) -> ParsedImageStr:
def parse_image_str(
cls, image_str: str, registry: str | Literal["*"] | None = None
) -> ParsedImageStr:
"""
Parses a string representing an image.
Expand All @@ -458,7 +461,7 @@ def parse_image_str(cls, image_str: str, registry: str | None = None) -> ParsedI
if "://" in image_str or image_str.startswith("//"):
raise InvalidImageName(image_str)

def divide_parts(image_str: str, registry: str | None) -> tuple[str, str]:
def divide_parts(image_str: str, registry: str | Literal["*"] | None) -> tuple[str, str]:
if "/" not in image_str:
return (default_registry, image_str)

Expand Down
13 changes: 10 additions & 3 deletions src/ai/backend/manager/container_registry/aws_ecr.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
from typing import AsyncIterator
from typing import AsyncIterator, override

import aiohttp
import boto3

from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.exceptions import ContainerRegistryProjectEmpty

from .base import (
BaseContainerRegistry,
Expand All @@ -14,10 +15,14 @@


class AWSElasticContainerRegistry(BaseContainerRegistry):
@override
async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
) -> AsyncIterator[str]:
if not self.registry_info.project:
raise ContainerRegistryProjectEmpty(self.registry_info.type, self.registry_info.project)

access_key, secret_access_key, region, type_ = (
self.registry_info.extra.get("access_key"),
self.registry_info.extra.get("secret_access_key"),
Expand All @@ -43,10 +48,12 @@ async def fetch_repositories(
for repo in response["repositories"]:
match type_:
case "ecr":
yield repo["repositoryName"]
if repo["repositoryName"].startswith(self.registry_info.project):
yield repo["repositoryName"]
case "ecr-public":
registry_alias = (repo["repositoryUri"].split("/"))[1]
yield f"{registry_alias}/{repo['repositoryName']}"
if self.registry_info.project == registry_alias:
yield f"{registry_alias}/{repo['repositoryName']}"
case _:
raise ValueError(f"Unknown registry type: {type_}")

Expand Down
87 changes: 35 additions & 52 deletions src/ai/backend/manager/container_registry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import sqlalchemy as sa
import trafaret as t
import yarl
from sqlalchemy.orm import load_only

from ai.backend.common.bgtask import ProgressReporter
from ai.backend.common.docker import (
Expand Down Expand Up @@ -147,61 +146,45 @@ async def commit_rescan_result(self) -> None:
image_row.resources = update["resources"]
image_row.is_local = is_local

registries = cast(
list[ContainerRegistryRow],
(
await session.scalars(
sa.select(ContainerRegistryRow).options(
load_only(
ContainerRegistryRow.project,
ContainerRegistryRow.registry_name,
ContainerRegistryRow.url,
)
)
)
).all(),
)

for image_identifier, update in _all_updates.items():
for registry in registries:
try:
parsed_img = ImageRef.from_image_str(
image_identifier.canonical, registry.project, registry.registry_name
)
except ProjectMismatchWithCanonical:
continue
except ValueError as e:
skip_reason = str(e)
progress_msg = f"Skipped image - {image_identifier.canonical}/{image_identifier.architecture} ({skip_reason})"
log.warning(progress_msg)
break

session.add(
ImageRow(
name=parsed_img.canonical,
project=registry.project,
registry=parsed_img.registry,
registry_id=registry.id,
image=join_non_empty(parsed_img.project, parsed_img.name, sep="/"),
tag=parsed_img.tag,
architecture=image_identifier.architecture,
is_local=is_local,
config_digest=update["config_digest"],
size_bytes=update["size_bytes"],
type=ImageType.COMPUTE,
accelerators=update.get("accels"),
labels=update["labels"],
resources=update["resources"],
)
try:
parsed_img = ImageRef.from_image_str(
image_identifier.canonical,
self.registry_info.project,
self.registry_info.registry_name,
is_local=is_local,
)
progress_msg = f"Updated image - {parsed_img.canonical}/{image_identifier.architecture} ({update['config_digest']})"
log.info(progress_msg)
break

else:
skip_reason = "No container registry found matching the image."
except ProjectMismatchWithCanonical:
continue
except ValueError as e:
skip_reason = str(e)
progress_msg = f"Skipped image - {image_identifier.canonical}/{image_identifier.architecture} ({skip_reason})"
log.warning(progress_msg)
if (reporter := progress_reporter.get()) is not None:
await reporter.update(1, message=progress_msg)

continue

session.add(
ImageRow(
name=parsed_img.canonical,
project=self.registry_info.project,
registry=parsed_img.registry,
registry_id=self.registry_info.id,
image=join_non_empty(parsed_img.project, parsed_img.name, sep="/"),
tag=parsed_img.tag,
architecture=image_identifier.architecture,
is_local=is_local,
config_digest=update["config_digest"],
size_bytes=update["size_bytes"],
type=ImageType.COMPUTE,
accelerators=update.get("accels"),
labels=update["labels"],
resources=update["resources"],
)
)
progress_msg = f"Updated image - {parsed_img.canonical}/{image_identifier.architecture} ({update['config_digest']})"
log.info(progress_msg)

if (reporter := progress_reporter.get()) is not None:
await reporter.update(1, message=progress_msg)
Expand Down
11 changes: 9 additions & 2 deletions src/ai/backend/manager/container_registry/docker.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json
import logging
from typing import AsyncIterator, Optional, cast
from typing import AsyncIterator, Optional, cast, override

import aiohttp
import typing_extensions
import yarl

from ai.backend.common.docker import login as registry_login
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.exceptions import ContainerRegistryProjectEmpty

from .base import BaseContainerRegistry

Expand All @@ -18,6 +19,7 @@ class DockerHubRegistry(BaseContainerRegistry):
@typing_extensions.deprecated(
"Rescanning a whole Docker Hub account is disabled due to the API rate limit."
)
@override
async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
Expand Down Expand Up @@ -63,10 +65,14 @@ async def fetch_repositories_legacy(


class DockerRegistry_v2(BaseContainerRegistry):
@override
async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
) -> AsyncIterator[str]:
if not self.registry_info.project:
raise ContainerRegistryProjectEmpty(self.registry_info.type, self.registry_info.project)

# The credential should have the catalog search privilege.
rqst_args = await registry_login(
sess,
Expand All @@ -83,7 +89,8 @@ async def fetch_repositories(
if resp.status == 200:
data = json.loads(await resp.read())
for item in data["repositories"]:
yield item
if item.startswith(self.registry_info.project):
yield item
log.debug("found {} repositories", len(data["repositories"]))
else:
log.warning(
Expand Down
18 changes: 10 additions & 8 deletions src/ai/backend/manager/container_registry/github.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from typing import AsyncIterator
from typing import AsyncIterator, override

import aiohttp

from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.exceptions import ContainerRegistryProjectEmpty

from .base import (
BaseContainerRegistry,
Expand All @@ -13,18 +14,19 @@


class GitHubRegistry(BaseContainerRegistry):
async def fetch_repositories(
self,
sess: aiohttp.ClientSession,
) -> AsyncIterator[str]:
username = self.registry_info.username
@override
async def fetch_repositories(self, sess: aiohttp.ClientSession) -> AsyncIterator[str]:
if not self.registry_info.project:
raise ContainerRegistryProjectEmpty(self.registry_info.type, self.registry_info.project)

project = self.registry_info.project
access_token = self.registry_info.password
entity_type = self.registry_info.extra.get("entity_type", None)

if entity_type is None:
raise RuntimeError("Entity type is not provided for GitHub registry!")

base_url = f"https://api.github.com/{entity_type}/{username}/packages"
base_url = f"https://api.github.com/{entity_type}/{project}/packages"

headers = {
"Authorization": f"Bearer {access_token}",
Expand All @@ -41,7 +43,7 @@ async def fetch_repositories(
if response.status == 200:
data = await response.json()
for repo in data:
yield f"{username}/{repo['name']}"
yield f"{project}/{repo['name']}"
if "next" in response.links:
page += 1
else:
Expand Down
74 changes: 33 additions & 41 deletions src/ai/backend/manager/container_registry/gitlab.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import logging
import urllib.parse
from typing import AsyncIterator, cast
from typing import AsyncIterator, override

import aiohttp
import sqlalchemy as sa

from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.container_registry import ContainerRegistryRow
from ai.backend.manager.exceptions import ContainerRegistryProjectEmpty

from .base import (
BaseContainerRegistry,
Expand All @@ -16,49 +15,42 @@


class GitLabRegistry(BaseContainerRegistry):
@override
async def fetch_repositories(self, sess: aiohttp.ClientSession) -> AsyncIterator[str]:
if not self.registry_info.project:
raise ContainerRegistryProjectEmpty(self.registry_info.type, self.registry_info.project)

access_token = self.registry_info.password
api_endpoint = self.registry_info.extra.get("api_endpoint", None)

if api_endpoint is None:
raise RuntimeError('"api_endpoint" is not provided for GitLab registry!')

async with self.db.begin_readonly_session() as db_sess:
result = await db_sess.execute(
sa.select(ContainerRegistryRow.project).where(
ContainerRegistryRow.registry_name == self.registry_info.registry_name
)
)
projects = cast(list[str], result.scalars().all())

for project in projects:
encoded_project_id = urllib.parse.quote(project, safe="")
repo_list_url = (
f"{api_endpoint}/api/v4/projects/{encoded_project_id}/registry/repositories"
)

headers = {
"Accept": "application/json",
"PRIVATE-TOKEN": access_token,
}
page = 1

while True:
async with sess.get(
repo_list_url,
headers=headers,
params={"per_page": 30, "page": page},
) as response:
if response.status == 200:
data = await response.json()

for repo in data:
yield repo["path"]
if "next" in response.headers.get("Link", ""):
page += 1
else:
break
encoded_project_id = urllib.parse.quote(self.registry_info.project, safe="")
repo_list_url = f"{api_endpoint}/api/v4/projects/{encoded_project_id}/registry/repositories"

headers = {
"Accept": "application/json",
"PRIVATE-TOKEN": access_token,
}
page = 1

while True:
async with sess.get(
repo_list_url,
headers=headers,
params={"per_page": 30, "page": page},
) as response:
if response.status == 200:
data = await response.json()

for repo in data:
yield repo["path"]
if "next" in response.headers.get("Link", ""):
page += 1
else:
raise RuntimeError(
f"Failed to fetch repositories for project {project}! {response.status} error occurred."
)
break
else:
raise RuntimeError(
f"Failed to fetch repositories for project {self.registry_info.project}! {response.status} error occurred."
)
Loading

0 comments on commit 15c7530

Please sign in to comment.