Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 11, 2024
1 parent 14bfdd8 commit 01c4f88
Show file tree
Hide file tree
Showing 10 changed files with 636 additions and 459 deletions.
39 changes: 0 additions & 39 deletions sota-implementations/ppo/config_atari.yaml

This file was deleted.

1 change: 1 addition & 0 deletions sota-implementations/ppo/config_atari.yaml
36 changes: 0 additions & 36 deletions sota-implementations/ppo/config_mujoco.yaml

This file was deleted.

1 change: 1 addition & 0 deletions sota-implementations/ppo/config_mujoco.yaml
14 changes: 9 additions & 5 deletions sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import hydra
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder
from torchrl.trainers.agents.ppo import AtariPPOTrainer


@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
Expand All @@ -28,7 +29,6 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_atari import eval_model, make_parallel_env, make_ppo_models

device = "cpu" if not torch.cuda.device_count() else "cuda"

Expand All @@ -40,12 +40,14 @@ def main(cfg: "DictConfig"): # noqa: F821
test_interval = cfg.logger.test_interval // frame_skip

# Create models (check utils_atari.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = AtariPPOTrainer.make_ppo_models(cfg.env.env_name)
actor, critic = actor.to(device), critic.to(device)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, "cpu"),
create_env_fn=AtariPPOTrainer.make_parallel_env(
cfg.env.env_name, cfg.env.num_envs, "cpu"
),
policy=actor,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
Expand Down Expand Up @@ -110,7 +112,9 @@ def main(cfg: "DictConfig"): # noqa: F821
logger_video = False

# Create test environment
test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True)
test_env = AtariPPOTrainer.make_parallel_env(
cfg.env.env_name, 1, device, is_test=True
)
if logger_video:
test_env = test_env.append_transform(
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels_int"])
Expand Down Expand Up @@ -223,7 +227,7 @@ def main(cfg: "DictConfig"): # noqa: F821
) // test_interval:
actor.eval()
eval_start = time.time()
test_rewards = eval_model(
test_rewards = AtariPPOTrainer.eval_model(
actor, test_env, num_episodes=cfg_logger_num_test_episodes
)
eval_time = time.time() - eval_start
Expand Down
12 changes: 7 additions & 5 deletions sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import hydra
from torchrl._utils import logger as torchrl_logger
from torchrl.record import VideoRecorder
from torchrl.trainers.agents.ppo import ContinuousControlPPOTrainer


@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
Expand All @@ -28,7 +29,6 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_mujoco import eval_model, make_env, make_ppo_models

device = "cpu" if not torch.cuda.device_count() else "cuda"
num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size
Expand All @@ -39,12 +39,12 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create models (check utils_mujoco.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = ContinuousControlPPOTrainer.make_ppo_models(cfg.env.env_name)
actor, critic = actor.to(device), critic.to(device)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
create_env_fn=ContinuousControlPPOTrainer.make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.collector.frames_per_batch,
total_frames=cfg.collector.total_frames,
Expand Down Expand Up @@ -102,7 +102,9 @@ def main(cfg: "DictConfig"): # noqa: F821
logger_video = False

# Create test environment
test_env = make_env(cfg.env.env_name, device, from_pixels=logger_video)
test_env = ContinuousControlPPOTrainer.make_env(
cfg.env.env_name, device, from_pixels=logger_video
)
if logger_video:
test_env = test_env.append_transform(
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
Expand Down Expand Up @@ -216,7 +218,7 @@ def main(cfg: "DictConfig"): # noqa: F821
) // cfg_logger_test_interval:
actor.eval()
eval_start = time.time()
test_rewards = eval_model(
test_rewards = ContinuousControlPPOTrainer.eval_model(
actor, test_env, num_episodes=cfg_logger_num_test_episodes
)
eval_time = time.time() - eval_start
Expand Down
Loading

0 comments on commit 01c4f88

Please sign in to comment.