diff --git a/.gitignore b/.gitignore index 21d3acc8c..46b806b7c 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ cleanrl/videos/* benchmark/**/*.svg benchmark/**/*.pkl mjkey.txt +.idea # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/benchmark/apo.sh b/benchmark/apo.sh new file mode 100644 index 000000000..0933d3e99 --- /dev/null +++ b/benchmark/apo.sh @@ -0,0 +1,28 @@ +export WANDB_ENTITY=openrlbenchmark + +poetry install -E "mujoco pybullet" +python -c "import mujoco_py" + +#xvfb-run -a python -m cleanrl_utils.benchmark \ +# --env-ids Swimmer-v3 \ +# --command "poetry run python cleanrl/apo_continuous_action.py --track --capture-video --gae-lambda=0.99" \ +# --num-seeds 3 \ +# --workers 3 + +#xvfb-run -a python -m cleanrl_utils.benchmark \ +# --env-ids HalfCheetah-v3 \ +# --command "poetry run python cleanrl/apo_continuous_action.py --track --capture-video --gae-lambda=0.9" \ +# --num-seeds 3 \ +# --workers 3 + +#xvfb-run -a python -m cleanrl_utils.benchmark \ +# --env-ids Ant-v3 \ +# --command "poetry run python cleanrl/apo_continuous_action.py --track --capture-video --gae-lambda=0.9" \ +# --num-seeds 3 \ +# --workers 3 + +xvfb-run -a python -m cleanrl_utils.benchmark \ + --env-ids Hopper-v3 \ + --command "poetry run python cleanrl/apo_continuous_action.py --track --capture-video --gae-lambda=0.99" \ + --num-seeds 3 \ + --workers 3 \ No newline at end of file diff --git a/cleanrl/apo_continuous_action.py b/cleanrl/apo_continuous_action.py new file mode 100644 index 000000000..247989f93 --- /dev/null +++ b/cleanrl/apo_continuous_action.py @@ -0,0 +1,325 @@ +import argparse +import os +import random +import time +from distutils.util import strtobool + +import gym +import numpy as np +import pybullet_envs # noqa +import torch +import torch.nn as nn +import torch.optim as optim +from torch.distributions.normal import Normal +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="weather to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="HalfCheetahBulletEnv-v0", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=1_000_000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=3e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=1, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=2048, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=32, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=10, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.2, + help="the surrogate clipping coefficient") + parser.add_argument("--ent-coef", type=float, default=0.0, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=0.5, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=0.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + + # APO specific arguments + parser.add_argument("--tau", type=float, default=0.1, + help="exponential moving average coefficient for reward and value rate estimates") + parser.add_argument("--value-constraint", type=float, default=1.0, + help="lagrangian multiplier, controls the bias in value estimation (average value constraint)") + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + # fmt: on + return args + + +def make_env(env_id, seed, idx, capture_video, run_name): + def thunk(): + env = gym.make(env_id) + env = gym.wrappers.RecordEpisodeStatistics(env) + if capture_video and idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + env = gym.wrappers.ClipAction(env) + env = gym.wrappers.NormalizeObservation(env) + env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) + # IMPORTANT: in average reward setting gamma is always 1.0 + env = gym.wrappers.NormalizeReward(env, gamma=1.0) + env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10)) + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + return thunk + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + +class Agent(nn.Module): + def __init__(self, envs): + super().__init__() + self.critic = nn.Sequential( + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 1), std=1.0), + ) + self.actor_mean = nn.Sequential( + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 64)), + nn.Tanh(), + layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01), + ) + self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape))) + + def get_value(self, x): + return self.critic(x) + + def get_action_and_value(self, x, action=None): + action_mean = self.actor_mean(x) + action_logstd = self.actor_logstd.expand_as(action_mean) + action_std = torch.exp(action_logstd) + probs = Normal(action_mean, action_std) + if action is None: + action = probs.sample() + return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # env setup + envs = gym.vector.SyncVectorEnv( + [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] + ) + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + + agent = Agent(envs).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # IMPORTANT: reward and value rate estimates, they are crucial for the APO algorithm + reward_rate, value_rate = 0.0, 0.0 + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs = torch.Tensor(envs.reset()).to(device) + next_done = torch.zeros(args.num_envs).to(device) + num_updates = args.total_timesteps // args.batch_size + + for update in range(1, num_updates + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (update - 1.0) / num_updates + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += 1 * args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, done, info = envs.step(action.cpu().numpy()) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) + + for item in info: + if "episode" in item.keys(): + print(f"global_step={global_step}, episodic_return={item['episode']['r']}") + writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step) + writer.add_scalar("charts/episodic_mean_reward", item["episode"]["r"] / item["episode"]["l"], global_step) + break + + # ALGO LOGIC: update reward and value rate estimates + mean_batch_value = values.mean().item() + reward_rate = (1 - args.tau) * reward_rate + args.tau * rewards.mean().item() + value_rate = (1 - args.tau) * value_rate + args.tau * mean_batch_value + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + # IMPORTANT: GAE for average reward value function, for formulas see APO paper + delta = rewards[t] - reward_rate + nextnonterminal * nextvalues - values[t] + advantages[t] = lastgaelam = delta + args.gae_lambda * nextnonterminal * lastgaelam + # A = R - V => R = A + V + returns = advantages + values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + # IMPORTANT: average value constraint as it is called in the paper + target_return = b_returns[mb_inds] - args.value_constraint * value_rate + v_loss = 0.5 * ((newvalue - target_return) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None: + if approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + # APO specific metrics + writer.add_scalar("losses/reward_rate", reward_rate, global_step) + writer.add_scalar("losses/value_rate", value_rate, global_step) + writer.add_scalar("losses/mean_batch_value", mean_batch_value, global_step) + + envs.close() + writer.close()