Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] RenameTransform of ParallelEnv is not the same as ParallelEnv of transformed environment #2445

Open
3 tasks done
thomasbbrunner opened this issue Sep 19, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@thomasbbrunner
Copy link
Contributor

thomasbbrunner commented Sep 19, 2024

Describe the bug

In short:

transform(ParallelEnv(base_env)) != ParallelEnv(transform(base_env))

I'm aware that this is not supported in some cases, but I'd expect that this would work for the RenameTransform.

This is even stated in the documentation: "There are two equivalent ways of transforming parallen environments: in each process separately, or on the main process. It is even possible to do both."

To Reproduce

Simple script to reproduce the issue:

from torchrl.envs import RenameTransform, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs import check_env_specs, ParallelEnv

def _make_env():
    return GymEnv("CartPole-v1")

def _transform_env(env):
    return TransformedEnv(
        env,
        RenameTransform(
            in_keys=[
                "terminated",
            ],
            out_keys=[
                ("stuff", "terminated"),
            ],
        )
    )


def _make_transformed_env():
    return _transform_env(_make_env())

if __name__ == "__main__":

    base_env = _make_env()
    transformed_env = _make_transformed_env()
    trans_parallel_env = _transform_env(ParallelEnv(
            1,
            create_env_fn=_make_env,
        )
    )
    parallel_trans_env = ParallelEnv(
        1,
        create_env_fn=_make_transformed_env,
    )

    # Works!
    check_env_specs(base_env)
    # Works!
    check_env_specs(transformed_env)
    # Works!
    check_env_specs(trans_parallel_env)
    # RuntimeError: Cannot infer the value of terminated when only done and truncated are present.
    check_env_specs(parallel_trans_env)

Expected behavior

The script above should run without errors.

System info

Ubuntu 22.04
Python 3.10.14
torch 2.4.1
torchrl 0.5.0

Reason and Possible fixes

I've tracked down the issue to the _set_properties in the BatchedEnvBase class. When writing to the self.done_spec property, the EnvBase.done_spec setter does not respect the renamed keys.

Here's a comparison of the full_done_spec before and after calling the done_spec setter:

Before: Composite(
    done: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    truncated: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    stuff: Composite(
        terminated: Categorical(
            shape=torch.Size([1, 1]),
            space=CategoricalBox(n=2),
            device=cpu,
            dtype=torch.bool,
            domain=discrete),
        device=cpu,
        shape=torch.Size([1])),
    device=cpu,
    shape=torch.Size([1]))
After: Composite(
    done: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    truncated: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    stuff: Composite(
        terminated: Categorical(
            shape=torch.Size([1, 1]),
            space=CategoricalBox(n=2),
            device=cpu,
            dtype=torch.bool,
            domain=discrete),
        done: Categorical(
            shape=torch.Size([1, 1]),
            space=CategoricalBox(n=2),
            device=cpu,
            dtype=torch.bool,
            domain=discrete),
        device=cpu,
        shape=torch.Size([1])),
    terminated: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    device=cpu,
    shape=torch.Size([1]))

The correct spec should be:

>>> transformed_env.full_done_spec
Composite(
    done: Categorical(
        shape=torch.Size([1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    truncated: Categorical(
        shape=torch.Size([1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    stuff: Composite(
        terminated: Categorical(
            shape=torch.Size([1]),
            space=CategoricalBox(n=2),
            device=cpu,
            dtype=torch.bool,
            domain=discrete),
        device=None,
        shape=torch.Size([])),
    device=None,
    shape=torch.Size([]))

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@thomasbbrunner thomasbbrunner added the bug Something isn't working label Sep 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants