Skip to content

[RLlib] SingleAgentEpisode is not designed to handle dict observations #54659

@wullli

Description

@wullli

What happened + What you expected to happen

If SingleAgentEpisode.concat_episode gets dict observations, it will fail with the following error:

...
File "/Users/<redacted>/miniconda3/envs/tax-metrics-exp/lib/python3.11/site-packages/ray/rllib/env/single_agent_episode.py", line 626, in concat_episode
    assert np.all(other.observations[0] == self.observations[-1])
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

This is, since np.all needs to be called for each numpy array of each key separately. Something like this could maybe work, but I'm not sure these are always numpy array. Otherwise overwriting the equality check of the dict might be possible?

import numpy as np

class SingleAgentEpisode:
    ...
    def concat_episode(self, other: "SingleAgentEpisode") -> None:
        ...
        np.testing.assert_equal(other.observations[0], self.observations[-1])
        ...

Versions / Dependencies

ray[rllib]: 2.47.1
Python 3.11.13
MacOS Sequoia 15.5

Reproduction script

import gymnasium as gym
from gymnasium import spaces
import numpy as np
from ray.rllib.env.single_agent_episode import SingleAgentEpisode

class DummyDictObsEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.n_features = 5
        self.n_actions = 4

        self.observation_space = spaces.Dict({
            "features": spaces.Box(shape=(self.n_features,),
                                   low=-np.inf, high=np.inf, dtype=np.float32),
            "action_mask": spaces.Box(shape=(self.n_actions,),
                                      low=0.0, high=1.0, dtype=np.float32)
        })

        self.action_space = spaces.Discrete(self.n_actions)
        self._current_step = 0
        self._max_steps = 3

    def _get_obs(self):
        return {
            "features": np.zeros((self.n_features,), dtype=np.float32),
            "action_mask": np.ones((self.n_actions,), dtype=np.float32) # Mask always full for simplicity
        }

    def reset(self, seed=None, options=None):
        self._current_step = 0
        return self._get_obs(), {}

    def step(self, _):
        self._current_step += 1
        return self._get_obs(), 1.0, False, False, {}


print("Setting up DummyDictObsEnv (Simplified with Gymnasium)...")
env = DummyDictObsEnv()

print("\nCollecting data for Episode 1...")
obs_list_1 = []
reward_list_1 = []
action_list_1 = []
terminal_list_1 = []
info_list_1 = []

initial_obs_1, initial_info_1 = env.reset()
obs_list_1.append(initial_obs_1)
info_list_1.append(initial_info_1)

for i in range(env._max_steps):
    action = env.action_space.sample()
    observation, reward, terminated, truncated, info = env.step(action)
    obs_list_1.append(observation)
    reward_list_1.append(reward)
    action_list_1.append(action)
    terminal_list_1.append(terminated)
    info_list_1.append(info)

episode1 = SingleAgentEpisode(
    id_="test_episode",
    observations=obs_list_1,
    rewards=reward_list_1,
    actions=action_list_1,
    terminated=False,
    infos=info_list_1,
)
print(f"Episode 1 created with {len(episode1)} steps.")
print(f"Episode 1 last observation (t={len(episode1)-1}): {episode1.observations[-1]}")

print("\nCollecting data for Episode 2...")
obs_list_2 = []
reward_list_2 = []
action_list_2 = []
terminal_list_2 = []
info_list_2 = []

initial_obs_2, initial_info_2 = env.reset()
obs_list_2.append(initial_obs_2) # Add initial observation
info_list_2.append(initial_info_2)

for i in range(env._max_steps):
    action = env.action_space.sample()
    observation, reward, terminated, truncated, info = env.step(action)
    obs_list_2.append(observation)
    reward_list_2.append(reward)
    action_list_2.append(action)
    terminal_list_1.append(terminated)
    info_list_2.append(info)


episode2 = SingleAgentEpisode(
    id_="test_episode",
    observations=obs_list_2,
    rewards=reward_list_2,
    actions=action_list_2,
    terminated=False,
    infos=info_list_2,
)

concatenated_episode = episode1.concat_episode(episode2) # Fails

Issue Severity

High: It blocks me from completing my task.

Metadata

Metadata

Assignees

No one assigned

    Labels

    P2Important issue, but not time-criticalbugSomething that is supposed to be working; but isn'trllibRLlib related issuesrllib-envrllib env related issuesrllib-envrunnersIssues around the sampling backend of RLlibstability

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions