Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GCS authentication for service accounts #315

Merged
merged 15 commits into from
Jun 29, 2023
2 changes: 1 addition & 1 deletion STYLE_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Streaming uses the [yapf](https://github.com/google/yapf) formatter for general
(see section 2.2). These checks can also be run manually via:

```
pre-commit run yapf --all-files # for yahp
pre-commit run yapf --all-files # for yapf
pre-commit run isort --all-files # for isort
```

Expand Down
18 changes: 17 additions & 1 deletion docs/source/how_to_guides/configure_cloud_storage_credentials.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,23 @@ export S3_ENDPOINT_URL='https://<accountid>.r2.cloudflarestorage.com'

For [MosaicML platform](https://www.mosaicml.com/cloud) users, follow the steps mentioned in the [Google Cloud Storage](https://mcli.docs.mosaicml.com/en/latest/secrets/gcp.html) MCLI doc on how to configure the cloud provider credentials.

### Others

### GCP Service Account Credentials Mounted as Environment Variables

Users must set their GCP `account credentials` to point to their credentials file in the run environment.

````{tabs}
```{code-tab} py
import os
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'KEY_FILE'
```

```{code-tab} sh
export GOOGLE_APPLICATION_CREDENTIALS='KEY_FILE'
```
````

### GCP User Auth Credentials Mounted as Environment Variables

Streaming dataset supports [GCP user credentials](https://cloud.google.com/storage/docs/authentication#user_accounts) or [HMAC keys for User account](https://cloud.google.com/storage/docs/authentication/hmackeys). Users must set their GCP `user access key` and GCP `user access secret` in the run environment.

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
install_requires = [
'boto3>=1.21.45,<2',
'Brotli>=1.0.9',
'google-cloud-storage>=2.9.0',
'matplotlib>=3.5.2,<4',
'paramiko>=2.11.0,<4',
'python-snappy>=0.6.1,<1',
Expand Down
42 changes: 38 additions & 4 deletions streaming/base/storage/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,33 @@ def download_from_gcs(remote: str, local: str) -> None:
remote (str): Remote path (GCS).
local (str): Local path (local filesystem).
"""
import boto3
from boto3.s3.transfer import TransferConfig
from botocore.exceptions import ClientError

obj = urllib.parse.urlparse(remote)
if obj.scheme != 'gs':
raise ValueError(
f'Expected obj.scheme to be `gs`, instead, got {obj.scheme} for remote={remote}')

if 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ:
_gcs_with_service_account(local, obj)
elif 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ:
_gcs_with_hmac(remote, local, obj)
else:
raise ValueError(f'Either GOOGLE_APPLICATION_CREDENTIALS needs to be set for ' +
f'service level accounts or GCS_KEY and GCS_SECRET needs to be ' +
f'set for HMAC authentication')


def _gcs_with_hmac(remote: str, local: str, obj: urllib.parse.ParseResult) -> None:
"""Download a file from remote GCS to local using user level credentials.

Args:
remote (str): Remote path (GCS).
local (str): Local path (local filesystem).
obj (ParseResult): ParseResult object of remote.
"""
import boto3
from boto3.s3.transfer import TransferConfig
from botocore.exceptions import ClientError

# Create a new session per thread
session = boto3.session.Session()
# Create a resource client using a thread's session object
Expand All @@ -190,6 +208,22 @@ def download_from_gcs(remote: str, local: str) -> None:
raise


def _gcs_with_service_account(local: str, obj: urllib.parse.ParseResult) -> None:
"""Download a file from remote GCS to local using service account credentials.

Args:
local (str): Local path (local filesystem).
obj (ParseResult): ParseResult object of remote path (GCS).
"""
from google.cloud.storage import Blob, Bucket, Client

service_account_path = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
gcs_client = Client.from_service_account_json(service_account_path)

blob = Blob(obj.path.lstrip('/'), Bucket(gcs_client, obj.netloc))
blob.download_to_filename(local)
b-chu marked this conversation as resolved.
Show resolved Hide resolved


def download_from_oci(remote: str, local: str) -> None:
"""Download a file from remote OCI to local.

Expand Down
119 changes: 80 additions & 39 deletions streaming/base/storage/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import shutil
import sys
import urllib.parse
from enum import Enum
from tempfile import mkdtemp
from typing import Any, Tuple, Union

Expand All @@ -16,8 +17,12 @@
from streaming.base.storage.download import BOTOCORE_CLIENT_ERROR_CODES

__all__ = [
'CloudUploader', 'S3Uploader', 'GCSUploader', 'OCIUploader', 'AzureUploader',
'AzureDataLakeUploader', 'LocalUploader'
'CloudUploader',
'S3Uploader',
'GCSUploader',
'OCIUploader',
'AzureUploader',
'LocalUploader',
]

logger = logging.getLogger(__name__)
Expand All @@ -32,6 +37,11 @@
}


class GCSAuthentication(Enum):
HMAC = 1
SERVICE_ACCOUNT = 2


class CloudUploader:
"""Upload local files to a cloud storage."""

Expand Down Expand Up @@ -84,10 +94,9 @@ def _validate(self, out: Union[str, Tuple[str, str]]) -> None:
obj = urllib.parse.urlparse(out)
else:
if len(out) != 2:
raise ValueError(''.join([
f'Invalid `out` argument. It is either a string of local/remote directory ',
'or a list of two strings with [local, remote].'
]))
raise ValueError(f'Invalid `out` argument. It is either a string of ' +
f'local/remote directory or a list of two strings with ' +
f'[local, remote].')
obj = urllib.parse.urlparse(out[1])
if obj.scheme not in UPLOADERS:
raise ValueError(f'Invalid Cloud provider prefix: {obj.scheme}.')
Expand Down Expand Up @@ -183,6 +192,7 @@ def __init__(self,

import boto3
from botocore.config import Config

b-chu marked this conversation as resolved.
Show resolved Hide resolved
config = Config()
# Create a session and use it to make our client. Unlike Resources and Sessions,
# clients are generally thread-safe.
Expand Down Expand Up @@ -261,19 +271,34 @@ def __init__(self,
progress_bar: bool = False) -> None:
super().__init__(out, keep_local, progress_bar)

import boto3
if 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ:
from google.cloud.storage import Client

service_account_path = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
self.gcs_client = Client.from_service_account_json(service_account_path)
self.authentication = GCSAuthentication.SERVICE_ACCOUNT
elif 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ:
import boto3

# Create a session and use it to make our client. Unlike Resources and Sessions,
# clients are generally thread-safe.
session = boto3.session.Session()
self.gcs_client = session.client(
's3',
region_name='auto',
endpoint_url='https://storage.googleapis.com',
aws_access_key_id=os.environ['GCS_KEY'],
aws_secret_access_key=os.environ['GCS_SECRET'],
)
self.authentication = GCSAuthentication.HMAC
else:
raise ValueError(f'Either GOOGLE_APPLICATION_CREDENTIALS needs to be set for ' +
f'service level accounts or GCS_KEY and GCS_SECRET needs to ' +
f'be set for HMAC authentication')

# Create a session and use it to make our client. Unlike Resources and Sessions,
# clients are generally thread-safe.
session = boto3.session.Session()
self.gcs_client = session.client('s3',
region_name='auto',
endpoint_url='https://storage.googleapis.com',
aws_access_key_id=os.environ['GCS_KEY'],
aws_secret_access_key=os.environ['GCS_SECRET'])
self.check_bucket_exists(self.remote) # pyright: ignore

def upload_file(self, filename: str):
def upload_file(self, filename: str) -> None:
"""Upload file from local instance to Google Cloud Storage bucket.

Args:
Expand All @@ -283,21 +308,31 @@ def upload_file(self, filename: str):
remote_filename = os.path.join(self.remote, filename) # pyright: ignore
obj = urllib.parse.urlparse(remote_filename)
logger.debug(f'Uploading to {remote_filename}')
file_size = os.stat(local_filename).st_size
with tqdm.tqdm(total=file_size,
unit='B',
unit_scale=True,
desc=f'Uploading to {remote_filename}',
disable=(not self.progress_bar)) as pbar:
self.gcs_client.upload_file(
local_filename,
obj.netloc,
obj.path.lstrip('/'),
Callback=lambda bytes_transferred: pbar.update(bytes_transferred),
)

if self.authentication == GCSAuthentication.HMAC:
file_size = os.stat(local_filename).st_size
with tqdm.tqdm(
total=file_size,
unit='B',
unit_scale=True,
desc=f'Uploading to {remote_filename}',
disable=(not self.progress_bar),
) as pbar:
self.gcs_client.upload_file(
local_filename,
obj.netloc,
obj.path.lstrip('/'),
Callback=lambda bytes_transferred: pbar.update(bytes_transferred),
)
elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT:
from google.cloud.storage import Blob, Bucket

blob = Blob(obj.path.lstrip('/'), Bucket(self.gcs_client, obj.netloc))
blob.upload_from_filename(local_filename)

self.clear_local(local=local_filename)

def check_bucket_exists(self, remote: str):
def check_bucket_exists(self, remote: str) -> None:
"""Raise an exception if the bucket does not exist.

Args:
Expand All @@ -306,16 +341,20 @@ def check_bucket_exists(self, remote: str):
Raises:
error: Bucket does not exist.
"""
from botocore.exceptions import ClientError

bucket_name = urllib.parse.urlparse(remote).netloc
try:
self.gcs_client.head_bucket(Bucket=bucket_name)
except ClientError as error:
if error.response['Error']['Code'] == BOTOCORE_CLIENT_ERROR_CODES:
error.args = (f'Either bucket `{bucket_name}` does not exist! ' +
f'or check the bucket permission.',)
raise error

if self.authentication == GCSAuthentication.HMAC:
from botocore.exceptions import ClientError

try:
self.gcs_client.head_bucket(Bucket=bucket_name)
except ClientError as error:
if (error.response['Error']['Code'] == BOTOCORE_CLIENT_ERROR_CODES):
error.args = (f'Either bucket `{bucket_name}` does not exist! ' +
f'or check the bucket permission.',)
raise error
elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT:
self.gcs_client.get_bucket(bucket_name)


class OCIUploader(CloudUploader):
Expand Down Expand Up @@ -343,6 +382,7 @@ def __init__(self,
super().__init__(out, keep_local, progress_bar)

import oci

config = oci.config.from_file()
self.client = oci.object_storage.ObjectStorageClient(
config=config, retry_strategy=oci.retry.DEFAULT_RETRY_STRATEGY)
Expand Down Expand Up @@ -430,7 +470,8 @@ def __init__(self,
# clients are generally thread-safe.
self.azure_service = BlobServiceClient(
account_url=f"https://{os.environ['AZURE_ACCOUNT_NAME']}.blob.core.windows.net",
credential=os.environ['AZURE_ACCOUNT_ACCESS_KEY'])
credential=os.environ['AZURE_ACCOUNT_ACCESS_KEY'],
)
self.check_bucket_exists(self.remote) # pyright: ignore

def upload_file(self, filename: str):
Expand Down
Loading