Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clarify updating state #224

Merged
merged 12 commits into from
Sep 22, 2020
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
- Vanilla Imitation Learning
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
Expand Down
28 changes: 28 additions & 0 deletions docs/tutorials/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,34 @@ A policy class typically has the following parts:
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``.


.. _policy_state:

States for policy
^^^^^^^^^^^^^^^^^

During the training process, the policy has two main states: training state and testing state. The training state can be further divided into the collecting state and updating state.

The meaning of training and testing state is obvious: the agent interacts with environment, collects training data and performs update, that's training state; the testing state is to evaluate the performance of the current policy during training process.

As for the collecting state, it is defined as interacting with environments and collecting training data into the buffer;
we define the updating state as performing a model update by :meth:`~tianshou.policy.BasePolicy.update` during training process.


In order to distinguish these states, you can check the policy state by ``policy.training`` and ``policy.updating``. The state setting is as follows:

+-----------------------------------+-----------------+-----------------+
| State for policy | policy.training | policy.updating |
+================+==================+=================+=================+
| | Collecting state | True | False |
| Training state +------------------+-----------------+-----------------+
| | Updating state | True | True |
+----------------+------------------+-----------------+-----------------+
| Testing state | False | False |
+-----------------------------------+-----------------+-----------------+

``policy.updating`` is helpful to distinguish the different exploration state, for example, in DQN we don't have to use epsilon-greedy in a pure network update, so ``policy.updating`` is helpful for setting epsilon in this case.


policy.forward
^^^^^^^^^^^^^^

Expand Down
6 changes: 5 additions & 1 deletion tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,14 @@ def reset(self) -> None:
obs_next={}, policy={})
self.reset_env()
self.reset_buffer()
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
self.reset_stat()
if self._action_noise is not None:
self._action_noise.reset()

def reset_stat(self) -> None:
"""Reset the statistic variables."""
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0

def reset_buffer(self) -> None:
"""Reset the main data buffer."""
if self.buffer is not None:
Expand Down
14 changes: 14 additions & 0 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
self.observation_space = observation_space
self.action_space = action_space
self.agent_id = 0
self.updating = False
self._compile()

def set_agent_id(self, agent_id: int) -> None:
Expand Down Expand Up @@ -118,6 +119,13 @@ def learn(

:return: A dict which includes loss and its corresponding label.

.. note::

In order to distinguish the collecting state, updating state and
testing state, you can check the policy state by ``self.training``
and ``self.updating``. Please refer to :ref:`policy_state` for more
detailed explanation.

.. warning::

If you use ``torch.distributions.Normal`` and
Expand Down Expand Up @@ -146,6 +154,10 @@ def update(
"""Update the policy network and replay buffer.

It includes 3 function steps: process_fn, learn, and post_process_fn.
In addition, this function will change the value of ``self.updating``:
it will be False before this function and will be True when executing
:meth:`update`. Please refer to :ref:`policy_state` for more detailed
explanation.

:param int sample_size: 0 means it will extract all the data from the
buffer, otherwise it will sample a batch with given sample_size.
Expand All @@ -154,9 +166,11 @@ def update(
if buffer is None:
return {}
batch, indice = buffer.sample(sample_size)
self.updating = True
batch = self.process_fn(batch, buffer, indice)
result = self.learn(batch, **kwargs)
self.post_process_fn(batch, buffer, indice)
self.updating = False
return result

@staticmethod
Expand Down
11 changes: 5 additions & 6 deletions tianshou/policy/modelfree/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def _target_q(
) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next',
explorating=False).act)
target_q = self.critic_old(
batch.obs_next,
self(batch, model='actor_old', input='obs_next').act)
return target_q

def process_fn(
Expand All @@ -124,7 +124,6 @@ def forward(
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "actor",
input: str = "obs",
explorating: bool = True,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
Expand All @@ -143,7 +142,7 @@ def forward(
obs = batch[input]
actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias
if self._noise and self.training and explorating:
if self._noise and not self.updating:
actions += to_torch_as(self._noise(actions.shape), actions)
actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h)
Expand All @@ -158,7 +157,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()
action = self(batch, explorating=False).act
action = self(batch).act
actor_loss = -self.critic(batch.obs, action).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
Expand Down
13 changes: 5 additions & 8 deletions tianshou/policy/modelfree/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _target_q(
batch = buffer[indice] # batch.obs_next: s_{t+n}
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(batch, input="obs_next", eps=0).act
a = self(batch, input="obs_next").act
with torch.no_grad():
target_q = self(
batch, model="model_old", input="obs_next"
Expand Down Expand Up @@ -110,7 +110,6 @@ def forward(
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "model",
input: str = "obs",
eps: Optional[float] = None,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
Expand Down Expand Up @@ -152,12 +151,10 @@ def forward(
q_: np.ndarray = to_numpy(q)
q_[~obs.mask] = -np.inf
act = q_.argmax(axis=1)
# add eps to act
if eps is None:
eps = self.eps
if not np.isclose(eps, 0.0):
# add eps to act in training or testing phase
if not self.updating and not np.isclose(self.eps, 0.0):
for i in range(len(q)):
if np.random.rand() < eps:
if np.random.rand() < self.eps:
q_ = np.random.rand(*q[i].shape)
if hasattr(obs, "mask"):
q_[~obs.mask[i]] = -np.inf
Expand All @@ -169,7 +166,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop("weight", 1.0)
q = self(batch, eps=0.0).logits
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns.flatten(), q)
td = r - q
Expand Down
7 changes: 3 additions & 4 deletions tianshou/policy/modelfree/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def forward( # type: ignore
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
input: str = "obs",
explorating: bool = True,
**kwargs: Any,
) -> Batch:
obs = batch[input]
Expand All @@ -123,7 +122,7 @@ def forward( # type: ignore
y = self._action_scale * (1 - y.pow(2)) + self.__eps
log_prob = dist.log_prob(x).unsqueeze(-1)
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
if self._noise is not None and self.training and explorating:
if self._noise is not None and not self.updating:
act += to_torch_as(self._noise(act.shape), act)
act = act.clamp(self._range[0], self._range[1])
return Batch(
Expand All @@ -134,7 +133,7 @@ def _target_q(
) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n}
with torch.no_grad():
obs_next_result = self(batch, input='obs_next', explorating=False)
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act
batch.act = to_torch_as(batch.act, a_)
target_q = torch.min(
Expand Down Expand Up @@ -167,7 +166,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
batch.weight = (td1 + td2) / 2.0 # prio-buffer

# actor
obs_result = self(batch, explorating=False)
obs_result = self(batch)
a = obs_result.act
current_q1a = self.critic1(batch.obs, a).flatten()
current_q2a = self.critic2(batch.obs, a).flatten()
Expand Down
2 changes: 2 additions & 0 deletions tianshou/trainer/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def offpolicy_trainer(
best_epoch, best_reward = -1, -1.0
stat: Dict[str, MovAvg] = {}
start_time = time.time()
train_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train
Expand Down
2 changes: 2 additions & 0 deletions tianshou/trainer/onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def onpolicy_trainer(
best_epoch, best_reward = -1, -1.0
stat: Dict[str, MovAvg] = {}
start_time = time.time()
train_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train
Expand Down
1 change: 1 addition & 0 deletions tianshou/utils/net/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def conv2d_layers_size_out(
nn.ReLU(inplace=True),
nn.Flatten(),
nn.Linear(linear_input_size, 512),
nn.ReLU(inplace=True),
nn.Linear(512, np.prod(action_shape)),
)

Expand Down