Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for auth tokens. #336

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 113 additions & 13 deletions runners/mlcube_singularity/mlcube_singularity/singularity_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import json
import logging
import platform
import typing as t
Expand Down Expand Up @@ -63,6 +65,31 @@ def __init__(
self.tag: t.Optional[str] = tag
self.digest: t.Optional[str] = digest

def resolve_host(self) -> str:
"""Return `canonical` host (registry URL without protocol).
Returns:
Canonical registry URL, examples are `docker.io`, `nvcr.io`, `docker.synapse.org`.
"""
if not self.host or self.host == "docker.io":
return "docker.io"
host: str = self.host[8:] if self.host.startswith("https://") else self.host
return host

def resolve_registry_url(self) -> str:
"""Return registry URL.

Return `https://registry-1.docker.io` for docker hub or https for canonical host name.
"""
host: str = self.resolve_host()
return (
"https://registry-1.docker.io" if host == "docker.io" else f"https://{host}"
)

def resolve_auths_url(self) -> str:
"""Return possible key in `auths` section of docker's JSON config file."""
host: str = self.resolve_host()
return "https://index.docker.io/v1/" if host == "docker.io" else host

def __str__(self) -> str:
name: str = ""
if self.host:
Expand Down Expand Up @@ -451,16 +478,14 @@ def get_manifest(self, image: t.Union[str, DockerImage]) -> t.Dict:
image,
)

if not image.host or image.host == "docker.io":
repository_url = "https://registry-1.docker.io"
else:
repository_url = image.host
if not repository_url.startswith("https://"):
repository_url = f"https://{repository_url}"
registry_url: str = image.resolve_registry_url()
auth_key: str = image.resolve_auths_url()

logger.debug(
"DockerHubClient.get_manifest repository_url=%s (image.host=%s).",
repository_url,
"DockerHubClient.get_manifest resolved image host (%s) to registry_url=%s and auth_key=%s.",
image.host,
registry_url,
auth_key,
)

if len(image.path) == 1:
Expand All @@ -473,7 +498,7 @@ def get_manifest(self, image: t.Union[str, DockerImage]) -> t.Dict:
name: str = "/".join(image.path)
reference: str = (image.digest or image.tag) or "latest"

url = f"{repository_url}/v2/{name}/manifests/{reference}"
url = f"{registry_url}/v2/{name}/manifests/{reference}"
headers = {
"Accept": "application/vnd.docker.distribution.manifest.v2+json," # single-arch image
"application/vnd.oci.image.index.v1+json," # multi-arch image
Expand All @@ -487,7 +512,7 @@ def get_manifest(self, image: t.Union[str, DockerImage]) -> t.Dict:
response.headers,
)
token = _get_authentication_token(
response.headers.get("www-authenticate", None)
response.headers.get("www-authenticate", None), auth_key
)
headers["Authorization"] = f"Bearer {token}"
response = requests.get(url, headers=headers)
Expand Down Expand Up @@ -529,12 +554,13 @@ def get_manifest(self, image: t.Union[str, DockerImage]) -> t.Dict:
return response


def _get_authentication_token(www_authenticate: t.Optional[str]) -> str:
def _get_authentication_token(www_authenticate: t.Optional[str], auth_key: str) -> str:
"""Retrieve bearer authentication token.

Args:
www_authenticate: A string that contains endpoint details where token must be requested. Must start with
`Bearer`: `Bearer realm="https://nvcr.io/proxy_auth",scope="repository:nvidia/pytorch:pull,push"`.
auth_key: Docker registry key in `auths` dictionary (for instance, in ~/.docker/config.json).

Returns:
Authentication token that can be used with docker registry API.
Expand All @@ -552,19 +578,29 @@ def _get_authentication_token(www_authenticate: t.Optional[str]) -> str:
)

url: t.Optional[str] = parsed.pop("realm", None)
loggable_url: t.Optional[str] = url

auth_token: t.Optional[str] = _get_auth_token(auth_key)
if auth_token:
logger.info("_get_authentication_token using auth token from config file.")
if url.startswith("https://"):
url = url[8:]
loggable_url = f"https://***:***@{url}"
url = "https://" + base64.b64decode(auth_token).decode() + "@" + url

if not url:
raise MLCubeError(
f"_get_authentication_token unrecognized www_authenticate format (www_authenticate={www_authenticate}, "
f"parsed={parsed})."
)
logger.debug(
"_get_authentication_token requesting token at %s for %s.", url, parsed
"_get_authentication_token requesting token at %s for %s.", loggable_url, parsed
)

response = requests.get(url, params=parsed)
if response.status_code != 200:
raise MLCubeError(
f"_get_authentication_token could not retrieve authentication token (url={url}, params={parsed}, "
f"_get_authentication_token could not retrieve authentication token (url={loggable_url}, params={parsed}, "
f"status_code={response.status_code}, content={response.json()}, headers={response.headers})"
)
token = response.json()["token"]
Expand Down Expand Up @@ -612,6 +648,70 @@ def _select_manifest(
return None, (system, arch)


def _get_auth_token(auth_key: str) -> t.Optional[str]:
"""Return (if found) auth token for docker registry.

In order for auth token to be present, users need to log in (e.g. `docker login nvcr.io`). It is assumed that
these tokens (which are base64-encoded "USER:PASSWORD" strings) are stored unencrypted in one of JSON config files
(e.g., ~/.docker/config.json).

Args:
auth_key: A key under which this function searches for the auth token (see `DockerImage.resolve_auths_url`).

Returns:
Auth token if found, else None.
"""
config_files: t.List[Path] = [
Path.home() / ".singularity" / "docker-config.json",
Path.home() / ".apptainer" / "docker-config.json",
Path.home() / ".docker" / "config.json",
]

for config_file in config_files:
config_file = config_file.expanduser().resolve().absolute()

if not config_file.is_file():
logger.debug("_get_auth_token %s does not exist.", config_file.as_posix())
continue

with open(config_file, "rt") as fp:
config: t.Dict = json.load(fp)
if not (
isinstance(config, dict) and isinstance(config.get("auths", None), dict)
):
logger.debug(
"_get_auth_token % exists but content is unsupported.",
config_file.as_posix(),
)
continue

auths: t.Dict = config["auths"]

def _get_auth(_config_file: Path, _key: str) -> t.Optional[str]:
if _key in auths:
if isinstance(auths[_key], dict) and "auth" in auths[_key]:
logger.info(
"%s contains auth token for %s", _config_file.as_posix(), _key
)
return auths[_key]["auth"]
return None

auth: t.Optional[str] = _get_auth(config_file, auth_key)
if not auth:
if auth_key.startswith("https://"):
auth = _get_auth(config_file, auth_key[8:])
else:
auth = _get_auth(config_file, "https://" + auth_key)
if auth:
return auth

logger.debug(
"%s does not contain auth token for %s", config_file.as_posix(), auth_key
)

return None


def parse_key_value_string(kv_str: str) -> t.Dict:
"""Parse key-value string into dictionaries.

Expand Down