Skip to content

Commit

Permalink
Allow reading datasets from private s3 bucket
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed May 11, 2021
1 parent 70e5010 commit 5df83bf
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 20 deletions.
10 changes: 8 additions & 2 deletions sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _run_on_dask(jobs, verbose):

def run(synthesizers, datasets=None, datasets_path=None, modalities=None, bucket=None,
metrics=None, iterations=1, workers=1, cache_dir=None, show_progress=False,
timeout=None, output_path=None):
timeout=None, output_path=None, aws_key=None, aws_secret=None):
"""Run the SDGym benchmark and return a leaderboard.
The ``synthesizers`` object can either be a single synthesizer or, an iterable of
Expand Down Expand Up @@ -273,13 +273,19 @@ def run(synthesizers, datasets=None, datasets_path=None, modalities=None, bucket
If an ``output_path`` is given, the generated leaderboard will be stored in the
indicated path as a CSV file. The given path must be a complete path including
the ``.csv`` filename.
aws_key (str):
If an ``aws_key`` is provided, the given access key id will be used to read
from the specified bucket.
aws_secret (str):
If an ``aws_secret`` is provided, the given secret access key will be used to read
from the specified bucket.
Returns:
pandas.DataFrame:
A table containing one row per synthesizer + dataset + metric + iteration.
"""
synthesizers = get_synthesizers_dict(synthesizers)
datasets = get_dataset_paths(datasets, datasets_path, bucket)
datasets = get_dataset_paths(datasets, datasets_path, bucket, aws_key, aws_secret)
run_id = os.getenv('RUN_ID') or str(uuid.uuid4())[:10]

if cache_dir:
Expand Down
49 changes: 31 additions & 18 deletions sdgym/datasets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import io
import logging
import urllib.request
from pathlib import Path
from xml.etree import ElementTree
from zipfile import ZipFile

import appdirs
import boto3
import botocore
import pandas as pd
from sdv import Metadata

Expand All @@ -17,21 +17,35 @@
TIMESERIES_FIELDS = ['sequence_index', 'entity_columns', 'context_columns', 'deepecho_version']


def download_dataset(dataset_name, datasets_path=None, bucket=None):
def download_dataset(dataset_name, datasets_path=None, bucket=None, aws_key=None, aws_secret=None):
datasets_path = datasets_path or DATASETS_PATH
bucket = bucket or BUCKET
url = BUCKET_URL.format(bucket) + f'{dataset_name}.zip'

LOGGER.info('Downloading dataset %s from %s', dataset_name, url)
response = urllib.request.urlopen(url)
bytes_io = io.BytesIO(response.read())
LOGGER.info('Downloading dataset %s from %s', dataset_name, bucket)
s3 = _get_s3_client(aws_key, aws_secret)
obj = s3.get_object(Bucket=bucket, Key=f'{dataset_name}.zip')
bytes_io = io.BytesIO(obj['Body'].read())

LOGGER.info('Extracting dataset into %s', datasets_path)
with ZipFile(bytes_io) as zf:
zf.extractall(datasets_path)


def _get_dataset_path(dataset, datasets_path, bucket=None):
def _get_s3_client(aws_key=None, aws_secret=None):
if aws_key is not None and aws_secret is not None:
# credentials available
return boto3.client(
's3',
aws_access_key_id=aws_key,
aws_secret_access_key=aws_secret
)
else:
# no credentials available, make unsigned requests
config = botocore.config.Config(signature_version=botocore.UNSIGNED)
return boto3.client('s3', config=config)


def _get_dataset_path(dataset, datasets_path, bucket=None, aws_key=None, aws_secret=None):
dataset = Path(dataset)
if dataset.exists():
return dataset
Expand All @@ -41,7 +55,7 @@ def _get_dataset_path(dataset, datasets_path, bucket=None):
if dataset_path.exists():
return dataset_path

download_dataset(dataset, datasets_path, bucket=bucket)
download_dataset(dataset, datasets_path, bucket=bucket, aws_key=aws_key, aws_secret=aws_secret)
return dataset_path


Expand Down Expand Up @@ -83,14 +97,13 @@ def load_tables(metadata):
return real_data


def get_available_datasets(bucket=None):
bucket_url = BUCKET_URL.format(bucket or BUCKET)
response = urllib.request.urlopen(bucket_url)
tree = ElementTree.fromstring(response.read())
def get_available_datasets(bucket=None, aws_key=None, aws_secret=None):
s3 = _get_s3_client(aws_key, aws_secret)
response = s3.list_objects(Bucket=bucket or BUCKET)
datasets = []
for content in tree.findall('{*}Contents'):
key = content.find('{*}Key').text
size = int(content.find('{*}Size').text)
for content in response['Contents']:
key = content['Key']
size = int(content['Size'])
if key.endswith('.zip'):
datasets.append({
'name': key[:-len('.zip')],
Expand Down Expand Up @@ -118,7 +131,7 @@ def get_downloaded_datasets(datasets_path=None):
return pd.DataFrame(datasets)


def get_dataset_paths(datasets, datasets_path, bucket):
def get_dataset_paths(datasets, datasets_path, bucket, aws_key, aws_secret):
"""Build the full path to datasets and ensure they exist."""
if datasets_path is None:
datasets_path = DATASETS_PATH
Expand All @@ -132,6 +145,6 @@ def get_dataset_paths(datasets, datasets_path, bucket):
datasets = get_available_datasets()['name'].tolist()

return [
_get_dataset_path(dataset, datasets_path, bucket)
_get_dataset_path(dataset, datasets_path, bucket, aws_key, aws_secret)
for dataset in datasets
]
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

install_requires = [
'appdirs>1.1.4,<2',
'boto3>=1.15.0,<2',
'compress-pickle>=1.2.0,<2',
'humanfriendly>=8.2,<9',
'numpy>=1.15.4,<2',
Expand Down

0 comments on commit 5df83bf

Please sign in to comment.