Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gym fixes - Follow up from #705 #734

Merged
merged 41 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
243457e
fix Atari in CI
jkterry1 Sep 16, 2021
0d94863
fix dtype and atari extra
jkterry1 Sep 16, 2021
774b7c9
Merge branch 'master' into master
jkterry1 Sep 27, 2021
cdb4028
Merge branch 'master' into master
Miffyli Sep 28, 2021
4899c60
Update setup.py
jkterry1 Oct 21, 2021
4329f4b
remove 3.6
jkterry1 Oct 21, 2021
d2ad8fd
note about how to install Atari
jkterry1 Oct 21, 2021
cd29301
pendulum-v1
jkterry1 Oct 21, 2021
e01e535
atari v5
jkterry1 Oct 21, 2021
20b1ac9
Merge branch 'master' into master
jkterry1 Oct 21, 2021
c4e4f0a
black
jkterry1 Oct 21, 2021
4279d63
fix pendulum capitalization
jkterry1 Oct 21, 2021
f549fc8
add minimum version
jkterry1 Oct 21, 2021
1db85d1
moved things in changelog to breaking changes
jkterry1 Oct 21, 2021
9abfafb
Merge branch 'master' into master
jkterry1 Oct 23, 2021
d72cdf6
partial v5 fix
jkterry1 Oct 23, 2021
ba0db77
Merge branch 'master' into master
jkterry1 Nov 14, 2021
30c9f4d
Merge branch 'master' into master
araffin Dec 10, 2021
de74ec8
Merge branch 'master' into master
araffin Dec 21, 2021
3620a04
Merge branch 'master' into master
jkterry1 Dec 22, 2021
55414c3
env update to pass tests
modanesh Dec 28, 2021
790cb6b
Merge branch 'master' into pr/572
araffin Dec 28, 2021
319ce24
mismatch env version fixed
modanesh Dec 28, 2021
8ad3f75
Merge branch 'pr/571' of https://github.com/modanesh/stable-baselines…
modanesh Dec 28, 2021
e527efe
Merge branch 'master' into pr/572
araffin Jan 2, 2022
4adde2f
Merge branch 'master' into gym_fixes
carlosluis Jan 22, 2022
218bc1a
Fix tests after merge
carlosluis Jan 22, 2022
e5f7012
Include autorom in setup.py
carlosluis Jan 22, 2022
7bde14c
Blacken code
AdamGleave Feb 3, 2022
f6414e7
Fix dtype issue in more robust way
AdamGleave Feb 3, 2022
f4b3342
Fix GitLab CI: switch to Docker container with new black version
AdamGleave Feb 4, 2022
7f1e99e
Remove workaround from GitLab. (May need to rebuild Docker for this t…
AdamGleave Feb 4, 2022
92c1bc7
Merge branch 'blacken' into gym_fixes
AdamGleave Feb 4, 2022
09a3a42
Merge branch 'master' into gym_fixes
AdamGleave Feb 4, 2022
ea073ae
Revert to v4
AdamGleave Feb 4, 2022
8f7d26b
Update setup.py
AdamGleave Feb 4, 2022
0f158f1
Merge branch 'gym_fixes' of github.com:carlosluis/stable-baselines3 i…
AdamGleave Feb 4, 2022
edb504a
Apply suggestions from code review
AdamGleave Feb 4, 2022
b211781
Merge branch 'master' into gym_fixes
AdamGleave Feb 4, 2022
f34ea24
Remove unnecessary autorom
AdamGleave Feb 4, 2022
d7de342
Consistent gym versions
AdamGleave Feb 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ jobs:
pip install .[extra,tests,docs]
# Use headless version
pip install opencv-python-headless
# Tmp fix: ROM missing in the newest atari-py version
pip install atari-py==0.2.5
- name: Build the doc
run: |
make doc
Expand Down
2 changes: 0 additions & 2 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ type-check:
pytest:
script:
- python --version
# Fix to get atari ROMs
- pip install atari-py==0.2.5
# MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error
- MKL_THREADING_LAYER=GNU make pytest

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Documentation is available online: [https://sb3-contrib.readthedocs.io/](https:/
**Note:** Stable-Baselines3 supports PyTorch >= 1.8.1.

### Prerequisites
Stable Baselines3 requires python 3.7+.
Stable Baselines3 requires Python 3.7+.

#### Windows 10

Expand Down
18 changes: 9 additions & 9 deletions docs/guide/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ and optionally a prefix for the checkpoints (``rl_model`` by default).
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/',
name_prefix='rl_model')

model = SAC('MlpPolicy', 'Pendulum-v0')
model = SAC('MlpPolicy', 'Pendulum-v1')
model.learn(2000, callback=checkpoint_callback)


Expand Down Expand Up @@ -206,13 +206,13 @@ It will save the best model if ``best_model_save_path`` folder is specified and
from stable_baselines3.common.callbacks import EvalCallback

# Separate evaluation env
eval_env = gym.make('Pendulum-v0')
eval_env = gym.make('Pendulum-v1')
# Use deterministic actions for evaluation
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',
log_path='./logs/', eval_freq=500,
deterministic=True, render=False)

model = SAC('MlpPolicy', 'Pendulum-v0')
model = SAC('MlpPolicy', 'Pendulum-v1')
model.learn(5000, callback=eval_callback)


Expand All @@ -234,13 +234,13 @@ Alternatively, you can pass directly a list of callbacks to the ``learn()`` meth

checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/')
# Separate evaluation env
eval_env = gym.make('Pendulum-v0')
eval_env = gym.make('Pendulum-v1')
eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/best_model',
log_path='./logs/results', eval_freq=500)
# Create the callback list
callback = CallbackList([checkpoint_callback, eval_callback])

model = SAC('MlpPolicy', 'Pendulum-v0')
model = SAC('MlpPolicy', 'Pendulum-v1')
# Equivalent to:
# model.learn(5000, callback=[checkpoint_callback, eval_callback])
model.learn(5000, callback=callback)
Expand All @@ -263,12 +263,12 @@ It must be used with the :ref:`EvalCallback` and use the event triggered by a ne
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

# Separate evaluation env
eval_env = gym.make('Pendulum-v0')
eval_env = gym.make('Pendulum-v1')
# Stop training when the model reaches the reward threshold
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1)

model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1)
model = SAC('MlpPolicy', 'Pendulum-v1', verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the reward threshold is reached
model.learn(int(1e10), callback=eval_callback)
Expand Down Expand Up @@ -299,7 +299,7 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path='./logs/')
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)

model = PPO('MlpPolicy', 'Pendulum-v0', verbose=1)
model = PPO('MlpPolicy', 'Pendulum-v1', verbose=1)

model.learn(int(2e4), callback=event_callback)

Expand Down Expand Up @@ -328,7 +328,7 @@ and in total for ``max_episodes * n_envs`` episodes.
# Stops training when the model reaches the maximum number of episodes
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=5, verbose=1)

model = A2C('MlpPolicy', 'Pendulum-v0', verbose=1)
model = A2C('MlpPolicy', 'Pendulum-v1', verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the max number of episodes is reached
model.learn(int(1e10), callback=callback_max_episodes)
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -407,5 +407,5 @@ you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256
# Custom critic architecture with two layers of 400 and 300 units
policy_kwargs = dict(net_arch=dict(pi=[64, 64], qf=[400, 300]))
# Create the agent
model = SAC("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs, verbose=1)
model = SAC("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs, verbose=1)
model.learn(5000)
4 changes: 2 additions & 2 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ Atari Games

Training a RL agent on Atari games is straightforward thanks to ``make_atari_env`` helper function.
It will do `all the preprocessing <https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/>`_
and multiprocessing for you.
and multiprocessing for you. To install the Atari environments, run the command ``pip install gym[atari, accept-rom-license]`` to install the Atari environments and ROMs, or install Stable Baselines3 with ``pip install stable-baselines3[extra]`` to install this and other optional dependencies.

.. image:: ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
Expand Down Expand Up @@ -564,7 +564,7 @@ Behind the scene, SB3 uses an :ref:`EvalCallback <callbacks>`.

# Create the model, the training environment
# and the test environment (for evaluation)
model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1,
model = SAC('MlpPolicy', 'Pendulum-v1', verbose=1,
learning_rate=1e-3, create_eval_env=True)

# Evaluate the model every 1000 steps on 5 test episodes
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ For PPO, assuming a shared feature extactor.
action_hidden, value_hidden = self.extractor(observation)
return self.action_net(action_hidden), self.value_net(value_hidden)

# Example: model = PPO("MlpPolicy", "Pendulum-v0")
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
model = PPO.load("PathToTrainedModel.zip")
model.policy.to("cpu")
onnxable_model = OnnxablePolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)
Expand Down
8 changes: 4 additions & 4 deletions docs/guide/tensorboard.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Here is a simple example on how to log both additional tensor or arbitrary scala
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback

model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="/tmp/sac/", verbose=1)
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)


class TensorboardCallback(BaseCallback):
Expand Down Expand Up @@ -104,7 +104,7 @@ Here is an example of how to render an image to TensorBoard at regular intervals
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import Image

model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="/tmp/sac/", verbose=1)
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)


class ImageRecorderCallback(BaseCallback):
Expand Down Expand Up @@ -141,7 +141,7 @@ Here is an example of how to store a plot in TensorBoard at regular intervals:
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import Figure

model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="/tmp/sac/", verbose=1)
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)


class FigureRecorderCallback(BaseCallback):
Expand Down Expand Up @@ -251,7 +251,7 @@ can get direct access to the underlying SummaryWriter in a callback:



model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="/tmp/sac/", verbose=1)
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)


class SummaryWriterCallback(BaseCallback):
Expand Down
3 changes: 3 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ Release 1.3.0 (2021-10-23)

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Support for Python 3.6 was removed.
- ``sde_net_arch`` argument in policies is deprecated and will be removed in a future version.
- ``_get_latent`` (``ActorCriticPolicy``) was removed
- All logging keys now use underscores instead of spaces (@timokau). Concretely this changes:
Expand All @@ -127,6 +128,7 @@ Breaking Changes:
- ``rollout/exploration rate`` to ``rollout/exploration_rate`` and
- ``rollout/success rate`` to ``rollout/success_rate``.


New Features:
^^^^^^^^^^^^^
- Added methods ``get_distribution`` and ``predict_values`` for ``ActorCriticPolicy`` for A2C/PPO/TRPO (@cyprienc)
Expand All @@ -145,6 +147,7 @@ Bug Fixes:

Deprecations:
^^^^^^^^^^^^^
- Switched minimum Gym version to 0.21.0.

Others:
^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion docs/modules/ddpg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ This example is only to demonstrate the use of the library and its functions, an
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise

env = gym.make("Pendulum-v0")
env = gym.make("Pendulum-v1")

# The noise objects for DDPG
n_actions = env.action_space.shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion docs/modules/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ This example is only to demonstrate the use of the library and its functions, an

from stable_baselines3 import SAC

env = gym.make("Pendulum-v0")
env = gym.make("Pendulum-v1")

model = SAC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
Expand Down
2 changes: 1 addition & 1 deletion docs/modules/td3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ This example is only to demonstrate the use of the library and its functions, an
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise

env = gym.make("Pendulum-v0")
env = gym.make("Pendulum-v1")

# The noise objects for TD3
n_actions = env.action_space.shape[-1]
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gym>=0.17,<0.20", # gym 0.20 breaks atari-py behavior
"gym>=0.21", # Remember to also update gym version in "extra" below when this changes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably fixed the version until #780 is ready

"numpy",
"torch>=1.8.1",
# For saving models
Expand Down Expand Up @@ -116,7 +116,7 @@
# For render
"opencv-python",
# For atari games,
"atari_py==0.2.6",
"gym[atari,accept-rom-license]>=0.21",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this doesn't work (I remember testing it in the past), we should put autorom with accept license here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tested it and it seems to work OK:

virtualenv test_venv
. ./test_venv/bin/activate
pip install -e .[extra]

then:

$ python -c 'import gym; gym.make("BreakoutNoFrameSkip-v4")'
A.L.E: Arcade Learning Environment (version 0.7.4+069f8bd)
[Powered by Stella]

IIRC it fails on some very old pip versions that don't support backtracking. I tested on pip 20.3.4 and Python 3.9, but I'm pretty sure it works on older versions (there's some discussion about this in the review).

"pillow",
# Tensorboard support
"tensorboard>=2.2.0",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_callbacks(tmp_path, model_class):
if model_class in [A2C, PPO]:
max_episodes = 1
n_envs = 2
# Pendulum-v0 has a timelimit of 200 timesteps
# Pendulum-v1 has a timelimit of 200 timesteps
max_episode_length = 200
envs = make_vec_env(env_name, n_envs=n_envs, seed=0)

Expand All @@ -99,7 +99,7 @@ def select_env(model_class) -> str:
if model_class is DQN:
return "CartPole-v0"
else:
return "Pendulum-v0"
return "Pendulum-v1"


def test_eval_callback_vec_env():
Expand Down
6 changes: 3 additions & 3 deletions tests/test_custom_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_flexible_mlp(model_class, net_arch):
@pytest.mark.parametrize("net_arch", [[], [4], [4, 4], dict(qf=[8], pi=[8, 4])])
@pytest.mark.parametrize("model_class", [SAC, TD3])
def test_custom_offpolicy(model_class, net_arch):
_ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=net_arch), learning_starts=100).learn(300)
_ = model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=dict(net_arch=net_arch), learning_starts=100).learn(300)


@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
Expand All @@ -38,12 +38,12 @@ def test_custom_optimizer(model_class, optimizer_kwargs):
kwargs = dict(n_steps=64)

policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
_ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs, **kwargs).learn(300)
_ = model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs, **kwargs).learn(300)


def test_tf_like_rmsprop_optimizer():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = A2C("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(500)
_ = A2C("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs).learn(500)


def test_dqn_custom_policy():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_deterministic_training_common(algo):
rewards = [[], []]
# Smaller network
kwargs = {"policy_kwargs": dict(net_arch=[64])}
env_id = "Pendulum-v0"
env_id = "Pendulum-v1"
if algo in [TD3, SAC]:
kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100, "train_freq": 4})
else:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_squashed_gaussian(model_class):
"""
Test run with squashed Gaussian (notably entropy computation)
"""
model = model_class("MlpPolicy", "Pendulum-v0", use_sde=True, n_steps=64, policy_kwargs=dict(squash_output=True))
model = model_class("MlpPolicy", "Pendulum-v1", use_sde=True, n_steps=64, policy_kwargs=dict(squash_output=True))
model.learn(500)

gaussian_mean = th.rand(N_SAMPLES, N_ACTIONS)
Expand All @@ -57,10 +57,10 @@ def test_squashed_gaussian(model_class):
@pytest.fixture()
def dummy_model_distribution_obs_and_actions() -> Tuple[A2C, np.array, np.array]:
"""
Fixture creating a Pendulum-v0 gym env, an A2C model and sampling 10 random observations and actions from the env
Fixture creating a Pendulum-v1 gym env, an A2C model and sampling 10 random observations and actions from the env
:return: A2C model, random observations, random actions
"""
env = gym.make("Pendulum-v0")
env = gym.make("Pendulum-v1")
model = A2C("MlpPolicy", env, seed=23)
random_obs = np.array([env.observation_space.sample() for _ in range(10)])
random_actions = np.array([env.action_space.sample() for _ in range(10)])
Expand Down
4 changes: 2 additions & 2 deletions tests/test_env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ class ActionDictTestEnv(gym.Env):
observation_space = Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)

def step(self, action):
observation = np.array([1.0, 1.5, 0.5])
observation = np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype)
reward = 1
done = True
info = {}
return observation, reward, done, info

def reset(self):
return np.array([1.0, 1.5, 0.5])
return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype)

def render(self, mode="human"):
pass
Expand Down
6 changes: 3 additions & 3 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
]


@pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v0"])
@pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v1"])
def test_env(env_id):
"""
Check that environmnent integrated in Gym pass the test.
Expand All @@ -38,9 +38,9 @@ def test_env(env_id):
with pytest.warns(None) as record:
check_env(env)

# Pendulum-v0 will produce a warning because the action space is
# Pendulum-v1 will produce a warning because the action space is
# in [-2, 2] and not [-1, 1]
if env_id == "Pendulum-v0":
if env_id == "Pendulum-v1":
assert len(record) == 1
else:
# The other environments must pass without warning
Expand Down
4 changes: 2 additions & 2 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def test_auto_wrap(model_class):
if model_class is DQN:
env_name = "CartPole-v0"
else:
env_name = "Pendulum-v0"
env_name = "Pendulum-v1"
env = gym.make(env_name)
eval_env = gym.make(env_name)
model = model_class("MlpPolicy", env)
model.learn(100, eval_env=eval_env)


@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
def test_predict(model_class, env_id, device):
if device == "cuda" and not th.cuda.is_available():
Expand Down
Loading