Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Incorrect reset handling in collectors #937

Closed
3 tasks done
btx0424 opened this issue Feb 26, 2023 · 33 comments · Fixed by #938
Closed
3 tasks done

[BUG] Incorrect reset handling in collectors #937

btx0424 opened this issue Feb 26, 2023 · 33 comments · Fixed by #938
Assignees
Labels
bug Something isn't working

Comments

@btx0424
Copy link
Contributor

btx0424 commented Feb 26, 2023

Describe the bug

After auto-resetting an environment with _reset_if_necessary, the initial obs is ignored. The actual obs seen by the policy at the next step is always a zero TensorDict.

To Reproduce

Here we use a dummy env where the obs is just the time stamp (starting from 1).

from torchrl.collectors import SyncDataCollector
from torchrl.envs import EnvBase
from tensordict import TensorDict
from torchrl.data import TensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
import torch

class DummyEnv(EnvBase):
    def __init__(
        self, 
        device = "cpu", 
        dtype = None, 
        batch_size = None, 
        run_type_checks: bool = True
    ):
        super().__init__(device, dtype, batch_size, run_type_checks)
        self.observation_spec = CompositeSpec({
            "time": UnboundedContinuousTensorSpec((*batch_size, 1)),
        }, shape=batch_size)
        self.action_spec = UnboundedContinuousTensorSpec((*batch_size, 1))
        self.reward_spec = UnboundedContinuousTensorSpec((*batch_size, 1,))
        self.time: torch.Tensor = torch.zeros(*self.batch_size, 1, device=self.device)
    
    def _step(self, tensordict: TensorDict) -> TensorDict:
        result = TensorDict({
            "time": self.time.clone(),
            "reward": self.reward_spec.rand(),
            "done": self.time > 4,
        }, self.batch_size)
        self.time += 1
        return result
    
    def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
        if tensordict is not None:
            reset_mask = tensordict.get("_reset")
            self.time[reset_mask] = 1
        else:
            # reset all envs
            self.time[:] = 1
        result = TensorDict({
            "time": self.time.clone()
        }, self.batch_size)
        self.time += 1
        return result

    def _set_seed(self, seed):
        torch.manual_seed(seed)

if __name__ == "__main__":
    batch_size = [4]
    env = DummyEnv(batch_size=batch_size)
    def policy(tensordict: TensorDict):
        if "collector" in tensordict.keys():
            step = tensordict[("collector", "step_count")]
            print(f'step: {step}, obs: {tensordict["time"].squeeze(-1)}')
        tensordict.set("action", env.action_spec.rand())
        return tensordict

    collector = SyncDataCollector(
        env, policy, split_trajs=False, frames_per_batch=4 * 10
    )

    for i, data in enumerate(collector):
        break

Running the above code gives

step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([1., 1., 1., 1.]) # the first episode is correct
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
step: tensor([2, 2, 2, 2], dtype=torch.int32), obs: tensor([3., 3., 3., 3.])
step: tensor([3, 3, 3, 3], dtype=torch.int32), obs: tensor([4., 4., 4., 4.])
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([0., 0., 0., 0.]) # the policy should have never seen zeros
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
step: tensor([2, 2, 2, 2], dtype=torch.int32), obs: tensor([3., 3., 3., 3.])
step: tensor([3, 3, 3, 3], dtype=torch.int32), obs: tensor([4., 4., 4., 4.])
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([0., 0., 0., 0.]) # the policy should have never seen zeros
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])

Expected behavior

step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([1., 1., 1., 1.]) # the first episode is correct
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
step: tensor([2, 2, 2, 2], dtype=torch.int32), obs: tensor([3., 3., 3., 3.])
step: tensor([3, 3, 3, 3], dtype=torch.int32), obs: tensor([4., 4., 4., 4.])
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([1., 1., 1., 1.]) # consistent with the first episode
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
step: tensor([2, 2, 2, 2], dtype=torch.int32), obs: tensor([3., 3., 3., 3.])
step: tensor([3, 3, 3, 3], dtype=torch.int32), obs: tensor([4., 4., 4., 4.])
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([1., 1., 1., 1.]) # consistent with the first episode
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])

Reason and Possible fixes

This occurs because in SyncDataCollector:

def rollout(self):
    ...
    self._reset_if_necessary()
    self._tensordict.update(step_mdp(self._tensordict), inplace=True) 
    ...

def _reset_if_necessary(self):
    ...
    if done_or_terminated.any():
        traj_ids = self._tensordict.get(("collector", "traj_ids")).clone()
        steps = steps.clone()
        if len(self.env.batch_size):
            self._tensordict.masked_fill_(done_or_terminated, 0)
            _reset = done_or_terminated
            self._tensordict.set("_reset", _reset)
        else:
            _reset = None
            self._tensordict.zero_()
        self.env.reset(self._tensordict)
    ...

The initial obs of the new episode gets discarded by step_mdp because it is not in self._tensordict["next"]. What the policy will see is the zeros set by self._tensordict.masked_fill_(done_or_terminated, 0).

The most straightforward fix is to change the above to:

def _reset_if_necessary(self):
    ...
    if done_or_terminated.any():
        traj_ids = self._tensordict.get(("collector", "traj_ids")).clone()
        steps = steps.clone()
        if len(self.env.batch_size):
            self._tensordict.masked_fill_(done_or_terminated, 0)
            _reset = done_or_terminated
            self._tensordict.set(("next", "_reset"), _reset)
        else:
            _reset = None
            self._tensordict.zero_()
        self.env.reset(self._tensordict["next"])
    ...

So that the initial obs get carried to the next step by step_mdp.

However, this would break some tests, e.g., test_collector.py::test_traj_len_consistency, because now we have ("next", "done") in keys after the first reset which causes key inconsistency when doing torch.cat.

I recall that the earlier versions of torchrl require step_mdp(env.reset()). Here I wonder does the coexistence of "done" and ("next", "done") make sense. I personally think having both is more rigorous: "done" indicates whether this step is an initial step, and ("next", "done") indicates whether the episode is terminated after this step (IIUC currently we have only "done" for the latter). In this way inside an RNN policy module, we can decide whether we need to reset some of its hidden states.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@btx0424 btx0424 added the bug Something isn't working label Feb 26, 2023
@vmoens
Copy link
Contributor

vmoens commented Feb 26, 2023

thanks for reporting this.
I think we should first fix the bug and then consider having ("next", "done") along with "done".
It's a challenging one: there are some conventions, like r_t goes with s_t, a_t, only s_{t+1} is considered as the "next" value. Is the related final_t a property of t or t+1 is not super clear IMO.
IIUC what you're suggesting, we should move the current done to the next state, is that right?
I'm just afraid that it would not be obvious for all users. But I'm open to discussing it...

I'll make a PR to fix this bug but feel free to keep the conversation going regarding "done".

@btx0424
Copy link
Contributor Author

btx0424 commented Feb 26, 2023

Yep, that's challenging and requires a lot of code changes, including the objective and loss modules.

But I do think having "done" and ("next", "done") to mean different things is sometimes favorable, e.g., resetting a recurrent module to some non-trivial hidden states would require identifying if the current step is just after a reset by checking input_td["done"].

Currently the "done" return by env.reset is not very clear in its meaning since no transition has taken place yet.

BTW a test for this

@pytest.mark.parametrize("env_class", [MockSerialEnv, MockBatchedLockedEnv])
def test_initial_obs_consistency(
    env_class, seed=1
):
    if env_class == MockSerialEnv:
        num_envs = 1
        env = MockSerialEnv(device="cpu")
    elif env_class == MockBatchedLockedEnv:
        num_envs = 2
        env = MockBatchedLockedEnv(device="cpu", batch_size=[num_envs])
    env.set_seed(seed)
    collector = SyncDataCollector(
        create_env_fn=env,
        frames_per_batch=(env.max_val * 2 + 2) * num_envs, # at least two episodes
        split_trajs=False
    )
    for _, d in enumerate(collector):
        break
    obs = d["observation"].squeeze()
    arange = torch.arange(1, collector.env.counter).float().expand_as(obs)
    assert torch.allclose(obs, arange)

@vmoens
Copy link
Contributor

vmoens commented Feb 27, 2023

After due consideration i'm open about it.
Brining a few folks in the discussion for visibility:
We're considering moving away from a step that returns

TensorDict({
    "done": torch.Tensor(...),
    "action": torch.Tensor(...),
    "reward": torch.Tensor(...),
    "next": TensorDict({
        "observation": torch.Tensor(...),
    }, batch_size),
}, batch_size)

to one where the "done" state is in the "next" tensordict:

TensorDict({
    "action": torch.Tensor(...),
    "reward": torch.Tensor(...),
    "next": TensorDict({
        "done": torch.Tensor(...),
        "observation": torch.Tensor(...),
    }, batch_size),
}, batch_size)

This is a major BC-breaking change, meaning that if we don't do it now (before beta) it'll be more difficult to bring it later.

This has several advantages but mainly the point is that "done" is a property of the next state (ie: it is final) and not of the current state. See above for the full discussion.

cc @shagunsodhani @matteobettini @albertbou92 @Benjamin-eecs @riiswa @smorad @XuehaiPan

@btx0424
Copy link
Contributor Author

btx0424 commented Feb 27, 2023

For policy input consistency, I suggest having reset also return a done of Trues so that we can do

hidden_state[done] = init_hidden_state

either inside the policy or in the collector.

@vmoens
Copy link
Contributor

vmoens commented Feb 27, 2023

For policy input consistency, I suggest having reset also return a done of Trues so that we can do

hidden_state[done] = init_hidden_state

either inside the policy or in the collector.

Isn't is sufficient that the ("next", "done") is True? Like this two consecutive trajectories are clearly delimited.
In other words: we don't need to set ("done",) to True since it already is

@smorad
Copy link
Contributor

smorad commented Feb 27, 2023

Would it make sense to have done in both the current timestep and next timestep TensorDicts, similar to how observations are treated? It might be confusing to diverge from how gym/dm_env/Sutton&Barto treat done, unless there is a compelling reason to do so.

@matteobettini
Copy link
Contributor

matteobettini commented Feb 27, 2023

So just to recap and think about this

Before step we have (o_t, done_t, r_t,a_t) which means that we saw state (o_t, done_t, r_t) and took action (a_t)

TensorDict({
    "done": torch.Tensor(...),
    "action": torch.Tensor(...),
    "reward": torch.Tensor(...),
    "observation": torch.Tensor(...),
}, batch_size)

Now, after the step, we should get (o_t+1, done_t+1, r_t+1), which, paired with the previous data, would be a td like

TensorDict({
    "done": torch.Tensor(...),
    "action": torch.Tensor(...),
    "reward": torch.Tensor(...),
    "observation": torch.Tensor(...),
    "next": TensorDict({
        "observation": torch.Tensor(...),
        "done": torch.Tensor(...),
        "reward": torch.Tensor(...),
    }, batch_size),
}, batch_size)

Now, we can take the new action

TensorDict({
    "done": torch.Tensor(...),
    "action": torch.Tensor(...),
    "reward": torch.Tensor(...),
    "observation": torch.Tensor(...),
    "next": TensorDict({
        "observation": torch.Tensor(...),
        "done": torch.Tensor(...),
        "reward": torch.Tensor(...),
        "action": torch.Tensor(...),
    }, batch_size),
}, batch_size)

And finally step the MDP to go back to the top of this comment

TensorDict({
    "done": torch.Tensor(...),
    "reward": torch.Tensor(...),
    "observation": torch.Tensor(...),
    "action": torch.Tensor(...),
}, batch_size)

The intermediate view

TensorDict({
    "done": torch.Tensor(...),
    "action": torch.Tensor(...),
    "reward": torch.Tensor(...),
    "observation": torch.Tensor(...),
    "next": TensorDict({
        "observation": torch.Tensor(...),
        "done": torch.Tensor(...),
        "reward": torch.Tensor(...),
        "action": torch.Tensor(...),
    }, batch_size),
}, batch_size)

contains all the info we could possibly want.
The action could be also taken after stepping the mdp. In that case it would just have access to less info.

@matteobettini
Copy link
Contributor

matteobettini commented Feb 27, 2023

After due consideration i'm open about it. Brining a few folks in the discussion for visibility: We're considering moving away from a step that returns

TensorDict({
    "done": torch.Tensor(...),
    "action": torch.Tensor(...),
    "reward": torch.Tensor(...),
    "next": TensorDict({
        "observation": torch.Tensor(...),
    }, batch_size),
}, batch_size)

to one where the "done" state is in the "next" tensordict:

TensorDict({
    "action": torch.Tensor(...),
    "reward": torch.Tensor(...),
    "next": TensorDict({
        "done": torch.Tensor(...),
        "observation": torch.Tensor(...),
    }, batch_size),
}, batch_size)

This is a major BC-breaking change, meaning that if we don't do it now (before beta) it'll be more difficult to bring it later.

This has several advantages but mainly the point is that "done" is a property of the next state (ie: it is final) and not of the current state. See above for the full discussion.

cc @shagunsodhani @matteobettini @albertbou92 @Benjamin-eecs @riiswa @smorad @XuehaiPan

Commenting on this, I think reward should go togheter with done and obs, whatever we do

@vmoens
Copy link
Contributor

vmoens commented Feb 27, 2023

Now, after the step, we should get (o_t+1, done_t+1, r_t+1), which, paired with the previous data, would be a td like

TensorDict({
    "done": torch.Tensor(...),
    "action": torch.Tensor(...),
    "reward": torch.Tensor(...),
    "observation": torch.Tensor(...),
    "next": TensorDict({
        "observation": torch.Tensor(...),
        "done": torch.Tensor(...),
        "reward": torch.Tensor(...),
    }, batch_size),
}, batch_size)

I think reward is still at t

When we write RL papers we usually say s_t, a_t, r_t, s_{t+1}
Not sure about terminal. The suggestion is to say "done belongs to t+1". IMO reward should stay out of "next".

Besides, reset will not return a "reward" but it can return a "done"

EDIT:
Happy to move r_{t} to r_{t+1}
But then do we need reset to fill a reward too? What sense does it have? I can see why an env would be immediately done but the reward is linked to an action no?

@matteobettini
Copy link
Contributor

matteobettini commented Feb 27, 2023

Agent-environment-interaction-Sutton-and-Barto-2017

I mean one could argue it is a subjective choice, but this was decided by Sutton and Barto and picked up by OpenAI gym so I would not break it. In the book they always make rewards, dones and obs go together. And I think that is what gym wanted to do with the values returned by step.

@vmoens
Copy link
Contributor

vmoens commented Feb 27, 2023

Agent-environment-interaction-Sutton-and-Barto-2017

I mean one could argue it is a subjective choice, but this was decided by Sutton and Barto and picked up by OpenAI gym so I would not break it. In they book they always make rewoards, dones and obs got together. And I think that is what gym wanted to do with the values returned by step.

Reset could omit returning a reward, but it would still return done and obs since the reset could be partial and only some envs may actually have been reset

See my edited comment above

@matteobettini
Copy link
Contributor

matteobettini commented Feb 27, 2023

EDIT: Happy to move r_{t} to r_{t+1} But then do we need reset to fill a reward too? What sense does it have? I can see why an env would be immediately done but the reward is linked to an action no?

Would reset need to return a reward in that case? Couldnt it just return done and obs? Done is needed as the return because the reset could be partial

@vmoens
Copy link
Contributor

vmoens commented Feb 27, 2023

Yeah maybe we can do without.
So for a trajectory we'd have

TensorDict({
   "obs": torch.Tensor([T, ...]),
   "action": torch.Tensor([T, ...]),
   "done": torch.Tensor([T, ...]),
   "next": TensorDict({
        "obs": torch.Tensor([T, ...]),
        "done": torch.Tensor([T, ...]),
        "reward": torch.Tensor([T, ...]),
    }, batch_size=[T]),
}, batch_size=[T])

@giadefa
Copy link
Contributor

giadefa commented Feb 27, 2023 via email

@matteobettini
Copy link
Contributor

Yeah maybe we can do without. So for a trajectory we'd have

TensorDict({
   "obs": torch.Tensor([T, ...]),
   "action": torch.Tensor([T, ...]),
   "done": torch.Tensor([T, ...]),
   "next": TensorDict({
        "obs": torch.Tensor([T, ...]),
        "done": torch.Tensor([T, ...]),
        "reward": torch.Tensor([T, ...]),
    }, batch_size=[T]),
}, batch_size=[T])

This looks good to me

@vmoens
Copy link
Contributor

vmoens commented Feb 27, 2023

At this point, the reward specs should go in the observation_spec, like action is in input_spec

@matteobettini
Copy link
Contributor

matteobettini commented Feb 27, 2023

At this point, the reward specs should go in the observation_spec, like action is in input_spec

I mean it could be output_spec which is a composite of reward and obs and **info pecs.

Reset would return output to fit obs_spec and **info specs
Step would return output to fit output_spec

@shagunsodhani
Copy link
Contributor

Following the standard convention in RL:

  • obs, done and reward should go together (seconding @matteobettini )

  • Regarding @vmoens comment "But then do we need reset to fill a reward too? What sense does it have? I can see why an env would be immediately done but the reward is linked to an action no?", one could argue that the action (at reset) is no-op so the reward = f(obs, no-op). Having said that, I would prefer not returning a reward at reset as this seems to be the common convention.

Maybe I missed something but is the plan to go with

TensorDict({
   "obs": torch.Tensor([T, ...]),
   "action": torch.Tensor([T, ...]),
   "done": torch.Tensor([T, ...]),
   "next": TensorDict({
        "obs": torch.Tensor([T, ...]),
        "done": torch.Tensor([T, ...]),
        "reward": torch.Tensor([T, ...]),
    }, batch_size=[T]),
}, batch_size=[T])

I thought we want to keep reward next to obs so would have expected this:

TensorDict({
   "obs": torch.Tensor([T, ...]),
   "action": torch.Tensor([T, ...]),
   "done": torch.Tensor([T, ...]),
   "reward": torch.Tensor([T, ...]),
   "next": TensorDict({
        "obs": torch.Tensor([T, ...]),
        "done": torch.Tensor([T, ...]),     
    }, batch_size=[T]),
}, batch_size=[T])

@matteobettini
Copy link
Contributor

matteobettini commented Mar 1, 2023

Following the standard convention in RL:

  • obs, done and reward should go together (seconding @matteobettini )
  • Regarding @vmoens comment "But then do we need reset to fill a reward too? What sense does it have? I can see why an env would be immediately done but the reward is linked to an action no?", one could argue that the action (at reset) is no-op so the reward = f(obs, no-op). Having said that, I would prefer not returning a reward at reset as this seems to be the common convention.

Maybe I missed something but is the plan to go with

TensorDict({
   "obs": torch.Tensor([T, ...]),
   "action": torch.Tensor([T, ...]),
   "done": torch.Tensor([T, ...]),
   "next": TensorDict({
        "obs": torch.Tensor([T, ...]),
        "done": torch.Tensor([T, ...]),
        "reward": torch.Tensor([T, ...]),
    }, batch_size=[T]),
}, batch_size=[T])

I thought we want to keep reward next to obs so would have expected this:

TensorDict({
   "obs": torch.Tensor([T, ...]),
   "action": torch.Tensor([T, ...]),
   "done": torch.Tensor([T, ...]),
   "reward": torch.Tensor([T, ...]),
   "next": TensorDict({
        "obs": torch.Tensor([T, ...]),
        "done": torch.Tensor([T, ...]),     
    }, batch_size=[T]),
}, batch_size=[T])

I think the solution proposed by Vincent (first of your comment):

TensorDict({
   "obs": torch.Tensor([T, ...]),
   "action": torch.Tensor([T, ...]),
   "done": torch.Tensor([T, ...]),
   "next": TensorDict({
        "obs": torch.Tensor([T, ...]),
        "done": torch.Tensor([T, ...]),
        "reward": torch.Tensor([T, ...]),
    }, batch_size=[T]),
}, batch_size=[T])

does what we all want.

If you think aboout it, at time step 0 you have

TensorDict({
   "obs": torch.Tensor([1, ...]),
   "action": torch.Tensor([1, ...]),
   "done": torch.Tensor([1, ...]),
}, batch_size=[1])

which has obs_0 and done_0 coming from reset and action_0, taken for that obs.

You call step with this and you get done_1, obs_1, rew_1, which go in "next".

TensorDict({
   "obs": torch.Tensor([1, ...]),
   "action": torch.Tensor([1, ...]),
   "done": torch.Tensor([1, ...]),
   "next": TensorDict({
        "obs": torch.Tensor([1, ...]),
        "done": torch.Tensor([1, ...]),
        "reward": torch.Tensor([1, ...]),
    }, batch_size=[1]),
}, batch_size=[1])

Then, when you step the mdp, you obtain

TensorDict({
   "obs": torch.Tensor([1, ...]),
   "action": torch.Tensor([1, ...]),
   "done": torch.Tensor([1, ...]),
}, batch_size=[1])

again, which now contains obs_1 and done_1, with added action_1.

@shagunsodhani
Copy link
Contributor

Sorry I didnt follow. Lets take the example of gym API and then we will continue with the torch-rl example.

At time 0, we just reset the env so we get just obs. At time t, the agent uses the last observed state [footnote 0], obs_{t-1}, used that to predict the action to take (a_t), performs that action, reaches a state obs_t, emits a reward r_t and a done signal d_t. Now, in the standard gym API obs_t, r_t, d_t are returned with env.step call.

Now lets extend this to TorchRL

If you think aboout it, at time step 0 you have

 TensorDict({
   "obs": torch.Tensor([1, ...]),
  "action": torch.Tensor([1, ...]),
   "done": torch.Tensor([1, ...]),
}, batch_size=[1])

What is the action here? The agent hasnt taken any action so far. If the agent has taken an action, there should be a reard field as well.

You call step with this and you get done_1, obs_1, rew_1, which go in "next".

TensorDict({
  "obs": torch.Tensor([1, ...]),
  "action": torch.Tensor([1, ...]),
  "done": torch.Tensor([1, ...]),
  "next": TensorDict({
       "obs": torch.Tensor([1, ...]),
       "done": torch.Tensor([1, ...]),
       "reward": torch.Tensor([1, ...]),
   }, batch_size=[1]),
}, batch_size=[1])

Why are we still returning old obs and old done?

Then, when you step the mdp, you obtain

TensorDict({
   "obs": torch.Tensor([1, ...]),
   "action": torch.Tensor([1, ...]),
   "done": torch.Tensor([1, ...]),
}, batch_size=[1])
again, which now contains obs_1 and done_1, with added action_1.

Just to confirm, we will have the reward as well here correct ?

[0]: or observation, or all observations seen so far etc. I am overloading the notation here.

@matteobettini
Copy link
Contributor

matteobettini commented Mar 1, 2023

@shagunsodhani

I think you are using a different notation for the action. My notation is (as in the picture from the rl book above): a_t is the action taken when seeing obs_t. so at time 0 that is the action taken on the reset observation

Agent-environment-interaction-Sutton-and-Barto-2017

The reward, done, state got after taking action a_t is s_t+1, done_t+1 and r_t+1.

So in openAi gym step(a_t) = r_t+1, done_t+1 and s_t+1

@shagunsodhani
Copy link
Contributor

@matteobettini Thanks for the clarification. In that case, is the following flow correct:

at time t = 0, env emits

 TensorDict({
   "obs": torch.Tensor([1, ...]), # say obs_0
   "done": torch.Tensor([1, ...]), # say d_0
}, batch_size=[1])

agent chooses an action and returns

 TensorDict({
   "obs": torch.Tensor([1, ...]), # obs_0
    "action": torch.Tensor([1, ...]), # say a_0
   "done": torch.Tensor([1, ...]), # d_0
}, batch_size=[1])

The action is executed and the env returns

TensorDict({
  "obs": torch.Tensor([1, ...]), # obs_0
  "action": torch.Tensor([1, ...]), # a_0
  "done": torch.Tensor([1, ...]), # d_0
  "next": TensorDict({
       "obs": torch.Tensor([1, ...]), # say obs_1
       "done": torch.Tensor([1, ...]), # say d_1
       "reward": torch.Tensor([1, ...]), # say r_1
   }, batch_size=[1]),
}, batch_size=[1])

Could you also clarify what is the object that the agent sees at time step 1

@matteobettini
Copy link
Contributor

matteobettini commented Mar 1, 2023

What you wrote is correct. at timestep 1 the agent sees what you have put in next (obs_1, d_1)

TensorDict({
   "obs": torch.Tensor([1, ...]), # o_1
   "done": torch.Tensor([1, ...]), # d_1
}, batch_size=[1])

We could also keep in memory r_1 but there is not really a purpose for this. Here we are keeping in memory d_1 and o_1 because this alligns with the info returned by reset and thus available at start of the trajectory

@shagunsodhani
Copy link
Contributor

Sounds good - thanks for the clarification :)

@vmoens
Copy link
Contributor

vmoens commented Mar 2, 2023

Something I'm having some trouble figuring out is:
do we need "done" to be True when calling reset?
Here's the rationale in favour:

  • makes sense when stacking two trajectories one after the other. Since traj["next", "done"] at t is the same as traj["done"] at t+1 but t+1 is the result of a reset, having "done" set to True marks this as the result of a reset.
  • Following this idea, "done" could substitute "_reset" altogether (wdyt @matteobettini ?)
  • As pointed by @btx0424 that could help with the implementation of RNN policies.
    Cons:
  • What if the env is actually done after a reset? In some cases (eg using noops reset) we do have envs that are immediately done.

@matteobettini
Copy link
Contributor

matteobettini commented Mar 2, 2023

@vmoens

We need to keep in mind vectorized and miulti-agent environments. If i have a batch size with 32 vectorized environments each with 4 agents lets say, i want to be able to reset just one of the agents in one of the envs , independently if they are done or not. So, if i reset just one agent in just one part of the batch, the rest of the done will be what it was before when returning from the reset. In other words, we need to keep in mind that “done” is multidimensional and the dimensions can be independent and unrelated. For the same reason, some part of the env that is done and not resetted has to stay done.

Therefore, i think reset should only depend of the “_reset” flag, which could match the previous done or not. It will return a truthfull done, stating which dimensions, after the partial (always think of a reset as partial as we are batched) reset are still done.

@XuehaiPan
Copy link
Contributor

do we need "done" to be True when calling reset?
Here's the rationale in favour:

  • makes sense when stacking two trajectories one after the other.
  • As pointed by @btx0424 that could help with the implementation of RNN policies.

@vmoens I don't think so. You should never do that if you only have one boolean flag (i.e., done) when doing episode stacking with RNN or GAE. The semantics for done is a bit ambiguous. Have you ever thought about using the new Gymnasium step API? I.e. split the single boolean done flag into two booleans terminated and truncated, where:

  • terminated: the game is over, and needs reset(). Calling further step() will give undefined values.
  • truncated: reach the time limit, you can call reset. Also, you can still run further step() calls, which will give valid values.

The old done = terminated or truncated.

For value learning (V or Q):

$$
V (s_t) = E_{a_{t} \sim \pi} [ r_t + \gamma \cdot (1 - \text{terminated}{t + 1}) \cdot V (s{t + 1}) ]
$$

Ref:

@vmoens
Copy link
Contributor

vmoens commented Mar 2, 2023

It's something we could support but I don't think we want to enforce that.
I personally don't mind but from the interactions I had with the community many feel that this change in the gym API was unwanted and harmful. It blocks many users from adopting gymnasium afaict.
Using TorchRL's env API you could easily modify StepCounter to register a new copy of "done" that carries the extra information that you want. I'd rather do that and leave it to the users to choose if they want to differentiate terminated or trunctated rather than imposing it when the community hasn't fully bought the feature.

@XuehaiPan
Copy link
Contributor

It's something we could support but I don't think we want to enforce that.
I personally don't mind but from the interactions I had with the community many feel that this change in the gym API was unwanted and harmful. It blocks many users from adopting gymnasium afaict.

I agree that it's painful to migrate to the Gymnasium API. There are still many RL frameworks using the old Gym API. But I think the implementation correctness is the top priority for an RL framework. It is not worth shipping "wrong" implementations of value-based algorithms (e.g., Q-learning) or policy-based algorithms (e.g., Actor-Critic, PPO). All RL algorithms using TD learning need to consider the "done" information seriously. Since TorchRL is a relatively new repo and its APIs are not finalized yet. It does not have too much tech debt to resolve.

Using TorchRL's env API you could easily modify StepCounter to register a new copy of "done" that carries the extra information that you want.

A step counter is not doing the same thing as a truncated flag does. You can truncate and reset your environment whenever you want. You do not need to reach some conditions such as the step counter hinting at a limit. For example, you have an env with max_episode_steps=200, your batch size is 1024, and already sampled 1000 steps (5 episodes). Then you only need to sample 24 steps and set the truncated flag on.

@btx0424
Copy link
Contributor Author

btx0424 commented Mar 4, 2023

I have the following suggestions:

  1. API-wise, since two dones could be a bit confusing to those have been comfortable with Gym-like convention, we can consider having different names for them, e.g., is_initial and done. When stepping the MDP we let tensordict["is_initial"]=tensordict[("next", "done")]. It should resolve the ambiguity while keeping the advantages.
  2. Keep a clear description of the stepping process somewhere obvious in the doc since it's a fundamental design choice here.
  3. Clarify the role of _reset somewhere in the doc/code as many other users are not familiar with a natively vectorized (and multi-agent) environment.

@vmoens
Copy link
Contributor

vmoens commented Mar 7, 2023

This one is almost ready
#941

@btx0424

API-wise, since two dones could be a bit confusing to those have been comfortable with Gym-like convention, we can consider having different names for them, e.g., is_initial and done. When stepping the MDP we let tensordict["is_initial"]=tensordict[("next", "done")]. It should resolve the ambiguity while keeping the advantages.

I will do a transform that allocates a "is_initial" key when calling reset later. Would that serve your purpose?

Keep a clear description of the stepping process somewhere obvious in the doc since it's a fundamental design choice here.

I updated the readme by putting a gif of the env API, putting the env feature on top and referring to the tutorial and doc in there. Suggestions are welcome!

Clarify the role of _reset somewhere in the doc/code as many other users are not familiar with a natively vectorized (and multi-agent) environment.

We should do that, but as of now it is a private key that is handled by the classes on a per-usage basis. @matteobettini may have an opinion on where and what should be told about it?

@matteobettini
Copy link
Contributor

matteobettini commented Mar 7, 2023

Regarding the gif, I really like it! The only thing that i wonder is in the rollout on top why do we have reward at t=0 (and in the same way why is it between parethesis in policy.forward).
rollout

Regarding the reset, I think the normal user which does not use vectorized envs can completely ignore it, as in its absence it will default to all trues (which is just one True in the case of a classic gym env). Users will not have to set this and it will be True by default when calling reset.

To make users aware of it I would probably talk about it in a tutorial that explains how to make your own environment / use vectorized envs (such as brax) / a multi agent env tutorial

@vmoens
Copy link
Contributor

vmoens commented Mar 10, 2023

@btx0424

  • API-wise, since two dones could be a bit confusing to those have been comfortable with Gym-like convention, we can consider having different names for them, e.g., is_initial and done. When stepping the MDP we let tensordict["is_initial"]=tensordict[("next", "done")]. It should resolve the ambiguity while keeping the advantages.

Can you check #962 ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants