-
Notifications
You must be signed in to change notification settings - Fork 675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PPO timeout proper handling #198
Comments
Proper timeout handling is definitely an interesting issue I want to look into further. I think the best resource is DLR-RM/stable-baselines3#658. Happy to take a look into this together :) If you are submitting a fix PR, I highly suggest doing it just on |
Thanks for the link! I'll take a look and see what I can come up with. |
After some thinking and sketching on a piece of paper, it seems to me that it could be solved this way (just proposal for now): # define buffer's here: states, actions, rewards, dones, values, logprobs
# Initial obs and info
state, prev_info = torch.Tensor(envs.reset()).to(device), {}
for step in range(0, args.num_steps):
with torch.no_grad():
action, logprob, _, value = agent.get_action_and_value(state)
next_state, reward, done, info = envs.step(action.cpu().numpy())
# full transition on step T
states[step] = state
actions[step] = action
rewards[step] = torch.tensor(reward).to(device).view(-1)
dones[step] = done
# also PPO stuff for step T
logprobs[step] = logprob
values[step] = value.flatten()
# Here we should check for timeout on previous step
if step > 1 and torch.any(dones[step - 1]):
# if on prev step was timeout, then we should
# 1. set dones[step - 1][env_id] to False, as it was not real done
# 2. set values[step][env_id] to V(prev_info['terminal_observation']) (as it is real next_state for previous step)
for env_id in range(args.num_envs):
timeout = "TimeLimit.truncated" in prev_info[env_id] and prev_info[env_id]["TimeLimit.truncated"]
if timeout:
terminal_state = torch.tensor(prev_info[env_id]["terminal_observation"], device=device)
# Set done to false as it was timeout
dones[step - 1, env_id] = 0
# Set value to V of terminal_state (not state as it is first state after reset on previous step)
values[step, env_id] = agent.get_value(terminal_state).flatten()
state, prev_info = torch.Tensor(next_state).to(device), info This small rearrangement, of course, would require separately handling the last state after num_steps, since we need it to bootstrap the last transition, but it could be done in similar way. GAE computation then would be unchanged at all! The main point here is that we never use
So it seems to me that we can safely swap it inplace here. Does that sound reasonable to you? |
Thanks for the sketch. I think it's a little tricky though so I wrote down what the variables look like as below (link) I think the correct implementation is consistent with https://github.com/DLR-RM/stable-baselines3/pull/658/files#diff-384b5f21f2bed58d1d6e64da04a42fee52f353fcec38bf410338524336657bd8R205 # Handle timeout by bootstraping with value function
# see GitHub issue #633
for idx, done_ in enumerate(dones):
if (
done_
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad():
terminal_value = self.policy.predict_values(terminal_obs)[0]
rewards[idx] += self.gamma * terminal_value Is this equivalent to your implementation? I didn't quite get the reason behind Next week will be pretty busy but will try to respond. CC @araffin |
After thinking a bit more, it seems that my implementation is not correct because it will not correctly account for the done flag in GAE ( I don't like the fact that we're implicitly changing the rewards, though, since in other PPO variants they could be used for something else :( I will still try to explain the thinking behind We have two variants of value computation for step 1:
Because we will add
Then value estimation will be ok: Actually you do the same as cleanrl/cleanrl/ppo_continuous_action.py Line 203 in ee262da
For now I think your proposal is better, I will try to make PR & some runs on gym Mujoco. |
That's a good point. One way we could address is to create another storage variable Also a quick note: the new gym API for truncation is coming soon openai/gym#2752. Ideally we should be prototyping using the new API.
One issue with the example you provided is storage -
|
Found this issue as I was wondering if the isaacgym ppo was handling timeouts; it is not, right? I will check the code more thoroughly tomorrow and try to find a solution for my use case. Additionally, I remember having issues on isaacgymenvs on how they handle the timeouts and resets; I think my env diverges a bit from theirs, so I'm not sure if it is handling their case but not mine |
Yeah the isaacgym ppo variant does not deal with truncation properly… I sort of lost a bit of interest in it because properly handling it did not seem to result in significant performance difference. See sail-sg/envpool#194 (comment) |
@vwxyzjn did you notice any significant overhead by handling truncations when benchmarking? |
Not really significant overhead. |
Just chiming in here to agree with this comment in a related thread: I have experienced that proper timeout handling can make quite a difference. Thought I'd share my experience and solution here in case it's useful to others, and to encourage getting it implemented in CleanRL. (In relation to the discussions above I am not sure about the best way to do this though.) Anyway: I have a vectorised environment that I was getting good results with from SB3, but to speed things up I wanted to move both environment and RL to GPU, building from CleanRL For reference, in my environment the agent can incur a lot of effort penalties early on while it's trying to learn how to get positive rewards, so early on rewards will be mostly negative. My understanding of what was happening here is that without truncation bootstrapping, just giving up and doing nothing becomes a very strong suboptimum, because every now and then the agent still gets a random (from its Markov perspective) treat of reaching the episode truncation, where without bootstrapping the reached state seems to have about zero value, which is much better than the negative values incurred when trying to learn. And the value function learning becomes noisy and slow, because there is this weird zero value that happens unpredictably every now and then. (Apols if this is obvious to others - it took me a while to suss it out.) I based my bootstrap implementation on the SB3 one, as referenced in Costa's comment above, changing this line: next_obs, rewards[step], next_done, info = envs.step(action) to: next_obs, rewards[step], terminations, truncations, info = envs.step(action)
next_done = (terminations | truncations).long()
with torch.no_grad():
terminal_values = agent.critic(info['final_observation']).squeeze()
rewards[step] = torch.where(
truncations, rewards[step] + args.gamma * terminal_values, rewards[step]) This approach directly modifies the rewards, as discussed in the thread above, but that doesn't matter to me in this case. Note that I am using my own environment here, not an IsaacGym environment, so the |
Hi! I'm a bit puzzled as to how a timeout could be handled correctly in your implementation of PPO (well, this is relevant for all variants). I am especially surprised by envpool, because seems like they do not return the last real state at all (like gym vec envs do in
info
).Were there any ideas on how to do it right? I'd like to make a PR with that fix, but I haven't figured out how yet. Buffer in PPO is a bit more confusing than in off-policy algos where there is a clear (s, s', d) thing..
The text was updated successfully, but these errors were encountered: