Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] GAIL compatibility with compile #2573

Open
wants to merge 28 commits into
base: gh/vmoens/42/base
Choose a base branch
from
Open
5 changes: 5 additions & 0 deletions sota-implementations/gail/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ gail:
gp_lambda: 10.0
device: null

compile:
compile: False
compile_mode:
cudagraphs: False

replay_buffer:
dataset: halfcheetah-expert-v2
batch_size: 256
152 changes: 88 additions & 64 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
from ppo_utils import eval_model, make_env, make_ppo_models
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from tensordict.nn import CudaGraphModule
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import ClipPPOLoss, GAILLoss
from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers
from torchrl.objectives.value.advantages import GAE
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -69,20 +70,9 @@ def main(cfg: "DictConfig"): # noqa: F821
np.random.seed(cfg.env.seed)

# Create models (check utils_mujoco.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = make_ppo_models(cfg.env.env_name, compile=cfg.compile.compile)
actor, critic = actor.to(device), critic.to(device)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.ppo.collector.frames_per_batch,
total_frames=cfg.ppo.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
)

# Create data buffer
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch),
Expand Down Expand Up @@ -111,6 +101,30 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
optim = group_optimizers(actor_optim, critic_optim)
del actor_optim, critic_optim

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.ppo.collector.frames_per_batch,
total_frames=cfg.ppo.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
compile_policy={"mode": compile_mode} if compile_mode is not None else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
Expand Down Expand Up @@ -138,32 +152,9 @@ def main(cfg: "DictConfig"): # noqa: F821
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
)
test_env.eval()
num_network_updates = torch.zeros((), dtype=torch.int64, device=device)

# Training loop
collected_frames = 0
num_network_updates = 0
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)

# extract cfg variables
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
cfg_optim_lr = cfg.ppo.optim.lr
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes

for i, data in enumerate(collector):

log_info = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Update discriminator
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)
def update(data, expert_data, num_network_updates=num_network_updates):
# Add collector data to expert data
expert_data.set(
discriminator_loss.tensor_keys.collector_action,
Expand All @@ -176,9 +167,9 @@ def main(cfg: "DictConfig"): # noqa: F821
d_loss = discriminator_loss(expert_data)

# Backward pass
discriminator_optim.zero_grad()
d_loss.get("loss").backward()
discriminator_optim.step()
discriminator_optim.zero_grad(set_to_none=True)

# Compute discriminator reward
with torch.no_grad():
Expand All @@ -188,32 +179,19 @@ def main(cfg: "DictConfig"): # noqa: F821
# Set discriminator rewards to tensordict
data.set(("next", "reward"), d_rewards)

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]
log_info.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)
# Update PPO
for _ in range(cfg_loss_ppo_epochs):

# Compute GAE
with torch.no_grad():
data = adv_module(data)
data_reshape = data.reshape(-1)

# Update the data buffer
data_buffer.empty()
data_buffer.extend(data_reshape)

for _, batch in enumerate(data_buffer):

# Get a data batch
batch = batch.to(device)
for batch in data_buffer:
optim.zero_grad(set_to_none=True)

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
Expand All @@ -233,20 +211,66 @@ def main(cfg: "DictConfig"): # noqa: F821
actor_loss = loss["loss_objective"] + loss["loss_entropy"]

# Backward pass
actor_loss.backward()
critic_loss.backward()
(actor_loss + critic_loss).backward()

# Update the networks
actor_optim.step()
critic_optim.step()
actor_optim.zero_grad()
critic_optim.zero_grad()
optim.step()
return d_loss.detach()

if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Training loop
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)

# extract cfg variables
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
cfg_optim_lr = cfg.ppo.optim.lr
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes

for i, data in enumerate(collector):

log_info = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Update discriminator
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)

d_loss = update(data, expert_data)

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]

log_info.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)

log_info.update(
{
"train/actor_loss": actor_loss.item(),
"train/critic_loss": critic_loss.item(),
"train/discriminator_loss": d_loss["loss"].item(),
# "train/actor_loss": actor_loss.item(),
# "train/critic_loss": critic_loss.item(),
"train/discriminator_loss": d_loss["loss"],
"train/lr": alpha * cfg_optim_lr,
"train/clip_epsilon": (
alpha * cfg_loss_clip_epsilon
Expand Down
7 changes: 4 additions & 3 deletions sota-implementations/gail/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False)
# --------------------------------------------------------------------


def make_ppo_models_state(proof_environment):
def make_ppo_models_state(proof_environment, compile):

# Define input shape
input_shape = proof_environment.observation_spec["observation"].shape
Expand All @@ -54,6 +54,7 @@ def make_ppo_models_state(proof_environment):
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"tanh_loc": False,
"safe_tanh": not compile,
}

# Define policy architecture
Expand Down Expand Up @@ -116,9 +117,9 @@ def make_ppo_models_state(proof_environment):
return policy_module, value_module


def make_ppo_models(env_name):
def make_ppo_models(env_name, compile):
proof_environment = make_env(env_name, device="cpu")
actor, critic = make_ppo_models_state(proof_environment)
actor, critic = make_ppo_models_state(proof_environment, compile=compile)
return actor, critic


Expand Down
Loading