Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(xyy):add HPT model to implement PolicyStem+DuelingHead #841

Merged
merged 21 commits into from
Dec 8, 2024

Conversation

luodi-7
Copy link
Contributor

@luodi-7 luodi-7 commented Nov 27, 2024

Description

Here are some tensorboard plots from the lunarlander_hpt_example.py run.
hpt_episode_return
hpt_train_q_value
hpt_target_q_value
hpt_train_total_loss

Related Issue

TODO

Check List

  • merge the latest version source branch/repo, and resolve all the conflicts
  • pass style check
  • pass all the tests

@PaParaZz1 PaParaZz1 added the algo Add new algorithm or improve old one label Nov 28, 2024
@@ -24,6 +24,8 @@
from .vae import VanillaVAE
from .decision_transformer import DecisionTransformer
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
from .hpt import HPT

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimize import order



class PolicyStem(nn.Module):
"""policy stem
Copy link
Collaborator

@puyuan1996 puyuan1996 Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reformat the docstring as the DI-engine style

create_config = lunarlander_hpt_create_config

if __name__ == "__main__":
# or you can enter `ding -m serial -c lunarlander_dqn_config.py -s 0`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the comments

import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='subprocess'),
# env_manager=dict(type='base'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move unused comments

@MODEL_REGISTRY.register('hpt')
class HPT(nn.Module):

def __init__(self, state_dim, action_dim):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add overview and related introduction

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add unittest like other template in DI-engine

policy=dict(
# Whether to use cuda for network.
cuda=True,
load_path="./lunarlander_hpt_seed0/ckpt/ckpt_best.pth.tar",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused part

dizoo/box2d/lunarlander/config/lunarlander_hpt_config.py Outdated Show resolved Hide resolved
dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py Outdated Show resolved Hide resolved
- mask (:obj:`torch.Tensor`, optional): The attention mask tensor. Defaults to None.

Returns:
- torch.Tensor: The output tensor after applying attention.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add dtype

ding/model/template/hpt.py Outdated Show resolved Hide resolved
ding/model/template/hpt.py Outdated Show resolved Hide resolved
ding/model/template/hpt.py Outdated Show resolved Hide resolved
ding/model/template/hpt.py Outdated Show resolved Hide resolved
ding/model/template/hpt.py Outdated Show resolved Hide resolved
ding/model/template/hpt.py Outdated Show resolved Hide resolved
from ding.torch_utils import is_differentiable

T, B = 3, 4
obs_shape = [4, (8, ), (4, 64, 64)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the group (4, 64, 64) here, the current HPT model you implemented here can't support the image input like (4, 64, 64). It need a CNN feature extractor, and the corresponding state_dim should be `(4, 64, 64) rather than the current scalar. We will left this part for future, thus you can remove this group now.

@PaParaZz1 PaParaZz1 merged commit bbc9cc4 into opendilab:main Dec 8, 2024
8 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algo Add new algorithm or improve old one
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants