From 96672115ac28cfde6c0872708f8c8227bca4e073 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 27 Nov 2024 16:38:24 +0800 Subject: [PATCH 01/21] add hpt model and corresponding examples. --- ding/model/template/__init__.py | 3 + ding/model/template/hpt.py | 25 ++++ ding/model/template/policy_stem.py | 133 ++++++++++++++++++ .../config/lunarlander_hpt_config.py | 77 ++++++++++ .../entry/lunarlander_dqn_example.py | 18 ++- .../entry/lunarlander_hpt_example.py | 91 ++++++++++++ 6 files changed, 346 insertions(+), 1 deletion(-) create mode 100644 ding/model/template/hpt.py create mode 100644 ding/model/template/policy_stem.py create mode 100644 dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py create mode 100644 dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index 8e902f1504..d121d7f30e 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -24,8 +24,11 @@ from .vae import VanillaVAE from .decision_transformer import DecisionTransformer from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS +from .hpt import HPT + from .bcq import BCQ from .edac import EDAC from .qgpo import QGPO from .ebm import EBM, AutoregressiveEBM from .havac import HAVAC +from .policy_stem import PolicyStem diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py new file mode 100644 index 0000000000..ba550a8549 --- /dev/null +++ b/ding/model/template/hpt.py @@ -0,0 +1,25 @@ +from typing import Union, Optional, Dict, Callable, List +import torch +import torch.nn as nn + +from ding.model.common.head import DuelingHead +from ding.utils.registry_factory import MODEL_REGISTRY +from ding.model.template.policy_stem import PolicyStem +@MODEL_REGISTRY.register('hpt') +class HPT(nn.Module): + def __init__(self, state_dim, action_dim): + super(HPT, self).__init__() + # 初始化 Policy Stem + self.policy_stem = PolicyStem() + self.policy_stem.init_cross_attn() + + # Dueling Head,输入为 16*128,输出为动作维度 + self.head = DuelingHead(hidden_size=16*128, output_size=action_dim) + def forward(self, x): + # Policy Stem 输出 [B, 16, 128] + tokens = self.policy_stem.compute_latent(x) + # Flatten 操作 + tokens_flattened = tokens.view(tokens.size(0), -1) # [B, 16*128] + # 输入到 Dueling Head + q_values = self.head(tokens_flattened) + return q_values \ No newline at end of file diff --git a/ding/model/template/policy_stem.py b/ding/model/template/policy_stem.py new file mode 100644 index 0000000000..81963cdbc4 --- /dev/null +++ b/ding/model/template/policy_stem.py @@ -0,0 +1,133 @@ +# -------------------------------------------------------- +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + + +import torch +from torch import nn +from typing import List, Optional + + +import torch +import torch.nn as nn +INIT_CONST = 0.02 +from einops import rearrange, repeat, reduce +class CrossAttention(nn.Module): + """ + CrossAttention module used in the Perceiver IO model. + + Args: + query_dim (int): The dimension of the query input. + heads (int, optional): The number of attention heads. Defaults to 8. + dim_head (int, optional): The dimension of each attention head. Defaults to 64. + dropout (float, optional): The dropout probability. Defaults to 0.0. + """ + + def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = query_dim + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, query_dim) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor, context: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Forward pass of the CrossAttention module. + + Args: + x (torch.Tensor): The query input tensor. + context (torch.Tensor): The context input tensor. + mask (torch.Tensor, optional): The attention mask tensor. Defaults to None. + + Returns: + torch.Tensor: The output tensor. + """ + h = self.heads + q = self.to_q(x) + k, v = self.to_kv(context).chunk(2, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale + + if mask is not None: + # fill in the masks with negative values + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + # dropout + attn = self.dropout(attn) + out = torch.einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + +class PolicyStem(nn.Module): + """policy stem""" + + def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): + super().__init__() + # 初始化特征提取模块 + self.feature_extractor = nn.Linear(feature_dim, token_dim) + # 初始化 CrossAttention + self.init_cross_attn() + + def init_cross_attn(self): + """Initialize cross attention module and learnable tokens.""" + token_num = 16 + self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) + self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) + + def compute_latent(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute latent representations of input data using attention. + + Args: + x (torch.Tensor): Input tensor with shape [B, T, ..., F]. + + Returns: + torch.Tensor: Latent tokens, shape [B, 16, 128]. + """ + # 使用特征提取器而不是直接调用 self(x) + stem_feat = self.feature_extractor(x) + stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) + # 使用 CrossAttention 计算 latent tokens + stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128) + stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128) + return stem_tokens + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass to compute latent tokens. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Latent tokens tensor. + """ + return self.compute_latent(x) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def unfreeze(self): + for param in self.parameters(): + param.requires_grad = True + + def save(self, path : str): + torch.save(self.state_dict(), path) + + @property + def device(self): + return next(self.parameters()).device + diff --git a/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py b/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py new file mode 100644 index 0000000000..9df8a034c1 --- /dev/null +++ b/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py @@ -0,0 +1,77 @@ +from easydict import EasyDict + +nstep = 3 +lunarlander_hpt_config = dict( + exp_name='lunarlander_hpt_seed0', + env=dict( + # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess' + # Env number respectively for collector and evaluator. + collector_env_num=8, + evaluator_env_num=8, + env_id='LunarLander-v2', + n_evaluator_episode=8, + stop_value=200, + # The path to save the game replay + # replay_path='./lunarlander_hpt_seed0/video', + ), + policy=dict( + # Whether to use cuda for network. + cuda=True, + load_path="./lunarlander_hpt_seed0/ckpt/ckpt_best.pth.tar", + model=dict( + obs_shape=8, + action_shape=4, + ), + # Reward's future discount factor, aka. gamma. + discount_factor=0.99, + # How many steps in td error. + nstep=nstep, + # learn_mode config + learn=dict( + update_per_collect=10, + batch_size=64, + learning_rate=0.0005, + # Frequency of target network update. + target_update_freq=100, + ), + # collect_mode config + collect=dict( + # You can use either "n_sample" or "n_episode" in collector.collect. + # Get "n_sample" samples per collect. + n_sample=64, + # Cut trajectories into pieces with length "unroll_len". + unroll_len=1, + ), + # command_mode config + other=dict( + # Epsilon greedy with decay. + eps=dict( + # Decay type. Support ['exp', 'linear']. + type='exp', + start=0.95, + end=0.1, + decay=50000, + ), + replay_buffer=dict(replay_buffer_size=100000, ) + ), + ), +) +lunarlander_hpt_config = EasyDict(lunarlander_hpt_config) +main_config = lunarlander_hpt_config + +lunarlander_hpt_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + # env_manager=dict(type='base'), + policy=dict(type='dqn'), +) +lunarlander_hpt_create_config = EasyDict(lunarlander_hpt_create_config) +create_config = lunarlander_hpt_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial -c lunarlander_dqn_config.py -s 0` + from ding.entry import serial_pipeline + serial_pipeline([main_config, create_config], seed=0) diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py index b1c28ed975..bfcc1ab1d9 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py @@ -15,11 +15,14 @@ from ding.utils import set_pkg_seed from dizoo.box2d.lunarlander.config.lunarlander_dqn_config import main_config, create_config +import torch + def main(): logging.getLogger().setLevel(logging.INFO) cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = SubprocessEnvManagerV2( env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], @@ -32,9 +35,20 @@ def main(): set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - model = DQN(**cfg.policy.model) + # 迁移模型到 GPU + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = DQN(**cfg.policy.model).to(device) + + # 检查模型是否在 GPU + for param in model.parameters(): + print("模型参数所在设备:", param.device) + break buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + + # 将模型传入 Policy policy = DQNPolicy(cfg.policy, model=model) + print("日志保存路径:", cfg.exp_name) + # Consider the case with multiple processes if task.router.is_active: @@ -50,8 +64,10 @@ def main(): # Sync their context and model between each worker. task.use(ContextExchanger(skip_n_iter=1)) task.use(ModelExchanger(model)) + # Here is the part of single process pipeline. + evaluator_env.enable_save_replay(replay_path='./video') task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(eps_greedy_handler(cfg)) task.use(StepCollector(cfg, policy.collect_mode, collector_env)) diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py new file mode 100644 index 0000000000..b9dd6d8681 --- /dev/null +++ b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py @@ -0,0 +1,91 @@ + +import gym +from ditk import logging +from ding.data.model_loader import FileModelLoader +from ding.data.storage_loader import FileStorageLoader +from ding.model.common.head import DuelingHead +from ding.model.template.hpt import HPT +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, ContextExchanger, ModelExchanger, online_logger, termination_checker, \ + nstep_reward_enhancer +from ding.utils import set_pkg_seed +from dizoo.box2d.lunarlander.config.lunarlander_hpt_config import main_config, create_config +import torch +import torch.nn as nn + + + +def main(): + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) + ding_init(cfg) + + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], + cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], + cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + # 迁移模型到 GPU + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # model = DQN(**cfg.policy.model).to(device) + model = HPT(cfg.policy.model.obs_shape,cfg.policy.model.action_shape).to(device) + + + + # 检查模型是否在 GPU + for param in model.parameters(): + print("模型参数所在设备:", param.device) + break + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + + # 将模型传入 Policy + policy = DQNPolicy(cfg.policy, model=model) + print("日志保存路径:", cfg.exp_name) + + + # Consider the case with multiple processes + if task.router.is_active: + # You can use labels to distinguish between workers with different roles, + # here we use node_id to distinguish. + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + elif task.router.node_id == 1: + task.add_role(task.role.EVALUATOR) + else: + task.add_role(task.role.COLLECTOR) + + # Sync their context and model between each worker. + task.use(ContextExchanger(skip_n_iter=1)) + task.use(ModelExchanger(model)) + + + # Here is the part of single process pipeline. + # evaluator_env.enable_save_replay(replay_path='./video') + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(online_logger(train_show_freq=50)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000)) + task.use(termination_checker(max_env_step=int(3e6))) + task.run() + + +if __name__ == "__main__": + main() From f7f4d04aeec0858e1f9104cb8ef55bd5efcb0d58 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 27 Nov 2024 21:46:48 +0800 Subject: [PATCH 02/21] feature(xyy):add HPT model to implement PolicyStem+DuelingHead --- ding/model/template/__init__.py | 2 +- ding/model/template/hpt.py | 143 +++++++++++++++++- ding/model/template/policy_stem.py | 133 ---------------- .../entry/lunarlander_dqn_example.py | 12 +- .../entry/lunarlander_hpt_example.py | 21 +-- 5 files changed, 145 insertions(+), 166 deletions(-) delete mode 100644 ding/model/template/policy_stem.py diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index d121d7f30e..de506123ba 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -31,4 +31,4 @@ from .qgpo import QGPO from .ebm import EBM, AutoregressiveEBM from .havac import HAVAC -from .policy_stem import PolicyStem + diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index ba550a8549..e7a2960b2b 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -1,25 +1,152 @@ from typing import Union, Optional, Dict, Callable, List +from einops import rearrange, repeat import torch import torch.nn as nn - from ding.model.common.head import DuelingHead from ding.utils.registry_factory import MODEL_REGISTRY -from ding.model.template.policy_stem import PolicyStem + + +INIT_CONST = 0.02 + @MODEL_REGISTRY.register('hpt') class HPT(nn.Module): def __init__(self, state_dim, action_dim): super(HPT, self).__init__() - # 初始化 Policy Stem + # Initialise Policy Stem self.policy_stem = PolicyStem() self.policy_stem.init_cross_attn() - # Dueling Head,输入为 16*128,输出为动作维度 + # Dueling Head, input is 16*128, output is action dimension self.head = DuelingHead(hidden_size=16*128, output_size=action_dim) def forward(self, x): - # Policy Stem 输出 [B, 16, 128] + # Policy Stem Outputs [B, 16, 128] tokens = self.policy_stem.compute_latent(x) - # Flatten 操作 + # Flatten Operation tokens_flattened = tokens.view(tokens.size(0), -1) # [B, 16*128] - # 输入到 Dueling Head + # Enter to Dueling Head q_values = self.head(tokens_flattened) - return q_values \ No newline at end of file + return q_values + + + +class PolicyStem(nn.Module): + """policy stem + Overview: + The reference uses PolicyStem from + + """ + def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): + super().__init__() + # Initialise the feature extraction module + self.feature_extractor = nn.Linear(feature_dim, token_dim) + # Initialise CrossAttention + self.init_cross_attn() + + def init_cross_attn(self): + """Initialize cross attention module and learnable tokens.""" + token_num = 16 + self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) + self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) + + def compute_latent(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute latent representations of input data using attention. + + Args: + x (torch.Tensor): Input tensor with shape [B, T, ..., F]. + + Returns: + torch.Tensor: Latent tokens, shape [B, 16, 128]. + """ + # Using the Feature Extractor + stem_feat = self.feature_extractor(x) + stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) + # Calculating latent tokens using CrossAttention + stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128) + stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128) + return stem_tokens + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass to compute latent tokens. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Latent tokens tensor. + """ + return self.compute_latent(x) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def unfreeze(self): + for param in self.parameters(): + param.requires_grad = True + + def save(self, path : str): + torch.save(self.state_dict(), path) + + @property + def device(self): + return next(self.parameters()).device + +class CrossAttention(nn.Module): + """ + CrossAttention module used in the Perceiver IO model. + + Args: + query_dim (int): The dimension of the query input. + heads (int, optional): The number of attention heads. Defaults to 8. + dim_head (int, optional): The dimension of each attention head. Defaults to 64. + dropout (float, optional): The dropout probability. Defaults to 0.0. + """ + + def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = query_dim + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, query_dim) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor, context: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Forward pass of the CrossAttention module. + + Args: + x (torch.Tensor): The query input tensor. + context (torch.Tensor): The context input tensor. + mask (torch.Tensor, optional): The attention mask tensor. Defaults to None. + + Returns: + torch.Tensor: The output tensor. + """ + h = self.heads + q = self.to_q(x) + k, v = self.to_kv(context).chunk(2, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale + + if mask is not None: + # fill in the masks with negative values + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + # dropout + attn = self.dropout(attn) + out = torch.einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) \ No newline at end of file diff --git a/ding/model/template/policy_stem.py b/ding/model/template/policy_stem.py deleted file mode 100644 index 81963cdbc4..0000000000 --- a/ding/model/template/policy_stem.py +++ /dev/null @@ -1,133 +0,0 @@ -# -------------------------------------------------------- -# Licensed under The MIT License [see LICENSE for details] -# -------------------------------------------------------- - - -import torch -from torch import nn -from typing import List, Optional - - -import torch -import torch.nn as nn -INIT_CONST = 0.02 -from einops import rearrange, repeat, reduce -class CrossAttention(nn.Module): - """ - CrossAttention module used in the Perceiver IO model. - - Args: - query_dim (int): The dimension of the query input. - heads (int, optional): The number of attention heads. Defaults to 8. - dim_head (int, optional): The dimension of each attention head. Defaults to 64. - dropout (float, optional): The dropout probability. Defaults to 0.0. - """ - - def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0): - super().__init__() - inner_dim = dim_head * heads - context_dim = query_dim - self.scale = dim_head**-0.5 - self.heads = heads - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) - self.to_out = nn.Linear(inner_dim, query_dim) - - self.dropout = nn.Dropout(dropout) - - def forward(self, x: torch.Tensor, context: torch.Tensor, - mask: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - Forward pass of the CrossAttention module. - - Args: - x (torch.Tensor): The query input tensor. - context (torch.Tensor): The context input tensor. - mask (torch.Tensor, optional): The attention mask tensor. Defaults to None. - - Returns: - torch.Tensor: The output tensor. - """ - h = self.heads - q = self.to_q(x) - k, v = self.to_kv(context).chunk(2, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale - - if mask is not None: - # fill in the masks with negative values - mask = rearrange(mask, "b ... -> b (...)") - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, "b j -> (b h) () j", h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - # dropout - attn = self.dropout(attn) - out = torch.einsum("b i j, b j d -> b i d", attn, v) - out = rearrange(out, "(b h) n d -> b n (h d)", h=h) - return self.to_out(out) - -class PolicyStem(nn.Module): - """policy stem""" - - def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): - super().__init__() - # 初始化特征提取模块 - self.feature_extractor = nn.Linear(feature_dim, token_dim) - # 初始化 CrossAttention - self.init_cross_attn() - - def init_cross_attn(self): - """Initialize cross attention module and learnable tokens.""" - token_num = 16 - self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) - self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) - - def compute_latent(self, x: torch.Tensor) -> torch.Tensor: - """ - Compute latent representations of input data using attention. - - Args: - x (torch.Tensor): Input tensor with shape [B, T, ..., F]. - - Returns: - torch.Tensor: Latent tokens, shape [B, 16, 128]. - """ - # 使用特征提取器而不是直接调用 self(x) - stem_feat = self.feature_extractor(x) - stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) - # 使用 CrossAttention 计算 latent tokens - stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128) - stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128) - return stem_tokens - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass to compute latent tokens. - - Args: - x (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Latent tokens tensor. - """ - return self.compute_latent(x) - - def freeze(self): - for param in self.parameters(): - param.requires_grad = False - - def unfreeze(self): - for param in self.parameters(): - param.requires_grad = True - - def save(self, path : str): - torch.save(self.state_dict(), path) - - @property - def device(self): - return next(self.parameters()).device - diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py index bfcc1ab1d9..2ca2c06361 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py @@ -1,4 +1,5 @@ import gym +import torch from ditk import logging from ding.data.model_loader import FileModelLoader from ding.data.storage_loader import FileStorageLoader @@ -15,7 +16,7 @@ from ding.utils import set_pkg_seed from dizoo.box2d.lunarlander.config.lunarlander_dqn_config import main_config, create_config -import torch + def main(): @@ -35,19 +36,14 @@ def main(): set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - # 迁移模型到 GPU + # # Migrating models to the GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = DQN(**cfg.policy.model).to(device) - # 检查模型是否在 GPU - for param in model.parameters(): - print("模型参数所在设备:", param.device) - break buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) - # 将模型传入 Policy + # Pass the model into Policy policy = DQNPolicy(cfg.policy, model=model) - print("日志保存路径:", cfg.exp_name) # Consider the case with multiple processes diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py index b9dd6d8681..1410648dd3 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py @@ -1,5 +1,7 @@ import gym +import torch +import torch.nn as nn from ditk import logging from ding.data.model_loader import FileModelLoader from ding.data.storage_loader import FileStorageLoader @@ -16,8 +18,7 @@ nstep_reward_enhancer from ding.utils import set_pkg_seed from dizoo.box2d.lunarlander.config.lunarlander_hpt_config import main_config, create_config -import torch -import torch.nn as nn + @@ -38,24 +39,13 @@ def main(): set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - # 迁移模型到 GPU + # Migrating models to the GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # model = DQN(**cfg.policy.model).to(device) model = HPT(cfg.policy.model.obs_shape,cfg.policy.model.action_shape).to(device) - - - - # 检查模型是否在 GPU - for param in model.parameters(): - print("模型参数所在设备:", param.device) - break buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) - # 将模型传入 Policy + # Pass the model into Policy policy = DQNPolicy(cfg.policy, model=model) - print("日志保存路径:", cfg.exp_name) - # Consider the case with multiple processes if task.router.is_active: @@ -74,7 +64,6 @@ def main(): # Here is the part of single process pipeline. - # evaluator_env.enable_save_replay(replay_path='./video') task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(eps_greedy_handler(cfg)) task.use(StepCollector(cfg, policy.collect_mode, collector_env)) From 53c9d9abd26d17f241b84168d651f5a8bb08e5ce Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 27 Nov 2024 22:48:39 +0800 Subject: [PATCH 03/21] feature(xyy):add HPT model --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 59dd9ce782..01132439b0 100644 --- a/setup.py +++ b/setup.py @@ -78,6 +78,7 @@ 'sniffio', # parallel 'redis', # parallel 'mpire>=2.3.5', # parallel + 'einops=0.8.0' ], extras_require={ 'test': [ From db7ef14e9bac61bd9d0fc83ba521c4999dd40175 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 27 Nov 2024 22:53:49 +0800 Subject: [PATCH 04/21] feature(xyy):add HPT model to implement PolicyStem+DuelingHead --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 01132439b0..d4e954089c 100644 --- a/setup.py +++ b/setup.py @@ -78,7 +78,7 @@ 'sniffio', # parallel 'redis', # parallel 'mpire>=2.3.5', # parallel - 'einops=0.8.0' + 'einops=0.8.0', ], extras_require={ 'test': [ From 3ed5a7ec18f1cb571d84e48f558574d3fa28070d Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Thu, 28 Nov 2024 01:49:05 +0800 Subject: [PATCH 05/21] feature(xyy):add HPT model and examples --- ding/model/template/hpt.py | 25 +++++++++++-------- .../entry/lunarlander_dqn_example.py | 4 --- .../entry/lunarlander_hpt_example.py | 8 ++---- setup.py | 2 +- 4 files changed, 17 insertions(+), 22 deletions(-) diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index e7a2960b2b..67de99ea60 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -5,19 +5,21 @@ from ding.model.common.head import DuelingHead from ding.utils.registry_factory import MODEL_REGISTRY - INIT_CONST = 0.02 + @MODEL_REGISTRY.register('hpt') class HPT(nn.Module): + def __init__(self, state_dim, action_dim): super(HPT, self).__init__() # Initialise Policy Stem self.policy_stem = PolicyStem() self.policy_stem.init_cross_attn() - + # Dueling Head, input is 16*128, output is action dimension - self.head = DuelingHead(hidden_size=16*128, output_size=action_dim) + self.head = DuelingHead(hidden_size=16 * 128, output_size=action_dim) + def forward(self, x): # Policy Stem Outputs [B, 16, 128] tokens = self.policy_stem.compute_latent(x) @@ -28,13 +30,13 @@ def forward(self, x): return q_values - class PolicyStem(nn.Module): """policy stem Overview: The reference uses PolicyStem from """ + def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): super().__init__() # Initialise the feature extraction module @@ -59,12 +61,13 @@ def compute_latent(self, x: torch.Tensor) -> torch.Tensor: torch.Tensor: Latent tokens, shape [B, 16, 128]. """ # Using the Feature Extractor - stem_feat = self.feature_extractor(x) + stem_feat = self.feature_extractor(x) stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) # Calculating latent tokens using CrossAttention stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128) stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128) return stem_tokens + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass to compute latent tokens. @@ -76,7 +79,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch.Tensor: Latent tokens tensor. """ return self.compute_latent(x) - + def freeze(self): for param in self.parameters(): param.requires_grad = False @@ -85,13 +88,14 @@ def unfreeze(self): for param in self.parameters(): param.requires_grad = True - def save(self, path : str): + def save(self, path: str): torch.save(self.state_dict(), path) @property def device(self): return next(self.parameters()).device + class CrossAttention(nn.Module): """ CrossAttention module used in the Perceiver IO model. @@ -107,7 +111,7 @@ def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: super().__init__() inner_dim = dim_head * heads context_dim = query_dim - self.scale = dim_head**-0.5 + self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) @@ -116,8 +120,7 @@ def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: self.dropout = nn.Dropout(dropout) - def forward(self, x: torch.Tensor, context: torch.Tensor, - mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Forward pass of the CrossAttention module. @@ -149,4 +152,4 @@ def forward(self, x: torch.Tensor, context: torch.Tensor, attn = self.dropout(attn) out = torch.einsum("b i j, b j d -> b i d", attn, v) out = rearrange(out, "(b h) n d -> b n (h d)", h=h) - return self.to_out(out) \ No newline at end of file + return self.to_out(out) diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py index 2ca2c06361..9894711a31 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py @@ -17,8 +17,6 @@ from dizoo.box2d.lunarlander.config.lunarlander_dqn_config import main_config, create_config - - def main(): logging.getLogger().setLevel(logging.INFO) cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) @@ -45,7 +43,6 @@ def main(): # Pass the model into Policy policy = DQNPolicy(cfg.policy, model=model) - # Consider the case with multiple processes if task.router.is_active: # You can use labels to distinguish between workers with different roles, @@ -60,7 +57,6 @@ def main(): # Sync their context and model between each worker. task.use(ContextExchanger(skip_n_iter=1)) task.use(ModelExchanger(model)) - # Here is the part of single process pipeline. evaluator_env.enable_save_replay(replay_path='./video') diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py index 1410648dd3..c48125af10 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py @@ -1,4 +1,3 @@ - import gym import torch import torch.nn as nn @@ -20,8 +19,6 @@ from dizoo.box2d.lunarlander.config.lunarlander_hpt_config import main_config, create_config - - def main(): logging.getLogger().setLevel(logging.INFO) cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) @@ -41,11 +38,11 @@ def main(): # Migrating models to the GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = HPT(cfg.policy.model.obs_shape,cfg.policy.model.action_shape).to(device) + model = HPT(cfg.policy.model.obs_shape, cfg.policy.model.action_shape).to(device) buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) # Pass the model into Policy - policy = DQNPolicy(cfg.policy, model=model) + policy = DQNPolicy(cfg.policy, model=model) # Consider the case with multiple processes if task.router.is_active: @@ -61,7 +58,6 @@ def main(): # Sync their context and model between each worker. task.use(ContextExchanger(skip_n_iter=1)) task.use(ModelExchanger(model)) - # Here is the part of single process pipeline. task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) diff --git a/setup.py b/setup.py index d4e954089c..e0b6604600 100644 --- a/setup.py +++ b/setup.py @@ -78,7 +78,7 @@ 'sniffio', # parallel 'redis', # parallel 'mpire>=2.3.5', # parallel - 'einops=0.8.0', + 'einops<=0.8.0', ], extras_require={ 'test': [ From 912b37d80dc1bd80786d03c4be6f3b282f1628e2 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Thu, 28 Nov 2024 09:04:02 +0800 Subject: [PATCH 06/21] feature(xyy):add HPT model and examples --- ding/model/template/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index de506123ba..a25a448b61 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -31,4 +31,3 @@ from .qgpo import QGPO from .ebm import EBM, AutoregressiveEBM from .havac import HAVAC - From 9afedc76bcd3ca29879f28fd7c954f61ca0e4282 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 4 Dec 2024 20:14:16 +0800 Subject: [PATCH 07/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/__init__.py | 1 - ding/model/template/hpt.py | 111 +++++++++++++----- ding/model/template/tests/test_hpt.py | 42 +++++++ .../config/lunarlander_hpt_config.py | 3 +- .../entry/lunarlander_dqn_example.py | 2 +- .../entry/lunarlander_hpt_example.py | 1 + 6 files changed, 128 insertions(+), 32 deletions(-) create mode 100644 ding/model/template/tests/test_hpt.py diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index a25a448b61..d79b4a1919 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -25,7 +25,6 @@ from .decision_transformer import DecisionTransformer from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS from .hpt import HPT - from .bcq import BCQ from .edac import EDAC from .qgpo import QGPO diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index 67de99ea60..a09f52da47 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -10,8 +10,30 @@ @MODEL_REGISTRY.register('hpt') class HPT(nn.Module): + """ + Overview: + The HPT model for reinforcement learning, which consists of a Policy Stem and a Dueling Head. The Policy Stem \ + utilizes cross-attention to process input data, and the Dueling Head computes Q-values for discrete action spaces. + + Interfaces: + __init__, forward + + .. note:: + The model is designed to be flexible and can be adapted for different input dimensions and action spaces. + """ + + def __init__(self, state_dim: int, action_dim: int): + """ + Overview: + Initialize the HPT model, including the Policy Stem and the Dueling Head. - def __init__(self, state_dim, action_dim): + Arguments: + - state_dim (:obj:`int`): The dimension of the input state. + - action_dim (:obj:`int`): The dimension of the action space. + + .. note:: + The Policy Stem is initialized with cross-attention, and the Dueling Head is set to process the resulting tokens. + """ super(HPT, self).__init__() # Initialise Policy Stem self.policy_stem = PolicyStem() @@ -20,7 +42,17 @@ def __init__(self, state_dim, action_dim): # Dueling Head, input is 16*128, output is action dimension self.head = DuelingHead(hidden_size=16 * 128, output_size=action_dim) - def forward(self, x): + def forward(self, x: torch.Tensor): + """ + Overview: + Forward pass of the HPT model. Computes latent tokens from the input state and passes them through the Dueling Head. + + Arguments: + - x (:obj:`torch.Tensor`): The input tensor representing the state. + + Returns: + - q_values (:obj:`torch.Tensor`): The predicted Q-values for each action. + """ # Policy Stem Outputs [B, 16, 128] tokens = self.policy_stem.compute_latent(x) # Flatten Operation @@ -31,13 +63,27 @@ def forward(self, x): class PolicyStem(nn.Module): - """policy stem - Overview: - The reference uses PolicyStem from - + """ + Overview: + The Policy Stem module is responsible for processing input features and generating latent tokens using a cross-attention mechanism. + It extracts features from the input and then applies cross-attention to generate a set of latent tokens. + + Interfaces: + __init__, init_cross_attn, compute_latent, forward + + .. note:: + This module is inspired by the implementation in the Perceiver IO model and uses attention mechanisms for feature extraction. """ def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): + """ + Overview: + Initialize the Policy Stem module with a feature extractor and cross-attention mechanism. + + Arguments: + - feature_dim (:obj:`int`): The dimension of the input features. + - token_dim (:obj:`int`): The dimension of the latent tokens generated by the attention mechanism. + """ super().__init__() # Initialise the feature extraction module self.feature_extractor = nn.Linear(feature_dim, token_dim) @@ -45,20 +91,21 @@ def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): self.init_cross_attn() def init_cross_attn(self): - """Initialize cross attention module and learnable tokens.""" + """Initialize cross-attention module and learnable tokens.""" token_num = 16 self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) def compute_latent(self, x: torch.Tensor) -> torch.Tensor: """ - Compute latent representations of input data using attention. + Overview: + Compute latent representations of the input data using the feature extractor and cross-attention. - Args: - x (torch.Tensor): Input tensor with shape [B, T, ..., F]. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor with shape [B, T, ..., F]. Returns: - torch.Tensor: Latent tokens, shape [B, 16, 128]. + - stem_tokens (:obj:`torch.Tensor`): Latent tokens with shape [B, 16, 128]. """ # Using the Feature Extractor stem_feat = self.feature_extractor(x) @@ -70,41 +117,48 @@ def compute_latent(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Forward pass to compute latent tokens. + Overview: + Forward pass to compute latent tokens. - Args: - x (torch.Tensor): Input tensor. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor. Returns: - torch.Tensor: Latent tokens tensor. + - torch.Tensor: Latent tokens tensor. """ return self.compute_latent(x) def freeze(self): + """Freeze the parameters of the model, preventing updates during training.""" for param in self.parameters(): param.requires_grad = False def unfreeze(self): + """Unfreeze the parameters of the model, allowing updates during training.""" for param in self.parameters(): param.requires_grad = True def save(self, path: str): + """Save the model state dictionary to a file.""" torch.save(self.state_dict(), path) @property def device(self): + """Returns the device on which the model parameters are located.""" return next(self.parameters()).device class CrossAttention(nn.Module): """ - CrossAttention module used in the Perceiver IO model. - - Args: - query_dim (int): The dimension of the query input. - heads (int, optional): The number of attention heads. Defaults to 8. - dim_head (int, optional): The dimension of each attention head. Defaults to 64. - dropout (float, optional): The dropout probability. Defaults to 0.0. + Overview: + CrossAttention module used in the Perceiver IO model. It computes the attention between the query and context tensors, + and returns the output tensor after applying attention. + + Arguments: + query_dim (:obj:`int`): The dimension of the query input. + heads (:obj:`int`, optional): The number of attention heads. Defaults to 8. + dim_head (:obj:`int`, optional): The dimension of each attention head. Defaults to 64. + dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0. """ def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0): @@ -122,15 +176,16 @@ def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Forward pass of the CrossAttention module. + Overview: + Forward pass of the CrossAttention module. Computes the attention between the query and context tensors. - Args: - x (torch.Tensor): The query input tensor. - context (torch.Tensor): The context input tensor. - mask (torch.Tensor, optional): The attention mask tensor. Defaults to None. + Arguments: + - x (:obj:`torch.Tensor`): The query input tensor. + - context (:obj:`torch.Tensor`): The context input tensor. + - mask (:obj:`torch.Tensor`, optional): The attention mask tensor. Defaults to None. Returns: - torch.Tensor: The output tensor. + - torch.Tensor: The output tensor after applying attention. """ h = self.heads q = self.to_q(x) diff --git a/ding/model/template/tests/test_hpt.py b/ding/model/template/tests/test_hpt.py new file mode 100644 index 0000000000..1f77c308a7 --- /dev/null +++ b/ding/model/template/tests/test_hpt.py @@ -0,0 +1,42 @@ +import pytest +import torch +from itertools import product +from ding.model.template.hpt import HPT +from ding.torch_utils import is_differentiable + + +T, B = 3, 4 +obs_shape = [4, (8,), (4, 64, 64)] # Example observation shapes +act_shape = [3, (6,), [2, 3, 6]] # Example action shapes +args = list(product(*[obs_shape, act_shape])) + +@pytest.mark.unittest +class TestHPT: + + def output_check(self, model, outputs): + if isinstance(outputs, torch.Tensor): + loss = outputs.sum() + elif isinstance(outputs, list): + loss = sum([t.sum() for t in outputs]) + elif isinstance(outputs, dict): + loss = sum([v.sum() for v in outputs.values()]) + is_differentiable(loss, model) + + @pytest.mark.parametrize('obs_shape, act_shape', args) + def test_hpt(self, obs_shape, act_shape): + if isinstance(obs_shape, int): + inputs = torch.randn(B, obs_shape) + else: + inputs = torch.randn(B, *obs_shape) + model = HPT(state_dim=obs_shape, action_dim=act_shape) + outputs = model(inputs) + assert isinstance(outputs, torch.Tensor) + if isinstance(act_shape, int): + assert outputs.shape == (B, act_shape) + elif len(act_shape) == 1: + assert outputs.shape == (B, *act_shape) + else: + for i, s in enumerate(act_shape): + assert outputs[i].shape == (B, s) + self.output_check(model, outputs) + \ No newline at end of file diff --git a/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py b/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py index 9df8a034c1..fea3aca6bc 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py @@ -65,13 +65,12 @@ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], ), env_manager=dict(type='subprocess'), - # env_manager=dict(type='base'), policy=dict(type='dqn'), ) lunarlander_hpt_create_config = EasyDict(lunarlander_hpt_create_config) create_config = lunarlander_hpt_create_config if __name__ == "__main__": - # or you can enter `ding -m serial -c lunarlander_dqn_config.py -s 0` + # or you can enter `ding -m serial -c lunarlander_hpt_config.py -s 0` from ding.entry import serial_pipeline serial_pipeline([main_config, create_config], seed=0) diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py index 9894711a31..048554c310 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py @@ -34,7 +34,7 @@ def main(): set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) - # # Migrating models to the GPU + # Migrating models to the GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = DQN(**cfg.policy.model).to(device) diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py index c48125af10..5bcf7be472 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py @@ -38,6 +38,7 @@ def main(): # Migrating models to the GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # HPT introduces a Policy Stem module, which processes the input features using Cross-Attention and generates a set of latent tokens. model = HPT(cfg.policy.model.obs_shape, cfg.policy.model.action_shape).to(device) buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) From 19da3b3fd2dc5150222403f7303700c03c37a0e9 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 4 Dec 2024 20:16:30 +0800 Subject: [PATCH 08/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index d79b4a1919..d9b05b59dd 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -24,9 +24,9 @@ from .vae import VanillaVAE from .decision_transformer import DecisionTransformer from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS -from .hpt import HPT from .bcq import BCQ from .edac import EDAC +from .hpt import HPT from .qgpo import QGPO from .ebm import EBM, AutoregressiveEBM from .havac import HAVAC From 26f4c974746d3008931d04e1e7be2eda9e7c1836 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 4 Dec 2024 20:26:11 +0800 Subject: [PATCH 09/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/tests/test_hpt.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ding/model/template/tests/test_hpt.py b/ding/model/template/tests/test_hpt.py index 1f77c308a7..5b17bbefe5 100644 --- a/ding/model/template/tests/test_hpt.py +++ b/ding/model/template/tests/test_hpt.py @@ -4,12 +4,12 @@ from ding.model.template.hpt import HPT from ding.torch_utils import is_differentiable - T, B = 3, 4 -obs_shape = [4, (8,), (4, 64, 64)] # Example observation shapes -act_shape = [3, (6,), [2, 3, 6]] # Example action shapes +obs_shape = [4, (8, ), (4, 64, 64)] # Example observation shapes +act_shape = [3, (6, ), [2, 3, 6]] # Example action shapes args = list(product(*[obs_shape, act_shape])) + @pytest.mark.unittest class TestHPT: @@ -39,4 +39,3 @@ def test_hpt(self, obs_shape, act_shape): for i, s in enumerate(act_shape): assert outputs[i].shape == (B, s) self.output_check(model, outputs) - \ No newline at end of file From 79aa427d3b4249e2cfde574089a8574dd05bdb7c Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 4 Dec 2024 20:41:45 +0800 Subject: [PATCH 10/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/hpt.py | 12 ++++++++---- .../lunarlander/entry/lunarlander_hpt_example.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index a09f52da47..d55a91ab74 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -45,7 +45,8 @@ def __init__(self, state_dim: int, action_dim: int): def forward(self, x: torch.Tensor): """ Overview: - Forward pass of the HPT model. Computes latent tokens from the input state and passes them through the Dueling Head. + Forward pass of the HPT model. + Computes latent tokens from the input state and passes them through the Dueling Head. Arguments: - x (:obj:`torch.Tensor`): The input tensor representing the state. @@ -65,7 +66,8 @@ def forward(self, x: torch.Tensor): class PolicyStem(nn.Module): """ Overview: - The Policy Stem module is responsible for processing input features and generating latent tokens using a cross-attention mechanism. + The Policy Stem module is responsible for processing input features + and generating latent tokens using a cross-attention mechanism. It extracts features from the input and then applies cross-attention to generate a set of latent tokens. Interfaces: @@ -151,7 +153,8 @@ def device(self): class CrossAttention(nn.Module): """ Overview: - CrossAttention module used in the Perceiver IO model. It computes the attention between the query and context tensors, + CrossAttention module used in the Perceiver IO model. + It computes the attention between the query and context tensors, and returns the output tensor after applying attention. Arguments: @@ -177,7 +180,8 @@ def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Overview: - Forward pass of the CrossAttention module. Computes the attention between the query and context tensors. + Forward pass of the CrossAttention module. + Computes the attention between the query and context tensors. Arguments: - x (:obj:`torch.Tensor`): The query input tensor. diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py index 5bcf7be472..f5da04eaf9 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py @@ -38,7 +38,7 @@ def main(): # Migrating models to the GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # HPT introduces a Policy Stem module, which processes the input features using Cross-Attention and generates a set of latent tokens. + # HPT introduces a Policy Stem module, which processes the input features using Cross-Attention. model = HPT(cfg.policy.model.obs_shape, cfg.policy.model.action_shape).to(device) buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) From 0709f83d4e845b8d71a56901192839a3707ba06a Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 4 Dec 2024 20:55:40 +0800 Subject: [PATCH 11/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/hpt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index d55a91ab74..908f977774 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -12,8 +12,9 @@ class HPT(nn.Module): """ Overview: - The HPT model for reinforcement learning, which consists of a Policy Stem and a Dueling Head. The Policy Stem \ - utilizes cross-attention to process input data, and the Dueling Head computes Q-values for discrete action spaces. + The HPT model for reinforcement learning, which consists of a Policy Stem and a Dueling Head. + The Policy Stem utilizes cross-attention to process input data, + and the Dueling Head computes Q-values for discrete action spaces. Interfaces: __init__, forward From 32f147f3d4cb8955ee789cec582e1f248cbe7773 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 4 Dec 2024 21:04:23 +0800 Subject: [PATCH 12/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/hpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index 908f977774..4264a9487c 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -12,8 +12,8 @@ class HPT(nn.Module): """ Overview: - The HPT model for reinforcement learning, which consists of a Policy Stem and a Dueling Head. - The Policy Stem utilizes cross-attention to process input data, + The HPT model for reinforcement learning, which consists of a Policy Stem and a Dueling Head. + The Policy Stem utilizes cross-attention to process input data, and the Dueling Head computes Q-values for discrete action spaces. Interfaces: From f3d55075eab73744c2e18da1d8b6e12f9d5fa0e9 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 4 Dec 2024 21:32:21 +0800 Subject: [PATCH 13/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/hpt.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index 4264a9487c..daffce3c55 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -97,7 +97,8 @@ def init_cross_attn(self): """Initialize cross-attention module and learnable tokens.""" token_num = 16 self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) - self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) + self.cross_attention = CrossAttention( + 128, heads=8, dim_head=64, dropout=0.1) def compute_latent(self, x: torch.Tensor) -> torch.Tensor: """ @@ -112,10 +113,12 @@ def compute_latent(self, x: torch.Tensor) -> torch.Tensor: """ # Using the Feature Extractor stem_feat = self.feature_extractor(x) - stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) + stem_feat = stem_feat.reshape( + stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) # Calculating latent tokens using CrossAttention stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128) - stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128) + stem_tokens = self.cross_attention( + stem_tokens, stem_feat) # (B, 16, 128) return stem_tokens def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -195,7 +198,8 @@ def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.T h = self.heads q = self.to_q(x) k, v = self.to_kv(context).chunk(2, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange( + t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale if mask is not None: From 25f1d2fe5e2b0e49669851365255f6b26f2db669 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 4 Dec 2024 21:43:04 +0800 Subject: [PATCH 14/21] feature(xyy):add HPT model and test_hpt --- .../lunarlander/entry/lunarlander_dqn_example.py | 12 ++++++++---- .../lunarlander/entry/lunarlander_hpt_example.py | 15 ++++++++++----- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py index 048554c310..a90344e6c0 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py @@ -19,16 +19,19 @@ def main(): logging.getLogger().setLevel(logging.INFO) - cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) + cfg = compile_config(main_config, create_cfg=create_config, + auto=True, save_cfg=task.router.node_id == 0) ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = SubprocessEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper( + gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = SubprocessEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper( + gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) @@ -38,7 +41,8 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = DQN(**cfg.policy.model).to(device) - buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + buffer_ = DequeBuffer( + size=cfg.policy.other.replay_buffer.replay_buffer_size) # Pass the model into Policy policy = DQNPolicy(cfg.policy, model=model) diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py index f5da04eaf9..1fe52cffdd 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py @@ -21,16 +21,19 @@ def main(): logging.getLogger().setLevel(logging.INFO) - cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) + cfg = compile_config(main_config, create_cfg=create_config, + auto=True, save_cfg=task.router.node_id == 0) ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = SubprocessEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper( + gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = SubprocessEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper( + gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) @@ -39,8 +42,10 @@ def main(): # Migrating models to the GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # HPT introduces a Policy Stem module, which processes the input features using Cross-Attention. - model = HPT(cfg.policy.model.obs_shape, cfg.policy.model.action_shape).to(device) - buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + model = HPT(cfg.policy.model.obs_shape, + cfg.policy.model.action_shape).to(device) + buffer_ = DequeBuffer( + size=cfg.policy.other.replay_buffer.replay_buffer_size) # Pass the model into Policy policy = DQNPolicy(cfg.policy, model=model) From 188759bc18d5478d06a9cd4a9522eb50f3a54fc9 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 4 Dec 2024 23:48:55 +0800 Subject: [PATCH 15/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/hpt.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index daffce3c55..4264a9487c 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -97,8 +97,7 @@ def init_cross_attn(self): """Initialize cross-attention module and learnable tokens.""" token_num = 16 self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) - self.cross_attention = CrossAttention( - 128, heads=8, dim_head=64, dropout=0.1) + self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) def compute_latent(self, x: torch.Tensor) -> torch.Tensor: """ @@ -113,12 +112,10 @@ def compute_latent(self, x: torch.Tensor) -> torch.Tensor: """ # Using the Feature Extractor stem_feat = self.feature_extractor(x) - stem_feat = stem_feat.reshape( - stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) + stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) # Calculating latent tokens using CrossAttention stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128) - stem_tokens = self.cross_attention( - stem_tokens, stem_feat) # (B, 16, 128) + stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128) return stem_tokens def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -198,8 +195,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.T h = self.heads q = self.to_q(x) k, v = self.to_kv(context).chunk(2, dim=-1) - q, k, v = map(lambda t: rearrange( - t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale if mask is not None: From cbe7dea79d85c9d4e58ce0f5968725a4fc058df3 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 4 Dec 2024 23:59:30 +0800 Subject: [PATCH 16/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/hpt.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index 4264a9487c..a652efee38 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -20,7 +20,8 @@ class HPT(nn.Module): __init__, forward .. note:: - The model is designed to be flexible and can be adapted for different input dimensions and action spaces. + The model is designed to be flexible and can be adapted + for different input dimensions and action spaces. """ def __init__(self, state_dim: int, action_dim: int): @@ -33,7 +34,8 @@ def __init__(self, state_dim: int, action_dim: int): - action_dim (:obj:`int`): The dimension of the action space. .. note:: - The Policy Stem is initialized with cross-attention, and the Dueling Head is set to process the resulting tokens. + The Policy Stem is initialized with cross-attention, + and the Dueling Head is set to process the resulting tokens. """ super(HPT, self).__init__() # Initialise Policy Stem @@ -69,13 +71,15 @@ class PolicyStem(nn.Module): Overview: The Policy Stem module is responsible for processing input features and generating latent tokens using a cross-attention mechanism. - It extracts features from the input and then applies cross-attention to generate a set of latent tokens. + It extracts features from the input and then applies cross-attention + to generate a set of latent tokens. Interfaces: __init__, init_cross_attn, compute_latent, forward .. note:: - This module is inspired by the implementation in the Perceiver IO model and uses attention mechanisms for feature extraction. + This module is inspired by the implementation in the Perceiver IO model + and uses attention mechanisms for feature extraction. """ def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): @@ -85,7 +89,8 @@ def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): Arguments: - feature_dim (:obj:`int`): The dimension of the input features. - - token_dim (:obj:`int`): The dimension of the latent tokens generated by the attention mechanism. + - token_dim (:obj:`int`): The dimension of the latent tokens generated + by the attention mechanism. """ super().__init__() # Initialise the feature extraction module @@ -97,12 +102,14 @@ def init_cross_attn(self): """Initialize cross-attention module and learnable tokens.""" token_num = 16 self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) - self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) + self.cross_attention = CrossAttention( + 128, heads=8, dim_head=64, dropout=0.1) def compute_latent(self, x: torch.Tensor) -> torch.Tensor: """ Overview: - Compute latent representations of the input data using the feature extractor and cross-attention. + Compute latent representations of the input data using + the feature extractor and cross-attention. Arguments: - x (:obj:`torch.Tensor`): Input tensor with shape [B, T, ..., F]. @@ -112,10 +119,12 @@ def compute_latent(self, x: torch.Tensor) -> torch.Tensor: """ # Using the Feature Extractor stem_feat = self.feature_extractor(x) - stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) + stem_feat = stem_feat.reshape( + stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) # Calculating latent tokens using CrossAttention stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128) - stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128) + stem_tokens = self.cross_attention( + stem_tokens, stem_feat) # (B, 16, 128) return stem_tokens def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -165,7 +174,9 @@ class CrossAttention(nn.Module): dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0. """ - def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0): + def __init__(self, query_dim: int, + heads: int = 8, + dim_head: int = 64, dropout: float = 0.0): super().__init__() inner_dim = dim_head * heads context_dim = query_dim @@ -178,7 +189,9 @@ def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: self.dropout = nn.Dropout(dropout) - def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, + context: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Overview: Forward pass of the CrossAttention module. @@ -195,7 +208,8 @@ def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.T h = self.heads q = self.to_q(x) k, v = self.to_kv(context).chunk(2, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange( + t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale if mask is not None: From 611343340964abe9e0e55e3f3741c7f565a5204e Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Thu, 5 Dec 2024 00:12:48 +0800 Subject: [PATCH 17/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/hpt.py | 20 ++++------- dizoo/atari/config/serial/__init__.py | 2 +- .../serial/phoenix/phoenix_fqf_config.py | 2 +- dizoo/atari/config/serial/pong/__init__.py | 2 +- .../config/serial/pong/pong_fqf_config.py | 2 +- .../serial/pong/pong_gail_dqn_config.py | 2 +- dizoo/atari/config/serial/qbert/__init__.py | 2 +- .../config/serial/qbert/qbert_fqf_config.py | 2 +- .../config/serial/spaceinvaders/__init__.py | 2 +- .../spaceinvaders/spaceinvaders_fqf_config.py | 2 +- dizoo/beergame/entry/beergame_eval.py | 2 +- .../config/lunarlander_sqil_config.py | 2 +- .../entry/lunarlander_dqn_example.py | 12 +++---- .../entry/lunarlander_hpt_example.py | 15 +++------ .../cartpole/config/cartpole_fqf_config.py | 2 +- .../pendulum/config/__init__.py | 2 +- .../config/dmc2gym_sac_pixel_config.py | 2 +- dizoo/minigrid/__init__.py | 2 +- .../config/minigrid_icm_onppo_config.py | 2 +- dizoo/minigrid/envs/__init__.py | 2 +- .../mujoco/config/halfcheetah_ddpg_config.py | 2 +- .../config/ptz_pistonball_qmix_config.py | 8 ++--- .../envs/petting_zoo_pistonball_env.py | 33 ++++++++----------- .../envs/test_petting_zoo_pistonball_env.py | 2 +- 24 files changed, 51 insertions(+), 75 deletions(-) diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index a652efee38..27d194ea88 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -102,8 +102,7 @@ def init_cross_attn(self): """Initialize cross-attention module and learnable tokens.""" token_num = 16 self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) - self.cross_attention = CrossAttention( - 128, heads=8, dim_head=64, dropout=0.1) + self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) def compute_latent(self, x: torch.Tensor) -> torch.Tensor: """ @@ -119,12 +118,10 @@ def compute_latent(self, x: torch.Tensor) -> torch.Tensor: """ # Using the Feature Extractor stem_feat = self.feature_extractor(x) - stem_feat = stem_feat.reshape( - stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) + stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) # Calculating latent tokens using CrossAttention stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128) - stem_tokens = self.cross_attention( - stem_tokens, stem_feat) # (B, 16, 128) + stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128) return stem_tokens def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -174,9 +171,7 @@ class CrossAttention(nn.Module): dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0. """ - def __init__(self, query_dim: int, - heads: int = 8, - dim_head: int = 64, dropout: float = 0.0): + def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0): super().__init__() inner_dim = dim_head * heads context_dim = query_dim @@ -189,9 +184,7 @@ def __init__(self, query_dim: int, self.dropout = nn.Dropout(dropout) - def forward(self, x: torch.Tensor, - context: torch.Tensor, - mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Overview: Forward pass of the CrossAttention module. @@ -208,8 +201,7 @@ def forward(self, x: torch.Tensor, h = self.heads q = self.to_q(x) k, v = self.to_kv(context).chunk(2, dim=-1) - q, k, v = map(lambda t: rearrange( - t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale if mask is not None: diff --git a/dizoo/atari/config/serial/__init__.py b/dizoo/atari/config/serial/__init__.py index 1ecd50235d..ce0ea3b805 100644 --- a/dizoo/atari/config/serial/__init__.py +++ b/dizoo/atari/config/serial/__init__.py @@ -2,4 +2,4 @@ from dizoo.atari.config.serial.pong import * from dizoo.atari.config.serial.qbert import * from dizoo.atari.config.serial.spaceinvaders import * -from dizoo.atari.config.serial.asterix import * \ No newline at end of file +from dizoo.atari.config.serial.asterix import * diff --git a/dizoo/atari/config/serial/phoenix/phoenix_fqf_config.py b/dizoo/atari/config/serial/phoenix/phoenix_fqf_config.py index bc0273ad56..736a7a6cff 100644 --- a/dizoo/atari/config/serial/phoenix/phoenix_fqf_config.py +++ b/dizoo/atari/config/serial/phoenix/phoenix_fqf_config.py @@ -59,4 +59,4 @@ if __name__ == '__main__': # or you can enter `ding -m serial -c phoenix_fqf_config.py -s 0` from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), seed=0) \ No newline at end of file + serial_pipeline((main_config, create_config), seed=0) diff --git a/dizoo/atari/config/serial/pong/__init__.py b/dizoo/atari/config/serial/pong/__init__.py index 5ce3db9a5b..93cba0f3ac 100644 --- a/dizoo/atari/config/serial/pong/__init__.py +++ b/dizoo/atari/config/serial/pong/__init__.py @@ -1,3 +1,3 @@ from .pong_dqn_config import pong_dqn_config, pong_dqn_create_config from .pong_dqn_envpool_config import pong_dqn_envpool_config, pong_dqn_envpool_create_config -from .pong_dqfd_config import pong_dqfd_config, pong_dqfd_create_config \ No newline at end of file +from .pong_dqfd_config import pong_dqfd_config, pong_dqfd_create_config diff --git a/dizoo/atari/config/serial/pong/pong_fqf_config.py b/dizoo/atari/config/serial/pong/pong_fqf_config.py index 25a788aa0b..e1dbc346e9 100644 --- a/dizoo/atari/config/serial/pong/pong_fqf_config.py +++ b/dizoo/atari/config/serial/pong/pong_fqf_config.py @@ -59,4 +59,4 @@ if __name__ == '__main__': # or you can enter `ding -m serial -c pong_fqf_config.py -s 0` from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), seed=0) \ No newline at end of file + serial_pipeline((main_config, create_config), seed=0) diff --git a/dizoo/atari/config/serial/pong/pong_gail_dqn_config.py b/dizoo/atari/config/serial/pong/pong_gail_dqn_config.py index 505b75b626..201ea414ff 100644 --- a/dizoo/atari/config/serial/pong/pong_gail_dqn_config.py +++ b/dizoo/atari/config/serial/pong/pong_gail_dqn_config.py @@ -87,4 +87,4 @@ max_env_step=1000000, seed=0, collect_data=True - ) \ No newline at end of file + ) diff --git a/dizoo/atari/config/serial/qbert/__init__.py b/dizoo/atari/config/serial/qbert/__init__.py index 5032c3a751..ee170aba99 100644 --- a/dizoo/atari/config/serial/qbert/__init__.py +++ b/dizoo/atari/config/serial/qbert/__init__.py @@ -1,2 +1,2 @@ from .qbert_dqn_config import qbert_dqn_config, qbert_dqn_create_config -from .qbert_dqfd_config import qbert_dqfd_config, qbert_dqfd_create_config \ No newline at end of file +from .qbert_dqfd_config import qbert_dqfd_config, qbert_dqfd_create_config diff --git a/dizoo/atari/config/serial/qbert/qbert_fqf_config.py b/dizoo/atari/config/serial/qbert/qbert_fqf_config.py index 3241c924f3..af1c7f5035 100644 --- a/dizoo/atari/config/serial/qbert/qbert_fqf_config.py +++ b/dizoo/atari/config/serial/qbert/qbert_fqf_config.py @@ -61,4 +61,4 @@ if __name__ == '__main__': # or you can enter `ding -m serial -c qbert_fqf_config.py -s 0` from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), seed=0) \ No newline at end of file + serial_pipeline((main_config, create_config), seed=0) diff --git a/dizoo/atari/config/serial/spaceinvaders/__init__.py b/dizoo/atari/config/serial/spaceinvaders/__init__.py index f4ff222aa6..8be7e58bfe 100644 --- a/dizoo/atari/config/serial/spaceinvaders/__init__.py +++ b/dizoo/atari/config/serial/spaceinvaders/__init__.py @@ -1,2 +1,2 @@ from .spaceinvaders_dqn_config import spaceinvaders_dqn_config, spaceinvaders_dqn_create_config -from .spaceinvaders_dqfd_config import spaceinvaders_dqfd_config, spaceinvaders_dqfd_create_config \ No newline at end of file +from .spaceinvaders_dqfd_config import spaceinvaders_dqfd_config, spaceinvaders_dqfd_create_config diff --git a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_fqf_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_fqf_config.py index 95df0d4657..ce3cf7c52b 100644 --- a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_fqf_config.py +++ b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_fqf_config.py @@ -60,4 +60,4 @@ if __name__ == '__main__': # or you can enter `ding -m serial -c spaceinvaders_fqf_config.py -s 0` from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), seed=0) \ No newline at end of file + serial_pipeline((main_config, create_config), seed=0) diff --git a/dizoo/beergame/entry/beergame_eval.py b/dizoo/beergame/entry/beergame_eval.py index 5299107e78..c37a076f67 100644 --- a/dizoo/beergame/entry/beergame_eval.py +++ b/dizoo/beergame/entry/beergame_eval.py @@ -39,4 +39,4 @@ def main(cfg, seed=0): if __name__ == "__main__": beergame_ppo_config.exp_name = 'beergame_evaluate' - main(beergame_ppo_config) \ No newline at end of file + main(beergame_ppo_config) diff --git a/dizoo/box2d/lunarlander/config/lunarlander_sqil_config.py b/dizoo/box2d/lunarlander/config/lunarlander_sqil_config.py index 638c2d2981..b5deffc897 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_sqil_config.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_sqil_config.py @@ -63,4 +63,4 @@ from dizoo.box2d.lunarlander.config import lunarlander_dqn_config, lunarlander_dqn_create_config expert_main_config = lunarlander_dqn_config expert_create_config = lunarlander_dqn_create_config - serial_pipeline_sqil([main_config, create_config], [expert_main_config, expert_create_config], seed=0) \ No newline at end of file + serial_pipeline_sqil([main_config, create_config], [expert_main_config, expert_create_config], seed=0) diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py index a90344e6c0..048554c310 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py @@ -19,19 +19,16 @@ def main(): logging.getLogger().setLevel(logging.INFO) - cfg = compile_config(main_config, create_cfg=create_config, - auto=True, save_cfg=task.router.node_id == 0) + cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = SubprocessEnvManagerV2( - env_fn=[lambda: DingEnvWrapper( - gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = SubprocessEnvManagerV2( - env_fn=[lambda: DingEnvWrapper( - gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) @@ -41,8 +38,7 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = DQN(**cfg.policy.model).to(device) - buffer_ = DequeBuffer( - size=cfg.policy.other.replay_buffer.replay_buffer_size) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) # Pass the model into Policy policy = DQNPolicy(cfg.policy, model=model) diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py index 1fe52cffdd..f5da04eaf9 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py @@ -21,19 +21,16 @@ def main(): logging.getLogger().setLevel(logging.INFO) - cfg = compile_config(main_config, create_cfg=create_config, - auto=True, save_cfg=task.router.node_id == 0) + cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = SubprocessEnvManagerV2( - env_fn=[lambda: DingEnvWrapper( - gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = SubprocessEnvManagerV2( - env_fn=[lambda: DingEnvWrapper( - gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) @@ -42,10 +39,8 @@ def main(): # Migrating models to the GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # HPT introduces a Policy Stem module, which processes the input features using Cross-Attention. - model = HPT(cfg.policy.model.obs_shape, - cfg.policy.model.action_shape).to(device) - buffer_ = DequeBuffer( - size=cfg.policy.other.replay_buffer.replay_buffer_size) + model = HPT(cfg.policy.model.obs_shape, cfg.policy.model.action_shape).to(device) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) # Pass the model into Policy policy = DQNPolicy(cfg.policy, model=model) diff --git a/dizoo/classic_control/cartpole/config/cartpole_fqf_config.py b/dizoo/classic_control/cartpole/config/cartpole_fqf_config.py index ca65670d81..3c1103ba75 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_fqf_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_fqf_config.py @@ -60,4 +60,4 @@ if __name__ == '__main__': # or you can enter `ding -m serial -c cartpole_fqf_config.py -s 0` from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), seed=0) \ No newline at end of file + serial_pipeline((main_config, create_config), seed=0) diff --git a/dizoo/classic_control/pendulum/config/__init__.py b/dizoo/classic_control/pendulum/config/__init__.py index e7c2988f06..642386797f 100644 --- a/dizoo/classic_control/pendulum/config/__init__.py +++ b/dizoo/classic_control/pendulum/config/__init__.py @@ -4,4 +4,4 @@ from .pendulum_d4pg_config import pendulum_d4pg_config, pendulum_d4pg_create_config from .pendulum_ppo_config import pendulum_ppo_config, pendulum_ppo_create_config from .pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config -from .pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config \ No newline at end of file +from .pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config diff --git a/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py b/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py index c0155b1ebd..88c4e6a9d7 100644 --- a/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py +++ b/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py @@ -76,4 +76,4 @@ if __name__ == "__main__": # or you can enter `ding -m serial -c ant_sac_config.py -s 0 --env-step 1e7` from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), seed=0) \ No newline at end of file + serial_pipeline((main_config, create_config), seed=0) diff --git a/dizoo/minigrid/__init__.py b/dizoo/minigrid/__init__.py index db6673867b..b05d8f7ceb 100644 --- a/dizoo/minigrid/__init__.py +++ b/dizoo/minigrid/__init__.py @@ -12,4 +12,4 @@ register(id='MiniGrid-AKTDT-19x19-3-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_19x19_3') -register(id='MiniGrid-NoisyTV-v0', entry_point='dizoo.minigrid.envs:NoisyTVEnv') \ No newline at end of file +register(id='MiniGrid-NoisyTV-v0', entry_point='dizoo.minigrid.envs:NoisyTVEnv') diff --git a/dizoo/minigrid/config/minigrid_icm_onppo_config.py b/dizoo/minigrid/config/minigrid_icm_onppo_config.py index dc21fe5fc4..9e56b30cbe 100644 --- a/dizoo/minigrid/config/minigrid_icm_onppo_config.py +++ b/dizoo/minigrid/config/minigrid_icm_onppo_config.py @@ -79,4 +79,4 @@ if __name__ == "__main__": # or you can enter `ding -m serial -c minigrid_icm_onppo_config.py -s 0` from ding.entry import serial_pipeline_reward_model_onpolicy - serial_pipeline_reward_model_onpolicy([main_config, create_config], seed=0, max_env_step=int(10e6)) \ No newline at end of file + serial_pipeline_reward_model_onpolicy([main_config, create_config], seed=0, max_env_step=int(10e6)) diff --git a/dizoo/minigrid/envs/__init__.py b/dizoo/minigrid/envs/__init__.py index 02d73004ff..46eb0846e1 100644 --- a/dizoo/minigrid/envs/__init__.py +++ b/dizoo/minigrid/envs/__init__.py @@ -1,3 +1,3 @@ from .minigrid_env import MiniGridEnv from dizoo.minigrid.envs.app_key_to_door_treasure import AppleKeyToDoorTreasure, AppleKeyToDoorTreasure_13x13, AppleKeyToDoorTreasure_19x19, AppleKeyToDoorTreasure_13x13_1, AppleKeyToDoorTreasure_19x19_3, AppleKeyToDoorTreasure_7x7_1 -from dizoo.minigrid.envs.noisy_tv import NoisyTVEnv \ No newline at end of file +from dizoo.minigrid.envs.noisy_tv import NoisyTVEnv diff --git a/dizoo/mujoco/config/halfcheetah_ddpg_config.py b/dizoo/mujoco/config/halfcheetah_ddpg_config.py index 640717b8c6..c9031d36f4 100644 --- a/dizoo/mujoco/config/halfcheetah_ddpg_config.py +++ b/dizoo/mujoco/config/halfcheetah_ddpg_config.py @@ -62,4 +62,4 @@ if __name__ == "__main__": # or you can enter `ding -m serial -c halfcheetah_ddpg_config.py -s 0` from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), seed=0) \ No newline at end of file + serial_pipeline((main_config, create_config), seed=0) diff --git a/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py b/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py index 3816db6ef5..b6bfc0818e 100644 --- a/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py +++ b/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py @@ -18,7 +18,7 @@ evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, stop_value=1e6, - manager=dict(shared_memory=False,), + manager=dict(shared_memory=False, ), ), policy=dict( cuda=True, @@ -52,9 +52,7 @@ end=0.05, decay=100000, ), - replay_buffer=dict( - replay_buffer_size=5000, - ), + replay_buffer=dict(replay_buffer_size=5000, ), ), ), ) @@ -76,4 +74,4 @@ if __name__ == '__main__': # or you can enter `ding -m serial -c ptz_pistonball_qmix_config.py -s 0` from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), seed=0, max_env_step=max_env_step) \ No newline at end of file + serial_pipeline((main_config, create_config), seed=0, max_env_step=max_env_step) diff --git a/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py b/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py index 775af37d6a..d21ed049db 100644 --- a/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py +++ b/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py @@ -42,9 +42,7 @@ def reset(self) -> np.ndarray: # Initialize the pistonball environment parallel_env = pistonball_v6.parallel_env self._env = parallel_env( - n_pistons=self._num_pistons, - continuous=self._continuous_actions, - max_cycles=self._max_cycles + n_pistons=self._num_pistons, continuous=self._continuous_actions, max_cycles=self._max_cycles ) self._env.reset() self._agents = self._env.agents @@ -72,14 +70,16 @@ def reset(self) -> np.ndarray: self._reward_space = gym.spaces.Dict( { - agent: gym.spaces.Box(low=float('-inf'), high=float('inf'), shape=(1,), dtype=np.float32) + agent: gym.spaces.Box(low=float('-inf'), high=float('inf'), shape=(1, ), dtype=np.float32) for agent in self._agents } ) if self._replay_path is not None: self._env.render_mode = 'rgb_array' - self._env = PTZRecordVideo(self._env, self._replay_path, name_prefix=f'rl-video-{id(self)}', disable_logger=True) + self._env = PTZRecordVideo( + self._env, self._replay_path, name_prefix=f'rl-video-{id(self)}', disable_logger=True + ) self._init_flag = True if hasattr(self, '_seed'): @@ -123,7 +123,9 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: action = self._process_action(action) if self._act_scale: for agent in self._agents: - action[agent] = affine_transform(action[agent], min_val=self.action_space[agent].low, max_val=self.action_space[agent].high) + action[agent] = affine_transform( + action[agent], min_val=self.action_space[agent].low, max_val=self.action_space[agent].high + ) obs, rew, done, trunc, info = self._env.step(action) obs_n = self._process_obs(obs) @@ -132,14 +134,13 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: if self.normalize_reward: # TODO: more elegant scale factor - rew_n = rew_n / (self._num_pistons*50) + rew_n = rew_n / (self._num_pistons * 50) self._eval_episode_return += rew_n.item() done_n = reduce(lambda x, y: x and y, done.values()) or self._step_count >= self._max_cycles if done_n: info['eval_episode_return'] = self._eval_episode_return - return BaseEnvTimestep(obs_n, rew_n, done_n, info) @@ -157,8 +158,7 @@ def _process_obs(self, obs: Dict[str, np.ndarray]) -> np.ndarray: """ # Process agent observations, transpose if channel_first is True obs = np.array( - [np.transpose(obs[agent], (2, 0, 1)) if self._channel_first else obs[agent] - for agent in self._agents], + [np.transpose(obs[agent], (2, 0, 1)) if self._channel_first else obs[agent] for agent in self._agents], dtype=np.uint8 ) @@ -167,9 +167,7 @@ def _process_obs(self, obs: Dict[str, np.ndarray]) -> np.ndarray: return obs # Initialize return dictionary - ret = { - 'agent_state': (obs / 255.0).astype(np.float32) - } + ret = {'agent_state': (obs / 255.0).astype(np.float32)} # Obtain global state, transpose if channel_first is True global_state = self._env.state() @@ -179,10 +177,7 @@ def _process_obs(self, obs: Dict[str, np.ndarray]) -> np.ndarray: # Handle agent-specific global states by repeating the global state for each agent if self._agent_specific_global_state: - ret['global_state'] = np.tile( - np.expand_dims(ret['global_state'], axis=0), - (self._num_pistons, 1, 1, 1) - ) + ret['global_state'] = np.tile(np.expand_dims(ret['global_state'], axis=0), (self._num_pistons, 1, 1, 1)) # Set action mask for each agent ret['action_mask'] = np.ones((self._num_pistons, *self._action_dim), dtype=np.float32) @@ -239,6 +234,6 @@ def create_evaluator_env_cfg(cfg: dict) -> List[dict]: cfg = copy.deepcopy(cfg) cfg.normalize_reward = False return [cfg for _ in range(evaluator_env_num)] - + def __repr__(self) -> str: - return "DI-engine PettingZoo Pistonball Env" \ No newline at end of file + return "DI-engine PettingZoo Pistonball Env" diff --git a/dizoo/petting_zoo/envs/test_petting_zoo_pistonball_env.py b/dizoo/petting_zoo/envs/test_petting_zoo_pistonball_env.py index ea5ac988d7..1e9a88292d 100644 --- a/dizoo/petting_zoo/envs/test_petting_zoo_pistonball_env.py +++ b/dizoo/petting_zoo/envs/test_petting_zoo_pistonball_env.py @@ -103,4 +103,4 @@ def test_agent_specific_global_state(self): assert isinstance(timestep.done, bool), timestep.done assert isinstance(timestep.reward, np.ndarray), timestep.reward print(env.observation_space, env.action_space, env.reward_space) - env.close() \ No newline at end of file + env.close() From 31f2398547e6591122b332d91d7c850838c5fb00 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Thu, 5 Dec 2024 00:17:51 +0800 Subject: [PATCH 18/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/hpt.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index 27d194ea88..c7cb84c73e 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -20,7 +20,7 @@ class HPT(nn.Module): __init__, forward .. note:: - The model is designed to be flexible and can be adapted + The model is designed to be flexible and can be adapted for different input dimensions and action spaces. """ @@ -34,7 +34,7 @@ def __init__(self, state_dim: int, action_dim: int): - action_dim (:obj:`int`): The dimension of the action space. .. note:: - The Policy Stem is initialized with cross-attention, + The Policy Stem is initialized with cross-attention, and the Dueling Head is set to process the resulting tokens. """ super(HPT, self).__init__() @@ -48,7 +48,7 @@ def __init__(self, state_dim: int, action_dim: int): def forward(self, x: torch.Tensor): """ Overview: - Forward pass of the HPT model. + Forward pass of the HPT model. Computes latent tokens from the input state and passes them through the Dueling Head. Arguments: @@ -71,14 +71,14 @@ class PolicyStem(nn.Module): Overview: The Policy Stem module is responsible for processing input features and generating latent tokens using a cross-attention mechanism. - It extracts features from the input and then applies cross-attention + It extracts features from the input and then applies cross-attention to generate a set of latent tokens. Interfaces: __init__, init_cross_attn, compute_latent, forward .. note:: - This module is inspired by the implementation in the Perceiver IO model + This module is inspired by the implementation in the Perceiver IO model and uses attention mechanisms for feature extraction. """ @@ -89,7 +89,7 @@ def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): Arguments: - feature_dim (:obj:`int`): The dimension of the input features. - - token_dim (:obj:`int`): The dimension of the latent tokens generated + - token_dim (:obj:`int`): The dimension of the latent tokens generated by the attention mechanism. """ super().__init__() @@ -102,12 +102,13 @@ def init_cross_attn(self): """Initialize cross-attention module and learnable tokens.""" token_num = 16 self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) - self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) + self.cross_attention = CrossAttention( + 128, heads=8, dim_head=64, dropout=0.1) def compute_latent(self, x: torch.Tensor) -> torch.Tensor: """ Overview: - Compute latent representations of the input data using + Compute latent representations of the input data using the feature extractor and cross-attention. Arguments: @@ -118,10 +119,12 @@ def compute_latent(self, x: torch.Tensor) -> torch.Tensor: """ # Using the Feature Extractor stem_feat = self.feature_extractor(x) - stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) + stem_feat = stem_feat.reshape( + stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) # Calculating latent tokens using CrossAttention stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128) - stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128) + stem_tokens = self.cross_attention( + stem_tokens, stem_feat) # (B, 16, 128) return stem_tokens def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -160,7 +163,7 @@ def device(self): class CrossAttention(nn.Module): """ Overview: - CrossAttention module used in the Perceiver IO model. + CrossAttention module used in the Perceiver IO model. It computes the attention between the query and context tensors, and returns the output tensor after applying attention. @@ -171,7 +174,8 @@ class CrossAttention(nn.Module): dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0. """ - def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0): + def __init__(self, query_dim: int, heads: int = 8, + dim_head: int = 64, dropout: float = 0.0): super().__init__() inner_dim = dim_head * heads context_dim = query_dim @@ -184,7 +188,8 @@ def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: self.dropout = nn.Dropout(dropout) - def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, context: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Overview: Forward pass of the CrossAttention module. @@ -201,7 +206,8 @@ def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.T h = self.heads q = self.to_q(x) k, v = self.to_kv(context).chunk(2, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange( + t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale if mask is not None: From 9d30de84a8f70febb65c4e4bfc68b55e84b91b63 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Thu, 5 Dec 2024 00:55:11 +0800 Subject: [PATCH 19/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/tests/test_hpt.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/ding/model/template/tests/test_hpt.py b/ding/model/template/tests/test_hpt.py index 5b17bbefe5..bfb7844f32 100644 --- a/ding/model/template/tests/test_hpt.py +++ b/ding/model/template/tests/test_hpt.py @@ -5,8 +5,8 @@ from ding.torch_utils import is_differentiable T, B = 3, 4 -obs_shape = [4, (8, ), (4, 64, 64)] # Example observation shapes -act_shape = [3, (6, ), [2, 3, 6]] # Example action shapes +obs_shape = [4, (8, ), (4, 64, 64)] +act_shape = [3, (6, ), [2, 3, 6]] args = list(product(*[obs_shape, act_shape])) @@ -26,11 +26,21 @@ def output_check(self, model, outputs): def test_hpt(self, obs_shape, act_shape): if isinstance(obs_shape, int): inputs = torch.randn(B, obs_shape) + state_dim = obs_shape else: inputs = torch.randn(B, *obs_shape) - model = HPT(state_dim=obs_shape, action_dim=act_shape) + state_dim = obs_shape[0] + + if isinstance(act_shape, int): + action_dim = act_shape + else: + action_dim = len(act_shape) + + model = HPT(state_dim=state_dim, action_dim=action_dim) outputs = model(inputs) + assert isinstance(outputs, torch.Tensor) + if isinstance(act_shape, int): assert outputs.shape == (B, act_shape) elif len(act_shape) == 1: @@ -38,4 +48,5 @@ def test_hpt(self, obs_shape, act_shape): else: for i, s in enumerate(act_shape): assert outputs[i].shape == (B, s) + self.output_check(model, outputs) From 7afb21d50a872335c157dd8a6fd64b3f53b0c884 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Thu, 5 Dec 2024 01:03:19 +0800 Subject: [PATCH 20/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/hpt.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index c7cb84c73e..cf38a93f39 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -102,8 +102,7 @@ def init_cross_attn(self): """Initialize cross-attention module and learnable tokens.""" token_num = 16 self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) - self.cross_attention = CrossAttention( - 128, heads=8, dim_head=64, dropout=0.1) + self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) def compute_latent(self, x: torch.Tensor) -> torch.Tensor: """ @@ -119,12 +118,10 @@ def compute_latent(self, x: torch.Tensor) -> torch.Tensor: """ # Using the Feature Extractor stem_feat = self.feature_extractor(x) - stem_feat = stem_feat.reshape( - stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) + stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128) # Calculating latent tokens using CrossAttention stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128) - stem_tokens = self.cross_attention( - stem_tokens, stem_feat) # (B, 16, 128) + stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128) return stem_tokens def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -174,8 +171,7 @@ class CrossAttention(nn.Module): dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0. """ - def __init__(self, query_dim: int, heads: int = 8, - dim_head: int = 64, dropout: float = 0.0): + def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0): super().__init__() inner_dim = dim_head * heads context_dim = query_dim @@ -188,8 +184,7 @@ def __init__(self, query_dim: int, heads: int = 8, self.dropout = nn.Dropout(dropout) - def forward(self, x: torch.Tensor, context: torch.Tensor, - mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Overview: Forward pass of the CrossAttention module. @@ -206,8 +201,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor, h = self.heads q = self.to_q(x) k, v = self.to_kv(context).chunk(2, dim=-1) - q, k, v = map(lambda t: rearrange( - t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale if mask is not None: From 366ef6880eeeefbe9ff262ad96702f7b50ada0da Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Sun, 8 Dec 2024 01:21:42 +0800 Subject: [PATCH 21/21] feature(xyy):add HPT model and test_hpt --- ding/model/template/hpt.py | 59 +++++++------------ ding/model/template/tests/test_hpt.py | 4 +- .../config/lunarlander_hpt_config.py | 10 ++-- .../entry/lunarlander_dqn_example.py | 1 - 4 files changed, 30 insertions(+), 44 deletions(-) diff --git a/ding/model/template/hpt.py b/ding/model/template/hpt.py index cf38a93f39..07a47bacea 100644 --- a/ding/model/template/hpt.py +++ b/ding/model/template/hpt.py @@ -5,8 +5,6 @@ from ding.model.common.head import DuelingHead from ding.utils.registry_factory import MODEL_REGISTRY -INIT_CONST = 0.02 - @MODEL_REGISTRY.register('hpt') class HPT(nn.Module): @@ -17,11 +15,10 @@ class HPT(nn.Module): and the Dueling Head computes Q-values for discrete action spaces. Interfaces: - __init__, forward + ``__init__``, ``forward`` + + GitHub: [https://github.com/liruiw/HPT/blob/main/hpt/models/policy_stem.py] - .. note:: - The model is designed to be flexible and can be adapted - for different input dimensions and action spaces. """ def __init__(self, state_dim: int, action_dim: int): @@ -75,14 +72,15 @@ class PolicyStem(nn.Module): to generate a set of latent tokens. Interfaces: - __init__, init_cross_attn, compute_latent, forward + ``__init__``, ``init_cross_attn``, ``compute_latent``, ``forward`` .. note:: This module is inspired by the implementation in the Perceiver IO model and uses attention mechanisms for feature extraction. """ + INIT_CONST = 0.02 - def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): + def __init__(self, feature_dim: int = 8, token_dim: int = 128): """ Overview: Initialize the Policy Stem module with a feature extractor and cross-attention mechanism. @@ -101,7 +99,7 @@ def __init__(self, feature_dim: int = 8, token_dim: int = 128, **kwargs): def init_cross_attn(self): """Initialize cross-attention module and learnable tokens.""" token_num = 16 - self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST) + self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * self.INIT_CONST) self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1) def compute_latent(self, x: torch.Tensor) -> torch.Tensor: @@ -137,44 +135,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ return self.compute_latent(x) - def freeze(self): - """Freeze the parameters of the model, preventing updates during training.""" - for param in self.parameters(): - param.requires_grad = False - - def unfreeze(self): - """Unfreeze the parameters of the model, allowing updates during training.""" - for param in self.parameters(): - param.requires_grad = True - - def save(self, path: str): - """Save the model state dictionary to a file.""" - torch.save(self.state_dict(), path) - @property - def device(self): + def device(self) -> torch.device: """Returns the device on which the model parameters are located.""" return next(self.parameters()).device class CrossAttention(nn.Module): - """ - Overview: - CrossAttention module used in the Perceiver IO model. - It computes the attention between the query and context tensors, - and returns the output tensor after applying attention. - - Arguments: - query_dim (:obj:`int`): The dimension of the query input. - heads (:obj:`int`, optional): The number of attention heads. Defaults to 8. - dim_head (:obj:`int`, optional): The dimension of each attention head. Defaults to 64. - dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0. - """ def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0): + """ + Overview: + CrossAttention module used in the Perceiver IO model. + It computes the attention between the query and context tensors, + and returns the output tensor after applying attention. + + Arguments: + - query_dim (:obj:`int`): The dimension of the query input. + - heads (:obj:`int`, optional): The number of attention heads. Defaults to 8. + - dim_head (:obj:`int`, optional): The dimension of each attention head. Defaults to 64. + - dropout (:obj:`float`, optional): The dropout probability. Defaults to 0.0. + """ super().__init__() inner_dim = dim_head * heads context_dim = query_dim + # Scaling factor for the attention logits to ensure stable gradients. self.scale = dim_head ** -0.5 self.heads = heads @@ -193,7 +178,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.T Arguments: - x (:obj:`torch.Tensor`): The query input tensor. - context (:obj:`torch.Tensor`): The context input tensor. - - mask (:obj:`torch.Tensor`, optional): The attention mask tensor. Defaults to None. + - mask (:obj:`Optional[torch.Tensor]`): The attention mask tensor. Defaults to None. Returns: - torch.Tensor: The output tensor after applying attention. diff --git a/ding/model/template/tests/test_hpt.py b/ding/model/template/tests/test_hpt.py index bfb7844f32..e4d9729d59 100644 --- a/ding/model/template/tests/test_hpt.py +++ b/ding/model/template/tests/test_hpt.py @@ -5,8 +5,8 @@ from ding.torch_utils import is_differentiable T, B = 3, 4 -obs_shape = [4, (8, ), (4, 64, 64)] -act_shape = [3, (6, ), [2, 3, 6]] +obs_shape = [4, (8, )] +act_shape = [3, (6, )] args = list(product(*[obs_shape, act_shape])) diff --git a/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py b/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py index fea3aca6bc..b58166d566 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py @@ -69,8 +69,10 @@ ) lunarlander_hpt_create_config = EasyDict(lunarlander_hpt_create_config) create_config = lunarlander_hpt_create_config +""" +This is a configuration file for LunarLander environment with HPT (Hindsight Policy Transformer). +To run this config, please use the lunarlander_hpt_example.py script. -if __name__ == "__main__": - # or you can enter `ding -m serial -c lunarlander_hpt_config.py -s 0` - from ding.entry import serial_pipeline - serial_pipeline([main_config, create_config], seed=0) +Example: + python lunarlander_hpt_example.py +""" diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py index 048554c310..ae9b30d1fd 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py @@ -59,7 +59,6 @@ def main(): task.use(ModelExchanger(model)) # Here is the part of single process pipeline. - evaluator_env.enable_save_replay(replay_path='./video') task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(eps_greedy_handler(cfg)) task.use(StepCollector(cfg, policy.collect_mode, collector_env))