From 366ef6880eeeefbe9ff262ad96702f7b50ada0da Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Sun, 8 Dec 2024 01:21:42 +0800 Subject: [PATCH] feature(xyy):add HPT model and test_hpt --- ding/model/template/hpt.py | 59 +++++++------------ ding/model/template/tests/test_hpt.py | 4 +- .../config/lunarlander_hpt_config.py | 10 ++-- .../entry/lunarlander_dqn_example.py | 1 - 4 files changed, 30 insertions(+), 44 deletions(-) diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index cf38a93f39..07a47bacea 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -5,8 +5,6 @@ from ding.model.common.head import DuelingHead from ding.utils.registry_factory import MODEL_REGISTRY -INIT_CONST = 0.02 - @MODEL_REGISTRY.register('hpt') class HPT(nn.Module): @@ -17,11 +15,10 @@ class HPT(nn.Module): and the Dueling Head computes Q-values for discrete action spaces. Interfaces: - __init__, forward + ``__init__``, ``forward`` + + GitHub: [https://github.com/liruiw/HPT/blob/main/hpt/models/policy_stem.py] - .. note:: - The model is designed to be flexible and can be adapted - for different input dimensions and action spaces. """ def __init__(self, state_dim: int, action_dim: int): @@ -75,14 +72,15 @@ class PolicyStem(nn.Module): to generate a set of latent tokens. Interfaces: - __init__, init_cross_attn, compute_latent, forward + ``__init__``, ``init_cross_attn``, ``compute_latent``, ``forward`` .. note:: This module is inspired by the implementation in the Perceiver IO model and uses attention mechanisms for feature extraction. """ + INIT_CONST = 0.02 - def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): + def __init__(self, feature_dim: int = 8, token_dim: int = 128): """ Overview: Initialize the Policy Stem module with a feature extractor and cross-attention mechanism. @@ -101,7 +99,7 @@ def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): def init_cross_attn(self): """Initialize cross-attention module and learnable tokens.""" token_num = 16 - self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) + self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * self.INIT_CONST) self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) def compute_latent(self, x: torch.Tensor) -> torch.Tensor: @@ -137,44 +135,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ return self.compute_latent(x) - def freeze(self): - """Freeze the parameters of the model, preventing updates during training.""" - for param in self.parameters(): - param.requires_grad = False - - def unfreeze(self): - """Unfreeze the parameters of the model, allowing updates during training.""" - for param in self.parameters(): - param.requires_grad = True - - def save(self, path: str): - """Save the model state dictionary to a file.""" - torch.save(self.state_dict(), path) - @property - def device(self): + def device(self) -> torch.device: """Returns the device on which the model parameters are located.""" return next(self.parameters()).device class CrossAttention(nn.Module): - """ - Overview: - CrossAttention module used in the Perceiver IO model. - It computes the attention between the query and context tensors, - and returns the output tensor after applying attention. - - Arguments: - query_dim (:obj:`int`): The dimension of the query input. - heads (:obj:`int`, optional): The number of attention heads. Defaults to 8. - dim_head (:obj:`int`, optional): The dimension of each attention head. Defaults to 64. - dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0. - """ def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0): + """ + Overview: + CrossAttention module used in the Perceiver IO model. + It computes the attention between the query and context tensors, + and returns the output tensor after applying attention. + + Arguments: + - query_dim (:obj:`int`): The dimension of the query input. + - heads (:obj:`int`, optional): The number of attention heads. Defaults to 8. + - dim_head (:obj:`int`, optional): The dimension of each attention head. Defaults to 64. + - dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0. + """ super().__init__() inner_dim = dim_head * heads context_dim = query_dim + # Scaling factor for the attention logits to ensure stable gradients. self.scale = dim_head ** -0.5 self.heads = heads @@ -193,7 +178,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.T Arguments: - x (:obj:`torch.Tensor`): The query input tensor. - context (:obj:`torch.Tensor`): The context input tensor. - - mask (:obj:`torch.Tensor`, optional): The attention mask tensor. Defaults to None. + - mask (:obj:`Optional[torch.Tensor]`): The attention mask tensor. Defaults to None. Returns: - torch.Tensor: The output tensor after applying attention. diff --git a/ding/model/template/tests/test_hpt.py b/ding/model/template/tests/test_hpt.py index bfb7844f32..e4d9729d59 100644 --- a/ding/model/template/tests/test_hpt.py +++ b/ding/model/template/tests/test_hpt.py @@ -5,8 +5,8 @@ from ding.torch_utils import is_differentiable T, B = 3, 4 -obs_shape = [4, (8, ), (4, 64, 64)] -act_shape = [3, (6, ), [2, 3, 6]] +obs_shape = [4, (8, )] +act_shape = [3, (6, )] args = list(product(*[obs_shape, act_shape])) diff --git a/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py b/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py index fea3aca6bc..b58166d566 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py @@ -69,8 +69,10 @@ ) lunarlander_hpt_create_config = EasyDict(lunarlander_hpt_create_config) create_config = lunarlander_hpt_create_config +""" +This is a configuration file for LunarLander environment with HPT (Hindsight Policy Transformer). +To run this config, please use the lunarlander_hpt_example.py script. -if __name__ == "__main__": - # or you can enter `ding -m serial -c lunarlander_hpt_config.py -s 0` - from ding.entry import serial_pipeline - serial_pipeline([main_config, create_config], seed=0) +Example: + python lunarlander_hpt_example.py +""" diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py index 048554c310..ae9b30d1fd 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py @@ -59,7 +59,6 @@ def main(): task.use(ModelExchanger(model)) # Here is the part of single process pipeline. - evaluator_env.enable_save_replay(replay_path='./video') task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(eps_greedy_handler(cfg)) task.use(StepCollector(cfg, policy.collect_mode, collector_env))