Skip to content

Commit

Permalink
Some documentation updates (not complete)
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Jul 6, 2023
1 parent c4b0521 commit 1b5338b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Currently, we have implementations of the algorithms below. 'Discrete' and 'Cont
| [Adversarial Inverse Reinforcement Learning](https://arxiv.org/abs/1710.11248) | [`algoritms.airl`](https://imitation.readthedocs.io/en/latest/algorithms/airl.html) |||
| [Generative Adversarial Imitation Learning](https://arxiv.org/abs/1606.03476) | [`algorithms.gail`](https://imitation.readthedocs.io/en/latest/algorithms/gail.html) |||
| [Deep RL from Human Preferences](https://arxiv.org/abs/1706.03741) | [`algorithms.preference_comparisons`](https://imitation.readthedocs.io/en/latest/algorithms/preference_comparisons.html) |||
| [Soft Q Imitation Learning](https://arxiv.org/abs/1905.11108) | [`algorithms.sqil`](https://imitation.readthedocs.io/en/latest/algorithms/sqil.html) |||


You can find [the documentation here](https://imitation.readthedocs.io/en/latest/).
Expand Down
59 changes: 59 additions & 0 deletions docs/algorithms/sqil.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
.. _soft q imitation learning docs:

=======================
Soft Q Imitation Learning (SQIL)
=======================

<add description of SQIL>

Example
=======

Detailed example notebook: :doc:`../tutorials/10_train_sqil`

.. testcode::
:skipif: skip_doctests

import numpy as np
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms import sqil
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper

rng = np.random.default_rng(0)
env = gym.make("CartPole-v1")
expert = PPO(policy=MlpPolicy, env=env)
expert.learn(1000)

rollouts = rollout.rollout(
expert,
DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
rollout.make_sample_until(min_timesteps=None, min_episodes=50),
rng=rng,
)
transitions = rollout.flatten_trajectories(rollouts)

sqil_trainer = sqil.SQIL(
venv=DummyVecEnv([lambda: env]),
demonstrations=transitions,
policy="MlpPolicy",
)
sqil_trainer.train(n_epochs=1)
reward, _ = evaluate_policy(sqil_trainer.policy, env, 10)
print("Reward:", reward)

.. testoutput::
:hide:

...

API
===
.. autoclass:: imitation.algorithms.sqil.SQIL
:members:
:noindex:

0 comments on commit 1b5338b

Please sign in to comment.