Skip to content

Commit

Permalink
Merge pull request #57 from JemmaLDaniel/fix/pull-neptune-data
Browse files Browse the repository at this point in the history
fix: unique file names and squash neptune stdout
  • Loading branch information
JemmaLDaniel authored Mar 6, 2024
2 parents 0829a63 + e9db8c3 commit 63895c9
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions marl_eval/json_tools/json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import json
import logging
import os
import zipfile
from collections import defaultdict
Expand Down Expand Up @@ -135,16 +136,23 @@ def pull_neptune_data(
if not os.path.exists(store_directory):
os.makedirs(store_directory)

# Suppress neptune logger
neptune_logger = logging.getLogger("neptune")
neptune_logger.setLevel(logging.ERROR)

# Download and unzip the data
for run_id in tqdm(run_ids, desc="Downloading Neptune Data"):
run = neptune.init_run(project=project_name, with_id=run_id, mode="read-only")
for data_key in run.get_structure()[neptune_data_key].keys():
file_path = f"{store_directory}/{data_key}"
for j, data_key in enumerate(
run.get_structure()[neptune_data_key].keys(), start=1
):
# Create a unique filename
file_path = f"{store_directory}/{data_key}_{run_id}_{j}"
run[f"{neptune_data_key}/{data_key}"].download(destination=file_path)
# Try to unzip the file else continue to the next file
try:
with zipfile.ZipFile(file_path, "r") as zip_ref:
# Create a directory with to store unzipped data
# Create a directory to store unzipped data
os.makedirs(f"{file_path}_unzip", exist_ok=True)
# Unzip the data
zip_ref.extractall(f"{file_path}_unzip")
Expand All @@ -156,7 +164,13 @@ def pull_neptune_data(
# unzipped.
continue
except Exception as e:
print(f"An error occurred while unzipping or storing {file_path}: {e}")
print(
f"The following error occurred while unzipping or storing JSON \
data for run {run_id} at path {file_path}: {e}"
)
run.stop()

# Restore neptune logger level
neptune_logger.setLevel(logging.INFO)

print(f"{Fore.CYAN}{Style.BRIGHT}Data downloaded successfully!{Style.RESET_ALL}")

0 comments on commit 63895c9

Please sign in to comment.