diff --git a/sky/adaptors/cloudflare.py b/sky/adaptors/cloudflare.py new file mode 100644 index 00000000000..145665d701b --- /dev/null +++ b/sky/adaptors/cloudflare.py @@ -0,0 +1,121 @@ +"""Cloudflare cloud adaptors""" + +# pylint: disable=import-outside-toplevel + +import functools +import threading +import os + +boto3 = None +botocore = None +_session_creation_lock = threading.RLock() +ACCOUNT_ID_PATH = '~/.cloudflare/accountid' +R2_PROFILE_NAME = 'r2' + + +def import_package(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + global boto3, botocore + if boto3 is None or botocore is None: + try: + import boto3 as _boto3 + import botocore as _botocore + boto3 = _boto3 + botocore = _botocore + except ImportError: + raise ImportError('Fail to import dependencies for Cloudflare.' + 'Try pip install "skypilot[aws]"') from None + return func(*args, **kwargs) + + return wrapper + + +# lru_cache() is thread-safe and it will return the same session object +# for different threads. +# Reference: https://docs.python.org/3/library/functools.html#functools.lru_cache # pylint: disable=line-too-long +@functools.lru_cache() +@import_package +def session(): + """Create an AWS session.""" + # Creating the session object is not thread-safe for boto3, + # so we add a reentrant lock to synchronize the session creation. + # Reference: https://github.com/boto/boto3/issues/1592 + # However, the session object itself is thread-safe, so we are + # able to use lru_cache() to cache the session object. + with _session_creation_lock: + return boto3.session.Session(profile_name=R2_PROFILE_NAME) + + +@functools.lru_cache() +@import_package +def resource(resource_name: str, **kwargs): + """Create a Cloudflare resource. + + Args: + resource_name: Cloudflare resource name (e.g., 's3'). + kwargs: Other options. + """ + # Need to use the resource retrieved from the per-thread session + # to avoid thread-safety issues (Directly creating the client + # with boto3.resource() is not thread-safe). + # Reference: https://stackoverflow.com/a/59635814 + + session_ = session() + cloudflare_credentials = session_.get_credentials().get_frozen_credentials() + endpoint = create_endpoint() + + return session_.resource( + resource_name, + endpoint_url=endpoint, + aws_access_key_id=cloudflare_credentials.access_key, + aws_secret_access_key=cloudflare_credentials.secret_key, + region_name='auto', + **kwargs) + + +@functools.lru_cache() +def client(service_name: str, region): + """Create an CLOUDFLARE client of a certain service. + + Args: + service_name: CLOUDFLARE service name (e.g., 's3'). + kwargs: Other options. + """ + # Need to use the client retrieved from the per-thread session + # to avoid thread-safety issues (Directly creating the client + # with boto3.client() is not thread-safe). + # Reference: https://stackoverflow.com/a/59635814 + + session_ = session() + cloudflare_credentials = session_.get_credentials().get_frozen_credentials() + endpoint = create_endpoint() + + return session_.client( + service_name, + endpoint_url=endpoint, + aws_access_key_id=cloudflare_credentials.access_key, + aws_secret_access_key=cloudflare_credentials.secret_key, + region_name=region) + + +@import_package +def botocore_exceptions(): + """AWS botocore exception.""" + from botocore import exceptions + return exceptions + + +def create_endpoint(): + """Reads accountid necessary to interact with R2""" + + accountid_path = os.path.expanduser(ACCOUNT_ID_PATH) + with open(accountid_path, 'r') as f: + lines = f.readlines() + accountid = lines[0] + + accountid = accountid.strip() + endpoint = 'https://' + accountid + '.r2.cloudflarestorage.com' + + return endpoint diff --git a/sky/cloud_stores.py b/sky/cloud_stores.py index f9197a7960e..879d06ec18f 100644 --- a/sky/cloud_stores.py +++ b/sky/cloud_stores.py @@ -12,7 +12,7 @@ from sky.clouds import gcp from sky.data import data_utils -from sky.adaptors import aws +from sky.adaptors import aws, cloudflare class CloudStorage: @@ -146,6 +146,65 @@ def make_sync_file_command(self, source: str, destination: str) -> str: return ' && '.join(all_commands) +class R2CloudStorage(CloudStorage): + """Cloudflare Cloud Storage.""" + + # List of commands to install AWS CLI + _GET_AWSCLI = [ + 'aws --version >/dev/null 2>&1 || pip3 install awscli', + ] + + def is_directory(self, url: str) -> bool: + """Returns whether R2 'url' is a directory. + + In cloud object stores, a "directory" refers to a regular object whose + name is a prefix of other objects. + """ + r2 = cloudflare.resource('s3') + bucket_name, path = data_utils.split_r2_path(url) + bucket = r2.Bucket(bucket_name) + + num_objects = 0 + for obj in bucket.objects.filter(Prefix=path): + num_objects += 1 + if obj.key == path: + return False + # If there are more than 1 object in filter, then it is a directory + if num_objects == 3: + return True + + # A directory with few or no items + return True + + def make_sync_dir_command(self, source: str, destination: str) -> str: + """Downloads using AWS CLI.""" + # AWS Sync by default uses 10 threads to upload files to the bucket. + # To increase parallelism, modify max_concurrent_requests in your + # aws config file (Default path: ~/.aws/config). + endpoint_url = cloudflare.create_endpoint() + if 'r2://' in source: + source = source.replace('r2://', 's3://') + download_via_awscli = ('aws s3 sync --no-follow-symlinks ' + f'{source} {destination} ' + f'--endpoint {endpoint_url} ' + f'--profile={cloudflare.R2_PROFILE_NAME}') + + all_commands = list(self._GET_AWSCLI) + all_commands.append(download_via_awscli) + return ' && '.join(all_commands) + + def make_sync_file_command(self, source: str, destination: str) -> str: + """Downloads a file using AWS CLI.""" + endpoint_url = cloudflare.create_endpoint() + download_via_awscli = (f'aws s3 cp s3://{source} {destination} ' + f'--endpoint {endpoint_url} ' + f'--profile={cloudflare.R2_PROFILE_NAME}') + + all_commands = list(self._GET_AWSCLI) + all_commands.append(download_via_awscli) + return ' && '.join(all_commands) + + def get_storage_from_path(url: str) -> CloudStorage: """Returns a CloudStorage by identifying the scheme:// in a URL.""" result = urllib.parse.urlsplit(url) @@ -159,4 +218,5 @@ def get_storage_from_path(url: str) -> CloudStorage: _REGISTRY = { 'gs': GcsCloudStorage(), 's3': S3CloudStorage(), + 'r2': R2CloudStorage() } diff --git a/sky/data/data_transfer.py b/sky/data/data_transfer.py index f4ad1b9e5f7..b7fb8f0a8c9 100644 --- a/sky/data/data_transfer.py +++ b/sky/data/data_transfer.py @@ -13,6 +13,7 @@ TODO: - All combinations of Azure Transfer +- All combinations of R2 Transfer - GCS -> S3 """ import json @@ -116,6 +117,21 @@ def s3_to_gcs(s3_bucket_name: str, gs_bucket_name: str) -> None: f'Transfer finished in {(time.time() - start) / 60:.2f} minutes.') +def s3_to_r2(s3_bucket_name: str, r2_bucket_name: str) -> None: + """Creates a one-time transfer from Amazon S3 to Google Cloud Storage. + + Can be viewed from: https://console.cloud.google.com/transfer/cloud + it will block until the transfer is complete. + + Args: + s3_bucket_name: str; Name of the Amazon S3 Bucket + r2_bucket_name: str; Name of the Cloudflare R2 Bucket + """ + raise NotImplementedError('Moving data directly from clouds to R2 is ' + 'currently not supported. Please specify ' + 'a local source for the storage object.') + + def gcs_to_s3(gs_bucket_name: str, s3_bucket_name: str) -> None: """Creates a one-time transfer from Google Cloud Storage to Amazon S3. @@ -129,6 +145,48 @@ def gcs_to_s3(gs_bucket_name: str, s3_bucket_name: str) -> None: subprocess.call(sync_command, shell=True) +def gcs_to_r2(gs_bucket_name: str, r2_bucket_name: str) -> None: + """Creates a one-time transfer from Google Cloud Storage to Amazon S3. + + Args: + gs_bucket_name: str; Name of the Google Cloud Storage Bucket + r2_bucket_name: str; Name of the Cloudflare R2 Bucket + """ + raise NotImplementedError('Moving data directly from clouds to R2 is ' + 'currently not supported. Please specify ' + 'a local source for the storage object.') + + +def r2_to_gcs(r2_bucket_name: str, gs_bucket_name: str) -> None: + """Creates a one-time transfer from Cloudflare R2 to Google Cloud Storage. + + Can be viewed from: https://console.cloud.google.com/transfer/cloud + it will block until the transfer is complete. + + Args: + r2_bucket_name: str; Name of the Cloudflare R2 Bucket + gs_bucket_name: str; Name of the Google Cloud Storage Bucket + """ + raise NotImplementedError('Moving data directly from R2 to clouds is ' + 'currently not supported. Please specify ' + 'a local source for the storage object.') + + +def r2_to_s3(r2_bucket_name: str, s3_bucket_name: str) -> None: + """Creates a one-time transfer from Amazon S3 to Google Cloud Storage. + + Can be viewed from: https://console.cloud.google.com/transfer/cloud + it will block until the transfer is complete. + + Args: + r2_bucket_name: str; Name of the Cloudflare R2 Bucket\ + s3_bucket_name: str; Name of the Amazon S3 Bucket + """ + raise NotImplementedError('Moving data directly from R2 to clouds is ' + 'currently not supported. Please specify ' + 'a local source for the storage object.') + + def _add_bucket_iam_member(bucket_name: str, role: str, member: str) -> None: storage_client = gcp.storage_client() bucket = storage_client.bucket(bucket_name) diff --git a/sky/data/data_utils.py b/sky/data/data_utils.py index 754a5a4bfb1..b1048c3ec83 100644 --- a/sky/data/data_utils.py +++ b/sky/data/data_utils.py @@ -8,7 +8,7 @@ from sky import exceptions from sky import sky_logging -from sky.adaptors import aws, gcp +from sky.adaptors import aws, gcp, cloudflare from sky.utils import ux_utils Client = Any @@ -40,6 +40,18 @@ def split_gcs_path(gcs_path: str) -> Tuple[str, str]: return bucket, key +def split_r2_path(r2_path: str) -> Tuple[str, str]: + """Splits R2 Path into Bucket name and Relative Path to Bucket + + Args: + r2_path: str; R2 Path, e.g. r2://imagenet/train/ + """ + path_parts = r2_path.replace('r2://', '').split('/') + bucket = path_parts.pop(0) + key = '/'.join(path_parts) + return bucket, key + + def create_s3_client(region: str = 'us-east-2') -> Client: """Helper method that connects to Boto3 client for S3 Bucket @@ -73,6 +85,26 @@ def verify_gcs_bucket(name: str) -> bool: return False +def create_r2_client(region: str = 'auto') -> Client: + """Helper method that connects to Boto3 client for R2 Bucket + + Args: + region: str; Region for CLOUDFLARE R2 is set to auto + """ + return cloudflare.client('s3', region) + + +def verify_r2_bucket(name: str) -> bool: + """Helper method that checks if the R2 bucket exists + + Args: + name: str; Name of R2 Bucket (without r2:// prefix) + """ + r2 = cloudflare.resource('s3') + bucket = r2.Bucket(name) + return bucket in r2.buckets.all() + + def is_cloud_store_url(url): result = urllib.parse.urlsplit(url) # '' means non-cloud URLs. @@ -118,7 +150,7 @@ def parallel_upload(source_path_list: List[str], max_concurrent_uploads: Optional[int] = None) -> None: """Helper function to run parallel uploads for a list of paths. - Used by S3Store and GCSStore to run rsync commands in parallel by + Used by S3Store, GCSStore, and R2Store to run rsync commands in parallel by providing appropriate command generators. Args: diff --git a/sky/data/storage.py b/sky/data/storage.py index 629ecba5d6c..8439572bcf5 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -12,6 +12,7 @@ from sky import clouds from sky.adaptors import aws from sky.adaptors import gcp +from sky.adaptors import cloudflare from sky.backends import backend_utils from sky.utils import schemas from sky.data import data_transfer @@ -36,6 +37,8 @@ # Clouds with object storage implemented in this module. Azure Blob # Storage isn't supported yet (even though Azure is). +# TODO(Doyoung): need to add clouds.CLOUDFLARE() to support +# R2 to be an option as preferred store type STORE_ENABLED_CLOUDS = [clouds.AWS(), clouds.GCP()] # Maximum number of concurrent rsync upload processes @@ -52,6 +55,7 @@ class StoreType(enum.Enum): S3 = 'S3' GCS = 'GCS' AZURE = 'AZURE' + R2 = 'R2' @classmethod def from_cloud(cls, cloud: clouds.Cloud) -> 'StoreType': @@ -70,6 +74,8 @@ def from_store(cls, store: 'AbstractStore') -> 'StoreType': return StoreType.S3 elif isinstance(store, GcsStore): return StoreType.GCS + elif isinstance(store, R2Store): + return StoreType.R2 else: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Unknown store type: {store}') @@ -101,6 +107,9 @@ def get_store_prefix(storetype: StoreType) -> str: return 's3://' elif storetype == StoreType.GCS: return 'gs://' + # R2 storages use 's3://' as a prefix for various aws cli commands + elif storetype == StoreType.R2: + return 's3://' elif storetype == StoreType.AZURE: with ux_utils.print_exception_no_traceback(): raise ValueError('Azure Blob Storage is not supported yet.') @@ -409,6 +418,9 @@ def __init__(self, elif s_type == StoreType.GCS: store = GcsStore.from_metadata(s_metadata, source=self.source) + elif s_type == StoreType.R2: + store = R2Store.from_metadata(s_metadata, + source=self.source) else: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Unknown store type: {s_type}') @@ -445,6 +457,8 @@ def __init__(self, self.add_store(StoreType.S3) elif self.source.startswith('gs://'): self.add_store(StoreType.GCS) + elif self.source.startswith('r2://'): + self.add_store(StoreType.R2) @staticmethod def _validate_source( @@ -454,8 +468,8 @@ def _validate_source( Args: source: str; File path where the data is initially stored. Can be a - local path or a cloud URI (s3://, gs://, etc.). Local paths do not - need to be absolute. + local path or a cloud URI (s3://, gs://, r2:// etc.). + Local paths do not need to be absolute. mode: StorageMode; StorageMode of the storage object Returns: @@ -525,7 +539,7 @@ def _validate_local_source(local_source): 'using a bucket by writing : ' f'{source} in the file_mounts section of your YAML') is_local_source = True - elif split_path.scheme in ['s3', 'gs']: + elif split_path.scheme in ['s3', 'gs', 'r2']: is_local_source = False # Storage mounting does not support mounting specific files from # cloud store - ensure path points to only a directory @@ -541,7 +555,8 @@ def _validate_local_source(local_source): else: with ux_utils.print_exception_no_traceback(): raise exceptions.StorageSourceError( - f'Supported paths: local, s3://, gs://. Got: {source}') + f'Supported paths: local, s3://, gs://, ' + f'r2://. Got: {source}') return source, is_local_source def _validate_storage_spec(self) -> None: @@ -611,7 +626,7 @@ def add_store(self, store_type: Union[str, StoreType]) -> AbstractStore: add it to Storage. Args: - store_type: StoreType; Type of the storage [S3, GCS, AZURE] + store_type: StoreType; Type of the storage [S3, GCS, AZURE, R2] """ if isinstance(store_type, str): store_type = StoreType(store_type) @@ -625,6 +640,8 @@ def add_store(self, store_type: Union[str, StoreType]) -> AbstractStore: store_cls = S3Store elif store_type == StoreType.GCS: store_cls = GcsStore + elif store_type == StoreType.R2: + store_cls = R2Store else: with ux_utils.print_exception_no_traceback(): raise exceptions.StorageSpecError( @@ -836,6 +853,13 @@ def _validate(self): assert data_utils.verify_gcs_bucket(self.name), ( f'Source specified as {self.source}, a GCS bucket. ', 'GCS Bucket should exist.') + elif self.source.startswith('r2://'): + assert self.name == data_utils.split_r2_path(self.source)[0], ( + 'R2 Bucket is specified as path, the name should be ' + 'the same as R2 bucket.') + assert data_utils.verify_r2_bucket(self.name), ( + f'Source specified as {self.source}, a R2 bucket. ', + 'R2 Bucket should exist.') def initialize(self): """Initializes the S3 store object on the cloud. @@ -874,6 +898,8 @@ def upload(self): pass elif self.source.startswith('gs://'): self._transfer_to_s3() + elif self.source.startswith('r2://'): + self._transfer_to_s3() else: self.batch_aws_rsync([self.source]) except exceptions.StorageUploadError: @@ -949,6 +975,8 @@ def _transfer_to_s3(self) -> None: assert isinstance(self.source, str), self.source if self.source.startswith('gs://'): data_transfer.gcs_to_s3(self.name, self.name) + elif self.source.startswith('r2://'): + data_transfer.r2_to_s3(self.name, self.name) def _get_bucket(self) -> Tuple[StorageHandle, bool]: """Obtains the S3 bucket. @@ -1114,6 +1142,14 @@ def _validate(self): )[0], ( 'GCS Bucket is specified as path, the name should be ' 'the same as GCS bucket.') + elif self.source.startswith('r2://'): + assert self.name == data_utils.split_r2_path( + self.source + )[0], ('R2 Bucket is specified as path, the name should be ' + 'the same as R2 bucket.') + assert data_utils.verify_r2_bucket(self.name), ( + f'Source specified as {self.source}, a R2 bucket. ', + 'R2 Bucket should exist.') def initialize(self): """Initializes the GCS store object on the cloud. @@ -1152,6 +1188,8 @@ def upload(self): pass elif self.source.startswith('s3://'): self._transfer_to_gcs() + elif self.source.startswith('r2://'): + self._transfer_to_gcs() else: # If a single directory is specified in source, upload # contents to root of bucket by suffixing /*. @@ -1262,6 +1300,8 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): def _transfer_to_gcs(self) -> None: if isinstance(self.source, str) and self.source.startswith('s3://'): data_transfer.s3_to_gcs(self.name, self.name) + elif isinstance(self.source, str) and self.source.startswith('r2://'): + data_transfer.r2_to_gcs(self.name, self.name) def _get_bucket(self) -> Tuple[StorageHandle, bool]: """Obtains the GCS bucket. @@ -1386,3 +1426,304 @@ def _delete_gcs_bucket(self, bucket_name: str) -> None: with ux_utils.print_exception_no_traceback(): raise exceptions.StorageBucketDeleteError( f'Failed to delete GCS bucket {bucket_name}.') + + +class R2Store(AbstractStore): + """R2Store inherits from S3Store Object and represents the backend + for R2 buckets. + """ + + ACCESS_DENIED_MESSAGE = 'Access Denied' + + def __init__(self, + name: str, + source: str, + region: Optional[str] = 'auto', + is_sky_managed: Optional[bool] = None): + self.client: 'boto3.client.Client' + self.bucket: 'StorageHandle' + super().__init__(name, source, region, is_sky_managed) + + def _validate(self): + if self.source is not None and isinstance(self.source, str): + if self.source.startswith('s3://'): + assert self.name == data_utils.split_s3_path(self.source)[0], ( + 'S3 Bucket is specified as path, the name should be the' + ' same as S3 bucket.') + assert data_utils.verify_s3_bucket(self.name), ( + f'Source specified as {self.source}, a S3 bucket. ', + 'S3 Bucket should exist.') + elif self.source.startswith('gs://'): + assert self.name == data_utils.split_gcs_path(self.source)[0], ( + 'GCS Bucket is specified as path, the name should be ' + 'the same as GCS bucket.') + assert data_utils.verify_gcs_bucket(self.name), ( + f'Source specified as {self.source}, a GCS bucket. ', + 'GCS Bucket should exist.') + elif self.source.startswith('r2://'): + assert self.name == data_utils.split_r2_path(self.source)[0], ( + 'R2 Bucket is specified as path, the name should be ' + 'the same as R2 bucket.') + + def initialize(self): + """Initializes the R2 store object on the cloud. + + Initialization involves fetching bucket if exists, or creating it if + it does not. + + Raises: + StorageBucketCreateError: If bucket creation fails + StorageBucketGetError: If fetching existing bucket fails + StorageInitError: If general initialization fails. + """ + self.client = data_utils.create_r2_client(self.region) + self.bucket, is_new_bucket = self._get_bucket() + if self.is_sky_managed is None: + # If is_sky_managed is not specified, then this is a new storage + # object (i.e., did not exist in global_user_state) and we should + # set the is_sky_managed property. + # If is_sky_managed is specified, then we take no action. + self.is_sky_managed = is_new_bucket + + def upload(self): + """Uploads source to store bucket. + + Upload must be called by the Storage handler - it is not called on + Store initialization. + + Raises: + StorageUploadError: if upload fails. + """ + try: + if isinstance(self.source, list): + self.batch_aws_rsync(self.source, create_dirs=True) + elif self.source is not None: + if self.source.startswith('s3://'): + self._transfer_to_r2() + elif self.source.startswith('gs://'): + self._transfer_to_r2() + elif self.source.startswith('r2://'): + pass + else: + self.batch_aws_rsync([self.source]) + except exceptions.StorageUploadError: + raise + except Exception as e: + raise exceptions.StorageUploadError( + f'Upload failed for store {self.name}') from e + + def delete(self) -> None: + self._delete_r2_bucket(self.name) + logger.info(f'{colorama.Fore.GREEN}Deleted R2 bucket {self.name}.' + f'{colorama.Style.RESET_ALL}') + + def get_handle(self) -> StorageHandle: + return cloudflare.resource('s3').Bucket(self.name) + + def batch_aws_rsync(self, + source_path_list: List[Path], + create_dirs: bool = False) -> None: + """Invokes aws s3 sync to batch upload a list of local paths to S3 + + AWS Sync by default uses 10 threads to upload files to the bucket. To + increase parallelism, modify max_concurrent_requests in your aws config + file (Default path: ~/.aws/config). + + Since aws s3 sync does not support batch operations, we construct + multiple commands to be run in parallel. + + Args: + source_path_list: List of paths to local files or directories + create_dirs: If the local_path is a directory and this is set to + False, the contents of the directory are directly uploaded to + root of the bucket. If the local_path is a directory and this is + set to True, the directory is created in the bucket root and + contents are uploaded to it. + """ + + def get_file_sync_command(base_dir_path, file_names): + includes = ' '.join( + [f'--include "{file_name}"' for file_name in file_names]) + endpoint_url = cloudflare.create_endpoint() + sync_command = ('aws s3 sync --no-follow-symlinks --exclude="*" ' + f'{includes} {base_dir_path} ' + f's3://{self.name} ' + f'--endpoint {endpoint_url} ' + f'--profile={cloudflare.R2_PROFILE_NAME}') + return sync_command + + def get_dir_sync_command(src_dir_path, dest_dir_name): + # we exclude .git directory from the sync + endpoint_url = cloudflare.create_endpoint() + sync_command = ( + 'aws s3 sync --no-follow-symlinks --exclude ".git/*" ' + f'{src_dir_path} ' + f's3://{self.name}/{dest_dir_name} ' + f'--endpoint {endpoint_url} ' + f'--profile={cloudflare.R2_PROFILE_NAME}') + return sync_command + + # Generate message for upload + if len(source_path_list) > 1: + source_message = f'{len(source_path_list)} paths' + else: + source_message = source_path_list[0] + + with log_utils.safe_rich_status( + f'[bold cyan]Syncing ' + f'[green]{source_message}[/] to [green]r2://{self.name}/[/]'): + data_utils.parallel_upload( + source_path_list, + get_file_sync_command, + get_dir_sync_command, + self.name, + self.ACCESS_DENIED_MESSAGE, + create_dirs=create_dirs, + max_concurrent_uploads=_MAX_CONCURRENT_UPLOADS) + + def _transfer_to_r2(self) -> None: + assert isinstance(self.source, str), self.source + if self.source.startswith('gs://'): + data_transfer.gcs_to_r2(self.name, self.name) + elif self.source.startswith('s3://'): + data_transfer.s3_to_r2(self.name, self.name) + + def _get_bucket(self) -> Tuple[StorageHandle, bool]: + """Obtains the R2 bucket. + + If the bucket exists, this method will connect to the bucket. + If the bucket does not exist, there are two cases: + 1) Raise an error if the bucket source starts with s3:// + 2) Create a new bucket otherwise + + Raises: + StorageBucketCreateError: If creating the bucket fails + StorageBucketGetError: If fetching a bucket fails + """ + r2 = cloudflare.resource('s3') + bucket = r2.Bucket(self.name) + endpoint_url = cloudflare.create_endpoint() + try: + # Try Public bucket case. + # This line does not error out if the bucket is an external public + # bucket or if it is a user's bucket that is publicly + # accessible. + self.client.head_bucket(Bucket=self.name) + return bucket, False + except aws.botocore_exceptions().ClientError as e: + error_code = e.response['Error']['Code'] + # AccessDenied error for buckets that are private and not owned by + # user. + if error_code == '403': + command = (f'aws s3 ls s3://{self.name} ' + f'--endpoint {endpoint_url} ' + f'--profile={cloudflare.R2_PROFILE_NAME}') + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketGetError( + _BUCKET_FAIL_TO_CONNECT_MESSAGE.format(name=self.name) + + f' To debug, consider using {command}.') from e + + if isinstance(self.source, str) and self.source.startswith('r2://'): + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketGetError( + 'Attempted to connect to a non-existent bucket: ' + f'{self.source}. Consider using `aws s3 ls ' + f's3://{self.name} ' + f'--endpoint {endpoint_url} ' + f'--profile={cloudflare.R2_PROFILE_NAME}\' ' + 'to debug.') + + # If bucket cannot be found in both private and public settings, + # the bucket is created by Sky. + bucket = self._create_r2_bucket(self.name) + return bucket, True + + def _download_file(self, remote_path: str, local_path: str) -> None: + """Downloads file from remote to local on r2 bucket + using the boto3 API + + Args: + remote_path: str; Remote path on R2 bucket + local_path: str; Local path on user's device + """ + self.bucket.download_file(remote_path, local_path) + + def mount_command(self, mount_path: str) -> str: + """Returns the command to mount the bucket to the mount_path. + + Uses goofys to mount the bucket. + + Args: + mount_path: str; Path to mount the bucket to. + """ + install_cmd = ('sudo wget -nc https://github.com/romilbhardwaj/goofys/' + 'releases/download/0.24.0-romilb-upstream/goofys ' + '-O /usr/local/bin/goofys && ' + 'sudo chmod +x /usr/local/bin/goofys') + endpoint_url = cloudflare.create_endpoint() + mount_cmd = ( + f'AWS_PROFILE={cloudflare.R2_PROFILE_NAME} goofys -o allow_other ' + f'--stat-cache-ttl {self._STAT_CACHE_TTL} ' + f'--type-cache-ttl {self._TYPE_CACHE_TTL} ' + f'--endpoint {endpoint_url} ' + f'{self.bucket.name} {mount_path}') + return mounting_utils.get_mounting_command(mount_path, install_cmd, + mount_cmd) + + def _create_r2_bucket(self, + bucket_name: str, + region='auto') -> StorageHandle: + """Creates R2 bucket with specific name in specific region + + Args: + bucket_name: str; Name of bucket + region: str; Region name, r2 automatically sets region + Raises: + StorageBucketCreateError: If bucket creation fails. + """ + r2_client = self.client + try: + if region is None: + r2_client.create_bucket(Bucket=bucket_name) + else: + location = {'LocationConstraint': region} + r2_client.create_bucket(Bucket=bucket_name, + CreateBucketConfiguration=location) + logger.info(f'Created R2 bucket {bucket_name} in {region}') + except aws.botocore_exceptions().ClientError as e: + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketCreateError( + f'Attempted to create a bucket ' + f'{self.name} but failed.') from e + return cloudflare.resource('s3').Bucket(bucket_name) + + def _delete_r2_bucket(self, bucket_name: str) -> None: + """Deletes R2 bucket, including all objects in bucket + + Args: + bucket_name: str; Name of bucket + """ + # Deleting objects is very slow programatically + # (i.e. bucket.objects.all().delete() is slow). + # In addition, standard delete operations (i.e. via `aws s3 rm`) + # are slow, since AWS puts deletion markers. + # https://stackoverflow.com/questions/49239351/why-is-it-so-much-slower-to-delete-objects-in-aws-s3-than-it-is-to-create-them + # The fastest way to delete is to run `aws s3 rb --force`, + # which removes the bucket by force. + endpoint_url = cloudflare.create_endpoint() + remove_command = (f'aws s3 rb s3://{bucket_name} --force ' + f'--endpoint {endpoint_url} ' + f'--profile={cloudflare.R2_PROFILE_NAME}') + try: + with log_utils.safe_rich_status( + f'[bold cyan]Deleting R2 bucket {bucket_name}[/]'): + subprocess.check_output(remove_command.split(' ')) + except subprocess.CalledProcessError as e: + logger.error(e.output) + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketDeleteError( + f'Failed to delete R2 bucket {bucket_name}.') + + # Wait until bucket deletion propagates on AWS servers + while data_utils.verify_r2_bucket(bucket_name): + time.sleep(0.1) diff --git a/sky/execution.py b/sky/execution.py index 599ea10b7f5..9d08c0ffe46 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -819,6 +819,8 @@ def _maybe_translate_local_file_mounts_and_sync_up( storage_obj.source = f's3://{storage_obj.name}' elif store_type == storage_lib.StoreType.GCS: storage_obj.source = f'gs://{storage_obj.name}' + elif store_type == storage_lib.StoreType.R2: + storage_obj.source = f'r2://{storage_obj.name}' else: with ux_utils.print_exception_no_traceback(): raise exceptions.NotSupportedError( diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index fb9ef0d6cfb..20fbdd4b40e 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -106,14 +106,15 @@ def parse_readme(readme: str) -> str: # NOTE: Change the templates/spot-controller.yaml.j2 file if any of the following # packages dependencies are changed. +aws_dependencies = [ + # awscli>=1.27.10 is required for SSO support. + 'awscli', + 'boto3', + # 'Crypto' module used in authentication.py for AWS. + 'pycryptodome==3.12.0', +] extras_require: Dict[str, List[str]] = { - 'aws': [ - # awscli>=1.27.10 is required for SSO support. - 'awscli', - 'boto3', - # 'Crypto' module used in authentication.py for AWS. - 'pycryptodome==3.12.0', - ], + 'aws': aws_dependencies, # TODO(zongheng): azure-cli is huge and takes a long time to install. # Tracked in: https://github.com/Azure/azure-cli/issues/7387 # azure-identity is needed in node_provider. @@ -121,6 +122,7 @@ def parse_readme(readme: str) -> str: 'gcp': ['google-api-python-client', 'google-cloud-storage'], 'docker': ['docker'], 'lambda': [], + 'cloudflare': aws_dependencies } extras_require['all'] = sum(extras_require.values(), []) diff --git a/sky/task.py b/sky/task.py index 62c96f3e363..bbaf9306011 100644 --- a/sky/task.py +++ b/sky/task.py @@ -736,6 +736,16 @@ def sync_storage_mounts(self) -> None: self.update_file_mounts({ mnt_path: blob_path, }) + elif store_type is storage_lib.StoreType.R2: + if storage.source is not None and not isinstance( + storage.source, + list) and storage.source.startswith('r2://'): + blob_path = storage.source + else: + blob_path = 'r2://' + storage.name + self.update_file_mounts({ + mnt_path: blob_path, + }) elif store_type is storage_lib.StoreType.AZURE: # TODO when Azure Blob is done: sync ~/.azure assert False, 'TODO: Azure Blob not mountable yet' diff --git a/tests/conftest.py b/tests/conftest.py index 9c20348fd1a..647a1988d3e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,7 @@ # --aws, --gcp, --azure, or --lambda. # # To only run tests for managed spot (without generic tests), use --managed-spot. -all_clouds_in_smoke_tests = ['aws', 'gcp', 'azure', 'lambda'] +all_clouds_in_smoke_tests = ['aws', 'gcp', 'azure', 'lambda', 'cloudflare'] default_clouds_to_run = ['gcp', 'azure'] # Translate cloud name to pytest keyword. We need this because @@ -28,7 +28,8 @@ 'aws': 'aws', 'gcp': 'gcp', 'azure': 'azure', - 'lambda': 'lambda_cloud' + 'lambda': 'lambda_cloud', + 'cloudflare': 'cloudflare' } @@ -76,7 +77,10 @@ def _get_cloud_to_run(config) -> List[str]: cloud_to_run = [] for cloud in all_clouds_in_smoke_tests: if config.getoption(f'--{cloud}'): - cloud_to_run.append(cloud) + if cloud == 'cloudflare': + cloud_to_run.append(default_clouds_to_run[0]) + else: + cloud_to_run.append(cloud) if not cloud_to_run: cloud_to_run = default_clouds_to_run return cloud_to_run @@ -104,6 +108,10 @@ def pytest_collection_modifyitems(config, items): for cloud in all_clouds_in_smoke_tests: cloud_keyword = cloud_to_pytest_keyword[cloud] if (cloud_keyword in item.keywords and cloud not in cloud_to_run): + # Need to check both conditions as 'gcp' is added to cloud_to_run + # when tested for cloudflare + if config.getoption('--cloudflare') and cloud == 'cloudflare': + continue item.add_marker(skip_marks[cloud]) if (not 'managed_spot' diff --git a/tests/test_smoke.py b/tests/test_smoke.py index d7638a89569..279a5063fd8 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -32,6 +32,7 @@ from typing import Dict, List, NamedTuple, Optional, Tuple import urllib.parse import uuid +import os import colorama import jinja2 @@ -40,6 +41,7 @@ import sky from sky import global_user_state from sky.data import storage as storage_lib +from sky.adaptors import cloudflare from sky.skylet import events from sky.utils import common_utils from sky.utils import subprocess_utils @@ -678,6 +680,35 @@ def test_gcp_storage_mounts(): run_one_test(test) +@pytest.mark.cloudflare +def test_cloudflare_storage_mounts(generic_cloud: str): + name = _get_cluster_name() + storage_name = f'sky-test-{int(time.time())}' + template_str = pathlib.Path( + 'tests/test_yamls/test_r2_storage_mounting.yaml').read_text() + template = jinja2.Template(template_str) + content = template.render(storage_name=storage_name) + endpoint_url = cloudflare.create_endpoint() + with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f: + f.write(content) + f.flush() + file_path = f.name + test_commands = [ + *storage_setup_commands, + f'sky launch -y -c {name} --cloud {generic_cloud} {file_path}', + f'sky logs {name} 1 --status', # Ensure job succeeded. + f'aws s3 ls s3://{storage_name}/hello.txt --endpoint {endpoint_url} --profile=r2' + ] + + test = Test( + 'cloudflare_storage_mounts', + test_commands, + f'sky down -y {name}; sky storage delete {storage_name}', + timeout=20 * 60, # 20 mins + ) + run_one_test(test) + + # ---------- CLI logs ---------- def test_cli_logs(generic_cloud: str): name = _get_cluster_name() @@ -1725,6 +1756,20 @@ def tmp_gsutil_bucket(self, tmp_bucket_name): yield tmp_bucket_name subprocess.check_call(['gsutil', 'rm', '-r', f'gs://{tmp_bucket_name}']) + @pytest.fixture + def tmp_awscli_bucket_r2(self, tmp_bucket_name): + # Creates a temporary bucket using awscli + endpoint_url = cloudflare.create_endpoint() + subprocess.check_call([ + 'aws', 's3', 'mb', f's3://{tmp_bucket_name}', '--endpoint', + endpoint_url, '--profile=r2' + ]) + yield tmp_bucket_name + subprocess.check_call([ + 'aws', 's3', 'rb', f's3://{tmp_bucket_name}', '--force', + '--endpoint', endpoint_url, '--profile=r2' + ]) + @pytest.fixture def tmp_public_storage_obj(self, request): # Initializes a storage object with a public bucket @@ -1733,8 +1778,10 @@ def tmp_public_storage_obj(self, request): # This does not require any deletion logic because it is a public bucket # and should not get added to global_user_state. - @pytest.mark.parametrize( - 'store_type', [storage_lib.StoreType.S3, storage_lib.StoreType.GCS]) + @pytest.mark.parametrize('store_type', [ + storage_lib.StoreType.S3, storage_lib.StoreType.GCS, + pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare) + ]) def test_new_bucket_creation_and_deletion(self, tmp_local_storage_obj, store_type): # Creates a new bucket with a local source, uploads files to it @@ -1753,8 +1800,10 @@ def test_new_bucket_creation_and_deletion(self, tmp_local_storage_obj, out = subprocess.check_output(['sky', 'storage', 'ls']) assert tmp_local_storage_obj.name not in out.decode('utf-8') - @pytest.mark.parametrize( - 'store_type', [storage_lib.StoreType.S3, storage_lib.StoreType.GCS]) + @pytest.mark.parametrize('store_type', [ + storage_lib.StoreType.S3, storage_lib.StoreType.GCS, + pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare) + ]) def test_bucket_bulk_deletion(self, store_type): # Create a temp folder with over 256 files and folders, upload # files and folders to a new bucket, then delete bucket. @@ -1791,8 +1840,10 @@ def test_public_bucket(self, tmp_public_storage_obj, store_type): out = subprocess.check_output(['sky', 'storage', 'ls']) assert tmp_public_storage_obj.name not in out.decode('utf-8') - @pytest.mark.parametrize('nonexist_bucket_url', - ['s3://{random_name}', 'gs://{random_name}']) + @pytest.mark.parametrize('nonexist_bucket_url', [ + 's3://{random_name}', 'gs://{random_name}', + pytest.param('r2://{random_name}', marks=pytest.mark.cloudflare) + ]) def test_nonexistent_bucket(self, nonexist_bucket_url): # Attempts to create fetch a stroage with a non-existent source. # Generate a random bucket name and verify it doesn't exist: @@ -1811,6 +1862,14 @@ def test_nonexistent_bucket(self, nonexist_bucket_url): nonexist_bucket_url.format(random_name=nonexist_bucket_name) ] expected_output = 'BucketNotFoundException' + elif nonexist_bucket_url.startswith('r2'): + endpoint_url = cloudflare.create_endpoint() + command = [ + 'aws', 's3api', 'head-bucket', '--bucket', + nonexist_bucket_name, '--endpoint', endpoint_url, + '--profile=r2' + ] + expected_output = '404' else: raise ValueError('Unsupported bucket type ' f'{nonexist_bucket_url}') @@ -1863,10 +1922,23 @@ def cli_ls_cmd(store_type, bucket_name, suffix=''): else: url = f'gs://{bucket_name}' return ['gsutil', 'ls', url] + if store_type == storage_lib.StoreType.R2: + endpoint_url = cloudflare.create_endpoint() + if suffix: + url = f's3://{bucket_name}/{suffix}' + else: + url = f's3://{bucket_name}' + return [ + 'aws', 's3', 'ls', url, '--endpoint', endpoint_url, + '--profile=r2' + ] @pytest.mark.parametrize('ext_bucket_fixture, store_type', [('tmp_awscli_bucket', storage_lib.StoreType.S3), - ('tmp_gsutil_bucket', storage_lib.StoreType.GCS)]) + ('tmp_gsutil_bucket', storage_lib.StoreType.GCS), + pytest.param('tmp_awscli_bucket_r2', + storage_lib.StoreType.R2, + marks=pytest.mark.cloudflare)]) def test_upload_to_existing_bucket(self, ext_bucket_fixture, request, tmp_source, store_type): # Tries uploading existing files to newly created bucket (outside of @@ -1905,8 +1977,10 @@ def test_copy_mount_existing_storage(self, out = subprocess.check_output(['sky', 'storage', 'ls']).decode('utf-8') assert storage_name in out, f'Storage {storage_name} not found in sky storage ls.' - @pytest.mark.parametrize( - 'store_type', [storage_lib.StoreType.S3, storage_lib.StoreType.GCS]) + @pytest.mark.parametrize('store_type', [ + storage_lib.StoreType.S3, storage_lib.StoreType.GCS, + pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare) + ]) def test_list_source(self, tmp_local_list_storage_obj, store_type): # Uses a list in the source field to specify a file and a directory to # be uploaded to the storage object. diff --git a/tests/test_yamls/test_r2_storage_mounting.yaml b/tests/test_yamls/test_r2_storage_mounting.yaml new file mode 100644 index 00000000000..1af93cc29fa --- /dev/null +++ b/tests/test_yamls/test_r2_storage_mounting.yaml @@ -0,0 +1,39 @@ +file_mounts: + # Mounting private buckets in COPY mode with a source dir + /mount_private_copy: + name: {{storage_name}} + source: ~/tmp-workdir + store: r2 + mode: COPY + + # Mounting private buckets in COPY mode with a list of files as source + /mount_private_copy_lof: + name: {{storage_name}} + source: ['~/tmp-workdir/tmp file', '~/tmp-workdir/tmp file2'] + store: r2 + mode: COPY + + # Mounting private buckets in MOUNT mode + /mount_private_mount: + name: {{storage_name}} + source: ~/tmp-workdir + store: r2 + mode: MOUNT + +run: | + set -ex + + # Check private bucket contents + ls -ltr /mount_private_copy/foo + ls -ltr /mount_private_copy/tmp\ file + ls -ltr /mount_private_copy_lof/tmp\ file + ls -ltr /mount_private_copy_lof/tmp\ file2 + ls -ltr /mount_private_mount/foo + ls -ltr /mount_private_mount/tmp\ file + + # Symlinks are not copied to buckets + ! ls /mount_private_copy/circle-link + ! ls /mount_private_mount/circle-link + + # Write to private bucket in MOUNT mode should pass + echo "hello" > /mount_private_mount/hello.txt \ No newline at end of file