Skip to content

Commit

Permalink
[RLlib; Offline RL] Store episodes in state form. (ray-project#47294)
Browse files Browse the repository at this point in the history
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
simonsays1980 authored and ujjawal-khare committed Oct 15, 2024
1 parent 8a87272 commit 192990a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
19 changes: 10 additions & 9 deletions rllib/env/utils/infinite_lookback_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ def get_state(self) -> Dict[str, Any]:
A dict containing all the data and metadata from the buffer.
"""
return {
"data": to_jsonable_if_needed(self.data, self.space)
if self.space
else self.data,
"data": self.data,
"lookback": self.lookback,
"finalized": self.finalized,
"space_struct": gym_space_to_dict(self.space_struct)
if self.space_struct
else self.space_struct,
"space": gym_space_to_dict(self.space) if self.space else self.space,
}

Expand All @@ -93,16 +94,16 @@ def from_state(state: Dict[str, Any]) -> None:
from the state dict.
"""
buffer = InfiniteLookbackBuffer()
buffer.data = state["data"]
buffer.lookback = state["lookback"]
buffer.finalized = state["finalized"]
buffer.space = gym_space_from_dict(state["space"]) if state["space"] else None
buffer.space_struct = (
get_base_struct_from_space(buffer.space) if buffer.space else None
gym_space_from_dict(state["space_struct"])
if state["space_struct"]
else state["space_struct"]
)
buffer.data = (
from_jsonable_if_needed(state["data"], buffer.space)
if buffer.space
else state["data"]
buffer.space = (
gym_space_from_dict(state["space"]) if state["space"] else state["space"]
)

return buffer
Expand Down
8 changes: 0 additions & 8 deletions rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,6 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]
)
for state in batch["item"]
]
self.episode_buffer.add(episodes)
episodes = self.episode_buffer.sample(
num_items=self.config.train_batch_size_per_learner,
# TODO (simon): This can be removed as soon as DreamerV3 has been
# cleaned up, i.e. can use episode samples for training.
sample_episodes=True,
finalize=True,
)
# Else, if we have old stack `SampleBatch`es.
elif self.input_read_sample_batches:
episodes = OfflinePreLearner._map_sample_batch_to_episode(
Expand Down

0 comments on commit 192990a

Please sign in to comment.