diff --git a/README.md b/README.md index a39da242e..dc99cc7b5 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ - Vanilla Imitation Learning - [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf) - [Conservative Q-Learning (CQL)](https://arxiv.org/pdf/2006.04779.pdf) +- [Twin Delayed DDPG with Behavior Cloning (TD3+BC)](https://arxiv.org/pdf/2106.06860.pdf) - [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf) - [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf) - [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf) diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index fe129d8ec..6acd1d54b 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -129,6 +129,11 @@ Imitation :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.TD3BCPolicy + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: tianshou.policy.DiscreteBCQPolicy :members: :undoc-members: diff --git a/docs/index.rst b/docs/index.rst index c18782718..7fce12f6b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -31,6 +31,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning * :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning `_ * :class:`~tianshou.policy.CQLPolicy` `Conservative Q-Learning `_ +* :class:`~tianshou.policy.TD3BCPolicy` `Twin Delayed DDPG with Behavior Cloning `_ * :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ * :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ * :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression `_ diff --git a/examples/offline/README.md b/examples/offline/README.md index 61052c7c6..f52fb918c 100644 --- a/examples/offline/README.md +++ b/examples/offline/README.md @@ -35,6 +35,24 @@ Tianshou provides an `offline_trainer` for offline reinforcement learning. You c | HalfCheetah-v2 | halfcheetah-expert-v2 | 2864.37 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` | | HalfCheetah-v2 | halfcheetah-medium-v2 | 6505.41 | `python3 d4rl_cql.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` | +### TD3+BC + +| Environment | Dataset | CQL | Parameters | +| --------------------- | --------------------- | --------------- | -------------------------------------------------------- | +| HalfCheetah-v2 | halfcheetah-expert-v2 | 11788.25 | `python3 d4rl_td3_bc.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2` | +| HalfCheetah-v2 | halfcheetah-medium-v2 | 5741.13 | `python3 d4rl_td3_bc.py --task HalfCheetah-v2 --expert-data-task halfcheetah-medium-v2` | + +#### Observation normalization + +Following the original paper, we use observation normalization by default. You can turn it off by setting `--norm-obs 0`. The difference are small but consistent. + +| Dataset | w/ norm-obs | w/o norm-obs | +| :--- | :--- | :--- | +| halfcheeta-medium-v2 | 5741.13 | 5724.41 | +| halfcheeta-expert-v2 | 11788.25 | 11665.77 | +| walker2d-medium-v2 | 4051.76 | 3985.59 | +| walker2d-expert-v2 | 5068.15 | 5027.75 | + ## Discrete control For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent. diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py new file mode 100644 index 000000000..43ee21b31 --- /dev/null +++ b/examples/offline/d4rl_td3_bc.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 + +import argparse +import datetime +import os +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer +from tianshou.data import Collector +from tianshou.env import SubprocVectorEnv, VectorEnvNormObs +from tianshou.exploration import GaussianNoise +from tianshou.policy import TD3BCPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger, WandbLogger +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import Actor, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="HalfCheetah-v2") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--expert-data-task", type=str, default="halfcheetah-expert-v2" + ) + parser.add_argument("--buffer-size", type=int, default=1000000) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor-lr", type=float, default=3e-4) + parser.add_argument("--critic-lr", type=float, default=3e-4) + parser.add_argument("--epoch", type=int, default=200) + parser.add_argument("--step-per-epoch", type=int, default=5000) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--batch-size", type=int, default=256) + + parser.add_argument("--alpha", type=float, default=2.5) + parser.add_argument("--exploration-noise", type=float, default=0.1) + parser.add_argument("--policy-noise", type=float, default=0.2) + parser.add_argument("--noise-clip", type=float, default=0.5) + parser.add_argument("--update-actor-freq", type=int, default=2) + parser.add_argument("--tau", type=float, default=0.005) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--norm-obs", type=int, default=1) + + parser.add_argument("--eval-freq", type=int, default=1) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=1 / 35) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + return parser.parse_args() + + +def test_td3_bc(): + args = get_args() + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] # float + print("device:", args.device) + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + + args.state_dim = args.state_shape[0] + args.action_dim = args.action_shape[0] + print("Max_action", args.max_action) + + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + if args.norm_obs: + test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) + + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + + # model + # actor network + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + ) + actor = Actor( + net_a, + action_shape=args.action_shape, + max_action=args.max_action, + device=args.device, + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + + # critic network + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + policy = TD3BCPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + policy_noise=args.policy_noise, + update_actor_freq=args.update_actor_freq, + noise_clip=args.noise_clip, + alpha=args.alpha, + estimation_step=args.n_step, + action_space=env.action_space, + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + test_collector = Collector(policy, test_envs) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "td3_bc" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) + + def save_best_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def watch(): + if args.resume_path is None: + args.resume_path = os.path.join(log_path, "policy.pth") + + policy.load_state_dict( + torch.load(args.resume_path, map_location=torch.device("cpu")) + ) + policy.eval() + collector = Collector(policy, env) + collector.collect(n_episode=1, render=1 / 35) + + if not args.watch: + replay_buffer = load_buffer_d4rl(args.expert_data_task) + if args.norm_obs: + replay_buffer, obs_rms = normalize_all_obs_in_replay_buffer(replay_buffer) + test_envs.set_obs_rms(obs_rms) + # trainer + result = offline_trainer( + policy, + replay_buffer, + test_collector, + args.epoch, + args.step_per_epoch, + args.test_num, + args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + ) + pprint.pprint(result) + else: + watch() + + # Let's watch its performance! + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") + + +if __name__ == "__main__": + test_td3_bc() diff --git a/examples/offline/utils.py b/examples/offline/utils.py index 757baf69a..c6052795c 100644 --- a/examples/offline/utils.py +++ b/examples/offline/utils.py @@ -1,8 +1,12 @@ +from typing import Tuple + import d4rl import gym import h5py +import numpy as np from tianshou.data import ReplayBuffer +from tianshou.utils import RunningMeanStd def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer: @@ -27,3 +31,18 @@ def load_buffer(buffer_path: str) -> ReplayBuffer: obs_next=dataset["next_observations"] ) return buffer + + +def normalize_all_obs_in_replay_buffer( + replay_buffer: ReplayBuffer +) -> Tuple[ReplayBuffer, RunningMeanStd]: + # compute obs mean and var + obs_rms = RunningMeanStd() + obs_rms.update(replay_buffer.obs) + _eps = np.finfo(np.float32).eps.item() + # normalize obs + replay_buffer._meta["obs"] = (replay_buffer.obs - + obs_rms.mean) / np.sqrt(obs_rms.var + _eps) + replay_buffer._meta["obs_next"] = (replay_buffer.obs_next - + obs_rms.mean) / np.sqrt(obs_rms.var + _eps) + return replay_buffer, obs_rms diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py new file mode 100644 index 000000000..e34dfc6d2 --- /dev/null +++ b/test/offline/test_td3_bc.py @@ -0,0 +1,215 @@ +import argparse +import datetime +import os +import pickle +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.exploration import GaussianNoise +from tianshou.policy import TD3BCPolicy +from tianshou.trainer import OfflineTrainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import Actor, Critic + +if __name__ == "__main__": + from gather_pendulum_data import expert_file_name, gather_data +else: # pytest + from test.offline.gather_pendulum_data import expert_file_name, gather_data + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v1') + parser.add_argument('--reward-threshold', type=float, default=None) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--actor-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=500) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--alpha', type=float, default=2.5) + parser.add_argument("--exploration-noise", type=float, default=0.1) + parser.add_argument("--policy-noise", type=float, default=0.2) + parser.add_argument("--noise-clip", type=float, default=0.5) + parser.add_argument("--update-actor-freq", type=int, default=2) + parser.add_argument("--tau", type=float, default=0.005) + parser.add_argument("--gamma", type=float, default=0.99) + + parser.add_argument("--eval-freq", type=int, default=1) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=1 / 35) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only', + ) + parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + args = parser.parse_known_args()[0] + return args + + +def test_td3_bc(args=get_args()): + if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): + if args.load_buffer_name.endswith(".hdf5"): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + buffer = pickle.load(open(args.load_buffer_name, "rb")) + else: + buffer = gather_data() + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] # float + if args.reward_threshold is None: + # too low? + default_reward_threshold = {"Pendulum-v0": -1200, "Pendulum-v1": -1200} + args.reward_threshold = default_reward_threshold.get( + args.task, env.spec.reward_threshold + ) + + args.state_dim = args.state_shape[0] + args.action_dim = args.action_shape[0] + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + + # model + # actor network + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + ) + actor = Actor( + net_a, + action_shape=args.action_shape, + max_action=args.max_action, + device=args.device, + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + + # critic network + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + policy = TD3BCPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + policy_noise=args.policy_noise, + update_actor_freq=args.update_actor_freq, + noise_clip=args.noise_clip, + alpha=args.alpha, + estimation_step=args.n_step, + action_space=env.action_space, + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + # buffer has been gathered + # train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + # log + t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") + log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_td3_bc' + log_path = os.path.join(args.logdir, args.task, 'td3_bc', log_file) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + def save_best_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= args.reward_threshold + + def watch(): + policy.load_state_dict( + torch.load( + os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu') + ) + ) + policy.eval() + collector = Collector(policy, env) + collector.collect(n_episode=1, render=1 / 35) + + # trainer + trainer = OfflineTrainer( + policy, + buffer, + test_collector, + args.epoch, + args.step_per_epoch, + args.test_num, + args.batch_size, + save_best_fn=save_best_fn, + stop_fn=stop_fn, + logger=logger, + ) + + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) + + assert stop_fn(info["best_reward"]) + + # Let's watch its performance! + if __name__ == "__main__": + pprint.pprint(info) + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == '__main__': + test_td3_bc() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 8007bad9c..c8fa45e8e 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -23,6 +23,7 @@ from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.imitation.bcq import BCQPolicy from tianshou.policy.imitation.cql import CQLPolicy +from tianshou.policy.imitation.td3_bc import TD3BCPolicy from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy @@ -54,6 +55,7 @@ "ImitationPolicy", "BCQPolicy", "CQLPolicy", + "TD3BCPolicy", "DiscreteBCQPolicy", "DiscreteCQLPolicy", "DiscreteCRRPolicy", diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py new file mode 100644 index 000000000..9746659ad --- /dev/null +++ b/tianshou/policy/imitation/td3_bc.py @@ -0,0 +1,107 @@ +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F + +from tianshou.data import Batch, to_torch_as +from tianshou.exploration import BaseNoise, GaussianNoise +from tianshou.policy import TD3Policy + + +class TD3BCPolicy(TD3Policy): + """Implementation of TD3+BC. arXiv:2106.06860. + + :param torch.nn.Module actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer actor_optim: the optimizer for actor network. + :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a)) + :param torch.optim.Optimizer critic1_optim: the optimizer for the first + critic network. + :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) + :param torch.optim.Optimizer critic2_optim: the optimizer for the second + critic network. + :param float tau: param for soft update of the target network. Default to 0.005. + :param float gamma: discount factor, in [0, 1]. Default to 0.99. + :param float exploration_noise: the exploration noise, add to the action. + Default to ``GaussianNoise(sigma=0.1)`` + :param float policy_noise: the noise used in updating policy network. + Default to 0.2. + :param int update_actor_freq: the update frequency of actor network. + Default to 2. + :param float noise_clip: the clipping range used in updating policy network. + Default to 0.5. + :param float alpha: the value of alpha, which controls the weight for TD3 learning + relative to behavior cloning. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. + :param bool action_scaling: whether to map actions from range [-1, 1] to range + [action_spaces.low, action_spaces.high]. Default to True. + :param str action_bound_method: method to bound action to range [-1, 1], can be + either "clip" (for simply clipping the action) or empty string for no bounding. + Default to "clip". + :param Optional[gym.Space] action_space: env's action space, mandatory if you want + to use option "action_scaling" or "action_bound_method". Default to None. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + actor: torch.nn.Module, + actor_optim: torch.optim.Optimizer, + critic1: torch.nn.Module, + critic1_optim: torch.optim.Optimizer, + critic2: torch.nn.Module, + critic2_optim: torch.optim.Optimizer, + tau: float = 0.005, + gamma: float = 0.99, + exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), + policy_noise: float = 0.2, + update_actor_freq: int = 2, + noise_clip: float = 0.5, + alpha: float = 2.5, + reward_normalization: bool = False, + estimation_step: int = 1, + **kwargs: Any, + ) -> None: + super().__init__( + actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau, + gamma, exploration_noise, policy_noise, update_actor_freq, noise_clip, + reward_normalization, estimation_step, **kwargs + ) + self._alpha = alpha + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + # critic 1&2 + td1, critic1_loss = self._mse_optimizer( + batch, self.critic1, self.critic1_optim + ) + td2, critic2_loss = self._mse_optimizer( + batch, self.critic2, self.critic2_optim + ) + batch.weight = (td1 + td2) / 2.0 # prio-buffer + + # actor + if self._cnt % self._freq == 0: + act = self(batch, eps=0.0).act + q_value = self.critic1(batch.obs, act) + lmbda = self._alpha / q_value.abs().mean().detach() + actor_loss = -lmbda * q_value.mean() + F.mse_loss( + act, to_torch_as(batch.act, act) + ) + self.actor_optim.zero_grad() + actor_loss.backward() + self._last = actor_loss.item() + self.actor_optim.step() + self.sync_weight() + self._cnt += 1 + return { + "loss/actor": self._last, + "loss/critic1": critic1_loss.item(), + "loss/critic2": critic2_loss.item(), + } diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index b4dcfb6be..aa3d1fc77 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -301,6 +301,10 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: if t.n <= t.total and not self.stop_fn_flag: t.update() + # for offline RL + if self.train_collector is None: + self.env_step = self.gradient_step * self.batch_size + if not self.stop_fn_flag: self.logger.save_data( self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn