Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/algo eval #1074

Merged
merged 60 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
95cbfe6
added explicit env seeding for train and test envs
Mar 6, 2024
32cd3b4
logger updates
Mar 11, 2024
734119e
logger updates
Mar 12, 2024
5762d2c
extend hl experiment builder
Mar 12, 2024
6c1bd85
add mujoco example with multiple runs and performance plots
Mar 12, 2024
f730782
Merge branch 'thuml_master' into feature/algo-eval
Mar 12, 2024
d9a612a
format, type check and small fixes
Mar 12, 2024
a7898b1
small fix
Mar 12, 2024
5259d5f
Merge branch 'thuml_master' into feature/algo-eval
Mar 15, 2024
516c956
Merge branch 'thuml_master' into feature/algo-eval
Mar 25, 2024
d9a2017
updates
Mar 26, 2024
2e3f0b5
move doc string
Mar 26, 2024
85204b1
added matplotlib dependency
Mar 26, 2024
5a3f229
added pandas dependency
Mar 26, 2024
dffe8cd
fix pandas dependency
Mar 26, 2024
e95fa26
replace assert with exception in wandb logger
Mar 27, 2024
18d8ffa
removed name shortener
Mar 27, 2024
6d9b697
restructured and moved RLiableExperimentResult
Mar 27, 2024
9055eb5
removed attributes from pandas logger
Mar 27, 2024
ce5fa0d
fixed logger test
Mar 27, 2024
9c645ff
pleased the mypy gods
Mar 27, 2024
ec2c5c1
added primitive joblib launcher
Mar 27, 2024
929dd10
Merge branch 'thuml_master' into feature/algo-eval
Mar 28, 2024
f2e10b0
Merge branch 'thuml_master' into feature/algo-eval
Apr 2, 2024
85e910e
Added launcher interface and registry
Apr 3, 2024
ed12b16
Added contextmanager for ExperimentBuilder modifications
Apr 3, 2024
7d479af
Experiment: use name attribute during run except if overriden explicitly
Apr 3, 2024
60e75e3
Adjusted launchers to new interface
Apr 3, 2024
c6ee225
Merge branch 'thuml_master' into feature/algo-eval
Apr 5, 2024
152b6d5
create evaluation package
Apr 8, 2024
85909d3
updated examples
Apr 8, 2024
0957d2d
some documentation and mypy stuff
Apr 8, 2024
d7d3a54
handle experiment name if name str is empty
Apr 8, 2024
6925fec
more epochs
Apr 8, 2024
65e7cfa
updated dependencies
Apr 8, 2024
2e410ee
made loading from disk safer
Apr 8, 2024
a5988ac
clean up...
Apr 8, 2024
c751b6a
mypy stuff
Apr 8, 2024
1eb7bae
spelling word list
Apr 8, 2024
135c376
removed unnecessary + 1
Apr 9, 2024
769b97f
Merge branch 'master' into feature/algo-eval
MischaPanch Apr 15, 2024
617efe4
Merge branch 'aai-master' into feature/algo-eval
Apr 17, 2024
1a9b5d0
removed pandas logger
Apr 17, 2024
49f5b12
fixed rliable dependency and some docs
Apr 17, 2024
6146ad2
updated lock file
Apr 17, 2024
7ebcf93
suppressed ImportError on optional dependencies
Apr 17, 2024
0c8b4df
added eval to pytest.yml and removed contextlib suppress
Apr 18, 2024
3b1ec50
lint
Apr 18, 2024
c27b577
Merge branch 'aai-master' into feature/algo-eval
Apr 18, 2024
32c8eb1
install rliable with https
Apr 18, 2024
19f3fdf
updated lint_and_docs.yml
Apr 18, 2024
0592b6a
Renamed and commented `restore_logged_data` in TensorboardLogger [ski…
Apr 20, 2024
6183f70
Removed old and deprecated BasicLogger
Apr 20, 2024
10d1d34
Logging: improved typing using recursive type definition
Apr 20, 2024
96e42dc
Env: added argparse deps tp eval extra
Apr 20, 2024
34d1fec
Experiment: use absolute paths
Apr 20, 2024
9fafe7a
Rliable eval: added docstring, improved figure layout, option to disp…
Apr 20, 2024
b42ad64
Launcher: don't modify user input, set loky as default backend
Apr 20, 2024
31f40c9
Multi-experiment script: run sequentially by default, added docstring
Apr 20, 2024
edda9af
Merge branch 'master' into feature/algo-eval
MischaPanch Apr 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint_and_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
key: venv-${{ hashFiles('poetry.lock') }}
- name: Install the project dependencies
run: |
poetry install --with dev
poetry install --with dev --extras "eval"
- name: Lint
run: poetry run poe lint
- name: Types
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
key: venv-${{ hashFiles('poetry.lock') }}
- name: Install the project dependencies
run: |
poetry install --with dev --extras "envpool"
poetry install --with dev --extras "envpool eval"
- name: wandb login
run: |
poetry run wandb login e2366d661b89f2bee877c40bee15502d67b7abef
Expand Down
6 changes: 5 additions & 1 deletion docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,8 @@ BA
BH
BO
BD

configs
postfix
backend
rliable
hl
10 changes: 8 additions & 2 deletions examples/atari/atari_dqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ def main(
replay_buffer_save_only_last_obs=True,
)

env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)

builder = (
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand Down Expand Up @@ -98,7 +104,7 @@ def main(
)

experiment = builder.build()
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
10 changes: 8 additions & 2 deletions examples/atari/atari_iqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ def main(
replay_buffer_save_only_last_obs=True,
)

env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)

experiment = (
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand All @@ -90,7 +96,7 @@ def main(
.with_epoch_stop_callback(AtariEpochStopCallback(task))
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
10 changes: 8 additions & 2 deletions examples/atari/atari_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@ def main(
replay_buffer_save_only_last_obs=True,
)

env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack)
env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)

builder = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand Down Expand Up @@ -109,7 +115,7 @@ def main(
),
)
experiment = builder.build()
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
10 changes: 8 additions & 2 deletions examples/atari/atari_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ def main(
replay_buffer_save_only_last_obs=True,
)

env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)

builder = (
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand Down Expand Up @@ -97,7 +103,7 @@ def main(
),
)
experiment = builder.build()
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
8 changes: 5 additions & 3 deletions examples/atari/atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def make_atari_env(

:return: a tuple of (single env, training envs, test envs).
"""
env_factory = AtariEnvFactory(task, seed, frame_stack, scale=bool(scale))
env_factory = AtariEnvFactory(task, seed, seed + training_num, frame_stack, scale=bool(scale))
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
envs = env_factory.create_envs(training_num, test_num)
return envs.env, envs.train_envs, envs.test_envs

Expand All @@ -392,7 +392,8 @@ class AtariEnvFactory(EnvFactoryRegistered):
def __init__(
self,
task: str,
seed: int,
train_seed: int,
test_seed: int,
frame_stack: int,
scale: bool = False,
use_envpool_if_available: bool = True,
Expand All @@ -409,7 +410,8 @@ def __init__(
log.info("Not using envpool, because it is not available")
super().__init__(
task=task,
seed=seed,
train_seed=train_seed,
test_seed=test_seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
envpool_factory=envpool_factory,
)
Expand Down
8 changes: 7 additions & 1 deletion examples/discrete/discrete_dqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
def main() -> None:
experiment = (
DQNExperimentBuilder(
EnvFactoryRegistered(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY),
EnvFactoryRegistered(
task="CartPole-v1",
seed=0,
venv_type=VectorEnvType.DUMMY,
train_seed=0,
test_seed=10,
),
ExperimentConfig(
persistence_enabled=False,
watch=True,
Expand Down
9 changes: 7 additions & 2 deletions examples/mujoco/mujoco_a2c_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ def main(
repeat_per_collect=repeat_per_collect,
)

env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)

experiment = (
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand All @@ -78,7 +83,7 @@ def main(
.with_critic_factory_default(hidden_sizes, nn.Tanh)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
9 changes: 7 additions & 2 deletions examples/mujoco/mujoco_ddpg_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def main(
start_timesteps_random=True,
)

env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=False,
)

experiment = (
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand All @@ -69,7 +74,7 @@ def main(
.with_critic_factory_default(hidden_sizes)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
8 changes: 5 additions & 3 deletions examples/mujoco/mujoco_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def make_mujoco_env(

:return: a tuple of (single env, training envs, test envs).
"""
envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs(
envs = MujocoEnvFactory(task, seed, seed + num_train_envs, obs_norm=obs_norm).create_envs(
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
num_train_envs,
num_test_envs,
)
Expand Down Expand Up @@ -73,13 +73,15 @@ class MujocoEnvFactory(EnvFactoryRegistered):
def __init__(
self,
task: str,
seed: int,
train_seed: int,
test_seed: int,
obs_norm: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM,
) -> None:
super().__init__(
task=task,
seed=seed,
train_seed=train_seed,
test_seed=test_seed,
venv_type=venv_type,
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
)
Expand Down
9 changes: 7 additions & 2 deletions examples/mujoco/mujoco_npg_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ def main(
repeat_per_collect=repeat_per_collect,
)

env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)

experiment = (
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand All @@ -80,7 +85,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
9 changes: 7 additions & 2 deletions examples/mujoco/mujoco_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def main(
repeat_per_collect=repeat_per_collect,
)

env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)

experiment = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand Down Expand Up @@ -90,7 +95,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
Loading
Loading