Skip to content

Commit

Permalink
Add a performance test - might be slow?
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Jul 7, 2023
1 parent d018cbd commit 29cdbfa
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions tests/algorithms/test_sqil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import stable_baselines3.common.buffers as buffers
import stable_baselines3.common.vec_env as vec_env
import stable_baselines3.dqn as dqn
from stable_baselines3 import ppo
from stable_baselines3.common.evaluation import evaluate_policy

from imitation.algorithms import base as algo_base
from imitation.algorithms import sqil
from imitation.data import rollout, wrappers
from imitation.testing import reward_improvement


def test_sqil_demonstration_buffer(rng):
Expand Down Expand Up @@ -169,3 +172,61 @@ def test_sqil_cartpole_few_demonstrations(rng):
dqn_kwargs=dict(learning_starts=10),
)
model.train(total_timesteps=1_000)


def test_sqil_performance(rng):
env = gym.make("CartPole-v1")
venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)])

expert = ppo.PPO(
policy=ppo.MlpPolicy,
env=env,
seed=0,
batch_size=64,
ent_coef=0.0,
learning_rate=0.0003,
n_epochs=10,
n_steps=64,
)
expert.learn(10_000)

expert_reward, _ = evaluate_policy(expert, env, 10)
print(expert_reward)

rollouts = rollout.rollout(
expert.policy,
venv,
rollout.make_sample_until(min_timesteps=None, min_episodes=50),
rng=rng,
)

demonstrations = rollout.flatten_trajectories(rollouts)
demonstrations = demonstrations[:5]

model = sqil.SQIL(
venv=venv,
demonstrations=demonstrations,
policy="MlpPolicy",
dqn_kwargs=dict(learning_starts=1000),
)

rewards_before, _ = evaluate_policy(
model.policy,
env,
10,
return_episode_rewards=True,
)

model.train(total_timesteps=10_000)

rewards_after, _ = evaluate_policy(
model.policy,
env,
10,
return_episode_rewards=True,
)

assert reward_improvement.is_significant_reward_improvement(
rewards_before,
rewards_after,
)

0 comments on commit 29cdbfa

Please sign in to comment.