Skip to content

Commit

Permalink
Allow reading datasets from private s3 bucket (#74)
Browse files Browse the repository at this point in the history
* Allow reading datasets from private s3 bucket

* Add unit tests for reading datasets

* Update docstrings

* Add cli arguments for reading private datasets
  • Loading branch information
katxiao authored May 14, 2021
1 parent 70e5010 commit 900ec9d
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 27 deletions.
29 changes: 24 additions & 5 deletions sdgym/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,15 @@ def _run(args):
datasets_path=args.datasets_path,
modalities=args.modalities,
metrics=args.metrics,
bucket=args.bucket,
iterations=args.iterations,
cache_dir=args.cache_dir,
workers=workers,
show_progress=args.progress,
timeout=args.timeout,
output_path=args.output_path,
aws_key=args.aws_key,
aws_secret=args.aws_secret,
)

if args.groupby:
Expand All @@ -105,10 +108,12 @@ def _download_datasets(args):
_env_setup(args.logfile, args.verbose)
datasets = args.datasets
if not datasets:
datasets = sdgym.datasets.get_available_datasets(args.bucket)['name']
datasets = sdgym.datasets.get_available_datasets(
args.bucket, args.aws_key, args.aws_secret)['name']

for dataset in tqdm.tqdm(datasets):
sdgym.datasets.load_dataset(dataset, args.datasets_path, args.bucket)
sdgym.datasets.load_dataset(
dataset, args.datasets_path, args.bucket, args.aws_key, args.aws_secret)


def _list_downloaded(args):
Expand All @@ -118,7 +123,7 @@ def _list_downloaded(args):


def _list_available(args):
datasets = sdgym.datasets.get_available_datasets(args.bucket)
datasets = sdgym.datasets.get_available_datasets(args.bucket, args.aws_key, args.aws_secret)
_print_table(datasets, args.sort, args.reverse, {'size': humanfriendly.format_size})


Expand All @@ -140,6 +145,8 @@ def _get_parser():
help='Synthesizer/s to be benchmarked. Accepts multiple names.')
run.add_argument('-m', '--metrics', nargs='+',
help='Metrics to apply. Accepts multiple names.')
run.add_argument('-b', '--bucket',
help='Bucket from which to download the datasets.')
run.add_argument('-d', '--datasets', nargs='+',
help='Datasets/s to be used. Accepts multiple names.')
run.add_argument('-dp', '--datasets-path',
Expand All @@ -163,7 +170,11 @@ def _get_parser():
run.add_argument('-t', '--timeout', type=int,
help='Maximum seconds to run for each dataset.')
run.add_argument('-g', '--groupby', nargs='+',
help='Group scores leaderboard by the given fields')
help='Group scores leaderboard by the given fields.')
run.add_argument('-ak', '--aws-key', type=str, required=False,
help='Aws access key ID to use when reading datasets.')
run.add_argument('-as', '--aws-secret', type=str, required=False,
help='Aws secret access key to use when reading datasets.')

# download-datasets
download = action.add_parser('download-datasets', help='Download datasets.')
Expand All @@ -178,8 +189,12 @@ def _get_parser():
help='Be verbose. Repeat for increased verbosity.')
download.add_argument('-l', '--logfile', type=str,
help='Name of the log file.')
download.add_argument('-ak', '--aws-key', type=str, required=False,
help='Aws access key ID to use when reading datasets.')
download.add_argument('-as', '--aws-secret', type=str, required=False,
help='Aws secret access key to use when reading datasets.')

# list-available-datasets
# list-downloaded-datasets
list_downloaded = action.add_parser('list-downloaded', help='List downloaded datasets.')
list_downloaded.set_defaults(action=_list_downloaded)
list_downloaded.add_argument('-s', '--sort', default='name',
Expand All @@ -198,6 +213,10 @@ def _get_parser():
help='Reverse the order.')
list_available.add_argument('-b', '--bucket',
help='Bucket from which to download the datasets.')
list_available.add_argument('-ak', '--aws-key', type=str, required=False,
help='Aws access key ID to use when reading datasets.')
list_available.add_argument('-as', '--aws-secret', type=str, required=False,
help='Aws secret access key to use when reading datasets.')
list_available.set_defaults(action=_list_available)

return parser
Expand Down
10 changes: 8 additions & 2 deletions sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _run_on_dask(jobs, verbose):

def run(synthesizers, datasets=None, datasets_path=None, modalities=None, bucket=None,
metrics=None, iterations=1, workers=1, cache_dir=None, show_progress=False,
timeout=None, output_path=None):
timeout=None, output_path=None, aws_key=None, aws_secret=None):
"""Run the SDGym benchmark and return a leaderboard.
The ``synthesizers`` object can either be a single synthesizer or, an iterable of
Expand Down Expand Up @@ -273,13 +273,19 @@ def run(synthesizers, datasets=None, datasets_path=None, modalities=None, bucket
If an ``output_path`` is given, the generated leaderboard will be stored in the
indicated path as a CSV file. The given path must be a complete path including
the ``.csv`` filename.
aws_key (str):
If an ``aws_key`` is provided, the given access key id will be used to read
from the specified bucket.
aws_secret (str):
If an ``aws_secret`` is provided, the given secret access key will be used to read
from the specified bucket.
Returns:
pandas.DataFrame:
A table containing one row per synthesizer + dataset + metric + iteration.
"""
synthesizers = get_synthesizers_dict(synthesizers)
datasets = get_dataset_paths(datasets, datasets_path, bucket)
datasets = get_dataset_paths(datasets, datasets_path, bucket, aws_key, aws_secret)
run_id = os.getenv('RUN_ID') or str(uuid.uuid4())[:10]

if cache_dir:
Expand Down
58 changes: 38 additions & 20 deletions sdgym/datasets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import io
import logging
import urllib.request
from pathlib import Path
from xml.etree import ElementTree
from zipfile import ZipFile

import appdirs
import boto3
import botocore
import pandas as pd
from sdv import Metadata

Expand All @@ -17,21 +17,40 @@
TIMESERIES_FIELDS = ['sequence_index', 'entity_columns', 'context_columns', 'deepecho_version']


def download_dataset(dataset_name, datasets_path=None, bucket=None):
def _get_s3_client(aws_key=None, aws_secret=None):
if aws_key is not None and aws_secret is not None:
# credentials available
return boto3.client(
's3',
aws_access_key_id=aws_key,
aws_secret_access_key=aws_secret
)
else:
if boto3.Session().get_credentials():
# credentials available and will be detected automatically
config = None
else:
# no credentials available, make unsigned requests
config = botocore.config.Config(signature_version=botocore.UNSIGNED)

return boto3.client('s3', config=config)


def download_dataset(dataset_name, datasets_path=None, bucket=None, aws_key=None, aws_secret=None):
datasets_path = datasets_path or DATASETS_PATH
bucket = bucket or BUCKET
url = BUCKET_URL.format(bucket) + f'{dataset_name}.zip'

LOGGER.info('Downloading dataset %s from %s', dataset_name, url)
response = urllib.request.urlopen(url)
bytes_io = io.BytesIO(response.read())
LOGGER.info('Downloading dataset %s from %s', dataset_name, bucket)
s3 = _get_s3_client(aws_key, aws_secret)
obj = s3.get_object(Bucket=bucket, Key=f'{dataset_name}.zip')
bytes_io = io.BytesIO(obj['Body'].read())

LOGGER.info('Extracting dataset into %s', datasets_path)
with ZipFile(bytes_io) as zf:
zf.extractall(datasets_path)


def _get_dataset_path(dataset, datasets_path, bucket=None):
def _get_dataset_path(dataset, datasets_path, bucket=None, aws_key=None, aws_secret=None):
dataset = Path(dataset)
if dataset.exists():
return dataset
Expand All @@ -41,12 +60,12 @@ def _get_dataset_path(dataset, datasets_path, bucket=None):
if dataset_path.exists():
return dataset_path

download_dataset(dataset, datasets_path, bucket=bucket)
download_dataset(dataset, datasets_path, bucket=bucket, aws_key=aws_key, aws_secret=aws_secret)
return dataset_path


def load_dataset(dataset, datasets_path=None, bucket=None):
dataset_path = _get_dataset_path(dataset, datasets_path, bucket)
def load_dataset(dataset, datasets_path=None, bucket=None, aws_key=None, aws_secret=None):
dataset_path = _get_dataset_path(dataset, datasets_path, bucket, aws_key, aws_secret)
metadata = Metadata(str(dataset_path / 'metadata.json'))
tables = metadata.get_tables()
if not hasattr(metadata, 'modality'):
Expand Down Expand Up @@ -83,14 +102,13 @@ def load_tables(metadata):
return real_data


def get_available_datasets(bucket=None):
bucket_url = BUCKET_URL.format(bucket or BUCKET)
response = urllib.request.urlopen(bucket_url)
tree = ElementTree.fromstring(response.read())
def get_available_datasets(bucket=None, aws_key=None, aws_secret=None):
s3 = _get_s3_client(aws_key, aws_secret)
response = s3.list_objects(Bucket=bucket or BUCKET)
datasets = []
for content in tree.findall('{*}Contents'):
key = content.find('{*}Key').text
size = int(content.find('{*}Size').text)
for content in response['Contents']:
key = content['Key']
size = int(content['Size'])
if key.endswith('.zip'):
datasets.append({
'name': key[:-len('.zip')],
Expand Down Expand Up @@ -118,7 +136,7 @@ def get_downloaded_datasets(datasets_path=None):
return pd.DataFrame(datasets)


def get_dataset_paths(datasets, datasets_path, bucket):
def get_dataset_paths(datasets, datasets_path, bucket, aws_key, aws_secret):
"""Build the full path to datasets and ensure they exist."""
if datasets_path is None:
datasets_path = DATASETS_PATH
Expand All @@ -132,6 +150,6 @@ def get_dataset_paths(datasets, datasets_path, bucket):
datasets = get_available_datasets()['name'].tolist()

return [
_get_dataset_path(dataset, datasets_path, bucket)
_get_dataset_path(dataset, datasets_path, bucket, aws_key, aws_secret)
for dataset in datasets
]
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

install_requires = [
'appdirs>1.1.4,<2',
'boto3>=1.15.0,<2',
'compress-pickle>=1.2.0,<2',
'humanfriendly>=8.2,<9',
'numpy>=1.15.4,<2',
Expand Down
Empty file added tests/__init__.py
Empty file.
Empty file added tests/unit/__init__.py
Empty file.
138 changes: 138 additions & 0 deletions tests/unit/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import io
from unittest.mock import Mock, patch
from zipfile import ZipFile

import botocore

from sdgym.datasets import download_dataset


class AnyConfigWith:
"""AnyConfigWith matches any s3 config with the specified signature version."""
def __init__(self, signature_version):
self.signature_version = signature_version

def __eq__(self, other):
return self.signature_version == other.signature_version


@patch('sdgym.datasets.boto3')
def test_download_dataset_public_bucket(boto3_mock, tmpdir):
"""Test the ``sdv.datasets.download_dataset`` method. It calls `download_dataset`
with a dataset in a public bucket, and does not pass in any aws credentials.
Setup:
The boto3 library for s3 access is patched, and mocks are created for the
s3 bucket and dataset zipfile. The tmpdir fixture is used for the expected
file creation.
Input:
- dataset name
- datasets path
- bucket
Output:
- n/a
Side effects:
- s3 client creation
- s3 method call to the bucket
- file creation for dataset in datasets path
"""
# setup
dataset = 'my_dataset'
bucket = 'my_bucket'
bytesio = io.BytesIO()

with ZipFile(bytesio, mode='w') as zf:
zf.writestr(dataset, 'test_content')

s3_mock = Mock()
body_mock = Mock()
body_mock.read.return_value = bytesio.getvalue()
obj = {
'Body': body_mock
}
s3_mock.get_object.return_value = obj
boto3_mock.client.return_value = s3_mock
boto3_mock.Session().get_credentials.return_value = None

# run
download_dataset(
dataset,
datasets_path=str(tmpdir),
bucket=bucket
)

# asserts
boto3_mock.client.assert_called_once_with(
's3',
config=AnyConfigWith(botocore.UNSIGNED)
)
s3_mock.get_object.assert_called_once_with(Bucket=bucket, Key=f'{dataset}.zip')
with open(f'{tmpdir}/{dataset}') as dataset_file:
assert dataset_file.read() == 'test_content'


@patch('sdgym.datasets.boto3')
def test_download_dataset_private_bucket(boto3_mock, tmpdir):
"""Test the ``sdv.datasets.download_dataset`` method. It calls `download_dataset`
with a dataset in a private bucket and uses aws credentials.
Setup:
The boto3 library for s3 access is patched, and mocks are created for the
s3 bucket and dataset zipfile. The tmpdir fixture is used for the expected
file creation.
Input:
- dataset name
- datasets path
- bucket
- aws key
- aws secret
Output:
- n/a
Side effects:
- s3 client creation with aws credentials
- s3 method call to the bucket
- file creation for dataset in datasets path
"""
# setup
dataset = 'my_dataset'
bucket = 'my_bucket'
aws_key = 'my_key'
aws_secret = 'my_secret'
bytesio = io.BytesIO()

with ZipFile(bytesio, mode='w') as zf:
zf.writestr(dataset, 'test_content')

s3_mock = Mock()
body_mock = Mock()
body_mock.read.return_value = bytesio.getvalue()
obj = {
'Body': body_mock
}
s3_mock.get_object.return_value = obj
boto3_mock.client.return_value = s3_mock

# run
download_dataset(
dataset,
datasets_path=str(tmpdir),
bucket=bucket,
aws_key=aws_key,
aws_secret=aws_secret
)

# asserts
boto3_mock.client.assert_called_once_with(
's3',
aws_access_key_id=aws_key,
aws_secret_access_key=aws_secret
)
s3_mock.get_object.assert_called_once_with(Bucket=bucket, Key=f'{dataset}.zip')
with open(f'{tmpdir}/{dataset}') as dataset_file:
assert dataset_file.read() == 'test_content'

0 comments on commit 900ec9d

Please sign in to comment.