Skip to content

Commit

Permalink
Allow cache dir to be a s3 bucket
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed May 19, 2021
1 parent e1e4aae commit ea87801
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 32 deletions.
20 changes: 13 additions & 7 deletions sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sdgym.errors import SDGymError
from sdgym.metrics import get_metrics
from sdgym.progress import TqdmLogger, progress
from sdgym.s3 import is_s3_path, write_csv, write_file
from sdgym.synthesizers.base import Baseline
from sdgym.utils import (
build_synthesizer, format_exception, get_synthesizers_dict, import_object, used_memory)
Expand Down Expand Up @@ -153,7 +154,8 @@ def _run_job(args):
# Reset random seed
np.random.seed()

synthesizer, metadata, metrics, iteration, cache_dir, timeout, run_id = args
synthesizer, metadata, metrics, iteration, cache_dir, \
timeout, run_id, aws_key, aws_secret = args

name = synthesizer['name']
dataset_name = metadata._metadata['name']
Expand Down Expand Up @@ -183,14 +185,16 @@ def _run_job(args):
scores['error'] = output['error']

if cache_dir:
base_path = str(cache_dir / f'{name}_{dataset_name}_{iteration}_{run_id}')
cache_dir_name = str(cache_dir)
base_path = f'{cache_dir_name}/{name}_{dataset_name}_{iteration}_{run_id}'
if scores is not None:
scores.to_csv(base_path + '_scores.csv', index=False)
write_csv(scores, f'{base_path}_scores.csv', aws_key, aws_secret)
if 'synthetic_data' in output:
compress_pickle.dump(output['synthetic_data'], base_path + '.data.gz')
synthetic_data = compress_pickle.dumps(output['synthetic_data'], compression='gzip')
write_file(synthetic_data, f'{base_path}.data.gz', aws_key, aws_secret)
if 'exception' in output:
with open(base_path + '_error.txt', 'w') as error_file:
error_file.write(output['exception'])
exception = output['exception'].encode('utf-8')
write_file(exception, f'{base_path}_error.txt', aws_key, aws_secret)

return scores

Expand Down Expand Up @@ -288,7 +292,7 @@ def run(synthesizers, datasets=None, datasets_path=None, modalities=None, bucket
datasets = get_dataset_paths(datasets, datasets_path, bucket, aws_key, aws_secret)
run_id = os.getenv('RUN_ID') or str(uuid.uuid4())[:10]

if cache_dir:
if cache_dir and not is_s3_path(cache_dir):
cache_dir = Path(cache_dir)
os.makedirs(cache_dir, exist_ok=True)

Expand All @@ -307,6 +311,8 @@ def run(synthesizers, datasets=None, datasets_path=None, modalities=None, bucket
cache_dir,
timeout,
run_id,
aws_key,
aws_secret,
)
jobs.append(args)

Expand Down
27 changes: 4 additions & 23 deletions sdgym/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from zipfile import ZipFile

import appdirs
import boto3
import botocore
import pandas as pd
from sdv import Metadata

from sdgym.s3 import get_s3_client

LOGGER = logging.getLogger(__name__)

DATASETS_PATH = Path(appdirs.user_data_dir()) / 'SDGym' / 'datasets'
Expand All @@ -17,31 +17,12 @@
TIMESERIES_FIELDS = ['sequence_index', 'entity_columns', 'context_columns', 'deepecho_version']


def _get_s3_client(aws_key=None, aws_secret=None):
if aws_key is not None and aws_secret is not None:
# credentials available
return boto3.client(
's3',
aws_access_key_id=aws_key,
aws_secret_access_key=aws_secret
)
else:
if boto3.Session().get_credentials():
# credentials available and will be detected automatically
config = None
else:
# no credentials available, make unsigned requests
config = botocore.config.Config(signature_version=botocore.UNSIGNED)

return boto3.client('s3', config=config)


def download_dataset(dataset_name, datasets_path=None, bucket=None, aws_key=None, aws_secret=None):
datasets_path = datasets_path or DATASETS_PATH
bucket = bucket or BUCKET

LOGGER.info('Downloading dataset %s from %s', dataset_name, bucket)
s3 = _get_s3_client(aws_key, aws_secret)
s3 = get_s3_client(aws_key, aws_secret)
obj = s3.get_object(Bucket=bucket, Key=f'{dataset_name}.zip')
bytes_io = io.BytesIO(obj['Body'].read())

Expand Down Expand Up @@ -103,7 +84,7 @@ def load_tables(metadata):


def get_available_datasets(bucket=None, aws_key=None, aws_secret=None):
s3 = _get_s3_client(aws_key, aws_secret)
s3 = get_s3_client(aws_key, aws_secret)
response = s3.list_objects(Bucket=bucket or BUCKET)
datasets = []
for content in response['Contents']:
Expand Down
146 changes: 146 additions & 0 deletions sdgym/s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import boto3
import botocore

S3_PREFIX = 's3://'


def is_s3_path(path):
"""Determine if the given path is an s3 path.
Args:
path (str):
The path, which might be an s3 path.
Returns:
bool:
A boolean representing if the path is an s3 path or not.
"""
return isinstance(path, str) and S3_PREFIX in path


def parse_s3_path(path):
"""Parse a s3 path into the bucket and key prefix.
Args:
path (str):
The s3 path to parse. The expected format for the s3 path is
`s3://<bucket-name>/path/to/dir`.
Returns:
tuple:
A tuple containing (`bucket_name`, `key_prefix`) where `bucket_name`
is the name of the s3 bucket, and `key_prefix` is the remainder
of the s3 path.
"""
bucket_parts = path.replace(S3_PREFIX, '').split('/')
bucket_name = bucket_parts[0]

key_prefix = ''
if len(bucket_parts) > 1:
key_prefix = '/'.join(bucket_parts[1:])

return bucket_name, key_prefix


def get_s3_client(aws_key=None, aws_secret=None):
"""Get the boto client for interfacing with AWS s3.
Args:
aws_key (str):
The access key id that will be used to communicate with
s3, if provided.
aws_secret (str):
The secret access key that will be used to communicate
with s3, if provided.
Returns:
boto3.session.Session.client:
The s3 client that can be used to read / write to s3.
"""
if aws_key is not None and aws_secret is not None:
# credentials available
return boto3.client(
's3',
aws_access_key_id=aws_key,
aws_secret_access_key=aws_secret
)
else:
if boto3.Session().get_credentials():
# credentials available and will be detected automatically
config = None
else:
# no credentials available, make unsigned requests
config = botocore.config.Config(signature_version=botocore.UNSIGNED)

return boto3.client('s3', config=config)


def write_file(contents, path, aws_key, aws_secret):
"""Write a file to the given path with the given contents.
If the path is an s3 directory, we will use the given aws credentials
to write to s3.
Args:
contents (bytes):
The contents that will be written to the file.
path (str):
The path to write the file to, which can be either local
or an s3 path.
aws_key (str):
The access key id that will be used to communicate with s3,
if provided.
aws_secret (str):
The secret access key that will be used to communicate
with s3, if provided.
Returns:
none
"""
content_encoding = ''
write_mode = 'w'
if path.endswith('gz') or path.endswith('gzip'):
content_encoding = 'gzip'
write_mode = 'wb'

if is_s3_path(path):
s3 = get_s3_client(aws_key, aws_secret)
bucket_name, key = parse_s3_path(path)
s3.put_object(
Bucket=bucket_name,
Key=key,
Body=contents,
ContentEncoding=content_encoding,
)
else:
with open(path, write_mode) as f:
if write_mode == 'w':
f.write(contents.decode('utf-8'))
else:
f.write(contents)


def write_csv(data, path, aws_key, aws_secret):
"""Write a csv file to the given path with the given contents.
If the path is an s3 directory, we will use the given aws credentials
to write to s3.
Args:
data (pandas.DataFrame):
The data that will be written to the csv file.
path (str):
The path to write the file to, which can be either local
or an s3 path.
aws_key (str):
The access key id that will be used to communicate with s3,
if provided.
aws_secret (str):
The secret access key that will be used to communicate
with s3, if provided.
Returns:
none
"""
contents = data.to_csv(index=False).encode('utf-8')
write_file(contents, path, aws_key, aws_secret)
4 changes: 2 additions & 2 deletions tests/unit/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __eq__(self, other):
return self.signature_version == other.signature_version


@patch('sdgym.datasets.boto3')
@patch('sdgym.s3.boto3')
def test_download_dataset_public_bucket(boto3_mock, tmpdir):
"""Test the ``sdv.datasets.download_dataset`` method. It calls `download_dataset`
with a dataset in a public bucket, and does not pass in any aws credentials.
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_download_dataset_public_bucket(boto3_mock, tmpdir):
assert dataset_file.read() == 'test_content'


@patch('sdgym.datasets.boto3')
@patch('sdgym.s3.boto3')
def test_download_dataset_private_bucket(boto3_mock, tmpdir):
"""Test the ``sdv.datasets.download_dataset`` method. It calls `download_dataset`
with a dataset in a private bucket and uses aws credentials.
Expand Down

0 comments on commit ea87801

Please sign in to comment.