Skip to content

Commit

Permalink
fix the summarizer
Browse files Browse the repository at this point in the history
  • Loading branch information
dcfidalgo committed Nov 14, 2023
1 parent 0dcba2a commit 19bc30b
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions src/slurm_sweeps/experiment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
import shutil
Expand All @@ -14,8 +15,11 @@
from .backend import Backend, SlurmBackend
from .constants import (
ASHA_PKL,
DB_CFG,
DB_ITERATION,
DB_LOGGED,
DB_PATH,
DB_TIMESTAMP,
DB_TRIAL_ID,
EXPERIMENT_NAME,
STORAGE_PATH,
Expand Down Expand Up @@ -253,12 +257,17 @@ def _print_summary(
# add database info
database = self._database.read(experiment=self._name)
if not database.empty:
metrics = [
col
for col in database.columns
if not (col.startswith("_") or col.endswith(DB_LOGGED))
]

if summarize_cfg_and_metrics is True:
summarize_cfg_and_metrics = [
col
for col in database.columns
if col not in [DB_ITERATION, DB_TRIAL_ID]
]
summarize_cfg_and_metrics = list(
json.loads(database[DB_CFG].iloc[0]).keys()
)
summarize_cfg_and_metrics += metrics
elif summarize_cfg_and_metrics is False:
summarize_cfg_and_metrics = []

Expand All @@ -269,8 +278,17 @@ def _print_summary(
continue

trial_dict["ITERATION"] = trial_df.iloc[-1][DB_ITERATION]
for key in summarize_cfg_and_metrics:
trial_dict[key] = trial_df.iloc[-1][key]
# adding cfg
for key, val in json.loads(trial_df.iloc[-1][DB_CFG]).items():
if key in summarize_cfg_and_metrics:
trial_dict[key] = val

# adding metrics
for metric in metrics:
if metric not in summarize_cfg_and_metrics:
continue
metric_df = trial_df[trial_df[f"{metric}{DB_LOGGED}"] == 1]
trial_dict[metric] = metric_df.iloc[-1][metric]

summary_df = pd.DataFrame(summary_dicts).set_index("TRIAL_ID")

Expand Down

0 comments on commit 19bc30b

Please sign in to comment.