diff --git a/classification b/classification deleted file mode 160000 index 6cb1441..0000000 --- a/classification +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6cb144105fc5c2f778e51cc66e35314938f96fae diff --git a/classification/LICENSE b/classification/LICENSE new file mode 100644 index 0000000..526d738 --- /dev/null +++ b/classification/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2022 image_classification_sota contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/classification/README.md b/classification/README.md new file mode 100644 index 0000000..744641c --- /dev/null +++ b/classification/README.md @@ -0,0 +1,87 @@ +# Image Classification SOTA + +`Image Classification SOTA` is an image classification toolbox based on PyTorch. + +## Updates +### May 27, 2022 +* Add knowledge distillation methods (KD and [DIST](https://github.com/hunto/DIST_KD)). + +### March 24, 2022 +* Support training strategies in DeiT (ViT). + +### March 11, 2022 +* Release training code. + +## Supported Algorithms +### Structural Re-parameterization (Rep) +* DBB (CVPR 2021) [[paper]](https://arxiv.org/abs/2103.13425) [[original repo]](https://github.com/DingXiaoH/DiverseBranchBlock) +* DyRep (CVPR 2022) [[README]](https://github.com/hunto/DyRep) + +### Knowledge Distillation (KD) +* KD [[paper]](https://arxiv.org/abs/1503.02531) +* DIST [[README]](https://github.com/hunto/DIST_KD) [[paper]](https://arxiv.org/abs/2205.10536) + +## Requirements +``` +torch>=1.0.1 +torchvision +``` + +## Getting Started +### Prepare datasets +It is recommended to symlink the dataset root to `image_classification_sota/data`. Then the file structure should be like +``` +image_classification_sota +├── lib +├── tools +├── configs +├── data +│ ├── imagenet +│ │ ├── meta +│ │ ├── train +│ │ ├── val +│ ├── cifar +│ │ ├── cifar-10-batches-py +│ │ ├── cifar-100-python +``` + +### Training configurations +* `Strategies`: The training strategies are configured using yaml file or arguments. Examples are in `configs/strategies` directory. + +### Train a model + +* Training with a single GPU + ```shell + python tools/train.py -c ${CONFIG} --model ${MODEL} [optional arguments] + ``` + +* Training with multiple GPUs + ```shell + sh tools/dist_train.sh ${GPU_NUM} ${CONFIG} ${MODEL} [optional arguments] + ``` + +* For slurm users + ```shell + sh tools/slurm_train.sh ${PARTITION} ${GPU_NUM} ${CONFIG} ${MODEL} [optional arguments] + ``` + +**Examples** +* Train ResNet-50 on ImageNet + ```shell + sh tools/dist_train.sh 8 configs/strategies/resnet/resnet.yaml resnet50 --experiment imagenet_res50 + ``` + +* Train MobileNetV2 on ImageNet + ```shell + sh tools/dist_train.sh 8 configs/strategies/MBV2/mbv2.yaml nas_model --model-config configs/models/MobileNetV2/MobileNetV2.yaml --experiment imagenet_mbv2 + ``` + +* Train VGG-16 on CIFAR-10 + ```shell + sh tools/dist_train.sh 1 configs/strategies/CIFAR/cifar.yaml nas_model --model-config configs/models/VGG/vgg16_cifar10.yaml --experiment cifar10_vgg16 + ``` + +## Projects based on Image Classification SOTA +* [CVPR 2022] [DyRep](https://github.com/hunto/DyRep): Bootstrapping Training with Dynamic Re-parameterization +* [NeurIPS 2022] [DIST](https://github.com/hunto/DIST_KD): Knowledge Distillation from A Stronger Teacher +* [LightViT](https://github.com/hunto/LightViT): Towards Light-Weight Convolution-Free Vision Transformers diff --git a/classification/configs/models/DARTS/DARTS_V2_cifar.yaml b/classification/configs/models/DARTS/DARTS_V2_cifar.yaml new file mode 100644 index 0000000..85a15bb --- /dev/null +++ b/classification/configs/models/DARTS/DARTS_V2_cifar.yaml @@ -0,0 +1,7 @@ +genotype: + normal: "[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)]" + reduce: "[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)]" + +init_channels: 36 +layers: 20 +num_classes: 10 diff --git a/classification/configs/models/DARTS/DARTS_V2_imagenet.yaml b/classification/configs/models/DARTS/DARTS_V2_imagenet.yaml new file mode 100644 index 0000000..b6c7ddc --- /dev/null +++ b/classification/configs/models/DARTS/DARTS_V2_imagenet.yaml @@ -0,0 +1,7 @@ +genotype: + normal: "[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)]" + reduce: "[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)]" + +init_channels: 48 +layers: 14 +num_classes: 1000 diff --git a/classification/configs/models/GreedyNAS/GreedyNAS-A.yaml b/classification/configs/models/GreedyNAS/GreedyNAS-A.yaml new file mode 100644 index 0000000..4f57833 --- /dev/null +++ b/classification/configs/models/GreedyNAS/GreedyNAS-A.yaml @@ -0,0 +1,31 @@ +backbone: + # name: [stride, in_channels, out_channels, expand_ratio, op] + layer0: [2, 3, 32, 1, 'conv3x3'] + layer1: [1, 32, 16, 1, 'ir_3x3_nse'] + layer2: [2, 16, 32, 3, 'ir_5x5_se'] + layer3: [1, 32, 32, 3, 'ir_3x3_nse'] + layer4: [1, 32, 32, 1, 'id'] + layer5: [1, 32, 32, 3, 'ir_3x3_se'] + layer6: [2, 32, 40, 3, 'ir_5x5_nse'] + layer7: [1, 40, 40, 3, 'ir_7x7_nse'] + layer8: [1, 40, 40, 6, 'ir_3x3_se'] + layer9: [1, 40, 40, 6, 'ir_5x5_se'] + layer10: [2, 40, 80, 3, 'ir_5x5_se'] + layer11: [1, 80, 80, 3, 'ir_7x7_se'] + layer12: [1, 80, 80, 6, 'ir_7x7_nse'] + layer13: [1, 80, 80, 3, 'ir_5x5_nse'] + layer14: [1, 80, 96, 3, 'ir_5x5_se'] + layer15: [1, 96, 96, 3, 'ir_3x3_se'] + layer16: [1, 96, 96, 6, 'ir_3x3_nse'] + layer17: [1, 96, 96, 3, 'ir_7x7_se'] + layer18: [2, 96, 192, 3, 'ir_7x7_se'] + layer19: [1, 192, 192, 6, 'ir_7x7_se'] + layer20: [1, 192, 192, 3, 'ir_5x5_nse'] + layer21: [1, 192, 192, 6, 'ir_3x3_se'] + layer22: [1, 192, 320, 6, 'ir_7x7_se'] + layer23: [1, 320, 1280, 1, 'conv1x1'] + layer24: [1, 1280, 1280, 1, 'gavgp'] +head: + linear1: + dim_in: 1280 + dim_out: 1000 diff --git a/classification/configs/models/GreedyNAS/GreedyNAS-B.yaml b/classification/configs/models/GreedyNAS/GreedyNAS-B.yaml new file mode 100644 index 0000000..483ece2 --- /dev/null +++ b/classification/configs/models/GreedyNAS/GreedyNAS-B.yaml @@ -0,0 +1,31 @@ +backbone: + # name: [stride, in_channels, out_channels, expand_ratio, op] + layer0: [2, 3, 32, 1, 'conv3x3'] + layer1: [1, 32, 16, 1, 'ir_3x3_nse'] + layer2: [2, 16, 32, 3, 'ir_5x5_se'] + layer3: [1, 32, 32, 3, 'ir_5x5_se'] + layer4: [1, 32, 32, 3, 'ir_5x5_se'] + layer5: [1, 32, 32, 1, 'id'] + layer6: [2, 32, 40, 3, 'ir_5x5_se'] + layer7: [1, 40, 40, 3, 'ir_7x7_se'] + layer8: [1, 40, 40, 3, 'ir_5x5_se'] + layer9: [1, 40, 40, 3, 'ir_3x3_nse'] + layer10: [2, 40, 80, 6, 'ir_5x5_nse'] + layer11: [1, 80, 80, 1, 'id'] + layer12: [1, 80, 80, 3, 'ir_7x7_se'] + layer13: [1, 80, 80, 3, 'ir_7x7_se'] + layer14: [1, 80, 96, 3, 'ir_3x3_se'] + layer15: [1, 96, 96, 3, 'ir_5x5_nse'] + layer16: [1, 96, 96, 3, 'ir_5x5_nse'] + layer17: [1, 96, 96, 3, 'ir_7x7_se'] + layer18: [2, 96, 192, 6, 'ir_5x5_se'] + layer19: [1, 192, 192, 3, 'ir_7x7_se'] + layer20: [1, 192, 192, 3, 'ir_3x3_se'] + layer21: [1, 192, 192, 3, 'ir_7x7_se'] + layer22: [1, 192, 320, 6, 'ir_7x7_se'] + layer23: [1, 320, 1280, 1, 'conv1x1'] + layer24: [1, 1280, 1280, 1, 'gavgp'] +head: + linear1: + dim_in: 1280 + dim_out: 1000 diff --git a/classification/configs/models/GreedyNAS/GreedyNAS-C.yaml b/classification/configs/models/GreedyNAS/GreedyNAS-C.yaml new file mode 100644 index 0000000..ede08ca --- /dev/null +++ b/classification/configs/models/GreedyNAS/GreedyNAS-C.yaml @@ -0,0 +1,32 @@ +backbone: + # name: [stride, in_channels, out_channels, expand_ratio, op] + layer0: [2, 3, 32, 1, 'conv3x3'] + layer1: [1, 32, 16, 1, 'ir_3x3_nse'] + layer2: [2, 16, 32, 3, 'ir_7x7_nse'] + layer3: [1, 32, 32, 1, 'id'] + layer4: [1, 32, 32, 1, 'id'] + layer5: [1, 32, 32, 3, 'ir_3x3_se'] + layer6: [2, 32, 40, 3, 'ir_5x5_se'] + layer7: [1, 40, 40, 1, 'id'] + layer8: [1, 40, 40, 3, 'ir_3x3_se'] + layer9: [1, 40, 40, 3, 'ir_3x3_se'] + layer10: [2, 40, 80, 3, 'ir_7x7_se'] + layer11: [1, 80, 80, 6, 'ir_5x5_se'] + layer12: [1, 80, 80, 3, 'ir_5x5_se'] + layer13: [1, 80, 80, 1, 'id'] + layer14: [1, 80, 96, 3, 'ir_5x5_se'] + layer15: [1, 96, 96, 3, 'ir_5x5_nse'] + layer16: [1, 96, 96, 6, 'ir_3x3_nse'] + layer17: [1, 96, 96, 3, 'ir_7x7_nse'] + layer18: [2, 96, 192, 3, 'ir_5x5_se'] + layer19: [1, 192, 192, 3, 'ir_3x3_nse'] + layer20: [1, 192, 192, 3, 'ir_7x7_se'] + layer21: [1, 192, 192, 3, 'ir_5x5_nse'] + layer22: [1, 192, 320, 6, 'ir_7x7_se'] + layer23: [1, 320, 1280, 1, 'conv1x1'] + layer24: [1, 1280, 1280, 1, 'gavgp'] +head: + linear1: + dim_in: 1280 + dim_out: 1000 + diff --git a/classification/configs/models/MCT-NAS/MCT-NAS-A.yaml b/classification/configs/models/MCT-NAS/MCT-NAS-A.yaml new file mode 100644 index 0000000..c37773e --- /dev/null +++ b/classification/configs/models/MCT-NAS/MCT-NAS-A.yaml @@ -0,0 +1,32 @@ +backbone: + # name: [stride, in_channels, out_channels, expand_ratio, op] + layer0: [2, 3, 32, 1, 'conv3x3'] + layer1: [1, 32, 16, 1, 'ir_3x3_nse'] + layer2: [2, 16, 32, 6, 'ir_3x3_se'] + layer3: [1, 32, 32, 3, 'ir_3x3_se'] + layer4: [1, 32, 32, 3, 'ir_3x3_se'] + layer5: [1, 32, 32, 3, 'ir_3x3_se'] + layer6: [2, 32, 40, 3, 'ir_3x3_se'] + layer7: [1, 40, 40, 3, 'ir_5x5_se'] + layer8: [1, 40, 40, 3, 'ir_5x5_se'] + layer9: [1, 40, 40, 3, 'ir_3x3_se'] + layer10: [2, 40, 80, 6, 'ir_3x3_se'] + layer11: [1, 80, 80, 3, 'ir_5x5_se'] + layer12: [1, 80, 80, 3, 'ir_5x5_se'] + layer13: [1, 80, 80, 6, 'ir_3x3_se'] + layer14: [1, 80, 96, 6, 'ir_3x3_se'] + layer15: [1, 96, 96, 6, 'ir_7x7_se'] + layer16: [1, 96, 96, 6, 'ir_3x3_se'] + layer17: [1, 96, 96, 6, 'ir_5x5_se'] + layer18: [2, 96, 192, 6, 'ir_7x7_se'] + layer19: [1, 192, 192, 6, 'ir_7x7_se'] + layer20: [1, 192, 192, 6, 'ir_5x5_se'] + layer21: [1, 192, 192, 6, 'ir_5x5_se'] + layer22: [1, 192, 320, 6, 'ir_7x7_se'] + layer23: [1, 320, 1280, 1, 'conv1x1'] + layer24: [1, 1280, 1280, 1, 'gavgp'] +head: + linear1: + dim_in: 1280 + dim_out: 1000 + diff --git a/classification/configs/models/MCT-NAS/MCT-NAS-B.yaml b/classification/configs/models/MCT-NAS/MCT-NAS-B.yaml new file mode 100644 index 0000000..ece4790 --- /dev/null +++ b/classification/configs/models/MCT-NAS/MCT-NAS-B.yaml @@ -0,0 +1,32 @@ +backbone: + # name: [stride, in_channels, out_channels, expand_ratio, op] + layer0: [2, 3, 32, 1, 'conv3x3'] + layer1: [1, 32, 16, 1, 'ir_3x3_nse'] + layer2: [2, 16, 32, 3, 'ir_7x7_se'] + layer3: [1, 32, 32, 1, 'id'] + layer4: [1, 32, 32, 1, 'id'] + layer5: [1, 32, 32, 3, 'ir_3x3_se'] + layer6: [2, 32, 40, 3, 'ir_5x5_se'] + layer7: [1, 40, 40, 3, 'ir_5x5_se'] + layer8: [1, 40, 40, 6, 'ir_3x3_se'] + layer9: [1, 40, 40, 3, 'ir_3x3_se'] + layer10: [2, 40, 80, 3, 'ir_3x3_se'] + layer11: [1, 80, 80, 1, 'id'] + layer12: [1, 80, 80, 6, 'ir_5x5_se'] + layer13: [1, 80, 80, 6, 'ir_5x5_se'] + layer14: [1, 80, 96, 6, 'ir_5x5_se'] + layer15: [1, 96, 96, 3, 'ir_3x3_se'] + layer16: [1, 96, 96, 3, 'ir_3x3_se'] + layer17: [1, 96, 96, 3, 'ir_3x3_se'] + layer18: [2, 96, 192, 6, 'ir_5x5_se'] + layer19: [1, 192, 192, 3, 'ir_5x5_se'] + layer20: [1, 192, 192, 6, 'ir_7x7_se'] + layer21: [1, 192, 192, 3, 'ir_3x3_se'] + layer22: [1, 192, 320, 6, 'ir_7x7_se'] + layer23: [1, 320, 1280, 1, 'conv2d'] + layer24: [1, 1280, 1280, 1, 'gavgp'] +head: + linear1: + dim_in: 1280 + dim_out: 1000 + diff --git a/classification/configs/models/MCT-NAS/MCT-NAS-C.yaml b/classification/configs/models/MCT-NAS/MCT-NAS-C.yaml new file mode 100644 index 0000000..df1d415 --- /dev/null +++ b/classification/configs/models/MCT-NAS/MCT-NAS-C.yaml @@ -0,0 +1,32 @@ +backbone: + # name: [stride, in_channels, out_channels, expand_ratio, op] + layer0: [2, 3, 32, 1, 'conv3x3'] + layer1: [1, 32, 16, 1, 'ir_3x3_nse'] + layer2: [2, 16, 32, 3, 'ir_3x3_se'] + layer3: [1, 32, 32, 1, 'id'] + layer4: [1, 32, 32, 3, 'ir_3x3_se'] + layer5: [1, 32, 32, 3, 'ir_3x3_se'] + layer6: [2, 32, 40, 3, 'ir_5x5_se'] + layer7: [1, 40, 40, 3, 'ir_5x5_se'] + layer8: [1, 40, 40, 3, 'ir_3x3_se'] + layer9: [1, 40, 40, 3, 'ir_3x3_se'] + layer10: [2, 40, 80, 3, 'ir_5x5_se'] + layer11: [1, 80, 80, 3, 'ir_7x7_se'] + layer12: [1, 80, 80, 1, 'id'] + layer13: [1, 80, 80, 3, 'ir_5x5_se'] + layer14: [1, 80, 96, 3, 'ir_7x7_se'] + layer15: [1, 96, 96, 3, 'ir_7x7_se'] + layer16: [1, 96, 96, 3, 'id'] + layer17: [1, 96, 96, 3, 'ir_3x3_se'] + layer18: [2, 96, 192, 3, 'ir_5x5_se'] + layer19: [1, 192, 192, 3, 'ir_5x5_se'] + layer20: [1, 192, 192, 3, 'ir_5x5_se'] + layer21: [1, 192, 192, 3, 'ir_7x7_se'] + layer22: [1, 192, 320, 6, 'ir_7x7_se'] + layer23: [1, 320, 1280, 1, 'conv2d'] + layer24: [1, 1280, 1280, 1, 'gavgp'] +head: + linear1: + dim_in: 1280 + dim_out: 1000 + diff --git a/classification/configs/models/MobileNetV2/MobileNetV2.yaml b/classification/configs/models/MobileNetV2/MobileNetV2.yaml new file mode 100644 index 0000000..66681b5 --- /dev/null +++ b/classification/configs/models/MobileNetV2/MobileNetV2.yaml @@ -0,0 +1,17 @@ +backbone: + # name: [n, stride, in_channels, out_channels, expand_ratio, op] + conv_stem: [1, 2, 3, 32, 1, 'conv3x3'] + stage_0: [1, 1, 32, 16, 1, 'ir_3x3'] + stage1: [2, 2, 16, 24, 6, 'ir_3x3'] + stage2: [3, 2, 24, 32, 6, 'ir_3x3'] + stage3: [4, 2, 32, 64, 6, 'ir_3x3'] + stage4: [3, 1, 64, 96, 6, 'ir_3x3'] + stage5: [3, 2, 96, 160, 6, 'ir_3x3'] + stage6: [1, 1, 160, 320, 6, 'ir_3x3'] + conv_out: [1, 1, 320, 1280, 1, 'conv1x1'] + gavg_pool: [1, 1280, 1280, 1, 'gavgp'] +head: + linear1: + dim_in: 1280 + dim_out: 1000 + diff --git a/classification/configs/models/MobileNetV2/MobileNetV2_cifar10.yaml b/classification/configs/models/MobileNetV2/MobileNetV2_cifar10.yaml new file mode 100644 index 0000000..eb621f4 --- /dev/null +++ b/classification/configs/models/MobileNetV2/MobileNetV2_cifar10.yaml @@ -0,0 +1,13 @@ +backbone: + # name: [n, stride, in_channels, out_channels, expand_ratio, op] + conv_stem: [1, 1, 3, 32, 1, 'conv3x3'] + stage_0: [2, 2, 32, 64, 6, 'ir_3x3'] + stage1: [3, 2, 64, 128, 6, 'ir_3x3'] + stage2: [3, 2, 128, 256, 6, 'ir_3x3'] + conv_out: [1, 1, 256, 1280, 1, 'conv1x1'] + gavg_pool: [1, 1280, 1280, 1, 'gavgp'] +head: + linear1: + dim_in: 1280 + dim_out: 10 + diff --git a/classification/configs/models/PC-DARTS/PC-DARTS_imagenet.yaml b/classification/configs/models/PC-DARTS/PC-DARTS_imagenet.yaml new file mode 100644 index 0000000..5ca4886 --- /dev/null +++ b/classification/configs/models/PC-DARTS/PC-DARTS_imagenet.yaml @@ -0,0 +1,6 @@ +genotype: + normal: "[('skip_connect', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 0), ('skip_connect', 1), ('sep_conv_3x3', 1), ('sep_conv_3x3', 3), ('sep_conv_3x3', 1), ('dil_conv_5x5', 4)]" + reduce: "[('sep_conv_3x3', 0), ('skip_connect', 1), ('dil_conv_5x5', 2), ('max_pool_3x3', 1), ('sep_conv_3x3', 2), ('sep_conv_3x3', 1), ('sep_conv_5x5', 0), ('sep_conv_3x3', 3)]" + + + diff --git a/classification/configs/models/ProxylessNAS/ProxylessR-mobile.yaml b/classification/configs/models/ProxylessNAS/ProxylessR-mobile.yaml new file mode 100644 index 0000000..199a494 --- /dev/null +++ b/classification/configs/models/ProxylessNAS/ProxylessR-mobile.yaml @@ -0,0 +1,32 @@ +backbone: + # name: [stride, in_channels, out_channels, expand_ratio, op] + conv_stem: [2, 3, 32, 1, 'conv3x3'] + stage_0: [1, 32, 16, 1, 'ir_3x3'] + stage1_1: [2, 16, 32, 3, 'ir_5x5'] + stage1_2: [1, 32 , 32, 3, 'ir_3x3'] + stage1_3: [1, 32, 32, 1, 'id'] + stage1_4: [1, 32, 32, 1, 'id'] + stage2_1: [2, 32, 40, 3, 'ir_7x7'] + stage2_2: [1, 40, 40, 3, 'ir_3x3'] + stage2_3: [1, 40, 40, 3, 'ir_5x5'] + stage2_4: [1, 40, 40, 3, 'ir_5x5'] + stage3_1: [2, 40, 80, 6, 'ir_7x7'] + stage3_2: [1, 80, 80, 3, 'ir_5x5'] + stage3_3: [1, 80, 80, 3, 'ir_5x5'] + stage3_4: [1, 80, 80, 3, 'ir_5x5'] + stage3_5: [1, 80, 96, 6, 'ir_5x5'] + stage3_6: [1, 96, 96, 3, 'ir_5x5'] + stage3_7: [1, 96, 96, 3, 'ir_5x5'] + stage3_8: [1, 96, 96, 3, 'ir_5x5'] + stage4_1: [2, 96, 192, 6, 'ir_7x7'] + stage4_2: [1, 192, 192, 6, 'ir_7x7'] + stage4_3: [1, 192, 192, 3, 'ir_7x7'] + stage4_4: [1, 192, 192, 3, 'ir_7x7'] + stage5: [1, 192, 320, 6, 'ir_7x7'] + conv_out: [1, 320, 1280, 1, 'conv1x1'] + gavg_pool: [1, 1280, 1280, 1, 'gavgp'] +head: + linear1: + dim_in: 1280 + dim_out: 1000 + diff --git a/classification/configs/models/ResNet/resnet-50.yaml b/classification/configs/models/ResNet/resnet-50.yaml new file mode 100644 index 0000000..4ff20f4 --- /dev/null +++ b/classification/configs/models/ResNet/resnet-50.yaml @@ -0,0 +1,14 @@ +backbone: + # name: [n, stride, in_channels, out_channels, expand_ratio, op] + conv_stem: [1, 2, 3, 64, 1, 'conv7x7'] + pool1: [1, 2, 64, 64, 1, 'maxp_3x3'] + stage1: [3, 1, 64, 256, 1, 'res_3x3', {'planes': 64}] + stage2: [4, 2, 256, 512, 1, 'res_3x3'] + stage3: [6, 2, 512, 1024, 1, 'res_3x3'] + stage4: [3, 2, 1024, 2048, 1, 'res_3x3'] + gavg_pool: [1, 2048, 2048, 1, 'gavgp'] +head: + linear1: + dim_in: 2048 + dim_out: 1000 + diff --git a/classification/configs/models/ResNet/resnext-50_32x4d.yaml b/classification/configs/models/ResNet/resnext-50_32x4d.yaml new file mode 100644 index 0000000..7116c19 --- /dev/null +++ b/classification/configs/models/ResNet/resnext-50_32x4d.yaml @@ -0,0 +1,14 @@ +backbone: + # name: [n, stride, in_channels, out_channels, expand_ratio, op] + conv_stem: [1, 2, 3, 64, 1, 'conv7x7'] + pool1: [1, 2, 64, 64, 1, 'maxp_3x3'] + stage1: [3, 1, 64, 256, 1, 'resnext_3x3', {'planes': 64}] + stage2: [4, 2, 256, 512, 1, 'resnext_3x3'] + stage3: [6, 2, 512, 1024, 1, 'resnext_3x3'] + stage4: [3, 2, 1024, 2048, 1, 'resnext_3x3'] + gavg_pool: [1, 2048, 2048, 1, 'gavgp'] +head: + linear1: + dim_in: 2048 + dim_out: 1000 + diff --git a/classification/configs/models/ResNet/seresnet-50.yaml b/classification/configs/models/ResNet/seresnet-50.yaml new file mode 100644 index 0000000..eb0294c --- /dev/null +++ b/classification/configs/models/ResNet/seresnet-50.yaml @@ -0,0 +1,14 @@ +backbone: + # name: [n, stride, in_channels, out_channels, expand_ratio, op] + conv_stem: [1, 2, 3, 64, 1, 'conv7x7'] + pool1: [1, 2, 64, 64, 1, 'maxp_3x3'] + stage1: [3, 1, 64, 256, 1, 'res_3x3_se', {'planes': 64}] + stage2: [4, 2, 256, 512, 1, 'res_3x3_se'] + stage3: [6, 2, 512, 1024, 1, 'res_3x3_se'] + stage4: [3, 2, 1024, 2048, 1, 'res_3x3_se'] + gavg_pool: [1, 2048, 2048, 1, 'gavgp'] +head: + linear1: + dim_in: 2048 + dim_out: 1000 + diff --git a/classification/configs/models/ResNet/seresnext-50_32x4d.yaml b/classification/configs/models/ResNet/seresnext-50_32x4d.yaml new file mode 100644 index 0000000..df73526 --- /dev/null +++ b/classification/configs/models/ResNet/seresnext-50_32x4d.yaml @@ -0,0 +1,14 @@ +backbone: + # name: [n, stride, in_channels, out_channels, expand_ratio, op] + conv_stem: [1, 2, 3, 64, 1, 'conv7x7'] + pool1: [1, 2, 64, 64, 1, 'maxp_3x3'] + stage1: [3, 1, 64, 256, 1, 'resnext_3x3_se', {'planes': 64}] + stage2: [4, 2, 256, 512, 1, 'resnext_3x3_se'] + stage3: [6, 2, 512, 1024, 1, 'resnext_3x3_se'] + stage4: [3, 2, 1024, 2048, 1, 'resnext_3x3_se'] + gavg_pool: [1, 2048, 2048, 1, 'gavgp'] +head: + linear1: + dim_in: 2048 + dim_out: 1000 + diff --git a/classification/configs/models/VGG/vgg16_cifar10.yaml b/classification/configs/models/VGG/vgg16_cifar10.yaml new file mode 100644 index 0000000..b06ac05 --- /dev/null +++ b/classification/configs/models/VGG/vgg16_cifar10.yaml @@ -0,0 +1,31 @@ +backbone: + # name: [n, stride, in_channels, out_channels, expand_ratio, op] + conv0: [1, 1, 3, 64, 1, 'conv3x3'] + conv1: [1, 1, 64, 64, 1, 'conv3x3'] + pool1: [1, 2, 64, 64, 1, 'maxp'] + + conv2: [1, 1, 64, 128, 1, 'conv3x3'] + conv3: [1, 1, 128, 128, 1, 'conv3x3'] + pool2: [1, 2, 128, 128, 1, 'maxp'] + + conv4: [1, 1, 128, 256, 1, 'conv3x3'] + conv5: [1, 1, 256, 256, 1, 'conv3x3'] + conv6: [1, 1, 256, 256, 1, 'conv3x3'] + pool3: [1, 2, 256, 256, 1, 'maxp'] + + conv7: [1, 1, 256, 512, 1, 'conv3x3'] + conv8: [1, 1, 512, 512, 1, 'conv3x3'] + conv9: [1, 1, 512, 512, 1, 'conv3x3'] + pool4: [1, 2, 512, 512, 1, 'maxp'] + + conv10: [1, 1, 512, 512, 1, 'conv3x3'] + conv11: [1, 1, 512, 512, 1, 'conv3x3'] + conv12: [1, 1, 512, 512, 1, 'conv3x3'] + pool5: [1, 1, 512, 512, 1, 'maxp'] + fc: [1, 1, 512, 512, 1, 'linear_relu'] + #pool6: [1, 1, 512, 512, 1, 'gavgp'] + +head: + linear1: + dim_in: 512 + dim_out: 10 diff --git a/classification/configs/models/VGG/vgg16_cifar100.yaml b/classification/configs/models/VGG/vgg16_cifar100.yaml new file mode 100644 index 0000000..d085555 --- /dev/null +++ b/classification/configs/models/VGG/vgg16_cifar100.yaml @@ -0,0 +1,31 @@ +backbone: + # name: [n, stride, in_channels, out_channels, expand_ratio, op] + conv0: [1, 1, 3, 64, 1, 'conv3x3'] + conv1: [1, 1, 64, 64, 1, 'conv3x3'] + pool1: [1, 2, 64, 64, 1, 'maxp'] + + conv2: [1, 1, 64, 128, 1, 'conv3x3'] + conv3: [1, 1, 128, 128, 1, 'conv3x3'] + pool2: [1, 2, 128, 128, 1, 'maxp'] + + conv4: [1, 1, 128, 256, 1, 'conv3x3'] + conv5: [1, 1, 256, 256, 1, 'conv3x3'] + conv6: [1, 1, 256, 256, 1, 'conv3x3'] + pool3: [1, 2, 256, 256, 1, 'maxp'] + + conv7: [1, 1, 256, 512, 1, 'conv3x3'] + conv8: [1, 1, 512, 512, 1, 'conv3x3'] + conv9: [1, 1, 512, 512, 1, 'conv3x3'] + pool4: [1, 2, 512, 512, 1, 'maxp'] + + conv10: [1, 1, 512, 512, 1, 'conv3x3'] + conv11: [1, 1, 512, 512, 1, 'conv3x3'] + conv12: [1, 1, 512, 512, 1, 'conv3x3'] + pool5: [1, 1, 512, 512, 1, 'maxp'] + fc: [1, 1, 512, 512, 1, 'linear_relu'] + #pool6: [1, 1, 512, 512, 1, 'gavgp'] + +head: + linear1: + dim_in: 512 + dim_out: 100 diff --git a/classification/configs/strategies/CIFAR/cifar.yaml b/classification/configs/strategies/CIFAR/cifar.yaml new file mode 100644 index 0000000..9d65ac9 --- /dev/null +++ b/classification/configs/strategies/CIFAR/cifar.yaml @@ -0,0 +1,35 @@ +dataset: cifar10 +aa: null +batch_size: 128 +color_jitter: 0.0 +cutout_length: 16 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +drop: 0.0 +epochs: 600 +log_interval: 50 +lr: 0.1 +smoothing: 0.0 +min_lr: 1.0e-06 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +sched: cosine +seed: 42 +warmup_epochs: 0 +warmup_lr: 0.2 +weight_decay: 1.0e-04 +workers: 4 +sgd_no_nesterov: True +opt_no_filter: True +# dyrep +dyrep: False +dyrep_adjust_interval: 15 +dyrep_recal_bn_every_epoch: False +dyrep_max_adjust_epochs: 500 diff --git a/classification/configs/strategies/DARTS/darts.yaml b/classification/configs/strategies/DARTS/darts.yaml new file mode 100644 index 0000000..8235b66 --- /dev/null +++ b/classification/configs/strategies/DARTS/darts.yaml @@ -0,0 +1,27 @@ +aa: null +batch_size: 128 +color_jitter: 0.0 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +drop: 0.0 +epochs: 250 +log_interval: 50 +lr: 0.1 +min_lr: 1.0e-05 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +auxiliary: True +auxiliary_weight: 0.4 +sched: cosine +seed: 42 +warmup_epochs: 0 +warmup_lr: 0.2 +weight_decay: 3.0e-05 +workers: 4 diff --git a/classification/configs/strategies/DARTS/darts_cifar10.yaml b/classification/configs/strategies/DARTS/darts_cifar10.yaml new file mode 100644 index 0000000..cbe4d1b --- /dev/null +++ b/classification/configs/strategies/DARTS/darts_cifar10.yaml @@ -0,0 +1,30 @@ +dataset: cifar10 +aa: null +batch_size: 96 +color_jitter: 0.0 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +drop: 0.0 +drop_path_rate: 0.2 +drop_path_strategy: linear +epochs: 600 +log_interval: 50 +lr: 0.175 +min_lr: 1.0e-05 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 1.0 +auxiliary: True +auxiliary_weight: 0.4 +sched: cosine +seed: 42 +warmup_epochs: 0 +warmup_lr: 0.2 +weight_decay: 3.0e-05 +workers: 4 diff --git a/classification/configs/strategies/DARTS/pc-darts.yaml b/classification/configs/strategies/DARTS/pc-darts.yaml new file mode 100644 index 0000000..7c5eb69 --- /dev/null +++ b/classification/configs/strategies/DARTS/pc-darts.yaml @@ -0,0 +1,27 @@ +aa: null +batch_size: 128 +color_jitter: 0.4 +decay_by_epoch: True +decay_epochs: 3 +decay_rate: 0.967 +drop: 0.0 +epochs: 250 +log_interval: 50 +lr: 0.5 +min_lr: 1.0e-05 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +sched: linear +auxiliary: True +auxiliary_weight: 0.4 +seed: 42 +warmup_epochs: 5 +warmup_lr: 0.0125 +weight_decay: 3.0e-05 +workers: 4 diff --git a/classification/configs/strategies/DyRep/cifar.yaml b/classification/configs/strategies/DyRep/cifar.yaml new file mode 100644 index 0000000..9d65ac9 --- /dev/null +++ b/classification/configs/strategies/DyRep/cifar.yaml @@ -0,0 +1,35 @@ +dataset: cifar10 +aa: null +batch_size: 128 +color_jitter: 0.0 +cutout_length: 16 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +drop: 0.0 +epochs: 600 +log_interval: 50 +lr: 0.1 +smoothing: 0.0 +min_lr: 1.0e-06 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +sched: cosine +seed: 42 +warmup_epochs: 0 +warmup_lr: 0.2 +weight_decay: 1.0e-04 +workers: 4 +sgd_no_nesterov: True +opt_no_filter: True +# dyrep +dyrep: False +dyrep_adjust_interval: 15 +dyrep_recal_bn_every_epoch: False +dyrep_max_adjust_epochs: 500 diff --git a/classification/configs/strategies/DyRep/mbv1.yaml b/classification/configs/strategies/DyRep/mbv1.yaml new file mode 100644 index 0000000..17af0c1 --- /dev/null +++ b/classification/configs/strategies/DyRep/mbv1.yaml @@ -0,0 +1,36 @@ +# Note: +# Differences between our strategy and the official dbb strategy: +# 1. we do not use PCA lighting augmentation +dataset: imagenet +aa: null +batch_size: 32 +color_jitter: 0.06 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +drop: 0.0 +drop-path-rate: 0.0 +epochs: 90 +log_interval: 50 +lr: 0.1 +smoothing: 0.0 +min_lr: 1.0e-07 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +sched: cosine +seed: 42 +warmup_epochs: 0 +warmup_lr: 0.2 +weight_decay: 4.0e-05 +workers: 4 +# dyrep +dyrep: False +dyrep_adjust_interval: 5 +dyrep_recal_bn_every_epoch: False +dyrep_max_adjust_epochs: 70 diff --git a/classification/configs/strategies/DyRep/repvgg_baseline.yaml b/classification/configs/strategies/DyRep/repvgg_baseline.yaml new file mode 100644 index 0000000..d6704e0 --- /dev/null +++ b/classification/configs/strategies/DyRep/repvgg_baseline.yaml @@ -0,0 +1,32 @@ +dataset: imagenet +aa: null +batch_size: 32 +color_jitter: 0.0 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +drop: 0.0 +epochs: 120 +log_interval: 50 +lr: 0.1 +smoothing: 0.0 +min_lr: 1.0e-07 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +sched: cosine +seed: 42 +warmup_epochs: 0 +warmup_lr: 0.2 +weight_decay: 1.0e-04 +workers: 4 +# dyrep +dyrep: False +dyrep_adjust_interval: 5 +dyrep_recal_bn_every_epoch: False +dyrep_max_adjust_epochs: 100 diff --git a/classification/configs/strategies/DyRep/repvgg_strong.yaml b/classification/configs/strategies/DyRep/repvgg_strong.yaml new file mode 100644 index 0000000..38b8f1a --- /dev/null +++ b/classification/configs/strategies/DyRep/repvgg_strong.yaml @@ -0,0 +1,32 @@ +dataset: imagenet +aa: rand-m9-mstd0.5 +batch_size: 32 +color_jitter: 0.0 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +drop: 0.0 +epochs: 200 +log_interval: 50 +lr: 0.1 +smoothing: 0.1 +min_lr: 1.0e-07 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: pixel +reprob: 0.2 +sched: cosine +seed: 42 +warmup_epochs: 5 +warmup_lr: 1.0e-6 +weight_decay: 1.0e-04 +workers: 4 +# dyrep +dyrep: True +dyrep_adjust_interval: 5 +dyrep_recal_bn_every_epoch: False +dyrep_max_adjust_epochs: 100 diff --git a/classification/configs/strategies/DyRep/resnet.yaml b/classification/configs/strategies/DyRep/resnet.yaml new file mode 100644 index 0000000..6a66394 --- /dev/null +++ b/classification/configs/strategies/DyRep/resnet.yaml @@ -0,0 +1,36 @@ +# Note: +# Differences between our strategy and the official dbb strategy: +# 1. we do not use PCA lighting augmentation +dataset: imagenet +aa: null +batch_size: 32 +color_jitter: 0.06 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +drop: 0.0 +drop-path-rate: 0.0 +epochs: 120 +log_interval: 50 +lr: 0.1 +smoothing: 0.0 +min_lr: 1.0e-07 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +sched: cosine +seed: 42 +warmup_epochs: 0 +warmup_lr: 0.2 +weight_decay: 1.0e-04 +workers: 4 +# dyrep +dyrep: False +dyrep_adjust_interval: 5 +dyrep_recal_bn_every_epoch: False +dyrep_max_adjust_epochs: 100 diff --git a/classification/configs/strategies/MBV2/mbv2.yaml b/classification/configs/strategies/MBV2/mbv2.yaml new file mode 100644 index 0000000..c9d7a2e --- /dev/null +++ b/classification/configs/strategies/MBV2/mbv2.yaml @@ -0,0 +1,25 @@ +aa: null +batch_size: 64 +color_jitter: 0.0 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +drop: 0.0 +epochs: 300 +log_interval: 50 +lr: 0.1 +min_lr: 1.0e-05 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +sched: cosine +seed: 42 +warmup_epochs: 5 +warmup_lr: 0.2 +weight_decay: 5.0e-05 +workers: 4 diff --git a/classification/configs/strategies/MBV2/mbv2_cifar10.yaml b/classification/configs/strategies/MBV2/mbv2_cifar10.yaml new file mode 100644 index 0000000..08deb40 --- /dev/null +++ b/classification/configs/strategies/MBV2/mbv2_cifar10.yaml @@ -0,0 +1,27 @@ +dataset: cifar10 +aa: null +batch_size: 256 +color_jitter: 0.0 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +drop: 0.0 +epochs: 50 +log_interval: 50 +lr: 0.1 +smoothing: 0.0 +min_lr: 1.0e-05 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +sched: cosine +seed: 42 +warmup_epochs: 0 +warmup_lr: 0.2 +weight_decay: 5.0e-04 +workers: 4 diff --git a/classification/configs/strategies/MBV2/mbv2_se_aa.yaml b/classification/configs/strategies/MBV2/mbv2_se_aa.yaml new file mode 100644 index 0000000..efc097f --- /dev/null +++ b/classification/configs/strategies/MBV2/mbv2_se_aa.yaml @@ -0,0 +1,24 @@ +aa: rand-m9-mstd0.5 +batch_size: 96 +decay_by_epoch: false +decay_epochs: 2.4 +decay_rate: 0.97 +drop: 0.2 +epochs: 450 +log_interval: 50 +lr: 0.048 +min_lr: 1.0e-05 +model_ema: true +model_ema_decay: 0.9999 +momentum: 0.9 +opt: rmsproptf +opt_betas: null +opt_eps: 0.001 +remode: pixel +reprob: 0.2 +sched: step +seed: 42 +warmup_epochs: 3 +warmup_lr: 1.0e-06 +weight_decay: 1.0e-05 +workers: 4 diff --git a/classification/configs/strategies/deit/deit_tiny.yaml b/classification/configs/strategies/deit/deit_tiny.yaml new file mode 100644 index 0000000..9a01b39 --- /dev/null +++ b/classification/configs/strategies/deit/deit_tiny.yaml @@ -0,0 +1,40 @@ +aa: rand-m9-mstd0.5 +batch_size: 128 +color_jitter: 0.0 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +# dropout +drop: 0.0 +drop_path_rate: 0.1 + +epochs: 300 +log_interval: 50 +lr: 1.e-3 +min_lr: 1.0e-05 +model_ema: False +model_ema_decay: 0.99996 +momentum: 0.9 +opt: adamw +opt_betas: null +opt_eps: 1.0e-08 + +interpolation: 'bicubic' + +# random erase +remode: pixel +reprob: 0.25 + +# mixup +mixup: 0.8 +cutmix: 1.0 +mixup_prob: 1.0 +mixup_switch_prob: 0.5 +mixup_mode: 'batch' + +sched: cosine +seed: 42 +warmup_epochs: 5 +warmup_lr: 1.e-6 +weight_decay: 0.05 +workers: 8 diff --git a/classification/configs/strategies/distill/diffkd/diffkd_b1.yaml b/classification/configs/strategies/distill/diffkd/diffkd_b1.yaml new file mode 100644 index 0000000..a722d37 --- /dev/null +++ b/classification/configs/strategies/distill/diffkd/diffkd_b1.yaml @@ -0,0 +1,36 @@ +aa: null +batch_size: 32 +color_jitter: 0.0 +decay_by_epoch: True +decay_epochs: 30 +decay_rate: 0.1 +drop: 0.0 +epochs: 100 +log_interval: 50 +lr: 0.1 +min_lr: 1.0e-05 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +smoothing: 0.0 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +sched: step +seed: 42 +warmup_epochs: 0 +warmup_lr: 0.2 +weight_decay: 1.0e-4 +workers: 8 +# kd +kd: 'diffkd' +ori_loss_weight: 1. +kd_loss_weight: 1. +teacher_model: 'tv_resnet34' +teacher_pretrained: True +kd_loss_kwargs: + use_ae: True + ae_channels: 1024 + tau: 1 \ No newline at end of file diff --git a/classification/configs/strategies/distill/diffkd/diffkd_b2.yaml b/classification/configs/strategies/distill/diffkd/diffkd_b2.yaml new file mode 100644 index 0000000..2a3e8cc --- /dev/null +++ b/classification/configs/strategies/distill/diffkd/diffkd_b2.yaml @@ -0,0 +1,34 @@ +aa: rand-m9-mstd0.5 +batch_size: 96 +decay_by_epoch: false +decay_epochs: 2.4 +decay_rate: 0.97 +drop: 0.2 +epochs: 450 +log_interval: 50 +lr: 0.048 +min_lr: 1.0e-05 +model_ema: true +model_ema_decay: 0.9999 +momentum: 0.9 +opt: rmsproptf +opt_betas: null +opt_eps: 0.001 +remode: pixel +reprob: 0.2 +sched: step +seed: 42 +warmup_epochs: 3 +warmup_lr: 1.0e-06 +weight_decay: 1.0e-05 +workers: 4 +# kd +kd: 'dist' +ori_loss_weight: 1. +kd_loss_weight: 1. +teacher_model: 'tv_resnet34' +teacher_pretrained: True +kd_loss_kwargs: + use_ae: True + ae_channels: 1024 + tau: 1 \ No newline at end of file diff --git a/classification/configs/strategies/distill/diffkd/diffkd_b3.yaml b/classification/configs/strategies/distill/diffkd/diffkd_b3.yaml new file mode 100644 index 0000000..cf3d2a4 --- /dev/null +++ b/classification/configs/strategies/distill/diffkd/diffkd_b3.yaml @@ -0,0 +1,53 @@ +aa: rand-m9-mstd0.5 +batch_size: 64 # x 16 gpus = 1024bs +color_jitter: 0.4 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +# dropout +drop: 0.0 +drop_path_rate: 0.2 + +epochs: 300 +log_interval: 50 +lr: 1.e-3 +min_lr: 5.0e-06 +model_ema: False +model_ema_decay: 0.999 +momentum: 0.9 +opt: adamw +opt_betas: null +opt_eps: 1.0e-08 +clip_grad_norm: true +clip_grad_max_norm: 5.0 + +interpolation: 'bicubic' + +# random erase +remode: pixel +reprob: 0.25 + +# mixup +mixup: 0.8 +cutmix: 1.0 +mixup_prob: 1.0 +mixup_switch_prob: 0.5 +mixup_mode: 'batch' + +sched: cosine +seed: 42 +warmup_epochs: 20 +warmup_lr: 5.e-7 +weight_decay: 0.04 +workers: 16 + +# kd +#kd: 'dist' +ori_loss_weight: 1. +kd_loss_weight: 1. +teacher_model: 'timm_swin_large_patch4_window7_224' +teacher_pretrained: True +kd_loss_kwargs: + use_ae: True + ae_channels: 1024 + tau: 1 \ No newline at end of file diff --git a/classification/configs/strategies/distill/dist_b2.yaml b/classification/configs/strategies/distill/dist_b2.yaml new file mode 100644 index 0000000..7b0ce71 --- /dev/null +++ b/classification/configs/strategies/distill/dist_b2.yaml @@ -0,0 +1,30 @@ +aa: rand-m9-mstd0.5 +batch_size: 96 +decay_by_epoch: false +decay_epochs: 2.4 +decay_rate: 0.97 +drop: 0.2 +epochs: 450 +log_interval: 50 +lr: 0.048 +min_lr: 1.0e-05 +model_ema: true +model_ema_decay: 0.9999 +momentum: 0.9 +opt: rmsproptf +opt_betas: null +opt_eps: 0.001 +remode: pixel +reprob: 0.2 +sched: step +seed: 42 +warmup_epochs: 3 +warmup_lr: 1.0e-06 +weight_decay: 1.0e-05 +workers: 4 +# kd +kd: 'dist' +ori_loss_weight: 1. +kd_loss_weight: 2. +teacher_model: 'tv_resnet34' +teacher_pretrained: True diff --git a/classification/configs/strategies/distill/dist_cifar.yaml b/classification/configs/strategies/distill/dist_cifar.yaml new file mode 100644 index 0000000..2846948 --- /dev/null +++ b/classification/configs/strategies/distill/dist_cifar.yaml @@ -0,0 +1,39 @@ +dataset: cifar100 +image_mean: [0.5071, 0.4867, 0.4408] +image_std: [0.2675, 0.2565, 0.2761] +aa: null +batch_size: 64 +color_jitter: 0.0 +cutout_length: 0 +decay_by_epoch: True +decay_epochs: 30 +decay_rate: 0.1 +drop: 0.0 +epochs: 240 +log_interval: 50 +lr: 0.05 +smoothing: 0.0 +min_lr: 1.0e-06 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +sched: step +seed: 42 +warmup_epochs: 120 +warmup_lr: 0.05 +weight_decay: 5.0e-04 +workers: 4 +sgd_no_nesterov: True +opt_no_filter: True +# kd +ori_loss_weight: 1. +kd_loss_weight: 2. +kd: dist_t4 +teacher_model: cifar_wrn_40_2 +teacher_pretrained: True +teacher_ckpt: ./data/saved_ckpts/wrn_40_2_vanilla/ckpt_epoch_240.pth diff --git a/classification/configs/strategies/distill/resnet_dist.yaml b/classification/configs/strategies/distill/resnet_dist.yaml new file mode 100644 index 0000000..c81764e --- /dev/null +++ b/classification/configs/strategies/distill/resnet_dist.yaml @@ -0,0 +1,32 @@ +aa: null +batch_size: 32 +color_jitter: 0.0 +decay_by_epoch: True +decay_epochs: 30 +decay_rate: 0.1 +drop: 0.0 +epochs: 100 +log_interval: 50 +lr: 0.1 +min_lr: 1.0e-05 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +smoothing: 0.0 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +sched: step +seed: 42 +warmup_epochs: 0 +warmup_lr: 0.2 +weight_decay: 1.0e-4 +workers: 8 +# kd +kd: 'dist' +ori_loss_weight: 1. +kd_loss_weight: 2. +teacher_model: 'tv_resnet34' +teacher_pretrained: True diff --git a/classification/configs/strategies/distill/resnet_kdt4.yaml b/classification/configs/strategies/distill/resnet_kdt4.yaml new file mode 100644 index 0000000..86bb811 --- /dev/null +++ b/classification/configs/strategies/distill/resnet_kdt4.yaml @@ -0,0 +1,32 @@ +aa: null +batch_size: 32 +color_jitter: 0.0 +decay_by_epoch: True +decay_epochs: 30 +decay_rate: 0.1 +drop: 0.0 +epochs: 100 +log_interval: 50 +lr: 0.1 +min_lr: 1.0e-05 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +smoothing: 0.0 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +sched: step +seed: 42 +warmup_epochs: 0 +warmup_lr: 0.2 +weight_decay: 1.0e-4 +workers: 8 +# kd +kd: 'kdt4' +ori_loss_weight: 0.9 +kd_loss_weight: 1.0 +teacher_model: 'tv_resnet34' +teacher_pretrained: True diff --git a/classification/configs/strategies/lightvit/config.yaml b/classification/configs/strategies/lightvit/config.yaml new file mode 100644 index 0000000..d54a92e --- /dev/null +++ b/classification/configs/strategies/lightvit/config.yaml @@ -0,0 +1,40 @@ +aa: rand-m9-mstd0.5 +batch_size: 128 +color_jitter: 0.3 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +# dropout +drop: 0.0 +drop_path_rate: 0.1 + +epochs: 300 +log_interval: 50 +lr: 1.e-3 +min_lr: 1.0e-06 +model_ema: False +model_ema_decay: 0.999 +momentum: 0.9 +opt: adamw +opt_betas: null +opt_eps: 1.0e-08 + +interpolation: 'bicubic' + +# random erase +remode: pixel +reprob: 0.25 + +# mixup +mixup: 0.2 +cutmix: 1.0 +mixup_prob: 1.0 +mixup_switch_prob: 0.5 +mixup_mode: 'batch' + +sched: cosine +seed: 42 +warmup_epochs: 20 +warmup_lr: 1.e-7 +weight_decay: 0.04 +workers: 16 diff --git a/classification/configs/strategies/local_mamba/config.yaml b/classification/configs/strategies/local_mamba/config.yaml new file mode 100644 index 0000000..d54a92e --- /dev/null +++ b/classification/configs/strategies/local_mamba/config.yaml @@ -0,0 +1,40 @@ +aa: rand-m9-mstd0.5 +batch_size: 128 +color_jitter: 0.3 +decay_by_epoch: false +decay_epochs: 3 +decay_rate: 0.967 +# dropout +drop: 0.0 +drop_path_rate: 0.1 + +epochs: 300 +log_interval: 50 +lr: 1.e-3 +min_lr: 1.0e-06 +model_ema: False +model_ema_decay: 0.999 +momentum: 0.9 +opt: adamw +opt_betas: null +opt_eps: 1.0e-08 + +interpolation: 'bicubic' + +# random erase +remode: pixel +reprob: 0.25 + +# mixup +mixup: 0.2 +cutmix: 1.0 +mixup_prob: 1.0 +mixup_switch_prob: 0.5 +mixup_mode: 'batch' + +sched: cosine +seed: 42 +warmup_epochs: 20 +warmup_lr: 1.e-7 +weight_decay: 0.04 +workers: 16 diff --git a/classification/configs/strategies/resnet/resnet.yaml b/classification/configs/strategies/resnet/resnet.yaml new file mode 100644 index 0000000..903549a --- /dev/null +++ b/classification/configs/strategies/resnet/resnet.yaml @@ -0,0 +1,26 @@ +aa: null +batch_size: 32 +color_jitter: 0.0 +decay_by_epoch: True +decay_epochs: 30 +decay_rate: 0.1 +drop: 0.0 +epochs: 120 +log_interval: 50 +lr: 0.1 +min_lr: 1.0e-05 +model_ema: false +model_ema_decay: 0.9998 +momentum: 0.9 +smoothing: 0.0 +opt: sgd +opt_betas: null +opt_eps: 1.0e-08 +remode: const +reprob: 0.0 +sched: step +seed: 42 +warmup_epochs: 0 +warmup_lr: 0.2 +weight_decay: 1.0e-4 +workers: 4 diff --git a/classification/configs/strategies/resnet/seresnext.yaml b/classification/configs/strategies/resnet/seresnext.yaml new file mode 100644 index 0000000..26e0c3b --- /dev/null +++ b/classification/configs/strategies/resnet/seresnext.yaml @@ -0,0 +1,24 @@ +aa: rand-m9-mstd0.5 +batch_size: 192 +decay_by_epoch: false +decay_epochs: 2.4 +decay_rate: 0.97 +drop: 0.2 +epochs: 240 +log_interval: 50 +lr: 0.6 +min_lr: 1.0e-06 +model_ema: true +model_ema_decay: 0.9999 +momentum: 0.9 +opt: sgd +opt_betas: null +opt_eps: 0.001 +remode: pixel +reprob: 0.4 +sched: cosine +seed: 42 +warmup_epochs: 5 +warmup_lr: 1.0e-06 +weight_decay: 1.0e-04 +workers: 4 diff --git a/classification/lib/__init__.py b/classification/lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/classification/lib/dataset/__init__.py b/classification/lib/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/classification/lib/dataset/augment_ops.py b/classification/lib/dataset/augment_ops.py new file mode 100644 index 0000000..1eff442 --- /dev/null +++ b/classification/lib/dataset/augment_ops.py @@ -0,0 +1,809 @@ +""" +Random augmentation implemented by https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py +""" +import math +import random +import re +import torch +import numpy as np +import PIL +from PIL import Image, ImageEnhance, ImageOps +import torchvision.transforms.functional as F + + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) +_FILL = (128, 128, 128) +_MAX_LEVEL = 10. +_HPARAMS_DEFAULT = dict(translate_const=250, img_mean=_FILL) +_RAND_TRANSFORMS = [ + 'Distort', + 'Zoom', + 'Blur', + 'Skew', + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'PosterizeTpu', + 'Solarize', + 'SolarizeAdd', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', +] +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) +_RAND_CHOICE_WEIGHTS_0 = { + 'Rotate': 0.3, + 'ShearX': 0.2, + 'ShearY': 0.2, + 'TranslateXRel': 0.1, + 'TranslateYRel': 0.1, + 'Color': .025, + 'Sharpness': 0.025, + 'AutoContrast': 0.025, + 'Solarize': .005, + 'SolarizeAdd': .005, + 'Contrast': .005, + 'Brightness': .005, + 'Equalize': .005, + 'PosterizeTpu': 0, + 'Invert': 0, + 'Distort': 0, + 'Zoom': 0, + 'Blur': 0, + 'Skew': 0, +} + + +def _interpolation(kwargs): + interpolation = kwargs.pop('resample', Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + return interpolation + + +def _check_args_tf(kwargs): + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): + kwargs.pop('fillcolor') + kwargs['resample'] = _interpolation(kwargs) + +# define all kinds of functions + + +def _randomly_negate(v): + return -v if random.random() > 0.5 else v + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def rotate(img, degrees, **kwargs): + kwargs_new = kwargs + kwargs_new.pop('resample') + kwargs_new['resample'] = Image.BICUBIC + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs_new) + if _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs_new) + return img.rotate(degrees, resample=kwargs['resample']) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30. + level = _randomly_negate(level) + return (level,) + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams['translate_const'] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return (level,) + + +def _translate_rel_level_to_arg(level, _hparams): + # range [-0.45, 0.45] + level = (level / _MAX_LEVEL) * 0.45 + level = _randomly_negate(level) + return (level,) + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + return (int((level / _MAX_LEVEL) * 4) + 4,) + + +def _posterize_research_level_to_arg(level, _hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image' + return (4 - int((level / _MAX_LEVEL) * 4),) + + +def _posterize_tpu_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + return (int((level / _MAX_LEVEL) * 4),) + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + return (int((level / _MAX_LEVEL) * 256),) + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return (int((level / _MAX_LEVEL) * 110),) + + +def _distort_level_to_arg(level, _hparams): + return (int((level / _MAX_LEVEL) * 10 + 10),) + + +def _zoom_level_to_arg(level, _hparams): + return ((level / _MAX_LEVEL) * 0.4,) + + +def _blur_level_to_arg(level, _hparams): + level = (level / _MAX_LEVEL) * 0.5 + level = _randomly_negate(level) + return (level,) + + +def _skew_level_to_arg(level, _hparams): + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def distort(img, v, **__): + w, h = img.size + horizontal_tiles = int(0.1 * v) + vertical_tiles = int(0.1 * v) + + width_of_square = int(math.floor(w / float(horizontal_tiles))) + height_of_square = int(math.floor(h / float(vertical_tiles))) + width_of_last_square = w - (width_of_square * (horizontal_tiles - 1)) + height_of_last_square = h - (height_of_square * (vertical_tiles - 1)) + dimensions = [] + + for vertical_tile in range(vertical_tiles): + for horizontal_tile in range(horizontal_tiles): + if vertical_tile == (vertical_tiles - 1) and horizontal_tile == (horizontal_tiles - 1): + dimensions.append([horizontal_tile * width_of_square, + vertical_tile * height_of_square, + width_of_last_square + (horizontal_tile * width_of_square), + height_of_last_square + (height_of_square * vertical_tile)]) + elif vertical_tile == (vertical_tiles - 1): + dimensions.append([horizontal_tile * width_of_square, + vertical_tile * height_of_square, + width_of_square + (horizontal_tile * width_of_square), + height_of_last_square + (height_of_square * vertical_tile)]) + elif horizontal_tile == (horizontal_tiles - 1): + dimensions.append([horizontal_tile * width_of_square, + vertical_tile * height_of_square, + width_of_last_square + (horizontal_tile * width_of_square), + height_of_square + (height_of_square * vertical_tile)]) + else: + dimensions.append([horizontal_tile * width_of_square, + vertical_tile * height_of_square, + width_of_square + (horizontal_tile * width_of_square), + height_of_square + (height_of_square * vertical_tile)]) + last_column = [] + for i in range(vertical_tiles): + last_column.append((horizontal_tiles - 1) + horizontal_tiles * i) + + last_row = range((horizontal_tiles * vertical_tiles) - horizontal_tiles, horizontal_tiles * vertical_tiles) + + polygons = [] + for x1, y1, x2, y2 in dimensions: + polygons.append([x1, y1, x1, y2, x2, y2, x2, y1]) + + polygon_indices = [] + for i in range((vertical_tiles * horizontal_tiles) - 1): + if i not in last_row and i not in last_column: + polygon_indices.append([i, i + 1, i + horizontal_tiles, i + 1 + horizontal_tiles]) + + for a, b, c, d in polygon_indices: + dx = v + dy = v + + x1, y1, x2, y2, x3, y3, x4, y4 = polygons[a] + polygons[a] = [x1, y1, + x2, y2, + x3 + dx, y3 + dy, + x4, y4] + + x1, y1, x2, y2, x3, y3, x4, y4 = polygons[b] + polygons[b] = [x1, y1, + x2 + dx, y2 + dy, + x3, y3, + x4, y4] + + x1, y1, x2, y2, x3, y3, x4, y4 = polygons[c] + polygons[c] = [x1, y1, + x2, y2, + x3, y3, + x4 + dx, y4 + dy] + + x1, y1, x2, y2, x3, y3, x4, y4 = polygons[d] + polygons[d] = [x1 + dx, y1 + dy, + x2, y2, + x3, y3, + x4, y4] + + generated_mesh = [] + for idx, i in enumerate(dimensions): + generated_mesh.append([dimensions[idx], polygons[idx]]) + return img.transform(img.size, PIL.Image.MESH, generated_mesh, resample=PIL.Image.BICUBIC) + + +def zoom(img, v, **__): + #assert 0.1 <= v <= 2 + w, h = img.size + image_zoomed = img.resize((int(round(img.size[0] * v)), + int(round(img.size[1] * v))), + resample=PIL.Image.BICUBIC) + w_zoomed, h_zoomed = image_zoomed.size + + return image_zoomed.crop((math.floor((float(w_zoomed) / 2) - (float(w) / 2)), + math.floor((float(h_zoomed) / 2) - (float(h) / 2)), + math.floor((float(w_zoomed) / 2) + (float(w) / 2)), + math.floor((float(h_zoomed) / 2) + (float(h) / 2)))) + + +def erase(img, v, **__): + #assert 0.1<= v <= 1 + w, h = img.size + w_occlusion = int(w * v) + h_occlusion = int(h * v) + if len(img.getbands()) == 1: + rectangle = PIL.Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion) * 255)) + else: + rectangle = PIL.Image.fromarray(np.uint8(np.random.rand(w_occlusion, h_occlusion, len(img.getbands())) * 255)) + + random_position_x = random.randint(0, w - w_occlusion) + random_position_y = random.randint(0, h - h_occlusion) + img.paste(rectangle, (random_position_x, random_position_y)) + return img + + +def skew(img, v, **__): + #assert -1 <= v <= 1 + w, h = img.size + x1 = 0 + x2 = h + y1 = 0 + y2 = w + original_plane = [(y1, x1), (y2, x1), (y2, x2), (y1, x2)] + max_skew_amount = max(w, h) + max_skew_amount = int(math.ceil(max_skew_amount * v)) + skew_amount = max_skew_amount + new_plane = [(y1 - skew_amount, x1), # Top Left + (y2, x1 - skew_amount), # Top Right + (y2 + skew_amount, x2), # Bottom Right + (y1, x2 + skew_amount)] + matrix = [] + for p1, p2 in zip(new_plane, original_plane): + matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]]) + matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) + + A = np.matrix(matrix, dtype=np.float32) + B = np.array(original_plane).reshape(8) + perspective_skew_coefficients_matrix = np.dot(np.linalg.pinv(A), B) + perspective_skew_coefficients_matrix = np.array(perspective_skew_coefficients_matrix).reshape(8) + + return img.transform(img.size, PIL.Image.PERSPECTIVE, perspective_skew_coefficients_matrix, + resample=PIL.Image.BICUBIC) + + +def blur(img, v, **__): + #assert -3 <= v <= 3 + return img.filter(PIL.ImageFilter.GaussianBlur(v)) + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [AutoAugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_transform(config_str, hparams): + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + config = config_str.split('-') + assert config[0] == 'rand' + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + elif key == 'm': + magnitude = int(val) + elif key == 'n': + num_layers = int(val) + elif key == 'w': + weight_idx = int(val) + else: + assert False, 'Unknown RandAugment config section' + ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) + choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) + + final_result = RandAugment(ra_ops, num_layers, choice_weights=choice_weights) + return final_result + + +LEVEL_TO_ARG = { + 'Distort': _distort_level_to_arg, + 'Zoom': _zoom_level_to_arg, + 'Blur': _blur_level_to_arg, + 'Skew': _skew_level_to_arg, + 'AutoContrast': None, + 'Equalize': None, + 'Invert': None, + 'Rotate': _rotate_level_to_arg, + 'PosterizeOriginal': _posterize_original_level_to_arg, + 'PosterizeResearch': _posterize_research_level_to_arg, + 'PosterizeTpu': _posterize_tpu_level_to_arg, + 'Solarize': _solarize_level_to_arg, + 'SolarizeAdd': _solarize_add_level_to_arg, + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'TranslateX': _translate_abs_level_to_arg, + 'TranslateY': _translate_abs_level_to_arg, + 'TranslateXRel': _translate_rel_level_to_arg, + 'TranslateYRel': _translate_rel_level_to_arg, +} + +NAME_TO_OP = { + 'Distort': distort, + 'Zoom': zoom, + 'Blur': blur, + 'Skew': skew, + 'AutoContrast': auto_contrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'PosterizeOriginal': posterize, + 'PosterizeResearch': posterize, + 'PosterizeTpu': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x_abs, + 'TranslateY': translate_y_abs, + 'TranslateXRel': translate_x_rel, + 'TranslateYRel': translate_y_rel, +} + + +class AutoAugmentOp: + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = dict( + fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, + resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, + ) + self.magnitude_std = self.hparams.get('magnitude_std', 0) + + def __call__(self, img): + if random.random() > self.prob: + return img + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() + return self.aug_fn(img, *level_args, **self.kwargs) + + +class RandAugment: + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + ops = np.random.choice( + self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) + for op in ops: + img = op(img) + return img + + +'''random erasing''' +def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): + # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() + # paths, flip the order so normal is run on CPU if this becomes a problem + # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 + if per_pixel: + return torch.empty(patch_size, dtype=dtype, device=device).normal_() + elif rand_color: + return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() + else: + return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) + + +class RandomResizedCropAndInterpolation: + """Crop the given PIL Image to random size and aspect ratio with random interpolation. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation='bilinear'): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("range should be of kind (min, max)") + + self.scale = scale + self.ratio = ratio + + if interpolation == 'bilinear': + self.interpolation = Image.BILINEAR + elif interpolation == 'bicubic': + self.interpolation = Image.BICUBIC + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for attempt in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + return F.resized_crop(img, i, j, h, w, self.size, interpolation=self.interpolation) + + +class RandomErasing: + """ Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + + This variant of RandomErasing is intended to be applied to either a batch + or single image tensor after it has been normalized by dataset mean and std. + Args: + probability: Probability that the Random Erasing operation will be performed. + min_area: Minimum percentage of erased area wrt input image area. + max_area: Maximum percentage of erased area wrt input image area. + min_aspect: Minimum aspect ratio of erased area. + mode: pixel color mode, one of 'const', 'rand', or 'pixel' + 'const' - erase block is constant color of 0 for all channels + 'rand' - erase block is same per-channel random (normal) color + 'pixel' - erase block is per-pixel random (normal) color + max_count: maximum number of erasing blocks per image, area per box is scaled by count. + per-image count is randomly chosen between 1 and this value. + """ + + def __init__( + self, + probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, + mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): + self.probability = probability + self.min_area = min_area + self.max_area = max_area + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + self.min_count = min_count + self.max_count = max_count or min_count + self.num_splits = num_splits + mode = mode.lower() + self.rand_color = False + self.per_pixel = False + if mode == 'rand': + self.rand_color = True # per block random normal + elif mode == 'pixel': + self.per_pixel = True # per pixel random normal + else: + assert not mode or mode == 'const' + self.device = device + + def _erase(self, img, chan, img_h, img_w, dtype): + if random.random() > self.probability: + return + area = img_h * img_w + count = self.min_count if self.min_count == self.max_count else \ + random.randint(self.min_count, self.max_count) + for _ in range(count): + for attempt in range(10): + target_area = random.uniform(self.min_area, self.max_area) * area / count + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + img[:, top:top + h, left:left + w] = _get_pixels( + self.per_pixel, self.rand_color, (chan, h, w), + dtype=dtype, device=self.device) + break + + def __call__(self, input): + if len(input.size()) == 3: + self._erase(input, *input.size(), input.dtype) + else: + batch_size, chan, img_h, img_w = input.size() + # skip first slice of batch if num_splits is set (for clean portion of samples) + batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 + for i in range(batch_start, batch_size): + self._erase(input[i], chan, img_h, img_w, input.dtype) + return input + + +class Cutout(object): + """Randomly mask out one or more patches from an image. + Args: + n_holes (int): Number of patches to cut out of each image. + length (int): The length (in pixels) of each square patch. + """ + def __init__(self, n_holes=1, length=16): + self.n_holes = n_holes + self.length = length + + def __call__(self, img): + """ + Args: + img (Tensor): Tensor image of size (N, C, H, W). + Returns: + Tensor: Image with n_holes of dimension length x length cut out of it. + """ + if img.ndim == 4: + n = img.size(0) + h = img.size(2) + w = img.size(3) + elif img.ndim == 3: + n = 1 + h = img.size(1) + w = img.size(2) + + mask = np.ones((n, h, w), np.float32) + + for i in range(n): + for _ in range(self.n_holes): + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[i, y1: y2, x1: x2] = 0. + + mask = torch.from_numpy(mask).cuda() + if img.ndim == 4: + mask = mask.unsqueeze(1) + #mask = mask.expand_as(img) + img = img * mask + + return img + + +class Normalize: + def __init__(self, mean, std, inplace=True, use_cuda=True): + self.mean = torch.tensor(mean).view(1, 3, 1, 1) + self.std = torch.tensor(std).view(1, 3, 1, 1) + if use_cuda: + self.mean = self.mean.cuda() + self.std = self.std.cuda() + self.inplace = inplace + + def __call__(self, input): + if self.inplace: + input = input.sub_(self.mean).div_(self.std) + else: + input = (input - self.mean) / self.std + return input + + +class ToNumpy: + def __call__(self, pil_img): + np_img = np.array(pil_img, dtype=np.uint8) + if np_img.ndim < 3: + np_img = np.expand_dims(np_img, axis=-1) + np_img = np.rollaxis(np_img, 2) # HWC to CHW + return np_img + diff --git a/classification/lib/dataset/builder.py b/classification/lib/dataset/builder.py new file mode 100644 index 0000000..e653664 --- /dev/null +++ b/classification/lib/dataset/builder.py @@ -0,0 +1,102 @@ +import os +import re +import torch +import torchvision.datasets as datasets + +from .dataset import ImageNetDataset +from .dataloader import fast_collate, DataPrefetcher +from .mixup import Mixup +from . import transform + + +def _check_torch_version(target='1.7.0'): + if torch.__version__ == 'parrots': + return False + version = re.match('([\d.])*', torch.__version__).group() + target = re.match('([\d.])*', target).group() + major, minor, patch = [int(x) for x in version.split('.')[:3]] + t_major, t_minor, t_patch = [int(x) for x in target.split('.')[:3]] + if major > t_major: + return True + elif major == t_major: + if minor > t_minor: + return True + elif minor == t_minor: + if patch >= t_patch: + return True + return False + + +# for pytorch>=1.7.0, we add persistent_workers=True in +# dataloader params +if _check_torch_version('1.7.0'): + _LOADER_PARAMS = dict(persistent_workers=True) +else: + _LOADER_PARAMS = dict() + + +def build_dataloader(args): + # pre-configuration for the dataset + if args.dataset == 'imagenet': + args.data_path = 'data/imagenet' if args.data_path == '' else args.data_path + args.num_classes = 1000 + args.input_shape = (3, 224, 224) + elif args.dataset == 'cifar10': + args.data_path = 'data/cifar' if args.data_path == '' else args.data_path + args.num_classes = 10 + args.input_shape = (3, 32, 32) + elif args.dataset == 'cifar100': + args.data_path = 'data/cifar' if args.data_path == '' else args.data_path + args.num_classes = 100 + args.input_shape = (3, 32, 32) + + # train + if args.dataset == 'imagenet': + train_transforms_l, train_transforms_r = transform.build_train_transforms( + args.aa, args.color_jitter, args.reprob, args.remode, args.interpolation, args.image_mean, args.image_std) + train_dataset = ImageNetDataset( + os.path.join(args.data_path, 'train'), os.path.join(args.data_path, 'meta/train.txt'), transform=train_transforms_l) + elif args.dataset == 'cifar10': + train_transforms_l, train_transforms_r = transform.build_train_transforms_cifar10( + args.cutout_length, args.image_mean, args.image_std) + train_dataset = datasets.CIFAR10( + root=args.data_path, train=True, download=True, transform=train_transforms_l) + elif args.dataset == 'cifar100': + train_transforms_l, train_transforms_r = transform.build_train_transforms_cifar10( + args.cutout_length, args.image_mean, args.image_std) + train_dataset = datasets.CIFAR100( + root=args.data_path, train=True, download=True, transform=train_transforms_l) + + # mixup + mixup_active = args.mixup > 0. or args.cutmix > 0. or args.cutmix_minmax is not None + if mixup_active: + mixup_transform = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, + switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) + else: + mixup_transform = None + + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True) + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, + pin_memory=False, sampler=train_sampler, collate_fn=fast_collate, drop_last=True, **_LOADER_PARAMS) + train_loader = DataPrefetcher(train_loader, train_transforms_r, mixup_transform) + + # val + if args.dataset == 'imagenet': + val_transforms_l, val_transforms_r = transform.build_val_transforms(args.interpolation, args.image_mean, args.image_std) + val_dataset = ImageNetDataset(os.path.join(args.data_path, 'val'), os.path.join(args.data_path, 'meta/val.txt'), transform=val_transforms_l) + elif args.dataset == 'cifar10': + val_transforms_l, val_transforms_r = transform.build_val_transforms_cifar10(args.image_mean, args.image_std) + val_dataset = datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=val_transforms_l) + elif args.dataset == 'cifar100': + val_transforms_l, val_transforms_r = transform.build_val_transforms_cifar10(args.image_mean, args.image_std) + val_dataset = datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=val_transforms_l) + + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=int(args.batch_size * args.val_batch_size_multiplier), + shuffle=False, num_workers=args.workers, pin_memory=False, + sampler=val_sampler, collate_fn=fast_collate, **_LOADER_PARAMS) + val_loader = DataPrefetcher(val_loader, val_transforms_r) + + return train_dataset, val_dataset, train_loader, val_loader diff --git a/classification/lib/dataset/dataloader.py b/classification/lib/dataset/dataloader.py new file mode 100644 index 0000000..3ce8ce1 --- /dev/null +++ b/classification/lib/dataset/dataloader.py @@ -0,0 +1,74 @@ +import torch +import numpy as np + + +def fast_collate(batch, memory_format=torch.contiguous_format): + imgs = [img[0] for img in batch] + targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) + w = imgs[0].shape[2] + h = imgs[0].shape[1] + tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format) + for i, nump_array in enumerate(imgs): + if(nump_array.ndim < 3): + nump_array = np.expand_dims(nump_array, axis=-1) + #nump_array = np.rollaxis(nump_array, 2) + tensor[i] += torch.from_numpy(nump_array) + return tensor, targets + + +class DataPrefetcher(): + def __init__(self, loader, transforms, mixup_transform=None): + self.loader = loader + self.loader_iter = iter(loader) + self.transforms = transforms + self.mixup_transform = mixup_transform + self.stream = torch.cuda.Stream() + + def preload(self): + try: + self.next_input, self.next_target = next(self.loader_iter) + except StopIteration: + self.next_input = None + self.next_target = None + return + # if record_stream() doesn't work, another option is to make sure device inputs are created + # on the main stream. + # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda') + # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda') + # Need to make sure the memory allocated for next_* is not still in use by the main stream + # at the time we start copying to next_*: + # self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + self.next_input = self.next_input.cuda(non_blocking=True) + self.next_target = self.next_target.cuda(non_blocking=True) + self.next_input = self.transforms(self.next_input.float()) + if self.mixup_transform is not None: + self.next_input, self.next_target = \ + self.mixup_transform(self.next_input, self.next_target) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + input = self.next_input + target = self.next_target + if input is not None: + input.record_stream(torch.cuda.current_stream()) + if target is not None: + target.record_stream(torch.cuda.current_stream()) + self.preload() + return input, target + + def __iter__(self): + self.loader_iter = iter(self.loader) # re-generate an iter for each epoch + self.preload() + return self + + def __next__(self): + input, target = self.next() + if input is None: + raise StopIteration + return input, target + + def __len__(self): + return len(self.loader) + + diff --git a/classification/lib/dataset/dataset.py b/classification/lib/dataset/dataset.py new file mode 100644 index 0000000..5b6f4d9 --- /dev/null +++ b/classification/lib/dataset/dataset.py @@ -0,0 +1,81 @@ +import io +import torch +import warnings +from PIL import Image +from torch.utils.data import Dataset + +try: + import mc + from .file_io import PetrelMCBackend + _has_mc = True +except ModuleNotFoundError: + warnings.warn('mc module not found, using original ' + 'Image.open to read images') + _has_mc = False + + +class ImageNetDataset(Dataset): + r""" + Dataset using memcached to read data. + + Arguments + * root (string): Root directory of the Dataset. + * meta_file (string): The meta file of the Dataset. Each line has a image path + and a label. Eg: ``nm091234/image_56.jpg 18``. + * transform (callable, optional): A function that transforms the given PIL image + and returns a transformed image. + """ + def __init__(self, root, meta_file, transform=None): + self.root = root + if _has_mc: + with open('./data/mc_prefix.txt', 'r') as f: + prefix = f.readline().strip() + self.root = prefix + '/' + \ + ('train' if 'train' in self.root else 'val') + self.transform = transform + with open(meta_file) as f: + meta_list = f.readlines() + self.num = len(meta_list) + self.metas = [] + for line in meta_list: + path, cls = line.strip().split() + self.metas.append((path, int(cls))) + self._mc_initialized = False + + def __len__(self): + return self.num + + def _init_memcached(self): + if not self._mc_initialized: + ''' + server_list_config_file = "/mnt/lustre/share/memcached_client/server_list.conf" + client_config_file = "/mnt/lustre/share/memcached_client/client.conf" + self.mclient = mc.MemcachedClient.GetInstance( + server_list_config_file, client_config_file) + self._mc_initialized = True + ''' + self.backend = PetrelMCBackend() + + def __getitem__(self, index): + filename = self.root + '/' + self.metas[index][0] + cls = self.metas[index][1] + + if _has_mc: + # memcached + self._init_memcached() + ''' + value = mc.pyvector() + self.mclient.Get(filename, value) + value_buf = mc.ConvertBuffer(value) + buff = io.BytesIO(value_buf) + ''' + buff = self.backend.get(filename) + with Image.open(buff) as img: + img = img.convert('RGB') + else: + img = Image.open(filename).convert('RGB') + + # transform + if self.transform is not None: + img = self.transform(img) + return img, cls diff --git a/classification/lib/dataset/file_io.py b/classification/lib/dataset/file_io.py new file mode 100644 index 0000000..b381bd8 --- /dev/null +++ b/classification/lib/dataset/file_io.py @@ -0,0 +1,65 @@ +import os +import io +import warnings + + +with warnings.catch_warnings(): + # ignore warnings when importing mc + warnings.simplefilter("ignore") + try: + import mc + except ModuleNotFoundError: + pass + + +class PetrelMCBackend(): + """Petrel storage backend with multiple clusters (for internal use). + + Args: + path_mapping (dict|None): path mapping dict from local path to Petrel + path. When `path_mapping={'src': 'dst'}`, `src` in `filepath` will + be replaced by `dst`. Default: None. + enable_mc (bool): whether to enable memcached support. Default: True. + """ + def __init__(self, path_mapping=None, enable_mc=True): + self.enable_mc = enable_mc + assert isinstance(path_mapping, dict) or path_mapping is None + self.path_mapping = path_mapping + + self._client = None + self._mc_client = None + + def _init_clients(self): + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from petrel_client import client + except ImportError: + raise ImportError('Please install petrel_client to enable ' + 'PetrelBackend.') + self._client = client.Client(enable_mc=self.enable_mc, + boto=True, + enable_multi_cluster=True, + conf_path='{}/.s3cfg'.format( + os.environ['HOME'])) + server_list_cfg = "/mnt/lustre/share/memcached_client/server_list.conf" + client_cfg = "/mnt/lustre/share/memcached_client/client.conf" + self._mc_client = mc.MemcachedClient.GetInstance( + server_list_cfg, client_cfg) + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + if self._client is None: + self._init_clients() + filepath = str(filepath) + if self.path_mapping is not None: + for k, v in self.path_mapping.items(): + filepath = filepath.replace(k, v) + if filepath.startswith('cluster'): + value = self._client.Get(filepath) + else: + self._mc_client.Get(filepath, self._mc_buffer) + value = mc.ConvertBuffer(self._mc_buffer) + value_buf = memoryview(value) + buff = io.BytesIO(value_buf) + return buff diff --git a/classification/lib/dataset/mixup.py b/classification/lib/dataset/mixup.py new file mode 100644 index 0000000..fcf6b48 --- /dev/null +++ b/classification/lib/dataset/mixup.py @@ -0,0 +1,317 @@ +# Code from https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/mixup.py +""" Mixup and Cutmix + +Papers: +mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) + +CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) + +Code Reference: +CutMix: https://github.com/clovaai/CutMix-PyTorch + +Hacked together by / Copyright 2019, Ross Wightman +""" +import numpy as np +import torch + + +def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): + x = x.long().view(-1, 1) + return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) + + +def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): + off_value = smoothing / num_classes + on_value = 1. - smoothing + off_value + y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) + y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) + return y1 * lam + y2 * (1. - lam) + + +def rand_bbox(img_shape, lam, margin=0., count=None): + """ Standard CutMix bounding-box + Generates a random square bbox based on lambda value. This impl includes + support for enforcing a border margin as percent of bbox dimensions. + + Args: + img_shape (tuple): Image shape as tuple + lam (float): Cutmix lambda value + margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) + count (int): Number of bbox to generate + """ + ratio = np.sqrt(1 - lam) + img_h, img_w = img_shape[-2:] + cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) + margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) + cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) + cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) + yl = np.clip(cy - cut_h // 2, 0, img_h) + yh = np.clip(cy + cut_h // 2, 0, img_h) + xl = np.clip(cx - cut_w // 2, 0, img_w) + xh = np.clip(cx + cut_w // 2, 0, img_w) + return yl, yh, xl, xh + + +def rand_bbox_minmax(img_shape, minmax, count=None): + """ Min-Max CutMix bounding-box + Inspired by Darknet cutmix impl, generates a random rectangular bbox + based on min/max percent values applied to each dimension of the input image. + + Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. + + Args: + img_shape (tuple): Image shape as tuple + minmax (tuple or list): Min and max bbox ratios (as percent of image size) + count (int): Number of bbox to generate + """ + assert len(minmax) == 2 + img_h, img_w = img_shape[-2:] + cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) + cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) + yl = np.random.randint(0, img_h - cut_h, size=count) + xl = np.random.randint(0, img_w - cut_w, size=count) + yu = yl + cut_h + xu = xl + cut_w + return yl, yu, xl, xu + + +def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): + """ Generate bbox and apply lambda correction. + """ + if ratio_minmax is not None: + yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) + else: + yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) + if correct_lam or ratio_minmax is not None: + bbox_area = (yu - yl) * (xu - xl) + lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) + return (yl, yu, xl, xu), lam + + +class Mixup: + """ Mixup/Cutmix that applies different params to each element or whole batch + + Args: + mixup_alpha (float): mixup alpha value, mixup is active if > 0. + cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. + cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. + prob (float): probability of applying mixup or cutmix per batch or element + switch_prob (float): probability of switching to cutmix instead of mixup when both are active + mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) + correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders + label_smoothing (float): apply label smoothing to the mixed target tensor + num_classes (int): number of classes for target + """ + def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, + mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): + self.mixup_alpha = mixup_alpha + self.cutmix_alpha = cutmix_alpha + self.cutmix_minmax = cutmix_minmax + if self.cutmix_minmax is not None: + assert len(self.cutmix_minmax) == 2 + # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe + self.cutmix_alpha = 1.0 + self.mix_prob = prob + self.switch_prob = switch_prob + self.label_smoothing = label_smoothing + self.num_classes = num_classes + self.mode = mode + self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix + self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) + + def _params_per_elem(self, batch_size): + lam = np.ones(batch_size, dtype=np.float32) + use_cutmix = np.zeros(batch_size, dtype=np.bool) + if self.mixup_enabled: + if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: + use_cutmix = np.random.rand(batch_size) < self.switch_prob + lam_mix = np.where( + use_cutmix, + np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), + np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) + elif self.mixup_alpha > 0.: + lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) + elif self.cutmix_alpha > 0.: + use_cutmix = np.ones(batch_size, dtype=np.bool) + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) + else: + assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." + lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) + return lam, use_cutmix + + def _params_per_batch(self): + lam = 1. + use_cutmix = False + if self.mixup_enabled and np.random.rand() < self.mix_prob: + if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: + use_cutmix = np.random.rand() < self.switch_prob + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ + np.random.beta(self.mixup_alpha, self.mixup_alpha) + elif self.mixup_alpha > 0.: + lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) + elif self.cutmix_alpha > 0.: + use_cutmix = True + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) + else: + assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." + lam = float(lam_mix) + return lam, use_cutmix + + def _mix_elem(self, x): + batch_size = len(x) + lam_batch, use_cutmix = self._params_per_elem(batch_size) + x_orig = x.clone() # need to keep an unmodified original for mixing source + for i in range(batch_size): + j = batch_size - i - 1 + lam = lam_batch[i] + if lam != 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + x[i] = x[i] * lam + x_orig[j] * (1 - lam) + return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) + + def _mix_pair(self, x): + batch_size = len(x) + lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) + x_orig = x.clone() # need to keep an unmodified original for mixing source + for i in range(batch_size // 2): + j = batch_size - i - 1 + lam = lam_batch[i] + if lam != 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] + x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + x[i] = x[i] * lam + x_orig[j] * (1 - lam) + x[j] = x[j] * lam + x_orig[i] * (1 - lam) + lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) + return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) + + def _mix_batch(self, x): + lam, use_cutmix = self._params_per_batch() + if lam == 1.: + return 1. + if use_cutmix: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh] + else: + x_flipped = x.flip(0).mul_(1. - lam) + x.mul_(lam).add_(x_flipped) + return lam + + def __call__(self, x, target): + assert len(x) % 2 == 0, 'Batch size should be even when using this' + if self.mode == 'elem': + lam = self._mix_elem(x) + elif self.mode == 'pair': + lam = self._mix_pair(x) + else: + lam = self._mix_batch(x) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) + return x, target + + +class FastCollateMixup(Mixup): + """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch + + A Mixup impl that's performed while collating the batches. + """ + + def _mix_elem_collate(self, output, batch, half=False): + batch_size = len(batch) + num_elem = batch_size // 2 if half else batch_size + assert len(output) == num_elem + lam_batch, use_cutmix = self._params_per_elem(num_elem) + for i in range(num_elem): + j = batch_size - i - 1 + lam = lam_batch[i] + mixed = batch[i][0] + if lam != 1.: + if use_cutmix[i]: + if not half: + mixed = mixed.copy() + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) + np.rint(mixed, out=mixed) + output[i] += torch.from_numpy(mixed.astype(np.uint8)) + if half: + lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) + return torch.tensor(lam_batch).unsqueeze(1) + + def _mix_pair_collate(self, output, batch): + batch_size = len(batch) + lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) + for i in range(batch_size // 2): + j = batch_size - i - 1 + lam = lam_batch[i] + mixed_i = batch[i][0] + mixed_j = batch[j][0] + assert 0 <= lam <= 1.0 + if lam < 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + patch_i = mixed_i[:, yl:yh, xl:xh].copy() + mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] + mixed_j[:, yl:yh, xl:xh] = patch_i + lam_batch[i] = lam + else: + mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) + mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) + mixed_i = mixed_temp + np.rint(mixed_j, out=mixed_j) + np.rint(mixed_i, out=mixed_i) + output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) + output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) + lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) + return torch.tensor(lam_batch).unsqueeze(1) + + def _mix_batch_collate(self, output, batch): + batch_size = len(batch) + lam, use_cutmix = self._params_per_batch() + if use_cutmix: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + for i in range(batch_size): + j = batch_size - i - 1 + mixed = batch[i][0] + if lam != 1.: + if use_cutmix: + mixed = mixed.copy() # don't want to modify the original while iterating + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] + else: + mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) + np.rint(mixed, out=mixed) + output[i] += torch.from_numpy(mixed.astype(np.uint8)) + return lam + + def __call__(self, batch, _=None): + batch_size = len(batch) + assert batch_size % 2 == 0, 'Batch size should be even when using this' + half = 'half' in self.mode + if half: + batch_size //= 2 + output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + if self.mode == 'elem' or self.mode == 'half': + lam = self._mix_elem_collate(output, batch, half=half) + elif self.mode == 'pair': + lam = self._mix_pair_collate(output, batch) + else: + lam = self._mix_batch_collate(output, batch) + target = torch.tensor([b[1] for b in batch], dtype=torch.int64) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') + target = target[:batch_size] + return output, target + diff --git a/classification/lib/dataset/transform.py b/classification/lib/dataset/transform.py new file mode 100644 index 0000000..bbbf8c8 --- /dev/null +++ b/classification/lib/dataset/transform.py @@ -0,0 +1,102 @@ +import torch +import torchvision.transforms as transforms +from PIL import Image +from . import augment_ops + + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +CIFAR_DEFAULT_MEAN = (0.49139968, 0.48215827, 0.44653124) +CIFAR_DEFAULT_STD = (0.24703233, 0.24348505, 0.26158768) + + +def build_train_transforms(aa_config_str="rand-m9-mstd0.5", color_jitter=None, + reprob=0., remode='pixel', interpolation='bilinear', mean=None, std=None): + mean = IMAGENET_DEFAULT_MEAN if mean is None else mean + std = IMAGENET_DEFAULT_STD if std is None else std + trans_l = [] + trans_r = [] + trans_l.extend([ + augment_ops.RandomResizedCropAndInterpolation(224, interpolation=interpolation), + transforms.RandomHorizontalFlip()]) + if aa_config_str is not None and aa_config_str != '': + if interpolation == 'bilinear': + aa_interpolation = Image.BILINEAR + elif interpolation == 'bicubic': + aa_interpolation = Image.BICUBIC + else: + raise RuntimeError(f'Interpolation mode {interpolation} not found.') + + aa_params = dict( + translate_const=int(224 * 0.45), + img_mean=tuple([round(x * 255) for x in IMAGENET_DEFAULT_MEAN]), + interpolation=aa_interpolation + ) + trans_l.append(augment_ops.rand_augment_transform(aa_config_str, aa_params)) + elif color_jitter != 0 and color_jitter is not None: + # enable color_jitter when not using AA + trans_l.append(transforms.ColorJitter(color_jitter, color_jitter, color_jitter)) + trans_l.append(augment_ops.ToNumpy()) + + trans_r.append(augment_ops.Normalize(mean=[x * 255 for x in mean], + std=[x * 255 for x in std])) + if reprob > 0: + trans_r.append(augment_ops.RandomErasing(reprob, mode=remode, max_count=1, num_splits=0, device='cuda')) + return transforms.Compose(trans_l), transforms.Compose(trans_r) + + +def build_val_transforms(interpolation='bilinear', mean=None, std=None): + mean = IMAGENET_DEFAULT_MEAN if mean is None else mean + std = IMAGENET_DEFAULT_STD if std is None else std + if interpolation == 'bilinear': + interpolation = Image.BILINEAR + elif interpolation == 'bicubic': + interpolation = Image.BICUBIC + else: + raise RuntimeError(f'Interpolation mode {interpolation} not found.') + + trans_l = transforms.Compose([ + transforms.Resize(256, interpolation=interpolation), + transforms.CenterCrop(224), + augment_ops.ToNumpy() + ]) + trans_r = transforms.Compose([ + augment_ops.Normalize(mean=[x * 255 for x in mean], + std=[x * 255 for x in std]) + ]) + return trans_l, trans_r + + +def build_train_transforms_cifar10(cutout_length=0., mean=None, std=None): + mean = CIFAR_DEFAULT_MEAN if mean is None else mean + std = CIFAR_DEFAULT_STD if std is None else std + trans_l = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + augment_ops.ToNumpy() + ]) + trans_r = [ + augment_ops.Normalize(mean=[x * 255 for x in mean], + std=[x * 255 for x in std]) + ] + if cutout_length != 0: + trans_r.append(augment_ops.Cutout(length=cutout_length)) + trans_r = transforms.Compose(trans_r) + return trans_l, trans_r + + +def build_val_transforms_cifar10(mean=None, std=None): + mean = CIFAR_DEFAULT_MEAN if mean is None else mean + std = CIFAR_DEFAULT_STD if std is None else std + trans_l = transforms.Compose([ + augment_ops.ToNumpy() + ]) + trans_r = transforms.Compose([ + augment_ops.Normalize(mean=[x * 255 for x in mean], + std=[x * 255 for x in std]) + ]) + return trans_l, trans_r + + + diff --git a/classification/lib/models/__init__.py b/classification/lib/models/__init__.py new file mode 100644 index 0000000..aadf17b --- /dev/null +++ b/classification/lib/models/__init__.py @@ -0,0 +1,12 @@ +from . import operations +from . import operations_resnet + +# models which use timm's registry +try: + import timm + _has_timm = True +except ModuleNotFoundError: + _has_timm = False + +if _has_timm: + from . import lightvit diff --git a/classification/lib/models/builder.py b/classification/lib/models/builder.py new file mode 100644 index 0000000..066baf0 --- /dev/null +++ b/classification/lib/models/builder.py @@ -0,0 +1,100 @@ +import yaml +import torch +import torchvision +import logging + +from .nas_model import gen_nas_model +from .darts_model import gen_darts_model +from .mobilenet_v1 import MobileNetV1 +from . import resnet + + +logger = logging.getLogger() + + +def build_model(args, model_name, pretrained=False, pretrained_ckpt=''): + if model_name.lower() == 'nas_model': + # model with architectures specific in yaml file + model = gen_nas_model(yaml.safe_load(open(args.model_config, 'r')), drop_rate=args.drop, + drop_path_rate=args.drop_path_rate, auxiliary_head=args.auxiliary) + + elif model_name.lower() == 'darts_model': + # DARTS evaluation models + model = gen_darts_model(yaml.safe_load(open(args.model_config, 'r')), args.dataset, drop_rate=args.drop, + drop_path_rate=args.drop_path_rate, auxiliary_head=args.auxiliary) + + elif model_name.lower() == 'nas_pruning_model': + # model with architectures specific in yaml file + # the model is searched by pruning algorithms + from edgenn.models import EdgeNNModel + model_config = yaml.safe_load(open(args.model_config, 'r')) + channel_settings = model_config.pop('channel_settings') + model = gen_nas_model(model_config, drop_rate=args.drop, drop_path_rate=args.drop_path_rate, auxiliary_head=args.auxiliary) + edgenn_model = EdgeNNModel(model, loss_fn=None, pruning=True, input_shape=args.input_shape) + logger.info(edgenn_model.graph) + edgenn_model.fold_dynamic_nn(channel_settings['choices'], channel_settings['bins'], channel_settings['min_bins']) + logger.info(model) + + elif model_name.lower().startswith('resnet'): + # resnet variants (the same as torchvision) + model = getattr(resnet, model_name.lower())(num_classes=args.num_classes) + + elif model_name.lower() == 'mobilenet_v1': + # mobilenet v1 + model = MobileNetV1(num_classes=args.num_classes) + + elif model_name.startswith('tv_'): + # build model using torchvision + import torchvision + model = getattr(torchvision.models, model_name[3:])(pretrained=pretrained) + + elif model_name.startswith('timm_'): + # build model using timm + # we import local_vim and local_vmamba here to register the models + if 'local_vim' in model_name: + from lib.models import local_vim + if 'local_vmamba' in model_name: + from lib.models import local_vmamba + import timm + model = timm.create_model(model_name[5:], pretrained=pretrained, drop_path_rate=args.drop_path_rate) + + elif model_name.startswith('cifar_'): + from .cifar import model_dict + model_name = model_name[6:] + model = model_dict[model_name](num_classes=args.num_classes) + else: + raise RuntimeError(f'Model {model_name} not found.') + + if pretrained and pretrained_ckpt != '': + logger.info(f'Loading pretrained checkpoint from {pretrained_ckpt}') + ckpt = torch.load(pretrained_ckpt, map_location='cpu') + if 'state_dict' in ckpt: + ckpt = ckpt['state_dict'] + elif 'model' in ckpt: + ckpt = ckpt['model'] + missing_keys, unexpected_keys = \ + model.load_state_dict(ckpt, strict=False) + if len(missing_keys) != 0: + logger.info(f'Missing keys in source state dict: {missing_keys}') + if len(unexpected_keys) != 0: + logger.info(f'Unexpected keys in source state dict: {unexpected_keys}') + + return model + + +def build_edgenn_model(args, edgenn_cfgs=None): + import edgenn + if args.model.lower() in ['nas_model', 'nas_pruning_model']: + # gen model with yaml config first + model = gen_nas_model(yaml.load(open(args.model_config, 'r'), Loader=yaml.FullLoader), drop_rate=args.drop, drop_path_rate=args.drop_path_rate) + # wrap the model with EdgeNNModel + model = edgenn.models.EdgeNNModel(model, loss_fn, pruning=(args.model=='nas_pruning_model')) + + elif args.model == 'edgenn': + # build model from edgenn + model = edgenn.build_model(edgenn_cfgs.model) + + else: + raise RuntimeError(f'Model {args.model} not found.') + + return model diff --git a/classification/lib/models/cifar/ShuffleNetv1.py b/classification/lib/models/cifar/ShuffleNetv1.py new file mode 100644 index 0000000..8e5cd24 --- /dev/null +++ b/classification/lib/models/cifar/ShuffleNetv1.py @@ -0,0 +1,138 @@ +'''ShuffleNet in PyTorch. +See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ShuffleBlock(nn.Module): + def __init__(self, groups): + super(ShuffleBlock, self).__init__() + self.groups = groups + + def forward(self, x): + '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' + N,C,H,W = x.size() + g = self.groups + return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) + + +class Bottleneck(nn.Module): + def __init__(self, in_planes, out_planes, stride, groups, is_last=False): + super(Bottleneck, self).__init__() + self.is_last = is_last + self.stride = stride + + mid_planes = int(out_planes/4) + g = 1 if in_planes == 24 else groups + self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) + self.bn1 = nn.BatchNorm2d(mid_planes) + self.shuffle1 = ShuffleBlock(groups=g) + self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) + self.bn2 = nn.BatchNorm2d(mid_planes) + self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) + self.bn3 = nn.BatchNorm2d(out_planes) + + self.shortcut = nn.Sequential() + if stride == 2: + self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.shuffle1(out) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + res = self.shortcut(x) + preact = torch.cat([out, res], 1) if self.stride == 2 else out+res + out = F.relu(preact) + # out = F.relu(torch.cat([out, res], 1)) if self.stride == 2 else F.relu(out+res) + if self.is_last: + return out, preact + else: + return out + + +class ShuffleNet(nn.Module): + def __init__(self, cfg, num_classes=10): + super(ShuffleNet, self).__init__() + out_planes = cfg['out_planes'] + num_blocks = cfg['num_blocks'] + groups = cfg['groups'] + + self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(24) + self.in_planes = 24 + self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) + self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) + self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) + self.linear = nn.Linear(out_planes[2], num_classes) + + def _make_layer(self, out_planes, num_blocks, groups): + layers = [] + for i in range(num_blocks): + stride = 2 if i == 0 else 1 + cat_planes = self.in_planes if i == 0 else 0 + layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, + stride=stride, + groups=groups, + is_last=(i == num_blocks - 1))) + self.in_planes = out_planes + return nn.Sequential(*layers) + + def get_feat_modules(self): + feat_m = nn.ModuleList([]) + feat_m.append(self.conv1) + feat_m.append(self.bn1) + feat_m.append(self.layer1) + feat_m.append(self.layer2) + feat_m.append(self.layer3) + return feat_m + + def get_bn_before_relu(self): + raise NotImplementedError('ShuffleNet currently is not supported for "Overhaul" teacher') + + def forward(self, x, is_feat=False, preact=False): + out = F.relu(self.bn1(self.conv1(x))) + f0 = out + out, f1_pre = self.layer1(out) + f1 = out + out, f2_pre = self.layer2(out) + f2 = out + out, f3_pre = self.layer3(out) + f3 = out + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + f4 = out + out = self.linear(out) + + if is_feat: + if preact: + return [f0, f1_pre, f2_pre, f3_pre, f4], out + else: + return [f0, f1, f2, f3, f4], out + else: + return out + + +def ShuffleV1(**kwargs): + cfg = { + 'out_planes': [240, 480, 960], + 'num_blocks': [4, 8, 4], + 'groups': 3 + } + return ShuffleNet(cfg, **kwargs) + + +if __name__ == '__main__': + + x = torch.randn(2, 3, 32, 32) + net = ShuffleV1(num_classes=100) + import time + a = time.time() + feats, logit = net(x, is_feat=True, preact=True) + b = time.time() + print(b - a) + for f in feats: + print(f.shape, f.min().item()) + print(logit.shape) diff --git a/classification/lib/models/cifar/ShuffleNetv2.py b/classification/lib/models/cifar/ShuffleNetv2.py new file mode 100644 index 0000000..bd0821b --- /dev/null +++ b/classification/lib/models/cifar/ShuffleNetv2.py @@ -0,0 +1,210 @@ +'''ShuffleNetV2 in PyTorch. +See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ShuffleBlock(nn.Module): + def __init__(self, groups=2): + super(ShuffleBlock, self).__init__() + self.groups = groups + + def forward(self, x): + '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' + N, C, H, W = x.size() + g = self.groups + return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) + + +class SplitBlock(nn.Module): + def __init__(self, ratio): + super(SplitBlock, self).__init__() + self.ratio = ratio + + def forward(self, x): + c = int(x.size(1) * self.ratio) + return x[:, :c, :, :], x[:, c:, :, :] + + +class BasicBlock(nn.Module): + def __init__(self, in_channels, split_ratio=0.5, is_last=False): + super(BasicBlock, self).__init__() + self.is_last = is_last + self.split = SplitBlock(split_ratio) + in_channels = int(in_channels * split_ratio) + self.conv1 = nn.Conv2d(in_channels, in_channels, + kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(in_channels) + self.conv2 = nn.Conv2d(in_channels, in_channels, + kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) + self.bn2 = nn.BatchNorm2d(in_channels) + self.conv3 = nn.Conv2d(in_channels, in_channels, + kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(in_channels) + self.shuffle = ShuffleBlock() + + def forward(self, x): + x1, x2 = self.split(x) + out = F.relu(self.bn1(self.conv1(x2))) + out = self.bn2(self.conv2(out)) + preact = self.bn3(self.conv3(out)) + out = F.relu(preact) + # out = F.relu(self.bn3(self.conv3(out))) + preact = torch.cat([x1, preact], 1) + out = torch.cat([x1, out], 1) + out = self.shuffle(out) + if self.is_last: + return out, preact + else: + return out + + +class DownBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(DownBlock, self).__init__() + mid_channels = out_channels // 2 + # left + self.conv1 = nn.Conv2d(in_channels, in_channels, + kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) + self.bn1 = nn.BatchNorm2d(in_channels) + self.conv2 = nn.Conv2d(in_channels, mid_channels, + kernel_size=1, bias=False) + self.bn2 = nn.BatchNorm2d(mid_channels) + # right + self.conv3 = nn.Conv2d(in_channels, mid_channels, + kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(mid_channels) + self.conv4 = nn.Conv2d(mid_channels, mid_channels, + kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) + self.bn4 = nn.BatchNorm2d(mid_channels) + self.conv5 = nn.Conv2d(mid_channels, mid_channels, + kernel_size=1, bias=False) + self.bn5 = nn.BatchNorm2d(mid_channels) + + self.shuffle = ShuffleBlock() + + def forward(self, x): + # left + out1 = self.bn1(self.conv1(x)) + out1 = F.relu(self.bn2(self.conv2(out1))) + # right + out2 = F.relu(self.bn3(self.conv3(x))) + out2 = self.bn4(self.conv4(out2)) + out2 = F.relu(self.bn5(self.conv5(out2))) + # concat + out = torch.cat([out1, out2], 1) + out = self.shuffle(out) + return out + + +class ShuffleNetV2(nn.Module): + def __init__(self, net_size, num_classes=10): + super(ShuffleNetV2, self).__init__() + out_channels = configs[net_size]['out_channels'] + num_blocks = configs[net_size]['num_blocks'] + + # self.conv1 = nn.Conv2d(3, 24, kernel_size=3, + # stride=1, padding=1, bias=False) + self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(24) + self.in_channels = 24 + self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) + self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) + self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) + self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], + kernel_size=1, stride=1, padding=0, bias=False) + self.bn2 = nn.BatchNorm2d(out_channels[3]) + self.linear = nn.Linear(out_channels[3], num_classes) + + def _make_layer(self, out_channels, num_blocks): + layers = [DownBlock(self.in_channels, out_channels)] + for i in range(num_blocks): + layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1))) + self.in_channels = out_channels + return nn.Sequential(*layers) + + def get_feat_modules(self): + feat_m = nn.ModuleList([]) + feat_m.append(self.conv1) + feat_m.append(self.bn1) + feat_m.append(self.layer1) + feat_m.append(self.layer2) + feat_m.append(self.layer3) + return feat_m + + def get_bn_before_relu(self): + raise NotImplementedError('ShuffleNetV2 currently is not supported for "Overhaul" teacher') + + def forward(self, x, is_feat=False, preact=False): + out = F.relu(self.bn1(self.conv1(x))) + # out = F.max_pool2d(out, 3, stride=2, padding=1) + f0 = out + out, f1_pre = self.layer1(out) + f1 = out + out, f2_pre = self.layer2(out) + f2 = out + out, f3_pre = self.layer3(out) + f3 = out + out = F.relu(self.bn2(self.conv2(out))) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + f4 = out + out = self.linear(out) + if is_feat: + if preact: + return [f0, f1_pre, f2_pre, f3_pre, f4], out + else: + return [f0, f1, f2, f3, f4], out + else: + return out + + +configs = { + 0.2: { + 'out_channels': (40, 80, 160, 512), + 'num_blocks': (3, 3, 3) + }, + + 0.3: { + 'out_channels': (40, 80, 160, 512), + 'num_blocks': (3, 7, 3) + }, + + 0.5: { + 'out_channels': (48, 96, 192, 1024), + 'num_blocks': (3, 7, 3) + }, + + 1: { + 'out_channels': (116, 232, 464, 1024), + 'num_blocks': (3, 7, 3) + }, + 1.5: { + 'out_channels': (176, 352, 704, 1024), + 'num_blocks': (3, 7, 3) + }, + 2: { + 'out_channels': (224, 488, 976, 2048), + 'num_blocks': (3, 7, 3) + } +} + + +def ShuffleV2(**kwargs): + model = ShuffleNetV2(net_size=1, **kwargs) + return model + + +if __name__ == '__main__': + net = ShuffleV2(num_classes=100) + x = torch.randn(3, 3, 32, 32) + import time + a = time.time() + feats, logit = net(x, is_feat=True, preact=True) + b = time.time() + print(b - a) + for f in feats: + print(f.shape, f.min().item()) + print(logit.shape) diff --git a/classification/lib/models/cifar/__init__.py b/classification/lib/models/cifar/__init__.py new file mode 100644 index 0000000..1720555 --- /dev/null +++ b/classification/lib/models/cifar/__init__.py @@ -0,0 +1,32 @@ +from .resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4 +from .resnetv2 import ResNet50 +from .wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2 +from .vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn +from .mobilenetv2 import mobile_half +from .ShuffleNetv1 import ShuffleV1 +from .ShuffleNetv2 import ShuffleV2 + +model_dict = { + 'resnet8': resnet8, + 'resnet14': resnet14, + 'resnet20': resnet20, + 'resnet32': resnet32, + 'resnet44': resnet44, + 'resnet56': resnet56, + 'resnet110': resnet110, + 'resnet8x4': resnet8x4, + 'resnet32x4': resnet32x4, + 'ResNet50': ResNet50, + 'wrn_16_1': wrn_16_1, + 'wrn_16_2': wrn_16_2, + 'wrn_40_1': wrn_40_1, + 'wrn_40_2': wrn_40_2, + 'vgg8': vgg8_bn, + 'vgg11': vgg11_bn, + 'vgg13': vgg13_bn, + 'vgg16': vgg16_bn, + 'vgg19': vgg19_bn, + 'MobileNetV2': mobile_half, + 'ShuffleV1': ShuffleV1, + 'ShuffleV2': ShuffleV2, +} diff --git a/classification/lib/models/cifar/classifier.py b/classification/lib/models/cifar/classifier.py new file mode 100644 index 0000000..167ddb6 --- /dev/null +++ b/classification/lib/models/cifar/classifier.py @@ -0,0 +1,35 @@ +from __future__ import print_function + +import torch.nn as nn + + +######################################### +# ===== Classifiers ===== # +######################################### + +class LinearClassifier(nn.Module): + + def __init__(self, dim_in, n_label=10): + super(LinearClassifier, self).__init__() + + self.net = nn.Linear(dim_in, n_label) + + def forward(self, x): + return self.net(x) + + +class NonLinearClassifier(nn.Module): + + def __init__(self, dim_in, n_label=10, p=0.1): + super(NonLinearClassifier, self).__init__() + + self.net = nn.Sequential( + nn.Linear(dim_in, 200), + nn.Dropout(p=p), + nn.BatchNorm1d(200), + nn.ReLU(inplace=True), + nn.Linear(200, n_label), + ) + + def forward(self, x): + return self.net(x) diff --git a/classification/lib/models/cifar/mobilenetv2.py b/classification/lib/models/cifar/mobilenetv2.py new file mode 100644 index 0000000..6bfe9fa --- /dev/null +++ b/classification/lib/models/cifar/mobilenetv2.py @@ -0,0 +1,202 @@ +""" +MobileNetV2 implementation used in + +""" + +import torch +import torch.nn as nn +import math + +__all__ = ['mobilenetv2_T_w', 'mobile_half'] + +BN = None + + +def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True) + ) + + +def conv_1x1_bn(inp, oup): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True) + ) + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.blockname = None + + self.stride = stride + assert stride in [1, 2] + + self.use_res_connect = self.stride == 1 and inp == oup + + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), + nn.BatchNorm2d(inp * expand_ratio), + nn.ReLU(inplace=True), + # dw + nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), + nn.BatchNorm2d(inp * expand_ratio), + nn.ReLU(inplace=True), + # pw-linear + nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + self.names = ['0', '1', '2', '3', '4', '5', '6', '7'] + + def forward(self, x): + t = x + if self.use_res_connect: + return t + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + """mobilenetV2""" + def __init__(self, T, + feature_dim, + input_size=32, + width_mult=1., + remove_avg=False): + super(MobileNetV2, self).__init__() + self.remove_avg = remove_avg + + # setting of inverted residual blocks + self.interverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [T, 24, 2, 1], + [T, 32, 3, 2], + [T, 64, 4, 2], + [T, 96, 3, 1], + [T, 160, 3, 2], + [T, 320, 1, 1], + ] + + # building first layer + assert input_size % 32 == 0 + input_channel = int(32 * width_mult) + self.conv1 = conv_bn(3, input_channel, 2) + + # building inverted residual blocks + self.blocks = nn.ModuleList([]) + for t, c, n, s in self.interverted_residual_setting: + output_channel = int(c * width_mult) + layers = [] + strides = [s] + [1] * (n - 1) + for stride in strides: + layers.append( + InvertedResidual(input_channel, output_channel, stride, t) + ) + input_channel = output_channel + self.blocks.append(nn.Sequential(*layers)) + + self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 + self.conv2 = conv_1x1_bn(input_channel, self.last_channel) + + # building classifier + self.classifier = nn.Sequential( + # nn.Dropout(0.5), + nn.Linear(self.last_channel, feature_dim), + ) + + H = input_size // (32//2) + self.avgpool = nn.AvgPool2d(H, ceil_mode=True) + + self._initialize_weights() + print(T, width_mult) + + def get_bn_before_relu(self): + bn1 = self.blocks[1][-1].conv[-1] + bn2 = self.blocks[2][-1].conv[-1] + bn3 = self.blocks[4][-1].conv[-1] + bn4 = self.blocks[6][-1].conv[-1] + return [bn1, bn2, bn3, bn4] + + def get_feat_modules(self): + feat_m = nn.ModuleList([]) + feat_m.append(self.conv1) + feat_m.append(self.blocks) + return feat_m + + def forward(self, x, is_feat=False, preact=False): + + out = self.conv1(x) + f0 = out + + out = self.blocks[0](out) + out = self.blocks[1](out) + f1 = out + out = self.blocks[2](out) + f2 = out + out = self.blocks[3](out) + out = self.blocks[4](out) + f3 = out + out = self.blocks[5](out) + out = self.blocks[6](out) + f4 = out + + out = self.conv2(out) + + if not self.remove_avg: + out = self.avgpool(out) + out = out.view(out.size(0), -1) + f5 = out + out = self.classifier(out) + + if is_feat: + return [f0, f1, f2, f3, f4, f5], out + else: + return out + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + +def mobilenetv2_T_w(T, W, feature_dim=100): + model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W) + return model + + +def mobile_half(num_classes): + return mobilenetv2_T_w(6, 0.5, num_classes) + + +if __name__ == '__main__': + x = torch.randn(2, 3, 32, 32) + + net = mobile_half(100) + + feats, logit = net(x, is_feat=True, preact=True) + for f in feats: + print(f.shape, f.min().item()) + print(logit.shape) + + for m in net.get_bn_before_relu(): + if isinstance(m, nn.BatchNorm2d): + print('pass') + else: + print('warning') + diff --git a/classification/lib/models/cifar/resnet.py b/classification/lib/models/cifar/resnet.py new file mode 100644 index 0000000..e5d9a27 --- /dev/null +++ b/classification/lib/models/cifar/resnet.py @@ -0,0 +1,256 @@ +from __future__ import absolute_import + +'''Resnet for cifar dataset. +Ported form +https://github.com/facebook/fb.resnet.torch +and +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py +(c) YANG, Wei +''' +import torch.nn as nn +import torch.nn.functional as F +import math + + +__all__ = ['resnet'] + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): + super(BasicBlock, self).__init__() + self.is_last = is_last + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + preact = out + out = F.relu(out) + if self.is_last: + return out, preact + else: + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): + super(Bottleneck, self).__init__() + self.is_last = is_last + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + preact = out + out = F.relu(out) + if self.is_last: + return out, preact + else: + return out + + +class ResNet(nn.Module): + + def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10): + super(ResNet, self).__init__() + # Model type specifies number of layers for CIFAR-10 model + if block_name.lower() == 'basicblock': + assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' + n = (depth - 2) // 6 + block = BasicBlock + elif block_name.lower() == 'bottleneck': + assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' + n = (depth - 2) // 9 + block = Bottleneck + else: + raise ValueError('block_name shoule be Basicblock or Bottleneck') + + self.inplanes = num_filters[0] + self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(block, num_filters[1], n) + self.layer2 = self._make_layer(block, num_filters[2], n, stride=2) + self.layer3 = self._make_layer(block, num_filters[3], n, stride=2) + self.avgpool = nn.AvgPool2d(8) + self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = list([]) + layers.append(block(self.inplanes, planes, stride, downsample, is_last=(blocks == 1))) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, is_last=(i == blocks-1))) + + return nn.Sequential(*layers) + + def get_feat_modules(self): + feat_m = nn.ModuleList([]) + feat_m.append(self.conv1) + feat_m.append(self.bn1) + feat_m.append(self.relu) + feat_m.append(self.layer1) + feat_m.append(self.layer2) + feat_m.append(self.layer3) + return feat_m + + def get_bn_before_relu(self): + if isinstance(self.layer1[0], Bottleneck): + bn1 = self.layer1[-1].bn3 + bn2 = self.layer2[-1].bn3 + bn3 = self.layer3[-1].bn3 + elif isinstance(self.layer1[0], BasicBlock): + bn1 = self.layer1[-1].bn2 + bn2 = self.layer2[-1].bn2 + bn3 = self.layer3[-1].bn2 + else: + raise NotImplementedError('ResNet unknown block error !!!') + + return [bn1, bn2, bn3] + + def forward(self, x, is_feat=False, preact=False): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) # 32x32 + f0 = x + + x, f1_pre = self.layer1(x) # 32x32 + f1 = x + x, f2_pre = self.layer2(x) # 16x16 + f2 = x + x, f3_pre = self.layer3(x) # 8x8 + f3 = x + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + f4 = x + x = self.fc(x) + + if is_feat: + if preact: + return [f0, f1_pre, f2_pre, f3_pre, f4], x + else: + return [f0, f1, f2, f3, f4], x + else: + return x + + +def resnet8(**kwargs): + return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs) + + +def resnet14(**kwargs): + return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs) + + +def resnet20(**kwargs): + return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs) + + +def resnet32(**kwargs): + return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs) + + +def resnet44(**kwargs): + return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs) + + +def resnet56(**kwargs): + return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs) + + +def resnet110(**kwargs): + return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs) + + +def resnet8x4(**kwargs): + return ResNet(8, [32, 64, 128, 256], 'basicblock', **kwargs) + + +def resnet32x4(**kwargs): + return ResNet(32, [32, 64, 128, 256], 'basicblock', **kwargs) + + +if __name__ == '__main__': + import torch + + x = torch.randn(2, 3, 32, 32) + net = resnet8x4(num_classes=20) + feats, logit = net(x, is_feat=True, preact=True) + + for f in feats: + print(f.shape, f.min().item()) + print(logit.shape) + + for m in net.get_bn_before_relu(): + if isinstance(m, nn.BatchNorm2d): + print('pass') + else: + print('warning') diff --git a/classification/lib/models/cifar/resnetv2.py b/classification/lib/models/cifar/resnetv2.py new file mode 100644 index 0000000..bc03eaf --- /dev/null +++ b/classification/lib/models/cifar/resnetv2.py @@ -0,0 +1,198 @@ +'''ResNet in PyTorch. +For Pre-activation ResNet, see 'preact_resnet.py'. +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1, is_last=False): + super(BasicBlock, self).__init__() + self.is_last = is_last + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + preact = out + out = F.relu(out) + if self.is_last: + return out, preact + else: + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, is_last=False): + super(Bottleneck, self).__init__() + self.is_last = is_last + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + preact = out + out = F.relu(out) + if self.is_last: + return out, preact + else: + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def get_feat_modules(self): + feat_m = nn.ModuleList([]) + feat_m.append(self.conv1) + feat_m.append(self.bn1) + feat_m.append(self.layer1) + feat_m.append(self.layer2) + feat_m.append(self.layer3) + feat_m.append(self.layer4) + return feat_m + + def get_bn_before_relu(self): + if isinstance(self.layer1[0], Bottleneck): + bn1 = self.layer1[-1].bn3 + bn2 = self.layer2[-1].bn3 + bn3 = self.layer3[-1].bn3 + bn4 = self.layer4[-1].bn3 + elif isinstance(self.layer1[0], BasicBlock): + bn1 = self.layer1[-1].bn2 + bn2 = self.layer2[-1].bn2 + bn3 = self.layer3[-1].bn2 + bn4 = self.layer4[-1].bn2 + else: + raise NotImplementedError('ResNet unknown block error !!!') + + return [bn1, bn2, bn3, bn4] + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for i in range(num_blocks): + stride = strides[i] + layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x, is_feat=False, preact=False): + out = F.relu(self.bn1(self.conv1(x))) + f0 = out + out, f1_pre = self.layer1(out) + f1 = out + out, f2_pre = self.layer2(out) + f2 = out + out, f3_pre = self.layer3(out) + f3 = out + out, f4_pre = self.layer4(out) + f4 = out + out = self.avgpool(out) + out = out.view(out.size(0), -1) + f5 = out + out = self.linear(out) + if is_feat: + if preact: + return [[f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], out] + else: + return [f0, f1, f2, f3, f4, f5], out + else: + return out + + +def ResNet18(**kwargs): + return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + + +def ResNet34(**kwargs): + return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + + +def ResNet50(**kwargs): + return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + + +def ResNet101(**kwargs): + return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + + +def ResNet152(**kwargs): + return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + + +if __name__ == '__main__': + net = ResNet18(num_classes=100) + x = torch.randn(2, 3, 32, 32) + feats, logit = net(x, is_feat=True, preact=True) + + for f in feats: + print(f.shape, f.min().item()) + print(logit.shape) + + for m in net.get_bn_before_relu(): + if isinstance(m, nn.BatchNorm2d): + print('pass') + else: + print('warning') diff --git a/classification/lib/models/cifar/util.py b/classification/lib/models/cifar/util.py new file mode 100644 index 0000000..90293f1 --- /dev/null +++ b/classification/lib/models/cifar/util.py @@ -0,0 +1,290 @@ +from __future__ import print_function + +import torch.nn as nn +import math + + +class Paraphraser(nn.Module): + """Paraphrasing Complex Network: Network Compression via Factor Transfer""" + def __init__(self, t_shape, k=0.5, use_bn=False): + super(Paraphraser, self).__init__() + in_channel = t_shape[1] + out_channel = int(t_shape[1] * k) + self.encoder = nn.Sequential( + nn.Conv2d(in_channel, in_channel, 3, 1, 1), + nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), + nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(in_channel, out_channel, 3, 1, 1), + nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), + nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(out_channel, out_channel, 3, 1, 1), + nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), + nn.LeakyReLU(0.1, inplace=True), + ) + self.decoder = nn.Sequential( + nn.ConvTranspose2d(out_channel, out_channel, 3, 1, 1), + nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), + nn.LeakyReLU(0.1, inplace=True), + nn.ConvTranspose2d(out_channel, in_channel, 3, 1, 1), + nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), + nn.LeakyReLU(0.1, inplace=True), + nn.ConvTranspose2d(in_channel, in_channel, 3, 1, 1), + nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), + nn.LeakyReLU(0.1, inplace=True), + ) + + def forward(self, f_s, is_factor=False): + factor = self.encoder(f_s) + if is_factor: + return factor + rec = self.decoder(factor) + return factor, rec + + +class Translator(nn.Module): + def __init__(self, s_shape, t_shape, k=0.5, use_bn=True): + super(Translator, self).__init__() + in_channel = s_shape[1] + out_channel = int(t_shape[1] * k) + self.encoder = nn.Sequential( + nn.Conv2d(in_channel, in_channel, 3, 1, 1), + nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), + nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(in_channel, out_channel, 3, 1, 1), + nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), + nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(out_channel, out_channel, 3, 1, 1), + nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), + nn.LeakyReLU(0.1, inplace=True), + ) + + def forward(self, f_s): + return self.encoder(f_s) + + +class Connector(nn.Module): + """Connect for Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons""" + def __init__(self, s_shapes, t_shapes): + super(Connector, self).__init__() + self.s_shapes = s_shapes + self.t_shapes = t_shapes + + self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes)) + + @staticmethod + def _make_conenctors(s_shapes, t_shapes): + assert len(s_shapes) == len(t_shapes), 'unequal length of feat list' + connectors = [] + for s, t in zip(s_shapes, t_shapes): + if s[1] == t[1] and s[2] == t[2]: + connectors.append(nn.Sequential()) + else: + connectors.append(ConvReg(s, t, use_relu=False)) + return connectors + + def forward(self, g_s): + out = [] + for i in range(len(g_s)): + out.append(self.connectors[i](g_s[i])) + + return out + + +class ConnectorV2(nn.Module): + """A Comprehensive Overhaul of Feature Distillation (ICCV 2019)""" + def __init__(self, s_shapes, t_shapes): + super(ConnectorV2, self).__init__() + self.s_shapes = s_shapes + self.t_shapes = t_shapes + + self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes)) + + def _make_conenctors(self, s_shapes, t_shapes): + assert len(s_shapes) == len(t_shapes), 'unequal length of feat list' + t_channels = [t[1] for t in t_shapes] + s_channels = [s[1] for s in s_shapes] + connectors = nn.ModuleList([self._build_feature_connector(t, s) + for t, s in zip(t_channels, s_channels)]) + return connectors + + @staticmethod + def _build_feature_connector(t_channel, s_channel): + C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(t_channel)] + for m in C: + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + return nn.Sequential(*C) + + def forward(self, g_s): + out = [] + for i in range(len(g_s)): + out.append(self.connectors[i](g_s[i])) + + return out + + +class ConvReg(nn.Module): + """Convolutional regression for FitNet""" + def __init__(self, s_shape, t_shape, use_relu=True): + super(ConvReg, self).__init__() + self.use_relu = use_relu + s_N, s_C, s_H, s_W = s_shape + t_N, t_C, t_H, t_W = t_shape + if s_H == 2 * t_H: + self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) + elif s_H * 2 == t_H: + self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1) + elif s_H >= t_H: + self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W)) + else: + raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H)) + self.bn = nn.BatchNorm2d(t_C) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + if self.use_relu: + return self.relu(self.bn(x)) + else: + return self.bn(x) + + +class Regress(nn.Module): + """Simple Linear Regression for hints""" + def __init__(self, dim_in=1024, dim_out=1024): + super(Regress, self).__init__() + self.linear = nn.Linear(dim_in, dim_out) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = x.view(x.shape[0], -1) + x = self.linear(x) + x = self.relu(x) + return x + + +class Embed(nn.Module): + """Embedding module""" + def __init__(self, dim_in=1024, dim_out=128): + super(Embed, self).__init__() + self.linear = nn.Linear(dim_in, dim_out) + self.l2norm = Normalize(2) + + def forward(self, x): + x = x.view(x.shape[0], -1) + x = self.linear(x) + x = self.l2norm(x) + return x + + +class LinearEmbed(nn.Module): + """Linear Embedding""" + def __init__(self, dim_in=1024, dim_out=128): + super(LinearEmbed, self).__init__() + self.linear = nn.Linear(dim_in, dim_out) + + def forward(self, x): + x = x.view(x.shape[0], -1) + x = self.linear(x) + return x + + +class MLPEmbed(nn.Module): + """non-linear embed by MLP""" + def __init__(self, dim_in=1024, dim_out=128): + super(MLPEmbed, self).__init__() + self.linear1 = nn.Linear(dim_in, 2 * dim_out) + self.relu = nn.ReLU(inplace=True) + self.linear2 = nn.Linear(2 * dim_out, dim_out) + self.l2norm = Normalize(2) + + def forward(self, x): + x = x.view(x.shape[0], -1) + x = self.relu(self.linear1(x)) + x = self.l2norm(self.linear2(x)) + return x + + +class Normalize(nn.Module): + """normalization layer""" + def __init__(self, power=2): + super(Normalize, self).__init__() + self.power = power + + def forward(self, x): + norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) + out = x.div(norm) + return out + + +class Flatten(nn.Module): + """flatten module""" + def __init__(self): + super(Flatten, self).__init__() + + def forward(self, feat): + return feat.view(feat.size(0), -1) + + +class PoolEmbed(nn.Module): + """pool and embed""" + def __init__(self, layer=0, dim_out=128, pool_type='avg'): + super().__init__() + if layer == 0: + pool_size = 8 + nChannels = 16 + elif layer == 1: + pool_size = 8 + nChannels = 16 + elif layer == 2: + pool_size = 6 + nChannels = 32 + elif layer == 3: + pool_size = 4 + nChannels = 64 + elif layer == 4: + pool_size = 1 + nChannels = 64 + else: + raise NotImplementedError('layer not supported: {}'.format(layer)) + + self.embed = nn.Sequential() + if layer <= 3: + if pool_type == 'max': + self.embed.add_module('MaxPool', nn.AdaptiveMaxPool2d((pool_size, pool_size))) + elif pool_type == 'avg': + self.embed.add_module('AvgPool', nn.AdaptiveAvgPool2d((pool_size, pool_size))) + + self.embed.add_module('Flatten', Flatten()) + self.embed.add_module('Linear', nn.Linear(nChannels*pool_size*pool_size, dim_out)) + self.embed.add_module('Normalize', Normalize(2)) + + def forward(self, x): + return self.embed(x) + + +if __name__ == '__main__': + import torch + + g_s = [ + torch.randn(2, 16, 16, 16), + torch.randn(2, 32, 8, 8), + torch.randn(2, 64, 4, 4), + ] + g_t = [ + torch.randn(2, 32, 16, 16), + torch.randn(2, 64, 8, 8), + torch.randn(2, 128, 4, 4), + ] + s_shapes = [s.shape for s in g_s] + t_shapes = [t.shape for t in g_t] + + net = ConnectorV2(s_shapes, t_shapes) + out = net(g_s) + for f in out: + print(f.shape) diff --git a/classification/lib/models/cifar/vgg.py b/classification/lib/models/cifar/vgg.py new file mode 100644 index 0000000..b7bd5fe --- /dev/null +++ b/classification/lib/models/cifar/vgg.py @@ -0,0 +1,236 @@ +'''VGG for CIFAR10. FC layers are removed. +(c) YANG, Wei +''' +import torch.nn as nn +import torch.nn.functional as F +import math + + +__all__ = [ + 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', + 'vgg19_bn', 'vgg19', +] + + +model_urls = { + 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', + 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', + 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', + 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', +} + + +class VGG(nn.Module): + + def __init__(self, cfg, batch_norm=False, num_classes=1000): + super(VGG, self).__init__() + self.block0 = self._make_layers(cfg[0], batch_norm, 3) + self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1]) + self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1]) + self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1]) + self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1]) + + self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) + self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) + # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) + + self.classifier = nn.Linear(512, num_classes) + self._initialize_weights() + + def get_feat_modules(self): + feat_m = nn.ModuleList([]) + feat_m.append(self.block0) + feat_m.append(self.pool0) + feat_m.append(self.block1) + feat_m.append(self.pool1) + feat_m.append(self.block2) + feat_m.append(self.pool2) + feat_m.append(self.block3) + feat_m.append(self.pool3) + feat_m.append(self.block4) + feat_m.append(self.pool4) + return feat_m + + def get_bn_before_relu(self): + bn1 = self.block1[-1] + bn2 = self.block2[-1] + bn3 = self.block3[-1] + bn4 = self.block4[-1] + return [bn1, bn2, bn3, bn4] + + def forward(self, x, is_feat=False, preact=False): + h = x.shape[2] + x = F.relu(self.block0(x)) + f0 = x + x = self.pool0(x) + x = self.block1(x) + f1_pre = x + x = F.relu(x) + f1 = x + x = self.pool1(x) + x = self.block2(x) + f2_pre = x + x = F.relu(x) + f2 = x + x = self.pool2(x) + x = self.block3(x) + f3_pre = x + x = F.relu(x) + f3 = x + if h == 64: + x = self.pool3(x) + x = self.block4(x) + f4_pre = x + x = F.relu(x) + f4 = x + x = self.pool4(x) + x = x.view(x.size(0), -1) + f5 = x + x = self.classifier(x) + + if is_feat: + if preact: + return [f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], x + else: + return [f0, f1, f2, f3, f4, f5], x + else: + return x + + @staticmethod + def _make_layers(cfg, batch_norm=False, in_channels=3): + layers = [] + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + layers = layers[:-1] + return nn.Sequential(*layers) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + +cfg = { + 'A': [[64], [128], [256, 256], [512, 512], [512, 512]], + 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]], + 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]], + 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]], + 'S': [[64], [128], [256], [512], [512]], +} + + +def vgg8(**kwargs): + """VGG 8-layer model (configuration "S") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = VGG(cfg['S'], **kwargs) + return model + + +def vgg8_bn(**kwargs): + """VGG 8-layer model (configuration "S") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = VGG(cfg['S'], batch_norm=True, **kwargs) + return model + + +def vgg11(**kwargs): + """VGG 11-layer model (configuration "A") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = VGG(cfg['A'], **kwargs) + return model + + +def vgg11_bn(**kwargs): + """VGG 11-layer model (configuration "A") with batch normalization""" + model = VGG(cfg['A'], batch_norm=True, **kwargs) + return model + + +def vgg13(**kwargs): + """VGG 13-layer model (configuration "B") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = VGG(cfg['B'], **kwargs) + return model + + +def vgg13_bn(**kwargs): + """VGG 13-layer model (configuration "B") with batch normalization""" + model = VGG(cfg['B'], batch_norm=True, **kwargs) + return model + + +def vgg16(**kwargs): + """VGG 16-layer model (configuration "D") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = VGG(cfg['D'], **kwargs) + return model + + +def vgg16_bn(**kwargs): + """VGG 16-layer model (configuration "D") with batch normalization""" + model = VGG(cfg['D'], batch_norm=True, **kwargs) + return model + + +def vgg19(**kwargs): + """VGG 19-layer model (configuration "E") + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = VGG(cfg['E'], **kwargs) + return model + + +def vgg19_bn(**kwargs): + """VGG 19-layer model (configuration 'E') with batch normalization""" + model = VGG(cfg['E'], batch_norm=True, **kwargs) + return model + + +if __name__ == '__main__': + import torch + + x = torch.randn(2, 3, 32, 32) + net = vgg19_bn(num_classes=100) + feats, logit = net(x, is_feat=True, preact=True) + + for f in feats: + print(f.shape, f.min().item()) + print(logit.shape) + + for m in net.get_bn_before_relu(): + if isinstance(m, nn.BatchNorm2d): + print('pass') + else: + print('warning') diff --git a/classification/lib/models/cifar/wrn.py b/classification/lib/models/cifar/wrn.py new file mode 100644 index 0000000..72a7e10 --- /dev/null +++ b/classification/lib/models/cifar/wrn.py @@ -0,0 +1,170 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +""" +Original Author: Wei Yang +""" + +__all__ = ['wrn'] + + +class BasicBlock(nn.Module): + def __init__(self, in_planes, out_planes, stride, dropRate=0.0): + super(BasicBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.relu1 = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_planes) + self.relu2 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.droprate = dropRate + self.equalInOut = (in_planes == out_planes) + self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=False) or None + + def forward(self, x): + if not self.equalInOut: + x = self.relu1(self.bn1(x)) + else: + out = self.relu1(self.bn1(x)) + out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, training=self.training) + out = self.conv2(out) + return torch.add(x if self.equalInOut else self.convShortcut(x), out) + + +class NetworkBlock(nn.Module): + def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): + super(NetworkBlock, self).__init__() + self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) + + def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): + layers = [] + for i in range(nb_layers): + layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) + return nn.Sequential(*layers) + + def forward(self, x): + return self.layer(x) + + +class WideResNet(nn.Module): + def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): + super(WideResNet, self).__init__() + nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] + assert (depth - 4) % 6 == 0, 'depth should be 6n+4' + n = (depth - 4) // 6 + block = BasicBlock + # 1st conv before any network block + self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, + padding=1, bias=False) + # 1st block + self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) + # 2nd block + self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) + # 3rd block + self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) + # global average pooling and classifier + self.bn1 = nn.BatchNorm2d(nChannels[3]) + self.relu = nn.ReLU(inplace=True) + self.fc = nn.Linear(nChannels[3], num_classes) + self.nChannels = nChannels[3] + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.bias.data.zero_() + + def get_feat_modules(self): + feat_m = nn.ModuleList([]) + feat_m.append(self.conv1) + feat_m.append(self.block1) + feat_m.append(self.block2) + feat_m.append(self.block3) + return feat_m + + def get_bn_before_relu(self): + bn1 = self.block2.layer[0].bn1 + bn2 = self.block3.layer[0].bn1 + bn3 = self.bn1 + + return [bn1, bn2, bn3] + + def forward(self, x, is_feat=False, preact=False): + out = self.conv1(x) + f0 = out + out = self.block1(out) + f1 = out + out = self.block2(out) + f2 = out + out = self.block3(out) + f3 = out + out = self.relu(self.bn1(out)) + out = F.avg_pool2d(out, 8) + out = out.view(-1, self.nChannels) + f4 = out + out = self.fc(out) + if is_feat: + if preact: + f1 = self.block2.layer[0].bn1(f1) + f2 = self.block3.layer[0].bn1(f2) + f3 = self.bn1(f3) + return [f0, f1, f2, f3, f4], out + else: + return out + + +def wrn(**kwargs): + """ + Constructs a Wide Residual Networks. + """ + model = WideResNet(**kwargs) + return model + + +def wrn_40_2(**kwargs): + model = WideResNet(depth=40, widen_factor=2, **kwargs) + return model + + +def wrn_40_1(**kwargs): + model = WideResNet(depth=40, widen_factor=1, **kwargs) + return model + + +def wrn_16_2(**kwargs): + model = WideResNet(depth=16, widen_factor=2, **kwargs) + return model + + +def wrn_16_1(**kwargs): + model = WideResNet(depth=16, widen_factor=1, **kwargs) + return model + + +if __name__ == '__main__': + import torch + + x = torch.randn(2, 3, 32, 32) + net = wrn_40_2(num_classes=100) + feats, logit = net(x, is_feat=True, preact=True) + + for f in feats: + print(f.shape, f.min().item()) + print(logit.shape) + + for m in net.get_bn_before_relu(): + if isinstance(m, nn.BatchNorm2d): + print('pass') + else: + print('warning') diff --git a/classification/lib/models/darts_model.py b/classification/lib/models/darts_model.py new file mode 100644 index 0000000..0432b33 --- /dev/null +++ b/classification/lib/models/darts_model.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn +from .operations import DARTSCell, AuxiliaryHead + + +def gen_darts_model(net_cfg, dataset='imagenet', drop_rate=0., drop_path_rate=0., auxiliary_head=False, **kwargs): + if dataset.lower() == 'imagenet': + dataset = 'imagenet' + elif dataset.lower() in ['cifar', 'cifar10', 'cifar100']: + dataset = 'cifar' + model = DARTSModel(net_cfg, dataset, drop_rate, drop_path_rate, auxiliary_head=auxiliary_head) + return model + + +class DARTSModel(nn.Module): + def __init__(self, net_cfg, dataset='imagenet', drop_rate=0., drop_path_rate=0., auxiliary_head=False): + super(DARTSModel, self).__init__() + self.drop_rate = drop_rate + self.drop_path_rate = drop_path_rate + cell_normal = eval(net_cfg['genotype']['normal']) + cell_reduce = eval(net_cfg['genotype']['reduce']) + init_channels = net_cfg.get('init_channels', 48) + layers = net_cfg.get('layers', 14) + cell_multiplier = net_cfg.get('cell_multiplier', 4) + num_classes = net_cfg.get('num_classes', 1000) + + reduction_layers = [layers // 3, layers * 2 // 3] + C = init_channels + + if dataset == 'imagenet': + C_curr = C + self.stem0 = nn.Sequential( + nn.Conv2d(3, C_curr // 2, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(C_curr // 2), + nn.ReLU(inplace=True), + nn.Conv2d(C_curr // 2, C_curr, 3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(C_curr), + ) + + self.stem1 = nn.Sequential( + nn.ReLU(inplace=True), + nn.Conv2d(C_curr, C_curr, 3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(C_curr), + ) + elif dataset == 'cifar': + stem_multiplier = 3 + C_curr = C * stem_multiplier + self.stem0 = nn.Sequential( + nn.Conv2d(3, C_curr, 3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(C_curr), + ) + self.stem1 = nn.Identity() + + C_prev_prev, C_prev, C_curr = C_curr, C_curr, C + + # cell blocks + self.nas_cells = nn.Sequential() + reduction_prev = dataset == 'imagenet' + for layer_idx in range(layers): + s = 1 + cell_arch = cell_normal + if layer_idx in reduction_layers: + s = 2 + C_curr *= 2 + cell_arch = cell_reduce + cell = DARTSCell(cell_arch, C_prev_prev, C_prev, C_curr, stride=s, reduction_prev=reduction_prev) + self.nas_cells.add_module('cell_{}'.format(layer_idx), cell) + reduction_prev = (s == 2) + C_prev_prev, C_prev = C_prev, C_curr * cell_multiplier + if auxiliary_head and layer_idx == 2 * layers // 3: + C_to_auxiliary = C_prev + + self.pool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + + if auxiliary_head: + object.__setattr__(self, 'module_to_auxiliary', cell) + self.auxiliary_head = nn.Sequential( + nn.ReLU(inplace=False), + AuxiliaryHead(C_to_auxiliary, num_classes, avg_pool_stride=2 if dataset=='imagenet' else 3) + ) + + def get_classifier(self): + return self.classifier + + def forward(self, x): + s0 = self.stem0(x) + s1 = self.stem1(s0) + for cell in self.nas_cells: + s0, s1 = s1, cell(s0, s1, self.drop_path_rate) + x = self.pool(s1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = x.view(x.size(0), -1) + return self.classifier(x) + diff --git a/classification/lib/models/lightvit.py b/classification/lib/models/lightvit.py new file mode 100644 index 0000000..5bcd793 --- /dev/null +++ b/classification/lib/models/lightvit.py @@ -0,0 +1,513 @@ +import math +import torch +import torch.nn as nn +from functools import partial + +from timm.models.layers import DropPath, trunc_normal_, lecun_normal_ +from timm.models.registry import register_model + + +class ConvStem(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + self.patch_size = patch_size + stem_dim = embed_dim // 2 + self.stem = nn.Sequential( + nn.Conv2d(in_chans, stem_dim, kernel_size=3, + stride=2, padding=1, bias=False), + nn.BatchNorm2d(stem_dim), + nn.GELU(), + nn.Conv2d(stem_dim, stem_dim, kernel_size=3, + groups=stem_dim, stride=1, padding=1, bias=False), + nn.BatchNorm2d(stem_dim), + nn.GELU(), + nn.Conv2d(stem_dim, stem_dim, kernel_size=3, + groups=stem_dim, stride=1, padding=1, bias=False), + nn.BatchNorm2d(stem_dim), + nn.GELU(), + nn.Conv2d(stem_dim, stem_dim, kernel_size=3, + groups=stem_dim, stride=2, padding=1, bias=False), + nn.BatchNorm2d(stem_dim), + nn.GELU(), + ) + self.proj = nn.Conv2d(stem_dim, embed_dim, + kernel_size=3, + stride=2, padding=1) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x): + x = self.proj(self.stem(x)) + _, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x, (H, W) + + +class BiAttn(nn.Module): + def __init__(self, in_channels, act_ratio=0.25, act_fn=nn.GELU, gate_fn=nn.Sigmoid): + super().__init__() + reduce_channels = int(in_channels * act_ratio) + self.norm = nn.LayerNorm(in_channels) + self.global_reduce = nn.Linear(in_channels, reduce_channels) + self.local_reduce = nn.Linear(in_channels, reduce_channels) + self.act_fn = act_fn() + self.channel_select = nn.Linear(reduce_channels, in_channels) + self.spatial_select = nn.Linear(reduce_channels * 2, 1) + self.gate_fn = gate_fn() + + def forward(self, x): + ori_x = x + x = self.norm(x) + x_global = x.mean(1, keepdim=True) + x_global = self.act_fn(self.global_reduce(x_global)) + x_local = self.act_fn(self.local_reduce(x)) + + c_attn = self.channel_select(x_global) + c_attn = self.gate_fn(c_attn) # [B, 1, C] + s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1)) + s_attn = self.gate_fn(s_attn) # [B, N, 1] + + attn = c_attn * s_attn # [B, N, C] + return ori_x * attn + + +class BiAttnMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.attn = BiAttn(out_features) + self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity() + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.attn(x) + x = self.drop(x) + return x + + +def window_reverse( + windows: torch.Tensor, + original_size, + window_size=(7, 7) +) -> torch.Tensor: + """ Reverses the window partition. + Args: + windows (torch.Tensor): Window tensor of the shape [B * windows, window_size[0] * window_size[1], C]. + original_size (Tuple[int, int]): Original shape. + window_size (Tuple[int, int], optional): Window size which have been applied. Default (7, 7) + Returns: + output (torch.Tensor): Folded output tensor of the shape [B, original_size[0] * original_size[1], C]. + """ + # Get height and width + H, W = original_size + # Compute original batch size + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + # Fold grid tensor + output = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) + output = output.permute(0, 1, 3, 2, 4, 5).reshape(B, H * W, -1) + return output + + +def get_relative_position_index( + win_h: int, + win_w: int +) -> torch.Tensor: + """ Function to generate pair-wise relative position index for each token inside the window. + Taken from Timms Swin V1 implementation. + Args: + win_h (int): Window/Grid height. + win_w (int): Window/Grid width. + Returns: + relative_coords (torch.Tensor): Pair-wise relative position indexes [height * width, height * width]. + """ + coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += win_h - 1 + relative_coords[:, :, 1] += win_w - 1 + relative_coords[:, :, 0] *= 2 * win_w - 1 + return relative_coords.sum(-1) + + +class LightViTAttention(nn.Module): + def __init__(self, dim, num_tokens=1, num_heads=8, window_size=7, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.num_tokens = num_tokens + self.window_size = window_size + self.attn_area = window_size * window_size + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.kv_global = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0 else nn.Identity() + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0 else nn.Identity() + + # Define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)) + + # Get pair-wise relative position index for each token inside the window + self.register_buffer("relative_position_index", get_relative_position_index(window_size, + window_size).view(-1)) + # Init relative positional bias + trunc_normal_(self.relative_position_bias_table, std=.02) + + def _get_relative_positional_bias( + self + ) -> torch.Tensor: + """ Returns the relative positional bias. + Returns: + relative_position_bias (torch.Tensor): Relative positional bias. + """ + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index].view(self.attn_area, self.attn_area, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + return relative_position_bias.unsqueeze(0) + + def forward_global_aggregation(self, q, k, v): + """ + q: global tokens + k: image tokens + v: image tokens + """ + B, _, N, _ = q.shape + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + return x + + def forward_local(self, q, k, v, H, W): + """ + q: image tokens + k: image tokens + v: image tokens + """ + B, num_heads, N, C = q.shape + ws = self.window_size + h_group, w_group = H // ws, W // ws + + # partition to windows + q = q.view(B, num_heads, h_group, ws, w_group, ws, -1).permute(0, 2, 4, 1, 3, 5, 6).contiguous() + q = q.view(-1, num_heads, ws*ws, C) + k = k.view(B, num_heads, h_group, ws, w_group, ws, -1).permute(0, 2, 4, 1, 3, 5, 6).contiguous() + k = k.view(-1, num_heads, ws*ws, C) + v = v.view(B, num_heads, h_group, ws, w_group, ws, -1).permute(0, 2, 4, 1, 3, 5, 6).contiguous() + v = v.view(-1, num_heads, ws*ws, v.shape[-1]) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + pos_bias = self._get_relative_positional_bias() + attn = (attn + pos_bias).softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(v.shape[0], ws*ws, -1) + + # reverse + x = window_reverse(x, (H, W), (ws, ws)) + return x + + def forward_global_broadcast(self, q, k, v): + """ + q: image tokens + k: global tokens + v: global tokens + """ + B, num_heads, N, _ = q.shape + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + return x + + def forward(self, x, H, W): + B, N, C = x.shape + NT = self.num_tokens + # qkv + qkv = self.qkv(x) + q, k, v = qkv.view(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).unbind(0) + + # split img tokens & global tokens + q_img, k_img, v_img = q[:, :, NT:], k[:, :, NT:], v[:, :, NT:] + q_glb, _, _ = q[:, :, :NT], k[:, :, :NT], v[:, :, :NT] + + # local window attention + x_img = self.forward_local(q_img, k_img, v_img, H, W) + + # global aggregation + x_glb = self.forward_global_aggregation(q_glb, k_img, v_img) + + # global broadcast + k_glb, v_glb = self.kv_global(x_glb).view(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).unbind(0) + + x_img = x_img + self.forward_global_broadcast(q_img, k_glb, v_glb) + x = torch.cat([x_glb, x_img], dim=1) + x = self.proj(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, num_tokens=1, window_size=7, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention=LightViTAttention): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = attention(dim, num_heads=num_heads, num_tokens=num_tokens, window_size=window_size, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = BiAttnMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class ResidualMergePatch(nn.Module): + def __init__(self, dim, out_dim, num_tokens=1): + super().__init__() + self.num_tokens = num_tokens + self.norm = nn.LayerNorm(4 * dim) + self.reduction = nn.Linear(4 * dim, out_dim, bias=False) + self.norm2 = nn.LayerNorm(dim) + self.proj = nn.Linear(dim, out_dim, bias=False) + # use MaxPool3d to avoid permutations + self.maxp = nn.MaxPool3d((2, 2, 1), (2, 2, 1)) + self.res_proj = nn.Linear(dim, out_dim, bias=False) + + def forward(self, x, H, W): + global_token, x = x[:, :self.num_tokens].contiguous(), x[:, self.num_tokens:].contiguous() + B, L, C = x.shape + + x = x.view(B, H, W, C) + res = self.res_proj(self.maxp(x).view(B, -1, C)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + x = x + res + global_token = self.proj(self.norm2(global_token)) + x = torch.cat([global_token, x], 1) + return x, (H // 2, W // 2) + + +class LightViT(nn.Module): + + def __init__(self, img_size=224, patch_size=8, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256], num_layers=[2, 6, 6], + num_heads=[2, 4, 8], mlp_ratios=[8, 4, 4], num_tokens=8, window_size=7, neck_dim=1280, qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=ConvStem, norm_layer=None, + act_layer=None, weight_init=''): + super().__init__() + self.num_classes = num_classes + self.embed_dims = embed_dims + self.num_tokens = num_tokens + self.mlp_ratios = mlp_ratios + self.patch_size = patch_size + self.num_layers = num_layers + self.window_size = window_size + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) + + self.global_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dims[0])) + + stages = [] + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(num_layers))] # stochastic depth decay rule + for stage, (embed_dim, num_layer, num_head, mlp_ratio) in enumerate(zip(embed_dims, num_layers, num_heads, mlp_ratios)): + blocks = [] + if stage > 0: + # downsample + blocks.append(ResidualMergePatch(embed_dims[stage-1], embed_dim, num_tokens=num_tokens)) + blocks += [ + Block( + dim=embed_dim, num_heads=num_head, num_tokens=num_tokens, window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[sum(num_layers[:stage]) + i], norm_layer=norm_layer, act_layer=act_layer, attention=LightViTAttention) + for i in range(num_layer) + ] + blocks = nn.Sequential(*blocks) + stages.append(blocks) + self.stages = nn.Sequential(*stages) + + self.norm = norm_layer(embed_dim) + + self.neck = nn.Sequential( + nn.Linear(embed_dim, neck_dim), + nn.LayerNorm(neck_dim), + nn.GELU() + ) + + self.head = nn.Linear(neck_dim, num_classes) if num_classes > 0 else nn.Identity() + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + if mode.startswith('jax'): + # leave cls token as zeros to match jax impl + named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) + else: + trunc_normal_(self.global_token, std=.02) + self.apply(_init_vit_weights) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + _init_vit_weights(m) + + @torch.jit.ignore + def no_weight_decay(self): + return {'global_token', '[g]relative_position_bias_table'} + + def forward_features(self, x): + x, (H, W) = self.patch_embed(x) + global_token = self.global_token.expand(x.shape[0], -1, -1) + x = torch.cat((global_token, x), dim=1) + for stage in self.stages: + for block in stage: + if isinstance(block, ResidualMergePatch): + x, (H, W) = block(x, H, W) + elif isinstance(block, Block): + x = block(x, H, W) + else: + x = block(x) + x = self.norm(x) + x = self.neck(x) + return x.mean(1) + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + def flops(self, input_shape=(3, 224, 224)): + flops = 0 + ws = self.window_size + # stem + from lib.utils.measure import get_flops + flops += get_flops(self.patch_embed, input_shape) + H = input_shape[1] // self.patch_size + W = input_shape[2] // self.patch_size + N = self.num_tokens + H * W + # blocks + for stage in range(len(self.stages)): + embed_dim = self.embed_dims[stage] + if stage > 0: + # merge patch + # mp - reduction + flops += (H // 2) * (W // 2) * self.embed_dims[stage-1] * (4 * embed_dim) + # mp - residual + flops += (H // 2) * (W // 2) * self.embed_dims[stage-1] * embed_dim + # mp - cls proj + flops += self.num_tokens * self.embed_dims[stage-1] * embed_dim + H, W = H // 2, W // 2 + N = H * W + self.num_tokens + + for i in range(self.num_layers[stage]): + # attn - qkv (img & glb) + flops += N * embed_dim * embed_dim * 3 + # local window self-attn + flops += (H // ws) * (W // ws) * (ws * ws) * embed_dim * 2 + # global aggregation + flops += (H * W) * self.num_tokens * embed_dim * 2 + # global broadcast + flops += (H * W) * self.num_tokens * embed_dim * 2 + # attn - proj + flops += N * embed_dim * embed_dim + + # FFN - mlp + flops += (N * embed_dim * (embed_dim * self.mlp_ratios[stage])) * 2 + # FFN - biattn + attn_ratio = 0.25 + # c attn + flops += embed_dim * embed_dim * attn_ratio * 2 + # s attn + flops += N * embed_dim * embed_dim * attn_ratio + N * embed_dim * attn_ratio * 2 * 1 + # dot product + flops += N * embed_dim + # neck + neck_dim = self.neck[0].out_features + flops += N * embed_dim * neck_dim + # head + flops += neck_dim * 1000 + return flops + + +def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): + """ ViT weight initialization + * When called without n, head_bias, jax_impl args it will behave exactly the same + as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). + * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + elif name.startswith('pre_logits'): + lecun_normal_(module.weight) + nn.init.zeros_(module.bias) + else: + if jax_impl: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + else: + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif jax_impl and isinstance(module, nn.Conv2d): + # NOTE conv was left to pytorch default in my original init + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + +@register_model +def lightvit_tiny(pretrained=False, **kwargs): + model_kwargs = dict(patch_size=8, embed_dims=[64, 128, 256], num_layers=[2, 6, 6], + num_heads=[2, 4, 8, ], mlp_ratios=[8, 4, 4], num_tokens=8, **kwargs) + model = LightViT(**model_kwargs) + return model + + +@register_model +def lightvit_small(pretrained=False, **kwargs): + model_kwargs = dict(patch_size=8, embed_dims=[96, 192, 384], num_layers=[2, 6, 6], + num_heads=[3, 6, 12, ], mlp_ratios=[8, 4, 4], num_tokens=16, **kwargs) + model = LightViT(**model_kwargs) + return model + + +@register_model +def lightvit_base(pretrained=False, **kwargs): + model_kwargs = dict(patch_size=8, embed_dims=[128, 256, 512], num_layers=[3, 8, 6], + num_heads=[4, 8, 16, ], mlp_ratios=[8, 4, 4], num_tokens=24, **kwargs) + model = LightViT(**model_kwargs) + return model diff --git a/classification/lib/models/local_vim.py b/classification/lib/models/local_vim.py new file mode 100644 index 0000000..08f04c6 --- /dev/null +++ b/classification/lib/models/local_vim.py @@ -0,0 +1,652 @@ +import torch +import torch.nn as nn +from functools import partial +from torch import Tensor +from typing import Optional + +from timm.models.vision_transformer import _cfg +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_ + +from timm.models.layers import DropPath, PatchEmbed +from timm.models.vision_transformer import _load_weights + +import math + +from .mamba.multi_mamba import MultiMamba + +from .mamba.rope import * + +from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn + + + +class Block(nn.Module): + def __init__( + self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False,drop_path=0., + ): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" + + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.mixer = mixer_cls(dim) + self.norm = norm_cls(dim) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + + def forward( + self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Mixer(LN(residual)) + """ + if not self.fused_add_norm: + if residual is None: + residual = hidden_states + else: + residual = residual + self.drop_path(hidden_states) + + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + if residual is None: + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + else: + hidden_states, residual = fused_add_norm_fn( + self.drop_path(hidden_states), + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + hidden_states = self.mixer(hidden_states, inference_params=inference_params) + return hidden_states, residual + + +def create_block( + d_model, + ssm_cfg=None, + norm_epsilon=1e-5, + drop_path=0., + rms_norm=False, + residual_in_fp32=False, + fused_add_norm=False, + layer_idx=None, + device=None, + dtype=None, + bimamba_type="none", + directions=None, + token_size=(14, 14), + mamba_cls=None, +): + if ssm_cfg is None: + ssm_cfg = {} + factory_kwargs = {"device": device, "dtype": dtype} + mixer_cls = partial(mamba_cls, layer_idx=layer_idx, bimamba_type=bimamba_type, directions=directions, token_size=token_size, **ssm_cfg, **factory_kwargs) + norm_cls = partial( + nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs + ) + block = Block( + d_model, + mixer_cls, + norm_cls=norm_cls, + drop_path=drop_path, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + ) + block.layer_idx = layer_idx + return block + + +# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 +def _init_weights( + module, + n_layer, + initializer_range=0.02, # Now only used for embedding layer. + rescale_prenorm_residual=True, + n_residuals_per_layer=1, # Change to 2 if we have MLP +): + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * n_layer) + + +def segm_init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + + +class VisionMamba(nn.Module): + def __init__(self, + img_size=224, + patch_size=16, + depth=24, + embed_dim=192, + channels=3, + num_classes=1000, + ssm_cfg=None, + drop_rate=0., + drop_path_rate=0.1, + norm_epsilon: float = 1e-5, + rms_norm: bool = False, + initializer_cfg=None, + fused_add_norm=False, + residual_in_fp32=False, + device=None, + dtype=None, + ft_seq_len=None, + pt_hw_seq_len=14, + final_pool_type='none', + if_abs_pos_embed=False, + if_rope=False, + if_rope_residual=False, + bimamba_type="none", + if_cls_token=False, + directions=None, + mamba_cls=MultiMamba, + **kwargs): + factory_kwargs = {"device": device, "dtype": dtype} + # add factory_kwargs into kwargs + kwargs.update(factory_kwargs) + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.final_pool_type = final_pool_type + self.if_abs_pos_embed = if_abs_pos_embed + self.if_rope = if_rope + self.if_rope_residual = if_rope_residual + self.if_cls_token = if_cls_token + self.num_tokens = 1 if if_cls_token else 0 + self.patch_size = patch_size + + # pretrain parameters + self.num_classes = num_classes + self.d_model = self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=channels, embed_dim=embed_dim, strict_img_size=False, dynamic_img_pad=True) + num_patches = self.patch_embed.num_patches + self.token_size = self.patch_embed.grid_size + + if if_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + if if_abs_pos_embed: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, self.embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + if if_rope: + half_head_dim = embed_dim // 2 + if isinstance(img_size, (tuple, list)): + hw_seq_len = img_size[0] // patch_size + else: + hw_seq_len = img_size // patch_size + self.rope = VisionRotaryEmbeddingFast( + dim=half_head_dim, + pt_seq_len=pt_hw_seq_len, + ft_seq_len=hw_seq_len + ) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + + # TODO: release this comment + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + # import ipdb;ipdb.set_trace() + inter_dpr = [0.0] + dpr + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + # transformer blocks + if directions is None: + directions = [None] * depth + self.layers = nn.ModuleList( + [ + create_block( + embed_dim, + ssm_cfg=ssm_cfg, + norm_epsilon=norm_epsilon, + rms_norm=rms_norm, + residual_in_fp32=residual_in_fp32, + fused_add_norm=fused_add_norm, + layer_idx=i, + bimamba_type=bimamba_type, + drop_path=inter_dpr[i], + directions=directions[i], + token_size=self.token_size, + mamba_cls=mamba_cls, + **factory_kwargs, + ) + for i in range(depth) + ] + ) + + # output head + self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( + embed_dim, eps=norm_epsilon, **factory_kwargs + ) + + self.pre_logits = nn.Identity() + + # original init + self.apply(segm_init_weights) + self.head.apply(segm_init_weights) + if if_abs_pos_embed: + trunc_normal_(self.pos_embed, std=.02) + + # mamba init + self.apply( + partial( + _init_weights, + n_layer=depth, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token", "dist_token"} + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=""): + _load_weights(self, checkpoint_path, prefix) + + def forward_features(self, x, inference_params=None, out_indices=None): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + # with slight modifications to add the dist_token + B, _, H, W = x.shape + x = self.patch_embed(x) + if self.if_cls_token: + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_token, x), dim=1) + + if self.if_abs_pos_embed: + H, W = math.ceil(H / self.patch_size), math.ceil(W / self.patch_size) + for layer in self.layers: + layer.mixer.multi_scan.token_size = (H, W) + if H != self.token_size[0] or W != self.token_size[1]: + # downstream tasks such as det and seg may have various input resolutions + pos_embed = self.resize_pos_embed(self.pos_embed, (H, W), self.token_size, 'bicubic') + if self.if_rope: + freqs_cos = self.resize_pos_embed(self.rope.freqs_cos.unsqueeze(0), (H, W), self.token_size, 'bicubic')[0] + freqs_sin = self.resize_pos_embed(self.rope.freqs_sin.unsqueeze(0), (H, W), self.token_size, 'bicubic')[0] + else: + pos_embed = self.pos_embed + freqs_cos = None + freqs_sin = None + x = x + pos_embed + x = self.pos_drop(x) + + outs = [] + + # mamba impl + residual = None + hidden_states = x + for layer_idx, layer in enumerate(self.layers): + # rope about + if self.if_rope: + hidden_states = self.rope(hidden_states, freqs_cos=freqs_cos, freqs_sin=freqs_sin) + if residual is not None and self.if_rope_residual: + residual = self.rope(residual, freqs_cos=freqs_cos, freqs_sin=freqs_sin) + + hidden_states, residual = layer( + hidden_states, residual, inference_params=inference_params + ) + + if out_indices is not None and layer_idx in out_indices: + outs.append(hidden_states) + + if out_indices is not None: + assert len(outs) == len(out_indices) + return outs, (H, W) + + if not self.fused_add_norm: + if residual is None: + residual = hidden_states + else: + residual = residual + self.drop_path(hidden_states) + hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn + hidden_states = fused_add_norm_fn( + self.drop_path(hidden_states), + self.norm_f.weight, + self.norm_f.bias, + eps=self.norm_f.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, + ) + + # return only cls token if it exists + if self.if_cls_token: + return hidden_states[:, 0, :] + + if self.final_pool_type == 'none': + return hidden_states[:, -1, :] + elif self.final_pool_type == 'mean': + return hidden_states.mean(dim=1) + elif self.final_pool_type == 'max': + return hidden_states.max(dim=1) + elif self.final_pool_type == 'all': + return hidden_states + else: + raise NotImplementedError + + def forward(self, x, return_features=False, inference_params=None): + x = self.forward_features(x, inference_params) + if return_features: + return x + x = self.head(x) + return x + + def flops(self, input_shape=(3, 224, 224)): + flops = 0 + from lib.utils.measure import get_flops + flops += get_flops(self.patch_embed, input_shape) + + L = self.patch_embed.num_patches + for layer in self.layers: + # 1 in_proj + flops += layer.mixer.in_proj.in_features * layer.mixer.in_proj.out_features * L + # 2 MambaInnerFnNoOutProj + # 2.1 causual conv1d + flops += (L + layer.mixer.d_conv - 1) * layer.mixer.d_inner * layer.mixer.d_conv + # 2.2 x_proj + flops += L * layer.mixer.x_proj_0.in_features * layer.mixer.x_proj_0.out_features + # 2.3 dt_proj + flops += L * layer.mixer.dt_proj_0.in_features * layer.mixer.dt_proj_0.out_features + # 2.4 selective scan + """ + u: r(B D L) + delta: r(B D L) + A: r(D N) + B: r(B N L) + C: r(B N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + """ + D = layer.mixer.d_inner + N = layer.mixer.d_state + for i in range(4): + # flops += 9 * L * D * N + 2 * D * L + # A + flops += D * L * N + # B + flops += D * L * N * 2 + # C + flops += (D * N + D * N) * L + # D + flops += D * L + # Z + flops += D * L + # merge + attn = layer.mixer.attn + flops += attn.global_reduce.in_features * attn.global_reduce.out_features + # flops += attn.local_reduce.in_features * attn.local_reduce.out_features * L + flops += attn.channel_select.in_features * attn.channel_select.out_features + # flops += attn.spatial_select.in_features * attn.spatial_select.out_features * L + # 2.5 out_proj + flops += L * layer.mixer.out_proj.in_features * layer.mixer.out_proj.out_features + # layer norm + flops += L * layer.mixer.out_proj.out_features + + # head + flops += self.embed_dim * 1000 + return flops + + +class Backbone_LocalVisionMamba(VisionMamba): + def __init__(self, out_indices=[4, 9, 14, 19], pretrained_ckpt=None, **kwargs): + super().__init__(**kwargs) + del self.head + del self.norm_f + + self.out_indices = out_indices + for i in range(len(out_indices)): + layer = nn.LayerNorm(self.embed_dim) + layer_name = f'outnorm_{i}' + self.add_module(layer_name, layer) + + self.load_pretrained(pretrained_ckpt) + + + def load_pretrained(self, ckpt): + print(f'Load backbone state dict from {ckpt}') + state_dict = torch.load(ckpt, map_location='cpu')['state_dict'] + if 'pos_embed' in state_dict: + pos_size = int(math.sqrt(state_dict['pos_embed'].shape[1])) + state_dict['pos_embed'] = self.resize_pos_embed( + state_dict['pos_embed'], + self.token_size, + (pos_size, pos_size), + 'bicubic' + ) + if 'rope.freqs_cos' in state_dict: + pos_size = int(math.sqrt(state_dict['rope.freqs_cos'].shape[0])) + state_dict['rope.freqs_cos'] = self.resize_pos_embed( + state_dict['rope.freqs_cos'].unsqueeze(0), + self.token_size, + (pos_size, pos_size), + 'bicubic' + )[0] + if 'rope.freqs_cos' in state_dict: + pos_size = int(math.sqrt(state_dict['rope.freqs_sin'].shape[0])) + state_dict['rope.freqs_sin'] = self.resize_pos_embed( + state_dict['rope.freqs_sin'].unsqueeze(0), + self.token_size, + (pos_size, pos_size), + 'bicubic' + )[0] + a, b = self.load_state_dict(state_dict, strict=False) + print(a, b) + + @staticmethod + def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): + from mmseg.models.utils import resize + """Resize pos_embed weights. + + Resize pos_embed using bicubic interpolate method. + Args: + pos_embed (torch.Tensor): Position embedding weights. + input_shpae (tuple): Tuple for (downsampled input image height, + downsampled input image width). + pos_shape (tuple): The resolution of downsampled origin training + image. + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C] + """ + assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' + pos_h, pos_w = pos_shape + pos_embed_weight = pos_embed + pos_embed_weight = pos_embed_weight.reshape( + 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + pos_embed_weight = resize( + pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) + pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) + return pos_embed_weight + + def forward(self, x): + C = self.embed_dim + outs, (H, W) = self.forward_features(x, out_indices=self.out_indices) + outs = [getattr(self, f'outnorm_{i}')(o) for i, o in enumerate(outs)] + outs = [o.view(-1, H, W, C).permute(0, 3, 1, 2).contiguous() for o in outs] + if len(self.out_indices) == 1: + return outs[0] + return outs + + +@register_model +def local_vim_tiny_search(pretrained=False, **kwargs): + directions = None + model = VisionMamba( + patch_size=16, embed_dim=128, depth=20, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', + if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", directions=directions, **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="to.do", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def local_vim_tiny(pretrained=False, **kwargs): + directions = ( + ['h', 'v_flip', 'w7', 'w7_flip'], + ['h_flip', 'w2_flip', 'w7', 'w7_flip'], + ['h', 'h_flip', 'v', 'w7'], + ['h', 'h_flip', 'v', 'v_flip'], + ['h', 'h_flip', 'v', 'v_flip'], + ['h', 'h_flip', 'v', 'v_flip'], + ['h_flip', 'v', 'v_flip', 'w7'], + ['h_flip', 'v', 'v_flip', 'w2_flip'], + ['h', 'h_flip', 'v', 'v_flip'], + ['h', 'h_flip', 'v', 'v_flip'], + ['h', 'v', 'v_flip', 'w2'], + ['h', 'v', 'v_flip', 'w2_flip'], + ['h', 'h_flip', 'v_flip', 'w7'], + ['h_flip', 'v', 'v_flip', 'w2'], + ['h', 'h_flip', 'v', 'v_flip'], + ['h', 'v', 'w2', 'w2_flip'], + ['v', 'v_flip', 'w2', 'w7'], + ['h', 'h_flip', 'v', 'w2'], + ['h', 'h_flip', 'w2_flip', 'w7'], + ['v', 'v_flip', 'w2', 'w2_flip'], + ) + model = VisionMamba( + patch_size=16, embed_dim=192, depth=20, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', + if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", directions=directions, mamba_cls=MultiMamba, **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="to.do", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + +@register_model +def local_vim_small(pretrained=False, **kwargs): + directions = ( + ['h', 'v_flip', 'w7', 'w7_flip'], + ['h_flip', 'w2_flip', 'w7', 'w7_flip'], + ['h', 'h_flip', 'v', 'w7'], + ['h', 'h_flip', 'v', 'v_flip'], + ['h', 'h_flip', 'v', 'v_flip'], + ['h', 'h_flip', 'v', 'v_flip'], + ['h_flip', 'v', 'v_flip', 'w7'], + ['h_flip', 'v', 'v_flip', 'w2_flip'], + ['h', 'h_flip', 'v', 'v_flip'], + ['h', 'h_flip', 'v', 'v_flip'], + ['h', 'v', 'v_flip', 'w2'], + ['h', 'v', 'v_flip', 'w2_flip'], + ['h', 'h_flip', 'v_flip', 'w7'], + ['h_flip', 'v', 'v_flip', 'w2'], + ['h', 'h_flip', 'v', 'v_flip'], + ['h', 'v', 'w2', 'w2_flip'], + ['v', 'v_flip', 'w2', 'w7'], + ['h', 'h_flip', 'v', 'w2'], + ['h', 'h_flip', 'w2_flip', 'w7'], + ['v', 'v_flip', 'w2', 'w2_flip'], + ) + model = VisionMamba( + patch_size=16, embed_dim=384, depth=20, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", directions=directions, **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="to.do", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + +@register_model +def local_vim_tiny_wo_search(pretrained=False, **kwargs): + directions = (('h', 'h_flip', 'w2', 'w2_flip'),) * 20 + model = VisionMamba( + patch_size=16, embed_dim=192, depth=20, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', + if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", directions=directions, mamba_cls=MultiMamba, **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="to.do", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + +@register_model +def local_vim_small(pretrained=False, **kwargs): + directions = (('h', 'h_flip', 'w2', 'w2_flip'),) * 20 + model = VisionMamba( + patch_size=16, embed_dim=384, depth=20, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", directions=directions, **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="to.do", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model \ No newline at end of file diff --git a/classification/lib/models/losses/__init__.py b/classification/lib/models/losses/__init__.py new file mode 100644 index 0000000..27fae28 --- /dev/null +++ b/classification/lib/models/losses/__init__.py @@ -0,0 +1 @@ +from .cross_entropy import CrossEntropyLabelSmooth, SoftTargetCrossEntropy diff --git a/classification/lib/models/losses/cross_entropy.py b/classification/lib/models/losses/cross_entropy.py new file mode 100644 index 0000000..d902239 --- /dev/null +++ b/classification/lib/models/losses/cross_entropy.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CrossEntropyLabelSmooth(nn.Module): + + def __init__(self, num_classes, epsilon=0.1): + super(CrossEntropyLabelSmooth, self).__init__() + self.num_classes = num_classes + self.epsilon = epsilon + self.logsoftmax = nn.LogSoftmax(dim=1) + + def forward(self, inputs, targets): + log_probs = self.logsoftmax(inputs) + targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) + targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes + loss = (-targets * log_probs).mean(0).sum() + return loss + + +class SoftTargetCrossEntropy(nn.Module): + + def __init__(self): + super(SoftTargetCrossEntropy, self).__init__() + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + if target.dtype == torch.int64: + target = F.one_hot(target, x.shape[-1]).float() + loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) + return loss.mean() diff --git a/classification/lib/models/losses/diffkd/__init__.py b/classification/lib/models/losses/diffkd/__init__.py new file mode 100644 index 0000000..a07c18b --- /dev/null +++ b/classification/lib/models/losses/diffkd/__init__.py @@ -0,0 +1 @@ +from .diffkd import DiffKD diff --git a/classification/lib/models/losses/diffkd/diffkd.py b/classification/lib/models/losses/diffkd/diffkd.py new file mode 100644 index 0000000..efdc301 --- /dev/null +++ b/classification/lib/models/losses/diffkd/diffkd.py @@ -0,0 +1,79 @@ +import torch +from torch import nn +import torch.nn.functional as F +from .diffkd_modules import DiffusionModel, NoiseAdapter, AutoEncoder, DDIMPipeline +from .scheduling_ddim import DDIMScheduler + + +class DiffKD(nn.Module): + def __init__( + self, + student_channels, + teacher_channels, + kernel_size=3, + inference_steps=5, + num_train_timesteps=1000, + use_ae=False, + ae_channels=None, + ): + super().__init__() + self.use_ae = use_ae + self.diffusion_inference_steps = inference_steps + # AE for compress teacher feature + if use_ae: + if ae_channels is None: + ae_channels = teacher_channels // 2 + self.ae = AutoEncoder(teacher_channels, ae_channels) + teacher_channels = ae_channels + + # transform student feature to the same dimension as teacher + self.trans = nn.Conv2d(student_channels, teacher_channels, 1) + # diffusion model - predict noise + self.model = DiffusionModel(channels_in=teacher_channels, kernel_size=kernel_size) + self.scheduler = DDIMScheduler(num_train_timesteps=num_train_timesteps, clip_sample=False, beta_schedule="linear") + self.noise_adapter = NoiseAdapter(teacher_channels, kernel_size) + # pipeline for denoising student feature + self.pipeline = DDIMPipeline(self.model, self.scheduler, self.noise_adapter) + self.proj = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels, 1), nn.BatchNorm2d(teacher_channels)) + + def forward(self, student_feat, teacher_feat): + # project student feature to the same dimension as teacher feature + student_feat = self.trans(student_feat) + + # use autoencoder on teacher feature + if self.use_ae: + hidden_t_feat, rec_t_feat = self.ae(teacher_feat) + rec_loss = F.mse_loss(teacher_feat, rec_t_feat) + teacher_feat = hidden_t_feat.detach() + else: + rec_loss = None + + # denoise student feature + refined_feat = self.pipeline( + batch_size=student_feat.shape[0], + device=student_feat.device, + dtype=student_feat.dtype, + shape=student_feat.shape[1:], + feat=student_feat, + num_inference_steps=self.diffusion_inference_steps, + proj=self.proj + ) + refined_feat = self.proj(refined_feat) + + # train diffusion model + ddim_loss = self.ddim_loss(teacher_feat) + return refined_feat, teacher_feat, ddim_loss, rec_loss + + def ddim_loss(self, gt_feat): + # Sample noise to add to the images + noise = torch.randn(gt_feat.shape, device=gt_feat.device) #.to(gt_feat.device) + bs = gt_feat.shape[0] + + # Sample a random timestep for each image + timesteps = torch.randint(0, self.scheduler.num_train_timesteps, (bs,), device=gt_feat.device).long() + # Add noise to the clean images according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_images = self.scheduler.add_noise(gt_feat, noise, timesteps) + noise_pred = self.model(noisy_images, timesteps) + loss = F.mse_loss(noise_pred, noise) + return loss diff --git a/classification/lib/models/losses/diffkd/diffkd_modules.py b/classification/lib/models/losses/diffkd/diffkd_modules.py new file mode 100644 index 0000000..c352297 --- /dev/null +++ b/classification/lib/models/losses/diffkd/diffkd_modules.py @@ -0,0 +1,153 @@ +import torch +import torch.nn as nn + + +class NoiseAdapter(nn.Module): + def __init__(self, channels, kernel_size=3): + super().__init__() + if kernel_size == 3: + self.feat = nn.Sequential( + Bottleneck(channels, channels, reduction=8), + nn.AdaptiveAvgPool2d(1) + ) + else: + self.feat = nn.Sequential( + nn.Conv2d(channels, channels * 2, 1), + nn.BatchNorm2d(channels * 2), + nn.ReLU(inplace=True), + nn.Conv2d(channels * 2, channels, 1), + nn.BatchNorm2d(channels), + ) + self.pred = nn.Linear(channels, 2) + + def forward(self, x): + x = self.feat(x).flatten(1) + x = self.pred(x).softmax(1)[:, 0] + return x + + +class DiffusionModel(nn.Module): + def __init__(self, channels_in, kernel_size=3): + super().__init__() + self.kernel_size = kernel_size + self.time_embedding = nn.Embedding(1280, channels_in) + + if kernel_size == 3: + self.pred = nn.Sequential( + Bottleneck(channels_in, channels_in), + Bottleneck(channels_in, channels_in), + nn.Conv2d(channels_in, channels_in, 1), + nn.BatchNorm2d(channels_in) + ) + else: + self.pred = nn.Sequential( + nn.Conv2d(channels_in, channels_in * 4, 1), + nn.BatchNorm2d(channels_in * 4), + nn.ReLU(inplace=True), + nn.Conv2d(channels_in * 4, channels_in, 1), + nn.BatchNorm2d(channels_in), + nn.Conv2d(channels_in, channels_in * 4, 1), + nn.BatchNorm2d(channels_in * 4), + nn.ReLU(inplace=True), + nn.Conv2d(channels_in * 4, channels_in, 1) + ) + + def forward(self, noisy_image, t): + if t.dtype != torch.long: + t = t.type(torch.long) + feat = noisy_image + feat = feat + self.time_embedding(t)[..., None, None] + ret = self.pred(feat) + return ret + + +class AutoEncoder(nn.Module): + def __init__(self, channels, latent_channels): + super().__init__() + self.encoder = nn.Sequential( + nn.Conv2d(channels, latent_channels, 1, padding=0), + nn.BatchNorm2d(latent_channels) + ) + self.decoder = nn.Sequential( + nn.Conv2d(latent_channels, channels, 1, padding=0), + ) + + def forward(self, x): + hidden = self.encoder(x) + out = self.decoder(hidden) + return hidden, out + + def forward_encoder(self, x): + return self.encoder(x) + + +class DDIMPipeline: + ''' + Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim/pipeline_ddim.py + ''' + + def __init__(self, model, scheduler, noise_adapter=None, solver='ddim'): + super().__init__() + self.model = model + self.scheduler = scheduler + self.noise_adapter = noise_adapter + self._iter = 0 + self.solver = solver + + def __call__( + self, + batch_size, + device, + dtype, + shape, + feat, + generator = None, + eta: float = 0.0, + num_inference_steps: int = 50, + proj = None + ): + + # Sample gaussian noise to begin loop + image_shape = (batch_size, *shape) + + if self.noise_adapter is not None: + noise = torch.randn(image_shape, device=device, dtype=dtype) + timesteps = self.noise_adapter(feat) + image = self.scheduler.add_noise_diff2(feat, noise, timesteps) + else: + image = feat + + # set step values + self.scheduler.set_timesteps(num_inference_steps*2) + + for t in self.scheduler.timesteps[len(self.scheduler.timesteps)//2:]: + noise_pred = self.model(image, t.to(device)) + + # 2. predict previous mean of image x_t-1 and add variance depending on eta + # eta corresponds to η in paper and should be between [0, 1] + # do x_t -> x_t-1 + image = self.scheduler.step( + noise_pred, t, image, eta=eta, use_clipped_model_output=True, generator=generator + )['prev_sample'] + + self._iter += 1 + return image + + +class Bottleneck(nn.Module): + def __init__(self, in_channels, out_channels, reduction=4): + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(in_channels, in_channels // reduction, 1), + nn.BatchNorm2d(in_channels // reduction), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels // reduction, in_channels // reduction, 3, padding=1), + nn.BatchNorm2d(in_channels // reduction), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels // reduction, out_channels, 1), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + out = self.block(x) + return out + x diff --git a/classification/lib/models/losses/diffkd/scheduling_ddim.py b/classification/lib/models/losses/diffkd/scheduling_ddim.py new file mode 100644 index 0000000..9b76037 --- /dev/null +++ b/classification/lib/models/losses/diffkd/scheduling_ddim.py @@ -0,0 +1,456 @@ +# modified from https://raw.githubusercontent.com/huggingface/diffusers/main/src/diffusers/schedulers/scheduling_ddim.py + +# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from typing import Dict +import functools +import inspect +from types import SimpleNamespace + + +def register_to_config(init): + r""" + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + new_kwargs = {**config_init_kwargs, **new_kwargs} + getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) + + return inner_init + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas) + + +class DDIMScheduler(): + config_name = "scheduler_config.json" + _deprecated_kwargs = ["predict_epsilon"] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = False, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + **kwargs, + ): + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " DDIMScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = kwargs.get('predict_epsilon', None) + if predict_epsilon is not None: + self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def register_to_config(self, **kwargs): + if self.config_name is None: + raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") + # Special case for `kwargs` used in deprecation warning added to schedulers + # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, + # or solve in a more general way. + kwargs.pop("kwargs", None) + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + print(f"Can't set {key} with value {value} for {self}") + raise err + + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + previous_dict = dict(self._internal_dict) + internal_dict = {**self._internal_dict, **kwargs} + print(f"Updating config from {previous_dict} to {internal_dict}") + + self._internal_dict = internal_dict + + @property + def config(self): + """ + Returns the config of the class as a frozen dictionary + Returns: + `Dict[str, Any]`: Config of the class. + """ + return SimpleNamespace(**self._internal_dict) + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.config.steps_offset + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + ) -> Union[Dict, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the model_output is always re-derived from the clipped x_0 in Glide + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 + device = model_output.device + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + if device.type == "mps": + # randn does not work reproducibly on mps + variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) + variance_noise = variance_noise.to(device) + else: + variance_noise = torch.randn( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) + variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return dict(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + + def add_noise_diff2( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + alpha_prod: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + #self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + #timesteps = timesteps.to(original_samples.device) + + #sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + #alpha_prod = (self.alphas_cumprod.unsqueeze(0) * timesteps).sum(1) + #sqrt_alpha_prod = (alpha_prod + 1e-6) ** 0.5 + sqrt_alpha_prod = alpha_prod + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + #sqrt_one_minus_alpha_prod = (1 - alpha_prod + 1e-6) ** 0.5 + sqrt_one_minus_alpha_prod = 1 - alpha_prod + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + + def add_noise_diff( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + #sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + alpha_prod = (self.alphas_cumprod.unsqueeze(0) * timesteps).sum(1) + sqrt_alpha_prod = alpha_prod ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alpha_prod) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps + + + diff --git a/classification/lib/models/losses/dist_kd.py b/classification/lib/models/losses/dist_kd.py new file mode 100644 index 0000000..4ae9aa0 --- /dev/null +++ b/classification/lib/models/losses/dist_kd.py @@ -0,0 +1,34 @@ +import torch.nn as nn + + +def cosine_similarity(a, b, eps=1e-8): + return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps) + + +def pearson_correlation(a, b, eps=1e-8): + return cosine_similarity(a - a.mean(1).unsqueeze(1), + b - b.mean(1).unsqueeze(1), eps) + + +def inter_class_relation(y_s, y_t): + return 1 - pearson_correlation(y_s, y_t).mean() + + +def intra_class_relation(y_s, y_t): + return inter_class_relation(y_s.transpose(0, 1), y_t.transpose(0, 1)) + + +class DIST(nn.Module): + def __init__(self, beta=1.0, gamma=1.0, tau=1.0): + super(DIST, self).__init__() + self.beta = beta + self.gamma = gamma + self.tau = tau + + def forward(self, z_s, z_t): + y_s = (z_s / self.tau).softmax(dim=1) + y_t = (z_t / self.tau).softmax(dim=1) + inter_loss = self.tau**2 * inter_class_relation(y_s, y_t) + intra_loss = self.tau**2 * intra_class_relation(y_s, y_t) + kd_loss = self.beta * inter_loss + self.gamma * intra_loss + return kd_loss diff --git a/classification/lib/models/losses/kd_loss.py b/classification/lib/models/losses/kd_loss.py new file mode 100644 index 0000000..3e1935d --- /dev/null +++ b/classification/lib/models/losses/kd_loss.py @@ -0,0 +1,198 @@ +import math +import torch +import torch.nn as nn +from functools import partial + +from .kl_div import KLDivergence +from .dist_kd import DIST +from .diffkd import DiffKD + +import logging +logger = logging.getLogger() + + + +KD_MODULES = { + 'cifar_wrn_40_1': dict(modules=['relu', 'fc'], channels=[64, 100]), + 'cifar_wrn_40_2': dict(modules=['relu', 'fc'], channels=[128, 100]), + 'cifar_resnet56': dict(modules=['layer3', 'fc'], channels=[64, 100]), + 'cifar_resnet20': dict(modules=['layer3', 'fc'], channels=[64, 100]), + 'tv_resnet50': dict(modules=['layer4', 'fc'], channels=[2048, 1000]), + 'tv_resnet34': dict(modules=['layer4', 'fc'], channels=[512, 1000]), + 'tv_resnet18': dict(modules=['layer4', 'fc'], channels=[512, 1000]), + 'resnet18': dict(modules=['layer4', 'fc'], channels=[512, 1000]), + 'tv_mobilenet_v2': dict(modules=['features.18', 'classifier'], channels=[1280, 1000]), + 'nas_model': dict(modules=['features.conv_out', 'classifier'], channels=[1280, 1000]), # mbv2 + 'timm_tf_efficientnet_b0': dict(modules=['conv_head', 'classifier'], channels=[1280, 1000]), + 'mobilenet_v1': dict(modules=['model.13', 'fc'], channels=[1024, 1000]), + 'timm_swin_large_patch4_window7_224': dict(modules=['norm', 'head'], channels=[1536, 1000]), + 'timm_swin_tiny_patch4_window7_224': dict(modules=['norm', 'head'], channels=[768, 1000]), +} + + + +class KDLoss(): + ''' + kd loss wrapper. + ''' + + def __init__( + self, + student, + teacher, + student_name, + teacher_name, + ori_loss, + kd_method='kdt4', + ori_loss_weight=1.0, + kd_loss_weight=1.0, + kd_loss_kwargs={} + ): + self.student = student + self.teacher = teacher + self.ori_loss = ori_loss + self.ori_loss_weight = ori_loss_weight + self.kd_method = kd_method + self.kd_loss_weight = kd_loss_weight + + self._teacher_out = None + self._student_out = None + + # init kd loss + # module keys for distillation. '': output logits + teacher_modules = ['',] + student_modules = ['',] + if kd_method == 'kd': + self.kd_loss = KLDivergence(tau=4) + elif kd_method == 'dist': + self.kd_loss = DIST(beta=1, gamma=1, tau=1) + elif kd_method.startswith('dist_t'): + tau = float(kd_method[6:]) + self.kd_loss = DIST(beta=1, gamma=1, tau=tau) + elif kd_method.startswith('kdt'): + tau = float(kd_method[3:]) + self.kd_loss = KLDivergence(tau) + elif kd_method == 'diffkd': + # get configs + ae_channels = kd_loss_kwargs.get('ae_channels', 1024) + use_ae = kd_loss_kwargs.get('use_ae', True) + tau = kd_loss_kwargs.get('tau', 1) + + print(kd_loss_kwargs) + kernel_sizes = [3, 1] # distillation on feature and logits + student_modules = KD_MODULES[student_name]['modules'] + student_channels = KD_MODULES[student_name]['channels'] + teacher_modules = KD_MODULES[teacher_name]['modules'] + teacher_channels = KD_MODULES[teacher_name]['channels'] + self.diff = nn.ModuleDict() + self.kd_loss = nn.ModuleDict() + for tm, tc, sc, ks in zip(teacher_modules, teacher_channels, student_channels, kernel_sizes): + self.diff[tm] = DiffKD(sc, tc, kernel_size=ks, use_ae=(ks!=1) and use_ae, ae_channels=ae_channels) + self.kd_loss[tm] = nn.MSELoss() if ks != 1 else KLDivergence(tau=tau) + self.diff.cuda() + # add diff module to student for optimization + self.student._diff = self.diff + elif kd_method == 'mse': + # distillation on feature + student_modules = KD_MODULES[student_name]['modules'][:1] + student_channels = KD_MODULES[student_name]['channels'][:1] + teacher_modules = KD_MODULES[teacher_name]['modules'][:1] + teacher_channels = KD_MODULES[teacher_name]['channels'][:1] + self.kd_loss = nn.MSELoss() + self.align = nn.Conv2d(student_channels[0], teacher_channels[0], 1) + self.align.cuda() + # add align module to student for optimization + self.student._align = self.align + else: + raise RuntimeError(f'KD method {kd_method} not found.') + + # register forward hook + # dicts that store distillation outputs of student and teacher + self._teacher_out = {} + self._student_out = {} + + for student_module, teacher_module in zip(student_modules, teacher_modules): + self._register_forward_hook(student, student_module, teacher=False) + self._register_forward_hook(teacher, teacher_module, teacher=True) + self.student_modules = student_modules + self.teacher_modules = teacher_modules + + teacher.eval() + self._iter = 0 + + def __call__(self, x, targets): + with torch.no_grad(): + t_logits = self.teacher(x) + + # compute ori loss of student + logits = self.student(x) + ori_loss = self.ori_loss(logits, targets) + + kd_loss = 0 + + for tm, sm in zip(self.teacher_modules, self.student_modules): + + # transform student feature + if self.kd_method == 'diffkd': + self._student_out[sm], self._teacher_out[tm], diff_loss, ae_loss = \ + self.diff[tm](self._reshape_BCHW(self._student_out[sm]), self._reshape_BCHW(self._teacher_out[tm])) + if hasattr(self, 'align'): + self._student_out[sm] = self.align(self._student_out[sm]) + + # compute kd loss + if isinstance(self.kd_loss, nn.ModuleDict): + kd_loss_ = self.kd_loss[tm](self._student_out[sm], self._teacher_out[tm]) + else: + kd_loss_ = self.kd_loss(self._student_out[sm], self._teacher_out[tm]) + + if self.kd_method == 'diffkd': + # add additional losses in DiffKD + if ae_loss is not None: + kd_loss += diff_loss + ae_loss + if self._iter % 50 == 0: + logger.info(f'[{tm}-{sm}] KD ({self.kd_method}) loss: {kd_loss_.item():.4f} Diff loss: {diff_loss.item():.4f} AE loss: {ae_loss.item():.4f}') + else: + kd_loss += diff_loss + if self._iter % 50 == 0: + logger.info(f'[{tm}-{sm}] KD ({self.kd_method}) loss: {kd_loss_.item():.4f} Diff loss: {diff_loss.item():.4f}') + else: + if self._iter % 50 == 0: + logger.info(f'[{tm}-{sm}] KD ({self.kd_method}) loss: {kd_loss_.item():.4f}') + kd_loss += kd_loss_ + + self._teacher_out = {} + self._student_out = {} + + self._iter += 1 + return ori_loss * self.ori_loss_weight + kd_loss * self.kd_loss_weight + + def _register_forward_hook(self, model, name, teacher=False): + if name == '': + # use the output of model + model.register_forward_hook(partial(self._forward_hook, name=name, teacher=teacher)) + else: + module = None + for k, m in model.named_modules(): + if k == name: + module = m + break + module.register_forward_hook(partial(self._forward_hook, name=name, teacher=teacher)) + + def _forward_hook(self, module, input, output, name, teacher=False): + if teacher: + self._teacher_out[name] = output[0] if len(output) == 1 else output + else: + self._student_out[name] = output[0] if len(output) == 1 else output + + def _reshape_BCHW(self, x): + """ + Reshape a 2d (B, C) or 3d (B, N, C) tensor to 4d BCHW format. + """ + if x.dim() == 2: + x = x.view(x.shape[0], x.shape[1], 1, 1) + elif x.dim() == 3: + # swin [B, N, C] + B, N, C = x.shape + H = W = int(math.sqrt(N)) + x = x.transpose(-2, -1).reshape(B, C, H, W) + return x \ No newline at end of file diff --git a/classification/lib/models/losses/kl_div.py b/classification/lib/models/losses/kl_div.py new file mode 100644 index 0000000..8ca4104 --- /dev/null +++ b/classification/lib/models/losses/kl_div.py @@ -0,0 +1,54 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class KLDivergence(nn.Module): + """A measure of how one probability distribution Q is different from a + second, reference probability distribution P. + + Args: + tau (float): Temperature coefficient. Defaults to 1.0. + reduction (str): Specifies the reduction to apply to the loss: + ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. + ``'none'``: no reduction will be applied, + ``'batchmean'``: the sum of the output will be divided by + the batchsize, + ``'sum'``: the output will be summed, + ``'mean'``: the output will be divided by the number of + elements in the output. + Default: ``'batchmean'`` + loss_weight (float): Weight of loss. Defaults to 1.0. + """ + + def __init__( + self, + tau=1.0, + reduction='batchmean', + ): + super(KLDivergence, self).__init__() + self.tau = tau + + accept_reduction = {'none', 'batchmean', 'sum', 'mean'} + assert reduction in accept_reduction, \ + f'KLDivergence supports reduction {accept_reduction}, ' \ + f'but gets {reduction}.' + self.reduction = reduction + + def forward(self, preds_S, preds_T): + """Forward computation. + + Args: + preds_S (torch.Tensor): The student model prediction with + shape (N, C, H, W) or shape (N, C). + preds_T (torch.Tensor): The teacher model prediction with + shape (N, C, H, W) or shape (N, C). + + Return: + torch.Tensor: The calculated loss value. + """ + preds_T = preds_T.detach() + softmax_pred_T = F.softmax(preds_T / self.tau, dim=1) + logsoftmax_preds_S = F.log_softmax(preds_S / self.tau, dim=1) + loss = (self.tau**2) * F.kl_div( + logsoftmax_preds_S, softmax_pred_T, reduction=self.reduction) + return loss diff --git a/classification/lib/models/mamba/__init__.py b/classification/lib/models/mamba/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/classification/lib/models/mamba/local_scan.py b/classification/lib/models/mamba/local_scan.py new file mode 100644 index 0000000..24b669d --- /dev/null +++ b/classification/lib/models/mamba/local_scan.py @@ -0,0 +1,261 @@ + +import math +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + + +@triton.jit +def triton_local_scan( + x, # x point (B, C, H, W) or (B, C, L) + y, # y point (B, C, H, W) or (B, C, L) + K: tl.constexpr, # window size + flip: tl.constexpr, # whether to flip the tokens + BC: tl.constexpr, # number of channels in each program + BH: tl.constexpr, # number of heights in each program + BW: tl.constexpr, # number of width in each program + DC: tl.constexpr, # original channels + DH: tl.constexpr, # original height + DW: tl.constexpr, # original width + NH: tl.constexpr, # number of programs on height + NW: tl.constexpr, # number of programs on width +): + i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) # program id of hw axis, c axis, batch axis + i_h, i_w = (i_hw // NW), (i_hw % NW) # program idx of h and w + _mask_h = (i_h * BH + tl.arange(0, BH)) < DH + _mask_w = (i_w * BW + tl.arange(0, BW)) < DW + _mask_hw = _mask_h[:, None] & _mask_w[None, :] # [BH, BW] + _for_C = min(DC - i_c * BC, BC) # valid number of c in the program + + _tmp0 = i_c * BC * DH * DW # start offset of this program + _tmp1 = DC * DH * DW # n_elements in one batch + _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :] # offsets of elements in this program + + p_x = x + i_b * _tmp1 + _tmp2 + + _i = (tl.arange(0, BH) + BH * i_h)[:, None] + _j = (tl.arange(0, BW) + BW * i_w)[None, :] + _c_offset = ((DW // K) * (_i // K) + (_j // K)) * K * K + (_i % K) * K + _j % K + if flip: + _c_offset = DH * DW - _c_offset - 1 + + p_y = y + i_b * _tmp1 + _tmp0 + _c_offset + for idxc in range(_for_C): + _idx = idxc * DH * DW + _x = tl.load(p_x + _idx, mask=_mask_hw) + tl.store(p_y + _idx, _x, mask=_mask_hw) + tl.debug_barrier() + + +@triton.jit +def triton_local_reverse( + x, # x point (B, C, H, W) or (B, C, L) + y, # y point (B, C, H, W) or (B, C, L) + K: tl.constexpr, # window size + flip: tl.constexpr, # whether to flip the tokens + BC: tl.constexpr, # number of channels in each program + BH: tl.constexpr, # number of heights in each program + BW: tl.constexpr, # number of width in each program + DC: tl.constexpr, # original channels + DH: tl.constexpr, # original height + DW: tl.constexpr, # original width + NH: tl.constexpr, # number of programs on height + NW: tl.constexpr, # number of programs on width +): + i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) # program id of hw axis, c axis, batch axis + i_h, i_w = (i_hw // NW), (i_hw % NW) # program idx of h and w + _mask_h = (i_h * BH + tl.arange(0, BH)) < DH + _mask_w = (i_w * BW + tl.arange(0, BW)) < DW + _mask_hw = _mask_h[:, None] & _mask_w[None, :] # [BH, BW] + _for_C = min(DC - i_c * BC, BC) # valid number of c in the program + + _tmp0 = i_c * BC * DH * DW # start offset of this program + _tmp1 = DC * DH * DW # n_elements in one batch + _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :] # offsets of elements in this program + + p_x = x + i_b * _tmp1 + _tmp2 + + _i = (tl.arange(0, BH) + BH * i_h)[:, None] + _j = (tl.arange(0, BW) + BW * i_w)[None, :] + _o = _i * DW + _j + + _i = _o // (K * K) // (DW // K) * K + _o % (K * K) // K + _j = _o // (K * K) % (DW // K) * K + _o % (K * K) % K + _c_offset = _i * DW + _j + if flip: + _c_offset = DH * DW - _c_offset - 1 + + p_y = y + i_b * _tmp1 + _tmp0 + _c_offset + for idxc in range(_for_C): + _idx = idxc * DH * DW + _x = tl.load(p_x + _idx, mask=_mask_hw) + tl.store(p_y + _idx, _x, mask=_mask_hw) + tl.debug_barrier() + + +class LocalScanTriton(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, K: int, flip: bool, H: int = None, W: int = None): + ori_x = x + if len(x.shape) == 4: + B, C, H, W = x.shape + elif len(x.shape) == 3: + B, C, _ = x.shape + assert H is not None and W is not None, "x must be BCHW format to infer the H W" + else: + raise RuntimeError(f"Unsupported shape of x: {x.shape}") + B, C, H, W = int(B), int(C), int(H), int(W) + + ctx.ori_shape = (B, C, H, W) + # pad tensor to make it evenly divisble by window size + x, (H, W) = pad_tensor(x, K, H, W) + ctx.shape = (B, C, H, W) + + BC, BH, BW = min(triton.next_power_of_2(C), 1), min(triton.next_power_of_2(H), 64), min(triton.next_power_of_2(W), 64) + NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) + ctx.triton_shape = (BC, BH, BW, NC, NH, NW) + ctx.K = K + ctx.flip = flip + + if x.stride(-1) != 1: + x = x.contiguous() + + if len(ori_x.shape) == 4: + y = x.new_empty((B, C, H, W)) + elif len(ori_x.shape) == 3: + y = x.new_empty((B, C, H * W)) + + triton_local_scan[(NH * NW, NC, B)](x, y, K, flip, BC, BH, BW, C, H, W, NH, NW) + return y + + @staticmethod + def backward(ctx, y: torch.Tensor): + # out: (b, k, d, l) + B, C, H, W = ctx.shape + BC, BH, BW, NC, NH, NW = ctx.triton_shape + + if y.stride(-1) != 1: + y = y.contiguous() + if len(y.shape) == 4: + x = y.new_empty((B, C, H, W)) + else: + x = y.new_empty((B, C, H * W)) + + triton_local_reverse[(NH * NW, NC, B)](y, x, ctx.K, ctx.flip, BC, BH, BW, C, H, W, NH, NW) + return x, None, None, None, None + + +class LocalReverseTriton(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, K: int, flip: bool, H: int = None, W: int = None): + if len(x.shape) == 4: + B, C, H, W = x.shape + elif len(x.shape) == 3: + B, C, _ = x.shape + assert H is not None and W is not None, "x must be BCHW format to infer the H W" + else: + raise RuntimeError(f"Unsupported shape of x: {x.shape}") + B, C, H, W = int(B), int(C), int(H), int(W) + + ctx.ori_shape = (B, C, H, W) + # x may have been padded + Hg, Wg = math.ceil(H / K), math.ceil(W / K) + H, W = Hg * K, Wg * K + ctx.shape = (B, C, H, W) + + BC, BH, BW = min(triton.next_power_of_2(C), 1), min(triton.next_power_of_2(H), 64), min(triton.next_power_of_2(W), 64) + NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) + ctx.triton_shape = (BC, BH, BW, NC, NH, NW) + ctx.K = K + ctx.flip = flip + + if x.stride(-1) != 1: + x = x.contiguous() + + if len(x.shape) == 4: + y = x.new_empty((B, C, H, W)) + else: + y = x.new_empty((B, C, H * W)) + + triton_local_reverse[(NH * NW, NC, B)](x, y, K, flip, BC, BH, BW, C, H, W, NH, NW) + + if ctx.ori_shape != ctx.shape: + ori_H, ori_W = ctx.ori_shape[-2:] + if len(y.shape) == 3: + y = y.view(B, C, H, W)[:, :, :ori_H, :ori_W].flatten(2) + else: + y = y[:, :, :ori_H, :ori_W] + + return y + + @staticmethod + def backward(ctx, y: torch.Tensor): + # out: (b, k, d, l) + B, C, H, W = ctx.ori_shape + BC, BH, BW, NC, NH, NW = ctx.triton_shape + + x, (H, W) = pad_tensor(x, ctx.K, H, W) + + if y.stride(-1) != 1: + y = y.contiguous() + + if len(y.shape) == 4: + x = y.new_empty((B, C, H, W)) + else: + x = y.new_empty((B, C, H * W)) + + triton_local_scan[(NH * NW, NC, B)](y, x, ctx.K, ctx.flip, BC, BH, BW, C, H, W, NH, NW) + + return x, None, None, None, None + + + +def pad_tensor(x, w, H, W): + if H % w == 0 and W % w == 0: + return x, (H, W) + ori_x = x + B, C = x.shape[:2] + if len(x.shape) == 3: + x = x.view(B, C, H, W) + + Hg, Wg = math.ceil(H / w), math.ceil(W / w) + newH, newW = Hg * w, Wg * w + x = F.pad(x, (0, newW - W, 0, newH - H)) + + # if len(ori_x.shape) == 3: + # x = x.flatten(2) + + return x, (newH, newW) + + +"""PyTorch code for local scan and local reverse""" + +def local_scan(x, w=7, H=14, W=14, h_scan=False): + B, L, C = x.shape + x = x.view(B, H, W, C) + Hg, Wg = math.ceil(H / w), math.ceil(W / w) + if H % w != 0 or W % w != 0: + newH, newW = Hg * w, Wg * w + x = F.pad(x, (0, 0, 0, newW - W, 0, newH - H)) + if h_scan: + x = x.view(B, Hg, w, Wg, w, C).permute(0, 3, 1, 4, 2, 5).reshape(B, -1, C) + else: + x = x.view(B, Hg, w, Wg, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, -1, C) + return x + +def local_reverse(x, w=7, H=14, W=14, h_scan=False): + B, L, C = x.shape + Hg, Wg = math.ceil(H / w), math.ceil(W / w) + if H % w != 0 or W % w != 0: + if h_scan: + x = x.view(B, Wg, Hg, w, w, C).permute(0, 2, 4, 1, 3, 5).reshape(B, Hg * w, Wg * w, C) + else: + x = x.view(B, Hg, Wg, w, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, Hg * w, Wg * w, C) + x = x[:, :H, :W].reshape(B, -1, C) + else: + if h_scan: + x = x.view(B, Wg, Hg, w, w, C).permute(0, 2, 4, 1, 3, 5).reshape(B, L, C) + else: + x = x.view(B, Hg, Wg, w, w, C).permute(0, 1, 3, 2, 4, 5).reshape(B, L, C) + return x \ No newline at end of file diff --git a/classification/lib/models/mamba/multi_mamba.py b/classification/lib/models/mamba/multi_mamba.py new file mode 100644 index 0000000..e7c01e3 --- /dev/null +++ b/classification/lib/models/mamba/multi_mamba.py @@ -0,0 +1,480 @@ +# Copyright (c) 2023, Tri Dao, Albert Gu. + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from einops import rearrange, repeat +import logging + + +try: + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn_no_out_proj +except ImportError: + mamba_inner_fn_no_out_proj = None + +from .local_scan import LocalScanTriton, LocalReverseTriton + +class MultiScan(nn.Module): + + ALL_CHOICES = ('h', 'h_flip', 'v', 'v_flip', 'w2', 'w2_flip', 'w7', 'w7_flip') + + def __init__(self, dim, choices=None, token_size=(14, 14)): + super().__init__() + self.token_size = token_size + if choices is None: + self.choices = MultiScan.ALL_CHOICES + self.norms = nn.ModuleList([nn.LayerNorm(dim, elementwise_affine=False) for _ in self.choices]) + self.weights = nn.Parameter(1e-3 * torch.randn(len(self.choices), 1, 1, 1)) + self._iter = 0 + self.logger = logging.getLogger() + self.search = True + else: + self.choices = choices + self.search = False + + def forward(self, xs): + """ + Input @xs: [[B, L, D], ...] + """ + if self.search: + weights = self.weights.softmax(0) + xs = [norm(x) for norm, x in zip(self.norms, xs)] + xs = torch.stack(xs) * weights + x = xs.sum(0) + if self._iter % 200 == 0: + if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: + self.logger.info(str(weights.detach().view(-1).tolist())) + self._iter += 1 + else: + x = torch.stack(xs).sum(0) + return x + + def multi_scan(self, x): + """ + Input @x: shape [B, L, D] + """ + xs = [] + for direction in self.choices: + xs.append(self.scan(x, direction)) + return xs + + def multi_reverse(self, xs): + new_xs = [] + for x, direction in zip(xs, self.choices): + new_xs.append(self.reverse(x, direction)) + return new_xs + + def scan(self, x, direction='h'): + """ + Input @x: shape [B, L, D] or [B, C, H, W] + Return torch.Tensor: shape [B, D, L] + """ + H, W = self.token_size + if len(x.shape) == 3: + if direction == 'h': + return x.transpose(-2, -1) + elif direction == 'h_flip': + return x.transpose(-2, -1).flip([-1]) + elif direction == 'v': + return rearrange(x, 'b (h w) d -> b d (w h)', h=H, w=W) + elif direction == 'v_flip': + return rearrange(x, 'b (h w) d -> b d (w h)', h=H, w=W).flip([-1]) + elif direction.startswith('w'): + K = int(direction[1:].split('_')[0]) + flip = direction.endswith('flip') + return LocalScanTriton.apply(x.transpose(-2, -1), K, flip, H, W) + else: + raise RuntimeError(f'Direction {direction} not found.') + elif len(x.shape) == 4: + if direction == 'h': + return x.flatten(2) + elif direction == 'h_flip': + return x.flatten(2).flip([-1]) + elif direction == 'v': + return rearrange(x, 'b d h w -> b d (w h)', h=H, w=W) + elif direction == 'v_flip': + return rearrange(x, 'b d h w -> b d (w h)', h=H, w=W).flip([-1]) + elif direction.startswith('w'): + K = int(direction[1:].split('_')[0]) + flip = direction.endswith('flip') + return LocalScanTriton.apply(x, K, flip, H, W).flatten(2) + else: + raise RuntimeError(f'Direction {direction} not found.') + + def reverse(self, x, direction='h'): + """ + Input @x: shape [B, D, L] + Return torch.Tensor: shape [B, D, L] + """ + H, W = self.token_size + if direction == 'h': + return x + elif direction == 'h_flip': + return x.flip([-1]) + elif direction == 'v': + return rearrange(x, 'b d (h w) -> b d (w h)', h=H, w=W) + elif direction == 'v_flip': + return rearrange(x.flip([-1]), 'b d (h w) -> b d (w h)', h=H, w=W) + elif direction.startswith('w'): + K = int(direction[1:].split('_')[0]) + flip = direction.endswith('flip') + return LocalReverseTriton.apply(x, K, flip, H, W) + else: + raise RuntimeError(f'Direction {direction} not found.') + + def __repr__(self): + scans = ', '.join(self.choices) + return super().__repr__().replace(self.__class__.__name__, f'{self.__class__.__name__}[{scans}]') + + +class BiAttn(nn.Module): + def __init__(self, in_channels, act_ratio=0.125, act_fn=nn.GELU, gate_fn=nn.Sigmoid): + super().__init__() + reduce_channels = int(in_channels * act_ratio) + self.norm = nn.LayerNorm(in_channels) + self.global_reduce = nn.Linear(in_channels, reduce_channels) + # self.local_reduce = nn.Linear(in_channels, reduce_channels) + self.act_fn = act_fn() + self.channel_select = nn.Linear(reduce_channels, in_channels) + # self.spatial_select = nn.Linear(reduce_channels * 2, 1) + self.gate_fn = gate_fn() + + def forward(self, x): + ori_x = x + x = self.norm(x) + x_global = x.mean(1, keepdim=True) + x_global = self.act_fn(self.global_reduce(x_global)) + # x_local = self.act_fn(self.local_reduce(x)) + + c_attn = self.channel_select(x_global) + c_attn = self.gate_fn(c_attn) # [B, 1, C] + # s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1)) + # s_attn = self.gate_fn(s_attn) # [B, N, 1] + + attn = c_attn #* s_attn # [B, N, C] + return ori_x * attn + + +class MultiMamba(nn.Module): + def __init__( + self, + d_model, + d_state=16, + d_conv=4, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + use_fast_path=True, # Fused kernel options + layer_idx=None, + device=None, + dtype=None, + bimamba_type="none", + directions=None, + token_size=(14, 14), + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + self.use_fast_path = use_fast_path + self.layer_idx = layer_idx + self.bimamba_type = bimamba_type + self.token_size = token_size + + self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) + + self.activation = "silu" + self.act = nn.SiLU() + + + self.multi_scan = MultiScan(self.d_inner, choices=directions, token_size=token_size) + '''new for search''' + A = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=self.d_inner, + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + for i in range(len(self.multi_scan.choices)): + setattr(self, f'A_log_{i}', nn.Parameter(A_log)) + getattr(self, f'A_log_{i}')._no_weight_decay = True + + conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + groups=self.d_inner, + padding=d_conv - 1, + **factory_kwargs, + ) + setattr(self, f'conv1d_{i}', conv1d) + + x_proj = nn.Linear( + self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs + ) + setattr(self, f'x_proj_{i}', x_proj) + + dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = self.dt_rank**-0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + dt_proj.bias._no_reinit = True + + setattr(self, f'dt_proj_{i}', dt_proj) + + D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 + D._no_weight_decay = True + setattr(self, f'D_{i}', D) + + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + self.attn = BiAttn(self.d_inner) + + def forward(self, hidden_states, inference_params=None): + """ + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + xz = self.in_proj(hidden_states) + + xs = self.multi_scan.multi_scan(xz) # [[BDL], [BDL], ...] + outs = [] + for i, xz in enumerate(xs): + # xz = rearrange(xz, "b l d -> b d l") + A = -torch.exp(getattr(self, f'A_log_{i}').float()) + conv1d = getattr(self, f'conv1d_{i}') + x_proj = getattr(self, f'x_proj_{i}') + dt_proj = getattr(self, f'dt_proj_{i}') + D = getattr(self, f'D_{i}') + + out = mamba_inner_fn_no_out_proj( + xz, + conv1d.weight, + conv1d.bias, + x_proj.weight, + dt_proj.weight, + A, + None, # input-dependent B + None, # input-dependent C + D, + delta_bias=dt_proj.bias.float(), + delta_softplus=True, + ) + outs.append(out) + + outs = self.multi_scan.multi_reverse(outs) + outs = [self.attn(rearrange(out, 'b d l -> b l d')) for out in outs] + out = self.multi_scan(outs) + out = F.linear(out, self.out_proj.weight, self.out_proj.bias) + + return out + + +try: + import selective_scan_cuda_oflex +except: + selective_scan_cuda_oflex = None + +class SelectiveScanOflex(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True): + ctx.delta_softplus = delta_softplus + out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, dout, *args): + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( + u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 + ) + return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None + + +class MultiVMamba(nn.Module): + def __init__( + self, + d_model, + d_state=16, + d_conv=4, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + use_fast_path=True, # Fused kernel options + layer_idx=None, + device=None, + dtype=None, + bimamba_type="none", + directions=None, + token_size=(14, 14), + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + self.use_fast_path = use_fast_path + self.layer_idx = layer_idx + self.bimamba_type = bimamba_type + self.token_size = token_size + + self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) + + self.activation = "silu" + self.act = nn.SiLU() + + + self.multi_scan = MultiScan(self.d_inner, choices=directions, token_size=token_size) + '''new for search''' + A = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=self.d_inner, + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + for i in range(len(self.multi_scan.choices)): + setattr(self, f'A_log_{i}', nn.Parameter(A_log)) + getattr(self, f'A_log_{i}')._no_weight_decay = True + + x_proj = nn.Linear( + self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs + ) + setattr(self, f'x_proj_{i}', x_proj) + + conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + groups=self.d_inner, + padding=d_conv - 1, + **factory_kwargs, + ) + setattr(self, f'conv1d_{i}', conv1d) + + dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = self.dt_rank**-0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + dt_proj.bias._no_reinit = True + + setattr(self, f'dt_proj_{i}', dt_proj) + + D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 + D._no_weight_decay = True + setattr(self, f'D_{i}', D) + + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + self.attn = BiAttn(self.d_inner) + + def forward(self, hidden_states, inference_params=None): + """ + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + batch_size, seq_len, dim = hidden_states.shape + xz = self.in_proj(hidden_states) + x, z = xz.chunk(2, dim=2) + z = self.act(z) + + xs = self.multi_scan.multi_scan(x) + outs = [] + for i, xz in enumerate(xs): + xz = rearrange(xz, "b l d -> b d l") + A = -torch.exp(getattr(self, f'A_log_{i}').float()) + x_proj = getattr(self, f'x_proj_{i}') + conv1d = getattr(self, f'conv1d_{i}') + dt_proj = getattr(self, f'dt_proj_{i}') + D = getattr(self, f'D_{i}') + + xz = conv1d(xz)[:, :, :seq_len] + xz = self.act(xz) + + N = A.shape[-1] + R = dt_proj.weight.shape[-1] + + x_dbl = F.linear(rearrange(xz, 'b d l -> b l d'), x_proj.weight) + dts, B, C = torch.split(x_dbl, [R, N, N], dim=2) + dts = F.linear(dts, dt_proj.weight) + + dts = rearrange(dts, 'b l d -> b d l') + B = rearrange(B, 'b l d -> b 1 d l') + C = rearrange(C, 'b l d -> b 1 d l') + D = D.float() + delta_bias = dt_proj.bias.float() + + out = SelectiveScanOflex.apply(xz.contiguous(), dts.contiguous(), A.contiguous(), B.contiguous(), C.contiguous(), D.contiguous(), delta_bias, True, True) + + outs.append(rearrange(out, "b d l -> b l d")) + + outs = self.multi_scan.multi_reverse(outs) + outs = [self.attn(out) for out in outs] + out = self.multi_scan(outs) + out = out * z + out = self.out_proj(out) + + return out + diff --git a/classification/lib/models/mamba/rope.py b/classification/lib/models/mamba/rope.py new file mode 100644 index 0000000..d835b8a --- /dev/null +++ b/classification/lib/models/mamba/rope.py @@ -0,0 +1,147 @@ +# -------------------------------------------------------- +# EVA-02: A Visual Representation for Neon Genesis +# Github source: https://github.com/baaivision/EVA/EVA02 +# Copyright (c) 2023 Beijing Academy of Artificial Intelligence (BAAI) +# Licensed under The MIT License [see LICENSE for details] +# By Yuxin Fang +# +# Based on https://github.com/lucidrains/rotary-embedding-torch +# --------------------------------------------------------' + +from math import pi + +import torch +from torch import nn + +from einops import rearrange, repeat + + + +def broadcat(tensors, dim = -1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim = dim) + + + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') + + + +class VisionRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs = None, + freqs_for = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f'unknown modality {freqs_for}') + + if ft_seq_len is None: ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs_h = torch.einsum('..., f -> ... f', t, freqs) + freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) + + freqs_w = torch.einsum('..., f -> ... f', t, freqs) + freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) + + freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) + + self.register_buffer("freqs_cos", freqs.cos()) + self.register_buffer("freqs_sin", freqs.sin()) + + print('======== shape of rope freq', self.freqs_cos.shape, '========') + + def forward(self, t, start_index = 0): + rot_dim = self.freqs_cos.shape[-1] + end_index = start_index + rot_dim + assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) + return torch.cat((t_left, t, t_right), dim = -1) + + + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__( + self, + dim, + pt_seq_len=16, + ft_seq_len=None, + custom_freqs = None, + freqs_for = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + has_cls_token = False + ): + super().__init__() + self.has_cls_token = has_cls_token + if custom_freqs: + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f'unknown modality {freqs_for}') + + if ft_seq_len is None: ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum('..., f -> ... f', t, freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) + + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + print('======== shape of rope freq', self.freqs_cos.shape, '========') + + def forward(self, t, freqs_cos=None, freqs_sin=None): + if freqs_cos is None: + freqs_cos = self.freqs_cos + if freqs_sin is None: + freqs_sin = self.freqs_sin + if self.has_cls_token: + t_spatial = t[:, 1:, :] + t_spatial = t_spatial * freqs_cos + rotate_half(t_spatial) * freqs_sin + return torch.cat((t[:, :1, :], t_spatial), dim=1) + else: + return t * freqs_cos + rotate_half(t) * freqs_sin \ No newline at end of file diff --git a/classification/lib/models/mdconv.py b/classification/lib/models/mdconv.py new file mode 100644 index 0000000..0b39e4c --- /dev/null +++ b/classification/lib/models/mdconv.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + + +def split_layer(total_channels, num_groups): + split = [int(np.ceil(total_channels / num_groups)) for _ in range(num_groups)] + split[num_groups - 1] += total_channels - sum(split) + return split + + +class DepthwiseConv2D(nn.Module): + def __init__(self, in_channels, kernal_size, stride, bias=False): + super(DepthwiseConv2D, self).__init__() + padding = (kernal_size - 1) // 2 + + self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=kernal_size, padding=padding, stride=stride, groups=in_channels, bias=bias) + + def forward(self, x): + out = self.depthwise_conv(x) + return out + + +class GroupConv2D(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, n_chunks=1, bias=False): + super(GroupConv2D, self).__init__() + self.n_chunks = n_chunks + self.split_in_channels = split_layer(in_channels, n_chunks) + split_out_channels = split_layer(out_channels, n_chunks) + + if n_chunks == 1: + self.group_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias) + else: + self.group_layers = nn.ModuleList() + for idx in range(n_chunks): + self.group_layers.append(nn.Conv2d(self.split_in_channels[idx], split_out_channels[idx], kernel_size=kernel_size, bias=bias)) + + def forward(self, x): + if self.n_chunks == 1: + return self.group_conv(x) + else: + split = torch.split(x, self.split_in_channels, dim=1) + out = torch.cat([layer(s) for layer, s in zip(self.group_layers, split)], dim=1) + return out + + +class MDConv(nn.Module): + def __init__(self, out_channels, n_chunks, stride=1, bias=False): + super(MDConv, self).__init__() + self.n_chunks = n_chunks + self.split_out_channels = split_layer(out_channels, n_chunks) + + self.layers = nn.ModuleList() + for idx in range(self.n_chunks): + kernel_size = 2 * idx + 3 + self.layers.append(DepthwiseConv2D(self.split_out_channels[idx], kernal_size=kernel_size, stride=stride, bias=bias)) + + def forward(self, x): + split = torch.split(x, self.split_out_channels, dim=1) + out = torch.cat([layer(s) for layer, s in zip(self.layers, split)], dim=1) + return out + + +# temp = torch.randn((16, 3, 32, 32)) +# group = GroupConv2D(3, 16, n_chunks=2) +# print(group(temp).size()) diff --git a/classification/lib/models/mobilenet_v1.py b/classification/lib/models/mobilenet_v1.py new file mode 100644 index 0000000..e9a6bf2 --- /dev/null +++ b/classification/lib/models/mobilenet_v1.py @@ -0,0 +1,73 @@ +import math +import torch.nn as nn + + +def _initialize_weight_goog(m): + # weight init as per Tensorflow Official impl + # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # fan-out + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(0) # fan-out + init_range = 1.0 / math.sqrt(n) + m.weight.data.uniform_(-init_range, init_range) + m.bias.data.zero_() + + +class MobileNetV1(nn.Module): + def __init__(self, ch_in=3, num_classes=1000): + super(MobileNetV1, self).__init__() + + def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True) + ) + + def conv_dw(inp, oup, stride): + return nn.Sequential( + # dw + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.ReLU(inplace=True), + + # pw + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True), + ) + + self.model = nn.Sequential( + conv_bn(ch_in, 32, 2), + conv_dw(32, 64, 1), + conv_dw(64, 128, 2), + conv_dw(128, 128, 1), + conv_dw(128, 256, 2), + conv_dw(256, 256, 1), + conv_dw(256, 512, 2), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 512, 1), + conv_dw(512, 1024, 2), + conv_dw(1024, 1024, 1), + nn.AdaptiveAvgPool2d(1) + ) + self.fc = nn.Linear(1024, num_classes) + + for m in self.modules(): + _initialize_weight_goog(m) + + def forward(self, x): + x = self.model(x) + x = x.view(-1, 1024) + x = self.fc(x) + return x diff --git a/classification/lib/models/nas_model.py b/classification/lib/models/nas_model.py new file mode 100644 index 0000000..14cdc78 --- /dev/null +++ b/classification/lib/models/nas_model.py @@ -0,0 +1,130 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from .operations import OPS, AuxiliaryHead + + +def _initialize_weight_goog(m): + # weight init as per Tensorflow Official impl + # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # fan-out + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(0) # fan-out + init_range = 1.0 / math.sqrt(n) + m.weight.data.uniform_(-init_range, init_range) + m.bias.data.zero_() + + +def _initialize_weight_default(m): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') + + +class NASModel(nn.Module): + def __init__(self, net_cfg, weight_init='goog', drop_rate=0.2, drop_path_rate=0.0, auxiliary_head=False, **kwargs): + super(NASModel, self).__init__() + self.drop_rate = drop_rate + self.drop_path_rate = drop_path_rate + if self.drop_path_rate != 0.: + raise NotImplementedError('Drop path is not implemented in NAS model.') + + backbone_cfg = net_cfg.pop('backbone') + self.features = nn.Sequential() + downsample_num = 0 + for layer in backbone_cfg: + if len(backbone_cfg[layer]) == 5: + stride, inp, oup, t, op = backbone_cfg[layer] + n = 1 + kwargs = {} + elif len(backbone_cfg[layer]) == 6 and isinstance(backbone_cfg[layer][-1], dict): + stride, inp, oup, t, op, kwargs = backbone_cfg[layer] + n = 1 + elif len(backbone_cfg[layer]) == 6: + n, stride, inp, oup, t, op = backbone_cfg[layer] + kwargs = {} + elif len(backbone_cfg[layer]) == 7: + n, stride, inp, oup, t, op, kwargs = backbone_cfg[layer] + else: + raise RuntimeError(f'Invalid layer configuration: {backbone_cfg[layer]}') + + for idx in range(n): + layer_ = layer + f'_{idx}' if n > 1 else layer + if isinstance(t, (list, tuple)) or isinstance(op, (list, tuple)): + # NAS supernet + if not isinstance(t, (list, tuple)): + t = [t] + if not isinstance(op, (list, tuple)): + op = [op] + from edgenn.models import ListChoice + blocks = [] + for t_ in t: + for op_ in op: + if op_ == 'id': + # add it later + continue + blocks.append(OPS[op_](inp, oup, t_, stride, kwargs)) + if 'id' in op: + blocks.append(OPS['id'](inp, oup, 1, stride, kwargs)) + self.features.add_module(layer_, ListChoice(blocks)) + else: + if t is None: + t = 1 + self.features.add_module(layer_, OPS[op](inp, oup, t, stride, kwargs)) + if stride == 2: + downsample_num += 1 + if auxiliary_head and downsample_num == 5: + # auxiliary head added after the 5-th downsampling layer + object.__setattr__(self, 'module_to_auxiliary', self.features[-1]) + C_to_auxiliary = oup + inp = oup + stride = 1 + + # build head + head_cfg = net_cfg.pop('head') + self.classifier = nn.Sequential() + for layer in head_cfg: + self.classifier.add_module(layer, nn.Linear(head_cfg[layer]['dim_in'], head_cfg[layer]['dim_out'])) + + if auxiliary_head: + self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, 1000) + + # init weight + for m in self.modules(): + if weight_init == 'goog': + _initialize_weight_goog(m) + else: + _initialize_weight_default(m) + + def get_classifier(self): + return self.classifier + + def forward(self, x): + x = self.features(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = x.view(x.size(0), -1) + return self.classifier(x) + + +def gen_nas_model(net_cfg, drop_rate=0.2, drop_path_rate=0.0, auxiliary_head=False, **kwargs): + model = NASModel( + net_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + auxiliary_head=auxiliary_head + ) + return model + diff --git a/classification/lib/models/operations.py b/classification/lib/models/operations.py new file mode 100644 index 0000000..fde1e69 --- /dev/null +++ b/classification/lib/models/operations.py @@ -0,0 +1,605 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict +from .mdconv import MDConv + + +OPS = OrderedDict() +OPS['id'] = lambda inp, oup, t, stride, kwargs: Identity(in_channels=inp, out_channels=oup, kernel_size=1, stride=stride, **kwargs) + +'''MixConv''' +OPS['ir_mix_se'] = lambda inp, oup, t, stride, kwargs: InvertedResidualMixConv(in_channels=inp, out_channels=oup, dw_kernel_size=3, + stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=0.25, se_gate_fn=HSigmoid, **kwargs) +OPS['ir_mix_nse'] = lambda inp, oup, t, stride, kwargs: InvertedResidualMixConv(in_channels=inp, out_channels=oup, dw_kernel_size=3, + stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=None, **kwargs) + +'''MobileNet V2 Inverted Residual''' +OPS['ir_3x3_se'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=3, + stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=0.25, se_gate_fn=HSigmoid, **kwargs) +OPS['ir_5x5_se'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=5, + stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=0.25, se_gate_fn=HSigmoid, **kwargs) +OPS['ir_7x7_se'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=7, + stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=0.25, se_gate_fn=HSigmoid, **kwargs) +OPS['ir_3x3_nse'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=3, + stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=None, **kwargs) +OPS['ir_5x5_nse'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=5, + stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=None, **kwargs) +OPS['ir_7x7_nse'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=7, + stride=stride, act_fn=HSwish, expand_ratio=t, se_ratio=None, **kwargs) +OPS['ir_3x3'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=3, + stride=stride, act_fn=nn.ReLU, expand_ratio=t, se_ratio=None, **kwargs) +OPS['ir_5x5'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=5, + stride=stride, act_fn=nn.ReLU, expand_ratio=t, se_ratio=None, **kwargs) +OPS['ir_7x7'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=7, + stride=stride, act_fn=nn.ReLU, expand_ratio=t, se_ratio=None, **kwargs) + +# assign ops with given expand ratios +class OpWrapper: + def __init__(self, t, op_func): + self.t = t + self.op_func = op_func + def __call__(self, inp, oup, t, stride, kwargs): + return self.op_func(inp, oup, self.t, stride, kwargs) + +_t = [1, 3, 6] +new_ops = {} +for op in OPS: + if 'ir' in op and 't' not in op: + for given_t in _t: + newop = op + f'_t{given_t}' + func = OpWrapper(given_t, OPS[op]) + new_ops[newop] = func #lambda inp, oup, t, stride, kwargs: OPS[op](inp, oup, given_t, stride, kwargs) +for op in new_ops: + OPS[op] = new_ops[op] + +OPS['conv1x1'] = lambda inp, oup, t, stride, kwargs: ConvBnAct(in_channels=inp, out_channels=oup, kernel_size=1, stride=stride, **kwargs) +OPS['conv3x3'] = lambda inp, oup, t, stride, kwargs: ConvBnAct(in_channels=inp, out_channels=oup, kernel_size=3, stride=stride, **kwargs) +OPS['gavgp'] = lambda inp, oup, t, stride, kwargs: nn.AdaptiveAvgPool2d(1, **kwargs) +OPS['maxp'] = lambda inp, oup, t, stride, kwargs: nn.MaxPool2d(kernel_size=2, stride=stride, **kwargs) + +OPS['linear_relu'] = lambda inp, oup, t, stride, kwargs: LinearReLU(inp, oup) + +'''for NAS-Bench-Macro''' +OPS['ir_3x3_t3'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=3, + stride=stride, act_fn=nn.ReLU, expand_ratio=3, se_ratio=None, **kwargs) +OPS['ir_5x5_t6'] = lambda inp, oup, t, stride, kwargs: InvertedResidual(in_channels=inp, out_channels=oup, dw_kernel_size=5, + stride=stride, act_fn=nn.ReLU, expand_ratio=6, se_ratio=None, **kwargs) +OPS['ID'] = lambda inp, oup, t, stride, kwargs: Identity(in_channels=inp, out_channels=oup, kernel_size=1, stride=stride, **kwargs) + + +""" +========================== +basic operations & modules +========================== +""" + +class HSwish(nn.Module): + def __init__(self, inplace=True): + super(HSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + out = x * F.relu6(x + 3, inplace=self.inplace) / 6 + return out + + +class HSigmoid(nn.Module): + def __init__(self, inplace=True): + super(HSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + out = F.relu6(x + 3, inplace=self.inplace) / 6 + return out + + +class LinearReLU(nn.Module): + def __init__(self, inp, oup,): + super(LinearReLU, self).__init__() + self.fc = nn.Sequential( + nn.Linear(inp, oup, bias=True), + nn.ReLU(inplace=True)) + + def forward(self, x): + #if x.ndims != 2: + if len(x.shape) != 2: + x = x.view(x.shape[0], -1) + return self.fc(x) + + +def conv2d(in_channels, out_channels, kernel_size, stride=1, pad_type='SAME', **kwargs): + if pad_type == 'SAME' or pad_type == '': + if isinstance(kernel_size, (tuple, list)): + padding = [(kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2] + else: + padding = (kernel_size - 1) // 2 + elif pad_type == 'NONE': + padding = 0 + else: + raise NotImplementedError('Not supported padding type: {}.'.format(pad_type)) + return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, **kwargs) + + +class ConvBnAct(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, pad_type='SAME', act_fn=nn.ReLU, **attrs): + super(ConvBnAct, self).__init__() + for k, v in attrs.items(): + setattr(self, k, v) + self.conv = conv2d(in_channels, out_channels, kernel_size, stride=stride, pad_type=pad_type, bias=False) + self.bn1 = nn.BatchNorm2d(out_channels) + self.act = act_fn(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act(x) + return x + + +class Identity(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, **kwargs): + super(Identity, self).__init__() + if in_channels != out_channels or stride != 1: + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=False), + nn.BatchNorm2d(out_channels) + ) + else: + self.conv = None + + def forward(self, x): + if self.conv is not None: + return self.conv(x) + else: + return x + + +class SqueezeExcite(nn.Module): + def __init__(self, in_channels, reduce_channels, act_fn=nn.ReLU, gate_fn=nn.Sigmoid): + super(SqueezeExcite, self).__init__() + self.avgp = nn.AdaptiveAvgPool2d(1) + self.conv_reduce = nn.Conv2d(in_channels, reduce_channels, 1, bias=True) + self.act_fn = act_fn(inplace=True) + self.conv_expand = nn.Conv2d(reduce_channels, in_channels, 1, bias=True) + self.gate_fn = gate_fn() + + def forward(self, x): + x_se = self.avgp(x) + x_se = self.conv_reduce(x_se) + x_se = self.act_fn(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +""" +========================== +ShuffleNetV2-ops +========================== +""" +OPS['shuffle_3x3_se'] = lambda inp, oup, t, stride, kwargs: ShufflenetBlock(inp, oup, ksize=3, stride=stride, activation='HSwish', use_se=True) +OPS['shuffle_5x5_se'] = lambda inp, oup, t, stride, kwargs: ShufflenetBlock(inp, oup, ksize=5, stride=stride, activation='HSwish', use_se=True) +OPS['shuffle_7x7_se'] = lambda inp, oup, t, stride, kwargs: ShufflenetBlock(inp, oup, ksize=7, stride=stride, activation='HSwish', use_se=True) +OPS['shuffle_x_se'] = lambda inp, oup, t, stride, kwargs: ShufflenetBlock(inp, oup, ksize='x', stride=stride, activation='HSwish', use_se=True) + + +def channel_shuffle(x): + batchsize, num_channels, height, width = x.data.size() + assert (num_channels % 4 == 0) + x = x.reshape(batchsize * num_channels // 2, 2, height * width) + x = x.permute(1, 0, 2) + x = x.reshape(2, -1, num_channels // 2, height, width) + return x[0], x[1] + + +class ShufflenetBlock(nn.Module): + + def __init__(self, inp, oup, ksize, stride, activation='ReLU', use_se=False, **kwargs): + super(ShufflenetBlock, self).__init__() + self.stride = stride + assert stride in [1, 2] + assert ksize in [3, 5, 7, 'x'] + base_mid_channels = oup // 2 + + self.base_mid_channel = base_mid_channels + self.ksize = ksize + pad = ksize // 2 if ksize != 'x' else 3 // 2 + self.pad = pad + if stride == 1: + inp = inp // 2 + outputs = oup - inp + else: + outputs = oup // 2 + + self.inp = inp + + + if ksize != 'x': + branch_main = [ + # pw + nn.Conv2d(inp, base_mid_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(base_mid_channels), + nn.ReLU(inplace=True) if activation == 'ReLU' else HSwish(inplace=True), + # dw + nn.Conv2d(base_mid_channels, base_mid_channels, ksize, stride, pad, groups=base_mid_channels, bias=False), + nn.BatchNorm2d(base_mid_channels), + # pw-linear + nn.Conv2d(base_mid_channels, outputs, 1, 1, 0, bias=False), + nn.BatchNorm2d(outputs), + nn.ReLU(inplace=True) if activation == 'ReLU' else HSwish(inplace=True), + ] + else: + ksize = 3 + branch_main = [ + # dw + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + # pw + nn.Conv2d(inp, base_mid_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(base_mid_channels), + nn.ReLU(inplace=True) if activation == 'ReLU' else HSwish(inplace=True), + # dw + nn.Conv2d(base_mid_channels, base_mid_channels, 3, 1, 1, groups=base_mid_channels, bias=False), + nn.BatchNorm2d(base_mid_channels), + # pw + nn.Conv2d(base_mid_channels, base_mid_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(base_mid_channels), + nn.ReLU(inplace=True) if activation == 'ReLU' else HSwish(inplace=True), + # dw + nn.Conv2d(base_mid_channels, base_mid_channels, 3, 1, 1, groups=base_mid_channels, bias=False), + nn.BatchNorm2d(base_mid_channels), + # pw + nn.Conv2d(base_mid_channels, outputs, 1, 1, 0, bias=False), + nn.BatchNorm2d(outputs), + nn.ReLU(inplace=True) if activation == 'ReLU' else HSwish(inplace=True), + ] + if use_se: + assert activation != 'ReLU' + branch_main.append(SqueezeExcite(outputs, outputs // 4, act_fn=HSwish, gate_fn=HSigmoid)) + self.branch_main = nn.Sequential(*branch_main) + + if stride == 2: + branch_proj = [ + # dw + nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=False), + nn.BatchNorm2d(inp), + # pw-linear + nn.Conv2d(inp, outputs, 1, 1, 0, bias=False), + nn.BatchNorm2d(outputs), + nn.ReLU(inplace=True) if activation == 'ReLU' else HSwish(inplace=True), + ] + self.branch_proj = nn.Sequential(*branch_proj) + else: + self.branch_proj = None + + def forward(self, old_x): + if self.stride == 1: + x_proj, x = channel_shuffle(old_x) + return torch.cat((x_proj, self.branch_main(x)), 1) + elif self.stride == 2: + x_proj = old_x + x = old_x + return torch.cat((self.branch_proj(x_proj), self.branch_main(x)), 1) + + + +""" +========================== +DARTS-ops +========================== +""" +OPS['avg_pool_3x3'] = lambda inp, oup, t, stride, kwargs: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) +OPS['max_pool_3x3'] = lambda inp, oup, t, stride, kwargs: nn.MaxPool2d(3, stride=stride, padding=1) +OPS['skip_connect'] = lambda inp, oup, t, stride, kwargs: nn.Identity() if stride == 1 else FactorizedReduce(inp, oup) +OPS['sep_conv_3x3'] = lambda inp, oup, t, stride, kwargs: SepConv(inp, oup, 3, stride) +OPS['sep_conv_5x5'] = lambda inp, oup, t, stride, kwargs: SepConv(inp, oup, 5, stride) +OPS['dil_conv_3x3'] = lambda inp, oup, t, stride, kwargs: DilConv(inp, oup, 3, stride, padding=2) +OPS['dil_conv_5x5'] = lambda inp, oup, t, stride, kwargs: DilConv(inp, oup, 5, stride, padding=4) + + +class ReLUConvBN(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride): + super(ReLUConvBN, self).__init__() + padding = (kernel_size - 1) // 2 + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), + nn.BatchNorm2d(C_out) + ) + + def forward(self, x): + return self.op(x) + + +class FactorizedReduce(nn.Module): + + def __init__(self, C_in, C_out): + super(FactorizedReduce, self).__init__() + assert C_out % 2 == 0 + self.relu = nn.ReLU(inplace=False) + self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.bn = nn.BatchNorm2d(C_out) + + def forward(self, x): + x = self.relu(x) + out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1) + out = self.bn(out) + return out + + +class SepConv(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride): + super(SepConv, self).__init__() + padding = (kernel_size - 1) // 2 + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), + nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), + nn.BatchNorm2d(C_in), + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), + nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), + nn.BatchNorm2d(C_out), + ) + + def forward(self, x): + return self.op(x) + + + +class DilConv(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation=2): + super(DilConv, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), + nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), + nn.BatchNorm2d(C_out), + ) + + def forward(self, x): + return self.op(x) + + +""" +========================== +blocks +========================== +""" + +class InvertedResidualMixConv(nn.Module): + '''Inverted Residual block from MobileNet V2''' + def __init__(self, in_channels, out_channels, dw_kernel_size=3, + stride=1, pad_type='', act_fn=nn.ReLU, + expand_ratio=1.0, se_ratio=0., se_gate_fn=nn.Sigmoid, + drop_connect_rate=0.0, use_residual=True, use_3x3_dw_only=False, **attrs): + super(InvertedResidualMixConv, self).__init__() + mid_channels = int(in_channels * expand_ratio) + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = in_channels == out_channels and stride == 1 and use_residual + self.drop_connect_rate = drop_connect_rate + + for k, v in attrs.items(): + # for edgenn: NAS and pruning + setattr(self, k, v) + + # Point-wise convolution + if expand_ratio == 1: + self.conv_pw = nn.Sequential() + else: + self.conv_pw = nn.Sequential( + conv2d(in_channels, mid_channels, 1, 1, bias=False), + nn.BatchNorm2d(mid_channels), + act_fn(inplace=True) + ) + + use_3x3_dw_only = False + # Depth-wise convolution + if not use_3x3_dw_only: + self.conv_dw = nn.Sequential( + #conv2d(mid_channels, mid_channels, dw_kernel_size, stride, groups=mid_channels, bias=False), + MDConv(mid_channels, n_chunks=3, stride=stride, bias=False), + nn.BatchNorm2d(mid_channels), + act_fn(inplace=True) + ) + else: + conv_dw = [] + for i in range((dw_kernel_size - 3) // 2 + 1): + conv_dw.extend([ + conv2d(mid_channels, mid_channels, 3, stride if i == 0 else 1, groups=mid_channels, bias=False), + nn.BatchNorm2d(mid_channels), + ]) + conv_dw.append(act_fn(inplace=True)) + self.conv_dw = nn.Sequential(*conv_dw) + + # Squeeze-and-excitation + if self.has_se: + self.se = SqueezeExcite( + mid_channels, reduce_channels=max(1, int(mid_channels * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) + + # Point-wise convolution + self.conv_pw2 = nn.Sequential( + conv2d(mid_channels, out_channels, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + residual = x + + x = self.conv_pw(x) + x = self.conv_dw(x) + + if self.has_se: + x = self.se(x) + + x = self.conv_pw2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_path(x, self.drop_connect_rate, self.training) + x += residual + + return x + + + + +class InvertedResidual(nn.Module): + '''Inverted Residual block from MobileNet V2''' + def __init__(self, in_channels, out_channels, dw_kernel_size=3, + stride=1, pad_type='', act_fn=nn.ReLU, + expand_ratio=1.0, se_ratio=0., se_gate_fn=nn.Sigmoid, + drop_connect_rate=0.0, use_residual=True, use_3x3_dw_only=False, **attrs): + super(InvertedResidual, self).__init__() + mid_channels = int(in_channels * expand_ratio) + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = in_channels == out_channels and stride == 1 and use_residual + self.drop_connect_rate = drop_connect_rate + + for k, v in attrs.items(): + # for edgenn: NAS and pruning + setattr(self, k, v) + + # Point-wise convolution + if expand_ratio == 1: + self.conv_pw = nn.Sequential() + else: + self.conv_pw = nn.Sequential( + conv2d(in_channels, mid_channels, 1, 1, bias=False), + nn.BatchNorm2d(mid_channels), + act_fn(inplace=True) + ) + + # Depth-wise convolution + if not use_3x3_dw_only: + self.conv_dw = nn.Sequential( + conv2d(mid_channels, mid_channels, dw_kernel_size, stride, groups=mid_channels, bias=False), + nn.BatchNorm2d(mid_channels), + act_fn(inplace=True) + ) + else: + conv_dw = [] + for i in range((dw_kernel_size - 3) // 2 + 1): + conv_dw.extend([ + conv2d(mid_channels, mid_channels, 3, stride if i == 0 else 1, groups=mid_channels, bias=False), + nn.BatchNorm2d(mid_channels), + ]) + conv_dw.append(act_fn(inplace=True)) + self.conv_dw = nn.Sequential(*conv_dw) + + # Squeeze-and-excitation + if self.has_se: + self.se = SqueezeExcite( + mid_channels, reduce_channels=max(1, int(mid_channels * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) + + # Point-wise convolution + self.conv_pw2 = nn.Sequential( + conv2d(mid_channels, out_channels, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + residual = x + + x = self.conv_pw(x) + x = self.conv_dw(x) + + if self.has_se: + x = self.se(x) + + x = self.conv_pw2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_path(x, self.drop_connect_rate, self.training) + x += residual + + return x + + +class DARTSCell(nn.Module): + def __init__(self, cell_arch, c_prev_prev, c_prev, c, stride=1, reduction_prev=False, steps=4): + super().__init__() + self.cell_arch = cell_arch + self.steps = steps + self.preprocess0 = FactorizedReduce(c_prev_prev, c) if reduction_prev else \ + ReLUConvBN(c_prev_prev, c, 1, stride=1) + self.preprocess1 = ReLUConvBN(c_prev, c, 1, stride=1) + + if len(cell_arch[0]) != 0 and isinstance(cell_arch[0][0], str): + # DARTS-like genotype, convert it to topo-free type + cell_arch = [[cell_arch[idx*2], cell_arch[idx*2+1]] for idx in range(len(cell_arch) // 2)] + + self.ops = nn.ModuleList() + self.inputs = [] + for step in cell_arch: + step_ops = nn.ModuleList() + step_inputs = [] + for op_name, input_idx in step: + step_ops += [OPS[op_name](c, c, None, stride if input_idx < 2 else 1, {})] + step_inputs.append(input_idx) + self.ops += [step_ops] + self.inputs.append(step_inputs) + + def forward(self, s0, s1, drop_path_rate=0.): + s0 = self.preprocess0(s0) + s1 = self.preprocess1(s1) + states = [s0, s1] + + for step_idx, (step_inputs, step_ops) in enumerate(zip(self.inputs, self.ops)): + step_outs = [] + for input_idx, op in zip(step_inputs, step_ops): + out = op(states[input_idx]) + if drop_path_rate > 0. and not isinstance(op, (FactorizedReduce, nn.Identity)): + out = drop_path(out, drop_path_rate, self.training) + step_outs.append(out) + states.append(sum(step_outs)) + + return torch.cat(states[-4:], dim=1) + + +""" +========================= +Auxiliary Heads +========================= +""" +class AuxiliaryHead(nn.Module): + + def __init__(self, C, num_classes, avg_pool_stride=2): + """with avg_pol_stride=2, assuming input size 14x14""" + super(AuxiliaryHead, self).__init__() + self.features = nn.Sequential( + nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), + nn.Conv2d(C, 128, 1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 768, 2, bias=False), + nn.BatchNorm2d(768), + nn.ReLU(inplace=True) + ) + self.classifier = nn.Linear(768, num_classes) + + def forward(self, x): + x = self.features(x) + x = self.classifier(x.view(x.size(0),-1)) + return x + + + diff --git a/classification/lib/models/operations_resnet.py b/classification/lib/models/operations_resnet.py new file mode 100644 index 0000000..32fd997 --- /dev/null +++ b/classification/lib/models/operations_resnet.py @@ -0,0 +1,192 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Type, Any, Callable, Union, List, Optional +from torch import Tensor +from collections import OrderedDict +from .operations import OPS, conv2d, ConvBnAct, SqueezeExcite + + +'''ResNet''' +OPS['maxp_3x3'] = lambda inp, oup, t, stride, kwargs: nn.MaxPool2d(kernel_size=3, stride=stride, padding=1) +OPS['conv7x7'] = lambda inp, oup, t, stride, kwargs: ConvBnAct(inp, oup, kernel_size=7, stride=stride, **kwargs) +OPS['res_3x3'] = lambda inp, oup, t, stride, kwargs: Bottleneck(inplanes=inp, outplanes=oup, kernel_size=3, stride=stride, **kwargs) +OPS['res_5x5'] = lambda inp, oup, t, stride, kwargs: Bottleneck(inplanes=inp, outplanes=oup, kernel_size=5, stride=stride, **kwargs) +OPS['res_7x7'] = lambda inp, oup, t, stride, kwargs: Bottleneck(inplanes=inp, outplanes=oup, kernel_size=7, stride=stride, **kwargs) +OPS['res_3x3_se'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=3, stride=stride, use_se=True, expansion=4, **kwargs) +OPS['res_5x5_se'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=5, stride=stride, use_se=True, expansion=4, **kwargs) +OPS['res_7x7_se'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=7, stride=stride, use_se=True, expansion=4, **kwargs) +OPS['res_3x3_se_e'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=3, stride=stride, use_se=True, expansion=t, **kwargs) +OPS['res_5x5_se_e'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=5, stride=stride, use_se=True, expansion=t, **kwargs) +OPS['res_7x7_se_e'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=7, stride=stride, use_se=True, expansion=t, **kwargs) +OPS['resnext_3x3'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=3, stride=stride, **kwargs) +OPS['resnext_5x5'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=5, stride=stride, **kwargs) +OPS['resnext_7x7'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=7, stride=stride, **kwargs) +OPS['resnext_3x3_se'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=3, stride=stride, use_se=True, expansion=4, **kwargs) +OPS['resnext_5x5_se'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=5, stride=stride, use_se=True, expansion=4, **kwargs) +OPS['resnext_7x7_se'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=7, stride=stride, use_se=True, expansion=4, **kwargs) +OPS['resnext_3x3_se_e'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=3, stride=stride, use_se=True, expansion=t, **kwargs) +OPS['resnext_5x5_se_e'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=5, stride=stride, use_se=True, expansion=t, **kwargs) +OPS['resnext_7x7_se_e'] = lambda inp, oup, t, stride, kwargs: ResNeXtBottleneck(inplanes=inp, outplanes=oup, kernel_size=7, stride=stride, use_se=True, expansion=t, **kwargs) + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + outplanes: int, + stride: int = 1, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + kernel_size: int = 3, + use_se: bool = False, + planes: int = None, + expansion = 4 + ) -> None: + super(Bottleneck, self).__init__() + self.expansion = expansion + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + if stride != 1 or inplanes != outplanes: + self.downsample = nn.Sequential( + nn.Conv2d(inplanes, outplanes, stride=stride, kernel_size=1, bias=False), + norm_layer(outplanes), + ) + if planes is None: + planes = int(inplanes // self.expansion * 2) + else: + self.downsample = None + planes = int(inplanes // self.expansion) + + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False) + self.bn1 = norm_layer(width) + self.conv2 = conv2d(width, width, kernel_size, stride, bias=False, groups=groups) + self.bn2 = norm_layer(width) + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + if use_se: + self.se = SqueezeExcite(outplanes, reduce_channels=max(1, outplanes // 16)) + else: + self.se = None + self.relu = nn.ReLU(inplace=True) + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNeXtBottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + def __init__( + self, + inplanes: int, + outplanes: int, + stride: int = 1, + groups: int = 32, + base_width: int = 4, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + kernel_size: int = 3, + use_se: bool = False, + planes: int = None, + expansion = 4, + ) -> None: + super(ResNeXtBottleneck, self).__init__() + self.expansion = expansion + + if stride != 1 or inplanes != outplanes: + self.downsample = nn.Sequential( + nn.Conv2d(inplanes, outplanes, stride=stride, kernel_size=1, bias=False), + nn.BatchNorm2d(outplanes), + ) + if planes is None: + planes = int(inplanes // self.expansion * 2 ) + else: + self.downsample = None + planes = int(inplanes // self.expansion) + + width = math.floor(planes * (base_width / 64)) * groups + self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, + stride=1) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = conv2d(width, width, kernel_size=kernel_size, stride=stride, + groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(outplanes) + if use_se: + self.se = SqueezeExcite(outplanes, reduce_channels=max(1, outplanes // 16)) + else: + self.se = None + self.relu = nn.ReLU(inplace=True) + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + + + diff --git a/classification/lib/models/resnet.py b/classification/lib/models/resnet.py new file mode 100644 index 0000000..40e9ec8 --- /dev/null +++ b/classification/lib/models/resnet.py @@ -0,0 +1,392 @@ +"""resnet implemented in torchvision: +https://pytorch.org/vision/stable/_modules/torchvision/models/resnet.html +""" +import torch +from torch import Tensor +import torch.nn as nn +from typing import Type, Any, Callable, Union, List, Optional + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + return model + + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + + + +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + + + +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + + + +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + + + +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + diff --git a/classification/lib/models/utils/__init__.py b/classification/lib/models/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/classification/lib/models/utils/dbb/__init__.py b/classification/lib/models/utils/dbb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/classification/lib/models/utils/dbb/dbb_block.py b/classification/lib/models/utils/dbb/dbb_block.py new file mode 100644 index 0000000..edb6d23 --- /dev/null +++ b/classification/lib/models/utils/dbb/dbb_block.py @@ -0,0 +1,510 @@ +import math + +import torch +import torch.nn as nn + +from .dbb_transforms import (transI_fusebn, transII_addbranch, + transIII_1x1_kxk, transVI_multiscale, + transV_avg, transIX_bn_to_1x1) + + +class ConvBN(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + bn=nn.BatchNorm2d): + super(ConvBN, self).__init__() + + self.in_channels = in_channels + self.kernel_size = kernel_size + self.out_channels = out_channels + self.stride = stride + self.dilation = dilation + self.groups = groups + self.padding = padding + + self.conv = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + + self.bn = bn(num_features=out_channels, + affine=True, + track_running_stats=True) + self.deployed = False + + def forward(self, input): + output = self.conv(input) + output = self.bn(output) + return output + + +class BNAndPad(nn.Module): + def __init__(self, + pad_pixels, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + last_conv_bias=None, + bn=nn.BatchNorm2d): + super(BNAndPad, self).__init__() + self.bn = bn(num_features, eps, momentum, affine, track_running_stats) + self.pad_pixels = pad_pixels + self.last_conv_bias = last_conv_bias + + def forward(self, input): + output = self.bn(input) + if self.pad_pixels > 0: + bias = -self.bn.running_mean + if self.last_conv_bias is not None: + bias += self.last_conv_bias + pad_values = self.bn.bias.data + self.bn.weight.data * ( + bias / torch.sqrt(self.bn.running_var + self.bn.eps)) + ''' pad ''' + n, c, h, w = output.size() + values = pad_values.view(1, -1, 1, 1) + w_values = values.expand(n, -1, self.pad_pixels, w) + x = torch.cat([w_values, output, w_values], dim=2) + h = h + self.pad_pixels * 2 + h_values = values.expand(n, -1, h, self.pad_pixels) + x = torch.cat([h_values, x, h_values], dim=3) + output = x + return output + + @property + def weight(self): + return self.bn.weight + + @property + def bias(self): + return self.bn.bias + + @property + def running_mean(self): + return self.bn.running_mean + + @property + def running_var(self): + return self.bn.running_var + + @property + def eps(self): + return self.bn.eps + + +class DiverseBranchBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + branches=[1, 1, 1, 1, 1, 1, 1 + ], # stands for 1x1, 1x1_kxk, 1x1_avg, kxk, 1xk, kx1, id + internal_channels=None, # internal channel between 1x1 and kxk + nonlinear=None, + ori_conv=None, + padding=None, + bn=nn.BatchNorm2d, + recal_bn_fn=None, + **kwargs): + super(DiverseBranchBlock, self).__init__() + if isinstance(stride, tuple): + stride = stride[0] + if not (out_channels == in_channels and stride == 1): + branches[6] = 0 + assert branches[3] == 1 # original kxk branch should always be active + self.deployed = False + self.branches = branches + if nonlinear is None: + self.nonlinear = nn.Sequential() + else: + self.nonlinear = nonlinear + self.in_channels = in_channels + self.kernel_size = kernel_size + self.out_channels = out_channels + self.stride = stride + self.dilation = dilation + self.groups = groups + if padding is None: + padding = kernel_size // 2 + self.padding = padding + + self.active_branch_num = sum(branches) + if branches[0]: + self.dbb_1x1 = ConvBN(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + padding=0, + groups=groups, + bn=bn) + if branches[1]: + if internal_channels is None: + internal_channels = in_channels + self.dbb_1x1_kxk = nn.Sequential() + self.dbb_1x1_kxk.add_module( + 'conv1', + nn.Conv2d(in_channels=in_channels, + out_channels=internal_channels, + kernel_size=1, + stride=1, + padding=0, + groups=groups, + bias=False)) + self.dbb_1x1_kxk.add_module( + 'bn1', + BNAndPad(pad_pixels=padding, + num_features=internal_channels, + affine=True, + last_conv_bias=self.dbb_1x1_kxk.conv1.bias, + bn=bn)) + self.dbb_1x1_kxk.add_module( + 'conv2', + nn.Conv2d(in_channels=internal_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=0, + groups=groups, + bias=False)) + self.dbb_1x1_kxk.add_module( + 'bn2', + bn(num_features=out_channels, + affine=True, + track_running_stats=True)) + if branches[2]: + self.dbb_1x1_avg = nn.Sequential() + if self.groups < self.out_channels: + self.dbb_1x1_avg.add_module( + 'conv', + nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + groups=groups, + bias=False)) + self.dbb_1x1_avg.add_module( + 'bn', + BNAndPad(pad_pixels=padding, + num_features=out_channels, + last_conv_bias=self.dbb_1x1_avg.conv.bias, + bn=bn)) + self.dbb_1x1_avg.add_module( + 'avg', + nn.AvgPool2d(kernel_size=kernel_size, + stride=stride, + padding=0)) + else: + self.dbb_1x1_avg.add_module( + 'avg', + nn.AvgPool2d(kernel_size=kernel_size, + stride=stride, + padding=padding)) + self.dbb_1x1_avg.add_module( + 'avgbn', + bn( + num_features=out_channels, + affine=True, + track_running_stats=True, + )) + if branches[3]: + self.dbb_kxk = ConvBN(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=True, + bn=bn) + if branches[4]: + self.dbb_1xk = ConvBN(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, kernel_size), + stride=stride, + padding=(0, self.padding), + dilation=dilation, + groups=groups, + bias=False, + bn=bn) + if branches[5]: + self.dbb_kx1 = ConvBN(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_size, 1), + stride=stride, + padding=(self.padding, 0), + dilation=dilation, + groups=groups, + bias=False, + bn=bn) + if branches[6]: + self.dbb_id = bn( + num_features=out_channels, + affine=True, + track_running_stats=True, + ) + + if ori_conv is not None: + self.recal_bn_fn = recal_bn_fn + + def branch_weights(self): + def _cal_weight(data): + return data.abs().mean().item() # L1 + + weights = [-1] * len(self.branches) + kxk_weight = _cal_weight(self.dbb_kxk.bn.weight.data) + # Make the weight of kxk branch as 1, + # this is for better generalization of the thrd value (lambda) + weights[3] = 1 + if self.branches[0]: + weights[0] = _cal_weight(self.dbb_1x1.bn.weight.data) / kxk_weight + if self.branches[1]: + weights[1] = _cal_weight( + self.dbb_1x1_kxk[-1].weight.data) / kxk_weight + if self.branches[2]: + weights[2] = _cal_weight( + self.dbb_1x1_avg[-1].weight.data) / kxk_weight + if self.branches[4]: + weights[4] = _cal_weight(self.dbb_1xk.bn.weight.data) / kxk_weight + if self.branches[5]: + weights[5] = _cal_weight(self.dbb_kx1.bn.weight.data) / kxk_weight + if self.branches[6]: + weights[6] = _cal_weight(self.dbb_id.weight.data) / kxk_weight + return weights + + def _reset_dbb(self, + kernel, + bias, + no_init_branches=[0, 0, 0, 0, 0, 0, 0, 0]): + self._init_branch(self.dbb_kxk, set_zero=True, norm=1) + if self.branches[0] and no_init_branches[0] == 0: + self._init_branch(self.dbb_1x1) + if self.branches[1] and no_init_branches[1] == 0: + self._init_branch(self.dbb_1x1_kxk) + if self.branches[2] and no_init_branches[2] == 0: + self._init_branch(self.dbb_1x1_avg) + if self.branches[4] and no_init_branches[4] == 0: + self._init_branch(self.dbb_1xk) + if self.branches[5] and no_init_branches[5] == 0: + self._init_branch(self.dbb_kx1) + if self.branches[6] and no_init_branches[6] == 0: + self._init_branch(self.dbb_id) + + if self.recal_bn_fn is not None and sum( + no_init_branches) == 0 and isinstance(kernel, nn.Parameter): + self.dbb_kxk.conv.weight.data.copy_(kernel) + if bias is not None: + self.dbb_kxk.conv.bias = bias + self.recal_bn_fn(self) + self.dbb_kxk.bn.reset_running_stats() + cur_w, cur_b = self.get_actual_kernel(ignore_kxk=True) + # reverse dbb transform + new_w = kernel.data.to(cur_w.device) - cur_w + if bias is not None: + new_b = bias.data.to(cur_b.device) - cur_b + else: + new_b = -cur_b + + if isinstance(self.dbb_kxk.conv, nn.Conv2d): + if isinstance(self.dbb_kxk.bn, nn.BatchNorm2d): + self.dbb_kxk.bn.weight.data.fill_(1.) + self.dbb_kxk.bn.bias.data.zero_() + self.dbb_kxk.conv.weight.data = new_w + self.dbb_kxk.conv.bias.data = new_b + elif isinstance(self.dbb_kxk.conv, DiverseBranchBlock): + self.dbb_kxk.conv._reset_dbb(new_w, new_b) + + def _init_branch(self, branch, set_zero=False, norm=0.01): + bns = [] + for m in branch.modules(): + if isinstance(m, nn.Conv2d): + if set_zero: + m.weight.data.zero_() + else: + n = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels # fan-out + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + bns.append(m) + for idx, m in enumerate(bns): + m.reset_parameters() + m.reset_running_stats() + if idx == len(bns) - 1: + m.weight.data.fill_(norm) # set to a small value + else: + m.weight.data.fill_(1.) + if m.bias is not None: + m.bias.data.zero_() + + def get_actual_kernel(self, ignore_kxk=False): + if self.deployed: + return self.conv_deployed.weight.data, self.conv_deployed.bias.data + ws = [] + bs = [] + if not ignore_kxk: # kxk-bn + if isinstance(self.dbb_kxk.conv, nn.Conv2d): + w, b = self.dbb_kxk.conv.weight, self.dbb_kxk.conv.bias + elif isinstance(self.dbb_kxk.conv, DiverseBranchBlock): + w, b = self.dbb_kxk.conv.get_actual_kernel() + if not isinstance(self.dbb_kxk.bn, nn.Identity): + w, b = transI_fusebn(w, self.dbb_kxk.bn, b) + ws.append(w.unsqueeze(0)) + bs.append(b.unsqueeze(0)) + if self.branches[0]: # 1x1-bn + w_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, + self.dbb_1x1.bn, + self.dbb_1x1.conv.bias) + w_1x1 = transVI_multiscale(w_1x1, self.kernel_size) + ws.append(w_1x1.unsqueeze(0)) + bs.append(b_1x1.unsqueeze(0)) + if self.branches[1]: # 1x1-bn-kxk-bn + if isinstance(self.dbb_1x1_kxk.conv2, nn.Conv2d): + w_1x1_kxk, b_1x1_kxk = transI_fusebn( + self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2, + self.dbb_1x1_kxk.conv2.bias) + elif isinstance(self.dbb_1x1_kxk.conv2, DiverseBranchBlock): + w_1x1_kxk, b_1x1_kxk = \ + self.dbb_1x1_kxk.conv2.get_actual_kernel() + w_1x1_kxk, b_1x1_kxk = transI_fusebn(w_1x1_kxk, + self.dbb_1x1_kxk.bn2, + b_1x1_kxk) + w_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn( + self.dbb_1x1_kxk.conv1.weight, self.dbb_1x1_kxk.bn1, + self.dbb_1x1_kxk.conv1.bias) + w_1x1_kxk, b_1x1_kxk = transIII_1x1_kxk(w_1x1_kxk_first, + b_1x1_kxk_first, + w_1x1_kxk, + b_1x1_kxk, + groups=self.groups) + ws.append(w_1x1_kxk.unsqueeze(0)) + bs.append(b_1x1_kxk.unsqueeze(0)) + if self.branches[2]: # 1x1-bn-avg-bn + w_1x1_avg = transV_avg(self.out_channels, self.kernel_size, + self.groups) + w_1x1_avg, b_1x1_avg = transI_fusebn( + w_1x1_avg.to(self.dbb_1x1_avg.avgbn.weight.device), + self.dbb_1x1_avg.avgbn, None) + if self.groups < self.out_channels: + w_1x1_avg_first, b_1x1_avg_first = transI_fusebn( + self.dbb_1x1_avg.conv.weight, self.dbb_1x1_avg.bn, + self.dbb_1x1_avg.conv.bias) + w_1x1_avg, b_1x1_avg = transIII_1x1_kxk(w_1x1_avg_first, + b_1x1_avg_first, + w_1x1_avg, + b_1x1_avg, + groups=self.groups) + ws.append(w_1x1_avg.unsqueeze(0)) + bs.append(b_1x1_avg.unsqueeze(0)) + if self.branches[4]: # 1xk-bn + w_1xk, b_1xk = transI_fusebn(self.dbb_1xk.conv.weight, + self.dbb_1xk.bn, + self.dbb_1xk.conv.bias) + w_1xk = transVI_multiscale(w_1xk, self.kernel_size) + ws.append(w_1xk.unsqueeze(0)) + bs.append(b_1xk.unsqueeze(0)) + if self.branches[5]: # kx1-bn + w_kx1, b_kx1 = transI_fusebn(self.dbb_kx1.conv.weight, + self.dbb_kx1.bn, + self.dbb_kx1.conv.bias) + w_kx1 = transVI_multiscale(w_kx1, self.kernel_size) + ws.append(w_kx1.unsqueeze(0)) + bs.append(b_kx1.unsqueeze(0)) + if self.branches[6]: # BN + w_id, b_id = transIX_bn_to_1x1(self.dbb_id, + self.dbb_kxk.conv.in_channels, + self.dbb_kxk.conv.groups) + w_id = transVI_multiscale(w_id, self.kernel_size) + ws.append(w_id.unsqueeze(0)) + bs.append(b_id.unsqueeze(0)) + + ws = torch.cat(ws) + bs = torch.cat(bs) + + return transII_addbranch(ws, bs) + + def switch_to_deploy(self): + if self.deployed: + return + w, b = self.get_actual_kernel() + + self.conv_deployed = nn.Conv2d(in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + bias=True) + + self.conv_deployed.weight.data = w + self.conv_deployed.bias.data = b + for para in self.parameters(): + para.detach_() + if self.branches[0]: + self.__delattr__('dbb_1x1') + if self.branches[1]: + self.__delattr__('dbb_1x1_kxk') + if self.branches[2]: + self.__delattr__('dbb_1x1_avg') + if self.branches[3]: + self.__delattr__('dbb_kxk') + if self.branches[4]: + self.__delattr__('dbb_1xk') + if self.branches[5]: + self.__delattr__('dbb_kx1') + if self.branches[6]: + self.__delattr__('dbb_id') + self.deployed = True + + def forward(self, inputs): + if self.deployed: + return self.nonlinear(self.conv_deployed(inputs)) + + branch_outs = [] + branch_outs.append(self.dbb_kxk(inputs)) + if self.branches[0]: + branch_outs.append(self.dbb_1x1(inputs)) + if self.branches[1]: + branch_outs.append(self.dbb_1x1_kxk(inputs)) + if self.branches[2]: + branch_outs.append(self.dbb_1x1_avg(inputs)) + if self.branches[4]: + branch_outs.append(self.dbb_1xk(inputs)) + if self.branches[5]: + branch_outs.append(self.dbb_kx1(inputs)) + if self.branches[6]: + branch_outs.append(self.dbb_id(inputs)) + + out = self.nonlinear(torch.stack(branch_outs).sum(0)) + return out + + def cut_branch(self, branches): + ori_w, ori_b = self.get_actual_kernel() + _branch_names = [ + 'dbb_1x1', 'dbb_1x1_kxk', 'dbb_1x1_avg', 'dbb_kxk', 'dbb_1xk', + 'dbb_kx1', 'dbb_id' + ] + for idx, status in enumerate(branches): + if status == 0 and self.branches[idx] == 1: + self.branches[idx] = 0 + self.__delattr__(_branch_names[idx]) + self._reset_dbb(ori_w, ori_b, no_init_branches=branches) + diff --git a/classification/lib/models/utils/dbb/dbb_transforms.py b/classification/lib/models/utils/dbb/dbb_transforms.py new file mode 100644 index 0000000..1a54b56 --- /dev/null +++ b/classification/lib/models/utils/dbb/dbb_transforms.py @@ -0,0 +1,126 @@ +import numpy as np +import torch +import torch.nn.functional as F + + +def restore_bn(kernel, bn, conv_bias): + gamma = bn.weight + std = (bn.running_var + bn.eps).sqrt() + bias = -bn.running_mean + new_bias = (conv_bias - bn.bias) / gamma * std - bias + new_weight = kernel * (std / gamma).reshape(-1, 1, 1, 1) + return new_weight, new_bias + + +def transI_fusebn(kernel, bn, conv_bias): + gamma = bn.weight + std = (bn.running_var + bn.eps).sqrt() + bias = -bn.running_mean + if conv_bias is not None: + bias += conv_bias + return kernel * ( + (gamma / std).reshape(-1, 1, 1, 1)), bn.bias + bias * gamma / std + + +def transII_addbranch(kernels, biases): + return torch.sum(kernels, dim=0), torch.sum(biases, dim=0) + + +def transIII_1x1_kxk(k1, b1, k2, b2, groups=1): + if groups == 1: + k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) + b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) + else: + k_slices = [] + b_slices = [] + k1_T = k1.permute(1, 0, 2, 3) + k1_group_width = k1.size(0) // groups + k2_group_width = k2.size(0) // groups + for g in range(groups): + k1_T_slice = k1_T[:, g * k1_group_width:(g + 1) * + k1_group_width, :, :] + k2_slice = k2[g * k2_group_width:(g + 1) * k2_group_width, :, :, :] + k_slices.append(F.conv2d(k2_slice, k1_T_slice)) + b_slices.append( + (k2_slice * + b1[g * k1_group_width:(g + 1) * k1_group_width].reshape( + 1, -1, 1, 1)).sum((1, 2, 3))) + k, b_hat = transIV_depthconcat(k_slices, b_slices) + return k, b_hat + b2 + + +def transIV_depthconcat(kernels, biases): + return torch.cat(kernels), torch.cat(biases) + + +def transV_avg(channels, kernel_size, groups): + input_dim = channels // groups + k = torch.zeros((channels, input_dim, kernel_size, kernel_size)) + k[np.arange(channels).tolist(), + np.tile(np.arange(input_dim), groups).tolist( + ), :, :] = 1.0 / kernel_size**2 + return k + + +def transVI_multiscale(kernel, target_kernel_size): + """ + NOTE: This has not been tested with non-square kernels + (kernel.size(2) != kernel.size(3)) nor even-size kernels + """ + W_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2 + H_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2 + return F.pad( + kernel, + [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad]) + + +def transVII_kxk_1x1(k1, b1, k2, b2): + return F.conv2d(k1.permute(1, 0, 2, 3), + k2).permute(1, 0, 2, + 3), (k2 * b1.reshape(-1, 1, 1, 1)).sum( + (1, 2, 3)) + b2 + + +def transIIX_kxk_kxk(k1, b1, k2, b2, groups=1): + k1 = torch.from_numpy( + np.flip(np.flip(np.array(k1), axis=3), axis=2).copy()) + k_size = k1.size(2) + padding = k_size // 2 + 1 + if groups == 1: + k = F.conv2d(k2, k1.permute(1, 0, 2, 3), padding=padding) + b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) + else: + k_slices = [] + b_slices = [] + k1_T = k1.permute(1, 0, 2, 3) + k1_group_width = k1.size(0) // groups + k2_group_width = k2.size(0) // groups + for g in range(groups): + k1_T_slice = k1_T[:, g * k1_group_width:(g + 1) * + k1_group_width, :, :] + k2_slice = k2[g * k2_group_width:(g + 1) * k2_group_width, :, :, :] + k_slices.append(F.conv2d(k2_slice, k1_T_slice, padding=padding)) + b_slices.append( + (k2_slice * + b1[g * k1_group_width:(g + 1) * k1_group_width].reshape( + 1, -1, 1, 1)).sum((1, 2, 3))) + k, b_hat = transIV_depthconcat(k_slices, b_slices) + return k, b_hat + b2 + + +def transIX_bn_to_1x1(bn, in_channels, groups=1): + input_dim = in_channels // groups + kernel_value = np.zeros((in_channels, input_dim, 3, 3), dtype=np.float32) + for i in range(in_channels): + kernel_value[i, i % input_dim, 1, 1] = 1 + id_tensor = torch.from_numpy(kernel_value).to(bn.weight.device) + kernel = id_tensor + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + diff --git a/classification/lib/models/utils/dbb_converter.py b/classification/lib/models/utils/dbb_converter.py new file mode 100644 index 0000000..ccb62f6 --- /dev/null +++ b/classification/lib/models/utils/dbb_converter.py @@ -0,0 +1,41 @@ +import torch.nn as nn + +from .dbb.dbb_block import DiverseBranchBlock + + +# Convert all the 3x3 convs in the model to DBB blocks +def convert_to_dbb(model, ignore_key=None, dbb_branches=[1, 1, 1, 1, 0, 0, 0]): + named_children = list(model.named_children()) + next_bn = False + for k, m in named_children: + if k == '': + continue + if ignore_key is not None and k.startswith(ignore_key): + continue + if isinstance( + m, nn.Conv2d + ) and m.kernel_size[0] == 3 and m.kernel_size[0] == m.kernel_size[1]: + # dbb_branches = [1, 1, 1, 1, 1, 1, 0] + # dbb_branches = [1, 1, 1, 1, 0, 0, 0] + if m.padding[0] != m.kernel_size[0] // 2: + dbb_branches_ = [0, 1, 1, 1, 0, 0, 0] + else: + dbb_branches_ = dbb_branches + setattr( + model, k, + DiverseBranchBlock(m.in_channels, + m.out_channels, + m.kernel_size[0], + stride=m.stride, + groups=m.groups, + padding=m.padding[0], + ori_conv=None, + branches=dbb_branches_, + use_bn=True)) + next_bn = True + if isinstance(m, nn.BatchNorm2d) and next_bn: + setattr(model, k, nn.Identity()) + next_bn = False + else: + convert_to_dbb(m, ignore_key=None) + print(model) diff --git a/classification/lib/models/utils/dyrep.py b/classification/lib/models/utils/dyrep.py new file mode 100644 index 0000000..55bca4c --- /dev/null +++ b/classification/lib/models/utils/dyrep.py @@ -0,0 +1,252 @@ +import random +import logging +import warnings +import torch +import torch.distributed as dist +import torch.nn as nn + +from lib.utils.optim import get_params +from lib.utils.misc import AverageMeter +from .dbb.dbb_block import DiverseBranchBlock + + +logger = logging.getLogger() + + +class DyRep(object): + def __init__(self, + model, + optimizer, + recal_bn_fn=None, + grow_metric='synflow', + dbb_branches=[1, 1, 1, 1, 1, 1, 1], + filter_bias_and_bn=False): + self.model = model + self.recal_bn_fn = recal_bn_fn + self.optimizer = optimizer + self.filter_bias_and_bn = filter_bias_and_bn # used in optimizer get_params + + accept_metrics = ('grad_norm', 'snip', 'synflow', 'random') + assert grow_metric in accept_metrics, \ + f'DyRep supports metrics {accept_metrics}, ' \ + f'but gets {grow_metric}' + self.grow_metric = grow_metric + self.dbb_branches = dbb_branches + # valid dbb branches for conv with unequal shapes of input and output + self.dbb_branches_unequal = [ + v if i not in (0, 4, 5, 6) else 0 + for i, v in enumerate(dbb_branches) + ] + + # dict for recording the metric of each conv modules + self._metric_records = {} + self._weight_records = {} + + self.new_param_group = None + + self.last_growed_module = 'none' + + def _get_module(self, path): + path_split = path.split('.') + m = self.model + for key in path_split: + if not hasattr(m, key): + return None + m = getattr(m, key) + return m + + def record_metrics(self): + for k, m in self.model.named_modules(): + if not isinstance(m, nn.Conv2d) \ + or m.kernel_size[0] != m.kernel_size[1] \ + or m.kernel_size[0] == 1 \ + or k.count('dbb') >= 2: + # Requirements for growing the module: + # 1. the module is a nn.Conv2d module; + # 2. it must has the same kernel_size (>1) in `h` and `w` axes; + # 3. we restrict the number of growths in each layer. + continue + + if m.weight.grad is None: + continue + grad = m.weight.grad.data.reshape(-1) + weight = m.weight.data.reshape(-1) + + if self.grow_metric == 'grad_norm': + metric_val = grad.norm().item() + elif self.grow_metric == 'snip': + metric_val = (grad * weight).abs().sum().item() + elif self.grow_metric == 'synflow': + metric_val = (grad * weight).sum().item() + elif self.grow_metric == 'random': + metric_val = random.random() + + if k not in self._metric_records: + self._metric_records[k] = AverageMeter(dist=True) + self._metric_records[k].update(metric_val) + + def _grow(self, metric_records_sorted, topk=1): + if len(metric_records_sorted) == 0: + return + for i in range(topk): + conv_to_grow = metric_records_sorted[i][0] + logger.info('grow: {}'.format(conv_to_grow)) + len_parent_str = conv_to_grow.rfind('.') + if len_parent_str != -1: + parent = conv_to_grow[:len_parent_str] + conv_key = conv_to_grow[len_parent_str + 1:] + # get the target conv module and its parent + parent_m = self._get_module(parent) + else: + conv_key = conv_to_grow + parent_m = self.model + conv_m = getattr(parent_m, conv_key, None) + # replace target conv module with DBB + conv_m_padding = conv_m.padding[0] + conv_m_kernel_size = conv_m.kernel_size[0] + + if conv_m_padding == conv_m_kernel_size // 2: + dbb_branches = self.dbb_branches.copy() + else: + dbb_branches = self.dbb_branches_unequal.copy() + dbb_block = DiverseBranchBlock( + conv_m.in_channels, + conv_m.out_channels, + conv_m_kernel_size, + stride=conv_m.stride, + groups=conv_m.groups, + padding=conv_m_padding, + ori_conv=conv_m, + branches=dbb_branches, + use_bn=True, + bn=nn.BatchNorm2d, + recal_bn_fn=self.recal_bn_fn).cuda() + setattr(parent_m, conv_key, dbb_block) + dbb_block._reset_dbb(conv_m.weight, conv_m.bias) + self.last_growed_module = conv_to_grow + logger.info(str(self.model)) + + def _cut(self, dbb_key, cut_branches, remove_bn=False): + dbb_m = self._get_module(dbb_key) + if dbb_m is None: + return + if sum(cut_branches) == 1: + # only keep the original 3x3 conv + parent = self._get_module(dbb_key[:dbb_key.rfind('.')]) + weight, bias = dbb_m.get_actual_kernel() + conv = nn.Conv2d(dbb_m.in_channels, + dbb_m.out_channels, + dbb_m.kernel_size, + stride=dbb_m.stride, + groups=dbb_m.groups, + padding=dbb_m.padding, + bias=True).cuda() + conv.weight.data = weight + conv.bias.data = bias + setattr(parent, dbb_key[dbb_key.rfind('.') + 1:], conv) + else: + dbb_m.cut_branch(cut_branches) + + def _reset_optimizer(self): + param_groups = get_params(self.model, lr=0.1, weight_decay=1e-5, filter_bias_and_bn=self.filter_bias_and_bn, sort_params=True) + + # remove the states of removed paramters + assert len(param_groups) == len(self.optimizer.param_groups) + for param_group, param_group_old in zip(param_groups, self.optimizer.param_groups): + params, params_old = param_group['params'], param_group_old['params'] + params = set(params) + for param_old in params_old: + if param_old not in params: + if param_old in self.optimizer.state: + del self.optimizer.state[param_old] + param_group_old['params'] = param_group['params'] + + def adjust_model(self): + records = {} + for key in self._metric_records: + records[key] = self._metric_records[key].avg + + metric_records_sorted = sorted(records.items(), + key=lambda item: item[1], + reverse=True) + + logger.info('metric: {}'.format(metric_records_sorted)) + self._grow(metric_records_sorted) + + # reset records + self._metric_records = {} + + for k, m in self.model.named_modules(): + if isinstance(m, DiverseBranchBlock): + weights = m.branch_weights() + logger.info(k + ': ' + str(weights)) + valid_weights = torch.tensor( + [x for x in weights[:3] + weights[4:] if x not in [-1, 1]]) + if valid_weights.std() > 0.02: + mean = valid_weights.mean() + # cut those branches less than 0.1 + need_cut = False + cut_branches = [1] * len(weights) + for idx in range(len(weights)): + if weights[idx] < mean and weights[idx] < 0.1: + cut_branches[idx] = 0 + if weights[idx] != -1: + need_cut = True + if need_cut: + self._cut(k, cut_branches) + logger.info( + f'cut: {k}, new branches: {cut_branches}') + + self._reset_optimizer() + + def state_dict(self): + # save dbb graph + res = {} + res['dbb_graph'] = self.dbb_graph() + return res + + def load_state_dict(self, state_dict): + if 'dbb_graph' in state_dict: + self.load_dbb_graph(state_dict['dbb_graph']) + + + def dbb_graph(self): + dbb_list = [] + + def traverse(parent, prefix=''): + for k, m in parent.named_children(): + path = prefix + '.' + k if prefix != '' else k + if isinstance(m, DiverseBranchBlock): + dbb_list.append((path, m.branches)) + traverse(m, prefix=path) + + traverse(self.model) + print(dbb_list) + return dbb_list + + def load_dbb_graph(self, dbb_list: list): + if dbb_list is None or len(dbb_list) == 0: + return + print(dbb_list) + assert not any( + [isinstance(m, DiverseBranchBlock) + for m in self.model.modules()]), 'model must be clean' + for key, branches in dbb_list: + parent = self._get_module(key[:key.rfind('.')]) + conv_key = key[key.rfind('.') + 1:] + conv_m = getattr(parent, conv_key) + dbb_m = DiverseBranchBlock(conv_m.in_channels, + conv_m.out_channels, + conv_m.kernel_size[0], + stride=conv_m.stride, + groups=conv_m.groups, + padding=conv_m.padding[0], + ori_conv=conv_m, + branches=branches, + use_bn=True) + setattr(parent, conv_key, dbb_m) + self.model.cuda() + # reset optimizer + if self.optimizer is not None: + self._reset_optimizer() + # print(self.model) diff --git a/classification/lib/models/utils/recal_bn.py b/classification/lib/models/utils/recal_bn.py new file mode 100644 index 0000000..a348771 --- /dev/null +++ b/classification/lib/models/utils/recal_bn.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn +import torch.distributed as dist +import logging + + +logger = logging.getLogger() + + +def recal_bn(model, train_loader, recal_bn_iters=200, module=None): + status = model.training + model.eval() + m = model if module is None else module + if recal_bn_iters > 0: + # recal bn + logger.info(f'recalculating bn stats {recal_bn_iters} iters') + for mod in m.modules(): + if isinstance(mod, nn.BatchNorm2d) or issubclass(mod.__class__, nn.BatchNorm2d): + mod.reset_running_stats() + # for small recal_bn_iters like 20, must set mod.momentum = None + # for big recal_bn_iters like 300, mod.momentum can be 0.1 + mod.momentum = None + mod.train() + + with torch.no_grad(): + cnt = 0 + while cnt < recal_bn_iters: + for i, (images, target) in enumerate(train_loader): + images = images.cuda() + target = target.cuda() + output = model(images) + cnt += 1 + if i % 20 == 0 or cnt == recal_bn_iters: + logger.info(f'recal bn iter {i}') + if cnt >= recal_bn_iters: + break + + for mod in m.modules(): + if isinstance(mod, nn.BatchNorm2d) or issubclass(mod.__class__, nn.BatchNorm2d): + if mod.track_running_stats: + dist.all_reduce(mod.running_mean) + dist.all_reduce(mod.running_var) + mod.running_mean /= dist.get_world_size() + mod.running_var /= dist.get_world_size() + mod.momentum = 0.1 + model.train(status) diff --git a/classification/lib/utils/args.py b/classification/lib/utils/args.py new file mode 100644 index 0000000..8fb3b38 --- /dev/null +++ b/classification/lib/utils/args.py @@ -0,0 +1,211 @@ +import argparse +import yaml +import torch + + +class ParseKwargs(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, dict()) + for value in values: + key, value = value.split('=') + getattr(namespace, self.dest)[key] = value + + +config_parser = argparse.ArgumentParser(description='Training Config', add_help=False) +config_parser.add_argument('-c', '--config', default='', type=str, + help='YAML config file specifying default arguments') + + +parser = argparse.ArgumentParser(description='ImageNet Training') + +# Dataset / Model parameters +parser.add_argument('--dataset', default='imagenet', type=str, choices=['cifar10', 'cifar100', 'imagenet'], + help='Dataset to use') +parser.add_argument('--data-path', default='', type=str, + help='Path to load dataset') +parser.add_argument('--model', default='nas_model', type=str, + help='Name of model to train (default: "countception"') +parser.add_argument('--model-config', type=str, default='', + help='Path to net config. Used for NAS model.') +parser.add_argument('--resume', default='', type=str, + help='Resume the states of model, optimizer, etc. in a checkpoint file') +parser.add_argument('-b', '--batch-size', type=int, default=32, + help='input batch size for training (default: 32)') +parser.add_argument('--val-batch-size-multiplier', type=float, default=1.0, + help='batch size of validation data equals to (batch-size * val-batch-size-multiplier)') +parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower') +parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss') +parser.add_argument('--smoothing', default=0.1, type=float, + help='Epsilon value of label smoothing') + +# Optimizer parameters +parser.add_argument('--opt', default='sgd', type=str, + help='Optimizer (default: "sgd"') +parser.add_argument('--opt-eps', default=1e-8, type=float, + help='Optimizer Epsilon (default: 1e-8, use opt default)') +parser.add_argument('--opt-no-filter', action='store_true', default=False, + help='disable bias and bn filter of weight decay') +parser.add_argument('--momentum', type=float, default=0.9, + help='Optimizer momentum (default: 0.9)') +parser.add_argument('--sgd-no-nesterov', action='store_true', default=False, + help='set nesterov=False in SGD optimizer') +parser.add_argument('--weight-decay', type=float, default=0.0001, + help='weight decay (default: 0.0001)') +parser.add_argument('--clip-grad-norm', action='store_true', default=False, + help='clip gradients of network') +parser.add_argument('--clip-grad-max-norm', type=float, default=5., + help='value of max_norm in clip_grad_norm') +parser.add_argument('--amp', action='store_true', default=False, + help='Use automatic mixed precision training (torch.cuda.amp)') + +# Learning rate schedule parameters +parser.add_argument('--sched', default='step', type=str, + help='LR scheduler (default: "step"') +parser.add_argument('--decay-epochs', type=float, default=3, + help='epoch interval to decay LR') +parser.add_argument('--lr', type=float, default=0.01, + help='learning rate (default: 0.01)') +parser.add_argument('--warmup-lr', type=float, default=0.0001, + help='warmup learning rate (default: 0.0001)') +parser.add_argument('--min-lr', type=float, default=1e-5, + help='minimal learning rate (default: 1e-5)') +parser.add_argument('--epochs', type=int, default=200, + help='number of epochs to train (default: 2)') +parser.add_argument('--warmup-epochs', type=int, default=3, + help='epochs to warmup LR, if scheduler supports') +parser.add_argument('--decay-rate', '--dr', type=float, + help='LR decay rate') +parser.add_argument('--decay_by_epoch', action='store_true', default=False, + help='decay LR by epoch, valid only for cosine scheduler') + +# Augmentation & regularization parameters +parser.add_argument('--image-mean', type=float, nargs=3, default=None, + help='Mean values of image normalization') +parser.add_argument('--image-std', type=float, nargs=3, default=None, + help='Std values of image normalization') +parser.add_argument('--interpolation', type=str, default='bilinear', choices=['bilinear', 'bicubic'], + help='Interpolation mode in image resize') +parser.add_argument('--color-jitter', type=float, default=0., + help='Color jitter factor (default: 0.)') +parser.add_argument('--cutout-length', type=int, default=0, + help='Cutout length. Only used in CIFAR transforms') +parser.add_argument('--aa', type=str, default=None, + help='Use AutoAugment policy. "v0" or "original". (default: None)'), +parser.add_argument('--reprob', type=float, default=0., + help='Random erase prob (default: 0.)') +parser.add_argument('--remode', type=str, default='const', + help='Random erase mode (default: "const")') +parser.add_argument('--drop', type=float, default=0.0, + help='Dropout rate (default: 0.)') +parser.add_argument('--drop-path-rate', type=float, default=0., + help='Drop path rate, (default: 0.)') +parser.add_argument('--drop-path-strategy', type=str, default='const', choices=['const', 'linear'], + help='Drop path rate update strategy, default: const') + +# Mixup +parser.add_argument('--mixup', type=float, default=0., + help='mixup alpha, mixup enabled if > 0. (default: 0.8)') +parser.add_argument('--cutmix', type=float, default=0., + help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') +parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') +parser.add_argument('--mixup-prob', type=float, default=1.0, + help='Probability of performing mixup or cutmix when either/both is enabled') +parser.add_argument('--mixup-switch-prob', type=float, default=0.5, + help='Probability of switching to cutmix when both mixup and cutmix enabled') +parser.add_argument('--mixup-mode', type=str, default='batch', + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') + +# Model Exponential Moving Average +parser.add_argument('--model-ema', action='store_true', default=False, + help='Enable tracking moving average of model weights') +parser.add_argument('--model-ema-decay', type=float, default=0.9998, + help='decay factor for model weights moving average (default: 0.9998)') + +# Misc +parser.add_argument('--seed', type=int, default=42, + help='random seed (default: 42)') +parser.add_argument('--log-interval', type=int, default=50, + help='how many batches to wait before logging training status') +parser.add_argument('-j', '--workers', type=int, default=4, + help='how many training processes to use (default: 1)') +parser.add_argument('--experiment', default='exp', type=str, + help='name of train experiment, name of sub-folder for output') +parser.add_argument('--slurm', action='store_true', default=False, + help='Use slurm') +if torch.__version__ >= '2.0.0': + parser.add_argument('--local-rank', default=0, type=int, + help='local rank of current process in distributed running') +else: + parser.add_argument('--local_rank', default=0, type=int, + help='local rank of current process in distributed running') +parser.add_argument('--dist-port', default='12345', type=str, + help='port for distributed communication') + +# KD +parser.add_argument('--kd', type=str, default='', + help='Knowledge distillation method. Default: disable') +parser.add_argument('--teacher-model', type=str, default='', + help='teacher model name') +parser.add_argument('--teacher-pretrained', action='store_true', + help='load pretrained model of teacher') +parser.add_argument('--teacher-no-pretrained', action='store_false', dest='teacher_pretrained') +parser.set_defaults(teacher_pretrained=True) +parser.add_argument('--teacher-ckpt', type=str, default='', + help='path to the ckpt of teacher model') +parser.add_argument('--kd-loss-weight', type=float, default=1., + help='weight of kd loss') +parser.add_argument('--ori-loss-weight', type=float, default=1., + help='weight of original loss') +parser.add_argument('--teacher-module', type=str, default='', + help='name of the teacher module used in kd. Default (""): use the output of model.') +parser.add_argument('--student-module', type=str, default='', + help='name of the student module used in kd. Default (""): use the output of model.') +parser.add_argument('--kd-loss-kwargs', nargs='*', action=ParseKwargs) + +# DBB +parser.add_argument('--dbb', action='store_true', default=False, + help='Use DBB') + +# DyRep +parser.add_argument('--dyrep', action='store_true', default=False, + help='Use DyRep') +parser.add_argument('--dyrep-adjust-interval', type=int, default=10, + help='how many epochs to rep & dep the dyrep model') +parser.add_argument('--dyrep-max-adjust-epochs', type=int, default=100, + help='after how many epochs the dyrep model will be fixed.') +parser.add_argument('--dyrep-recal-bn-iters', type=int, default=20, + help='how many iterations for recalibrating the bn states in dyrep') +parser.add_argument('--dyrep-recal-bn-every-epoch', action='store_true', default=False, + help='Recal BN after every epoch in DyRep') + +# EdgeNN +parser.add_argument('--edgenn-config', type=str, default='', + help='path to edgenn config') + + +def parse_args(): + # Do we have a config file to parse? + args_config, remaining = config_parser.parse_known_args() + default_dicts = {} + if args_config.config: + with open(args_config.config, 'r') as f: + cfg = yaml.safe_load(f) + parser.set_defaults(**cfg) + + for k, v in cfg.items(): + if isinstance(v, dict): + default_dicts[k] = v + + # The main arg parser parses the rest of the args, the usual + # defaults will have been overridden if config file specified. + args = parser.parse_args(remaining) + for k, v in default_dicts.items(): + v.update(args.__dict__[k]) + args.__dict__[k] = v + + # Cache the args as a text string to save them in the output dir later + args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) + return args, args_text + + diff --git a/classification/lib/utils/dist_utils.py b/classification/lib/utils/dist_utils.py new file mode 100644 index 0000000..081963c --- /dev/null +++ b/classification/lib/utils/dist_utils.py @@ -0,0 +1,77 @@ +import os +import time +import shutil +import logging +import subprocess +import torch + + +def init_dist(args): + args.distributed = False + if 'WORLD_SIZE' in os.environ: + args.distributed = int(os.environ['WORLD_SIZE']) > 1 + if args.slurm: + args.distributed = True + if not args.distributed: + # task with single GPU also needs to use distributed module + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['WORLD_SIZE'] = '1' + os.environ['LOCAL_RANK'] = '0' + os.environ['RANK'] = '0' + args.local_rank = 0 + args.distributed = True + + args.device = 'cuda:0' + args.world_size = 1 + args.rank = 0 # global rank + if args.distributed: + if args.slurm: + # processes are created with slurm + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + args.local_rank = proc_id % num_gpus + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + print(f'Using slurm with master node: {addr}, rank: {proc_id}, world size: {ntasks}') + else: + addr = os.environ['MASTER_ADDR'] + ntasks = os.environ['WORLD_SIZE'] + proc_id = os.environ['RANK'] + args.local_rank = int(os.environ['LOCAL_RANK']) + print(f'Using torch.distributed with master node: {addr}, rank: {proc_id}, local_rank: {args.local_rank} world size: {ntasks}') + + + #os.environ['MASTER_PORT'] = args.dist_port + args.device = 'cuda:%d' % args.local_rank + torch.distributed.init_process_group(backend='nccl', init_method='env://') + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + # if not args.slurm: + torch.cuda.set_device(args.local_rank) + print(f'Training in distributed model with multiple processes, 1 GPU per process. Process {args.rank}, total {args.world_size}.') + else: + print('Training with a single process on 1 GPU.') + + +# create logger file handler for rank 0, +# ignore the outputs of the other ranks +def init_logger(args): + logger = logging.getLogger() + if args.rank == 0: + if not os.path.exists(args.exp_dir): + os.makedirs(args.exp_dir) + logger.setLevel(logging.INFO) + fh = logging.FileHandler(os.path.join(args.exp_dir, f'log_{time.strftime("%Y%m%d_%H%M%S", time.localtime())}.txt')) + fh.setFormatter(logging.Formatter(fmt='%(asctime)s %(levelname)s %(message)s', + datefmt='%y-%m-%d %H:%M:%S')) + logger.addHandler(fh) + logger.info(f'Experiment directory: {args.exp_dir}') + + else: + logger.setLevel(logging.ERROR) diff --git a/classification/lib/utils/gen_network.py b/classification/lib/utils/gen_network.py new file mode 100644 index 0000000..b1964f6 --- /dev/null +++ b/classification/lib/utils/gen_network.py @@ -0,0 +1,48 @@ +import yaml +import sys +from itertools import product +import json + + +def gen_network(supernet_fp, subnet, output_fp=''): + supernet = yaml.safe_load(open(supernet_fp, 'r')) + network = supernet.copy() + for layer in network['backbone']: + _, _, _, t, op, *_ = network['backbone'][layer] + if isinstance(t, (list, tuple)) or isinstance(op, (list, tuple)): + if not isinstance(t, (list, tuple)): + t = [t] + if not isinstance(op, (list, tuple)): + op = [op] + has_id = 'id' in op + if has_id: + op.remove('id') + blocks = list(product(t, op)) + if has_id: + blocks += [(1, 'id')] + selected_block = blocks[subnet.pop(0)] + network['backbone'][layer][3] = selected_block[0] + network['backbone'][layer][4] = selected_block[1] + assert len(subnet) == 0 + res = [] + dict_formatter(network, res) + network = '\n'.join(res) + '\n' + if output_fp != '': + open(output_fp, 'w').write(network) + return network + + +def dict_formatter(item, res, indent=0): + if isinstance(item, dict): + for key in item: + if not isinstance(item[key], dict): + res.append(' '*indent + key + ': ' + str(item[key])) + else: + res.append(' '*indent + key + ':') + dict_formatter(item[key], res, indent+4) + + + +if __name__ == '__main__': + print(gen_network(sys.argv[1], list(eval(sys.argv[2])), sys.argv[3])) + diff --git a/classification/lib/utils/measure.py b/classification/lib/utils/measure.py new file mode 100644 index 0000000..9c415d9 --- /dev/null +++ b/classification/lib/utils/measure.py @@ -0,0 +1,67 @@ +import torch + + +def get_params(model, ignore_auxiliary_head=True): + if not ignore_auxiliary_head: + params = sum([m.numel() for m in model.parameters()]) + else: + params = sum([m.numel() for k, m in model.named_parameters() if 'auxiliary_head' not in k]) + return params + +def get_flops(model, input_shape=(3, 224, 224)): + if hasattr(model, 'flops'): + return model.flops(input_shape) + else: + return get_flops_hook(model, input_shape) + +def get_flops_hook(model, input_shape=(3, 224, 224)): + is_training = model.training + list_conv = [] + + def conv_hook(self, input, output): + batch_size, input_channels, input_height, input_width = input[0].size() + output_channels, output_height, output_width = output[0].size() + + assert self.in_channels % self.groups == 0 + + kernel_ops = self.kernel_size[0] * self.kernel_size[ + 1] * (self.in_channels // self.groups) + params = output_channels * kernel_ops + flops = batch_size * params * output_height * output_width + + list_conv.append(flops) + + list_linear = [] + + def linear_hook(self, input, output): + batch_size = input[0].size(0) if input[0].dim() == 2 else 1 + + weight_ops = self.weight.nelement() + + flops = batch_size * weight_ops + list_linear.append(flops) + + def foo(net, hook_handle): + childrens = list(net.children()) + if not childrens: + if isinstance(net, torch.nn.Conv2d): + hook_handle.append(net.register_forward_hook(conv_hook)) + if isinstance(net, torch.nn.Linear): + hook_handle.append(net.register_forward_hook(linear_hook)) + return + for c in childrens: + foo(c, hook_handle) + + hook_handle = [] + foo(model, hook_handle) + input = torch.rand(*input_shape).unsqueeze(0).to(next(model.parameters()).device) + model.eval() + with torch.no_grad(): + out = model(input) + for handle in hook_handle: + handle.remove() + + total_flops = sum(sum(i) for i in [list_conv, list_linear]) + model.train(is_training) + return total_flops + diff --git a/classification/lib/utils/misc.py b/classification/lib/utils/misc.py new file mode 100644 index 0000000..d9de467 --- /dev/null +++ b/classification/lib/utils/misc.py @@ -0,0 +1,178 @@ +import shutil +import os +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +import logging +logger = logging.getLogger() + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0) + res.append(correct_k.item() / batch_size) + return res + + +class AverageMeter: + """Computes and stores the average and current value""" + def __init__(self, dist=False): + self.dist = dist + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.local_val = 0 + self.local_sum = 0 + self.local_count = 0 + + def update(self, val, n=1): + self.local_val = val + self.local_sum += val * n + self.local_count += n + if not self.dist: + self.val = self.local_val + self.sum = self.local_sum + self.count = self.local_count + self.avg = self.sum / self.count + else: + self._dist_reduce() + + def _dist_reduce(self): + '''gather results from all ranks''' + reduce_tensor = torch.Tensor([self.local_val, self.local_sum, self.local_count]).cuda() + dist.all_reduce(reduce_tensor) + world_size = dist.get_world_size() + self.val = reduce_tensor[0].item() / world_size + self.sum = reduce_tensor[1].item() + self.count = reduce_tensor[2].item() + self.avg = self.sum / self.count + + +class CheckpointManager(): + def __init__(self, model, optimizer=None, ema_model=None, save_dir='', keep_num=10, rank=0, additions={}): + self.model = model + self.optimizer = optimizer + self.ema_model = ema_model + self.additions = additions + self.save_dir = save_dir + self.keep_num = keep_num + self.rank = rank + self.ckpts = [] + if self.rank == 0: + if not os.path.exists(save_dir): + os.makedirs(save_dir) + self.metrics_fp = open(os.path.join(save_dir, 'metrics.csv'), 'a') + self.metrics_fp.write('epoch,train_loss,test_loss,top1,top5\n') + + def update(self, epoch, metrics, score_key='top1'): + if self.rank == 0: + self.metrics_fp.write('{},{},{},{},{}\n'.format(epoch, metrics['train_loss'], metrics['test_loss'], metrics['top1'], metrics['top5'])) + self.metrics_fp.flush() + + score = metrics[score_key] + insert_idx = 0 + for ckpt_, score_ in self.ckpts: + if score > score_: + break + insert_idx += 1 + if insert_idx < self.keep_num: + save_path = os.path.join(self.save_dir, 'checkpoint-{}.pth.tar'.format(epoch)) + self.ckpts.insert(insert_idx, [save_path, score]) + if len(self.ckpts) > self.keep_num: + remove_ckpt = self.ckpts.pop(-1)[0] + if self.rank == 0: + if os.path.exists(remove_ckpt): + os.remove(remove_ckpt) + self._save(save_path, epoch, is_best=(insert_idx == 0)) + else: + self._save(os.path.join(self.save_dir, 'last.pth.tar'), epoch) + return self.ckpts + + def _save(self, save_path, epoch, is_best=False): + if self.rank != 0: + return + save_dict = { + 'epoch': epoch, + 'model': self.model.module.state_dict() if isinstance(self.model, DDP) else self.model.state_dict(), + 'ema_model': self.ema_model.state_dict() if self.ema_model else None, + 'optimizer': self.optimizer.state_dict() if self.optimizer else None, + } + for key, value in self.additions.items(): + save_dict[key] = value.state_dict() if hasattr(value, 'state_dict') else value + + torch.save(save_dict, save_path) + if save_path != os.path.join(self.save_dir, 'last.pth.tar'): + shutil.copy(save_path, os.path.join(self.save_dir, 'last.pth.tar')) + if is_best: + shutil.copy(save_path, os.path.join(self.save_dir, 'best.pth.tar')) + + def load(self, ckpt_path): + save_dict = torch.load(ckpt_path, map_location='cpu') + + for key, value in self.additions.items(): + if hasattr(value, 'load_state_dict'): + value.load_state_dict(save_dict[key]) + else: + self.additions[key] = save_dict[key] + + if 'state_dict' in save_dict and 'model' not in save_dict: + save_dict['model'] = save_dict['state_dict'] + if isinstance(self.model, DDP): + missing_keys, unexpected_keys = \ + self.model.module.load_state_dict(save_dict['model'], strict=False) + else: + missing_keys, unexpected_keys = \ + self.model.load_state_dict(save_dict['model'], strict=False) + if len(missing_keys) != 0: + logger.info(f'Missing keys in source state dict: {missing_keys}') + if len(unexpected_keys) != 0: + logger.info(f'Unexpected keys in source state dict: {unexpected_keys}') + + if self.ema_model is not None and 'ema_model' in save_dict: + self.ema_model.load_state_dict(save_dict['ema_model']) + if self.optimizer is not None and 'optimizer' in save_dict: + self.optimizer.load_state_dict(save_dict['optimizer']) + + if 'epoch' in save_dict: + epoch = save_dict['epoch'] + else: + epoch = -1 + + '''avoid memory leak''' + del save_dict + torch.cuda.empty_cache() + + return epoch + + +class AuxiliaryOutputBuffer: + + def __init__(self, model, loss_weight=1.0): + self.loss_weight = loss_weight + self.model = model + self.aux_head = model.module.auxiliary_head + self._output = None + self.model.module.module_to_auxiliary.register_forward_hook(lambda net, input, output: self._forward_hook(net, input, output)) + + def _forward_hook(self, net, input, output): + if net.training: + self._output = self.aux_head(output) + + @property + def output(self): + output = self._output + self._output = None + return output diff --git a/classification/lib/utils/model_ema.py b/classification/lib/utils/model_ema.py new file mode 100644 index 0000000..6c5ad4b --- /dev/null +++ b/classification/lib/utils/model_ema.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from copy import deepcopy + + +class ModelEMA(nn.Module): + """ Model Exponential Moving Average V2 + Implemented by: https://github.com/rwightman/pytorch-image-models/tree/master/timm/utils/model_ema.py + + Keep a moving average of everything in the model state_dict (parameters and buffers). + V2 of this module is simpler, it does not match params/buffers based on name but simply + iterates in order. It works with torchscript (JIT of full model). + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__(self, model, decay=0.9999, device=None): + super(ModelEMA, self).__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if self.device is not None: + self.module.to(device=device) + + def _update(self, model, update_fn): + with torch.no_grad(): + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) + + diff --git a/classification/lib/utils/optim.py b/classification/lib/utils/optim.py new file mode 100644 index 0000000..0eacaed --- /dev/null +++ b/classification/lib/utils/optim.py @@ -0,0 +1,183 @@ +import torch +import torch.optim as optim + + +def build_optimizer(opt, model, lr, eps=1e-10, momentum=0.9, weight_decay=1e-5, filter_bias_and_bn=True, nesterov=True, sort_params=False): + # params in dyrep must be sorted to make sure optimizer can correctly + # load the states in resuming + params = get_params(model, lr, weight_decay, filter_bias_and_bn, sort_params=sort_params) + + if opt == 'rmsprop': + optimizer = optim.RMSprop(params, lr, eps=eps, weight_decay=weight_decay, momentum=momentum) + elif opt == 'rmsproptf': + optimizer = RMSpropTF(params, lr, eps=eps, weight_decay=weight_decay, momentum=momentum) + elif opt == 'sgd': + optimizer = optim.SGD(params, lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov) + elif opt == 'adamw': + optimizer = optim.AdamW(params, lr, eps=eps, weight_decay=weight_decay) + else: + raise NotImplementedError(f'Optimizer {opt} not implemented.') + return optimizer + + +def get_params(model, lr, weight_decay=1e-5, filter_bias_and_bn=True, sort_params=False): + if weight_decay != 0 and filter_bias_and_bn: + if hasattr(model, 'no_weight_decay'): + skip_list = model.no_weight_decay() + print(f'no weight decay: {skip_list}') + else: + skip_list = () + params = _add_weight_decay(model, lr, weight_decay, skip_list=skip_list, sort_params=sort_params) + weight_decay = 0 + else: + named_params = list(model.named_parameters()) + if sort_params: + named_params.sort(key=lambda x: x[0]) + params = [x[1] for x in named_params] + params = [{'params': params, 'initial_lr': lr}] + return params + + +def _add_weight_decay(model, lr, weight_decay=1e-5, skip_list=(), sort_params=False): + decay = [] + no_decay = [] + named_params = list(model.named_parameters()) + if sort_params: + named_params.sort(key=lambda x: x[0]) + for name, param in named_params: + if not param.requires_grad: + continue # frozen weights + skip = False + for skip_name in skip_list: + if skip_name.startswith('[g]'): + if skip_name[3:] in name: + skip = True + elif name == skip_name: + skip = True + if len(param.shape) == 1 or name.endswith(".bias") or skip: + no_decay.append(param) + else: + decay.append(param) + return [ + {'params': no_decay, 'weight_decay': 0., 'initial_lr': lr}, + {'params': decay, 'weight_decay': weight_decay, 'initial_lr': lr}] + + +class RMSpropTF(optim.Optimizer): + """Implements RMSprop algorithm (TensorFlow style epsilon) + NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt + and a few other modifications to closer match Tensorflow for matching hyper-params. + Noteworthy changes include: + 1. Epsilon applied inside square-root + 2. square_avg initialized to ones + 3. LR scaling of update accumulated in momentum buffer + Proposed by G. Hinton in his + `course `_. + The centered version first appears in `Generating Sequences + With Recurrent Neural Networks `_. + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + momentum (float, optional): momentum factor (default: 0) + alpha (float, optional): smoothing (decay) constant (default: 0.9) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-10) + centered (bool, optional) : if ``True``, compute the centered RMSProp, + the gradient is normalized by an estimation of its variance + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 + lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer + update as per defaults in Tensorflow + """ + + def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, + decoupled_decay=False, lr_in_momentum=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= momentum: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + + defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, + decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) + super(RMSpropTF, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RMSpropTF, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('momentum', 0) + group.setdefault('centered', False) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('RMSprop does not support sparse gradients') + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero + if group['momentum'] > 0: + state['momentum_buffer'] = torch.zeros_like(p.data) + if group['centered']: + state['grad_avg'] = torch.zeros_like(p.data) + + square_avg = state['square_avg'] + one_minus_alpha = 1. - group['alpha'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + if 'decoupled_decay' in group and group['decoupled_decay']: + p.data.add_(-group['weight_decay'], p.data) + else: + grad = grad.add(group['weight_decay'], p.data) + + # Tensorflow order of ops for updating squared avg + square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg) + # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original + + if group['centered']: + grad_avg = state['grad_avg'] + grad_avg.add_(one_minus_alpha, grad - grad_avg) + # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original + avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt + else: + avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt + + if group['momentum'] > 0: + buf = state['momentum_buffer'] + # Tensorflow accumulates the LR scaling in the momentum buffer + if 'lr_in_momentum' in group and group['lr_in_momentum']: + buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg) + p.data.add_(-buf) + else: + # PyTorch scales the param update by LR + buf.mul_(group['momentum']).addcdiv_(grad, avg) + p.data.add_(-group['lr'], buf) + else: + p.data.addcdiv_(-group['lr'], grad, avg) + + return loss + + diff --git a/classification/lib/utils/scheduler.py b/classification/lib/utils/scheduler.py new file mode 100644 index 0000000..1c8cc3c --- /dev/null +++ b/classification/lib/utils/scheduler.py @@ -0,0 +1,99 @@ +from collections import OrderedDict +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, LambdaLR + + +def build_scheduler(sched_type, optimizer, warmup_steps, warmup_lr, step_size, decay_rate, total_steps=-1, multiplier=1, steps_per_epoch=1, decay_by_epoch=True, min_lr=1e-5): + if sched_type == 'step': + scheduler = StepLR(optimizer, step_size, gamma=decay_rate) + decay_by_epoch = False + elif sched_type == 'cosine': + scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=min_lr) + elif sched_type == 'linear': + scheduler = LambdaLR(optimizer, lambda epoch: (total_steps - warmup_steps - epoch) / (total_steps - warmup_steps)) + else: + raise NotImplementedError(f'Scheduler {sched_type} not implemented.') + scheduler = GradualWarmupScheduler(optimizer, multiplier=multiplier, total_epoch=warmup_steps, after_scheduler=scheduler, warmup_lr=warmup_lr, step_size=steps_per_epoch, decay_by_epoch=decay_by_epoch) + return scheduler + + +class GradualWarmupScheduler(_LRScheduler): + """ Gradually warm-up(increasing) learning rate in optimizer. + Modified based on: https://github.com/ildoonet/pytorch-gradual-warmup-lr + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. + Args: + optimizer (Optimizer): Wrapped optimizer. + multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. + total_epoch: target learning rate is reached at total_epoch, gradually + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) + warmup_lr: warmup learning rate for the first epoch + step_size: step number in one epoch + decay_by_epoch: if True, decay lr in after_scheduler after each epoch; otherwise decay after every step + """ + + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None, warmup_lr=1e-6, step_size=1, decay_by_epoch=True): + self.multiplier = multiplier + if self.multiplier < 1.: + raise ValueError('multiplier should be greater thant or equal to 1.') + self.total_epoch = total_epoch + self.after_scheduler = after_scheduler + self.warmup_lr = warmup_lr + self.step_size = step_size + self.finished = False + if self.total_epoch == 0: + self.finished = True + self.total_epoch = -1 + self.decay_by_epoch = decay_by_epoch + super(GradualWarmupScheduler, self).__init__(optimizer) + + def get_lr(self): + if self.last_epoch > self.total_epoch or self.finished: + if self.after_scheduler: + if not self.finished: + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] + self.finished = True + return self.after_scheduler.get_lr() + return [base_lr * self.multiplier for base_lr in self.base_lrs] + + if self.multiplier == 1.0: + return [self.warmup_lr + (base_lr - self.warmup_lr) * (float(self.last_epoch // self.step_size * self.step_size) / self.total_epoch) for base_lr in self.base_lrs] + else: + return [base_lr * ((self.multiplier - 1.) * (self.last_epoch // self.step_size * self.step_size) / self.total_epoch + 1.) for base_lr in self.base_lrs] + + def step_ReduceLROnPlateau(self, metrics, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + if self.last_epoch <= self.total_epoch: + if self.multiplier == 1.0: + warmup_lr = [self.warmup_lr + (base_lr - self.warmup_lr) * (float(self.last_epoch // self.step_size * self.step_size) / self.total_epoch) for base_lr in self.base_lrs] + else: + warmup_lr = [base_lr * ((self.multiplier - 1.) * (self.last_epoch // self.step_size * self.step_size) / self.total_epoch + 1.) for base_lr in self.base_lrs] + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): + param_group['lr'] = lr + else: + if epoch is None: + self.after_scheduler.step(metrics, None) + else: + if self.decay_by_epoch: + self.after_scheduler.step(metrics, (epoch - self.total_epoch - 1) // self.step_size * self.step_size) + else: + self.after_scheduler.step(metrics, epoch - self.total_epoch - 1) + + def step(self, epoch=None, metrics=None): + if type(self.after_scheduler) != ReduceLROnPlateau: + if self.finished and self.after_scheduler: + if epoch is None: + self.after_scheduler.step(None) + else: + if self.decay_by_epoch: + self.after_scheduler.step((epoch - self.total_epoch - 1) // self.step_size * self.step_size) + else: + self.after_scheduler.step(epoch - self.total_epoch - 1) + self._last_lr = self.after_scheduler.get_lr() + else: + return super(GradualWarmupScheduler, self).step(epoch) + else: + self.step_ReduceLROnPlateau(metrics, epoch) + + diff --git a/classification/tools/convert.py b/classification/tools/convert.py new file mode 100644 index 0000000..f7d6495 --- /dev/null +++ b/classification/tools/convert.py @@ -0,0 +1,162 @@ +import os +import torch +import torch.nn as nn +import logging +import time +from torch.nn.parallel import DistributedDataParallel as DDP + +from lib.models.builder import build_model +from lib.models.loss import CrossEntropyLabelSmooth +from lib.models.utils.dbb.dbb_block import DiverseBranchBlock +from lib.dataset.builder import build_dataloader +from lib.utils.args import parse_args +from lib.utils.dist_utils import init_dist, init_logger +from lib.utils.misc import accuracy, AverageMeter, CheckpointManager +from lib.utils.model_ema import ModelEMA +from lib.utils.measure import get_params, get_flops + +torch.backends.cudnn.benchmark = True +'''init logger''' +logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', + datefmt='%H:%M:%S') +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +def main(): + args, args_text = parse_args() + assert args.resume != '' + args.exp_dir = f'{os.path.dirname(args.resume)}/convert' + + '''distributed''' + init_dist(args) + init_logger(args) + + '''build dataloader''' + train_dataset, val_dataset, train_loader, val_loader = \ + build_dataloader(args) + + '''build model''' + if args.smoothing == 0.: + loss_fn = nn.CrossEntropyLoss().cuda() + else: + loss_fn = CrossEntropyLabelSmooth(num_classes=args.num_classes, + epsilon=args.smoothing).cuda() + + model = build_model(args) + logger.info( + f'Model {args.model} created, params: {get_params(model)}, ' + f'FLOPs: {get_flops(model, input_shape=args.input_shape)}') + + # Diverse Branch Blocks + if args.dbb: + # convert 3x3 convs to dbb blocks + from lib.models.utils.dbb_converter import convert_to_dbb + convert_to_dbb(model) + logger.info(model) + logger.info( + f'Converted to DBB blocks, model params: {get_params(model)}, ' + f'FLOPs: {get_flops(model, input_shape=args.input_shape)}') + + model.cuda() + model = DDP(model, + device_ids=[args.local_rank], + find_unused_parameters=False) + + if args.model_ema: + model_ema = ModelEMA(model, decay=args.model_ema_decay) + else: + model_ema = None + + '''dyrep''' + if args.dyrep: + from lib.models.utils.dyrep import DyRep + dyrep = DyRep( + model.module, + None) + logger.info('Init DyRep done.') + else: + dyrep = None + + '''resume''' + ckpt_manager = CheckpointManager(model, + ema_model=model_ema, + save_dir=args.exp_dir, + rank=args.rank, + additions={ + 'dyrep': dyrep + }) + + if args.resume: + epoch = ckpt_manager.load(args.resume) + if args.dyrep: + model = DDP(model.module, + device_ids=[args.local_rank], + find_unused_parameters=True) + logger.info( + f'Resume ckpt {args.resume} done, ' + f'epoch {epoch}' + ) + else: + epoch = 0 + + # validate + test_metrics = validate(args, epoch, model, val_loader, loss_fn) + # convert dyrep / dbb model to inference model + for m in model.module.modules(): + if isinstance(m, DiverseBranchBlock): + m.switch_to_deploy() + logger.info(str(model)) + logger.info( + f'Converted DBB / DyRep model to inference model, params: {get_params(model)}, ' + f'FLOPs: {get_flops(model, input_shape=args.input_shape)}') + test_metrics = validate(args, epoch, model, val_loader, loss_fn) + + '''save converted checkpoint''' + if args.rank == 0: + save_path = os.path.join(args.exp_dir, 'model.ckpt') + torch.save(model.module.state_dict(), save_path) + logger.info(f'Saved converted model checkpoint into {save_path} .') + + +def validate(args, epoch, model, loader, loss_fn, log_suffix=''): + loss_m = AverageMeter(dist=True) + top1_m = AverageMeter(dist=True) + top5_m = AverageMeter(dist=True) + batch_time_m = AverageMeter(dist=True) + start_time = time.time() + + model.eval() + for batch_idx, (input, target) in enumerate(loader): + with torch.no_grad(): + output = model(input) + loss = loss_fn(output, target) + + top1, top5 = accuracy(output, target, topk=(1, 5)) + loss_m.update(loss.item(), n=input.size(0)) + top1_m.update(top1 * 100, n=input.size(0)) + top5_m.update(top5 * 100, n=input.size(0)) + + batch_time = time.time() - start_time + batch_time_m.update(batch_time) + if batch_idx % args.log_interval == 0 or batch_idx == len(loader) - 1: + logger.info('Test{}: {} [{:>4d}/{}] ' + 'Loss: {loss.val:.3f} ({loss.avg:.3f}) ' + 'Top-1: {top1.val:.3f}% ({top1.avg:.3f}%) ' + 'Top-5: {top5.val:.3f}% ({top5.avg:.3f}%) ' + 'Time: {batch_time.val:.2f}s'.format( + log_suffix, + epoch, + batch_idx, + len(loader), + loss=loss_m, + top1=top1_m, + top5=top5_m, + batch_time=batch_time_m)) + start_time = time.time() + + return {'test_loss': loss_m.avg, 'top1': top1_m.avg, 'top5': top5_m.avg} + + +if __name__ == '__main__': + main() diff --git a/classification/tools/convert_ckpt.py b/classification/tools/convert_ckpt.py new file mode 100644 index 0000000..196611b --- /dev/null +++ b/classification/tools/convert_ckpt.py @@ -0,0 +1,14 @@ +import sys +import torch + + +ori_path = sys.argv[1] +new_path = sys.argv[2] +key = sys.argv[3] + +ckpt = torch.load(ori_path, map_location='cpu') +new_ckpt = {} + +new_ckpt['state_dict'] = ckpt[key] + +torch.save(new_ckpt, new_path) diff --git a/classification/tools/dist_run.sh b/classification/tools/dist_run.sh new file mode 100644 index 0000000..3019bbf --- /dev/null +++ b/classification/tools/dist_run.sh @@ -0,0 +1,11 @@ +#!/bin/bash +ENTRY=$1 +GPUS=$2 +CONFIG=$3 +MODEL=$4 +PY_ARGS=${@:5} + +set -x + +python -m torch.distributed.launch --nproc_per_node=${GPUS} \ + ${ENTRY} -c ${CONFIG} --model ${MODEL} ${PY_ARGS} diff --git a/classification/tools/dist_train.sh b/classification/tools/dist_train.sh new file mode 100644 index 0000000..5f8a79d --- /dev/null +++ b/classification/tools/dist_train.sh @@ -0,0 +1,15 @@ +#!/bin/bash +GPUS=$1 +CONFIG=$2 +MODEL=$3 +PY_ARGS=${@:4} + +MASTER_PORT=29500 + +set -x + +# NOTE: This script only supports run on single machine and single (multiple) GPUs. +# You may need to modify it to support multi-machine multi-card training on your distributed platform. + +python -m torch.distributed.launch --nproc_per_node=${GPUS} \ + tools/train.py -c ${CONFIG} --model ${MODEL} ${PY_ARGS} diff --git a/classification/tools/slurm_run.sh b/classification/tools/slurm_run.sh new file mode 100644 index 0000000..0693e0a --- /dev/null +++ b/classification/tools/slurm_run.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +ENTRY=$1 +PARTITION=$2 +GPUS=$3 +CONFIG=$4 +MODEL=$5 +PY_ARGS=${@:6} + +N=${GPUS} +if [ ${GPUS} -gt 8 ] +then + echo "multi machine" + N=8 +fi + +set -x + +PYTHONPATH=$PWD:$PYTHONPATH PYTHONWARNINGS=ignore GLOG_logtostderr=-1 GLOG_vmodule=MemcachedClient=-1 OMPI_MCA_btl_smcuda_use_cuda_ipc=0 OMPI_MCA_mpi_warn_on_fork=0 \ + srun --mpi=pmi2 --job-name train --partition=${PARTITION} -n${GPUS} --gres=gpu:${N} --ntasks-per-node=${N} \ + python -u ${ENTRY} -c ${CONFIG} --model ${MODEL} --slurm ${PY_ARGS} diff --git a/classification/tools/slurm_train.sh b/classification/tools/slurm_train.sh new file mode 100644 index 0000000..05e8ba2 --- /dev/null +++ b/classification/tools/slurm_train.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +PARTITION=$1 +GPUS=$2 +CONFIG=$3 +MODEL=$4 +PY_ARGS=${@:5} + +N=${GPUS} +if [ ${GPUS} -gt 8 ] +then + echo "multi machine" + N=8 +fi + +set -x + +PYTHONPATH=$PWD:$PYTHONPATH PYTHONWARNINGS=ignore GLOG_logtostderr=-1 GLOG_vmodule=MemcachedClient=-1 OMPI_MCA_btl_smcuda_use_cuda_ipc=0 OMPI_MCA_mpi_warn_on_fork=0 \ + srun --mpi=pmi2 --job-name train --partition=${PARTITION} -n${GPUS} --gres=gpu:${N} --ntasks-per-node=${N} \ + python -u tools/train.py -c ${CONFIG} --model ${MODEL} --slurm ${PY_ARGS} diff --git a/classification/tools/speed_test.py b/classification/tools/speed_test.py new file mode 100644 index 0000000..475e496 --- /dev/null +++ b/classification/tools/speed_test.py @@ -0,0 +1,78 @@ +import os +import torch +import torch.nn as nn +import logging +import time +import random +import numpy as np +from torch.nn.parallel import DistributedDataParallel as DDP + +from lib.models.builder import build_model +from lib.utils.args import parse_args +from lib.utils.dist_utils import init_dist, init_logger +from lib.utils.measure import get_params, get_flops + +torch.backends.cudnn.benchmark = True + +'''init logger''' +logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', + datefmt='%H:%M:%S') +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +def main(): + args, args_text = parse_args() + args.input_shape = (3, 224, 224) + + '''fix random seed''' + seed = args.seed + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + # torch.backends.cudnn.deterministic = True + + '''build model''' + model = build_model(args, args.model) + logger.info( + f'Model {args.model} created, params: {get_params(model)}, ' + f'FLOPs: {get_flops(model, input_shape=args.input_shape)}') + + # Diverse Branch Blocks + if args.dbb: + # convert 3x3 convs to dbb blocks + from lib.models.utils.dbb_converter import convert_to_dbb + convert_to_dbb(model) + logger.info(model) + logger.info( + f'Converted to DBB blocks, model params: {get_params(model)}, ' + f'FLOPs: {get_flops(model, input_shape=args.input_shape)}') + + speed_test(model, batch_size=args.batch_size, input_shape=args.input_shape) + + +def speed_test(model, warmup_iters=100, n_iters=1000, batch_size=128, input_shape=(3, 224, 224), device='cuda'): + device = torch.device(device) + model.to(device) + model.eval() + x = torch.randn((batch_size, *input_shape), device=device) + + with torch.no_grad(): + for _ in range(warmup_iters): + model(x) + logger.info('Start measuring speed.') + torch.cuda.synchronize() + t = time.time() + for i in range(n_iters): + model(x) + torch.cuda.synchronize() + total_time = time.time() - t + total_samples = batch_size * n_iters + speed = total_samples / total_time + logger.info(f'Done, n_iters: {n_iters}, batch size: {batch_size}, image shape: {input_shape}') + logger.info(f'total time: {total_time} s, total samples: {total_samples}, throughput: {speed:.3f} samples/second.') + + +if __name__ == '__main__': + main() diff --git a/classification/tools/test.py b/classification/tools/test.py new file mode 100644 index 0000000..b614779 --- /dev/null +++ b/classification/tools/test.py @@ -0,0 +1,125 @@ +import os +import torch +import torch.nn as nn +import logging +import time +import numpy as np +from torch.nn.parallel import DistributedDataParallel as DDP + +from lib.models.builder import build_model +from lib.dataset.builder import build_dataloader +from lib.utils.args import parse_args +from lib.utils.dist_utils import init_dist, init_logger +from lib.utils.misc import accuracy, AverageMeter, CheckpointManager +from lib.utils.model_ema import ModelEMA +from lib.utils.measure import get_params, get_flops + +torch.backends.cudnn.benchmark = True + +'''init logger''' +logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', + datefmt='%H:%M:%S') +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +def main(): + args, args_text = parse_args() + args.exp_dir = f'experiments/{args.experiment}' + + '''distributed''' + init_dist(args) + init_logger(args) + + '''build dataloader''' + _, val_dataset, _, val_loader = \ + build_dataloader(args) + + '''build model''' + loss_fn = nn.CrossEntropyLoss().cuda() + val_loss_fn = loss_fn + + model = build_model(args, args.model) + logger.info(model) + logger.info( + f'Model {args.model} created, params: {get_params(model) / 1e6:.3f} M, ' + f'FLOPs: {get_flops(model, input_shape=args.input_shape) / 1e9:.3f} G') + + model.cuda() + model = DDP(model, + device_ids=[args.local_rank], + find_unused_parameters=False) + + if args.model_ema: + model_ema = ModelEMA(model, decay=args.model_ema_decay) + else: + model_ema = None + + '''resume''' + ckpt_manager = CheckpointManager(model, + ema_model=model_ema, + save_dir=args.exp_dir, + rank=args.rank) + + if args.resume: + epoch = ckpt_manager.load(args.resume) + logger.info( + f'Resume ckpt {args.resume} done, ' + f'epoch {epoch}' + ) + else: + epoch = 0 + + # validate + test_metrics = validate(args, epoch, model, val_loader, val_loss_fn) + if model_ema is not None: + test_metrics = validate(args, + epoch, + model_ema.module, + val_loader, + loss_fn, + log_suffix='(EMA)') + logger.info(test_metrics) + + +def validate(args, epoch, model, loader, loss_fn, log_suffix=''): + loss_m = AverageMeter(dist=True) + top1_m = AverageMeter(dist=True) + top5_m = AverageMeter(dist=True) + batch_time_m = AverageMeter(dist=True) + start_time = time.time() + + model.eval() + for batch_idx, (input, target) in enumerate(loader): + with torch.no_grad(): + output = model(input) + loss = loss_fn(output, target) + + top1, top5 = accuracy(output, target, topk=(1, 5)) + loss_m.update(loss.item(), n=input.size(0)) + top1_m.update(top1 * 100, n=input.size(0)) + top5_m.update(top5 * 100, n=input.size(0)) + + batch_time = time.time() - start_time + batch_time_m.update(batch_time) + if batch_idx % args.log_interval == 0 or batch_idx == len(loader) - 1: + logger.info('Test{}: {} [{:>4d}/{}] ' + 'Loss: {loss.val:.3f} ({loss.avg:.3f}) ' + 'Top-1: {top1.val:.3f}% ({top1.avg:.3f}%) ' + 'Top-5: {top5.val:.3f}% ({top5.avg:.3f}%) ' + 'Time: {batch_time.val:.2f}s'.format( + log_suffix, + epoch, + batch_idx, + len(loader), + loss=loss_m, + top1=top1_m, + top5=top5_m, + batch_time=batch_time_m)) + start_time = time.time() + + return {'test_loss': loss_m.avg, 'top1': top1_m.avg, 'top5': top5_m.avg} + + +if __name__ == '__main__': + main() diff --git a/classification/tools/train.py b/classification/tools/train.py new file mode 100644 index 0000000..70d4fae --- /dev/null +++ b/classification/tools/train.py @@ -0,0 +1,406 @@ +import os +import torch +import torch.nn as nn +import logging +import time +import random +import numpy as np +from torch.nn.parallel import DistributedDataParallel as DDP + +from lib.models.builder import build_model +from lib.models.losses import CrossEntropyLabelSmooth, \ + SoftTargetCrossEntropy +from lib.dataset.builder import build_dataloader +from lib.utils.optim import build_optimizer +from lib.utils.scheduler import build_scheduler +from lib.utils.args import parse_args +from lib.utils.dist_utils import init_dist, init_logger +from lib.utils.misc import accuracy, AverageMeter, \ + CheckpointManager, AuxiliaryOutputBuffer +from lib.utils.model_ema import ModelEMA +from lib.utils.measure import get_params, get_flops + +try: + # need `pip install nvidia-ml-py3` to measure gpu stats + import nvidia_smi + nvidia_smi.nvmlInit() + handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0) + _has_nvidia_smi = True +except ModuleNotFoundError: + _has_nvidia_smi = False + + +torch.backends.cudnn.benchmark = True + +'''init logger''' +logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', + datefmt='%H:%M:%S') +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +def main(): + args, args_text = parse_args() + args.exp_dir = f'experiments/{args.experiment}' + + '''distributed''' + init_dist(args) + init_logger(args) + + # save args + logger.info(args) + if args.rank == 0: + with open(os.path.join(args.exp_dir, 'args.yaml'), 'w') as f: + f.write(args_text) + + '''fix random seed''' + seed = args.seed + args.rank + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + # torch.backends.cudnn.deterministic = True + + '''build dataloader''' + train_dataset, val_dataset, train_loader, val_loader = \ + build_dataloader(args) + + '''build model''' + if args.mixup > 0. or args.cutmix > 0 or args.cutmix_minmax is not None: + loss_fn = SoftTargetCrossEntropy() + elif args.smoothing == 0.: + loss_fn = nn.CrossEntropyLoss().cuda() + else: + loss_fn = CrossEntropyLabelSmooth(num_classes=args.num_classes, + epsilon=args.smoothing).cuda() + val_loss_fn = loss_fn + + model = build_model(args, args.model) + logger.info( + f'Model {args.model} created, params: {get_params(model) / 1e6:.3f} M, ' + f'FLOPs: {get_flops(model, input_shape=args.input_shape) / 1e9:.3f} G') + + # logger.info( + # f'Model {args.model} created, params: {get_params(model) / 1e6:.3f} M') + + # Diverse Branch Blocks + if args.dbb: + # convert 3x3 convs to dbb blocks + from lib.models.utils.dbb_converter import convert_to_dbb + convert_to_dbb(model) + logger.info(model) + logger.info( + f'Converted to DBB blocks, model params: {get_params(model) / 1e6:.3f} M, ' + f'FLOPs: {get_flops(model, input_shape=args.input_shape) / 1e9:.3f} G') + + model.cuda() + + + # knowledge distillation + if args.kd != '': + # build teacher model + teacher_model = build_model(args, args.teacher_model, args.teacher_pretrained, args.teacher_ckpt) + logger.info( + f'Teacher model {args.teacher_model} created, params: {get_params(teacher_model) / 1e6:.3f} M, ' + f'FLOPs: {get_flops(teacher_model, input_shape=args.input_shape) / 1e9:.3f} G') + teacher_model.cuda() + test_metrics = validate(args, 0, teacher_model, val_loader, val_loss_fn, log_suffix=' (teacher)') + logger.info(f'Top-1 accuracy of teacher model {args.teacher_model}: {test_metrics["top1"]:.2f}') + + # build kd loss + from lib.models.losses.kd_loss import KDLoss + loss_fn = KDLoss(model, teacher_model, args.model, args.teacher_model, loss_fn, + args.kd, args.ori_loss_weight, args.kd_loss_weight, args.kd_loss_kwargs) + + model = DDP(model, + device_ids=[args.local_rank], + find_unused_parameters=False) + from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks + model.register_comm_hook(None, comm_hooks.fp16_compress_hook) + if args.kd != '': + loss_fn.student = model + logger.info(model) + + if args.model_ema: + model_ema = ModelEMA(model, decay=args.model_ema_decay) + else: + model_ema = None + + '''build optimizer''' + optimizer = build_optimizer(args.opt, + model.module, + args.lr, + eps=args.opt_eps, + momentum=args.momentum, + weight_decay=args.weight_decay, + filter_bias_and_bn=not args.opt_no_filter, + nesterov=not args.sgd_no_nesterov, + sort_params=args.dyrep) + + '''build scheduler''' + steps_per_epoch = len(train_loader) + warmup_steps = args.warmup_epochs * steps_per_epoch + decay_steps = args.decay_epochs * steps_per_epoch + total_steps = args.epochs * steps_per_epoch + scheduler = build_scheduler(args.sched, + optimizer, + warmup_steps, + args.warmup_lr, + decay_steps, + args.decay_rate, + total_steps, + steps_per_epoch=steps_per_epoch, + decay_by_epoch=args.decay_by_epoch, + min_lr=args.min_lr) + + '''dyrep''' + if args.dyrep: + from lib.models.utils.dyrep import DyRep + from lib.models.utils.recal_bn import recal_bn + dyrep = DyRep( + model.module, + optimizer, + recal_bn_fn=lambda m: recal_bn(model.module, train_loader, + args.dyrep_recal_bn_iters, m), + filter_bias_and_bn=not args.opt_no_filter) + logger.info('Init DyRep done.') + else: + dyrep = None + + '''amp''' + if args.amp: + loss_scaler = torch.cuda.amp.GradScaler() + else: + loss_scaler = None + + '''resume''' + ckpt_manager = CheckpointManager(model, + optimizer, + ema_model=model_ema, + save_dir=args.exp_dir, + rank=args.rank, + additions={ + 'scaler': loss_scaler, + 'dyrep': dyrep + }) + + if args.resume: + start_epoch = ckpt_manager.load(args.resume) + 1 + if start_epoch > args.warmup_epochs: + scheduler.finished = True + scheduler.step(start_epoch * len(train_loader)) + if args.dyrep: + model = DDP(model.module, + device_ids=[args.local_rank], + find_unused_parameters=True) + logger.info( + f'Resume ckpt {args.resume} done, ' + f'start training from epoch {start_epoch}' + ) + else: + start_epoch = 0 + + '''auxiliary tower''' + if args.auxiliary: + auxiliary_buffer = AuxiliaryOutputBuffer(model, args.auxiliary_weight) + else: + auxiliary_buffer = None + + '''train & val''' + for epoch in range(start_epoch, args.epochs): + train_loader.loader.sampler.set_epoch(epoch) + + if args.drop_path_rate > 0. and args.drop_path_strategy == 'linear': + # update drop path rate + if hasattr(model.module, 'drop_path_rate'): + model.module.drop_path_rate = \ + args.drop_path_rate * epoch / args.epochs + + # train + metrics = train_epoch(args, epoch, model, model_ema, train_loader, + optimizer, loss_fn, scheduler, auxiliary_buffer, + dyrep, loss_scaler) + + # validate + test_metrics = validate(args, epoch, model, val_loader, val_loss_fn) + if model_ema is not None: + test_metrics = validate(args, + epoch, + model_ema.module, + val_loader, + val_loss_fn, + log_suffix='(EMA)') + + # dyrep + if dyrep is not None: + if epoch < args.dyrep_max_adjust_epochs: + if (epoch + 1) % args.dyrep_adjust_interval == 0: + # adjust + logger.info('DyRep: adjust model.') + dyrep.adjust_model() + logger.info( + f'Model params: {get_params(model)/1e6:.3f} M, FLOPs: {get_flops(model, input_shape=args.input_shape)/1e9:.3f} G' + ) + # re-init DDP + model = DDP(model.module, + device_ids=[args.local_rank], + find_unused_parameters=True) + test_metrics = validate(args, epoch, model, val_loader, val_loss_fn) + elif args.dyrep_recal_bn_every_epoch: + logger.info('DyRep: recalibrate BN.') + recal_bn(model.module, train_loader, 200) + test_metrics = validate(args, epoch, model, val_loader, val_loss_fn) + + metrics.update(test_metrics) + ckpts = ckpt_manager.update(epoch, metrics) + logger.info('\n'.join(['Checkpoints:'] + [ + ' {} : {:.3f}%'.format(ckpt, score) for ckpt, score in ckpts + ])) + + +def train_epoch(args, + epoch, + model, + model_ema, + loader, + optimizer, + loss_fn, + scheduler, + auxiliary_buffer=None, + dyrep=None, + loss_scaler=None): + loss_m = AverageMeter(dist=True) + data_time_m = AverageMeter(dist=True) + batch_time_m = AverageMeter(dist=True) + start_time = time.time() + + model.train() + for batch_idx, (input, target) in enumerate(loader): + data_time = time.time() - start_time + data_time_m.update(data_time) + + # optimizer.zero_grad() + # use optimizer.zero_grad(set_to_none=True) for speedup + for p in model.parameters(): + p.grad = None + + with torch.cuda.amp.autocast(enabled=loss_scaler is not None): + if not args.kd: + output = model(input) + loss = loss_fn(output, target) + else: + loss = loss_fn(input, target) + + if auxiliary_buffer is not None: + loss_aux = loss_fn(auxiliary_buffer.output, target) + loss += loss_aux * auxiliary_buffer.loss_weight + + if loss_scaler is None: + loss.backward() + else: + # amp + loss_scaler.scale(loss).backward() + if args.clip_grad_norm: + if loss_scaler is not None: + loss_scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), + args.clip_grad_max_norm) + + if dyrep is not None: + # record states of model in dyrep + dyrep.record_metrics() + + if loss_scaler is None: + optimizer.step() + else: + loss_scaler.step(optimizer) + loss_scaler.update() + + if model_ema is not None: + model_ema.update(model) + + loss_m.update(loss.item(), n=input.size(0)) + batch_time = time.time() - start_time + batch_time_m.update(batch_time) + if batch_idx % args.log_interval == 0 or batch_idx == len(loader) - 1: + if _has_nvidia_smi: + util = int(nvidia_smi.nvmlDeviceGetUtilizationRates(handle).gpu) + mem = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used / 1024 / 1024 + logger.info('Train: {} [{:>4d}/{}] ' + 'Loss: {loss.val:.3f} ({loss.avg:.3f}) ' + 'LR: {lr:.3e} ' + 'Mem: {memory:.0f} ' + 'Util: {util:d}% ' + 'Time: {batch_time.val:.2f}s ({batch_time.avg:.2f}s) ' + 'Data: {data_time.val:.2f}s'.format( + epoch, + batch_idx, + len(loader), + loss=loss_m, + lr=optimizer.param_groups[0]['lr'], + util=util, + memory=mem, + batch_time=batch_time_m, + data_time=data_time_m)) + else: + logger.info('Train: {} [{:>4d}/{}] ' + 'Loss: {loss.val:.3f} ({loss.avg:.3f}) ' + 'LR: {lr:.3e} ' + 'Mem: {memory:.0f} ' + 'Time: {batch_time.val:.2f}s ({batch_time.avg:.2f}s) ' + 'Data: {data_time.val:.2f}s'.format( + epoch, + batch_idx, + len(loader), + loss=loss_m, + lr=optimizer.param_groups[0]['lr'], + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, + batch_time=batch_time_m, + data_time=data_time_m)) + scheduler.step(epoch * len(loader) + batch_idx + 1) + start_time = time.time() + + return {'train_loss': loss_m.avg} + + +def validate(args, epoch, model, loader, loss_fn, log_suffix=''): + loss_m = AverageMeter(dist=True) + top1_m = AverageMeter(dist=True) + top5_m = AverageMeter(dist=True) + batch_time_m = AverageMeter(dist=True) + start_time = time.time() + + model.eval() + for batch_idx, (input, target) in enumerate(loader): + with torch.no_grad(): + output = model(input) + loss = loss_fn(output, target) + + top1, top5 = accuracy(output, target, topk=(1, 5)) + loss_m.update(loss.item(), n=input.size(0)) + top1_m.update(top1 * 100, n=input.size(0)) + top5_m.update(top5 * 100, n=input.size(0)) + + batch_time = time.time() - start_time + batch_time_m.update(batch_time) + if batch_idx % args.log_interval == 0 or batch_idx == len(loader) - 1: + logger.info('Test{}: {} [{:>4d}/{}] ' + 'Loss: {loss.val:.3f} ({loss.avg:.3f}) ' + 'Top-1: {top1.val:.3f}% ({top1.avg:.3f}%) ' + 'Top-5: {top5.val:.3f}% ({top5.avg:.3f}%) ' + 'Time: {batch_time.val:.2f}s'.format( + log_suffix, + epoch, + batch_idx, + len(loader), + loss=loss_m, + top1=top1_m, + top5=top5_m, + batch_time=batch_time_m)) + start_time = time.time() + + return {'test_loss': loss_m.avg, 'top1': top1_m.avg, 'top5': top5_m.avg} + + +if __name__ == '__main__': + main() diff --git a/classification/tools/vis_search_prob.py b/classification/tools/vis_search_prob.py new file mode 100644 index 0000000..93722e7 --- /dev/null +++ b/classification/tools/vis_search_prob.py @@ -0,0 +1,18 @@ +import sys +import torch + +ALL_CHOICES = ('h', 'h_flip', 'v', 'v_flip', 'w2', 'w2_flip', 'w7', 'w7_flip') + +ckpt = sys.argv[1] + +state = torch.load(ckpt, map_location='cpu')['model'] + +choices = [] +for k, v in state.items(): + if 'multi_scan.weights' in k: + probs = v.view(-1).softmax(-1) + print(probs) + topk = probs.topk(4)[1].sort()[0].tolist() + choices.append(str([ALL_CHOICES[idx] for idx in topk])+',') +for c in choices: + print(c)