Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(xyy):add HPT model to implement PolicyStem+DuelingHead #841

Merged
merged 21 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
from .bcq import BCQ
from .edac import EDAC
from .hpt import HPT
from .qgpo import QGPO
from .ebm import EBM, AutoregressiveEBM
from .havac import HAVAC
221 changes: 221 additions & 0 deletions ding/model/template/hpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
from typing import Union, Optional, Dict, Callable, List
from einops import rearrange, repeat
import torch
import torch.nn as nn
from ding.model.common.head import DuelingHead
from ding.utils.registry_factory import MODEL_REGISTRY

INIT_CONST = 0.02
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved


@MODEL_REGISTRY.register('hpt')
class HPT(nn.Module):
"""
Overview:
The HPT model for reinforcement learning, which consists of a Policy Stem and a Dueling Head.
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
The Policy Stem utilizes cross-attention to process input data,
and the Dueling Head computes Q-values for discrete action spaces.

Interfaces:
__init__, forward
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved

.. note::
The model is designed to be flexible and can be adapted
for different input dimensions and action spaces.
"""

def __init__(self, state_dim: int, action_dim: int):
"""
Overview:
Initialize the HPT model, including the Policy Stem and the Dueling Head.

Arguments:
- state_dim (:obj:`int`): The dimension of the input state.
- action_dim (:obj:`int`): The dimension of the action space.

.. note::
The Policy Stem is initialized with cross-attention,
and the Dueling Head is set to process the resulting tokens.
"""
super(HPT, self).__init__()
# Initialise Policy Stem
self.policy_stem = PolicyStem()
self.policy_stem.init_cross_attn()

# Dueling Head, input is 16*128, output is action dimension
self.head = DuelingHead(hidden_size=16 * 128, output_size=action_dim)

def forward(self, x: torch.Tensor):
"""
Overview:
Forward pass of the HPT model.
Computes latent tokens from the input state and passes them through the Dueling Head.

Arguments:
- x (:obj:`torch.Tensor`): The input tensor representing the state.

Returns:
- q_values (:obj:`torch.Tensor`): The predicted Q-values for each action.
"""
# Policy Stem Outputs [B, 16, 128]
tokens = self.policy_stem.compute_latent(x)
# Flatten Operation
tokens_flattened = tokens.view(tokens.size(0), -1) # [B, 16*128]
# Enter to Dueling Head
q_values = self.head(tokens_flattened)
return q_values


class PolicyStem(nn.Module):
"""
Overview:
The Policy Stem module is responsible for processing input features
and generating latent tokens using a cross-attention mechanism.
It extracts features from the input and then applies cross-attention
to generate a set of latent tokens.

Interfaces:
__init__, init_cross_attn, compute_latent, forward

.. note::
This module is inspired by the implementation in the Perceiver IO model
and uses attention mechanisms for feature extraction.
"""

def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs):
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
"""
Overview:
Initialize the Policy Stem module with a feature extractor and cross-attention mechanism.

Arguments:
- feature_dim (:obj:`int`): The dimension of the input features.
- token_dim (:obj:`int`): The dimension of the latent tokens generated
by the attention mechanism.
"""
super().__init__()
# Initialise the feature extraction module
self.feature_extractor = nn.Linear(feature_dim, token_dim)
# Initialise CrossAttention
self.init_cross_attn()

def init_cross_attn(self):
"""Initialize cross-attention module and learnable tokens."""
token_num = 16
self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST)
self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1)

def compute_latent(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Compute latent representations of the input data using
the feature extractor and cross-attention.

Arguments:
- x (:obj:`torch.Tensor`): Input tensor with shape [B, T, ..., F].

Returns:
- stem_tokens (:obj:`torch.Tensor`): Latent tokens with shape [B, 16, 128].
"""
# Using the Feature Extractor
stem_feat = self.feature_extractor(x)
stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128)
# Calculating latent tokens using CrossAttention
stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128)
stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128)
return stem_tokens

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Forward pass to compute latent tokens.

Arguments:
- x (:obj:`torch.Tensor`): Input tensor.

Returns:
- torch.Tensor: Latent tokens tensor.
"""
return self.compute_latent(x)

def freeze(self):
"""Freeze the parameters of the model, preventing updates during training."""
for param in self.parameters():
param.requires_grad = False

def unfreeze(self):
"""Unfreeze the parameters of the model, allowing updates during training."""
for param in self.parameters():
param.requires_grad = True

def save(self, path: str):
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
"""Save the model state dictionary to a file."""
torch.save(self.state_dict(), path)

@property
def device(self):
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the device on which the model parameters are located."""
return next(self.parameters()).device


class CrossAttention(nn.Module):
"""
Overview:
CrossAttention module used in the Perceiver IO model.
It computes the attention between the query and context tensors,
and returns the output tensor after applying attention.

Arguments:
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
query_dim (:obj:`int`): The dimension of the query input.
heads (:obj:`int`, optional): The number of attention heads. Defaults to 8.
dim_head (:obj:`int`, optional): The dimension of each attention head. Defaults to 64.
dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0.
"""

def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = query_dim
self.scale = dim_head ** -0.5
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
self.heads = heads

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, query_dim)

self.dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Overview:
Forward pass of the CrossAttention module.
Computes the attention between the query and context tensors.

Arguments:
- x (:obj:`torch.Tensor`): The query input tensor.
- context (:obj:`torch.Tensor`): The context input tensor.
- mask (:obj:`torch.Tensor`, optional): The attention mask tensor. Defaults to None.
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
- torch.Tensor: The output tensor after applying attention.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add dtype

"""
h = self.heads
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

if mask is not None:
# fill in the masks with negative values
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value)

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)

# dropout
attn = self.dropout(attn)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
52 changes: 52 additions & 0 deletions ding/model/template/tests/test_hpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
import torch
from itertools import product
from ding.model.template.hpt import HPT
from ding.torch_utils import is_differentiable

T, B = 3, 4
obs_shape = [4, (8, ), (4, 64, 64)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the group (4, 64, 64) here, the current HPT model you implemented here can't support the image input like (4, 64, 64). It need a CNN feature extractor, and the corresponding state_dim should be `(4, 64, 64) rather than the current scalar. We will left this part for future, thus you can remove this group now.

act_shape = [3, (6, ), [2, 3, 6]]
args = list(product(*[obs_shape, act_shape]))


@pytest.mark.unittest
class TestHPT:

def output_check(self, model, outputs):
if isinstance(outputs, torch.Tensor):
loss = outputs.sum()
elif isinstance(outputs, list):
loss = sum([t.sum() for t in outputs])
elif isinstance(outputs, dict):
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)

@pytest.mark.parametrize('obs_shape, act_shape', args)
def test_hpt(self, obs_shape, act_shape):
if isinstance(obs_shape, int):
inputs = torch.randn(B, obs_shape)
state_dim = obs_shape
else:
inputs = torch.randn(B, *obs_shape)
state_dim = obs_shape[0]

if isinstance(act_shape, int):
action_dim = act_shape
else:
action_dim = len(act_shape)

model = HPT(state_dim=state_dim, action_dim=action_dim)
outputs = model(inputs)

assert isinstance(outputs, torch.Tensor)

if isinstance(act_shape, int):
assert outputs.shape == (B, act_shape)
elif len(act_shape) == 1:
assert outputs.shape == (B, *act_shape)
else:
for i, s in enumerate(act_shape):
assert outputs[i].shape == (B, s)

self.output_check(model, outputs)
2 changes: 1 addition & 1 deletion dizoo/atari/config/serial/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from dizoo.atari.config.serial.pong import *
from dizoo.atari.config.serial.qbert import *
from dizoo.atari.config.serial.spaceinvaders import *
from dizoo.atari.config.serial.asterix import *
from dizoo.atari.config.serial.asterix import *
2 changes: 1 addition & 1 deletion dizoo/atari/config/serial/phoenix/phoenix_fqf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@
if __name__ == '__main__':
# or you can enter `ding -m serial -c phoenix_fqf_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0)
serial_pipeline((main_config, create_config), seed=0)
2 changes: 1 addition & 1 deletion dizoo/atari/config/serial/pong/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .pong_dqn_config import pong_dqn_config, pong_dqn_create_config
from .pong_dqn_envpool_config import pong_dqn_envpool_config, pong_dqn_envpool_create_config
from .pong_dqfd_config import pong_dqfd_config, pong_dqfd_create_config
from .pong_dqfd_config import pong_dqfd_config, pong_dqfd_create_config
2 changes: 1 addition & 1 deletion dizoo/atari/config/serial/pong/pong_fqf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@
if __name__ == '__main__':
# or you can enter `ding -m serial -c pong_fqf_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0)
serial_pipeline((main_config, create_config), seed=0)
2 changes: 1 addition & 1 deletion dizoo/atari/config/serial/pong/pong_gail_dqn_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,4 @@
max_env_step=1000000,
seed=0,
collect_data=True
)
)
2 changes: 1 addition & 1 deletion dizoo/atari/config/serial/qbert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .qbert_dqn_config import qbert_dqn_config, qbert_dqn_create_config
from .qbert_dqfd_config import qbert_dqfd_config, qbert_dqfd_create_config
from .qbert_dqfd_config import qbert_dqfd_config, qbert_dqfd_create_config
2 changes: 1 addition & 1 deletion dizoo/atari/config/serial/qbert/qbert_fqf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@
if __name__ == '__main__':
# or you can enter `ding -m serial -c qbert_fqf_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0)
serial_pipeline((main_config, create_config), seed=0)
2 changes: 1 addition & 1 deletion dizoo/atari/config/serial/spaceinvaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .spaceinvaders_dqn_config import spaceinvaders_dqn_config, spaceinvaders_dqn_create_config
from .spaceinvaders_dqfd_config import spaceinvaders_dqfd_config, spaceinvaders_dqfd_create_config
from .spaceinvaders_dqfd_config import spaceinvaders_dqfd_config, spaceinvaders_dqfd_create_config
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@
if __name__ == '__main__':
# or you can enter `ding -m serial -c spaceinvaders_fqf_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0)
serial_pipeline((main_config, create_config), seed=0)
2 changes: 1 addition & 1 deletion dizoo/beergame/entry/beergame_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def main(cfg, seed=0):

if __name__ == "__main__":
beergame_ppo_config.exp_name = 'beergame_evaluate'
main(beergame_ppo_config)
main(beergame_ppo_config)
Loading
Loading