Tongzhou Wang, Antonio Torralba, Phillip Isola, Amy Zhang
This repository is the official code release for paper Optimal Goal-Reaching Reinforcement Learning via Quasimetric Learning, published in ICML 2023. We provide a PyTorch implementation of the proposed Quasimetric RL algorithm (QRL).
See webpage for explanation.
The code has been tested on
- CUDA 11 with NVIDIA RTX Titan, NVIDIA 2080Ti, NVIDIA Titan XP, NVIDIA V100, and NVIDIA 3080.
Software dependencies:
quasimetric-rl/requirements.txt
Lines 1 to 10 in 4f11323
Note
d4rl
depends on mujoco_py
which can be difficult to install. The code lazily imports mujoco_py
and d4rl
if the user requests such environments. Therefore, their installation is not necessary to run the QRL algorithm, e.g., on a custom environment. However, running QRL on the provided environments (d4rl.maze2d
and GCRL
) requires them.
quasimetric_rl.modules
implements the actor and critic components, as well as their associated QRL losses.quasimetric_rl.data
implements data loading and memory buffer utilities, as well as creation of environments.online.main
provides an entry point to online experiments.offline.main
provides an entry point to offline experiments.
Online and offline settings mostly differ in the usage of data storage:
- Offline: static dataset.
- Online: replay buffer that dynamically grows and stores more experiences.
In both online.main
and offline.main
, there is a Conf
object, containing all the provided knobs you can customize QRL behavior. This Conf
object is updated with commandline arguments via hydra
, and then used to create the modules and losses.
To reproduce the offline d4rl
experiments in paper, you can use commands similar to these:
# run umaze seed=12131415 device.index=2
./offline/run_maze2d.sh env.name='maze2d-umaze-v1'
# run medium maze with custom seed, the GPU at index 2, and not training an actor
./offline/run_maze2d.sh env.name='maze2d-medium-v1' seed=12131415 device.index=2 agent.actor=null
# run large maze with custom seed, the GPU at index 3, and 100 gradient steps
./offline/run_maze2d.sh env.name='maze2d-large-v1' seed=44411223 device.index=3 total_optim_steps=100
To reproduce the online gcrl
experiments in paper, you can use commands similar to these:
# run state-input FetchReach
./online/run_gcrl.sh env.name='FetchReach'
# run image-input FetchPush with custom seed and the GPU at index 2
./online/run_gcrl.sh env.name='FetchPushImage' seed=12131415 device.index=2
# run state-input FetchSlide with custom seed, 10 environment steps, and 3 critics
./online/run_gcrl.sh env.name='FetchSlide' seed=44411223 interaction.total_env_steps=10 agent.num_critics=3
Example code for how to load a trained checkpoint (click me)
import os
import torch
from omegaconf import OmegaConf, SCMode
import yaml
from quasimetric_rl.data import Dataset
from quasimetric_rl.modules import QRLAgent, QRLConf
expr_checkpoint = '/xxx/xx/xx/xxxx.pth' # FIXME
expr_dir = os.path.dirname(expr_checkpoint)
with open(expr_dir + '/config.yaml', 'r') as f:
# load saved conf
conf = OmegaConf.create(yaml.safe_load(f))
# 1. How to create env
dataset: Dataset = Dataset.Conf(kind=conf.env.kind, name=conf.env.name).make(dummy=True) # dummy: don't load data
env = dataset.create_env() # <-- you can use this now!
# episodes = list(dataset.load_episodes()) # if you want to load episodes for offline data
# 2. How to re-create QRL agent
agent_conf: QRLConf = OmegaConf.to_container(
OmegaConf.merge(OmegaConf.structured(QRLConf()), conf.agent), # overwrite with loaded conf
structured_config_mode=SCMode.INSTANTIATE, # create the object
)
agent: QRLAgent = agent_conf.make(env_spec=dataset.env_spec, total_optim_steps=1)[0] # you can move to your fav device
# 3. Load checkpoint
agent.load_state_dict(torch.load(expr_checkpoint, map_location='cpu')['agent'])
Note
- We recommend monitoring experiments with tensorboard.
- [Offline Only] if you do not want to train an actor (e.g., because the action space is discrete and the code only implements policy training via backpropagating through quasimetric critics), add
agent.actor=null
. - Environment flag
QRL_DEBUG=1
will enable additional checks and automaticpdb.post_mortem
. It is your debugging friend. - Adding environments can be done via
quasimetric_rl.data.register_(online|offline)_env
. See their docstrings for details. To construct anquasimetric_rl.data.EpisodeData
from a trajectory, see theEpisodeData.from_simple_trajectory
helper constructor.
Q: How to run QRL where the goal is not a single state?
A: If more than one state are considered as "reaching a goal", then we can think of the goal as a set of states. In this case, we can use the trick discussed in paper Appendix A: (1) Encode this goal as a tensor of the same format as states (but distinct from them, e.g., via an added indicator dimension). (2) Add transitions (state that reaches goal -> goal) whenever the agent reaches the goal. QRL can extend to such general goals in this way. This can be implemented by either modifying the dataset storage and sampling code [more flexible but involved], or changing the environment to append a transition when reaching the goal [simpler]. Coming soon: example code on the later approach.
Q: How to deal with variable-cost transitions?
A: Current code assumes that each transition incurs a fixed cost:
To support variable-cost transitions, simply modify these lines to use -data.rewards
as costs. However, you should make sure that your environment/dataset is set up to provide the expected non-positive rewards. We do not check that in current code.
Tongzhou Wang, Antonio Torralba, Phillip Isola, Amy Zhang. "Optimal Goal-Reaching Reinforcement Learning via Quasimetric Learning" International Conference on Machine Learning (ICML). 2023.
@inproceedings{tongzhouw2023qrl,
title={Optimal Goal-Reaching Reinforcement Learning via Quasimetric Learning},
author={Wang, Tongzhou and Torralba, Antonio and Isola, Phillip and Zhang, Amy},
booktitle={International Conference on Machine Learning},
organization={PMLR},
year={2023}
}
For questions about the code provided in this repository, please open an GitHub issue.
For questions about the paper, please contact Tongzhou Wang (tongzhou AT mit DOT edu).
This repo is under MIT license. Please check LICENSE file.