Skip to content

Commit

Permalink
Refactor backend artifacts implementation (#191)
Browse files Browse the repository at this point in the history
* Use common backend logic for listing artifacts

* Refactor aws backend to use common artifact upload/download logic

* Implement artifacts upload/download for local backend

Closes #173.

* Implement signed upload/download urls for aws
  • Loading branch information
r4victor authored Feb 15, 2023
1 parent 7ad54c2 commit dd63876
Show file tree
Hide file tree
Showing 21 changed files with 423 additions and 414 deletions.
12 changes: 6 additions & 6 deletions cli/dstack/api/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@ def list_artifacts_with_merged_backends(
) -> List[Tuple[Artifact, List[Backend]]]:
artifacts = list_artifacts(backends, repo_address, run_name)

artifact_name_file_to_artifact_map = {(a.name, a.file): a for a, _ in artifacts}
artifact_name_to_artifact_map = {a.name: a for a, _ in artifacts}

artifact_name_file_to_backends_map = defaultdict(list)
artifact_name_to_backends_map = defaultdict(list)
for artifact, backend in artifacts:
artifact_name_file_to_backends_map[(artifact.name, artifact.file)].append(backend)
artifact_name_to_backends_map[artifact.name].append(backend)

artifacts_with_merged_backends = []
for artifact_name_file in artifact_name_file_to_artifact_map:
for artifact_name in artifact_name_to_artifact_map:
artifacts_with_merged_backends.append(
(
artifact_name_file_to_artifact_map[artifact_name_file],
artifact_name_file_to_backends_map[artifact_name_file],
artifact_name_to_artifact_map[artifact_name],
artifact_name_to_backends_map[artifact_name],
)
)
return artifacts_with_merged_backends
Expand Down
47 changes: 23 additions & 24 deletions cli/dstack/backend/aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from pathlib import Path
from typing import Dict, Generator, List, Optional, Tuple
from typing import Dict, Generator, List, Optional

import boto3
from botocore.client import BaseClient

from dstack.backend.aws import artifacts, config, logs, tags
from dstack.backend.aws import config, logs
from dstack.backend.aws.compute import AWSCompute
from dstack.backend.aws.config import AWSConfig
from dstack.backend.aws.secrets import AWSSecretsManager
from dstack.backend.aws.storage import AWSStorage
from dstack.backend.base import RemoteBackend
from dstack.backend.base import CloudBackend
from dstack.backend.base import artifacts as base_artifacts
from dstack.backend.base import jobs as base_jobs
from dstack.backend.base import repos as base_repos
from dstack.backend.base import runs as base_runs
Expand All @@ -20,19 +21,18 @@
from dstack.core.error import ConfigError
from dstack.core.job import Job, JobHead
from dstack.core.log_event import LogEvent
from dstack.core.repo import LocalRepoData, RepoAddress, RepoCredentials, RepoHead
from dstack.core.repo import LocalRepoData, RepoAddress, RepoCredentials
from dstack.core.run import RunHead
from dstack.core.secret import Secret
from dstack.core.tag import TagHead


class AwsBackend(RemoteBackend):
class AwsBackend(CloudBackend):
@property
def name(self):
return "aws"

def __init__(self, backend_config: Optional[BackendConfig] = None):
super().__init__(backend_config)
if backend_config is None:
self.backend_config = AWSConfig()
try:
Expand Down Expand Up @@ -153,27 +153,20 @@ def poll_logs(
attached,
)

def list_run_artifact_files(
self, repo_address: RepoAddress, run_name: str
) -> Generator[Artifact, None, None]:
return artifacts.list_run_artifact_files(
self._s3_client(), self.backend_config.bucket_name, repo_address, run_name
)
def list_run_artifact_files(self, repo_address: RepoAddress, run_name: str) -> List[Artifact]:
return base_artifacts.list_run_artifact_files(self._storage, repo_address, run_name)

def download_run_artifact_files(
self,
repo_address: RepoAddress,
run_name: str,
output_dir: Optional[str],
output_job_dirs: bool = True,
):
artifacts.download_run_artifact_files(
self._s3_client(),
self.backend_config.bucket_name,
repo_address,
run_name,
output_dir,
output_job_dirs,
base_artifacts.download_run_artifact_files(
storage=self._storage,
repo_address=repo_address,
run_name=run_name,
output_dir=output_dir,
)

def upload_job_artifact_files(
Expand All @@ -183,9 +176,8 @@ def upload_job_artifact_files(
artifact_name: str,
local_path: Path,
):
artifacts.upload_job_artifact_files(
s3_client=self._s3_client(),
bucket_name=self.backend_config.bucket_name,
base_artifacts.upload_job_artifact_files(
storage=self._storage,
repo_address=repo_address,
job_id=job_id,
artifact_name=artifact_name,
Expand Down Expand Up @@ -216,11 +208,12 @@ def add_tag_from_run(
def add_tag_from_local_dirs(
self, repo_data: LocalRepoData, tag_name: str, local_dirs: List[str]
):
tags.create_tag_from_local_dirs(
base_tags.create_tag_from_local_dirs(
self._storage,
repo_data,
tag_name,
local_dirs,
self.type,
)

def delete_tag_head(self, repo_address: RepoAddress, tag_head: TagHead):
Expand Down Expand Up @@ -272,3 +265,9 @@ def delete_secret(self, repo_address: RepoAddress, secret_name: str):
repo_address,
secret_name,
)

def get_signed_download_url(self, object_key: str) -> str:
return self._storage.get_signed_download_url(object_key)

def get_signed_upload_url(self, object_key: str) -> str:
return self._storage.get_signed_upload_url(object_key)
157 changes: 0 additions & 157 deletions cli/dstack/backend/aws/artifacts.py

This file was deleted.

59 changes: 56 additions & 3 deletions cli/dstack/backend/aws/storage.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional

import botocore.exceptions
from boto3.s3 import transfer
from botocore.client import BaseClient

from dstack.backend.base.storage import Storage
from dstack.backend.base.storage import SIGNED_URL_EXPIRATION, CloudStorage
from dstack.core.storage import StorageFile


class AWSStorage(Storage):
class AWSStorage(CloudStorage):
def __init__(self, s3_client: BaseClient, bucket_name: str):
self.s3_client = s3_client
self.bucket_name = bucket_name
Expand Down Expand Up @@ -39,3 +41,54 @@ def list_objects(self, keys_prefix: str) -> List[str]:
for obj_metadata in response["Contents"]:
object_keys.append(obj_metadata["Key"])
return object_keys

def list_files(self, dirpath: str) -> List[StorageFile]:
prefix = dirpath
paginator = self.s3_client.get_paginator("list_objects")
page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix)
files = []
for page in page_iterator:
for obj in page.get("Contents") or []:
if obj["Size"] > 0:
filepath = obj["Key"]
files.append(
StorageFile(
filepath=filepath.removeprefix(prefix),
filesize_in_bytes=obj["Size"],
)
)
return files

def download_file(self, source_path: str, dest_path: str, callback: Callable[[int], None]):
downloader = transfer.S3Transfer(
self.s3_client, transfer.TransferConfig(), transfer.OSUtils()
)
downloader.download_file(self.bucket_name, source_path, dest_path, callback=callback)

def upload_file(self, source_path: str, dest_path: str, callback: Callable[[int], None]):
uploader = transfer.S3Transfer(
self.s3_client, transfer.TransferConfig(), transfer.OSUtils()
)
uploader.upload_file(source_path, self.bucket_name, dest_path, callback)

def get_signed_download_url(self, key: str) -> str:
url = self.s3_client.generate_presigned_url(
"get_object",
Params={
"Bucket": self.bucket_name,
"Key": key,
},
ExpiresIn=SIGNED_URL_EXPIRATION,
)
return url

def get_signed_upload_url(self, key: str) -> str:
url = self.s3_client.generate_presigned_url(
"put_object",
Params={
"Bucket": self.bucket_name,
"Key": key,
},
ExpiresIn=SIGNED_URL_EXPIRATION,
)
return url
Loading

0 comments on commit dd63876

Please sign in to comment.