Skip to content

Commit

Permalink
Write output to s3
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed May 17, 2021
1 parent a55dd1f commit b2626df
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,25 @@ def _run_on_dask(jobs, verbose):
return dask.compute(*persisted)


def _write_output_to_csv(scores, output_path, aws_key, aws_secret):
if not output_path:
return

if is_s3_path(output_path):
s3 = get_s3_client(aws_key, aws_secret)
bucket_name, key = parse_s3_path(output_path)
scores_csv = scores.to_csv(index=False)
s3.put_object(
Bucket=bucket_name,
Key=key,
Body=scores_csv.encode('utf-8'),
)
else:
scores.to_csv(output_path, index=False)

return


def run(synthesizers, datasets=None, datasets_path=None, modalities=None, bucket=None,
metrics=None, iterations=1, workers=1, cache_dir=None, show_progress=False,
timeout=None, output_path=None, aws_key=None, aws_secret=None):
Expand Down Expand Up @@ -392,7 +411,7 @@ def run(synthesizers, datasets=None, datasets_path=None, modalities=None, bucket
raise SDGymError("No valid Dataset/Synthesizer combination given")

scores = pd.concat(scores)
if output_path:
scores.to_csv(output_path, index=False)

_write_output_to_csv(scores, output_path, aws_key, aws_secret)

return scores

0 comments on commit b2626df

Please sign in to comment.