From 91b26aaf316bd375aad42140b24acd86161d0518 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sun, 6 Mar 2022 17:40:47 -0500 Subject: [PATCH] Update WandbLogger implementation (#558) * Use `global_step` as the x-axis for wandb * Use Tensorboard SummaryWritter as core with `wandb.init(..., sync_tensorboard=True)` * Update all atari examples with wandb Co-authored-by: Jiayi Weng --- docs/tutorials/logger.rst | 4 ++ examples/atari/atari_c51.py | 100 ++++++++++++++++---------- examples/atari/atari_dqn.py | 109 +++++++++++++++------------- examples/atari/atari_fqf.py | 104 +++++++++++++++++---------- examples/atari/atari_iqn.py | 104 +++++++++++++++++---------- examples/atari/atari_ppo.py | 123 +++++++++++++++++--------------- examples/atari/atari_qrdqn.py | 96 ++++++++++++++++--------- examples/atari/atari_rainbow.py | 120 ++++++++++++++++++------------- test/modelbased/test_psrl.py | 7 +- tianshou/utils/logger/wandb.py | 35 +++++---- 10 files changed, 482 insertions(+), 320 deletions(-) diff --git a/docs/tutorials/logger.rst b/docs/tutorials/logger.rst index 6a65e3296..c8161374d 100644 --- a/docs/tutorials/logger.rst +++ b/docs/tutorials/logger.rst @@ -34,8 +34,12 @@ WandbLogger :: from tianshou.utils import WandbLogger + from torch.utils.tensorboard import SummaryWriter logger = WandbLogger(...) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger.load(writer) result = trainer(..., logger=logger) Please refer to :class:`~tianshou.utils.WandbLogger` documentation for advanced configuration. diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 4d5a91c2b..71c52e0d3 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -1,4 +1,5 @@ import argparse +import datetime import os import pprint @@ -11,46 +12,54 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import C51Policy from tianshou.trainer import offpolicy_trainer -from tianshou.utils import TensorboardLogger +from tianshou.utils import TensorboardLogger, WandbLogger def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--scale-obs', type=int, default=0) - parser.add_argument('--eps-test', type=float, default=0.005) - parser.add_argument('--eps-train', type=float, default=1.) - parser.add_argument('--eps-train-final', type=float, default=0.05) - parser.add_argument('--buffer-size', type=int, default=100000) - parser.add_argument('--lr', type=float, default=0.0001) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--num-atoms', type=int, default=51) - parser.add_argument('--v-min', type=float, default=-10.) - parser.add_argument('--v-max', type=float, default=10.) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--target-update-freq', type=int, default=500) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=100000) - parser.add_argument('--step-per-collect', type=int, default=10) - parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--batch-size', type=int, default=32) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=10) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--num-atoms", type=int, default=51) + parser.add_argument("--v-min", type=float, default=-10.) + parser.add_argument("--v-max", type=float, default=10.) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) - parser.add_argument('--frames-stack', type=int, default=4) - parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( - '--watch', + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", default=False, - action='store_true', - help='watch the play of pre-trained policy only' + action="store_true", + help="watch the play of pre-trained policy only" ) - parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument("--save-buffer-name", type=str, default=None) return parser.parse_args() @@ -101,19 +110,36 @@ def test_c51(args=get_args()): # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) + # log - log_path = os.path.join(args.logdir, args.task, 'c51') + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "c51" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold - elif 'Pong' in args.task: + elif "Pong" in args.task: return mean_rewards >= 20 else: return False @@ -159,7 +185,7 @@ def watch(): n_episode=args.test_num, render=args.render ) rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: watch() @@ -190,5 +216,5 @@ def watch(): watch() -if __name__ == '__main__': +if __name__ == "__main__": test_c51(get_args()) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 4bae24baa..c5aacd2b7 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -1,4 +1,5 @@ import argparse +import datetime import os import pprint @@ -18,62 +19,63 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--scale-obs', type=int, default=0) - parser.add_argument('--eps-test', type=float, default=0.005) - parser.add_argument('--eps-train', type=float, default=1.) - parser.add_argument('--eps-train-final', type=float, default=0.05) - parser.add_argument('--buffer-size', type=int, default=100000) - parser.add_argument('--lr', type=float, default=0.0001) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--target-update-freq', type=int, default=500) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=100000) - parser.add_argument('--step-per-collect', type=int, default=10) - parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--batch-size', type=int, default=32) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=10) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) - parser.add_argument('--frames-stack', type=int, default=4) - parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--resume-id', type=str, default=None) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( - '--logger', + "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") parser.add_argument( - '--watch', + "--watch", default=False, - action='store_true', - help='watch the play of pre-trained policy only' + action="store_true", + help="watch the play of pre-trained policy only" ) - parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument("--save-buffer-name", type=str, default=None) parser.add_argument( - '--icm-lr-scale', + "--icm-lr-scale", type=float, default=0., - help='use intrinsic curiosity module with this lr scale' + help="use intrinsic curiosity module with this lr scale" ) parser.add_argument( - '--icm-reward-scale', + "--icm-reward-scale", type=float, default=0.01, - help='scaling factor for intrinsic curiosity reward' + help="scaling factor for intrinsic curiosity reward" ) parser.add_argument( - '--icm-forward-loss-weight', + "--icm-forward-loss-weight", type=float, default=0.2, - help='weight for the forward model loss in ICM' + help="weight for the forward model loss in ICM" ) return parser.parse_args() @@ -140,29 +142,36 @@ def test_dqn(args=get_args()): # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) + # log - log_name = 'dqn_icm' if args.icm_lr_scale > 0 else 'dqn' - log_path = os.path.join(args.logdir, args.task, log_name) - if args.logger == "tensorboard": - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) - else: + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "dqn_icm" if args.icm_lr_scale > 0 else "dqn" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": logger = WandbLogger( save_interval=1, - project=args.task, - name=log_name, + name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, + project=args.wandb_project, ) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold - elif 'Pong' in args.task: + elif "Pong" in args.task: return mean_rewards >= 20 else: return False @@ -183,8 +192,8 @@ def test_fn(epoch, env_step): def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html - ckpt_path = os.path.join(log_path, 'checkpoint.pth') - torch.save({'model': policy.state_dict()}, ckpt_path) + ckpt_path = os.path.join(log_path, "checkpoint.pth") + torch.save({"model": policy.state_dict()}, ckpt_path) return ckpt_path # watch agent's performance @@ -214,7 +223,7 @@ def watch(): n_episode=args.test_num, render=args.render ) rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: watch() @@ -247,5 +256,5 @@ def watch(): watch() -if __name__ == '__main__': +if __name__ == "__main__": test_dqn(get_args()) diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 4a8133556..c4c1b36c8 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -1,4 +1,5 @@ import argparse +import datetime import os import pprint @@ -11,49 +12,57 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import FQFPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.utils import TensorboardLogger +from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') - parser.add_argument('--seed', type=int, default=3128) - parser.add_argument('--scale-obs', type=int, default=0) - parser.add_argument('--eps-test', type=float, default=0.005) - parser.add_argument('--eps-train', type=float, default=1.) - parser.add_argument('--eps-train-final', type=float, default=0.05) - parser.add_argument('--buffer-size', type=int, default=100000) - parser.add_argument('--lr', type=float, default=5e-5) - parser.add_argument('--fraction-lr', type=float, default=2.5e-9) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--num-fractions', type=int, default=32) - parser.add_argument('--num-cosines', type=int, default=64) - parser.add_argument('--ent-coef', type=float, default=10.) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--target-update-freq', type=int, default=500) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=100000) - parser.add_argument('--step-per-collect', type=int, default=10) - parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--batch-size', type=int, default=32) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=10) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=3128) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=5e-5) + parser.add_argument("--fraction-lr", type=float, default=2.5e-9) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--num-fractions", type=int, default=32) + parser.add_argument("--num-cosines", type=int, default=64) + parser.add_argument("--ent-coef", type=float, default=10.) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) - parser.add_argument('--frames-stack', type=int, default=4) - parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( - '--watch', + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", default=False, - action='store_true', - help='watch the play of pre-trained policy only' + action="store_true", + help="watch the play of pre-trained policy only" ) - parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument("--save-buffer-name", type=str, default=None) return parser.parse_args() @@ -118,19 +127,36 @@ def test_fqf(args=get_args()): # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) + # log - log_path = os.path.join(args.logdir, args.task, 'fqf') + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "fqf" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold - elif 'Pong' in args.task: + elif "Pong" in args.task: return mean_rewards >= 20 else: return False @@ -176,7 +202,7 @@ def watch(): n_episode=args.test_num, render=args.render ) rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: watch() @@ -207,5 +233,5 @@ def watch(): watch() -if __name__ == '__main__': +if __name__ == "__main__": test_fqf(get_args()) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 0ae1d94a7..49eb6e9d5 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -1,4 +1,5 @@ import argparse +import datetime import os import pprint @@ -11,49 +12,57 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import IQNPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.utils import TensorboardLogger +from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.discrete import ImplicitQuantileNetwork def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') - parser.add_argument('--seed', type=int, default=1234) - parser.add_argument('--scale-obs', type=int, default=0) - parser.add_argument('--eps-test', type=float, default=0.005) - parser.add_argument('--eps-train', type=float, default=1.) - parser.add_argument('--eps-train-final', type=float, default=0.05) - parser.add_argument('--buffer-size', type=int, default=100000) - parser.add_argument('--lr', type=float, default=0.0001) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--sample-size', type=int, default=32) - parser.add_argument('--online-sample-size', type=int, default=8) - parser.add_argument('--target-sample-size', type=int, default=8) - parser.add_argument('--num-cosines', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--target-update-freq', type=int, default=500) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=100000) - parser.add_argument('--step-per-collect', type=int, default=10) - parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--batch-size', type=int, default=32) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=10) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--sample-size", type=int, default=32) + parser.add_argument("--online-sample-size", type=int, default=8) + parser.add_argument("--target-sample-size", type=int, default=8) + parser.add_argument("--num-cosines", type=int, default=64) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) - parser.add_argument('--frames-stack', type=int, default=4) - parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( - '--watch', + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", default=False, - action='store_true', - help='watch the play of pre-trained policy only' + action="store_true", + help="watch the play of pre-trained policy only" ) - parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument("--save-buffer-name", type=str, default=None) return parser.parse_args() @@ -113,19 +122,36 @@ def test_iqn(args=get_args()): # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) + # log - log_path = os.path.join(args.logdir, args.task, 'iqn') + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "iqn" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold - elif 'Pong' in args.task: + elif "Pong" in args.task: return mean_rewards >= 20 else: return False @@ -171,7 +197,7 @@ def watch(): n_episode=args.test_num, render=args.render ) rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: watch() @@ -202,5 +228,5 @@ def watch(): watch() -if __name__ == '__main__': +if __name__ == "__main__": test_iqn(get_args()) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 94e2c3f47..8c70bf66e 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -1,4 +1,5 @@ import argparse +import datetime import os import pprint @@ -19,69 +20,70 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') - parser.add_argument('--seed', type=int, default=4213) - parser.add_argument('--scale-obs', type=int, default=0) - parser.add_argument('--buffer-size', type=int, default=100000) - parser.add_argument('--lr', type=float, default=5e-5) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=100000) - parser.add_argument('--step-per-collect', type=int, default=1000) - parser.add_argument('--repeat-per-collect', type=int, default=4) - parser.add_argument('--batch-size', type=int, default=256) - parser.add_argument('--hidden-size', type=int, default=512) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=10) - parser.add_argument('--rew-norm', type=int, default=False) - parser.add_argument('--vf-coef', type=float, default=0.5) - parser.add_argument('--ent-coef', type=float, default=0.01) - parser.add_argument('--gae-lambda', type=float, default=0.95) - parser.add_argument('--lr-decay', type=int, default=True) - parser.add_argument('--max-grad-norm', type=float, default=0.5) - parser.add_argument('--eps-clip', type=float, default=0.2) - parser.add_argument('--dual-clip', type=float, default=None) - parser.add_argument('--value-clip', type=int, default=0) - parser.add_argument('--norm-adv', type=int, default=1) - parser.add_argument('--recompute-adv', type=int, default=0) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=4213) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=5e-5) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=1000) + parser.add_argument("--repeat-per-collect", type=int, default=4) + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--hidden-size", type=int, default=512) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--rew-norm", type=int, default=False) + parser.add_argument("--vf-coef", type=float, default=0.5) + parser.add_argument("--ent-coef", type=float, default=0.01) + parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--lr-decay", type=int, default=True) + parser.add_argument("--max-grad-norm", type=float, default=0.5) + parser.add_argument("--eps-clip", type=float, default=0.2) + parser.add_argument("--dual-clip", type=float, default=None) + parser.add_argument("--value-clip", type=int, default=0) + parser.add_argument("--norm-adv", type=int, default=1) + parser.add_argument("--recompute-adv", type=int, default=0) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) - parser.add_argument('--frames-stack', type=int, default=4) - parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--resume-id', type=str, default=None) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( - '--logger', + "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") parser.add_argument( - '--watch', + "--watch", default=False, - action='store_true', - help='watch the play of pre-trained policy only' + action="store_true", + help="watch the play of pre-trained policy only" ) - parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument("--save-buffer-name", type=str, default=None) parser.add_argument( - '--icm-lr-scale', + "--icm-lr-scale", type=float, default=0., - help='use intrinsic curiosity module with this lr scale' + help="use intrinsic curiosity module with this lr scale" ) parser.add_argument( - '--icm-reward-scale', + "--icm-reward-scale", type=float, default=0.01, - help='scaling factor for intrinsic curiosity reward' + help="scaling factor for intrinsic curiosity reward" ) parser.add_argument( - '--icm-forward-loss-weight', + "--icm-forward-loss-weight", type=float, default=0.2, - help='weight for the forward model loss in ICM' + help="weight for the forward model loss in ICM" ) return parser.parse_args() @@ -184,37 +186,44 @@ def dist(p): # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) + # log - log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo' - log_path = os.path.join(args.logdir, args.task, log_name) - if args.logger == "tensorboard": - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) - else: + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "ppo_icm" if args.icm_lr_scale > 0 else "ppo" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": logger = WandbLogger( save_interval=1, - project=args.task, - name=log_name, + name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, + project=args.wandb_project, ) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold - elif 'Pong' in args.task: + elif "Pong" in args.task: return mean_rewards >= 20 else: return False def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html - ckpt_path = os.path.join(log_path, 'checkpoint.pth') - torch.save({'model': policy.state_dict()}, ckpt_path) + ckpt_path = os.path.join(log_path, "checkpoint.pth") + torch.save({"model": policy.state_dict()}, ckpt_path) return ckpt_path # watch agent's performance @@ -243,7 +252,7 @@ def watch(): n_episode=args.test_num, render=args.render ) rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: watch() @@ -274,5 +283,5 @@ def watch(): watch() -if __name__ == '__main__': +if __name__ == "__main__": test_ppo(get_args()) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 1ae4a8c88..06f1cbb56 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -1,4 +1,5 @@ import argparse +import datetime import os import pprint @@ -6,7 +7,7 @@ import torch from atari_network import QRDQN from atari_wrapper import make_atari_env -from torch.utils.tensorboard import SummaryWriter +from torch.utils.tensorboard import SummaryWriter, WandbLogger from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import QRDQNPolicy @@ -16,39 +17,47 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--scale-obs', type=int, default=0) - parser.add_argument('--eps-test', type=float, default=0.005) - parser.add_argument('--eps-train', type=float, default=1.) - parser.add_argument('--eps-train-final', type=float, default=0.05) - parser.add_argument('--buffer-size', type=int, default=100000) - parser.add_argument('--lr', type=float, default=0.0001) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--num-quantiles', type=int, default=200) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--target-update-freq', type=int, default=500) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=100000) - parser.add_argument('--step-per-collect', type=int, default=10) - parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--batch-size', type=int, default=32) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=10) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--num-quantiles", type=int, default=200) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) - parser.add_argument('--frames-stack', type=int, default=4) - parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( - '--watch', + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", default=False, - action='store_true', - help='watch the play of pre-trained policy only' + action="store_true", + help="watch the play of pre-trained policy only" ) - parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument("--save-buffer-name", type=str, default=None) return parser.parse_args() @@ -97,19 +106,36 @@ def test_qrdqn(args=get_args()): # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) + # log - log_path = os.path.join(args.logdir, args.task, 'qrdqn') + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "qrdqn" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold - elif 'Pong' in args.task: + elif "Pong" in args.task: return mean_rewards >= 20 else: return False @@ -155,7 +181,7 @@ def watch(): n_episode=args.test_num, render=args.render ) rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: watch() @@ -186,5 +212,5 @@ def watch(): watch() -if __name__ == '__main__': +if __name__ == "__main__": test_qrdqn(get_args()) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 478a6d2b8..5109a1eeb 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -7,7 +7,7 @@ import torch from atari_network import Rainbow from atari_wrapper import make_atari_env -from torch.utils.tensorboard import SummaryWriter +from torch.utils.tensorboard import SummaryWriter, WandbLogger from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.policy import RainbowPolicy @@ -17,50 +17,58 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--scale-obs', type=int, default=0) - parser.add_argument('--eps-test', type=float, default=0.005) - parser.add_argument('--eps-train', type=float, default=1.) - parser.add_argument('--eps-train-final', type=float, default=0.05) - parser.add_argument('--buffer-size', type=int, default=100000) - parser.add_argument('--lr', type=float, default=0.0000625) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--num-atoms', type=int, default=51) - parser.add_argument('--v-min', type=float, default=-10.) - parser.add_argument('--v-max', type=float, default=10.) - parser.add_argument('--noisy-std', type=float, default=0.1) - parser.add_argument('--no-dueling', action='store_true', default=False) - parser.add_argument('--no-noisy', action='store_true', default=False) - parser.add_argument('--no-priority', action='store_true', default=False) - parser.add_argument('--alpha', type=float, default=0.5) - parser.add_argument('--beta', type=float, default=0.4) - parser.add_argument('--beta-final', type=float, default=1.) - parser.add_argument('--beta-anneal-step', type=int, default=5000000) - parser.add_argument('--no-weight-norm', action='store_true', default=False) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--target-update-freq', type=int, default=500) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=100000) - parser.add_argument('--step-per-collect', type=int, default=10) - parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--batch-size', type=int, default=32) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=10) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=0.0000625) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--num-atoms", type=int, default=51) + parser.add_argument("--v-min", type=float, default=-10.) + parser.add_argument("--v-max", type=float, default=10.) + parser.add_argument("--noisy-std", type=float, default=0.1) + parser.add_argument("--no-dueling", action="store_true", default=False) + parser.add_argument("--no-noisy", action="store_true", default=False) + parser.add_argument("--no-priority", action="store_true", default=False) + parser.add_argument("--alpha", type=float, default=0.5) + parser.add_argument("--beta", type=float, default=0.4) + parser.add_argument("--beta-final", type=float, default=1.) + parser.add_argument("--beta-anneal-step", type=int, default=5000000) + parser.add_argument("--no-weight-norm", action="store_true", default=False) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) - parser.add_argument('--frames-stack', type=int, default=4) - parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) parser.add_argument( - '--watch', + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", default=False, - action='store_true', - help='watch the play of pre-trained policy only' + action="store_true", + help="watch the play of pre-trained policy only" ) - parser.add_argument('--save-buffer-name', type=str, default=None) + parser.add_argument("--save-buffer-name", type=str, default=None) return parser.parse_args() @@ -131,22 +139,36 @@ def test_rainbow(args=get_args()): # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) + # log - log_path = os.path.join( - args.logdir, args.task, 'rainbow', - f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}' - ) + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "rainbow" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold - elif 'Pong' in args.task: + elif "Pong" in args.task: return mean_rewards >= 20 else: return False @@ -203,7 +225,7 @@ def watch(): n_episode=args.test_num, render=args.render ) rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: watch() @@ -234,5 +256,5 @@ def watch(): watch() -if __name__ == '__main__': +if __name__ == "__main__": test_rainbow(get_args()) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index a20d3dd0c..3a3023005 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -79,11 +79,14 @@ def test_psrl(args=get_args()): logger = WandbLogger( save_interval=1, project='psrl', name='wandb_test', config=args ) - elif args.logger == "tensorboard": + if args.logger != "none": log_path = os.path.join(args.logdir, args.task, 'psrl') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: + logger.load(writer) else: logger = LazyLogger() diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index f9c047c59..32a89d2a0 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -2,7 +2,9 @@ import os from typing import Callable, Optional, Tuple -from tianshou.utils import BaseLogger +from torch.utils.tensorboard import SummaryWriter + +from tianshou.utils import BaseLogger, TensorboardLogger from tianshou.utils.logger.base import LOG_DATA_TYPE try: @@ -17,17 +19,13 @@ class WandbLogger(BaseLogger): This logger creates three panels with plots: train, test, and update. Make sure to select the correct access for each panel in weights and biases: - - ``train/env_step`` for train plots - - ``test/env_step`` for test plots - - ``update/gradient_step`` for update plots - Example of usage: :: - with wandb.init(project="My Project"): - logger = WandBLogger() - result = onpolicy_trainer(policy, train_collector, test_collector, - logger=logger) + logger = WandbLogger() + logger.load(SummaryWriter(log_path)) + result = onpolicy_trainer(policy, train_collector, test_collector, + logger=logger) :param int train_interval: the log interval in log_train_data(). Default to 1000. :param int test_interval: the log interval in log_test_data(). Default to 1. @@ -46,7 +44,7 @@ def __init__( test_interval: int = 1, update_interval: int = 1000, save_interval: int = 1000, - project: str = 'tianshou', + project: Optional[str] = None, name: Optional[str] = None, entity: Optional[str] = None, run_id: Optional[str] = None, @@ -56,6 +54,8 @@ def __init__( self.last_save_step = -1 self.save_interval = save_interval self.restored = False + if project is None: + project = os.getenv("WANDB_PROJECT", "tianshou") self.wandb_run = wandb.init( project=project, @@ -63,14 +63,25 @@ def __init__( id=run_id, resume="allow", entity=entity, + sync_tensorboard=True, monitor_gym=True, config=config, # type: ignore ) if not wandb.run else wandb.run self.wandb_run._label(repo="tianshou") # type: ignore + self.tensorboard_logger: Optional[TensorboardLogger] = None + + def load(self, writer: SummaryWriter) -> None: + self.writer = writer + self.tensorboard_logger = TensorboardLogger(writer) def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: - data[step_type] = step - wandb.log(data) + if self.tensorboard_logger is None: + raise Exception( + "`logger` needs to load the Tensorboard Writer before " + "writing data. Try `logger.load(SummaryWriter(log_path))`" + ) + else: + self.tensorboard_logger.write(step_type, step, data) def save_data( self,