Skip to content

Commit

Permalink
polish(nyz): polish dqn and ppo comments (opendilab#732)
Browse files Browse the repository at this point in the history
* polish(nyz) polish dqn and ppo comments

* polish(nyz) polish ddpg comments

* polish(nyz) polish impala comments

* polish(nyz) polish pdqn comments

* polish(nyz) polish r2d2 comments

* polish(nyz): polish policy mode comments

* polish(nyz): polish sac comments

* polish(nyz): polish cql/dt comments

* polish(nyz): complete dqn comments

* fix(nyz): fix discrete cql/sac unittest bugs

* polish(nyz): complete r2d2 comments

* polish(nyz): complete ddpg/bc comments

* polish(nyz): complete sac/cql comments

* polish(nyz): polish qmix/mdqn/pdqn comments

* polish(nyz): complete ppo/impala/dt comments
  • Loading branch information
PaParaZz1 authored Oct 31, 2023
1 parent c005205 commit 111bf24
Show file tree
Hide file tree
Showing 22 changed files with 3,391 additions and 1,383 deletions.
6 changes: 3 additions & 3 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .base_policy import Policy, CommandModePolicy, create_policy, get_policy_cls
from .common_utils import single_env_forward_wrapper, single_env_forward_wrapper_ttorch
from .common_utils import single_env_forward_wrapper, single_env_forward_wrapper_ttorch, default_preprocess_learn
from .dqn import DQNSTDIMPolicy, DQNPolicy
from .mdqn import MDQNPolicy
from .iqn import IQNPolicy
Expand All @@ -17,8 +17,8 @@
from .pg import PGPolicy
from .a2c import A2CPolicy
from .ppo import PPOPolicy, PPOPGPolicy, PPOOffPolicy
from .sac import SACPolicy, SACDiscretePolicy, SQILSACPolicy
from .cql import CQLPolicy, CQLDiscretePolicy
from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy
from .cql import CQLPolicy, DiscreteCQLPolicy
from .edac import EDACPolicy
from .impala import IMPALAPolicy
from .ngu import NGUPolicy
Expand Down
550 changes: 531 additions & 19 deletions ding/policy/base_policy.py

Large diffs are not rendered by default.

171 changes: 107 additions & 64 deletions ding/policy/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@

@POLICY_REGISTRY.register('bc')
class BehaviourCloningPolicy(Policy):
"""
Overview:
Behaviour Cloning (BC) policy class, which supports both discrete and continuous action space. \
The policy is trained by supervised learning, and the data is a offline dataset collected by expert.
"""

config = dict(
type='bc',
Expand Down Expand Up @@ -52,18 +57,46 @@ class BehaviourCloningPolicy(Policy):
max=0.5,
),
),
eval=dict(),
other=dict(replay_buffer=dict(replay_buffer_size=10000, )),
eval=dict(), # for compatibility
)

def default_model(self) -> Tuple[str, List[str]]:
"""
Overview:
Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
automatically call this method to get the default model setting and create model.
Returns:
- model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
.. note::
The user can define and use customized network model but must obey the same inferface definition indicated \
by import_names path. For example about discrete BC, its registered name is ``discrete_bc`` and the \
import_names is ``ding.model.template.bc``.
"""
if self._cfg.continuous:
return 'continuous_bc', ['ding.model.template.bc']
else:
return 'discrete_bc', ['ding.model.template.bc']

def _init_learn(self):
assert self._cfg.learn.optimizer in ['SGD', 'Adam']
def _init_learn(self) -> None:
"""
Overview:
Initialize the learn mode of policy, including related attributes and modules. For BC, it mainly contains \
optimizer, algorithm-specific arguments such as lr_scheduler, loss, etc. \
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
.. note::
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
and ``_load_state_dict_learn`` methods.
.. note::
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
.. note::
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
"""
assert self._cfg.learn.optimizer in ['SGD', 'Adam'], self._cfg.learn.optimizer
if self._cfg.learn.optimizer == 'SGD':
self._optimizer = SGD(
self._model.parameters(),
Expand Down Expand Up @@ -103,20 +136,38 @@ def lr_scheduler_fn(epoch):
elif self._cfg.loss_type == 'mse_loss':
self._loss = nn.MSELoss()
else:
raise KeyError
raise KeyError("not support loss type: {}".format(self._cfg.loss_type))
else:
if not self._cfg.learn.ce_label_smooth:
self._loss = nn.CrossEntropyLoss()
else:
self._loss = LabelSmoothCELoss(0.1)

if self._cfg.learn.show_accuracy:
# accuracy statistics for debugging in discrete action space env, e.g. for gfootball
self.total_accuracy_in_dataset = []
self.action_accuracy_in_dataset = {k: [] for k in range(self._cfg.action_shape)}
def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the replay buffer and then returns the output \
result, including various training information such as loss and time.
Arguments:
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
training samples. For each element in list, the key of the dict is the name of data items and the \
value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
dimension by some utility functions such as ``default_preprocess_learn``. \
For BC, each element in list is a dict containing at least the following keys: ``obs``, ``action``.
Returns:
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
def _forward_learn(self, data):
if not isinstance(data, dict):
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
"""
if isinstance(data, list):
data = default_collate(data)
if self._cuda:
data = to_device(data, self._device)
Expand All @@ -125,10 +176,10 @@ def _forward_learn(self, data):
obs, action = data['obs'], data['action'].squeeze()
if self._cfg.continuous:
if self._cfg.learn.tanh_mask:
'''
"""tanh_mask
We mask the action out of range of [tanh(-1),tanh(1)], model will learn information
and produce action in [-1,1]. So the action won't always converge to -1 or 1.
'''
"""
mu = self._eval_model.forward(data['obs'])['action']
bound = 1 - 2 / (math.exp(2) + 1) # tanh(1): (e-e**(-1))/(e+e**(-1))
mask = mu.ge(-bound) & mu.le(bound)
Expand Down Expand Up @@ -183,28 +234,57 @@ def _forward_learn(self, data):
'sync_time': sync_time,
}

def _monitor_vars_learn(self):
def _monitor_vars_learn(self) -> List[str]:
"""
Overview:
Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
as text logger, tensorboard logger, will use these keys to save the corresponding data.
Returns:
- necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
"""
return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time']

def _init_eval(self):
"""
Overview:
Initialize the eval mode of policy, including related attributes and modules. For BC, it contains the \
eval model to greedily select action with argmax q_value mechanism for discrete action space.
This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
.. note::
If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
"""
if self._cfg.continuous:
self._eval_model = model_wrap(self._model, wrapper_name='base')
else:
self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
self._eval_model.reset()

def _forward_eval(self, data):
gfootball_flag = False
def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
"""
Overview:
Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
action to interact with the envs.
Arguments:
- data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
key of the dict is environment id and the value is the corresponding data of the env.
Returns:
- output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
key of the dict is the same as the input data, i.e. environment id.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
"""
tensor_input = isinstance(data, torch.Tensor)
if tensor_input:
data = default_collate(list(data))
else:
data_id = list(data.keys())
if data_id == ['processed_obs', 'raw_obs']:
# for gfootball
gfootball_flag = True
data = {0: data}
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
Expand All @@ -213,22 +293,20 @@ def _forward_eval(self, data):
output = self._eval_model.forward(data)
if self._cuda:
output = to_device(output, 'cpu')
if tensor_input or gfootball_flag:
if tensor_input:
return output
else:
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}

def _init_collect(self) -> None:
r"""
"""
Overview:
Collect mode init method. Called by ``self.__init__``.
Init traj and unroll length, collect model.
Enable the eps_greedy_sample
BC policy uses offline dataset so it does not need to collect data. However, sometimes we need to use the \
trained BC policy to collect data for other purposes.
"""
self._unroll_len = self._cfg.collect.unroll_len
if self._cfg.continuous:
# self._collect_model = model_wrap(self._model, wrapper_name='base')
self._collect_model = model_wrap(
self._model,
wrapper_name='action_noise',
Expand All @@ -244,14 +322,6 @@ def _init_collect(self) -> None:
self._collect_model.reset()

def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]:
r"""
Overview:
Forward function for collect mode with eps_greedy
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs'].
Returns:
- data (:obj:`dict`): The collected data
"""
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
Expand All @@ -268,43 +338,16 @@ def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]:
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}

def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
r"""
Overview:
Generate dict type transition data from inputs.
Arguments:
- obs (:obj:`Any`): Env observation
- model_output (:obj:`dict`): Output of collect model, including at least ['action']
- timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
(here 'obs' indicates obs after env step).
Returns:
- transition (:obj:`dict`): Dict type transition data.
"""
def _process_transition(self, obs: Any, policy_output: dict, timestep: namedtuple) -> dict:
transition = {
'obs': obs,
'next_obs': timestep.obs,
'action': model_output['action'],
'action': policy_output['action'],
'reward': timestep.reward,
'done': timestep.done,
}
return EasyDict(transition)

def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Overview:
For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \
can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \
or some continuous transitions(DRQN).
Arguments:
- data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \
format as the return value of ``self._process_transition`` method.
Returns:
- samples (:obj:`dict`): The list of training samples.
.. note::
We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \
And the user can customize the this data processing procecure by overriding this two methods and collector \
itself.
"""
data = get_nstep_return_data(data, 1, 1)
return get_train_sample(data, self._unroll_len)
13 changes: 6 additions & 7 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .td3 import TD3Policy
from .td3_vae import TD3VAEPolicy
from .td3_bc import TD3BCPolicy
from .sac import SACPolicy, SACDiscretePolicy
from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy
from .mbpolicy.mbsac import MBSACPolicy, STEVESACPolicy
from .mbpolicy.dreamer import DREAMERPolicy
from .qmix import QMIXPolicy
Expand All @@ -42,10 +42,9 @@
from .r2d3 import R2D3Policy

from .d4pg import D4PGPolicy
from .cql import CQLPolicy, CQLDiscretePolicy
from .cql import CQLPolicy, DiscreteCQLPolicy
from .dt import DTPolicy
from .pdqn import PDQNPolicy
from .sac import SQILSACPolicy
from .madqn import MADQNPolicy
from .bdq import BDQPolicy
from .bcq import BCQPolicy
Expand Down Expand Up @@ -316,8 +315,8 @@ class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('cql_discrete_command')
class CQLDiscreteCommandModePolicy(CQLDiscretePolicy, EpsCommandModePolicy):
@POLICY_REGISTRY.register('discrete_cql_command')
class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy):
pass


Expand Down Expand Up @@ -376,8 +375,8 @@ class PDQNCommandModePolicy(PDQNPolicy, EpsCommandModePolicy):
pass


@POLICY_REGISTRY.register('sac_discrete_command')
class SACDiscreteCommandModePolicy(SACDiscretePolicy, EpsCommandModePolicy):
@POLICY_REGISTRY.register('discrete_sac_command')
class DiscreteSACCommandModePolicy(DiscreteSACPolicy, EpsCommandModePolicy):
pass


Expand Down
Loading

0 comments on commit 111bf24

Please sign in to comment.