Skip to content

Commit

Permalink
Merge pull request #23 from vwxyzjn/wandb
Browse files Browse the repository at this point in the history
Add wandb support
  • Loading branch information
ikostrikov2 authored May 13, 2022
2 parents 2418cef + 42dd5c5 commit 52b3653
Show file tree
Hide file tree
Showing 6 changed files with 329 additions and 6 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
*.tfevents.*
tmp
wandb

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
7 changes: 7 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ OpenAI Gym MuJoCo tasks
python train.py --env_name=HalfCheetah-v2 --save_dir=./tmp/
```

Experiment tracking with Weights and Biases

```bash
python train.py --env_name=HalfCheetah-v2 --save_dir=./tmp/ --track
```


DeepMind Control suite (--env-name=domain-task)

```bash
Expand Down
25 changes: 22 additions & 3 deletions examples/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import random
import time

import numpy as np
import tqdm
Expand Down Expand Up @@ -29,6 +30,9 @@
'Number of training steps to start training.')
flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')
flags.DEFINE_boolean('save_video', False, 'Save videos during evaluation.')
flags.DEFINE_boolean('track', False, 'Track experiments with Weights and Biases.')
flags.DEFINE_string('wandb_project_name', "jaxrl", "The wandb's project name.")
flags.DEFINE_string('wandb_entity', None, "the entity (team) of wandb's project")
config_flags.DEFINE_config_file(
'config',
'configs/sac_default.py',
Expand All @@ -37,8 +41,24 @@


def main(_):
kwargs = dict(FLAGS.config)
algo = kwargs.pop('algo')
run_name = f"{FLAGS.env_name}__{algo}__{FLAGS.seed}__{int(time.time())}"
if FLAGS.track:
import wandb

wandb.init(
project=FLAGS.wandb_project_name,
entity=FLAGS.wandb_entity,
sync_tensorboard=True,
config=FLAGS,
name=run_name,
monitor_gym=True,
save_code=True,
)

summary_writer = SummaryWriter(
os.path.join(FLAGS.save_dir, 'tb', str(FLAGS.seed)))
os.path.join(FLAGS.save_dir, run_name))

if FLAGS.save_video:
video_train_folder = os.path.join(FLAGS.save_dir, 'video', 'train')
Expand All @@ -53,8 +73,7 @@ def main(_):
np.random.seed(FLAGS.seed)
random.seed(FLAGS.seed)

kwargs = dict(FLAGS.config)
algo = kwargs.pop('algo')

replay_buffer_size = kwargs.pop('replay_buffer_size')
if algo == 'sac':
agent = SACLearner(FLAGS.seed,
Expand Down
3 changes: 1 addition & 2 deletions jaxrl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from gym.wrappers.pixel_observation import PixelObservationWrapper

from jaxrl import wrappers
from jaxrl.wrappers import VideoRecorder


def make_env(env_name: str,
Expand Down Expand Up @@ -44,7 +43,7 @@ def make_env(env_name: str,
env = RescaleAction(env, -1.0, 1.0)

if save_folder is not None:
env = VideoRecorder(env, save_folder=save_folder)
env = gym.wrappers.RecordVideo(env, save_folder)

if from_pixels:
if env_name in env_ids:
Expand Down
Loading

0 comments on commit 52b3653

Please sign in to comment.