Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support K-Net #1289

Merged
merged 24 commits into from
Mar 10, 2022
Merged

[Feature] Support K-Net #1289

merged 24 commits into from
Mar 10, 2022

Conversation

MengzhangLI
Copy link
Contributor

@MengzhangLI MengzhangLI commented Feb 15, 2022

Re-implementing K-Net.

Original paper: https://arxiv.org/abs/2106.14855.
Original repo:https://github.com/ZwwWayne/K-Net.

Results on ADE20K

Method Backbone Crop Size Lr schd Mem (GB) Inf time (fps) mIoU mIoU from original repo
KNet + FCN R-50-D8 512x512 80000 7.01 19.24 43.60 43.30
KNet + PSPNet R-50-D8 512x512 80000 6.98 20.04 44.18 43.90
KNet + DeepLabV3 R-50-D8 512x512 80000 7.42 12.10 45.06 44.60
KNet + UperNet R-50-D8 512x512 80000 7.34 17.11 43.45 43.60
KNet + UperNet Swin-T 512x512 80000 7.57 15.56 45.84 45.40
KNet + UperNet Swin-L 512x512 80000 13.5 8.29 52.05 52.0
KNet + UperNet Swin-L 640x640 80000 13.54 8.29 52.21 52.7

@MengzhangLI MengzhangLI self-assigned this Feb 15, 2022
@MengzhangLI MengzhangLI added the WIP Work in process label Feb 15, 2022
@MengzhangLI
Copy link
Contributor Author

After refactoring, the model can directly used for inference and result is the same, which means the key is still the same with old one.

image

@codecov
Copy link

codecov bot commented Feb 22, 2022

Codecov Report

Merging #1289 (d0a7c08) into master (a7c2f68) will increase coverage by 0.06%.
The diff coverage is 93.75%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1289      +/-   ##
==========================================
+ Coverage   90.32%   90.39%   +0.06%     
==========================================
  Files         132      133       +1     
  Lines        7699     7879     +180     
  Branches     1290     1316      +26     
==========================================
+ Hits         6954     7122     +168     
- Misses        531      534       +3     
- Partials      214      223       +9     
Flag Coverage Δ
unittests 90.39% <93.75%> (+0.06%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmseg/models/decode_heads/knet_head.py 93.71% <93.71%> (ø)
mmseg/models/decode_heads/__init__.py 100.00% <100.00%> (ø)
mmseg/datasets/builder.py 86.74% <0.00%> (-1.06%) ⬇️
mmseg/models/losses/accuracy.py 100.00% <0.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e8cc322...d0a7c08. Read the comment docs.

@MengzhangLI
Copy link
Contributor Author

Add ep3,4,5.

@MengzhangLI
Copy link
Contributor Author

Could anyone @ZwwWayne for code review? Thx in advance.

@MeowZheng
Copy link
Collaborator

_forward_feature was added most recently
ref: #1299

mmseg/models/decode_heads/knet.py Outdated Show resolved Hide resolved
mmseg/models/decode_heads/knet.py Outdated Show resolved Hide resolved
mmseg/models/decode_heads/knet.py Outdated Show resolved Hide resolved
mmseg/models/decode_heads/knet.py Outdated Show resolved Hide resolved
mmseg/models/decode_heads/knet.py Outdated Show resolved Hide resolved
mmseg/models/decode_heads/knet.py Outdated Show resolved Hide resolved
mmseg/models/decode_heads/knet.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@MeowZheng MeowZheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might add readme file and metafile

mmseg/models/decode_heads/knet_head.py Outdated Show resolved Hide resolved
mmseg/models/decode_heads/knet_head.py Outdated Show resolved Hide resolved
mmseg/models/decode_heads/knet_head.py Outdated Show resolved Hide resolved
mmseg/models/decode_heads/knet_head.py Show resolved Hide resolved
mmseg/models/decode_heads/knet_head.py Outdated Show resolved Hide resolved
mmseg/models/decode_heads/knet_head.py Outdated Show resolved Hide resolved
Comment on lines 92 to 93
Tensor: The output tensor of shape
(self.num_classes(KernelUpdateHead) * self.in_channels(KernelUpdateHead)/ in_channels, kernel size * kernel size, in_channels). # noqa
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor: The dynamic kernel with shape ()?
is there any short description for the shape of the dynamic kernel? @MengzhangLI @ZwwWayne

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like: 'The dynamic kernel with shape (N, C), where N is the number of classes, C is the feature map channels'?

@MeowZheng MeowZheng merged commit 2e28db0 into open-mmlab:master Mar 10, 2022
@MengzhangLI MengzhangLI deleted the KNET branch March 17, 2022 09:24
@MengzhangLI MengzhangLI removed the WIP Work in process label Mar 18, 2022
mob5566 pushed a commit to mob5566/mmsegmentation that referenced this pull request Apr 13, 2022
* knet first commit

* fix import error in knet

* remove kernel update head from decoder head

* [Feature] Add kenerl updation for some decoder heads.

* [Feature] Add kenerl updation for some decoder heads.

* directly use forward_feature && modify other 3 decoder heads

* remover kernel_update attr

* delete unnecessary variables in forward function

* delete kernel update function

* delete kernel update function

* delete kernel_generate_head

* add unit test & comments in knet.py

* add copyright to fix lint error

* modify config names of knet

* rename swin-l 640

* upload models&logs and refactor knet_head.py

* modify docstrings and add some ut

* add url, modify docstring and add loss ut

* modify docstrings
ZhimingNJ pushed a commit to AetrexTechnology/mmsegmentation that referenced this pull request Jun 29, 2022
* knet first commit

* fix import error in knet

* remove kernel update head from decoder head

* [Feature] Add kenerl updation for some decoder heads.

* [Feature] Add kenerl updation for some decoder heads.

* directly use forward_feature && modify other 3 decoder heads

* remover kernel_update attr

* delete unnecessary variables in forward function

* delete kernel update function

* delete kernel update function

* delete kernel_generate_head

* add unit test & comments in knet.py

* add copyright to fix lint error

* modify config names of knet

* rename swin-l 640

* upload models&logs and refactor knet_head.py

* modify docstrings and add some ut

* add url, modify docstring and add loss ut

* modify docstrings
aravind-h-v pushed a commit to aravind-h-v/mmsegmentation that referenced this pull request Mar 27, 2023
…pen-mmlab#1289)

* fix non square images with UNet2DModel and DDIM/DDPM pipelines

* fix unet_2d `sample_size` docstring

* update pipeline tests for unet uncond

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants