Skip to content

Commit

Permalink
Remove FloatReward (#829)
Browse files Browse the repository at this point in the history
* Remove FloatReward. Fixes #794

* Bump SB3 version to ensure we have the bug-fix that makes the FloatReward unneeded.
  • Loading branch information
ernestum authored Jan 17, 2024
1 parent d74e903 commit a8b079c
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 12 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"rich",
"scikit-learn>=0.21.2",
"seals~=0.2.1",
"stable-baselines3~=2.0",
"stable-baselines3~=2.2.1",
"sacred>=0.8.4",
"tensorboard>=1.14",
"huggingface_sb3~=3.0",
Expand Down
11 changes: 0 additions & 11 deletions tests/algorithms/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Fixtures common across algorithm tests."""
from typing import Sequence

import gymnasium as gym
import pytest
from stable_baselines3.common import envs
from stable_baselines3.common.policies import BasePolicy
Expand Down Expand Up @@ -113,20 +112,10 @@ def pendulum_single_venv(rng) -> VecEnv:
)


# TODO(GH#794): Remove after https://github.com/DLR-RM/stable-baselines3/pull/1676
# merged and released.
class FloatReward(gym.RewardWrapper):
"""Typecasts reward to a float."""

def reward(self, reward):
return float(reward)


@pytest.fixture
def multi_obs_venv() -> VecEnv:
def make_env():
env = envs.SimpleMultiObsEnv(channel_last=False)
env = FloatReward(env)
return RolloutInfoWrapper(env)

return DummyVecEnv([make_env, make_env])

0 comments on commit a8b079c

Please sign in to comment.