Skip to content

[BUG] RenameTransform of ParallelEnv is not the same as ParallelEnv of transformed environment #2445

Open
@thomasbbrunner

Description

@thomasbbrunner

Describe the bug

In short:

transform(ParallelEnv(base_env)) != ParallelEnv(transform(base_env))

I'm aware that this is not supported in some cases, but I'd expect that this would work for the RenameTransform.

This is even stated in the documentation: "There are two equivalent ways of transforming parallen environments: in each process separately, or on the main process. It is even possible to do both."

To Reproduce

Simple script to reproduce the issue:

from torchrl.envs import RenameTransform, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs import check_env_specs, ParallelEnv

def _make_env():
    return GymEnv("CartPole-v1")

def _transform_env(env):
    return TransformedEnv(
        env,
        RenameTransform(
            in_keys=[
                "terminated",
            ],
            out_keys=[
                ("stuff", "terminated"),
            ],
        )
    )


def _make_transformed_env():
    return _transform_env(_make_env())

if __name__ == "__main__":

    base_env = _make_env()
    transformed_env = _make_transformed_env()
    trans_parallel_env = _transform_env(ParallelEnv(
            1,
            create_env_fn=_make_env,
        )
    )
    parallel_trans_env = ParallelEnv(
        1,
        create_env_fn=_make_transformed_env,
    )

    # Works!
    check_env_specs(base_env)
    # Works!
    check_env_specs(transformed_env)
    # Works!
    check_env_specs(trans_parallel_env)
    # RuntimeError: Cannot infer the value of terminated when only done and truncated are present.
    check_env_specs(parallel_trans_env)

Expected behavior

The script above should run without errors.

System info

Ubuntu 22.04
Python 3.10.14
torch 2.4.1
torchrl 0.5.0

Reason and Possible fixes

I've tracked down the issue to the _set_properties in the BatchedEnvBase class. When writing to the self.done_spec property, the EnvBase.done_spec setter does not respect the renamed keys.

Here's a comparison of the full_done_spec before and after calling the done_spec setter:

Before: Composite(
    done: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    truncated: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    stuff: Composite(
        terminated: Categorical(
            shape=torch.Size([1, 1]),
            space=CategoricalBox(n=2),
            device=cpu,
            dtype=torch.bool,
            domain=discrete),
        device=cpu,
        shape=torch.Size([1])),
    device=cpu,
    shape=torch.Size([1]))
After: Composite(
    done: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    truncated: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    stuff: Composite(
        terminated: Categorical(
            shape=torch.Size([1, 1]),
            space=CategoricalBox(n=2),
            device=cpu,
            dtype=torch.bool,
            domain=discrete),
        done: Categorical(
            shape=torch.Size([1, 1]),
            space=CategoricalBox(n=2),
            device=cpu,
            dtype=torch.bool,
            domain=discrete),
        device=cpu,
        shape=torch.Size([1])),
    terminated: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    device=cpu,
    shape=torch.Size([1]))

The correct spec should be:

>>> transformed_env.full_done_spec
Composite(
    done: Categorical(
        shape=torch.Size([1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    truncated: Categorical(
        shape=torch.Size([1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    stuff: Composite(
        terminated: Categorical(
            shape=torch.Size([1]),
            space=CategoricalBox(n=2),
            device=cpu,
            dtype=torch.bool,
            domain=discrete),
        device=None,
        shape=torch.Size([])),
    device=None,
    shape=torch.Size([]))

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions