Skip to content

Commit

Permalink
Update WandbLogger implementation (#558)
Browse files Browse the repository at this point in the history
* Use `global_step` as the x-axis for wandb
* Use Tensorboard SummaryWritter as core with `wandb.init(..., sync_tensorboard=True)`
* Update all atari examples with wandb

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
  • Loading branch information
vwxyzjn and Trinkle23897 authored Mar 6, 2022
1 parent 2377f2f commit df3d7f5
Show file tree
Hide file tree
Showing 10 changed files with 482 additions and 320 deletions.
4 changes: 4 additions & 0 deletions docs/tutorials/logger.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ WandbLogger
::

from tianshou.utils import WandbLogger
from torch.utils.tensorboard import SummaryWriter

logger = WandbLogger(...)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger.load(writer)
result = trainer(..., logger=logger)

Please refer to :class:`~tianshou.utils.WandbLogger` documentation for advanced configuration.
Expand Down
100 changes: 63 additions & 37 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import datetime
import os
import pprint

Expand All @@ -11,46 +12,54 @@
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils import TensorboardLogger, WandbLogger


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--num-atoms", type=int, default=51)
parser.add_argument("--v-min", type=float, default=-10.)
parser.add_argument("--v-max", type=float, default=10.)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--watch',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only"
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args()


Expand Down Expand Up @@ -101,19 +110,36 @@ def test_c51(args=get_args()):
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)

# log
log_path = os.path.join(args.logdir, args.task, 'c51')
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "c51"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
elif "Pong" in args.task:
return mean_rewards >= 20
else:
return False
Expand Down Expand Up @@ -159,7 +185,7 @@ def watch():
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

if args.watch:
watch()
Expand Down Expand Up @@ -190,5 +216,5 @@ def watch():
watch()


if __name__ == '__main__':
if __name__ == "__main__":
test_c51(get_args())
109 changes: 59 additions & 50 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import datetime
import os
import pprint

Expand All @@ -18,62 +19,63 @@

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--resume-id', type=str, default=None)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--logger',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
'--watch',
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only"
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument("--save-buffer-name", type=str, default=None)
parser.add_argument(
'--icm-lr-scale',
"--icm-lr-scale",
type=float,
default=0.,
help='use intrinsic curiosity module with this lr scale'
help="use intrinsic curiosity module with this lr scale"
)
parser.add_argument(
'--icm-reward-scale',
"--icm-reward-scale",
type=float,
default=0.01,
help='scaling factor for intrinsic curiosity reward'
help="scaling factor for intrinsic curiosity reward"
)
parser.add_argument(
'--icm-forward-loss-weight',
"--icm-forward-loss-weight",
type=float,
default=0.2,
help='weight for the forward model loss in ICM'
help="weight for the forward model loss in ICM"
)
return parser.parse_args()

Expand Down Expand Up @@ -140,29 +142,36 @@ def test_dqn(args=get_args()):
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)

# log
log_name = 'dqn_icm' if args.icm_lr_scale > 0 else 'dqn'
log_path = os.path.join(args.logdir, args.task, log_name)
if args.logger == "tensorboard":
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
else:
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "dqn_icm" if args.icm_lr_scale > 0 else "dqn"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
project=args.task,
name=log_name,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
elif "Pong" in args.task:
return mean_rewards >= 20
else:
return False
Expand All @@ -183,8 +192,8 @@ def test_fn(epoch, env_step):

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
torch.save({'model': policy.state_dict()}, ckpt_path)
ckpt_path = os.path.join(log_path, "checkpoint.pth")
torch.save({"model": policy.state_dict()}, ckpt_path)
return ckpt_path

# watch agent's performance
Expand Down Expand Up @@ -214,7 +223,7 @@ def watch():
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

if args.watch:
watch()
Expand Down Expand Up @@ -247,5 +256,5 @@ def watch():
watch()


if __name__ == '__main__':
if __name__ == "__main__":
test_dqn(get_args())
Loading

0 comments on commit df3d7f5

Please sign in to comment.