-
Notifications
You must be signed in to change notification settings - Fork 381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(xyy):add HPT model to implement PolicyStem+DuelingHead #841
Conversation
ding/model/template/__init__.py
Outdated
@@ -24,6 +24,8 @@ | |||
from .vae import VanillaVAE | |||
from .decision_transformer import DecisionTransformer | |||
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS | |||
from .hpt import HPT | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optimize import order
ding/model/template/hpt.py
Outdated
|
||
|
||
class PolicyStem(nn.Module): | ||
"""policy stem |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reformat the docstring as the DI-engine style
create_config = lunarlander_hpt_create_config | ||
|
||
if __name__ == "__main__": | ||
# or you can enter `ding -m serial -c lunarlander_dqn_config.py -s 0` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change the comments
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], | ||
), | ||
env_manager=dict(type='subprocess'), | ||
# env_manager=dict(type='base'), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move unused comments
ding/model/template/hpt.py
Outdated
@MODEL_REGISTRY.register('hpt') | ||
class HPT(nn.Module): | ||
|
||
def __init__(self, state_dim, action_dim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add overview and related introduction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add unittest like other template in DI-engine
policy=dict( | ||
# Whether to use cuda for network. | ||
cuda=True, | ||
load_path="./lunarlander_hpt_seed0/ckpt/ckpt_best.pth.tar", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove unused part
- mask (:obj:`torch.Tensor`, optional): The attention mask tensor. Defaults to None. | ||
|
||
Returns: | ||
- torch.Tensor: The output tensor after applying attention. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add dtype
from ding.torch_utils import is_differentiable | ||
|
||
T, B = 3, 4 | ||
obs_shape = [4, (8, ), (4, 64, 64)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the group (4, 64, 64)
here, the current HPT model you implemented here can't support the image input like (4, 64, 64)
. It need a CNN feature extractor, and the corresponding state_dim
should be `(4, 64, 64) rather than the current scalar. We will left this part for future, thus you can remove this group now.
Description
Here are some tensorboard plots from the lunarlander_hpt_example.py run.
Related Issue
TODO
Check List