From 5df83bf2d205a6552c64d73c6949e8da5ce0aecc Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Mon, 10 May 2021 15:40:00 -0700 Subject: [PATCH] Allow reading datasets from private s3 bucket --- sdgym/benchmark.py | 10 ++++++++-- sdgym/datasets.py | 49 +++++++++++++++++++++++++++++----------------- setup.py | 1 + 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index c0e6e43c..ef9086d9 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -220,7 +220,7 @@ def _run_on_dask(jobs, verbose): def run(synthesizers, datasets=None, datasets_path=None, modalities=None, bucket=None, metrics=None, iterations=1, workers=1, cache_dir=None, show_progress=False, - timeout=None, output_path=None): + timeout=None, output_path=None, aws_key=None, aws_secret=None): """Run the SDGym benchmark and return a leaderboard. The ``synthesizers`` object can either be a single synthesizer or, an iterable of @@ -273,13 +273,19 @@ def run(synthesizers, datasets=None, datasets_path=None, modalities=None, bucket If an ``output_path`` is given, the generated leaderboard will be stored in the indicated path as a CSV file. The given path must be a complete path including the ``.csv`` filename. + aws_key (str): + If an ``aws_key`` is provided, the given access key id will be used to read + from the specified bucket. + aws_secret (str): + If an ``aws_secret`` is provided, the given secret access key will be used to read + from the specified bucket. Returns: pandas.DataFrame: A table containing one row per synthesizer + dataset + metric + iteration. """ synthesizers = get_synthesizers_dict(synthesizers) - datasets = get_dataset_paths(datasets, datasets_path, bucket) + datasets = get_dataset_paths(datasets, datasets_path, bucket, aws_key, aws_secret) run_id = os.getenv('RUN_ID') or str(uuid.uuid4())[:10] if cache_dir: diff --git a/sdgym/datasets.py b/sdgym/datasets.py index c247b10c..cd577aa1 100644 --- a/sdgym/datasets.py +++ b/sdgym/datasets.py @@ -1,11 +1,11 @@ import io import logging -import urllib.request from pathlib import Path -from xml.etree import ElementTree from zipfile import ZipFile import appdirs +import boto3 +import botocore import pandas as pd from sdv import Metadata @@ -17,21 +17,35 @@ TIMESERIES_FIELDS = ['sequence_index', 'entity_columns', 'context_columns', 'deepecho_version'] -def download_dataset(dataset_name, datasets_path=None, bucket=None): +def download_dataset(dataset_name, datasets_path=None, bucket=None, aws_key=None, aws_secret=None): datasets_path = datasets_path or DATASETS_PATH bucket = bucket or BUCKET - url = BUCKET_URL.format(bucket) + f'{dataset_name}.zip' - LOGGER.info('Downloading dataset %s from %s', dataset_name, url) - response = urllib.request.urlopen(url) - bytes_io = io.BytesIO(response.read()) + LOGGER.info('Downloading dataset %s from %s', dataset_name, bucket) + s3 = _get_s3_client(aws_key, aws_secret) + obj = s3.get_object(Bucket=bucket, Key=f'{dataset_name}.zip') + bytes_io = io.BytesIO(obj['Body'].read()) LOGGER.info('Extracting dataset into %s', datasets_path) with ZipFile(bytes_io) as zf: zf.extractall(datasets_path) -def _get_dataset_path(dataset, datasets_path, bucket=None): +def _get_s3_client(aws_key=None, aws_secret=None): + if aws_key is not None and aws_secret is not None: + # credentials available + return boto3.client( + 's3', + aws_access_key_id=aws_key, + aws_secret_access_key=aws_secret + ) + else: + # no credentials available, make unsigned requests + config = botocore.config.Config(signature_version=botocore.UNSIGNED) + return boto3.client('s3', config=config) + + +def _get_dataset_path(dataset, datasets_path, bucket=None, aws_key=None, aws_secret=None): dataset = Path(dataset) if dataset.exists(): return dataset @@ -41,7 +55,7 @@ def _get_dataset_path(dataset, datasets_path, bucket=None): if dataset_path.exists(): return dataset_path - download_dataset(dataset, datasets_path, bucket=bucket) + download_dataset(dataset, datasets_path, bucket=bucket, aws_key=aws_key, aws_secret=aws_secret) return dataset_path @@ -83,14 +97,13 @@ def load_tables(metadata): return real_data -def get_available_datasets(bucket=None): - bucket_url = BUCKET_URL.format(bucket or BUCKET) - response = urllib.request.urlopen(bucket_url) - tree = ElementTree.fromstring(response.read()) +def get_available_datasets(bucket=None, aws_key=None, aws_secret=None): + s3 = _get_s3_client(aws_key, aws_secret) + response = s3.list_objects(Bucket=bucket or BUCKET) datasets = [] - for content in tree.findall('{*}Contents'): - key = content.find('{*}Key').text - size = int(content.find('{*}Size').text) + for content in response['Contents']: + key = content['Key'] + size = int(content['Size']) if key.endswith('.zip'): datasets.append({ 'name': key[:-len('.zip')], @@ -118,7 +131,7 @@ def get_downloaded_datasets(datasets_path=None): return pd.DataFrame(datasets) -def get_dataset_paths(datasets, datasets_path, bucket): +def get_dataset_paths(datasets, datasets_path, bucket, aws_key, aws_secret): """Build the full path to datasets and ensure they exist.""" if datasets_path is None: datasets_path = DATASETS_PATH @@ -132,6 +145,6 @@ def get_dataset_paths(datasets, datasets_path, bucket): datasets = get_available_datasets()['name'].tolist() return [ - _get_dataset_path(dataset, datasets_path, bucket) + _get_dataset_path(dataset, datasets_path, bucket, aws_key, aws_secret) for dataset in datasets ] diff --git a/setup.py b/setup.py index 680ee907..6b1fa10e 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ install_requires = [ 'appdirs>1.1.4,<2', + 'boto3>=1.15.0,<2', 'compress-pickle>=1.2.0,<2', 'humanfriendly>=8.2,<9', 'numpy>=1.15.4,<2',