Skip to content

Commit

Permalink
feature(xyy):add HPT model and test_hpt
Browse files Browse the repository at this point in the history
  • Loading branch information
luodi-7 committed Dec 7, 2024
1 parent 7afb21d commit 366ef68
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 44 deletions.
59 changes: 22 additions & 37 deletions ding/model/template/hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from ding.model.common.head import DuelingHead
from ding.utils.registry_factory import MODEL_REGISTRY

INIT_CONST = 0.02


@MODEL_REGISTRY.register('hpt')
class HPT(nn.Module):
Expand All @@ -17,11 +15,10 @@ class HPT(nn.Module):
and the Dueling Head computes Q-values for discrete action spaces.
Interfaces:
__init__, forward
``__init__``, ``forward``
GitHub: [https://github.com/liruiw/HPT/blob/main/hpt/models/policy_stem.py]
.. note::
The model is designed to be flexible and can be adapted
for different input dimensions and action spaces.
"""

def __init__(self, state_dim: int, action_dim: int):
Expand Down Expand Up @@ -75,14 +72,15 @@ class PolicyStem(nn.Module):
to generate a set of latent tokens.
Interfaces:
__init__, init_cross_attn, compute_latent, forward
``__init__``, ``init_cross_attn``, ``compute_latent``, ``forward``
.. note::
This module is inspired by the implementation in the Perceiver IO model
and uses attention mechanisms for feature extraction.
"""
INIT_CONST = 0.02

def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs):
def __init__(self, feature_dim: int = 8, token_dim: int = 128):
"""
Overview:
Initialize the Policy Stem module with a feature extractor and cross-attention mechanism.
Expand All @@ -101,7 +99,7 @@ def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs):
def init_cross_attn(self):
"""Initialize cross-attention module and learnable tokens."""
token_num = 16
self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST)
self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * self.INIT_CONST)
self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1)

def compute_latent(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -137,44 +135,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
return self.compute_latent(x)

def freeze(self):
"""Freeze the parameters of the model, preventing updates during training."""
for param in self.parameters():
param.requires_grad = False

def unfreeze(self):
"""Unfreeze the parameters of the model, allowing updates during training."""
for param in self.parameters():
param.requires_grad = True

def save(self, path: str):
"""Save the model state dictionary to a file."""
torch.save(self.state_dict(), path)

@property
def device(self):
def device(self) -> torch.device:
"""Returns the device on which the model parameters are located."""
return next(self.parameters()).device


class CrossAttention(nn.Module):
"""
Overview:
CrossAttention module used in the Perceiver IO model.
It computes the attention between the query and context tensors,
and returns the output tensor after applying attention.
Arguments:
query_dim (:obj:`int`): The dimension of the query input.
heads (:obj:`int`, optional): The number of attention heads. Defaults to 8.
dim_head (:obj:`int`, optional): The dimension of each attention head. Defaults to 64.
dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0.
"""

def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0):
"""
Overview:
CrossAttention module used in the Perceiver IO model.
It computes the attention between the query and context tensors,
and returns the output tensor after applying attention.
Arguments:
- query_dim (:obj:`int`): The dimension of the query input.
- heads (:obj:`int`, optional): The number of attention heads. Defaults to 8.
- dim_head (:obj:`int`, optional): The dimension of each attention head. Defaults to 64.
- dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0.
"""
super().__init__()
inner_dim = dim_head * heads
context_dim = query_dim
# Scaling factor for the attention logits to ensure stable gradients.
self.scale = dim_head ** -0.5
self.heads = heads

Expand All @@ -193,7 +178,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.T
Arguments:
- x (:obj:`torch.Tensor`): The query input tensor.
- context (:obj:`torch.Tensor`): The context input tensor.
- mask (:obj:`torch.Tensor`, optional): The attention mask tensor. Defaults to None.
- mask (:obj:`Optional[torch.Tensor]`): The attention mask tensor. Defaults to None.
Returns:
- torch.Tensor: The output tensor after applying attention.
Expand Down
4 changes: 2 additions & 2 deletions ding/model/template/tests/test_hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from ding.torch_utils import is_differentiable

T, B = 3, 4
obs_shape = [4, (8, ), (4, 64, 64)]
act_shape = [3, (6, ), [2, 3, 6]]
obs_shape = [4, (8, )]
act_shape = [3, (6, )]
args = list(product(*[obs_shape, act_shape]))


Expand Down
10 changes: 6 additions & 4 deletions dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@
)
lunarlander_hpt_create_config = EasyDict(lunarlander_hpt_create_config)
create_config = lunarlander_hpt_create_config
"""
This is a configuration file for LunarLander environment with HPT (Hindsight Policy Transformer).
To run this config, please use the lunarlander_hpt_example.py script.
if __name__ == "__main__":
# or you can enter `ding -m serial -c lunarlander_hpt_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline([main_config, create_config], seed=0)
Example:
python lunarlander_hpt_example.py
"""
1 change: 0 additions & 1 deletion dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def main():
task.use(ModelExchanger(model))

# Here is the part of single process pipeline.
evaluator_env.enable_save_replay(replay_path='./video')
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(eps_greedy_handler(cfg))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
Expand Down

0 comments on commit 366ef68

Please sign in to comment.