Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Update DDPG Example #1525

Merged
merged 22 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,12 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo.py \
python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
optimization.batch_size=10 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.num_workers=4 \
collector.env_per_collector=2 \
collector.collector_device=cuda:0 \
network.device=cuda:0 \
optimization.utd_ratio=1 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
logger.backend=
Expand Down Expand Up @@ -175,13 +174,12 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreame
python .github/unittest/helpers/coverage_run_parallel.py examples/ddpg/ddpg.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
optimization.batch_size=10 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.num_workers=2 \
collector.env_per_collector=1 \
collector.collector_device=cuda:0 \
network.device=cuda:0 \
optimization.utd_ratio=1 \
optim.utd_ratio=1 \
replay_buffer.size=120 \
env.name=Pendulum-v1 \
logger.backend=
Expand Down
25 changes: 13 additions & 12 deletions examples/ddpg/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Environment
# environment and task
env:
name: HalfCheetah-v3
task: ""
Expand All @@ -7,39 +7,40 @@ env:
frame_skip: 1
seed: 1

# Collection
# collector
collector:
total_frames: 1000000
init_random_frames: 10000
total_frames: 3_000_000
init_random_frames: 25_000
frames_per_batch: 1000
max_frames_per_traj: 1000
init_env_steps: 1000
async_collection: 1
collector_device: cpu
env_per_collector: 1
num_workers: 1

# Replay Buffer

# replay buffer
replay_buffer:
size: 1000000
prb: 0 # use prioritized experience replay

# Optimization
optimization:
# optimization
optim:
utd_ratio: 1.0
gamma: 0.99
loss_function: smooth_l1
loss_function: l2
lr: 3e-4
weight_decay: 2e-4
weight_decay: 0.0
batch_size: 256
target_update_polyak: 0.995

# network
network:
hidden_sizes: [256, 256]
activation: relu
device: "cuda:0"
noise_type: "ou" # ou or gaussian

# Logging
# logging
logger:
backend: wandb
mode: online
Expand Down
121 changes: 68 additions & 53 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@
The helper functions are coded in the utils.py associated with this script.
"""

import time

import hydra

import numpy as np
import torch
import torch.cuda
import tqdm

from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
log_metrics,
make_collector,
make_ddpg_agent,
make_environment,
Expand All @@ -33,6 +37,7 @@
def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.network.device)

# Create logger
exp_name = generate_exp_name("DDPG", cfg.env.exp_name)
logger = None
if cfg.logger.backend:
Expand All @@ -43,137 +48,147 @@ def main(cfg: "DictConfig"): # noqa: F821
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
)

# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)

# Create Environments
# Create environments
train_env, eval_env = make_environment(cfg)

# Create Agent
# Create agent
model, exploration_policy = make_ddpg_agent(cfg, train_env, eval_env, device)

# Create Loss Module and Target Updater
# Create DDPG loss
loss_module, target_net_updater = make_loss_module(cfg, model)

# Make Off-Policy Collector
# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)

# Make Replay Buffer
# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optimization.batch_size,
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
device=device,
)

# Make Optimizers
# Create optimizers
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)

rewards = []
rewards_eval = []

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
r0 = None
q_loss = None

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optimization.utd_ratio
* cfg.optim.utd_ratio
)
prb = cfg.replay_buffer.prb
env_per_collector = cfg.collector.env_per_collector
frames_per_batch, frame_skip = cfg.collector.frames_per_batch, cfg.env.frame_skip
eval_iter = cfg.logger.eval_iter
eval_rollout_steps = cfg.collector.max_frames_per_traj // frame_skip

for i, tensordict in enumerate(collector):
sampling_start = time.time()
for _, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
# Update exploration policy
exploration_policy.step(tensordict.numel())
# update weights of the inference policy

# Update weights of the inference policy
collector.update_policy_weights_()

if r0 is None:
r0 = tensordict["next", "reward"].sum(-1).mean().item()
pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

# optimization steps
# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
) = ([], [])
for _ in range(num_updates):
# sample from replay buffer
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample().clone()

# Compute loss
loss_td = loss_module(sampled_tensordict)

optimizer_critic.zero_grad()
optimizer_actor.zero_grad()

actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_value"]
(actor_loss + q_loss).backward()

# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_losses.append(q_loss.item())

# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

q_losses.append(q_loss.item())
actor_losses.append(actor_loss.item())

# update qnet_target params
# Update qnet_target params
target_net_updater.step()

# update priority
# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

rewards.append(
(i, tensordict["next", "reward"].sum().item() / env_per_collector)
)
train_log = {
"train_reward": rewards[-1][1],
"collected_frames": collected_frames,
}
if q_loss is not None:
train_log.update(
{
"actor_loss": np.mean(actor_losses),
"q_loss": np.mean(q_losses),
}
training_time = time.time() - training_start
episode_rewards = tensordict["next", "episode_reward"][
tensordict["next", "done"]
]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][
tensordict["next", "done"]
]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if logger is not None:
for key, value in train_log.items():
logger.log_scalar(key, value, step=collected_frames)

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch * frame_skip:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
exploration_policy,
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
rewards_eval.append((i, eval_reward))
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
if logger is not None:
logger.log_scalar(
"evaluation_reward", rewards_eval[-1][1], step=collected_frames
)
if len(rewards_eval):
pbar.set_description(
f"reward: {rewards[-1][1]: 4.4f} (r0 = {r0: 4.4f})," + eval_str
)
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time

log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
print(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
Loading
Loading