diff --git a/README.md b/README.md index f65837e1c..c68fbf4ad 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) - [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) +- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf) - Vanilla Imitation Learning - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) diff --git a/docs/index.rst b/docs/index.rst index 587dc7e5c..454997ef7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -18,6 +18,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ * :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG `_ * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ +* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index d7f5971d2..7d644587d 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -75,6 +75,34 @@ A policy class typically has the following parts: * :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``. +.. _policy_state: + +States for policy +^^^^^^^^^^^^^^^^^ + +During the training process, the policy has two main states: training state and testing state. The training state can be further divided into the collecting state and updating state. + +The meaning of training and testing state is obvious: the agent interacts with environment, collects training data and performs update, that's training state; the testing state is to evaluate the performance of the current policy during training process. + +As for the collecting state, it is defined as interacting with environments and collecting training data into the buffer; +we define the updating state as performing a model update by :meth:`~tianshou.policy.BasePolicy.update` during training process. + + +In order to distinguish these states, you can check the policy state by ``policy.training`` and ``policy.updating``. The state setting is as follows: + ++-----------------------------------+-----------------+-----------------+ +| State for policy | policy.training | policy.updating | ++================+==================+=================+=================+ +| | Collecting state | True | False | +| Training state +------------------+-----------------+-----------------+ +| | Updating state | True | True | ++----------------+------------------+-----------------+-----------------+ +| Testing state | False | False | ++-----------------------------------+-----------------+-----------------+ + +``policy.updating`` is helpful to distinguish the different exploration state, for example, in DQN we don't have to use epsilon-greedy in a pure network update, so ``policy.updating`` is helpful for setting epsilon in this case. + + policy.forward ^^^^^^^^^^^^^^ diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 531553846..0d755fbc2 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -129,10 +129,14 @@ def reset(self) -> None: obs_next={}, policy={}) self.reset_env() self.reset_buffer() - self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0 + self.reset_stat() if self._action_noise is not None: self._action_noise.reset() + def reset_stat(self) -> None: + """Reset the statistic variables.""" + self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0 + def reset_buffer(self) -> None: """Reset the main data buffer.""" if self.buffer is not None: diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 7785b8f38..809751fe7 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -60,6 +60,7 @@ def __init__( self.observation_space = observation_space self.action_space = action_space self.agent_id = 0 + self.updating = False self._compile() def set_agent_id(self, agent_id: int) -> None: @@ -118,6 +119,13 @@ def learn( :return: A dict which includes loss and its corresponding label. + .. note:: + + In order to distinguish the collecting state, updating state and + testing state, you can check the policy state by ``self.training`` + and ``self.updating``. Please refer to :ref:`policy_state` for more + detailed explanation. + .. warning:: If you use ``torch.distributions.Normal`` and @@ -146,6 +154,10 @@ def update( """Update the policy network and replay buffer. It includes 3 function steps: process_fn, learn, and post_process_fn. + In addition, this function will change the value of ``self.updating``: + it will be False before this function and will be True when executing + :meth:`update`. Please refer to :ref:`policy_state` for more detailed + explanation. :param int sample_size: 0 means it will extract all the data from the buffer, otherwise it will sample a batch with given sample_size. @@ -154,9 +166,11 @@ def update( if buffer is None: return {} batch, indice = buffer.sample(sample_size) + self.updating = True batch = self.process_fn(batch, buffer, indice) result = self.learn(batch, **kwargs) self.post_process_fn(batch, buffer, indice) + self.updating = False return result @staticmethod diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 81bf7f6d7..ab28b6b7d 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -103,9 +103,9 @@ def _target_q( ) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} with torch.no_grad(): - target_q = self.critic_old(batch.obs_next, self( - batch, model='actor_old', input='obs_next', - explorating=False).act) + target_q = self.critic_old( + batch.obs_next, + self(batch, model='actor_old', input='obs_next').act) return target_q def process_fn( @@ -124,7 +124,6 @@ def forward( state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = "actor", input: str = "obs", - explorating: bool = True, **kwargs: Any, ) -> Batch: """Compute action over the given batch data. @@ -143,7 +142,7 @@ def forward( obs = batch[input] actions, h = model(obs, state=state, info=batch.info) actions += self._action_bias - if self._noise and self.training and explorating: + if self._noise and not self.updating: actions += to_torch_as(self._noise(actions.shape), actions) actions = actions.clamp(self._range[0], self._range[1]) return Batch(act=actions, state=h) @@ -158,7 +157,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() - action = self(batch, explorating=False).act + action = self(batch).act actor_loss = -self.critic(batch.obs, action).mean() self.actor_optim.zero_grad() actor_loss.backward() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 71d16f6b6..91cca6139 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -80,7 +80,7 @@ def _target_q( batch = buffer[indice] # batch.obs_next: s_{t+n} if self._target: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - a = self(batch, input="obs_next", eps=0).act + a = self(batch, input="obs_next").act with torch.no_grad(): target_q = self( batch, model="model_old", input="obs_next" @@ -110,7 +110,6 @@ def forward( state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = "model", input: str = "obs", - eps: Optional[float] = None, **kwargs: Any, ) -> Batch: """Compute action over the given batch data. @@ -152,12 +151,10 @@ def forward( q_: np.ndarray = to_numpy(q) q_[~obs.mask] = -np.inf act = q_.argmax(axis=1) - # add eps to act - if eps is None: - eps = self.eps - if not np.isclose(eps, 0.0): + # add eps to act in training or testing phase + if not self.updating and not np.isclose(self.eps, 0.0): for i in range(len(q)): - if np.random.rand() < eps: + if np.random.rand() < self.eps: q_ = np.random.rand(*q[i].shape) if hasattr(obs, "mask"): q_[~obs.mask[i]] = -np.inf @@ -169,7 +166,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) - q = self(batch, eps=0.0).logits + q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns.flatten(), q) td = r - q diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index e44f8a124..8d1d72369 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -110,7 +110,6 @@ def forward( # type: ignore batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", - explorating: bool = True, **kwargs: Any, ) -> Batch: obs = batch[input] @@ -123,7 +122,7 @@ def forward( # type: ignore y = self._action_scale * (1 - y.pow(2)) + self.__eps log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) - if self._noise is not None and self.training and explorating: + if self._noise is not None and not self.updating: act += to_torch_as(self._noise(act.shape), act) act = act.clamp(self._range[0], self._range[1]) return Batch( @@ -134,7 +133,7 @@ def _target_q( ) -> torch.Tensor: batch = buffer[indice] # batch.obs: s_{t+n} with torch.no_grad(): - obs_next_result = self(batch, input='obs_next', explorating=False) + obs_next_result = self(batch, input='obs_next') a_ = obs_next_result.act batch.act = to_torch_as(batch.act, a_) target_q = torch.min( @@ -167,7 +166,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor - obs_result = self(batch, explorating=False) + obs_result = self(batch) a = obs_result.act current_q1a = self.critic1(batch.obs, a).flatten() current_q2a = self.critic2(batch.obs, a).flatten() diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 01e6530bb..170fd6835 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -75,6 +75,8 @@ def offpolicy_trainer( best_epoch, best_reward = -1, -1.0 stat: Dict[str, MovAvg] = {} start_time = time.time() + train_collector.reset_stat() + test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): # train diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 37f427826..877c6348c 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -75,6 +75,8 @@ def onpolicy_trainer( best_epoch, best_reward = -1, -1.0 stat: Dict[str, MovAvg] = {} start_time = time.time() + train_collector.reset_stat() + test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): # train diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 547593f66..0a8452ced 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -116,6 +116,7 @@ def conv2d_layers_size_out( nn.ReLU(inplace=True), nn.Flatten(), nn.Linear(linear_input_size, 512), + nn.ReLU(inplace=True), nn.Linear(512, np.prod(action_shape)), )