-
Notifications
You must be signed in to change notification settings - Fork 9.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP]: Refactor DETR #8754
[WIP]: Refactor DETR #8754
Conversation
* Upgrade onnxsim to 0.4.0 * Update pytorch2onnx.py * Update pytorch2onnx.py
* Add .github/workflow/stale.yml * modify prompt message in stale.yml * modify check strategy now, issues and prs with any of 'invalid', 'awaiting response' will be checked
Fix HTC link
the parameter of offset is not set as continuous will trigger the runtime error: offset must be continuous
* Fix swin backbone absolute pos_embed resizing * fix lint * fix lint * add unit test * Update swin.py Co-authored-by: Cedric Luo <luochunhua1996@outlook.com>
* Fix floordiv warning. * Add floordiv wrapper.
* logger hooks samples updated * [Docs] MMDetWandB LoggerHook Details Added * [Docs] lint test passed
Short-term PlanWe plan to implement a base detector class for Transformer-based detector, TransformerDetector, move the transformer-related part in the original head module to the new detector module, and use TransformerEncoder and TransformerDecoder directly in the detector. NOTE we are still exploring better refactoring solutions, and the short-term plans are only tried temporarily. It is welcome to give suggestions for refactor and report bugs, in the comments or reviews, or contact me. |
…DETR-like detectors
x = self.neck(x) | ||
return x | ||
|
||
def forward_train(self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some panoptic segmentation models may inherit from TransformerDetector
in the future, gt_semantic_seg
should be passed in forward_train
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi~ Thanks for your review !
This suggestion is very important. We will discuss it at our next meeting.
reference_points=reference_points, | ||
level_start_index=level_start_index, | ||
valid_ratios=valid_ratios, | ||
query = self.self_attn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Mask2Former, operation_order=('cross_attn', 'norm', 'self_attn', 'norm', 'ffn', 'norm'), see
operation_order=('cross_attn', 'norm', 'self_attn', 'norm', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi~ Thanks for your review !
The decoder was recommanded to custom decoderlayer module for each transformer detector.
And we were going to move heavy transformer modules to detector.
We have recommit a PR #8763 on dev-3.x, and this PR will be closed.
Additionally, I can assist with future refactor work of the Mask2Former.
For Mask2Former, the refactor decoderlayer can be implemented like:
class Mask2FormerDecoderLayer(DetrTransformerDecoderLayer):
# I'm not sure to inherit DetrTransformerDecoderLayer or DeformableDetrTransformerDecoderLayer temporarily.
def __init__(self, *args, **kwargs):
super(Mask2FormerDecoderLayer, self).__init__(*args, **kwargs)
def forward(self,
query,
key=None,
value=None,
query_pos=None,
key_pos=None,
self_attn_masks=None,
cross_attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
**kwargs):
query = self.cross_attn(
query=query,
key=key,
value=value,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=cross_attn_masks,
key_padding_mask=key_padding_mask,
**kwargs)
query = self.norms[0](query)
query = self.self_attn(
query=query,
key=query,
value=query,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=self_attn_masks,
key_padding_mask=query_key_padding_mask,
**kwargs)
query = self.norms[1](query)
query = self.ffn(query)
query = self.norms[2](query)
return query
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
We decided to PR refactored DETR-related modules on dev-3.x. Hence, this PR may not be updated, but reviews and comments about this PR will continue to be answered. The new PR is #8763. |
Motivation
We (me, @jshilong, @LYMDLUT, and @KeiChiTse) are going to refactor DETR-like models, to enhance the usability and readability of our codebase.