diff --git a/examples/mujoco/README.md b/examples/mujoco/README.md
index 3f116fd6b..0f1e7f9a2 100644
--- a/examples/mujoco/README.md
+++ b/examples/mujoco/README.md
@@ -1,27 +1,135 @@
-# Mujoco Result
+# Tianshou's Mujoco Benchmark
+We benchmarked Tianshou algorithm implementations in 9 out of 13 environments from the MuJoCo Gym task suite[[1]](#footnote1).
+For each supported algorithm and supported mujoco environments, we provide:
+- Default hyperparameters used for benchmark and scripts to reproduce the benchmark;
+- A comparison of performance (or code level details) with other open source implementations or classic papers;
+- Graphs and raw data that can be used for research purposes[[2]](#footnote2);
+- Log details obtained during training[[2]](#footnote2);
+- Pretrained agents[[2]](#footnote2);
+- Some hints on how to tune the algorithm.
+
-## SAC (single run)
+Supported algorithms are listed below:
+- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/v0.4.0)
+- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/v0.4.0)
+- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/v0.4.0)
-The best reward computes from 100 episodes returns in the test phase.
+## Offpolicy algorithms
-SAC on Swimmer-v3 always stops at 47\~48.
+#### Usage
-| task | 3M best reward | parameters | time cost (3M) |
-| -------------- | ----------------- | ------------------------------------------------------- | -------------- |
-| HalfCheetah-v3 | 10157.70 ± 171.70 | `python3 mujoco_sac.py --task HalfCheetah-v3` | 2~3h |
-| Walker2d-v3 | 5143.04 ± 15.57 | `python3 mujoco_sac.py --task Walker2d-v3` | 2~3h |
-| Hopper-v3 | 3604.19 ± 169.55 | `python3 mujoco_sac.py --task Hopper-v3` | 2~3h |
-| Humanoid-v3 | 6579.20 ± 1470.57 | `python3 mujoco_sac.py --task Humanoid-v3 --alpha 0.05` | 2~3h |
-| Ant-v3 | 6281.65 ± 686.28 | `python3 mujoco_sac.py --task Ant-v3` | 2~3h |
+Run
-![](results/sac/all.png)
+```bash
+$ python mujoco_sac.py --task Ant-v3
+```
-### Which parts are important?
+Logs is saved in `./log/` and can be monitored with tensorboard.
+
+```bash
+$ tensorboard --logdir log
+```
+
+You can also reproduce the benchmark (e.g. SAC in Ant-v3) with the example script we provide under `examples/mujoco/`:
+
+```bash
+$ ./run_experiments.sh Ant-v3
+```
+
+This will start 10 experiments with different seeds.
+
+#### Example benchmark
+
+
+
+Other graphs can be found under `/examples/mujuco/benchmark/`
+
+#### Hints
+
+In offpolicy algorithms(DDPG, TD3, SAC), the shared hyperparameters are almost the same[[8]](#footnote8), and most hyperparameters are consistent with those used for benchmark in SpinningUp's implementations[[9]](#footnote9).
+
+By comparison to both classic literature and open source implementations (e.g., SpinningUp)[[1]](#footnote1)[[2]](#footnote2), Tianshou's implementations of DDPG, TD3, and SAC are roughly at-parity with or better than the best reported results for these algorithms.
+
+### DDPG
+
+| Environment | Tianshou | [SpinningUp (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper (DDPG)](https://arxiv.org/abs/1802.09477) | [TD3 paper (OurDDPG)](https://arxiv.org/abs/1802.09477) |
+| :--------------------: | :---------------: | :----------------------------------------------------------: | :--------------------------------------------------: | :-----------------------------------------------------: |
+| Ant | 990.4±4.3 | ~840 | **1005.3** | 888.8 |
+| HalfCheetah | **11718.7±465.6** | ~11000 | 3305.6 | 8577.3 |
+| Hopper | **2197.0±971.6** | ~1800 | **2020.5** | 1860.0 |
+| Walker2d | 1400.6±905.0 | ~1950 | 1843.6 | **3098.1** |
+| Swimmer | **144.1±6.5** | ~137 | N | N |
+| Humanoid | **177.3±77.6** | N | N | N |
+| Reacher | **-3.3±0.3** | N | -6.51 | -4.01 |
+| InvertedPendulum | **1000.0±0.0** | N | **1000.0** | **1000.0** |
+| InvertedDoublePendulum | 8364.3±2778.9 | N | **9355.5** | 8370.0 |
+
+\* details[[5]](#footnote5)[[6]](#footnote6)[[7]](#footnote7)
+
+### TD3
+
+| Environment | Tianshou | [SpinningUp (Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper](https://arxiv.org/abs/1802.09477) |
+| :--------------------: | :---------------: | :-------------------: | :--------------: |
+| Ant | **5116.4±799.9** | ~3800 | 4372.4±1000.3 |
+| HalfCheetah | **10201.2±772.8** | ~9750 | 9637.0±859.1 |
+| Hopper | 3472.2±116.8 | ~2860 | **3564.1±114.7** |
+| Walker2d | 3982.4±274.5 | ~4000 | **4682.8±539.6** |
+| Swimmer | **104.2±34.2** | ~78 | N |
+| Humanoid | **5189.5±178.5** | N | N |
+| Reacher | **-2.7±0.2** | N | -3.6±0.6 |
+| InvertedPendulum | **1000.0±0.0** | N | **1000.0±0.0** |
+| InvertedDoublePendulum | **9349.2±14.3** | N | **9337.5±15.0** |
+
+\* details[[5]](#footnote5)[[6]](#footnote6)[[7]](#footnote7)
+
+### SAC
+
+| Environment | Tianshou | [SpinningUp (Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [SAC paper](https://arxiv.org/abs/1801.01290) |
+| :--------------------: | :----------------: | :-------------------: | :---------: |
+| Ant | **5850.2±475.7** | ~3980 | ~3720 |
+| HalfCheetah | **12138.8±1049.3** | ~11520 | ~10400 |
+| Hopper | **3542.2±51.5** | ~3150 | ~3370 |
+| Walker2d | **5007.0±251.5** | ~4250 | ~3740 |
+| Swimmer | **44.4±0.5** | ~41.7 | N |
+| Humanoid | **5488.5±81.2** | N | ~5200 |
+| Reacher | **-2.6±0.2** | N | N |
+| InvertedPendulum | **1000.0±0.0** | N | N |
+| InvertedDoublePendulum | **9359.5±0.4** | N | N |
+
+\* details[[5]](#footnote5)[[6]](#footnote6)
+
+#### Hints for SAC
0. DO NOT share the same network with two critic networks.
-1. The sigma (of the Gaussian policy) MUST be conditioned on input.
+1. The sigma (of the Gaussian policy) should be conditioned on input.
2. The network size should not be less than 256.
3. The deterministic evaluation helps a lot :)
+## Onpolicy Algorithms
+
+TBD
+
+
+
+
+## Note
+
+[1] Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures.
+
+[2] Pretrained agents, detailed graphs (single agent, single game) and log details can all be found [here](https://cloud.tsinghua.edu.cn/d/356e0f5d1e66426b9828/).
+
+[3] We used the latest version of all mujoco environments in gym (0.17.3 with mujoco==2.0.2.13), but it's not often the case with other benchmarks. Please check for details yourself in the original paper. (Different version's outcomes are usually similar, though)
+
+[4] We didn't compare offpolicy algorithms to OpenAI baselines [benchmark](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm), because for now it seems that they haven't provided benchmark for offpolicy algorithms, but in [SpinningUp docs](https://spinningup.openai.com/en/latest/spinningup/bench.html) they stated that "SpinningUp implementations of DDPG, TD3, and SAC are roughly at-parity with the best-reported results for these algorithms", so we think lack of comparisons with OpenAI baselines is okay.
+
+[5] ~ means the number is approximated from the graph because accurate numbers is not provided in the paper. N means graphs not provided.
+
+[6] Reward metric: The meaning of the table value is the max average return over 10 trails (different seeds) ± a single standard deviation over trails. Each trial is averaged on another 10 test seeds. Only the first 1M steps data will be considered. The shaded region on the graph also represents a single standard deviation. It is the same as [TD3 evaluation method](https://github.com/sfujim/TD3/issues/34).
+
+[7] In TD3 paper, shaded region represents only half of standard deviation.
+
+[8] SAC's start-timesteps is set to 10000 by default while it is 25000 is DDPG/TD3. TD3's learning rate is set to 3e-4 while it is 1e-3 for DDPG/SAC. However, there is NO enough evidence to support our choice of such hyperparameters (we simply choose them because of SpinningUp) and you can try playing with those hyperparameters to see if you can improve performance. Do tell us if you can!
+
+[9] We use batchsize of 256 in DDPG/TD3/SAC while SpinningUp use 100. Minor difference also lies with `start-timesteps`, data loop method `step_per_collect`, method to deal with/bootstrap truncated steps because of timelimit and unfinished/collecting episodes (contribute to performance improvement), etc.
diff --git a/examples/mujoco/benchmark/Ant-v3/ddpg/figure.png b/examples/mujoco/benchmark/Ant-v3/ddpg/figure.png
new file mode 100644
index 000000000..a7323780d
Binary files /dev/null and b/examples/mujoco/benchmark/Ant-v3/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/Ant-v3/figure.png b/examples/mujoco/benchmark/Ant-v3/figure.png
new file mode 100644
index 000000000..afb48c2bb
Binary files /dev/null and b/examples/mujoco/benchmark/Ant-v3/figure.png differ
diff --git a/examples/mujoco/benchmark/Ant-v3/sac/figure.png b/examples/mujoco/benchmark/Ant-v3/sac/figure.png
new file mode 100644
index 000000000..5d4d452e8
Binary files /dev/null and b/examples/mujoco/benchmark/Ant-v3/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/Ant-v3/td3/figure.png b/examples/mujoco/benchmark/Ant-v3/td3/figure.png
new file mode 100644
index 000000000..3bf9043ba
Binary files /dev/null and b/examples/mujoco/benchmark/Ant-v3/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/HalfCheetah-v3/ddpg/figure.png b/examples/mujoco/benchmark/HalfCheetah-v3/ddpg/figure.png
new file mode 100644
index 000000000..8df4fd841
Binary files /dev/null and b/examples/mujoco/benchmark/HalfCheetah-v3/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/HalfCheetah-v3/figure.png b/examples/mujoco/benchmark/HalfCheetah-v3/figure.png
new file mode 100644
index 000000000..6459af89d
Binary files /dev/null and b/examples/mujoco/benchmark/HalfCheetah-v3/figure.png differ
diff --git a/examples/mujoco/benchmark/HalfCheetah-v3/sac/figure.png b/examples/mujoco/benchmark/HalfCheetah-v3/sac/figure.png
new file mode 100644
index 000000000..0e692f797
Binary files /dev/null and b/examples/mujoco/benchmark/HalfCheetah-v3/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/HalfCheetah-v3/td3/figure.png b/examples/mujoco/benchmark/HalfCheetah-v3/td3/figure.png
new file mode 100644
index 000000000..76cb4bf9d
Binary files /dev/null and b/examples/mujoco/benchmark/HalfCheetah-v3/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/Hopper-v3/ddpg/figure.png b/examples/mujoco/benchmark/Hopper-v3/ddpg/figure.png
new file mode 100644
index 000000000..2ed9d3a9d
Binary files /dev/null and b/examples/mujoco/benchmark/Hopper-v3/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/Hopper-v3/figure.png b/examples/mujoco/benchmark/Hopper-v3/figure.png
new file mode 100644
index 000000000..0f41d5c99
Binary files /dev/null and b/examples/mujoco/benchmark/Hopper-v3/figure.png differ
diff --git a/examples/mujoco/benchmark/Hopper-v3/sac/figure.png b/examples/mujoco/benchmark/Hopper-v3/sac/figure.png
new file mode 100644
index 000000000..da077e640
Binary files /dev/null and b/examples/mujoco/benchmark/Hopper-v3/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/Hopper-v3/td3/figure.png b/examples/mujoco/benchmark/Hopper-v3/td3/figure.png
new file mode 100644
index 000000000..62ccc221b
Binary files /dev/null and b/examples/mujoco/benchmark/Hopper-v3/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/Humanoid-v3/ddpg/figure.png b/examples/mujoco/benchmark/Humanoid-v3/ddpg/figure.png
new file mode 100644
index 000000000..7a84f66c4
Binary files /dev/null and b/examples/mujoco/benchmark/Humanoid-v3/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/Humanoid-v3/figure.png b/examples/mujoco/benchmark/Humanoid-v3/figure.png
new file mode 100644
index 000000000..3d788b8e9
Binary files /dev/null and b/examples/mujoco/benchmark/Humanoid-v3/figure.png differ
diff --git a/examples/mujoco/benchmark/Humanoid-v3/sac/figure.png b/examples/mujoco/benchmark/Humanoid-v3/sac/figure.png
new file mode 100644
index 000000000..b585f5996
Binary files /dev/null and b/examples/mujoco/benchmark/Humanoid-v3/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/Humanoid-v3/td3/figure.png b/examples/mujoco/benchmark/Humanoid-v3/td3/figure.png
new file mode 100644
index 000000000..d919ceb46
Binary files /dev/null and b/examples/mujoco/benchmark/Humanoid-v3/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedDoublePendulum-v2/ddpg/figure.png b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/ddpg/figure.png
new file mode 100644
index 000000000..128d86e59
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedDoublePendulum-v2/figure.png b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/figure.png
new file mode 100644
index 000000000..ded29d538
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedDoublePendulum-v2/sac/figure.png b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/sac/figure.png
new file mode 100644
index 000000000..fc23e5ce3
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedDoublePendulum-v2/td3/figure.png b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/td3/figure.png
new file mode 100644
index 000000000..71f0e92c3
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedPendulum-v2/ddpg/figure.png b/examples/mujoco/benchmark/InvertedPendulum-v2/ddpg/figure.png
new file mode 100644
index 000000000..f0e33a76f
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedPendulum-v2/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedPendulum-v2/figure.png b/examples/mujoco/benchmark/InvertedPendulum-v2/figure.png
new file mode 100644
index 000000000..f5f1e71fb
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedPendulum-v2/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedPendulum-v2/sac/figure.png b/examples/mujoco/benchmark/InvertedPendulum-v2/sac/figure.png
new file mode 100644
index 000000000..11d8ada9f
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedPendulum-v2/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedPendulum-v2/td3/figure.png b/examples/mujoco/benchmark/InvertedPendulum-v2/td3/figure.png
new file mode 100644
index 000000000..bd8e5255b
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedPendulum-v2/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/Reacher-v2/ddpg/figure.png b/examples/mujoco/benchmark/Reacher-v2/ddpg/figure.png
new file mode 100644
index 000000000..baf2b6f82
Binary files /dev/null and b/examples/mujoco/benchmark/Reacher-v2/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/Reacher-v2/figure.png b/examples/mujoco/benchmark/Reacher-v2/figure.png
new file mode 100644
index 000000000..8943139dc
Binary files /dev/null and b/examples/mujoco/benchmark/Reacher-v2/figure.png differ
diff --git a/examples/mujoco/benchmark/Reacher-v2/sac/figure.png b/examples/mujoco/benchmark/Reacher-v2/sac/figure.png
new file mode 100644
index 000000000..be2debc93
Binary files /dev/null and b/examples/mujoco/benchmark/Reacher-v2/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/Reacher-v2/td3/figure.png b/examples/mujoco/benchmark/Reacher-v2/td3/figure.png
new file mode 100644
index 000000000..f883ffe23
Binary files /dev/null and b/examples/mujoco/benchmark/Reacher-v2/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/Swimmer-v3/ddpg/figure.png b/examples/mujoco/benchmark/Swimmer-v3/ddpg/figure.png
new file mode 100644
index 000000000..6982db5cc
Binary files /dev/null and b/examples/mujoco/benchmark/Swimmer-v3/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/Swimmer-v3/figure.png b/examples/mujoco/benchmark/Swimmer-v3/figure.png
new file mode 100644
index 000000000..c7345c1a1
Binary files /dev/null and b/examples/mujoco/benchmark/Swimmer-v3/figure.png differ
diff --git a/examples/mujoco/benchmark/Swimmer-v3/sac/figure.png b/examples/mujoco/benchmark/Swimmer-v3/sac/figure.png
new file mode 100644
index 000000000..7a2ac165f
Binary files /dev/null and b/examples/mujoco/benchmark/Swimmer-v3/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/Swimmer-v3/td3/figure.png b/examples/mujoco/benchmark/Swimmer-v3/td3/figure.png
new file mode 100644
index 000000000..f9f8219f6
Binary files /dev/null and b/examples/mujoco/benchmark/Swimmer-v3/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/Walker2d-v3/ddpg/figure.png b/examples/mujoco/benchmark/Walker2d-v3/ddpg/figure.png
new file mode 100644
index 000000000..bbe52a3bf
Binary files /dev/null and b/examples/mujoco/benchmark/Walker2d-v3/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/Walker2d-v3/figure.png b/examples/mujoco/benchmark/Walker2d-v3/figure.png
new file mode 100644
index 000000000..5201bad79
Binary files /dev/null and b/examples/mujoco/benchmark/Walker2d-v3/figure.png differ
diff --git a/examples/mujoco/benchmark/Walker2d-v3/sac/figure.png b/examples/mujoco/benchmark/Walker2d-v3/sac/figure.png
new file mode 100644
index 000000000..44581d1e9
Binary files /dev/null and b/examples/mujoco/benchmark/Walker2d-v3/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/Walker2d-v3/td3/figure.png b/examples/mujoco/benchmark/Walker2d-v3/td3/figure.png
new file mode 100644
index 000000000..389a9f288
Binary files /dev/null and b/examples/mujoco/benchmark/Walker2d-v3/td3/figure.png differ
diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py
new file mode 100755
index 000000000..491d42375
--- /dev/null
+++ b/examples/mujoco/mujoco_ddpg.py
@@ -0,0 +1,131 @@
+#!/usr/bin/env python3
+
+import os
+import gym
+import torch
+import datetime
+import argparse
+import numpy as np
+from torch.utils.tensorboard import SummaryWriter
+
+from tianshou.policy import DDPGPolicy
+from tianshou.utils import BasicLogger
+from tianshou.env import SubprocVectorEnv
+from tianshou.utils.net.common import Net
+from tianshou.exploration import GaussianNoise
+from tianshou.trainer import offpolicy_trainer
+from tianshou.utils.net.continuous import Actor, Critic
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--task', type=str, default='Ant-v3')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--buffer-size', type=int, default=1000000)
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
+ parser.add_argument('--actor-lr', type=float, default=1e-3)
+ parser.add_argument('--critic-lr', type=float, default=1e-3)
+ parser.add_argument('--gamma', type=float, default=0.99)
+ parser.add_argument('--tau', type=float, default=0.005)
+ parser.add_argument('--exploration-noise', type=float, default=0.1)
+ parser.add_argument("--start-timesteps", type=int, default=25000)
+ parser.add_argument('--epoch', type=int, default=200)
+ parser.add_argument('--step-per-epoch', type=int, default=5000)
+ parser.add_argument('--step-per-collect', type=int, default=1)
+ parser.add_argument('--update-per-step', type=int, default=1)
+ parser.add_argument('--n-step', type=int, default=1)
+ parser.add_argument('--batch-size', type=int, default=256)
+ parser.add_argument('--training-num', type=int, default=1)
+ parser.add_argument('--test-num', type=int, default=10)
+ parser.add_argument('--logdir', type=str, default='log')
+ parser.add_argument('--render', type=float, default=0.)
+ parser.add_argument(
+ '--device', type=str,
+ default='cuda' if torch.cuda.is_available() else 'cpu')
+ parser.add_argument('--resume-path', type=str, default=None)
+ return parser.parse_args()
+
+
+def test_ddpg(args=get_args()):
+ env = gym.make(args.task)
+ args.state_shape = env.observation_space.shape or env.observation_space.n
+ args.action_shape = env.action_space.shape or env.action_space.n
+ args.max_action = env.action_space.high[0]
+ args.exploration_noise = args.exploration_noise * args.max_action
+ print("Observations shape:", args.state_shape)
+ print("Actions shape:", args.action_shape)
+ print("Action range:", np.min(env.action_space.low),
+ np.max(env.action_space.high))
+ # train_envs = gym.make(args.task)
+ if args.training_num > 1:
+ train_envs = SubprocVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ else:
+ train_envs = gym.make(args.task)
+ # test_envs = gym.make(args.task)
+ test_envs = SubprocVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ # seed
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ train_envs.seed(args.seed)
+ test_envs.seed(args.seed)
+ # model
+ net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
+ actor = Actor(
+ net_a, args.action_shape, max_action=args.max_action,
+ device=args.device).to(args.device)
+ actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
+ net_c = Net(args.state_shape, args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True, device=args.device)
+ critic = Critic(net_c, device=args.device).to(args.device)
+ critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
+ policy = DDPGPolicy(
+ actor, actor_optim, critic, critic_optim,
+ action_range=[env.action_space.low[0], env.action_space.high[0]],
+ tau=args.tau, gamma=args.gamma,
+ exploration_noise=GaussianNoise(sigma=args.exploration_noise),
+ estimation_step=args.n_step)
+ # load a previous policy
+ if args.resume_path:
+ policy.load_state_dict(torch.load(
+ args.resume_path, map_location=args.device
+ ))
+ print("Loaded agent from: ", args.resume_path)
+
+ # collector
+ if args.training_num > 1:
+ buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
+ else:
+ buffer = ReplayBuffer(args.buffer_size)
+ train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
+ test_collector = Collector(policy, test_envs)
+ train_collector.collect(n_step=args.start_timesteps, random=True)
+ # log
+ log_path = os.path.join(args.logdir, args.task, 'ddpg', 'seed_' + str(
+ args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
+ writer = SummaryWriter(log_path)
+ writer.add_text("args", str(args))
+ logger = BasicLogger(writer)
+
+ def save_fn(policy):
+ torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
+ # trainer
+ result = offpolicy_trainer(
+ policy, train_collector, test_collector, args.epoch,
+ args.step_per_epoch, args.step_per_collect, args.test_num,
+ args.batch_size, save_fn=save_fn, logger=logger,
+ update_per_step=args.update_per_step, test_in_train=False)
+
+ # Let's watch its performance!
+ policy.eval()
+ test_envs.seed(args.seed)
+ test_collector.reset()
+ result = test_collector.collect(n_episode=args.test_num, render=args.render)
+ print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
+
+
+if __name__ == '__main__':
+ test_ddpg()
diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py
old mode 100644
new mode 100755
index c07877324..73ef4eb49
--- a/examples/mujoco/mujoco_sac.py
+++ b/examples/mujoco/mujoco_sac.py
@@ -1,7 +1,8 @@
+#!/usr/bin/env python3
+
import os
import gym
import torch
-import pprint
import datetime
import argparse
import numpy as np
@@ -12,42 +13,38 @@
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Ant-v3')
- parser.add_argument('--seed', type=int, default=1626)
+ parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=1000000)
- parser.add_argument('--actor-lr', type=float, default=3e-4)
- parser.add_argument('--critic-lr', type=float, default=3e-4)
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
+ parser.add_argument('--actor-lr', type=float, default=1e-3)
+ parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--auto-alpha', default=False, action='store_true')
parser.add_argument('--alpha-lr', type=float, default=3e-4)
- parser.add_argument('--n-step', type=int, default=2)
- parser.add_argument('--epoch', type=int, default=100)
- parser.add_argument('--step-per-epoch', type=int, default=40000)
- parser.add_argument('--step-per-collect', type=int, default=4)
- parser.add_argument('--update-per-step', type=float, default=0.25)
- parser.add_argument('--pre-collect-step', type=int, default=10000)
+ parser.add_argument("--start-timesteps", type=int, default=10000)
+ parser.add_argument('--epoch', type=int, default=200)
+ parser.add_argument('--step-per-epoch', type=int, default=5000)
+ parser.add_argument('--step-per-collect', type=int, default=1)
+ parser.add_argument('--update-per-step', type=int, default=1)
+ parser.add_argument('--n-step', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=256)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[128, 128])
- parser.add_argument('--training-num', type=int, default=4)
- parser.add_argument('--test-num', type=int, default=100)
+ parser.add_argument('--training-num', type=int, default=1)
+ parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
- parser.add_argument('--log-interval', type=int, default=1000)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
return parser.parse_args()
@@ -61,8 +58,11 @@ def test_sac(args=get_args()):
print("Action range:", np.min(env.action_space.low),
np.max(env.action_space.high))
# train_envs = gym.make(args.task)
- train_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ if args.training_num > 1:
+ train_envs = SubprocVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ else:
+ train_envs = gym.make(args.task)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
@@ -72,21 +72,20 @@ def test_sac(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
+ net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = ActorProb(
- net, args.action_shape, max_action=args.max_action,
+ net_a, args.action_shape, max_action=args.max_action,
device=args.device, unbounded=True, conditioned_sigma=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
- critic1 = Critic(net_c1, device=args.device).to(args.device)
- critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net_c2 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
+ critic1 = Critic(net_c1, device=args.device).to(args.device)
+ critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
@@ -109,46 +108,35 @@ def test_sac(args=get_args()):
print("Loaded agent from: ", args.resume_path)
# collector
- train_collector = Collector(
- policy, train_envs,
- VectorReplayBuffer(args.buffer_size, len(train_envs)),
- exploration_noise=True)
+ if args.training_num > 1:
+ buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
+ else:
+ buffer = ReplayBuffer(args.buffer_size)
+ train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
+ train_collector.collect(n_step=args.start_timesteps, random=True)
# log
log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str(
args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
- logger = BasicLogger(writer, train_interval=args.log_interval)
-
- def watch():
- # watch agent's performance
- print("Testing agent ...")
- policy.eval()
- test_envs.seed(args.seed)
- test_collector.reset()
- result = test_collector.collect(n_episode=args.test_num, render=args.render)
- pprint.pprint(result)
+ logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
-
- def stop_fn(mean_rewards):
- return False
-
- if args.watch:
- watch()
- exit(0)
-
# trainer
- train_collector.collect(n_step=args.pre_collect_step, random=True)
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step)
- pprint.pprint(result)
- watch()
+ args.batch_size, save_fn=save_fn, logger=logger,
+ update_per_step=args.update_per_step, test_in_train=False)
+
+ # Let's watch its performance!
+ policy.eval()
+ test_envs.seed(args.seed)
+ test_collector.reset()
+ result = test_collector.collect(n_episode=args.test_num, render=args.render)
+ print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
if __name__ == '__main__':
diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py
new file mode 100755
index 000000000..9066cbee1
--- /dev/null
+++ b/examples/mujoco/mujoco_td3.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python3
+
+import os
+import gym
+import torch
+import datetime
+import argparse
+import numpy as np
+from torch.utils.tensorboard import SummaryWriter
+
+from tianshou.policy import TD3Policy
+from tianshou.utils import BasicLogger
+from tianshou.env import SubprocVectorEnv
+from tianshou.utils.net.common import Net
+from tianshou.exploration import GaussianNoise
+from tianshou.trainer import offpolicy_trainer
+from tianshou.utils.net.continuous import Actor, Critic
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--task', type=str, default='Ant-v3')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--buffer-size', type=int, default=1000000)
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
+ parser.add_argument('--actor-lr', type=float, default=3e-4)
+ parser.add_argument('--critic-lr', type=float, default=3e-4)
+ parser.add_argument('--gamma', type=float, default=0.99)
+ parser.add_argument('--tau', type=float, default=0.005)
+ parser.add_argument('--exploration-noise', type=float, default=0.1)
+ parser.add_argument('--policy-noise', type=float, default=0.2)
+ parser.add_argument('--noise-clip', type=float, default=0.5)
+ parser.add_argument('--update-actor-freq', type=int, default=2)
+ parser.add_argument("--start-timesteps", type=int, default=25000)
+ parser.add_argument('--epoch', type=int, default=200)
+ parser.add_argument('--step-per-epoch', type=int, default=5000)
+ parser.add_argument('--step-per-collect', type=int, default=1)
+ parser.add_argument('--update-per-step', type=int, default=1)
+ parser.add_argument('--n-step', type=int, default=1)
+ parser.add_argument('--batch-size', type=int, default=256)
+ parser.add_argument('--training-num', type=int, default=1)
+ parser.add_argument('--test-num', type=int, default=10)
+ parser.add_argument('--logdir', type=str, default='log')
+ parser.add_argument('--render', type=float, default=0.)
+ parser.add_argument(
+ '--device', type=str,
+ default='cuda' if torch.cuda.is_available() else 'cpu')
+ parser.add_argument('--resume-path', type=str, default=None)
+ return parser.parse_args()
+
+
+def test_td3(args=get_args()):
+ env = gym.make(args.task)
+ args.state_shape = env.observation_space.shape or env.observation_space.n
+ args.action_shape = env.action_space.shape or env.action_space.n
+ args.max_action = env.action_space.high[0]
+ args.exploration_noise = args.exploration_noise * args.max_action
+ args.policy_noise = args.policy_noise * args.max_action
+ args.noise_clip = args.noise_clip * args.max_action
+ print("Observations shape:", args.state_shape)
+ print("Actions shape:", args.action_shape)
+ print("Action range:", np.min(env.action_space.low),
+ np.max(env.action_space.high))
+ # train_envs = gym.make(args.task)
+ if args.training_num > 1:
+ train_envs = SubprocVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ else:
+ train_envs = gym.make(args.task)
+ # test_envs = gym.make(args.task)
+ test_envs = SubprocVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ # seed
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ train_envs.seed(args.seed)
+ test_envs.seed(args.seed)
+ # model
+ net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
+ actor = Actor(
+ net_a, args.action_shape, max_action=args.max_action,
+ device=args.device).to(args.device)
+ actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
+ net_c1 = Net(args.state_shape, args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True, device=args.device)
+ net_c2 = Net(args.state_shape, args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True, device=args.device)
+ critic1 = Critic(net_c1, device=args.device).to(args.device)
+ critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
+ critic2 = Critic(net_c2, device=args.device).to(args.device)
+ critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
+
+ policy = TD3Policy(
+ actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
+ action_range=[env.action_space.low[0], env.action_space.high[0]],
+ tau=args.tau, gamma=args.gamma,
+ exploration_noise=GaussianNoise(sigma=args.exploration_noise),
+ policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq,
+ noise_clip=args.noise_clip, estimation_step=args.n_step)
+
+ # load a previous policy
+ if args.resume_path:
+ policy.load_state_dict(torch.load(
+ args.resume_path, map_location=args.device
+ ))
+ print("Loaded agent from: ", args.resume_path)
+
+ # collector
+ if args.training_num > 1:
+ buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
+ else:
+ buffer = ReplayBuffer(args.buffer_size)
+ train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
+ test_collector = Collector(policy, test_envs)
+ train_collector.collect(n_step=args.start_timesteps, random=True)
+ # log
+ log_path = os.path.join(args.logdir, args.task, 'td3', 'seed_' + str(
+ args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
+ writer = SummaryWriter(log_path)
+ writer.add_text("args", str(args))
+ logger = BasicLogger(writer)
+
+ def save_fn(policy):
+ torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
+ # trainer
+ result = offpolicy_trainer(
+ policy, train_collector, test_collector, args.epoch,
+ args.step_per_epoch, args.step_per_collect, args.test_num,
+ args.batch_size, save_fn=save_fn, logger=logger,
+ update_per_step=args.update_per_step, test_in_train=False)
+
+ # Let's watch its performance!
+ policy.eval()
+ test_envs.seed(args.seed)
+ test_collector.reset()
+ result = test_collector.collect(n_episode=args.test_num, render=args.render)
+ print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
+
+
+if __name__ == '__main__':
+ test_td3()
diff --git a/examples/mujoco/results/sac/all.png b/examples/mujoco/results/sac/all.png
deleted file mode 100644
index 7f314f46f..000000000
Binary files a/examples/mujoco/results/sac/all.png and /dev/null differ
diff --git a/examples/mujoco/run_experiments.sh b/examples/mujoco/run_experiments.sh
new file mode 100755
index 000000000..4de3263f5
--- /dev/null
+++ b/examples/mujoco/run_experiments.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+LOGDIR="results"
+TASK=$1
+
+echo "Experiments started."
+for seed in $(seq 1 10)
+do
+ python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1
+done
diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py
index 9a4dad062..87a06c53e 100644
--- a/tianshou/policy/modelfree/ddpg.py
+++ b/tianshou/policy/modelfree/ddpg.py
@@ -129,18 +129,31 @@ def forward(
obs = batch[input]
actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias
+ actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h)
- def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
- weight = batch.pop("weight", 1.0)
- current_q = self.critic(batch.obs, batch.act).flatten()
+ @staticmethod
+ def _mse_optimizer(
+ batch: Batch, critic: torch.nn.Module, optimizer: torch.optim.Optimizer
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """A simple wrapper script for updating critic network."""
+ weight = getattr(batch, "weight", 1.0)
+ current_q = critic(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
td = current_q - target_q
+ # critic_loss = F.mse_loss(current_q1, target_q)
critic_loss = (td.pow(2) * weight).mean()
- batch.weight = td # prio-buffer
- self.critic_optim.zero_grad()
+ optimizer.zero_grad()
critic_loss.backward()
- self.critic_optim.step()
+ optimizer.step()
+ return td, critic_loss
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ # critic
+ td, critic_loss = self._mse_optimizer(
+ batch, self.critic, self.critic_optim)
+ batch.weight = td # prio-buffer
+ # actor
action = self(batch).act
actor_loss = -self.critic(batch.obs, action).mean()
self.actor_optim.zero_grad()
diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py
index bd1fea14a..1edeea2fb 100644
--- a/tianshou/policy/modelfree/dqn.py
+++ b/tianshou/policy/modelfree/dqn.py
@@ -69,7 +69,7 @@ def train(self, mode: bool = True) -> "DQNPolicy":
def sync_weight(self) -> None:
"""Synchronize the weight for the target network."""
- self.model_old.load_state_dict(self.model.state_dict())
+ self.model_old.load_state_dict(self.model.state_dict()) # type: ignore
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py
index 68bef3971..ac81cbfc2 100644
--- a/tianshou/policy/modelfree/sac.py
+++ b/tianshou/policy/modelfree/sac.py
@@ -115,7 +115,11 @@ def forward( # type: ignore
x = dist.rsample()
y = torch.tanh(x)
act = y * self._action_scale + self._action_bias
+ # __eps is used to avoid log of zero/negative number.
y = self._action_scale * (1 - y.pow(2)) + self.__eps
+ # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
+ # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
+ # in appendix C to get some understanding of this equation.
log_prob = dist.log_prob(x).unsqueeze(-1)
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
@@ -134,26 +138,11 @@ def _target_q(
return target_q
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
- weight = batch.pop("weight", 1.0)
-
- # critic 1
- current_q1 = self.critic1(batch.obs, batch.act).flatten()
- target_q = batch.returns.flatten()
- td1 = current_q1 - target_q
- critic1_loss = (td1.pow(2) * weight).mean()
- # critic1_loss = F.mse_loss(current_q1, target_q)
- self.critic1_optim.zero_grad()
- critic1_loss.backward()
- self.critic1_optim.step()
-
- # critic 2
- current_q2 = self.critic2(batch.obs, batch.act).flatten()
- td2 = current_q2 - target_q
- critic2_loss = (td2.pow(2) * weight).mean()
- # critic2_loss = F.mse_loss(current_q2, target_q)
- self.critic2_optim.zero_grad()
- critic2_loss.backward()
- self.critic2_optim.step()
+ # critic 1&2
+ td1, critic1_loss = self._mse_optimizer(
+ batch, self.critic1, self.critic1_optim)
+ td2, critic2_loss = self._mse_optimizer(
+ batch, self.critic2, self.critic2_optim)
batch.weight = (td1 + td2) / 2.0 # prio-buffer
# actor
diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py
index bd6572205..09f288ff6 100644
--- a/tianshou/policy/modelfree/td3.py
+++ b/tianshou/policy/modelfree/td3.py
@@ -105,25 +105,14 @@ def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
return target_q
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
- weight = batch.pop("weight", 1.0)
- # critic 1
- current_q1 = self.critic1(batch.obs, batch.act).flatten()
- target_q = batch.returns.flatten()
- td1 = current_q1 - target_q
- critic1_loss = (td1.pow(2) * weight).mean()
- # critic1_loss = F.mse_loss(current_q1, target_q)
- self.critic1_optim.zero_grad()
- critic1_loss.backward()
- self.critic1_optim.step()
- # critic 2
- current_q2 = self.critic2(batch.obs, batch.act).flatten()
- td2 = current_q2 - target_q
- critic2_loss = (td2.pow(2) * weight).mean()
- # critic2_loss = F.mse_loss(current_q2, target_q)
- self.critic2_optim.zero_grad()
- critic2_loss.backward()
- self.critic2_optim.step()
+ # critic 1&2
+ td1, critic1_loss = self._mse_optimizer(
+ batch, self.critic1, self.critic1_optim)
+ td2, critic2_loss = self._mse_optimizer(
+ batch, self.critic2, self.critic2_optim)
batch.weight = (td1 + td2) / 2.0 # prio-buffer
+
+ # actor
if self._cnt % self._freq == 0:
actor_loss = -self.critic1(batch.obs, self(batch, eps=0.0).act).mean()
self.actor_optim.zero_grad()
diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py
index 13f96faeb..802e34963 100644
--- a/tianshou/trainer/offline.py
+++ b/tianshou/trainer/offline.py
@@ -93,8 +93,9 @@ def offline_trainer(
if save_fn:
save_fn(policy)
if verbose:
- print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
- f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
+ print(
+ f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
+ f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, None, test_collector, best_reward, best_reward_std)
diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py
index 72a243d9a..444dbc08a 100644
--- a/tianshou/trainer/offpolicy.py
+++ b/tianshou/trainer/offpolicy.py
@@ -145,8 +145,9 @@ def offpolicy_trainer(
if save_fn:
save_fn(policy)
if verbose:
- print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
- f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
+ print(
+ f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
+ f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, train_collector, test_collector,
diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py
index dae20a741..ab731270e 100644
--- a/tianshou/trainer/onpolicy.py
+++ b/tianshou/trainer/onpolicy.py
@@ -155,8 +155,9 @@ def onpolicy_trainer(
if save_fn:
save_fn(policy)
if verbose:
- print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
- f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
+ print(
+ f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
+ f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, train_collector, test_collector,