Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/algo eval #1074

Merged
merged 60 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
95cbfe6
added explicit env seeding for train and test envs
Mar 6, 2024
32cd3b4
logger updates
Mar 11, 2024
734119e
logger updates
Mar 12, 2024
5762d2c
extend hl experiment builder
Mar 12, 2024
6c1bd85
add mujoco example with multiple runs and performance plots
Mar 12, 2024
f730782
Merge branch 'thuml_master' into feature/algo-eval
Mar 12, 2024
d9a612a
format, type check and small fixes
Mar 12, 2024
a7898b1
small fix
Mar 12, 2024
5259d5f
Merge branch 'thuml_master' into feature/algo-eval
Mar 15, 2024
516c956
Merge branch 'thuml_master' into feature/algo-eval
Mar 25, 2024
d9a2017
updates
Mar 26, 2024
2e3f0b5
move doc string
Mar 26, 2024
85204b1
added matplotlib dependency
Mar 26, 2024
5a3f229
added pandas dependency
Mar 26, 2024
dffe8cd
fix pandas dependency
Mar 26, 2024
e95fa26
replace assert with exception in wandb logger
Mar 27, 2024
18d8ffa
removed name shortener
Mar 27, 2024
6d9b697
restructured and moved RLiableExperimentResult
Mar 27, 2024
9055eb5
removed attributes from pandas logger
Mar 27, 2024
ce5fa0d
fixed logger test
Mar 27, 2024
9c645ff
pleased the mypy gods
Mar 27, 2024
ec2c5c1
added primitive joblib launcher
Mar 27, 2024
929dd10
Merge branch 'thuml_master' into feature/algo-eval
Mar 28, 2024
f2e10b0
Merge branch 'thuml_master' into feature/algo-eval
Apr 2, 2024
85e910e
Added launcher interface and registry
Apr 3, 2024
ed12b16
Added contextmanager for ExperimentBuilder modifications
Apr 3, 2024
7d479af
Experiment: use name attribute during run except if overriden explicitly
Apr 3, 2024
60e75e3
Adjusted launchers to new interface
Apr 3, 2024
c6ee225
Merge branch 'thuml_master' into feature/algo-eval
Apr 5, 2024
152b6d5
create evaluation package
Apr 8, 2024
85909d3
updated examples
Apr 8, 2024
0957d2d
some documentation and mypy stuff
Apr 8, 2024
d7d3a54
handle experiment name if name str is empty
Apr 8, 2024
6925fec
more epochs
Apr 8, 2024
65e7cfa
updated dependencies
Apr 8, 2024
2e410ee
made loading from disk safer
Apr 8, 2024
a5988ac
clean up...
Apr 8, 2024
c751b6a
mypy stuff
Apr 8, 2024
1eb7bae
spelling word list
Apr 8, 2024
135c376
removed unnecessary + 1
Apr 9, 2024
769b97f
Merge branch 'master' into feature/algo-eval
MischaPanch Apr 15, 2024
617efe4
Merge branch 'aai-master' into feature/algo-eval
Apr 17, 2024
1a9b5d0
removed pandas logger
Apr 17, 2024
49f5b12
fixed rliable dependency and some docs
Apr 17, 2024
6146ad2
updated lock file
Apr 17, 2024
7ebcf93
suppressed ImportError on optional dependencies
Apr 17, 2024
0c8b4df
added eval to pytest.yml and removed contextlib suppress
Apr 18, 2024
3b1ec50
lint
Apr 18, 2024
c27b577
Merge branch 'aai-master' into feature/algo-eval
Apr 18, 2024
32c8eb1
install rliable with https
Apr 18, 2024
19f3fdf
updated lint_and_docs.yml
Apr 18, 2024
0592b6a
Renamed and commented `restore_logged_data` in TensorboardLogger [ski…
Apr 20, 2024
6183f70
Removed old and deprecated BasicLogger
Apr 20, 2024
10d1d34
Logging: improved typing using recursive type definition
Apr 20, 2024
96e42dc
Env: added argparse deps tp eval extra
Apr 20, 2024
34d1fec
Experiment: use absolute paths
Apr 20, 2024
9fafe7a
Rliable eval: added docstring, improved figure layout, option to disp…
Apr 20, 2024
b42ad64
Launcher: don't modify user input, set loky as default backend
Apr 20, 2024
31f40c9
Multi-experiment script: run sequentially by default, added docstring
Apr 20, 2024
edda9af
Merge branch 'master' into feature/algo-eval
MischaPanch Apr 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions examples/mujoco/mujoco_ppo_hl_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import torch

from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.env import VectorEnvType
from tianshou.highlevel.evaluation import RLiableExperimentResult
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.evaluation import RLiableExperimentResult
from tianshou.highlevel.experiment import (
ExperimentConfig,
PPOExperimentBuilder,
Expand Down Expand Up @@ -65,6 +64,7 @@ def main(
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
num_test_episodes=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect,
Expand All @@ -75,7 +75,6 @@ def main(
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
)

experiments = (
Expand Down
8 changes: 4 additions & 4 deletions test/base/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_flatten_dict_basic(
| dict[str, dict[str, dict[str, int]]],
expected_output: dict[str, int],
) -> None:
logger = TensorboardLogger(SummaryWriter('log/logger'))
logger = TensorboardLogger(SummaryWriter("log/logger"))
result = logger.prepare_dict_for_logging(input_dict)
assert result == expected_output

Expand All @@ -38,7 +38,7 @@ def test_flatten_dict_custom_delimiter(
delimiter: Literal["|", "."],
expected_output: dict[str, int],
) -> None:
logger = TensorboardLogger(SummaryWriter('log/logger'))
logger = TensorboardLogger(SummaryWriter("log/logger"))
result = logger.prepare_dict_for_logging(input_dict, delimiter=delimiter)
assert result == expected_output

Expand All @@ -59,7 +59,7 @@ def test_flatten_dict_exclude_arrays(
exclude_arrays: bool,
expected_output: dict[str, np.ndarray],
) -> None:
logger = TensorboardLogger(SummaryWriter('log/logger'))
logger = TensorboardLogger(SummaryWriter("log/logger"))
result = logger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays)
assert result.keys() == expected_output.keys()
for val1, val2 in zip(result.values(), expected_output.values(), strict=True):
Expand All @@ -76,6 +76,6 @@ def test_flatten_dict_invalid_values_filtered_out(
input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]],
expected_output: dict[str, int],
) -> None:
logger = TensorboardLogger(SummaryWriter('log/logger'))
logger = TensorboardLogger(SummaryWriter("log/logger"))
result = logger.prepare_dict_for_logging(input_dict)
assert result == expected_output
80 changes: 64 additions & 16 deletions tianshou/highlevel/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,46 @@
import os
from dataclasses import asdict, dataclass

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as sst
from rliable import library as rly
from rliable import plot_utils

from tianshou.highlevel.experiment import Experiment


@dataclass
class LoggedSummaryData:
mean: np.ndarray
std: np.ndarray
max: np.ndarray
min: np.ndarray


@dataclass
class LoggedCollectStats:
env_step: np.ndarray
n_collected_episodes: np.ndarray
n_collected_steps: np.ndarray
collect_time: np.ndarray
collect_speed: np.ndarray
returns_stat: LoggedSummaryData
lens_stat: LoggedSummaryData

@classmethod
def from_data_dict(cls, data: dict) -> "LoggedCollectStats":
return cls(
env_step=np.array(data["env_step"]),
n_collected_episodes=np.array(data["n_collected_episodes"]),
n_collected_steps=np.array(data["n_collected_steps"]),
collect_time=np.array(data["collect_time"]),
collect_speed=np.array(data["collect_speed"]),
returns_stat=LoggedSummaryData(**data["returns_stat"]),
lens_stat=LoggedSummaryData(**data["lens_stat"]),
)


@dataclass
class RLiableExperimentResult:
"""The result of an experiment that can be used with the rliable library."""
Expand All @@ -26,7 +61,7 @@ def load_from_disk(cls, exp_dir: str) -> "RLiableExperimentResult":
:param exp_dir: The directory from where the experiment results are restored.
"""
test_episode_returns = []
test_data = None
env_step_at_test = None

for entry in os.scandir(exp_dir):
if entry.name.startswith(".") or not entry.is_dir():
Expand All @@ -41,22 +76,27 @@ def load_from_disk(cls, exp_dir: str) -> "RLiableExperimentResult":
)
data = logger.restore_logged_data(entry.path)

test_data = data["test"]
if "test" not in data or not data["test"]:
continue
test_data = LoggedCollectStats.from_data_dict(data["test"])

test_episode_returns.append(test_data["returns_stat"]["mean"])
test_episode_returns.append(test_data.returns_stat.mean)
env_step_at_test = test_data.env_step

if test_data is None:
if not test_episode_returns or env_step_at_test is None:
raise ValueError(f"No experiment data found in {exp_dir}.")

env_step = test_data["env_step"]

return cls(
test_episode_returns_RE=np.array(test_episode_returns),
env_steps_E=np.array(env_step),
env_steps_E=np.array(env_step_at_test),
exp_dir=exp_dir,
)

def _get_rliable_data(self, algo_name: str | None = None, score_thresholds: np.ndarray = None) -> (dict, np.ndarray, np.ndarray):
def _get_rliable_data(
self,
algo_name: str | None = None,
score_thresholds: np.ndarray | None = None,
) -> tuple[dict, np.ndarray, np.ndarray]:
"""Return the data in the format expected by the rliable library.

:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
Expand All @@ -67,7 +107,11 @@ def _get_rliable_data(self, algo_name: str | None = None, score_thresholds: np.n
:return: A tuple score_dict, env_steps, and score_thresholds.
"""
if score_thresholds is None:
score_thresholds = np.linspace(np.min(self.test_episode_returns_RE), np.max(self.test_episode_returns_RE), 101)
score_thresholds = np.linspace(
np.min(self.test_episode_returns_RE),
np.max(self.test_episode_returns_RE),
101,
)

if algo_name is None:
algo_name = os.path.basename(self.exp_dir)
Expand All @@ -76,22 +120,26 @@ def _get_rliable_data(self, algo_name: str | None = None, score_thresholds: np.n

return score_dict, self.env_steps_E, score_thresholds

def eval_results(self, algo_name: str | None = None, score_thresholds: np.ndarray = None, save_figure: bool = False):
def eval_results(
self,
algo_name: str | None = None,
score_thresholds: np.ndarray | None = None,
save_figure: bool = False,
) -> tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]:
"""Evaluate the results of an experiment and create a sample efficiency curve and a performance profile.

:param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
is set to the experiment dir.
:param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
from the minimum and maximum test episode returns.
:param save_figure: If True, the figures are saved to the experiment directory.

:return: The created figures and axes.
"""
import matplotlib.pyplot as plt
import scipy.stats as sst
from rliable import library as rly
from rliable import plot_utils

score_dict, env_steps, score_thresholds = self._get_rliable_data(algo_name, score_thresholds)
score_dict, env_steps, score_thresholds = self._get_rliable_data(
algo_name,
score_thresholds,
)

iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm)
Expand Down
5 changes: 4 additions & 1 deletion tianshou/utils/logger/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ def restore_data(self) -> tuple[int, int, int]:
"""

@abstractmethod
def restore_logged_data(self, log_path: str) -> dict[str, VALID_LOG_VALS_TYPE]:
def restore_logged_data(
self,
log_path: str,
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
"""Load the logged data from disk for post-processing.

:return: a dict containing the logged data.
Expand Down
6 changes: 5 additions & 1 deletion tianshou/utils/logger/pandas_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def restore_data(self) -> tuple[int, int, int]:

return epoch, env_step, gradient_step

def restore_logged_data(self, log_path: str) -> dict[str, Any]:
def restore_logged_data(
self,
log_path: str,
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
data = {}

def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]:
Expand All @@ -118,4 +121,5 @@ def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]:
data[scope] = merge_dicts(dict_list)
except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data")
data[scope] = {}
return data
9 changes: 6 additions & 3 deletions tianshou/utils/logger/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,17 @@ def restore_data(self) -> tuple[int, int, int]:

return epoch, env_step, gradient_step

def restore_logged_data(self, log_path: str) -> dict[str, Any]:
def restore_logged_data(
self,
log_path: str,
) -> dict[str, dict[str, VALID_LOG_VALS_TYPE | dict[str, VALID_LOG_VALS_TYPE]]]:
ea = event_accumulator.EventAccumulator(log_path)
ea.Reload()

def add_to_dict(data_dict: dict[str, Any], keys: list[str], value: Any) -> None:
current_dict = data_dict
for key in keys[:-1]:
current_dict = current_dict.setdefault(key, {})
for k in keys[:-1]:
current_dict = current_dict.setdefault(k, {})
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
current_dict[keys[-1]] = value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since only the last key is assigned a value, I guess this is a fairly specialised add_to_dict.
You might wish to reflect this in the naming.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is a bit complicated here, I renamed the function name and added more comments to explain it. Defaultdict wouldn't help here, unless there's something like an infinitely nested default dict ^^


data: dict[str, Any] = {}
Expand Down
Loading