Jax-Baseline is a Reinforcement Learning implementation using JAX and Flax/Haiku libraries, mirroring the functionality of Stable-Baselines.
- 2-3 times faster than previous Torch and Tensorflow implementations
- Optimized using JAX's Just-In-Time (JIT) compilation
- Flexible solution for Gym and Unity ML environments
pip install -r requirement.txt
pip install .
- ✔️ : Optional implemented
- ✅ : Defualt implemented at papers
- ❌ : Not implemeted yet or can not implemented
Name | Q-Net based | Actor-Critic based | DPG based |
---|---|---|---|
Gymnasium | ✔️ | ✔️ | ✔️ |
VectorizedGym with Ray | ✔️ | ✔️ | ✔️ |
Name | Double 1 |
Dueling 2 |
Per 3 |
N-step 45 |
NoisyNet 6 |
Munchausen 7 |
Ape-X 8 |
HL-Gauss 9 |
---|---|---|---|---|---|---|---|---|
DQN10 | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ❌ |
C5111 | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
QRDQN12 | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ❌ |
IQN13 | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ❌ | ❌ |
FQF14 | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ❌ | ❌ |
SPR15 | ✅ | ✅ | ✅ | ✅ | ✅ | ✔️ | ❌ | ✔️ |
BBF16 | ✅ | ✅ | ✅ | ✅ | ✔️ | ✔️ | ❌ | ✔️ |
Name | Box |
Discrete |
IMPALA 17 |
Simba 18 |
---|---|---|---|---|
A2C19 | ✔️ | ✔️ | ✔️ | ❌ |
PPO20 | ✔️ | ✔️ | ✔️21 | ❌ |
Truly PPO(TPPO)22 | ✔️ | ✔️ | ❌ | ❌ |
Name | Per 3 |
N-step 45 |
Ape-X 8 |
Simba 18 |
---|---|---|---|---|
DDPG23 | ✔️ | ✔️ | ✔️ | ✔️ |
TD324 | ✔️ | ✔️ | ✔️ | ✔️ |
SAC25 | ✔️ | ✔️ | ❌ | ✔️ |
TQC26 | ✔️ | ✔️ | ❌ | ✔️ |
TD727 | ✅(LAP28) | ❌ | ❌ | ✔️ |
BRO29 | ❌ | ❌ | ❌ | ❌ |
To test Atari with DQN (or C51, QRDQN, IQN, FQF):
python test/run_qnet.py --algo DQN --env BreakoutNoFrameskip-v4 --learning_rate 0.0002 \
--steps 5e5 --batch 32 --train_freq 1 --target_update 1000 --node 512 \
--hidden_n 1 --final_eps 0.01 --learning_starts 20000 --gamma 0.995 --clip_rewards
500K steps can be run in just 15 minutes on Atari Breakout (540 steps/sec). Performance measured on Nvidia RTX3080 and AMD Ryzen 9 5950X in a single process.
score : 9.600, epsilon : 0.010, loss : 0.181 |: 100%|███████| 500000/500000 [15:24<00:00, 540.88it/s]