Skip to content

Jax-Baseline is a Reinforcement Learning implementation using JAX and Flax/Haiku libraries, mirroring the functionality of Stable-Baselines.

License

Notifications You must be signed in to change notification settings

tinker495/jax-baseline

Repository files navigation

Jax-Baseline

Jax-Baseline is a Reinforcement Learning implementation using JAX and Flax/Haiku libraries, mirroring the functionality of Stable-Baselines.

Features

  • 2-3 times faster than previous Torch and Tensorflow implementations
  • Optimized using JAX's Just-In-Time (JIT) compilation
  • Flexible solution for Gym and Unity ML environments

Installation

pip install -r requirement.txt
pip install .

Implementation Status

  • ✔️ : Optional implemented
  • ✅ : Defualt implemented at papers
  • ❌ : Not implemeted yet or can not implemented

Supported Environments

Name Q-Net based Actor-Critic based DPG based
Gymnasium ✔️ ✔️ ✔️
VectorizedGym with Ray ✔️ ✔️ ✔️

Implemented Algorithms

Q-Net bases

Name Double1 Dueling2 Per3 N-step45 NoisyNet6 Munchausen7 Ape-X8 HL-Gauss9
DQN10 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
C5111 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
QRDQN12 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
IQN13 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
FQF14 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
SPR15 ✔️ ✔️
BBF16 ✔️ ✔️ ✔️

Actor-Critic based

Name Box Discrete IMPALA17 Simba18
A2C19 ✔️ ✔️ ✔️
PPO20 ✔️ ✔️ ✔️21
Truly PPO(TPPO)22 ✔️ ✔️

DPG bases

Name Per3 N-step45 Ape-X8 Simba18
DDPG23 ✔️ ✔️ ✔️ ✔️
TD324 ✔️ ✔️ ✔️ ✔️
SAC25 ✔️ ✔️ ✔️
TQC26 ✔️ ✔️ ✔️
TD727 ✅(LAP28) ✔️
BRO29

Performance Compariton

Test

To test Atari with DQN (or C51, QRDQN, IQN, FQF):

python test/run_qnet.py --algo DQN --env BreakoutNoFrameskip-v4 --learning_rate 0.0002 \
		--steps 5e5 --batch 32 --train_freq 1 --target_update 1000 --node 512 \
		--hidden_n 1 --final_eps 0.01 --learning_starts 20000 --gamma 0.995 --clip_rewards

500K steps can be run in just 15 minutes on Atari Breakout (540 steps/sec). Performance measured on Nvidia RTX3080 and AMD Ryzen 9 5950X in a single process.

score : 9.600, epsilon : 0.010, loss : 0.181 |: 100%|███████| 500000/500000 [15:24<00:00, 540.88it/s]

Footnotes

  1. Double DQN paper

  2. Dueling DQN paper

  3. PER 2

  4. N-step TD 2

  5. RAINBOW DQN 2

  6. Noisy network

  7. Munchausen rl

  8. Ape-X 2

  9. HL-GAUSS

  10. DQN

  11. C51

  12. QRDQN

  13. IQN

  14. FQF

  15. SPR

  16. BBF

  17. IMPALA

  18. SIMBA 2

  19. A3C

  20. PPO

  21. IMPALA + PPO, APPO

  22. Truly PPO

  23. DDPG

  24. TD3

  25. SAC

  26. TQC

  27. TD7

  28. LaP

  29. BRO

About

Jax-Baseline is a Reinforcement Learning implementation using JAX and Flax/Haiku libraries, mirroring the functionality of Stable-Baselines.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published