diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..5570e4b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,25 @@ +FROM nvidia/cuda:10.0-devel-ubuntu18.04 + +#RUN yes | unminimize + +RUN apt-get update && apt-get install -y wget bzip2 +RUN wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh +RUN bash miniconda.sh -b -p /opt/conda && \ + rm miniconda.sh +ENV PATH="/opt/conda/bin:${PATH}" +RUN conda config --set always_yes yes + +RUN conda install pytorch==1.3.1 torchvision==0.4.2 cudatoolkit=10.0 -c pytorch +RUN pip install scikit-image tqdm pyyaml easydict future pip +RUN apt-get install unzip + +COPY ./ /obow +RUN pip install -e /obow + +WORKDIR /obow + +# Test imports +RUN python -c "" +RUN python -c "import main_linear_classification" +RUN python -c "import main_obow" +RUN python -c "import main_semisupervised" diff --git a/LICENSE b/LICENSE index 261eeb9..8cf0b16 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,24 @@ - Apache License + OBoW + + Copyright 2020 Valeo + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + + + Apache License Version 2.0, January 2004 - http://www.apache.org/licenses/ + https://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION @@ -174,28 +192,3 @@ of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/README.md b/README.md index b1f28e1..c68beb9 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,226 @@ -# Online Bag-of-Visual-Words Generation for Unsupervised Representation Learning - -![OBoW](./img/obow_overview.png) - -This is a PyTorch implementation of the OBoW paper: - -If you use the OBoW code or framework in your research, please consider citing: - -``` -@article{gidaris2020obow, - title={Online Bag-of-Visual-Words Generation for Unsupervised Representation Learning}, - author={Gidaris, Spyros and Bursuc, Andrei and Komodakis, Nikos and P{\'e}rez, Patrick and Cord, Matthieu}, - journal={arXiv preprint arXiv:2012.xxxx}, - year={2020} -} -``` +# **Online Bag-of-Visual-Words Generationfor Unsupervised Representation Learning** + +![OBoW](./img/obow_overview.png) + +This is a PyTorch implementation of the OBoW paper: +**Title:** "Online Bag-of-Visual-Words Generationfor Unsupervised Representation Learning" +**Authors:** S. Gidaris, A. Bursuc, G. Puy, N. Komodakis, M. Cord, and P. PĂ©rez + +If you use the OBoW code or framework in your research, please consider citing: + +``` +@article{gidaris2020obow, + title={Online Bag-of-Visual-Words Generation for Unsupervised Representation Learning}, + author={Gidaris, Spyros and Bursuc, Andrei and Komodakis, Nikos and Cord, Matthieu and P{\'e}rez, Patrick}, + journal={arXiv preprint arXiv:2012.xxxx}, + year={2020} +} +``` + +### **License** +This code is released under the MIT License (refer to the LICENSE file for details). + +## **Preparation** + +### **Pre-requisites** +* Python 3.7 +* Pytorch >= 1.3.1 (tested with 1.3.1) +* CUDA 10.0 or higher + +### **Installation** + +**(1)** Clone the repo: +```bash +$ git clone https://github.com/valeoai/obow +``` + + +**(2)** Install this repository and the dependencies using pip: +```bash +$ pip install -e ./obow +``` + + +With this, you can edit the obow code on the fly and import function +and classes of obow in other projects as well. + +**(3)** Optional. To uninstall this package, run: +```bash +$ pip uninstall obow +``` + + +**(4)** Create *experiment* directory: +```bash +$ cd obow +$ mkdir ./experiments +``` + + +You can take a look at the [Dockerfile](./Dockerfile) if you are uncertain +about the steps to install this project. + + +### **Download pre-trained models (optional).** + +TODO + + +## **Experiments: Training and evaluating ImageNet self-supervised features.** + +### **Train a ResNet50-based OBoW model (full solution) on the ImageneNet dataset.** + +```bash +# Run from the obow directory +# Train the OBoW model. +$ python main_obow.py --config=ImageNetFull/ResNet50_OBoW_full --workers=32 -p=250 --dst-dir=./experiments/ --data-dir=/datasets_local/ImageNet --multiprocessing-distributed --dist-url='tcp://127.0.0.1:4444' +``` + + +Here with `--data-dir=/datasets_local/ImageNet` it is assumed that the ImageNet +dataset is at the location `/datasets_local/ImageNet`. +The configuration file for running the above experiment, which is specified by +the `--config` argument, is located at: `./config/ImageNetFull/ResNet50_OBoW_full.py`. +Note that all the experiment configuration files are placed in the `./config/` +directory. The data of this experiment, such as checkpoints and logs, will be +stored at `./experiments/ImageNetFull/ResNet50_OBoW_full`. + +### **Evaluate on the ImageNet linear classification protocol** + +Train an ImageNet linear classification model on top of frozen features learned by student of the OBoW model. +```bash +# Run from the obow directory +# Train and evaluate a linear classifier for the 1000-way ImageNet classification task. +$ python main_linear_classification.py --config=ImageNetFull/ResNet50_OBoW_full --workers=32 -p=250 -b 1024 --wd 0.0 --lr 10.0 --epochs 100 --cos-schedule --dataset ImageNet --name "ImageNet_LinCls_b1024_wd0lr10_e100" --dst-dir=./experiments/ --data-dir=/datasets_local/ImageNet --multiprocessing-distributed --dist-url='tcp://127.0.0.1:4444' +``` + + +The data of this experiment, such as checkpoints and logs, will be +stored at `./experiments/ImageNetFull/ResNet50_OBoW_full/ImageNet_LinCls_b1024_wd0lr10_e100`. + +### **Evaluate on the Places205 linear classification protocol** + +Train an Places205 linear classification model on top of frozen features extracted from the OBoW model. +```bash +# Run from the obow directory +# Train and evaluate a linear classifier for the 205-way Places205 classification task. +$ python main_linear_classification.py --config=ImageNetFull/ResNet50_OBoW_full --dataset Places205 --batch-norm --workers=32 -p=500 -b 256 --wd 0.00001 --lr 0.01 --epochs 28 --schedule 10 20 --name "Places205_LinCls_b256_wd1e4lr0p01_e28" --dst-dir=./experiments/ --data-dir=/datasets_local/Places205 --multiprocessing-distributed --dist-url='tcp://127.0.0.1:4444' +``` + + +The data of this experiment, such as checkpoints and logs, will be +stored at `./experiments/ImageNetFull/ResNet50_OBoW_full/Places205_LinCls_b256_wd1e4lr0p01_e28`. + +### **ImageNet semi-supervised evaluation setting.** + +```bash +# Run from the obow directory +# Fine-tune with 1% of ImageNet annotated images. +$ python main_semisupervised.py --config=ImageNetFull/ResNet50_OBoW_full --workers=32 -p=50 --dst-dir=./experiments/ --data-dir=/datasets_local/ImageNet --multiprocessing-distributed --dist-url='tcp://127.0.0.1:4444' --percentage 1 --lr=0.0002 --lr-head=0.5 --lr-decay=0.2 --wd=0.0 --epochs=40 --schedule 24 32 --name="semi_supervised_prc1_wd0_lr0002lrp5_e40" +# Fine-tune with 10% of ImageNet annotated images. +$ python main_semisupervised.py --config=ImageNetFull/ResNet50_OBoW_full --workers=32 -p=50 --dst-dir=./experiments/ --data-dir=/datasets_local/ImageNet --multiprocessing-distributed --dist-url='tcp://127.0.0.1:4444' --percentage 10 --lr=0.0002 --lr-head=0.5 --lr-decay=0.2 --wd=0.0 --epochs=20 --schedule 12 16 --name="semi_supervised_prc10_wd0_lr0002lrp5_e20" +``` + + +The data of these experiments, such as checkpoints and logs, will be +stored at `./experiments/ImageNetFull/ResNet50_OBoW_full/semi_supervised_prc1_wd0_lr0002lrp5_e40` and +`./experiments/ImageNetFull/ResNet50_OBoW_full/semi_supervised_prc10_wd0_lr0002lrp5_e20` +(for the 1% and 10% settings respectively). + + +### **Convert to torchvision format.** + +The ResNet50 model that we trained is stored in a different format than that of the torchvision ResNe50 model. +The following command converts it to the torchvision format. + +```bash +$ python main_obow.py --config=ImageNetFull/ResNet50_OBoW_full --dst-dir=./experiments/ --data-dir=/datasets_local/ImageNet --multiprocessing-distributed --dist-url='tcp://127.0.0.1:4444' --convert-to-torchvision +``` + +### **Pascal VOC07 Classification evaluation.** + +First convert from the torchvision format to (see command above) to the caffe2 format. +```bash +# Run from the obow directory +python utils/convert_pytorch_to_caffe2.py --pth_model ./experiments/ImageNetFull/ResNet50_OBoW_full/tochvision_resnet50_student_K8192_epoch200.pth.tar --output_model ./experiments/ImageNetFull/ResNet50_OBoW_full/caffe2_resnet50_student_K8192_epoch200_bgr.pkl --rgb2bgr True +``` + +For the following steps you need first to download and install [fair_self_supervision_benchmark](https://github.com/facebookresearch/fair_self_supervision_benchmark). + +```bash +# Run from the fair_self_supervision_benchmark directory +$ python setup.py install +$ python -c 'import self_supervision_benchmark' +# Step 1: prepare datatset. +$ mkdir obow_ep200 +$ mkdir obow_ep200/voc +$ mkdir obow_ep200/voc/voc07 +$ python extra_scripts/create_voc_data_files.py --data_source_dir /datasets_local/VOC2007/ --output_dir ./obow_ep200/voc/voc07/ +# Step 2: extract features from voc2007 +$ mkdir obow_ep200/ssl-benchmark-output +$ mkdir obow_ep200/ssl-benchmark-output/extract_features_gap +$ mkdir obow_ep200/ssl-benchmark-output/extract_features_gap/data +# ==> Extract pool5 features from the train split. +$ python tools/extract_features.py \ + --config_file [obow directory path]/utils/configs/benchmark_tasks/image_classification/voc07/resnet50_supervised_extract_gap_features.yaml \ + --data_type train \ + --output_file_prefix trainval \ + --output_dir ./obow_ep200/ssl-benchmark-output/extract_features_gap/data \ + NUM_DEVICES 1 TEST.BATCH_SIZE 64 TRAIN.BATCH_SIZE 64 \ + TEST.PARAMS_FILE [obow directory path]/experiments/obow/ImageNetFull/ResNet50_OBoW_full/caffe2_resnet50_student_K8192_epoch200_bgr.pkl \ + TRAIN.DATA_FILE ./obow_ep200/voc/voc07/train_images.npy \ + TRAIN.LABELS_FILE ./obow_ep200/voc/voc07/train_labels.npy +# ==> Extract pool5 features from the test split. +$ python tools/extract_features.py \ + --config_file [obow directory path]/utils/configs/benchmark_tasks/image_classification/voc07/resnet50_supervised_extract_gap_features.yaml \ + --data_type test \ + --output_file_prefix test \ + --output_dir ./obow_ep200/ssl-benchmark-output/extract_features_gap/data \ + NUM_DEVICES 1 TEST.BATCH_SIZE 64 TRAIN.BATCH_SIZE 64 \ + TEST.PARAMS_FILE [obow directory path]/experiments/obow/ImageNetFull/ResNet50_OBoW_full/caffe2_resnet50_student_K8192_epoch200_bgr.pkl \ + TRAIN.DATA_FILE ./obow_ep200/voc/voc07/test_images.npy TEST.DATA_FILE ./obow_ep200/voc/voc07/test_images.npy \ + TRAIN.LABELS_FILE ./obow_ep200/voc/voc07/test_labels.npy TEST.LABELS_FILE ./obow_ep200/voc/voc07/test_labels.npy +# Step 4: Train and test linear svms. +# ==> Train linear svms. +$ mkdir obow_ep200/ssl-benchmark-output/extract_features_gap/data/voc07_svm +$ mkdir obow_ep200/ssl-benchmark-output/extract_features_gap/data/voc07_svm/svm_pool5bn +$ python tools/svm/train_svm_kfold.py \ + --data_file ./obow_ep200/ssl-benchmark-output/extract_features_gap/data/trainval_pool5_bn_features.npy \ + --targets_data_file ./obow_ep200/ssl-benchmark-output/extract_features_gap/data/trainval_pool5_bn_targets.npy \ + --costs_list "0.05,0.1,0.3,0.5,1.0,3.0,5.0" \ + --output_path ./obow_ep200/ssl-benchmark-output/extract_features_gap/data/voc07_svm/svm_pool5bn/ +# ==> Test the linear svms. +$ python tools/svm/test_svm.py \ + --data_file ./obow_ep200/ssl-benchmark-output/extract_features_gap/data/test_pool5_bn_features.npy \ + --targets_data_file ./obow_ep200/ssl-benchmark-output/extract_features_gap/data/test_pool5_bn_targets.npy \ + --costs_list "0.05,0.1,0.3,0.5,1.0,3.0,5.0" \ + --output_path ./obow_ep200/ssl-benchmark-output/extract_features_gap/data/voc07_svm/svm_pool5bn/ +``` + +## **Other experiments: Training using 20% of ImageNet and ResNet18.** + +A single gpu is enough for the following experiments. + +### **ResNet18-based OBoW vanilla solution.** + +```bash +# Run from the obow directory +# Train the model. +$ python main_obow.py --config=ImageNet20/ResNet18_OBoW_vanilla --workers=16 --dst-dir=./experiments/ --data-dir=/datasets_local/ImageNet +# Few-shot evaluation. +$ python main_obow.py --config=ImageNet20/ResNet18_OBoW_vanilla --workers=16 --episodes 200 --fewshot-q 1 --fewshot-n 50 --fewshot-k 1 5 --evaluate --start-epoch=-1 --dst-dir=./experiments/ --data-dir=/datasets_local/ImageNet +# Linear classification evaluation. Note the following command precaches the extracted features at root/local_storage/spyros/cache/obow. +$ python main_linear_classification.py --config=ImageNet20/ResNet18_OBoW_vanilla --workers=16 -b 256 --wd 0.000002 --dataset ImageNet --name "ImageNet_LinCls_precache_b256_lr10p0wd2e6" --precache --lr 10.0 --epochs 50 --schedule 15 30 45 --subset=260 --dst-dir=./experiments/ --data-dir=/datasets_local/ImageNet --cache-dir=/root/local_storage/spyros/cache/obow +``` + +### **ResNet18-based OBoW full solution.** + +```bash +# Run from the obow directory +# Train the model. +$ python main_obow.py --config=ImageNet20/ResNet18_OBoW_full --workers=16 --dst-dir=./experiments/ --data-dir=/datasets_local/ImageNet +# Few-shot evaluation. +$ python main_obow.py --config=ImageNet20/ResNet18_OBoW_full --workers=16 --episodes 200 --fewshot-q 1 --fewshot-n 50 --fewshot-k 1 5 --evaluate --start-epoch=-1 --dst-dir=./experiments/ --data-dir=/datasets_local/ImageNet +# Linear classification evaluation. Note the following command precaches the extracted features at root/local_storage/spyros/cache/obow. +$ python main_linear_classification.py --config=ImageNet20/ResNet18_OBoW_full --workers=16 -b 256 --wd 0.000002 --dataset ImageNet --name "ImageNet_LinCls_precache_b256_lr10p0wd2e6" --precache --lr 10.0 --epochs 50 --schedule 15 30 45 --subset=260 --dst-dir=./experiments/ --data-dir=/datasets_local/ImageNet --cache-dir=/root/local_storage/spyros/cache/obow +``` diff --git a/config/ImageNet20/ResNet18_OBoW_full.yaml b/config/ImageNet20/ResNet18_OBoW_full.yaml new file mode 100644 index 0000000..00c4af1 --- /dev/null +++ b/config/ImageNet20/ResNet18_OBoW_full.yaml @@ -0,0 +1,49 @@ +# Model parameters. +model: + alpha: 0.99 + alpha_cosine: True + feature_extractor_arch: "resnet18" + feature_extractor_opts: + global_pooling: True + # Use two feature levels for BoW: "block3" (aka conv4 of ResNet) and "block4" + # (aka conv5 of ResNet). + bow_levels: ["block3", "block4"] + bow_extractor_opts: + inv_delta: 10 + num_words: 8192 + bow_predictor_opts: + kappa: 5 + # (Optional) on-line learning of a linear classifier on top of teacher + # features for monitoring purposes. + num_classes: 1000 + +# Optimization parameters. +optim: + optim_type: "sgd" + momentum: 0.9 + weight_decay: 0.0005 + nesterov: False + num_epochs: 80 + lr: 0.05 + lr_schedule_type: "cos" + +# Data parameters: +data: + dataset_name: "ImageNet" + batch_size: 128 + epoch_size: + subset: 260 # use only 260 images per class, i.e., 20% of ImageNet. + cjitter: [0.4, 0.4, 0.4, 0.1] + cjitter_p: 0.8 + gray_p: 0.2 + gaussian_blur: [0.1, 2.0] + gaussian_blur_p: 0.5 + num_img_crops: 2 # 2 crops of size 160x160. + image_crop_size: 160 + image_crop_range: [0.08, 0.6] + num_img_patches: 5 # 5 patches of size 96x96. + img_patch_preresize: 256 + img_patch_preresize_range: [0.6, 1.0] + img_patch_size: 96 + img_patch_jitter: 24 + only_patches: False diff --git a/config/ImageNet20/ResNet18_OBoW_vanilla.yaml b/config/ImageNet20/ResNet18_OBoW_vanilla.yaml new file mode 100644 index 0000000..91cae05 --- /dev/null +++ b/config/ImageNet20/ResNet18_OBoW_vanilla.yaml @@ -0,0 +1,48 @@ +# Model parameters. +model: + alpha: 0.99 + alpha_cosine: True + feature_extractor_arch: "resnet18" + feature_extractor_opts: + global_pooling: True + # Use a single feature level for BoW: "block4" (aka conv5 of ResNet). + bow_levels: ["block4",] + bow_extractor_opts: + inv_delta: 10 + num_words: 8192 + bow_predictor_opts: + kappa: 5 + # (Optional) on-line learning of a linear classifier on top of teacher + # features for monitoring purposes. + num_classes: 1000 + +# Optimization parameters. +optim: + optim_type: "sgd" + momentum: 0.9 + weight_decay: 0.0005 + nesterov: False + num_epochs: 80 + lr: 0.05 + lr_schedule_type: "cos" + +# Data parameters: +data: + dataset_name: "ImageNet" + batch_size: 256 + epoch_size: + subset: 260 # use only 260 images per class, i.e., 20% of ImageNet. + cjitter: [0.4, 0.4, 0.4, 0.1] + cjitter_p: 0.8 + gray_p: 0.2 + gaussian_blur: [0.1, 2.0] + gaussian_blur_p: 0.5 + num_img_crops: 1 # 1 crop of size 160x160. + image_crop_size: 160 + image_crop_range: [0.08, 0.6] + num_img_patches: 0 # 0 patches. + img_patch_preresize: 256 + img_patch_preresize_range: [0.6, 1.0] + img_patch_size: 96 + img_patch_jitter: 24 + only_patches: False diff --git a/config/ImageNetFull/ResNet50_OBoW_full.yaml b/config/ImageNetFull/ResNet50_OBoW_full.yaml new file mode 100644 index 0000000..5f03f3c --- /dev/null +++ b/config/ImageNetFull/ResNet50_OBoW_full.yaml @@ -0,0 +1,52 @@ +# Model parameters. +model: + alpha: 0.99 + alpha_cosine: True + feature_extractor_arch: "resnet50" + feature_extractor_opts: + global_pooling: True + # Use two feature levels for BoW: "block3" (aka conv4 of ResNet) and "block4" + # (aka conv5 of ResNet). + bow_levels: ["block3", "block4"] + bow_extractor_opts: + inv_delta: 15 + num_words: 8192 + bow_predictor_opts: + kappa: 8 + # (Optional) on-line learning of a linear classifier on top of teacher + # features for monitoring purposes. + num_classes: 1000 + +# Optimization parameters. +optim: + optim_type: "sgd" + momentum: 0.9 + weight_decay: 0.0001 + nesterov: False + num_epochs: 200 + lr: 0.03 + end_lr: 0.00003 + lr_schedule_type: "cos_warmup" + warmup_epochs: 10 + permanent: 10 # save a permanent checkpoint every 10 epochs. + +# Data parameters: +data: + dataset_name: "ImageNet" + batch_size: 256 + epoch_size: + subset: + cjitter: [0.4, 0.4, 0.4, 0.1] + cjitter_p: 0.8 + gray_p: 0.2 + gaussian_blur: [0.1, 2.0] + gaussian_blur_p: 0.5 + num_img_crops: 2 # 2 crops of size 160x160. + image_crop_size: 160 + image_crop_range: [0.08, 0.6] + num_img_patches: 5 # 5 patches of size 96x96. + img_patch_preresize: 256 + img_patch_preresize_range: [0.6, 1.0] + img_patch_size: 96 + img_patch_jitter: 24 + only_patches: False diff --git a/data/IMAGENET_LOWSHOT_BENCHMARK_CATEGORY_SPLITS.json b/data/IMAGENET_LOWSHOT_BENCHMARK_CATEGORY_SPLITS.json new file mode 100644 index 0000000..c6c2772 --- /dev/null +++ b/data/IMAGENET_LOWSHOT_BENCHMARK_CATEGORY_SPLITS.json @@ -0,0 +1 @@ +{"novel_classes_1": [3, 4, 5, 7, 15, 16, 24, 25, 27, 29, 30, 37, 39, 40, 41, 43, 47, 48, 55, 58, 61, 65, 66, 68, 70, 81, 83, 87, 88, 91, 92, 93, 95, 105, 107, 108, 112, 115, 118, 123, 124, 127, 128, 129, 130, 131, 134, 135, 138, 142, 145, 150, 155, 157, 158, 164, 166, 169, 170, 174, 181, 185, 186, 188, 190, 192, 193, 195, 199, 203, 214, 218, 222, 224, 227, 229, 230, 232, 236, 237, 241, 243, 246, 247, 250, 263, 266, 269, 272, 286, 287, 288, 291, 293, 298, 301, 304, 312, 322, 326, 329, 330, 342, 345, 346, 350, 351, 354, 358, 362, 365, 366, 378, 381, 382, 384, 385, 388, 390, 391, 393, 400, 407, 409, 411, 412, 424, 429, 430, 431, 432, 437, 442, 443, 445, 446, 447, 449, 451, 452, 456, 457, 458, 461, 464, 465, 467, 469, 477, 478, 479, 491, 492, 498, 501, 509, 512, 514, 518, 522, 523, 524, 525, 528, 533, 534, 535, 542, 543, 548, 549, 550, 554, 563, 565, 567, 568, 578, 579, 589, 592, 595, 596, 600, 610, 612, 613, 624, 630, 632, 635, 642, 644, 645, 646, 648, 649, 652, 654, 661, 665, 669, 680, 682, 683, 687, 689, 692, 694, 697, 705, 706, 712, 713, 715, 716, 719, 721, 724, 726, 729, 730, 733, 740, 741, 742, 745, 749, 751, 756, 759, 763, 766, 769, 770, 772, 778, 780, 782, 783, 792, 796, 799, 804, 807, 808, 811, 812, 814, 815, 816, 818, 822, 831, 832, 833, 842, 846, 847, 855, 856, 857, 859, 861, 868, 869, 871, 873, 877, 878, 880, 883, 888, 890, 891, 894, 896, 904, 909, 915, 918, 923, 928, 935, 941, 942, 944, 946, 950, 951, 955, 963, 965, 972, 976, 985, 988, 989, 991, 998], "base_classes": [0, 1, 2, 6, 9, 11, 12, 13, 17, 20, 21, 22, 26, 32, 33, 34, 35, 36, 38, 42, 44, 50, 52, 60, 62, 63, 69, 72, 74, 75, 76, 77, 84, 89, 90, 94, 98, 100, 110, 113, 117, 119, 120, 122, 125, 133, 140, 141, 148, 149, 152, 154, 156, 160, 162, 165, 167, 171, 172, 173, 176, 178, 182, 183, 184, 187, 191, 196, 198, 200, 205, 207, 208, 212, 213, 217, 219, 220, 221, 223, 226, 228, 231, 234, 240, 242, 245, 248, 249, 252, 253, 256, 257, 258, 259, 260, 262, 265, 267, 273, 274, 275, 277, 278, 281, 282, 285, 289, 292, 295, 297, 299, 300, 302, 305, 307, 308, 309, 310, 311, 316, 318, 320, 325, 331, 332, 335, 337, 341, 343, 344, 348, 353, 363, 367, 368, 369, 372, 374, 375, 379, 380, 383, 394, 396, 398, 403, 405, 408, 413, 414, 418, 420, 421, 422, 423, 425, 426, 433, 434, 435, 436, 438, 441, 444, 455, 462, 463, 470, 471, 473, 474, 476, 480, 481, 482, 483, 484, 485, 486, 489, 493, 494, 495, 496, 497, 500, 502, 503, 504, 507, 508, 510, 511, 513, 515, 516, 517, 519, 520, 531, 538, 540, 541, 544, 545, 547, 551, 553, 559, 560, 561, 564, 566, 572, 576, 577, 580, 581, 582, 583, 588, 590, 594, 597, 598, 601, 602, 603, 604, 605, 607, 608, 609, 611, 616, 618, 621, 622, 625, 627, 634, 640, 641, 643, 647, 650, 655, 657, 658, 659, 660, 662, 663, 664, 666, 670, 672, 675, 677, 679, 681, 684, 685, 686, 688, 690, 691, 693, 695, 696, 699, 702, 704, 707, 711, 714, 718, 720, 725, 727, 728, 731, 732, 734, 735, 737, 738, 743, 744, 747, 748, 750, 754, 757, 758, 762, 764, 765, 767, 768, 771, 775, 776, 779, 786, 788, 789, 791, 793, 795, 797, 798, 802, 805, 810, 813, 817, 821, 824, 825, 826, 827, 828, 830, 834, 836, 839, 840, 843, 844, 848, 853, 858, 860, 862, 863, 864, 865, 872, 874, 875, 876, 887, 889, 892, 893, 899, 900, 905, 906, 907, 908, 910, 912, 917, 919, 920, 924, 926, 929, 933, 936, 938, 939, 940, 943, 945, 947, 949, 952, 956, 957, 959, 961, 964, 966, 967, 968, 970, 971, 973, 974, 977, 979, 983, 994, 996, 999], "label_names": ["n01807496", "n02916936", "n03794056", "n10565667", "n02978881", "n03126707", "n03394916", "n07693725", "n03710193", "n02342885", "n02105412", "n02782093", "n01847000", "n04044716", "n07753275", "n01818515", "n02802426", "n04136333", "n03908714", "n03535780", "n11879895", "n03534580", "n02676566", "n09468604", "n03877845", "n02094114", "n03000247", "n03781244", "n02113023", "n03443371", "n02256656", "n01677366", "n04482393", "n03062245", "n09399592", "n03127925", "n02264363", "n02087394", "n04542943", "n02111129", "n02074367", "n02892767", "n01968897", "n03724870", "n02169497", "n02536864", "n01728920", "n04204347", "n03888257", "n02483362", "n07615774", "n02128757", "n01484850", "n04154565", "n04344873", "n01770393", "n09256479", "n07720875", "n02107574", "n03196217", "n02871525", "n03498962", "n03290653", "n01819313", "n07875152", "n07714571", "n06785654", "n01871265", "n02091032", "n02102318", "n02095889", "n01796340", "n03874293", "n04265275", "n02526121", "n02835271", "n03983396", "n07749582", "n03961711", "n01496331", "n03272010", "n01770081", "n03388043", "n03188531", "n07584110", "n02102480", "n02837789", "n02089973", "n01924916", "n02091244", "n04487394", "n04335435", "n01534433", "n04090263", "n04517823", "n02090622", "n03721384", "n03710637", "n03950228", "n02097209", "n02321529", "n02109047", "n02095314", "n03717622", "n01601694", "n13040303", "n02056570", "n09835506", "n01739381", "n02480855", "n01773157", "n03207941", "n02667093", "n03271574", "n04149813", "n01740131", "n02233338", "n07697313", "n02687172", "n03602883", "n02749479", "n02101556", "n07583066", "n03720891", "n01530575", "n01943899", "n03388183", "n04074963", "n01797886", "n02799071", "n09229709", "n03873416", "n02058221", "n01582220", "n02097298", "n03871628", "n03742115", "n04118776", "n03777568", "n02879718", "n04019541", "n02486261", "n02883205", "n02102177", "n03637318", "n01694178", "n03179701", "n02328150", "n03344393", "n03379051", "n03529860", "n04612504", "n03538406", "n02099712", "n01698640", "n03995372", "n02666196", "n04033995", "n02795169", "n02504458", "n02009229", "n03729826", "n02093754", "n03042490", "n02980441", "n03627232", "n04532670", "n01491361", "n04209133", "n01882714", "n07768694", "n02486410", "n01756291", "n02988304", "n04152593", "n04251144", "n03843555", "n04356056", "n01693334", "n03249569", "n04037443", "n02112018", "n04548280", "n04336792", "n03425413", "n02607072", "n02493793", "n02783161", "n02669723", "n01644373", "n02492660", "n02894605", "n04404412", "n04476259", "n03314780", "n04033901", "n02317335", "n02487347", "n02134084", "n04505470", "n03201208", "n02699494", "n03837869", "n01749939", "n03125729", "n04254680", "n02950826", "n03814639", "n06874185", "n03733131", "n03887697", "n03942813", "n03495258", "n04086273", "n02948072", "n03095699", "n02807133", "n03661043", "n04483307", "n02119789", "n03376595", "n02206856", "n04147183", "n03527444", "n03697007", "n04228054", "n04579145", "n02443484", "n02096177", "n03372029", "n02804610", "n02114548", "n04579432", "n02088632", "n04254777", "n02280649", "n03761084", "n02086910", "n02114712", "n03000134", "n04005630", "n01776313", "n03594734", "n02085936", "n02110341", "n03924679", "n02398521", "n03743016", "n01755581", "n03785016", "n01697457", "n01824575", "n03494278", "n04099969", "n04040759", "n07717556", "n02268443", "n02977058", "n01641577", "n04350905", "n02012849", "n04523525", "n07734744", "n03868863", "n01774750", "n07565083", "n03874599", "n03916031", "n03599486", "n02107142", "n01682714", "n04041544", "n02115641", "n02391049", "n02111500", "n01806567", "n07760859", "n02106662", "n01632458", "n02279972", "n02002724", "n03481172", "n02704792", "n01955084", "n02007558", "n04525305", "n04204238", "n04550184", "n03929660", "n02091134", "n01748264", "n01843065", "n04560804", "n03590841", "n03935335", "n04192698", "n02096585", "n03388549", "n04366367", "n03854065", "n04116512", "n03793489", "n07930864", "n03838899", "n02514041", "n02690373", "n02281787", "n03775071", "n03485794", "n04509417", "n03450230", "n01443537", "n02692877", "n02086079", "n04554684", "n04229816", "n01632777", "n03649909", "n04330267", "n03110669", "n01990800", "n04447861", "n02123159", "n03804744", "n02107683", "n02086646", "n04201297", "n02025239", "n02094433", "n02396427", "n03764736", "n02096051", "n04286575", "n03676483", "n01855032", "n02120505", "n02037110", "n04485082", "n09421951", "n02093991", "n04328186", "n01978455", "n02051845", "n03457902", "n04493381", "n01704323", "n02641379", "n03899768", "n02859443", "n02493509", "n01983481", "n03666591", "n01768244", "n12768682", "n03832673", "n03633091", "n02116738", "n03657121", "n03786901", "n03895866", "n01978287", "n09193705", "n02088364", "n01687978", "n03773504", "n02097474", "n07695742", "n12144580", "n01560419", "n02992529", "n03146219", "n03445777", "n02488702", "n02128925", "n02509815", "n03127747", "n07753592", "n02128385", "n03930313", "n04153751", "n04332243", "n03902125", "n02823428", "n01806143", "n02089867", "n03982430", "n02111889", "n04442312", "n04515003", "n03325584", "n03272562", "n03623198", "n02797295", "n04376876", "n02910353", "n02088094", "n01829413", "n02092339", "n04200800", "n01729322", "n01532829", "n04487081", "n03598930", "n03483316", "n01667778", "n03089624", "n02787622", "n02971356", "n04277352", "n13054560", "n07718472", "n07613480", "n02101006", "n09428293", "n04591713", "n07754684", "n03938244", "n02219486", "n04606251", "n03791053", "n02097658", "n04252225", "n03769881", "n03796401", "n03770679", "n03444034", "n02167151", "n06794110", "n01917289", "n02095570", "n01855672", "n03877472", "n01986214", "n04258138", "n03763968", "n03016953", "n03393912", "n02028035", "n01695060", "n03958227", "n03018349", "n02091831", "n07892512", "n02096294", "n02410509", "n02422106", "n02437312", "n07745940", "n02113712", "n02701002", "n03980874", "n07248320", "n02843684", "n03530642", "n02483708", "n02009912", "n02165456", "n02105855", "n01980166", "n01644900", "n02497673", "n03461385", "n04209239", "n04039381", "n03041632", "n03956157", "n04461696", "n03476684", "n03733805", "n02727426", "n02422699", "n03929855", "n03710721", "n04456115", "n03218198", "n04458633", "n01728572", "n04259630", "n15075141", "n13133613", "n04325704", "n02071294", "n02966193", "n04296562", "n02117135", "n02965783", "n01820546", "n01984695", "n04357314", "n02123597", "n01689811", "n04254120", "n03662601", "n02930766", "n02786058", "n03992509", "n03908618", "n02090379", "n03485407", "n03803284", "n04552348", "n07715103", "n01608432", "n02110185", "n02096437", "n03944341", "n12620546", "n07742313", "n01945685", "n02119022", "n03496892", "n03026506", "n01985128", "n02018207", "n03075370", "n01614925", "n07747607", "n04131690", "n02457408", "n07730033", "n04372370", "n03680355", "n04326547", "n02093647", "n04562935", "n04026417", "n02672831", "n02110627", "n04125021", "n07716358", "n04238763", "n02091635", "n03891332", "n02174001", "n02895154", "n04270147", "n02132136", "n02105251", "n04081281", "n03133878", "n03467068", "n07932039", "n06359193", "n07873807", "n02364673", "n03384352", "n03825788", "n04557648", "n04501370", "n13052670", "n02130308", "n01735189", "n01828970", "n03920288", "n02277742", "n01877812", "n04355933", "n03891251", "n04380533", "n07860988", "n02108551", "n04311004", "n02113186", "n02123394", "n07579787", "n03617480", "n04264628", "n01817953", "n12998815", "n02017213", "n02013706", "n02108089", "n02100877", "n02097130", "n02963159", "n03857828", "n01616318", "n02510455", "n04371430", "n02168699", "n04553703", "n02268853", "n03788195", "n02454379", "n02190166", "n03459775", "n01774384", "n02817516", "n03532672", "n03482405", "n03291819", "n07836838", "n02098413", "n04347754", "n02114855", "n07717410", "n04367480", "n02110806", "n13044778", "n03690938", "n02109961", "n02606052", "n02408429", "n03814906", "n02089078", "n04310018", "n01744401", "n03884397", "n09246464", "n02643566", "n04275548", "n02825657", "n07871810", "n02877765", "n04008634", "n01664065", "n03197337", "n01631663", "n04118538", "n02104029", "n07590611", "n03355925", "n04266014", "n02108422", "n04409515", "n04591157", "n02094258", "n02109525", "n01629819", "n07714990", "n07831146", "n07920052", "n02389026", "n03180011", "n03131574", "n03250847", "n03255030", "n01531178", "n03775546", "n01843383", "n02006656", "n04162706", "n01734418", "n02769748", "n02033041", "n02415577", "n02808304", "n02480495", "n04599235", "n03866082", "n04435653", "n04141327", "n02259212", "n01692333", "n02927161", "n01729977", "n03207743", "n01592084", "n02808440", "n02437616", "n03476991", "n04522168", "n02951585", "n02011460", "n04613696", "n02397096", "n03297495", "n04398044", "n03063689", "n03788365", "n03014705", "n02101388", "n03240683", "n03709823", "n07614500", "n02823750", "n02110063", "n02917067", "n04127249", "n07718747", "n01833805", "n02091467", "n02115913", "n03888605", "n01742172", "n04525038", "n12985857", "n01784675", "n01514859", "n04479046", "n04589890", "n02791124", "n04235860", "n02423022", "n13037406", "n02412080", "n03947888", "n02909870", "n02085620", "n01981276", "n04263257", "n02129165", "n03345487", "n02129604", "n02870880", "n01685808", "n03452741", "n04346328", "n09288635", "n02102973", "n02788148", "n02447366", "n03445924", "n01872401", "n04423845", "n04208210", "n02814533", "n03424325", "n01494475", "n02113624", "n02236044", "n01675722", "n02490219", "n02231487", "n03933933", "n11939491", "n03998194", "n03208938", "n02445715", "n03447447", "n02113799", "n02112706", "n03124170", "n04604644", "n02814860", "n02099429", "n02106382", "n12057211", "n04429376", "n02229544", "n02107312", "n01580077", "n02708093", "n02087046", "n03045698", "n03937543", "n01930112", "n03930630", "n01688243", "n02112137", "n07697537", "n02124075", "n03642806", "n04067472", "n04540053", "n02106166", "n02489166", "n02395406", "n03673027", "n02125311", "n04536866", "n02098286", "n04584207", "n04141975", "n03792972", "n02105641", "n02027492", "n02092002", "n04065272", "n03970156", "n02104365", "n02865351", "n02100236", "n01518878", "n01440764", "n03976657", "n01773797", "n02444819", "n04486054", "n02112350", "n04243546", "n02484975", "n02137549", "n02443114", "n02093859", "n03868242", "n02120079", "n03223299", "n02492035", "n04146614", "n02869837", "n01798484", "n02002556", "n04252077", "n02346627", "n02966687", "n04141076", "n04507155", "n02403003", "n02100583", "n03017168", "n02906734", "n04389033", "n02105162", "n03776460", "n03134739", "n02860847", "n02804414", "n01860187", "n01498041", "n01630670", "n01773549", "n03692522", "n07753113", "n03347037", "n03259280", "n02085782", "n04070727", "n02655020", "n02793495", "n02730930", "n02361337", "n02815834", "n02088466", "n04548362", "n01910747", "n02114367", "n03594945", "n03976467", "n03337140", "n03478589", "n04417672", "n02093428", "n03691459", "n04239074", "n01944390", "n03792782", "n04311174", "n03782006", "n04120489", "n03787032", "n04399382", "n02177972", "n04285008", "n03216828", "n02172182", "n02791270", "n02097047", "n02481823", "n03447721", "n09472597", "n09332890", "n07716906", "n02281406", "n03876231", "n03991062", "n04069434", "n04418357", "n01914609", "n04467665", "n02442845", "n04111531", "n02108000", "n02992211", "n02077923", "n03400231", "n03028079", "n02777292", "n02979186", "n04465501", "n02110958", "n01751748", "n03658185", "n03584829", "n01795545", "n02107908", "n04179913", "n02319095", "n03404251", "n01514668", "n03124043", "n03047690", "n02892201", "n04392985", "n01775062", "n04370456", "n03100240", "n03544143", "n03141823", "n02099267", "n03967562", "n03903868", "n03417042", "n02951358", "n01558993", "n03063599", "n02090721", "n01753488", "n01667114", "n02066245", "n07880968", "n04133789", "n02133161", "n02018795", "n02106030", "n02776631", "n01950731", "n02123045", "n02488291", "n07711569", "n02356798", "n07684084", "n02999410", "n03595614", "n02099849", "n04049303", "n02098105", "n03085013", "n06596364", "n02417914", "n02326432", "n04004767", "n02840245", "n04317175", "n03777754", "n04355338", "n02640242", "n03109150", "n02363005", "n01537544", "n02165105", "n02325366", "n02138441", "n02974003", "n03840681", "n04443257", "n03630383", "n02093256", "n01873310", "n12267677", "n02226429", "n03759954", "n02113978", "n02108915", "n02504013", "n03670208", "n03220513", "n02088238", "n03065424", "n02099601", "n03032252", "n03954731", "n02747177", "n02127052", "n04592741", "n03000684", "n01665541", "n02105056", "n02105505", "n01883070", "n04273569", "n04597913", "n02494079", "n02106550", "n04590129", "n04532106", "n03187595", "n03706229", "n03770439", "n02790996", "n04428191", "n02276258", "n03733281", "n01622779", "n02100735", "n02134418", "n02500267", "n02939185", "n10148035", "n03841143", "n03160309", "n02981792", "n04009552", "n07802026", "n03977966", "n02086240", "n03584254", "n04371774", "n04023962", "n01737021", "n02102040", "n02441942", "n04596742", "n02111277", "n02794156", "n01669191", "n02841315", "n03492542", "n04462240", "n02834397"], "novel_classes_2": [8, 10, 14, 18, 19, 23, 28, 31, 45, 46, 49, 51, 53, 54, 56, 57, 59, 64, 67, 71, 73, 78, 79, 80, 82, 85, 86, 96, 97, 99, 101, 102, 103, 104, 106, 109, 111, 114, 116, 121, 126, 132, 136, 137, 139, 143, 144, 146, 147, 151, 153, 159, 161, 163, 168, 175, 177, 179, 180, 189, 194, 197, 201, 202, 204, 206, 209, 210, 211, 215, 216, 225, 233, 235, 238, 239, 244, 251, 254, 255, 261, 264, 268, 270, 271, 276, 279, 280, 283, 284, 290, 294, 296, 303, 306, 313, 314, 315, 317, 319, 321, 323, 324, 327, 328, 333, 334, 336, 338, 339, 340, 347, 349, 352, 355, 356, 357, 359, 360, 361, 364, 370, 371, 373, 376, 377, 386, 387, 389, 392, 395, 397, 399, 401, 402, 404, 406, 410, 415, 416, 417, 419, 427, 428, 439, 440, 448, 450, 453, 454, 459, 460, 466, 468, 472, 475, 487, 488, 490, 499, 505, 506, 521, 526, 527, 529, 530, 532, 536, 537, 539, 546, 552, 555, 556, 557, 558, 562, 569, 570, 571, 573, 574, 575, 584, 585, 586, 587, 591, 593, 599, 606, 614, 615, 617, 619, 620, 623, 626, 628, 629, 631, 633, 636, 637, 638, 639, 651, 653, 656, 667, 668, 671, 673, 674, 676, 678, 698, 700, 701, 703, 708, 709, 710, 717, 722, 723, 736, 739, 746, 752, 753, 755, 760, 761, 773, 774, 777, 781, 784, 785, 787, 790, 794, 800, 801, 803, 806, 809, 819, 820, 823, 829, 835, 837, 838, 841, 845, 849, 850, 851, 852, 854, 866, 867, 870, 879, 881, 882, 884, 885, 886, 895, 897, 898, 901, 902, 903, 911, 913, 914, 916, 921, 922, 925, 927, 930, 931, 932, 934, 937, 948, 953, 954, 958, 960, 962, 969, 975, 978, 980, 981, 982, 984, 986, 987, 990, 992, 993, 995, 997], "base_classes_1": [6, 9, 11, 22, 26, 33, 34, 44, 50, 52, 63, 69, 72, 74, 75, 76, 77, 84, 89, 94, 120, 122, 125, 148, 149, 152, 154, 156, 165, 171, 173, 176, 178, 183, 184, 187, 196, 198, 208, 217, 219, 221, 226, 228, 231, 234, 240, 245, 248, 249, 257, 258, 259, 260, 265, 267, 273, 274, 277, 278, 281, 289, 295, 300, 302, 309, 310, 311, 316, 318, 320, 331, 343, 348, 363, 367, 368, 374, 375, 383, 394, 405, 408, 420, 422, 423, 425, 426, 433, 434, 436, 438, 444, 463, 471, 473, 474, 476, 480, 483, 486, 493, 494, 495, 496, 497, 503, 504, 507, 513, 515, 517, 519, 531, 540, 541, 544, 551, 561, 566, 572, 576, 577, 582, 590, 594, 598, 607, 618, 622, 634, 643, 655, 657, 660, 663, 664, 670, 677, 679, 684, 688, 690, 691, 695, 696, 699, 702, 707, 714, 718, 727, 731, 732, 734, 735, 747, 765, 768, 775, 788, 795, 802, 813, 824, 825, 834, 836, 839, 844, 863, 865, 874, 875, 887, 889, 893, 906, 908, 917, 924, 936, 938, 945, 957, 959, 961, 964, 967, 973, 974, 977, 979], "base_classes_2": [0, 1, 2, 12, 13, 17, 20, 21, 32, 35, 36, 38, 42, 60, 62, 90, 98, 100, 110, 113, 117, 119, 133, 140, 141, 160, 162, 167, 172, 182, 191, 200, 205, 207, 212, 213, 220, 223, 242, 252, 253, 256, 262, 275, 282, 285, 292, 297, 299, 305, 307, 308, 325, 332, 335, 337, 341, 344, 353, 369, 372, 379, 380, 396, 398, 403, 413, 414, 418, 421, 435, 441, 455, 462, 470, 481, 482, 484, 485, 489, 500, 502, 508, 510, 511, 516, 520, 538, 545, 547, 553, 559, 560, 564, 580, 581, 583, 588, 597, 601, 602, 603, 604, 605, 608, 609, 611, 616, 621, 625, 627, 640, 641, 647, 650, 658, 659, 662, 666, 672, 675, 681, 685, 686, 693, 704, 711, 720, 725, 728, 737, 738, 743, 744, 748, 750, 754, 757, 758, 762, 764, 767, 771, 776, 779, 786, 789, 791, 793, 797, 798, 805, 810, 817, 821, 826, 827, 828, 830, 840, 843, 848, 853, 858, 860, 862, 864, 872, 876, 892, 899, 900, 905, 907, 910, 912, 919, 920, 926, 929, 933, 939, 940, 943, 947, 949, 952, 956, 966, 968, 970, 971, 983, 994, 996, 999]} diff --git a/main_linear_classification.py b/main_linear_classification.py new file mode 100644 index 0000000..d062614 --- /dev/null +++ b/main_linear_classification.py @@ -0,0 +1,387 @@ +import argparse +import os +import random +import warnings +import pathlib +import yaml + +import torch +import torch.nn +import torch.nn.parallel +import torch.backends.cudnn +import torch.distributed +import torch.multiprocessing + +import obow.feature_extractor +import obow.classification +import obow.utils +import obow.datasets +from obow import project_root + + +def get_arguments(): + """ Parse input arguments. """ + default_dst_dir = str(pathlib.Path(project_root) / "experiments") + parser = argparse.ArgumentParser( + description='Linear classification evaluation using a pre-trained with ' + 'OBoW feature extractor (from the student network).') + parser.add_argument( + '-j', '--workers', default=4, type=int, + help='Number of data loading workers (default: 4)') + parser.add_argument( + '-b', '--batch-size', default=256, type=int, + help='Mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel.') + parser.add_argument( + '--start-epoch', default=0, type=int, + help='Manual epoch number to start training in case of restart (default 0).' + 'If -1, then it stargs training from the last available checkpoint.') + parser.add_argument( + '-p', '--print-freq', default=200, type=int, + help='Print frequency (default: 200)') + parser.add_argument( + '--world-size', default=1, type=int, + help='Number of nodes for distributed training (default 1)') + parser.add_argument( + '--rank', default=0, type=int, + help='Node rank for distributed training (default 0)') + parser.add_argument( + '--dist-url', default='tcp://127.0.0.1:4444', type=str, + help='Url used to set up distributed training ' + '(default tcp://127.0.0.1:4444)') + parser.add_argument( + '--dist-backend', default='nccl', type=str, + help='Distributed backend (default nccl)') + parser.add_argument( + '--seed', default=None, type=int, + help='Seed for initializing training (default None)') + parser.add_argument( + '--gpu', default=None, type=int, + help='GPU id to use (default: None). If None it will try to use all ' + 'the available GPUs.') + parser.add_argument( + '--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + parser.add_argument( + '--dst-dir', default=default_dst_dir, type=str, + help='Base directory where the experiments data (i.e, checkpoints) of ' + 'the pre-trained OBoW model is stored (default: ' + f'{default_dst_dir}). The final directory path would be: ' + '"dst-dir / config", where config is the name of the config file.') + parser.add_argument( + '--config', type=str, required=True, default="", + help='Config file that was used for training the OBoW model.') + parser.add_argument( + '--name', default='semi_supervised', type=str, + help='The directory name of the experiment. The final directory ' + 'where the model and logs would be stored is: ' + '"dst-dir / config / name", where dst-dir is the base directory ' + 'for the OBoW model and config is the name of the config file ' + 'that was used for training the model.') + parser.add_argument( + '--evaluate', action='store_true', help='Evaluate the model.') + parser.add_argument( + '--dataset', required=True, default='', type=str, + help='Dataset that will be used for the linear classification ' + 'evaluation. Supported options: ImageNet, Places205.') + parser.add_argument( + '--data-dir', required=True, type=str, default='', + help='Directory path to the ImageNet or Places205 datasets.') + parser.add_argument('--subset', default=-1, type=int, + help='The number of images per class that they would be use for ' + 'training (default -1). If -1, then all the availabe images are ' + 'used.') + parser.add_argument( + '-n', '--batch-norm', action='store_true', + help='Use batch normalization (without affine transform) on the linear ' + 'classifier. By default this option is deactivated.') + parser.add_argument('--epochs', default=100, type=int, + help='Number of total epochs to run (default 100).') + parser.add_argument('--lr', '--learning-rate', default=10.0, type=float, + help='Initial learning rate (default 10.0)', dest='lr') + parser.add_argument('--cos-schedule', action='store_true', + help='If True then a cosine learning rate schedule is used. Otherwise ' + 'a step-wise learning rate schedule is used. In this latter case, ' + 'the schedule and lr-decay arguments must be specified.') + parser.add_argument( + '--schedule', default=[15, 30, 45,], nargs='*', type=int, + help='Learning rate schedule (when to drop lr by a lr-decay ratio) ' + '(default: 15, 30, 45). This argument is only used in case of ' + 'step-wise learning rate schedule (when the cos-schedule flag is ' + 'not activated).') + parser.add_argument( + '--lr-decay', default=0.1, type=float, + help='Learning rate decay step (default 0.1). This argument is only ' + 'used in case of step-wise learning rate schedule (when the ' + 'cos-schedule flag is not activated).' ) + parser.add_argument('--momentum', default=0.9, type=float, + help='Momentum (default 0.9)') + parser.add_argument('--wd', '--weight-decay', default=0.0, type=float, + help='Weight decay (default: 0.)', dest='weight_decay') + parser.add_argument('--nesterov', action='store_true') + parser.add_argument( + '--precache', action='store_true', + help='Precache features for the linear classifier. Those features are ' + 'deleted after the end of training.') + parser.add_argument( + '--cache-dir', default='', type=str, + help='destination directory for the precached features.') + parser.add_argument( + '--cache-5crop', action='store_true', + help='Use five crops when precaching features (only for the train set).') + + args = parser.parse_args() + args.feature_extractor_dir = pathlib.Path(args.dst_dir) / args.config + os.makedirs(args.feature_extractor_dir, exist_ok=True) + args.exp_dir = args.feature_extractor_dir / args.name + os.makedirs(args.exp_dir, exist_ok=True) + + # Load the configuration params of the experiment + full_config_path = pathlib.Path(project_root) / "config" / (args.config + ".yaml") + print(f"Loading experiment {full_config_path}") + with open(full_config_path, "r") as f: + args.exp_config = yaml.load(f, Loader=yaml.SafeLoader) + + print(f"Logs and/or checkpoints will be stored on {args.exp_dir}") + + if args.precache: + if args.cache_dir == '': + raise ValueError( + 'To precache the features (--precache argument) you need to ' + 'specify with the --cache-dir argument the directory where the ' + 'features will be stored.') + cache_dir_name = f"{args.config}" + args.cache_dir = pathlib.Path(args.cache_dir) / cache_dir_name + os.makedirs(args.cache_dir, exist_ok=True) + args.cache_dir = pathlib.Path(args.cache_dir) / "cache_features" + os.makedirs(args.cache_dir, exist_ok=True) + + return args + + +def setup_model_for_distributed_training(model, args, ngpus_per_node): + if args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + model.linear_classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm( + model.linear_classifier) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int( + (args.workers + ngpus_per_node - 1) / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + + # According to Distributed Data Paralled (DDP) pytorch page: + # https://pytorch.org/docs/stable/notes/ddp.html?highlight=distributed + # "The DDP constructor takes a reference to the local module, and + # broadcasts state_dict() from the process with rank 0 to all other + # processes in the group to make sure that all model replicas start + # from the exact same state" + # So, all processes have exactly the same replica of the model at this + # moment. + elif (args.gpu is not None) or (ngpus_per_node == 1): + if (args.gpu is None) and ngpus_per_node == 1: + args.gpu = 0 + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + else: + raise NotImplementedError( + "torch.nn.DataParallel is not supported. " + "Use DistributedDataParallel instead with the argument " + "--multiprocessing-distributed).") + + print(f'==> workers={args.workers}') + + return model, args + + +def main(): + args = get_arguments() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + torch.multiprocessing.spawn( + main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + print( + f"gpu = {gpu} ngpus_per_node={ngpus_per_node} " + f"distributed={args.distributed} args={args}") + args.gpu = gpu + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + print(f"args.rank = {args.rank}") + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank) + + torch.backends.cudnn.benchmark = True + arch = args.exp_config['model']['feature_extractor_arch'] + if args.gpu == 0 or args.gpu is None: + obow.utils.setup_logger(args.exp_dir, "obow") + print(f"Creating linear classifier model with {arch} backbone.") + feature_extractor, channels = obow.feature_extractor.FeatureExtractor( + arch=arch, opts=args.exp_config['model']['feature_extractor_opts']) + dataset_to_num_classes = { + "ImageNet": 1000, + "Places205": 205, + } + assert args.dataset in dataset_to_num_classes + linear_classifier_opts = { + "num_classes": dataset_to_num_classes[args.dataset], + "num_channels": channels, + "batch_norm": args.batch_norm, + "pool_type": "none", + } + search_pattern = "feature_extractor_net_checkpoint_{epoch}.pth.tar" + search_pattern = str(args.feature_extractor_dir / search_pattern) + _, filename = obow.utils.find_last_epoch(search_pattern) + print(f"Loading pre-trained feature extractor from: {filename}") + out_msg = obow.utils.load_network_params( + feature_extractor, filename, strict=False) + print(f"Loading output msg: {out_msg}") + #assert str(out_msg) == "" + + model = obow.classification.FrozenFeaturesLinearClassifier( + feature_extractor=feature_extractor, + linear_classifier_opts=linear_classifier_opts, + ) + if args.gpu == 0 or args.gpu is None: + print(f"Model:\n{model}") + + model_without_ddp = model + model, args = setup_model_for_distributed_training( + model, args, ngpus_per_node) + + if args.precache: + feature_extractor = model.precache_feature_extractor() + if args.distributed: + raise NotImplementedError( + "Precaching with distributed is not supported.") + loader_train, sampler_train, _, loader_test, _, _ = ( + obow.datasets.get_data_loaders_linear_classification_precache( + dataset_name=args.dataset, + data_dir=args.data_dir, + batch_size=args.batch_size, + workers=args.workers, + epoch_size=None, + feature_extractor=feature_extractor, + cache_dir=args.cache_dir, + device=torch.device(args.gpu), + precache_batch_size=200, + five_crop=args.cache_5crop, + subset=args.subset)) + else: + loader_train, sampler_train, _, loader_test, _, _ = ( + obow.datasets.get_data_loaders_classification( + dataset_name=args.dataset, + data_dir=args.data_dir, + batch_size=args.batch_size, + workers=args.workers, + distributed=args.distributed, + epoch_size=None, + subset=args.subset)) + + if args.cos_schedule: + optim_opts = { + "optim_type": "sgd", + "lr": args.lr, + "num_epochs": args.epochs, + "lr_schedule_type": "cos", + "momentum": args.momentum, + "weight_decay": args.weight_decay, + "nesterov": args.nesterov, + } + else: + optim_opts = { + "optim_type": "sgd", + "lr": args.lr, + "num_epochs": args.epochs, + "lr_schedule_type": "step_lr", + "lr_schedule": args.schedule, + "lr_decay": args.lr_decay, + "momentum": args.momentum, + "weight_decay": args.weight_decay, + "nesterov": args.nesterov, + } + device = torch.device(args.gpu) # or torch.device('cuda') + solver = obow.classification.SupervisedClassifierSolver( + model, args.exp_dir, device, optim_opts, args.print_freq) + if args.start_epoch != 0: + print(f"[Rank {args.gpu}] - Loading checkpoint of: {args.start_epoch}") + solver.load_checkpoint(epoch=args.start_epoch) + + if args.start_epoch != 0 or args.evaluate: + solver.evaluate(loader_test) + + if args.evaluate: + return + + solver.solve( + loader_train=loader_train, + distributed=args.distributed, + sampler_train=sampler_train, + loader_test=loader_test) + + if args.precache: + # Delete precached features. + import shutil + shutil.rmtree(args.cache_dir) + +if __name__ == '__main__': + main() diff --git a/main_obow.py b/main_obow.py new file mode 100644 index 0000000..3f3b4c6 --- /dev/null +++ b/main_obow.py @@ -0,0 +1,419 @@ +import argparse +import copy +import os +import random +import warnings +import pathlib +import yaml + +import torch +import torch.nn +import torch.nn.parallel +import torch.backends.cudnn +import torch.distributed +import torch.multiprocessing + +import obow.builder_obow +import obow.feature_extractor +import obow.utils +import obow.datasets +import obow.visualization +import numpy as np + +from obow import project_root + + +def get_arguments(): + """ Parse input arguments. """ + default_dst_dir = str(pathlib.Path(project_root) / "experiments") + parser = argparse.ArgumentParser( + description="Trains OBoW self-supervised models." + ) + parser.add_argument( + '-j', '--workers', default=4, type=int, + help='Number of data loading workers (default: 4).') + parser.add_argument( + '-b', '--batch-size', default=256, type=int, + help='Mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Distributed Data Parallel. Note that if batch_size has ' + 'specified in the config file, then the batch_size of the config ' + 'file overloads this agruement.') + parser.add_argument( + '--start-epoch', default=0, type=int, + help='Manual epoch number to start training in case of restart. ' + 'If -1, then it restarts from the last available checkpoint.') + parser.add_argument( + '-p', '--print-freq', default=200, type=int, + help='print frequency (default: 200)') + parser.add_argument( + '--world-size', default=1, type=int, + help='Number of nodes for distributed training (default: 1)') + parser.add_argument( + '--rank', default=0, type=int, + help='Node rank for distributed training (default: 0)') + parser.add_argument( + '--dist-url', default='tcp://127.0.0.1:4444', type=str, + help='url used to set up distributed training ' + '(default: tcp://127.0.0.1:4444)') + parser.add_argument( + '--dist-backend', default='nccl', type=str, + help='Distributed backend (default: nccl)') + parser.add_argument( + '--seed', default=None, type=int, + help='Seed for initializing training (default: None).') + parser.add_argument( + '--gpu', default=None, type=int, + help='GPU id to use (default: None). If None it will try to use all ' + 'the available GPUs.') + parser.add_argument( + '--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training.') + parser.add_argument( + '--dst-dir', default=default_dst_dir, type=str, + help='Base directory where the experiments data ' + '(i.e., checkpoints, logts, etc) would be stored (default: ' + f'{default_dst_dir}). The final directory path would be: ' + '"dst-dir / config", where config is the name of the config file.') + parser.add_argument( + '--config', type=str, required=True, default="", + help='Config file with parameters of the experiment.') + parser.add_argument( + '--data-dir', required=True, type=str, default="", + help='Directory path to the ImageNet dataset.') + + # Arguments related to the few-shot evaluation of the learned features. + parser.add_argument( + '--evaluate', action='store_true', + help='Evaluate the model. No training is performed in this case.' + 'By default it evaluates the model of the last available checkpoint.') + parser.add_argument( + '--episodes', default=0, type=int, + help='Number of episodes for few-shot evaluation (default 0).') + parser.add_argument('--fewshot-k', default=[1,], nargs='*', type=int, + help='Number of training examples per class for few-shot evaluatation.') + parser.add_argument( + '--fewshot-n', default=50, type=int, + help='Number of novel classes per episode for few-shot evaluation.') + parser.add_argument( + '--fewshot-q', default=1, type=int, + help='Number of test examples per class for few-shot evaluatation.') + parser.add_argument( + '--convert-to-torchvision', action='store_true', + help='Converts and saves the student resnet backbone in torchvision ' + 'format. No training or evaluation is performed in this case. ' + 'Note that it converts the model of the last available checkpoint.') + parser.add_argument( + '--visualize-words', action='store_true', + help='Visualize the visual words of OBoW. ' + 'No training or evaluation is performed in this case. ' + 'Note that it visualizes the model of the last available checkpoint.') + + args = parser.parse_args() + exp_directory = pathlib.Path(args.dst_dir) / args.config + os.makedirs(exp_directory, exist_ok=True) + + # Load the configuration params of the experiment + full_config_path = pathlib.Path(project_root) / "config" / (args.config + ".yaml") + print(f"Loading experiment {full_config_path}") + with open(full_config_path, "r") as f: + args.exp_config = yaml.load(f, Loader=yaml.SafeLoader) + args.exp_dir = exp_directory + + if "batch_size" in args.exp_config["data"]: + args.batch_size = args.exp_config["data"].pop("batch_size") + + print(f"Logs and/or checkpoints will be stored on {exp_directory}") + + return args + + +def setup_model_distributed_data_parallel(model, args): + if args.distributed: + if args.gpu is not None: + model.cuda(args.gpu) + make_bn_sync = torch.nn.SyncBatchNorm.convert_sync_batchnorm + model.feature_extractor = make_bn_sync(model.feature_extractor) + model.feature_extractor_teacher = make_bn_sync(model.feature_extractor_teacher) + # The BN layer of the weight generator in the bow_predictor is + # not converted to synchrorized batchnorm because it is + # unecessary: the weight generators in all GPUs get as input the + # same vocabulary. + print("Use synchronized BN for the feature extractors.") + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif (args.gpu is not None): + model = model.cuda(args.gpu) + + return model, args + + +def main(): + args = get_arguments() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + if args.world_size > 1: + raise NotImplementedError( + f"Multi-machine distributed training (ie, " + f"world_size={args.world_size} > 1) is not supported. " + f"Only single-machine single-GPU and single-machine multi-GPU " + f"training is supported.") + + if ((torch.cuda.device_count() > 1) and + (args.gpu is not None) and + (not args.multiprocessing_distributed)): + raise NotImplementedError( + f"There are {torch.cuda.device_count()} GPUs available in the " + "machine.\nHowever, Multi-GPU training is only supported via " + "DistributedDataParallel and requires to activate the argument " + "--multiprocessing-distributed.\nOtherwise choose a single GPU to " + "run the experiment, e.g., by adding the argument --gpu=0.") + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + ngpus_per_node = torch.cuda.device_count() + + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + torch.multiprocessing.spawn( + main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Single-machine single-GPU training setting. + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def parse_model_opts(model_opts, num_channels, num_iters_total): + bow_extractor_opts = model_opts["bow_extractor_opts"] + num_words = bow_extractor_opts["num_words"] + inv_delta = bow_extractor_opts["inv_delta"] + bow_levels = model_opts["bow_levels"] + num_bow_levels = len(bow_levels) + if not isinstance(inv_delta, (list, tuple)): + inv_delta = [inv_delta for _ in range(num_bow_levels)] + if not isinstance(num_words, (list, tuple)): + num_words = [num_words for _ in range(num_bow_levels)] + + bow_extractor_opts_list = [] + for i in range(num_bow_levels): + bow_extr_this = copy.deepcopy(bow_extractor_opts) + if isinstance(bow_extr_this["inv_delta"], (list, tuple)): + bow_extr_this["inv_delta"] = bow_extr_this["inv_delta"][i] + if isinstance(bow_extr_this["num_words"], (list, tuple)): + bow_extr_this["num_words"] = bow_extr_this["num_words"][i] + bow_extr_this["num_channels"] = num_channels // (2**(num_bow_levels - 1 - i)) + bow_extractor_opts_list.append(bow_extr_this) + + model_opts["bow_extractor_opts_list"] = bow_extractor_opts_list + + if model_opts.pop("alpha_cosine", False): + alpha_base = model_opts["alpha"] + model_opts["alpha"] = (alpha_base, num_iters_total) + + return model_opts + + +def visualize_words(model, args, data_opts, dataset_name, data_dir): + loader, dataset = obow.datasets.get_data_loaders_for_visualization( + dataset_name=dataset_name, + data_dir=data_dir, + batch_size=args.batch_size, + workers=args.workers, + distributed=args.distributed, + split="train", + **data_opts) + + all_vword_ids, all_vword_mag, num_words = ( + obow.visualization.extract_visual_words(model, loader)) + + num_words_freq, vwords_order = [], [] + for i, v in enumerate(all_vword_ids): + print(f"all_vword_ids[{i}]: {v.shape}") + num_words_freq.append(np.bincount(v.reshape(-1), minlength=num_words[i])) + num_words_freq[i] = num_words_freq[i].reshape(-1) + print(f"num_words_freq[{i}]: {num_words_freq[i].shape}") + vwords_order.append(np.argsort(-num_words_freq[i])) + print(f"vwords_order[{i}]: {vwords_order[i].shape}") + + num_patches = 8 + patch_size = 64 + num_levels = len(num_words) + levels = list(range(num_levels)) + levels.reverse() + for i in levels: + dst_dir = os.path.join(str(args.exp_dir), f"visual_words_L{i}") + print(f"Saving visualizations on {dst_dir}") + os.makedirs(dst_dir, exist_ok=True) + obow.visualization.visualize_visual_words( + num_words[i], num_patches, patch_size, dataset, all_vword_ids[i], + all_vword_mag[i], vwords_order[i], dst_dir) + + +def main_worker(gpu, ngpus_per_node, args): + print(f"main_worker(gpu={gpu} ngpus_per_node={ngpus_per_node} args={args})") + args.gpu = gpu + + if args.gpu is not None: + print(f"==> Use GPU: {args.gpu} for training.") + + if args.distributed: + # Single-machine Multi-GPU training setting. + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.rank < 0 or args.rank > (args.world_size // ngpus_per_node): + raise ValueError( + f"Invalid rank argument {args.rank}. " + "Rank must specify the id of the current machine in the " + "multi-machine distributed training setting. In case of " + "single-machine multi-gpu distributed setting (which is the " + "most common) then rank must be 0, ie, the id of the single " + "machine.") + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank) + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + print(f'Rank={args.rank}: workers={args.workers} batch_size={args.batch_size}') + else: + # Single-machine single-GPU training setting. + if (args.gpu is None) and ngpus_per_node == 1: + args.gpu = 0 + torch.cuda.set_device(args.gpu) + + torch.backends.cudnn.benchmark = True + + if args.gpu == 0 or args.gpu is None: + obow.utils.setup_logger(args.exp_dir, "obow") + + data_opts = args.exp_config["data"] + dataset_name = data_opts.pop("dataset_name") + epoch_size = data_opts.pop("epoch_size", None) + + loader_train, sampler_train, _, loader_test, _, _ = ( + obow.datasets.get_data_loaders_for_OBoW( + dataset_name=dataset_name, + data_dir=args.data_dir, + batch_size=args.batch_size, + workers=args.workers, + distributed=args.distributed, + epoch_size=epoch_size, + **data_opts)) + num_iters_total = len(loader_train) * args.exp_config['optim']["num_epochs"] + + model_opts = args.exp_config["model"] + print(f"Creating an OBoW model with opts: {model_opts}") + feature_extractor, num_channels = obow.feature_extractor.FeatureExtractor( + arch=model_opts['feature_extractor_arch'], + opts=model_opts['feature_extractor_opts']) + model_opts = parse_model_opts(model_opts, num_channels, num_iters_total) + + model = obow.builder_obow.OBoW( + feature_extractor=feature_extractor, + num_channels=num_channels, + bow_levels=model_opts["bow_levels"], + bow_extractor_opts_list=model_opts["bow_extractor_opts_list"], + bow_predictor_opts=model_opts["bow_predictor_opts"], + alpha=model_opts["alpha"], + num_classes=model_opts.get("num_classes", None)) + + model_without_ddp = model + model, args = setup_model_distributed_data_parallel(model, args) + print(f"Model:\n{model}") + + optim_opts = args.exp_config['optim'] + device = torch.device(args.gpu) # or torch.device('cuda') + solver = obow.builder_obow.OBoWSolver( + model, args.exp_dir, device, optim_opts, args.print_freq) + + if args.evaluate or args.convert_to_torchvision or args.visualize_words: + args.start_epoch = -1 # load the last available checkpoint. + if args.start_epoch != 0: + print(f"==> [Rank {args.gpu}] - Loading checkpoint of: {args.start_epoch}") + solver.load_checkpoint(epoch=args.start_epoch) + + if args.convert_to_torchvision: + arch = model_opts['feature_extractor_arch'] + assert arch in ("resnet50", "resnet18") + solver.save_feature_extractor_in_torchvision_format( + arch=model_opts['feature_extractor_arch']) + return + + if args.visualize_words: + model = solver.model.module if args.distributed else solver.model + args.batch_size = 128 + visualize_words(model, args, data_opts, dataset_name, args.data_dir) + return + + loaders_test_all = [loader_test,] + if args.episodes > 0: + print(f"==> Few-shot evaluation: #{args.episodes} " + f"{args.fewshot_n}-way {args.fewshot_k}-shot " + f"(q={args.fewshot_q}) tasks") + loaders_fewshot_test, _, _ = obow.datasets.get_data_loaders_fewshot( + dataset_name=dataset_name, + data_dir=args.data_dir, + batch_size=1, + workers=args.workers, + distributed=args.distributed, + split="test" if args.evaluate else "val", + epoch_size=args.episodes, + num_novel=args.fewshot_n, + num_train=args.fewshot_k, + num_test=args.fewshot_q*args.fewshot_n) + loaders_test_all += loaders_fewshot_test + + + if args.start_epoch != 0 or args.evaluate: + for i, loaders_test_this in enumerate(loaders_test_all): + solver.evaluate(loaders_test_this) + + if args.evaluate: + return + + solver.solve( + loader_train=loader_train, + distributed=args.distributed, + sampler_train=sampler_train, + loader_test=loaders_test_all) + + solver.save_feature_extractor(distributed=args.distributed) + +if __name__ == '__main__': + main() diff --git a/main_semisupervised.py b/main_semisupervised.py new file mode 100644 index 0000000..4713945 --- /dev/null +++ b/main_semisupervised.py @@ -0,0 +1,301 @@ +import argparse +import os +import random +import warnings +import pathlib +import yaml + +import torch +import torch.nn +import torch.nn.parallel +import torch.backends.cudnn +import torch.distributed +import torch.multiprocessing + +import obow.feature_extractor +import obow.classification +import obow.utils +import obow.datasets +from obow import project_root + + +def get_arguments(): + """ Parse input arguments. """ + default_dst_dir = str(pathlib.Path(project_root) / "experiments") + parser = argparse.ArgumentParser( + description='Semi-supervised ImageNet evaluation using a pre-trained ' + 'feature extractor.') + parser.add_argument( + '-j', '--workers', default=4, type=int, + help='Number of data loading workers (default 4)') + parser.add_argument( + '-b', '--batch-size', default=256, type=int, + help='Mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel.') + parser.add_argument( + '--start-epoch', default=0, type=int, + help='Manual epoch number to start training in case of restart.' + 'If -1, then it starts training from the last available checkpoint.') + parser.add_argument( + '-p', '--print-freq', default=200, type=int, + help='Print frequency (default: 200)') + parser.add_argument( + '--world-size', default=1, type=int, + help='Number of nodes for distributed training (default 1)') + parser.add_argument( + '--rank', default=0, type=int, + help='Node rank for distributed training (default 0)') + parser.add_argument( + '--dist-url', default='tcp://127.0.0.1:4444', type=str, + help='Url used to set up distributed training ' + '(default tcp://127.0.0.1:4444)') + parser.add_argument( + '--dist-backend', default='nccl', type=str, + help='Distributed backend (default nccl)') + parser.add_argument( + '--seed', default=None, type=int, + help='Seed for initializing training (default None)') + parser.add_argument( + '--gpu', default=None, type=int, + help='GPU id to use (default: None). If None it will try to use all ' + 'the available GPUs.') + parser.add_argument( + '--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + parser.add_argument( + '--dst-dir', default=default_dst_dir, type=str, + help='Base directory where the experiments data (i.e, checkpoints) of ' + 'the pre-trained OBoW model is stored (default: ' + f'{default_dst_dir}). The final directory path would be: ' + '"dst-dir / config", where config is the name of the config file.') + parser.add_argument( + '--config', type=str, required=True, default="", + help='Config file that was used for training the OBoW model.') + parser.add_argument( + '--evaluate', action='store_true', help='Evaluate the model.') + parser.add_argument( + '--name', default='semi_supervised', type=str, + help='The directory name of the experiment. The final directory ' + 'where the model and logs would be stored is: ' + '"dst-dir / config / name", where dst-dir is the base directory ' + 'for the OBoW model and config is the name of the config file ' + 'that was used for training the model.') + parser.add_argument( + '--data-dir', required=True, type=str, default="", + help='Directory path to the ImageNet dataset.') + parser.add_argument( + '--percentage', default=1, type=int, + help='Percentage of ImageNet annotated images (default 1). Only the ' + 'values 1 (for 1 percent of annotated images) and 10 (for 10 ' + 'percent of annotated images) are supported.') + parser.add_argument('--epochs', default=40, type=int, + help='Number of total epochs to run.') + parser.add_argument('--lr', default=0.0002, type=float, + help='Initial learning rate for the feature extractor trunk ' + '(default 0.0002).') + parser.add_argument('--lr-head', default=0.5, type=float, + help='Initial learning rate of the classification head (default 0.5).') + parser.add_argument('--momentum', default=0.9, type=float, + help='Momentum (default 0.9).') + parser.add_argument('--wd', '--weight-decay', default=0.0, type=float, + help='Weight decay (default: 0.)', dest='weight_decay') + parser.add_argument( + '--lr-decay', default=0.2, type=float, + help='Learning rate decay step (default 0.2).') + parser.add_argument( + '--schedule', default=[24, 32,], nargs='*', type=int, + help='Learning rate schedule, i.e., when to drop lr by a lr-decay ratio' + ' (default: 24, 32 which means after 24 and 32 epochs)') + parser.add_argument('--nesterov', action='store_true') + + args = parser.parse_args() + args.feature_extractor_dir = pathlib.Path(args.dst_dir) / args.config + os.makedirs(args.feature_extractor_dir, exist_ok=True) + args.exp_dir = args.feature_extractor_dir / args.name + os.makedirs(args.exp_dir, exist_ok=True) + + # Load the configuration params of the experiment + full_config_path = pathlib.Path(project_root) / "config" / (args.config + ".yaml") + print(f"Loading experiment {full_config_path}") + with open(full_config_path, "r") as f: + args.exp_config = yaml.load(f, Loader=yaml.SafeLoader) + + print(f"Logs and/or checkpoints will be stored on {args.exp_dir}") + + return args + + +def setup_model_for_distributed_training(model, args, ngpus_per_node): + if args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have. + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int( + (args.workers + ngpus_per_node - 1) / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif (args.gpu is not None) or (ngpus_per_node == 1): + if (args.gpu is None) and ngpus_per_node == 1: + args.gpu = 0 + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + else: + raise NotImplementedError( + "torch.nn.DataParallel is not supported. " + "Use DistributedDataParallel instead with the argument " + "--multiprocessing-distributed).") + + return model, args + + +def main(): + args = get_arguments() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + torch.multiprocessing.spawn( + main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + print( + f"gpu = {gpu} ngpus_per_node={ngpus_per_node} " + f"distributed={args.distributed} args={args}") + args.gpu = gpu + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + print(f"args.rank = {args.rank}") + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank) + + torch.backends.cudnn.benchmark = True + arch = args.exp_config['model']['feature_extractor_arch'] + if args.gpu == 0 or args.gpu is None: + obow.utils.setup_logger(args.exp_dir, "obow") + print(f"Creating classification model with {arch} backbone.") + + feature_extractor, num_channels = obow.feature_extractor.FeatureExtractor( + arch=arch, opts={"global_pooling": True}) + linear_classifier_opts = { + "num_classes": 1000, + "num_channels": num_channels, + "batch_norm": False, + "pool_type": "none", + } + search_pattern = "feature_extractor_net_checkpoint_{epoch}.pth.tar" + search_pattern = str(args.feature_extractor_dir / search_pattern) + _, pretrained = obow.utils.find_last_epoch(search_pattern) + print(f"Loading pre-trained feature extractor from: {pretrained}") + out_msg = obow.utils.load_network_params( + feature_extractor, pretrained, strict=False) + print(f"Loading output msg: {out_msg}") + + model = obow.classification.SupervisedClassification( + feature_extractor=feature_extractor, + linear_classifier_opts=linear_classifier_opts, + ) + model_without_ddp = model + model, args = setup_model_for_distributed_training( + model, args, ngpus_per_node) + if args.gpu == 0 or args.gpu is None: + print(f"Model:\n{model}") + + loader_train, sampler_train, _, loader_test, _, _ = ( + obow.datasets.get_data_loaders_semisupervised_classification( + dataset_name="ImageNet", + data_dir=args.data_dir, + batch_size=args.batch_size, + workers=args.workers, + distributed=args.distributed, + epoch_size=None, + percentage=args.percentage)) + optim_opts = { + "optim_type": "sgd", + "lr": args.lr, + "start_lr_head": args.lr_head, + "num_epochs": args.epochs, + "lr_schedule_type": "step_lr", + "lr_schedule": args.schedule, + "lr_decay": args.lr_decay, + "momentum": args.momentum, + "weight_decay": args.weight_decay, + "nesterov": args.nesterov, + "eval_freq": 4 if args.percentage == 1 else 1, + } + device = torch.device(args.gpu) + solver = obow.classification.SupervisedClassifierSolver( + model, args.exp_dir, device, optim_opts, args.print_freq) + if args.start_epoch != 0: + print(f"[Rank {args.gpu}] - Loading checkpoint of: {args.start_epoch}") + solver.load_checkpoint(epoch=args.start_epoch) + + if args.start_epoch != 0 or args.evaluate: + solver.evaluate(loader_test) + + if args.evaluate: + return + + solver.solve( + loader_train=loader_train, + distributed=args.distributed, + sampler_train=sampler_train, + loader_test=loader_test) + +if __name__ == '__main__': + main() diff --git a/obow/__init__.py b/obow/__init__.py new file mode 100644 index 0000000..124bcc8 --- /dev/null +++ b/obow/__init__.py @@ -0,0 +1,4 @@ +import pathlib + + +project_root = pathlib.Path(__file__).resolve().parents[1] diff --git a/obow/builder_obow.py b/obow/builder_obow.py new file mode 100644 index 0000000..f2a3f1f --- /dev/null +++ b/obow/builder_obow.py @@ -0,0 +1,767 @@ +import copy +import math +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F + +import obow.utils as utils +import obow.solver as solver +import obow.fewshot as fewshot +from obow.classification import PredictionHead + + +logger = utils.setup_dist_logger(logging.getLogger(__name__)) + + +@torch.no_grad() +def compute_bow_perplexity(bow_target): + """ Compute the per image and per batch perplexity of the bow_targets. """ + assert isinstance(bow_target, (list, tuple)) + + perplexity_batch, perplexity_img = [], [] + for bow_target_level in bow_target: # For each bow level. + assert bow_target_level.dim() == 2 + # shape of bow_target_level: [batch_size x num_words] + + probs = F.normalize(bow_target_level, p=1, dim=1) + perplexity_img_level = torch.exp( + -torch.sum(probs * torch.log(probs + 1e-5), dim=1)).mean() + + bow_target_sum_all = bow_target_level.sum(dim=0) + # Uncomment the following line if you want to compute the perplexity of + # of the entire batch in case of distributed training. + # bow_target_sum_all = utils.reduce_all(bow_target_sum_all) + probs = F.normalize(bow_target_sum_all, p=1, dim=0) + perplexity_batch_level = torch.exp( + -torch.sum(probs * torch.log(probs + 1e-5), dim=0)) + + perplexity_img.append(perplexity_img_level) + perplexity_batch.append(perplexity_batch_level) + + perplexity_batch = torch.stack(perplexity_batch, dim=0).view(-1).tolist() + perplexity_img = torch.stack(perplexity_img, dim=0).view(-1).tolist() + + return perplexity_batch, perplexity_img + + +def expand_target(target, prediction): + """Expands the target in case of BoW predictions from multiple crops.""" + assert prediction.size(1) == target.size(1) + batch_size_x_num_crops, num_words = prediction.size() + batch_size = target.size(0) + assert batch_size_x_num_crops % batch_size == 0 + num_crops = batch_size_x_num_crops // batch_size + + if num_crops > 1: + target = target.unsqueeze(1).repeat(1, num_crops, 1).view(-1, num_words) + + return target + + +class OBoW(nn.Module): + def __init__( + self, + feature_extractor, + num_channels, + bow_levels, + bow_extractor_opts_list, + bow_predictor_opts, + alpha=0.99, + num_classes=None, + ): + """Builds an OBoW model. + + Args: + feature_extractor: essentially the convnet model that is going to be + trained in order to learn image representations. + num_channels: number of channels of the output global feature vector of + the feature_extractor. + bow_levels: a list with the names (strings) of the feature levels from + which the teacher network in OBoW will create BoW targets. + bow_extractor_opts_list: a list of dictionaries with the configuration + options for the BoW extraction (at teacher side) for each BoW level. + Each dictionary should define the following keys (1) "num_words" + with the vocabulary size of this level, (2) "num_channels", + optionally (3) "update_type" (default: "local_averaging"), + optionally (4) "inv_delta" (default: 15), which is the inverse + temperature that is used for computing the soft assignment codes, + and optionally (5) "bow_pool" (default: "max"). For more details + see the documentation of the BoWExtractor class. + bow_predictor_opts: a dictionary with configuration options for the + BoW prediction head of the student. The dictionary must define + the following keys (1) "kappa", a coefficent for scaling the + magnitude of the predicted weights, and optionally (2) "learn_kappa" + (default: False), a boolean value that if true kappa becomes a + learnable parameter. For all the OBoW experiments "learn_kappa" is + set to False. For more details see the documentation of the + BoWPredictor class. + alpha: the momentum coefficient between 0.0 and 1.0 for the teacher + network updates. If alpha is a scalar (e.g., 0.99) then a static + momentum coefficient is used during training. If alpha is tuple of + two values, e.g., alpha=(alpha_base, num_iterations), then OBoW + uses a cosine schedule that starts from alpha_base and it increases + it to 1.0 over num_iterations. + num_classes: (optional) if not None, then it creates a + linear classification head with num_classes outputs that would be + on top of the teacher features for on-line monitoring the quality + of the learned features. No gradients would back-propagated from + this head to the feature extractor trunks. So, it does not + influence the learning of the feature extractor. Note, at the end + the features that are used are those of the student network, not + of the teacher. + """ + super(OBoW, self).__init__() + assert isinstance(bow_levels, (list, tuple)) + assert isinstance(bow_extractor_opts_list, (list, tuple)) + assert len(bow_extractor_opts_list) == len(bow_levels) + + self._bow_levels = bow_levels + self._num_bow_levels = len(bow_levels) + if isinstance(alpha, (tuple, list)): + # Use cosine schedule in order to increase the alpha from + # alpha_base (e.g., 0.99) to 1.0. + alpha_base, num_iterations = alpha + self._alpha_base = alpha_base + self._num_iterations = num_iterations + self.register_buffer("_alpha", torch.FloatTensor(1).fill_(alpha_base)) + self.register_buffer("_iteration", torch.zeros(1)) + self._alpha_cosine_schedule = True + else: + self._alpha = alpha + self._alpha_cosine_schedule = False + + # Build the student network components. + self.feature_extractor = feature_extractor + assert "kappa" in bow_predictor_opts + bow_predictor_opts["num_channels_out"] = num_channels + bow_predictor_opts["num_channels_hidden"] = num_channels * 2 + bow_predictor_opts["num_channels_in"] = [ + d["num_channels"] for d in bow_extractor_opts_list] + self.bow_predictor = BoWPredictor(**bow_predictor_opts) + + # Build the teacher network components. + self.feature_extractor_teacher = copy.deepcopy(self.feature_extractor) + self.bow_extractor = BoWExtractorMultipleLevels(bow_extractor_opts_list) + + if (num_classes is not None): + self.linear_classifier = PredictionHead( + num_channels=num_channels, num_classes=num_classes, + batch_norm=True, pool_type="global_avg") + else: + self.linear_classifier = None + + for param, param_teacher in zip( + self.feature_extractor.parameters(), + self.feature_extractor_teacher.parameters()): + param_teacher.data.copy_(param.data) # initialize + param_teacher.requires_grad = False # not update by gradient + + @torch.no_grad() + def _get_momentum_alpha(self): + if self._alpha_cosine_schedule: + scale = 0.5 * (1. + math.cos((math.pi * self._iteration.item()) / self._num_iterations)) + self._alpha.fill_(1.0 - (1.0 - self._alpha_base) * scale) + self._iteration += 1 + return self._alpha.item() + else: + return self._alpha + + @torch.no_grad() + def _update_teacher(self): + """ Exponetial moving average for the feature_extractor_teacher params: + param_teacher = param_teacher * alpha + param * (1-alpha) + """ + if not self.training: + return + alpha = self._get_momentum_alpha() + if alpha >= 1.0: + return + for param, param_teacher in zip( + self.feature_extractor.parameters(), + self.feature_extractor_teacher.parameters()): + param_teacher.data.mul_(alpha).add_( + param.detach().data, alpha=(1. - alpha)) + + def _bow_loss(self, bow_prediction, bow_target): + assert isinstance(bow_prediction, (list, tuple)) + assert isinstance(bow_target, (list, tuple)) + assert len(bow_prediction) == self._num_bow_levels + assert len(bow_target) == self._num_bow_levels + + # Instead of using a custom made cross-entropy loss for soft targets, + # we use the pytorch kl-divergence loss that is defined as the + # cross-entropy plus the entropy of targets. Since there is no gradient + # back-propagation from the targets, it is equivalent to cross entropy. + loss = [ + F.kl_div(F.log_softmax(p, dim=1), expand_target(t, p), reduction="batchmean") + for (p, t) in zip(bow_prediction, bow_target)] + return torch.stack(loss).mean() + + def _linear_classification(self, features, labels): + # With .detach() no gradients of the classification loss are + # back-propagated to the feature extractor. + # The reason for training such a linear classifier is in order to be + # able to monitor while training the quality of the learned features. + features = features.detach() + if (labels is None) or (self.linear_classifier is None): + return (features.new_full((1,), 0.0).squeeze(), + features.new_full((1,), 0.0).squeeze()) + + scores = self.linear_classifier(features) + loss = F.cross_entropy(scores, labels) + with torch.no_grad(): + accuracy = utils.top1accuracy(scores, labels).item() + + return loss, accuracy + + def generate_bow_targets(self, image): + features = self.feature_extractor_teacher(image, self._bow_levels) + if isinstance(features, torch.Tensor): + features = [features,] + bow_target, _ = self.bow_extractor(features) + return bow_target, features + + def forward_test(self, img_orig, labels): + with torch.no_grad(): + features = self.feature_extractor_teacher(img_orig, self._bow_levels) + features = features if isinstance(features, torch.Tensor) else features[-1] + features = features.detach() + loss_cls, accuracy = self._linear_classification(features, labels) + + return loss_cls, accuracy + + def forward(self, img_orig, img_crops, labels=None): + """ Applies the OBoW self-supervised task to a mini-batch of images. + + Args: + img_orig: 4D tensor with shape [batch_size x 3 x img_height x img_width] + with the mini-batch of images from which the teacher network + generates the BoW targets. + img_crops: list of 4D tensors where each of them is a mini-batch of + image crops with shape [(batch_size * num_crops) x 3 x crop_height x crop_width] + from which the student network predicts the BoW targets. For + example, in the full version of OBoW this list will iclude a + [(batch_size * 2) x 3 x 160 x 160]-shaped tensor with two image crops + of size [160 x 160] pixels and a [(batch_size * 5) x 3 x 96 x 96]- + shaped tensor with five image patches of size [96 x 96] pixels. + labels: (optional) 1D tensor with shape [batch_size] with the class + labels of the img_orig images. If available, it would be used for + on-line monitoring the performance of the linear classifier. + + Returns: + losses: a tensor with the losses for each type of image crop and + (optionally) the loss of the linear classifier. + logs: a list of metrics for monitoring the training progress. It + includes the perplexity of the bow targets in a mini-batch + (perp_b), the perplexity of the bow targets in an image (perp_i), + and (optionally) the accuracy of a linear classifier on-line + trained on the teacher features (this is a proxy for monitoring + during training the quality of the learned features; Note, at the + end the features that are used are those of the student). + """ + if self.training is False: + # For testing, it only computes the linear classification accuracy. + return self.forward_test(img_orig, labels) + + #*********************** MAKE BOW PREDICTIONS ************************** + dictionary = self.bow_extractor.get_dictionary() + features = [self.feature_extractor(x) for x in img_crops] + bow_predictions = self.bow_predictor(features, dictionary) + #*********************************************************************** + #******************** COMPUTE THE BOW TARGETS ************************** + with torch.no_grad(): + self._update_teacher() + bow_target, features_t = self.generate_bow_targets(img_orig) + perp_b, perp_i = compute_bow_perplexity(bow_target) + #*********************************************************************** + #***************** COMPUTE THE BOW PREDICTION LOSSES ******************* + losses = [self._bow_loss(pred, bow_target) for pred in bow_predictions] + #*********************************************************************** + #****** MONITORING: APPLY LINEAR CLASSIFIER ON TEACHER FEATURES ******** + loss_cls, accuracy = self._linear_classification(features_t[-1], labels) + #*********************************************************************** + + losses = torch.stack(losses + [loss_cls,], dim=0).view(-1) + logs = list(perp_b + perp_i) + [accuracy,] + + return losses, logs + + +class BoWExtractor(nn.Module): + def __init__( + self, + num_words, + num_channels, + update_type="local_average", + inv_delta=15, + bow_pool="max"): + """Builds a BoW extraction module for the teacher network. + + It builds a BoW extraction module for the teacher network in which the + visual words vocabulary is on-line updated during training via a + queue-based vocabular/dictionary of randomly sampled local features. + + Args: + num_words: the number of visual words in the vocabulary/dictionary. + num_channels: the number of channels in the teacher feature maps and + visual word embeddings (of the vocabulary). + update_type: with what type of local features to update the queue-based + visual words vocabulary. Three update types are implemenented: + (a) "no_averaging": to update the queue it samples with uniform + distribution one local feature vector per image from the given + teacher feature maps. + (b) "global_averaging": to update the queue it computes from each + image a feature vector by globally average pooling the given + teacher feature maps. + (c) "local_averaging" (default option): to update the queue it + computes from each image a feature vector by first locally averaging + the given teacher feature map with a 3x3 kernel and then samples one + of the resulting feature vectors with uniform distribution. + inv_delta: the base value for the inverse temperature that is used for + computing the soft assignment codes over the visual words, used for + building the BoW targets. If inv_delta is None, then hard assignment + is used instead. + bow_pool: (default "max") how to reduce the assignment codes to BoW + vectors. Two options are supported, "max" for max-pooling and "avg" + for average-pooling. + """ + super(BoWExtractor, self).__init__() + + if inv_delta is not None: + assert isinstance(inv_delta, (float, int)) + assert inv_delta > 0.0 + assert bow_pool in ("max", "avg") + assert update_type in ("local_average", "global_average", "no_averaging") + + self._num_channels = num_channels + self._num_words = num_words + self._update_type = update_type + self._inv_delta = inv_delta + self._bow_pool = bow_pool + self._decay = 0.99 + + embedding = torch.randn(num_words, num_channels).clamp(min=0) + self.register_buffer("_embedding", embedding) + self.register_buffer("_embedding_ptr", torch.zeros(1, dtype=torch.long)) + self.register_buffer("_track_num_batches", torch.zeros(1)) + self.register_buffer("_min_distance_mean", torch.ones(1) * 0.5) + + @torch.no_grad() + def _update_dictionary(self, features): + """Given a teacher feature map it updates the queue-based vocabulary.""" + assert features.dim() == 4 + if self._update_type in ("local_average", "no_averaging"): + if self._update_type == "local_average": + features = F.avg_pool2d(features, kernel_size=3, stride=1, padding=0) + features = features.flatten(2) + batch_size, _, num_locs = features.size() + index = torch.randint(0, num_locs, (batch_size,), device=features.device) + index += torch.arange(batch_size, device=features.device) * num_locs + selected_features = features.permute(0,2,1).reshape(batch_size*num_locs, -1) + selected_features = selected_features[index].contiguous() + elif self._update_type == "global_average": + selected_features = utils.global_pooling(features, type="avg").flatten(1) + + assert selected_features.dim() == 2 + # Gather the selected_features from all nodes in the distributed setting. + selected_features = utils.concat_all_gather(selected_features) + + # To simplify the queue update implementation, it is assumed that the + # number of words is a multiple of the batch-size. + assert self._num_words % selected_features.shape[0] == 0 + batch_size = selected_features.shape[0] + # Replace the oldest visual word embeddings with the selected ones + # using the self._embedding_ptr pointer. Note that each training step + # self._embedding_ptr points to the older visual words. + ptr = int(self._embedding_ptr) + self._embedding[ptr:(ptr + batch_size),:] = selected_features + # move the pointer. + self._embedding_ptr[0] = (ptr + batch_size) % self._num_words + + @torch.no_grad() + def get_dictionary(self): + """Returns the visual word embeddings of the dictionary/vocabulary.""" + return self._embedding.detach().clone() + + @torch.no_grad() + def _broadast_initial_dictionary(self): + # Make sure every node in the distributed setting starts with the + # same dictionary. Maybe this is not necessary and copying the buffers + # across the models on all gpus is handled by nn.DistributedDataParallel + embedding = self._embedding.data.clone() + torch.distributed.broadcast(embedding, src=0) + self._embedding.data.copy_(embedding) + + def forward(self, features): + """Given a teacher feature maps, it generates BoW targets.""" + features = features[:, :, 1:-1, 1:-1].contiguous() + + # Compute distances between features and visual words embeddings. + embeddings_b = self._embedding.pow(2).sum(1) + embeddings_w = -2*self._embedding.unsqueeze(2).unsqueeze(3) + # dist = ||features||^2 + |embeddings||^2 + conv(features, -2 * embedding) + dist = (features.pow(2).sum(1, keepdim=True) + + F.conv2d(features, weight=embeddings_w, bias=embeddings_b)) + # dist shape: [batch_size, num_words, height, width] + min_dist, enc_indices = torch.min(dist, dim=1) + mu_min_dist = min_dist.mean() + mu_min_dist = utils.reduce_all(mu_min_dist) / utils.get_world_size() + + if self.training: + # exponential moving average update of self._min_distance_mean. + self._min_distance_mean.data.mul_(self._decay).add_( + mu_min_dist, alpha=(1. - self._decay)) + self._update_dictionary(features) + self._track_num_batches += 1 + + if self._inv_delta is None: + # Hard assignment codes. + codes = dist.new_full(list(dist.shape), 0.0) + codes.scatter_(1, enc_indices.unsqueeze(1), 1) + else: + # Soft assignment codes. + inv_delta_adaptive = self._inv_delta / self._min_distance_mean + codes = F.softmax(-inv_delta_adaptive * dist, dim=1) + + # Reduce assignment codes to bag-of-word vectors with global pooling. + bow = utils.global_pooling(codes, type=self._bow_pool).flatten(1) + bow = F.normalize(bow, p=1, dim=1) # L1-normalization. + return bow, codes + + def extra_repr(self): + str_options = ( + f"num_words={self._num_words}, num_channels={self._num_channels}, " + f"update_type={self._update_type}, inv_delta={self._inv_delta}, " + f"pool={self._bow_pool}, " + f"decay={self._decay}, " + f"track_num_batches={self._track_num_batches.item()}") + return str_options + + +class BoWExtractorMultipleLevels(nn.Module): + def __init__(self, opts_list): + """Builds a BoW extractor for each BoW level.""" + super(BoWExtractorMultipleLevels, self).__init__() + assert isinstance(opts_list, (list, tuple)) + self.bow_extractor = nn.ModuleList([ + BoWExtractor(**opts) for opts in opts_list]) + + @torch.no_grad() + def get_dictionary(self): + """Returns the dictionary of visual words from each BoW level.""" + return [b.get_dictionary() for b in self.bow_extractor] + + def forward(self, features): + """Given a list of feature levels, it generates multi-level BoWs.""" + assert isinstance(features, (list, tuple)) + assert len(features) == len(self.bow_extractor) + out = list(zip(*[b(f) for b, f in zip(self.bow_extractor, features)])) + return out + + +class BoWPredictor(nn.Module): + def __init__( + self, + num_channels_out=2048, + num_channels_in=[1024, 2048], + num_channels_hidden=4096, + kappa=8, + learn_kappa=False + ): + """ Builds the dynamic BoW prediction head of the student network. + + It essentially builds a weight generation module for each BoW level for + which the student network needs to predict BoW. For example, in its + full version, OBoW uses two BoW levels, one for conv4 of ResNet (i.e., + penultimate feature scale of ResNet) and one for conv5 of ResNet (i.e., + final feature scale of ResNet). Therefore, in this case, the dynamic + BoW prediction head has two weight generation modules. + + Args: + num_channels_in: a list with the number of input feature channels for + each weight generation module. For example, if OBoW uses two BoW + levels and a ResNet50 backbone, then num_channels_in should be + [1024, 2048], where the first number is the number of channels of + the conv4 level of ResNet50 and the second number is the number of + channels of the conv5 level of ResNet50. + num_channels_out: the number of output feature channels for the weight + generation modules. + num_channels_hidden: the number of feature channels at the hidden + layers of the weight generator modules. + kappa: scalar with scale coefficient for the output weight vectors that + the weight generation modules produce. + learn_kappa (default False): if True kappa is a learnable parameter. + """ + super(BoWPredictor, self).__init__() + + assert isinstance(num_channels_in, (list, tuple)) + num_bow_levels = len(num_channels_in) + + generators = [] + for i in range(num_bow_levels): + generators.append(nn.Sequential()) + generators[i].add_module(f"b{i}_l2norm_in", utils.L2Normalize(dim=1)) + generators[i].add_module(f"b{i}_fc", nn.Linear(num_channels_in[i], num_channels_hidden, bias=False)) + generators[i].add_module(f"b{i}_bn", nn.BatchNorm1d(num_channels_hidden)) + generators[i].add_module(f"b{i}_rl", nn.ReLU(inplace=True)) + generators[i].add_module(f"b{i}_last_layer", nn.Linear(num_channels_hidden, num_channels_out)) + generators[i].add_module(f"b{i}_l2norm_out", utils.L2Normalize(dim=1)) + self.layers_w = nn.ModuleList(generators) + + self.scale = nn.Parameter( + torch.FloatTensor(num_bow_levels).fill_(kappa), + requires_grad=learn_kappa) + + def forward(self, features, dictionary): + """Dynamically predicts the BoW from the features of cropped images. + + During the forward pass, it gets as input a list with the features from + each type of extracted image crop and a list with the visual word + dictionaries of each BoW level. First, it uses the weight generation + modules for producing from each dictionary level the weight vectors + that would be used for the BoW prediction. Then, it applies the + produced weight vectors of each dictionary level to the given features + to compute the BoW prediction logits. + + Args: + features: list of 2D tensors where each of them is a mini-batch of + features (extracted from the image crops) with shape + [(batch_size * num_crops) x num_channels_out] from which the BoW + prediction head predicts the BoW targets. For example, in the full + version of OBoW, in which it reconstructs BoW from (a) 2 image crops + of size [160 x 160] and (b) 5 image patches of size [96 x 96], the + features argument includes a 2D tensor of shape + [(batch_size * 2) x num_channels_out] (extracted from the 2 + 160x160-sized crops) and a 2D tensor of shape + [(batch_size * 5) x num_channels_out] (extractted from the 5 + 96x96-sized crops). + dictionary: list of 2D tensors with the visual word embeddings + (i.e., dictionaries) for each BoW level. So, the i-th item of + dictionary has shape [num_words x num_channels_in[i]], where + num_channels_in[i] is the number of channels of the visual word + embeddings at the i-th BoW level. + + Output: + logits_list: list of lists of 2D tensors. Specifically, logits_list[i][j] + contains the 2D tensor of size [(batch_size * num_crops) x num_words] + with the BoW predictions from features[i] for the j-th BoW level + (made using the dictionary[j]). + """ + assert isinstance(dictionary, (list, tuple)) + assert len(dictionary) == len(self.layers_w) + + weight = [gen(dict).t() for gen, dict in zip(self.layers_w, dictionary)] + kappa = torch.split(self.scale, 1, dim=0) + logits_list = [ + [torch.mm(f.flatten(1) * k, w) for k, w in zip(kappa, weight)] + for f in features] + + return logits_list + + def extra_repr(self): + kappa = self.scale.data + s = f"(kappa, learnable={kappa.requires_grad}): {kappa.tolist()}" + return s + + +class OBoWSolver(solver.Solver): + def end_of_training_epoch(self): + if self._epoch == self._start_epoch: + # In case of distributed training, it checks if all processes have + # the same parameters. + utils.sanity_check_for_distributed_training(self.model) + + def start_of_training_epoch(self): + if self._epoch == 0: + if utils.is_dist_avail_and_initialized(): + # In case of distributed training, it ensures that alls + # processes start with the same version of the dictionaries. + for b in self.model.module.bow_extractor.bow_extractor: + b._broadast_initial_dictionary() + + if self._epoch == self._start_epoch: + # In case of distributed training, it checks if all processes have + # the same parameters. + utils.sanity_check_for_distributed_training(self.model) + + if utils.get_rank() == 0: + model = ( + self.model.module + if utils.is_dist_avail_and_initialized() else self.model) + alpha = ( + model._alpha.item() + if isinstance(model._alpha, torch.Tensor) else model._alpha) + logger.info(f"alpha: {alpha}") + + def evaluation_step(self, mini_batch, metric_logger): + if len(mini_batch) == 6: + return self._process_fewshot(mini_batch, metric_logger) + else: + return self.eval_lincls(mini_batch, metric_logger) + + def eval_lincls(self, mini_batch, metric_logger): + assert len(mini_batch) == 2 + images, labels = mini_batch + img_orig = images[0] + + if self.device is not None: + img_orig = img_orig.cuda(self.device , non_blocking=True) + labels = labels.cuda(self.device , non_blocking=True) + + batch_size = img_orig.size(0) + with torch.no_grad(): + # Forward model and compute lossses. + lincls, acc1 = self.model(img_orig, None, labels) + lincls = lincls.item() + + metric_logger["lincls"].update(lincls, batch_size) + metric_logger["acc@1"].update(acc1, batch_size) + + return metric_logger + + def train_step(self, mini_batch, metric_logger): + assert len(mini_batch) == 2 + images, labels = mini_batch + + if self.device is not None: + images = [img.cuda(self.device , non_blocking=True) for img in images] + labels = labels.cuda(self.device , non_blocking=True) + + img_orig = images[0] + img_crops = images[1:] + for i in range(len(img_crops)): + if img_crops[i].dim() == 5: + # [B x C x 3 x H x W] ==> [(B * C) x 3 x H x W] + img_crops[i] = utils.convert_from_5d_to_4d(img_crops[i]) + batch_size = img_orig.size(0) + + # Forward model and compute lossses. + losses, logs = self.model(img_orig, img_crops, labels) + loss_total = losses.sum() + + # compute gradient and do SGD step + self.optimizer.zero_grad() + loss_total.backward() + self.optimizer.step() + losses = losses.view(-1).tolist() + num_levels = (len(logs)-1) // 2 + assert len(logs) == (2*num_levels+1) + assert len(losses) == (len(img_crops) + 1) + + for i in range(len(img_crops)): + crop_sz = img_crops[i].size(3) + metric_logger[f"loss_crop{crop_sz}"].update(losses[0], batch_size) + + for i in range(num_levels): + metric_logger[f"perp_b_lev@{i}"].update(logs[i], batch_size) + metric_logger[f"perp_i_lev@{i}"].update(logs[i+num_levels], batch_size) + + # Linear classification accuracy. + metric_logger["linear_acc@1"].update(logs[-1], batch_size) + + return metric_logger + + def save_feature_extractor(self, distributed=False): + if utils.get_rank() == 0: + epoch = self._epoch + 1 + filename = f"feature_extractor_net_checkpoint_{epoch}.pth.tar" + filename = str(self.exp_dir / filename) + model = self.model.module if distributed else self.model + state = { + "epoch": epoch, + "network": model.feature_extractor.state_dict(), + "meters": None,} + torch.save(state, filename) + + def save_feature_extractor_in_torchvision_format(self, arch="resnet50"): + import torchvision.models as torchvision_models + distributed = utils.is_dist_avail_and_initialized() + model = self.model.module if distributed else self.model + model.eval() + + dictionary = model.bow_extractor.get_dictionary() + dictionary_w = [ + model.bow_predictor.layers_w[d](dictionary[d]) + for d in range(len(dictionary))] + scale = torch.chunk(model.bow_predictor.scale, len(dictionary), dim=0) + weight = dictionary_w[-1] * scale[-1].item() + num_words = weight.size(0) + logger.info('==> Converting and saving the OBoW student resnet ' + 'backbone to torchvision format.') + torchvision_resnet = torchvision_models.__dict__[arch]( + num_classes=num_words) + torchvision_resnet.eval() + + logger.info('====> Converting 1st convolutional layer (aka conv1)') + torchvision_resnet.conv1.load_state_dict( + model.feature_extractor._feature_blocks[0][0].state_dict()) + torchvision_resnet.bn1.load_state_dict( + model.feature_extractor._feature_blocks[0][1].state_dict()) + + logger.info('====> Converting 1st residual block (aka conv2).') + torchvision_resnet.layer1.load_state_dict( + model.feature_extractor._feature_blocks[1].state_dict()) + + logger.info('====> Converting 2nd residual block (aka conv3).') + torchvision_resnet.layer2.load_state_dict( + model.feature_extractor._feature_blocks[2].state_dict()) + + logger.info('====> Converting 3rd residual block (aka conv4).') + torchvision_resnet.layer3.load_state_dict( + model.feature_extractor._feature_blocks[3].state_dict()) + + logger.info('====> Converting 4th residual block (aka conv5).') + torchvision_resnet.layer4.load_state_dict( + model.feature_extractor._feature_blocks[4].state_dict()) + + logger.info('====> Converting and fixing the BoW classification ' + 'head for the last BoW level.') + with torch.no_grad(): + torchvision_resnet.fc.weight.copy_(weight) + torchvision_resnet.fc.bias.fill_(0) + + epoch = self._epoch + filename = f"tochvision_{arch}_student_K{num_words}_epoch{epoch}.pth.tar" + filename = str(self.exp_dir / filename) + logger.info(f'==> Saving the torchvision resnet model at: {filename}') + + if utils.get_rank() == 0: + torch.save({'network': torchvision_resnet.state_dict()}, filename) + + def _process_fewshot(self, episode, metric_logger): + """ Evaluates the OBoW's feature extractor on few-shot classifcation """ + images_train, labels_train, images_test, labels_test, _, _ = episode + if (self.device is not None): + images_train = images_train.cuda(self.device , non_blocking=True) + images_test = images_test.cuda(self.device , non_blocking=True) + labels_train = labels_train.cuda(self.device , non_blocking=True) + labels_test = labels_test.cuda(self.device , non_blocking=True) + + nKnovel = 1 + labels_train.max().item() + labels_train_1hot_size = list(labels_train.size()) + [nKnovel,] + labels_train_unsqueeze = labels_train.unsqueeze(dim=labels_train.dim()) + labels_train_1hot = images_train.new_full(labels_train_1hot_size, 0.0) + labels_train_1hot.scatter_( + len(labels_train_1hot_size) - 1, labels_train_unsqueeze, 1) + + model = ( + self.model.module if + utils.is_dist_avail_and_initialized() else self.model) + + model.feature_extractor.eval() + with torch.no_grad(): + _, accuracies = fewshot.fewshot_classification( + feature_extractor=model.feature_extractor, + images_train=images_train, + labels_train=labels_train, + labels_train_1hot=labels_train_1hot, + images_test=images_test, + labels_test=labels_test, + feature_levels=None) + + accuracies = accuracies.view(-1) + for i in range(accuracies.numel()): + metric_logger[f"acc_novel_@{i}"].update(accuracies[i].item(), 1) + + return metric_logger diff --git a/obow/classification.py b/obow/classification.py new file mode 100644 index 0000000..cf72c30 --- /dev/null +++ b/obow/classification.py @@ -0,0 +1,235 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import obow.utils as utils +import obow.solver as solver + +import logging +logger = utils.setup_dist_logger(logging.getLogger(__name__)) + + +class PredictionHead(nn.Module): + def __init__( + self, + num_channels, + num_classes, + batch_norm=False, + pred_type="linear", + pool_type="global_avg", + pool_params=None, + ): + """ Builds a prediction head for the classification task.""" + + super(PredictionHead, self).__init__() + + if pred_type != "linear": + raise NotImplementedError( + f"Not recognized / supported prediction head type '{pred_type}'." + f" Currently, only pred_type 'linear' is implemented.") + self.pred_type = pred_type + total_num_channels = num_channels + + self.layers = nn.Sequential() + if pool_type == "none": + if isinstance(pool_params, int): + output_size = pool_params + total_num_channels *= (output_size * output_size) + elif pool_type == "global_avg": + self.layers.add_module( + "pooling", utils.GlobalPooling(type="avg")) + elif pool_type == "avg": + assert isinstance(pool_params, (list, tuple)) + assert len(pool_params) == 4 + kernel_size, stride, padding, output_size = pool_params + total_num_channels *= (output_size * output_size) + self.layers.add_module( + "pooling", nn.AvgPool2d(kernel_size, stride, padding)) + elif pool_type == "adaptive_avg": + assert isinstance(pool_params, int) + output_size = pool_params + total_num_channels *= (output_size * output_size) + self.layers.add_module( + "pooling", nn.AdaptiveAvgPool2d(output_size)) + else: + raise NotImplementedError( + f"Not supported pool_type '{pool_type}'. Valid pooling types: " + "('none', 'global_avg', 'avg', 'adaptive_avg').") + + assert isinstance(batch_norm, bool) + if batch_norm: + # Affine is set to False. So, this batch norm layer does not have + # any learnable (scale and bias) parameters. It's only purpose is + # to normalize the features. So, the prediction layer is still + # linear. It is only used for the Places205 linear classification + # setting to make it the same as the benchmark code: + # https://github.com/facebookresearch/fair_self_supervision_benchmark + self.layers.add_module( + "batch_norm", nn.BatchNorm2d(num_channels, affine=False)) + self.layers.add_module("flattening", nn.Flatten()) + + prediction_layer = nn.Linear(total_num_channels, num_classes) + prediction_layer.weight.data.normal_(0.0, 0.01) + prediction_layer.bias.data.fill_(0.0) + + self.layers.add_module("prediction_layer", prediction_layer) + + def forward(self, features): + return self.layers(features) + + +class SupervisedClassification(nn.Module): + def __init__( + self, + feature_extractor, + linear_classifier_opts, + ): + super(SupervisedClassification, self).__init__() + self.feature_extractor = feature_extractor + self.linear_classifier = PredictionHead(**linear_classifier_opts) + + def forward(self, images, labels): + features = self.feature_extractor(images) + scores = self.linear_classifier(features) + loss = F.cross_entropy(scores, labels) + with torch.no_grad(): + accuracies = utils.accuracy(scores, labels, topk=(1,5)) + accuracies = [a.item() for a in accuracies] + + return loss, accuracies + + +class FrozenFeaturesLinearClassifier(nn.Module): + def __init__( + self, + feature_extractor, + linear_classifier_opts, + feature_levels=None, + ): + super(FrozenFeaturesLinearClassifier, self).__init__() + self.feature_levels = feature_levels + self.feature_extractor = feature_extractor + for param in feature_extractor.parameters(): + param.requires_grad = False + + self.linear_classifier = PredictionHead(**linear_classifier_opts) + + @torch.no_grad() + def precache_feature_extractor(self): + """Returns the feature extractor for precaching features.""" + out_feature_extractor = self.feature_extractor.get_subnetwork( + self.feature_levels) + self.feature_extractor = nn.Sequential() + liner_classifier_layers = self.linear_classifier._modules["layers"] + pooling_layer = liner_classifier_layers._modules.pop("pooling", None) + if (pooling_layer is not None): + out_feature_extractor.add_module( + "pooling_layer_from_linear_classifier", + pooling_layer) + return out_feature_extractor + + def forward(self, images, labels): + # Set to evaluation mode the feature extractor to avoid training / + # updating its batch norm statistics. + self.feature_extractor.eval() + with torch.no_grad(): + features = ( + self.feature_extractor(images) if self.feature_levels is None + else self.feature_extractor(images, self.feature_levels)) + scores = self.linear_classifier(features) + loss = F.cross_entropy(scores, labels) + with torch.no_grad(): + accuracies = utils.accuracy(scores, labels, topk=(1,5)) + accuracies = [a.item() for a in accuracies] + + return loss, accuracies + + +def get_parameters(model, start_lr_head=None): + if start_lr_head is not None: + # Use different learning rate for the classification head of the model. + def is_linear_head(key): + return key.find('linear_classifier.layers.prediction_layer.') != -1 + param_group_head = [ + param for key, param in model.named_parameters() + if param.requires_grad and is_linear_head(key)] + param_group_trunk = [ + param for key, param in model.named_parameters() + if param.requires_grad and (not is_linear_head(key))] + param_group_all = [ + param for key, param in model.named_parameters() + if param.requires_grad] + assert len(param_group_all) == (len(param_group_head) + len(param_group_trunk)) + parameters = [ + {"params": iter(param_group_head), "start_lr": start_lr_head}, + {"params": iter(param_group_trunk)}] + logger.info(f"#params in head: {len(param_group_head)}") + logger.info(f"#params in trunk: {len(param_group_trunk)}") + return parameters + else: + return filter(lambda p: p.requires_grad, model.parameters()) + + +def initialize_optimizer(model, opts): + logger.info(f"Initialize optimizer") + parameters = filter(lambda p: p.requires_grad, model.parameters()) + return solver.initialize_optimizer(parameters, opts) + + +class SupervisedClassifierSolver(solver.Solver): + def end_of_training_epoch(self): + if self._epoch == self._start_epoch: + utils.sanity_check_for_distributed_training(self.model, False) + + def start_of_training_epoch(self): + if self._epoch == self._start_epoch: + utils.sanity_check_for_distributed_training(self.model, False) + + def save_feature_extractor(self, distributed=False): + if utils.get_rank() == 0: + epoch = self._epoch + 1 + filename = f"feature_extractor_net_checkpoint_{epoch}.pth.tar" + filename = str(self.exp_dir / filename) + model = self.model.module if distributed else self.model + state = { + "epoch": epoch, + "network": model.feature_extractor.state_dict(), + "meters": None,} + torch.save(state, filename) + + def initialize_optimizer(self): + if self.optimizer is None: + logger.info(f"Initialize optimizer") + start_lr_head = self.opts.get("start_lr_head", None) + parameters = get_parameters(self.model, start_lr_head=start_lr_head) + self.optimizer = solver.initialize_optimizer(parameters, self.opts) + assert self.optimizer is not None + + def evaluation_step(self, mini_batch, metric_logger): + return self._process(mini_batch, metric_logger, training=False) + + def train_step(self, mini_batch, metric_logger): + return self._process(mini_batch, metric_logger, training=True) + + def _process(self, mini_batch, metric_logger, training): + assert isinstance(training, bool) + images, labels = mini_batch + if self.device is not None: + images = images.cuda(self.device , non_blocking=True) + labels = labels.cuda(self.device , non_blocking=True) + + with torch.set_grad_enabled(training): + loss, accuracies = self.model(images, labels) + + if training: + # compute gradient and do SGD step + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + batch_size = images.size(0) + assert isinstance(accuracies, (list, tuple)) and len(accuracies) == 2 + metric_logger[f"loss"].update(loss.item(), batch_size) + metric_logger["acc@1"].update(accuracies[0], batch_size) + metric_logger["acc@5"].update(accuracies[1], batch_size) + + return metric_logger diff --git a/obow/datasets.py b/obow/datasets.py new file mode 100644 index 0000000..beda749 --- /dev/null +++ b/obow/datasets.py @@ -0,0 +1,1360 @@ +from __future__ import print_function + +import os +import os.path +import numpy as np +import torch +import torch.utils.data as data +import torchvision.transforms as T +import torchvision.datasets +import random +import json + +from PIL import ImageFilter + + +_MEAN_PIXEL_IMAGENET = [0.485, 0.456, 0.406] +_STD_PIXEL_IMAGENET = [0.229, 0.224, 0.225] + + +def generate_element_list(list_size, dataset_size): + if list_size == dataset_size: + return list(range(dataset_size)) + elif list_size < dataset_size: + return np.random.choice( + dataset_size, list_size, replace=False).tolist() + else: # list_size > list_size + num_times = list_size // dataset_size + residual = list_size % dataset_size + assert((num_times * dataset_size + residual) == list_size) + elem_list = list(range(dataset_size)) * num_times + if residual: + elem_list += np.random.choice( + dataset_size, residual, replace=False).tolist() + + return elem_list + + +def buildLabelIndex(labels): + label2inds = {} + for idx, label in enumerate(labels): + if label not in label2inds: + label2inds[label] = [] + label2inds[label].append(idx) + + return label2inds + + +class ParallelTransforms: + def __init__(self, transform_list): + assert isinstance(transform_list, (list, tuple)) + self.transform_list = transform_list + + def __call__(self, x): + return [transform(x) for transform in self.transform_list] + + def __str__(self): + str_transforms = f"ParallelTransforms([" + for i, transform in enumerate(self.transform_list): + str_transforms += f"\nTransform #{i}:\n{transform}, " + str_transforms += "\n])" + return str_transforms + + +class StackMultipleViews: + def __init__(self, transform, num_views): + assert num_views >= 1 + self.transform = transform + self.num_views = num_views + + def __call__(self, x): + if self.num_views == 1: + return self.transform(x).unsqueeze(dim=0) + else: + x_views = [self.transform(x) for _ in range(self.num_views)] + return torch.stack(x_views, dim=0) + + def __str__(self): + str_transforms = f"StackMultipleViews({self.num_views} x \n{self.transform})" + return str_transforms + + +class GaussianBlur(object): + def __init__(self, sigma=[.1, 2.]): + self.sigma = sigma + + def __call__(self, x): + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) + return x + + def __str__(self): + str_transforms = f"GaussianBlur(sigma={self.sigma})" + return str_transforms + + +class CropImagePatches: + """Crops from an image 3 x 3 overlapping patches.""" + def __init__( + self, + patch_size=96, + patch_jitter=24, + num_patches=5, + split_per_side=3): + + self.split_per_side = split_per_side + self.patch_size = patch_size + assert patch_jitter >= 0 + self.patch_jitter = patch_jitter + if num_patches is None: + num_patches = split_per_side**2 + assert num_patches > 0 and num_patches <= (split_per_side**2) + self.num_patches = num_patches + + def __call__(self, img): + _, height, width = img.size() + offset_y = ((height - self.patch_size - self.patch_jitter) + // (self.split_per_side - 1)) + offset_x = ((width - self.patch_size - self.patch_jitter) + // (self.split_per_side - 1)) + + patches = [] + for i in range(self.split_per_side): + for j in range(self.split_per_side): + y_top = i * offset_y + random.randint(0, self.patch_jitter) + x_left = j * offset_x + random.randint(0, self.patch_jitter) + y_bottom = y_top + self.patch_size + x_right = x_left + self.patch_size + patches.append(img[:, y_top:y_bottom, x_left:x_right]) + + if self.num_patches < (self.split_per_side * self.split_per_side): + indices = torch.randperm(len(patches))[:self.num_patches] + patches = [patches[i] for i in indices.tolist()] + + return torch.stack(patches, dim=0) + + def __str__(self): + print_str = ( + f'{self.__class__.__name__}(' + f'split_per_side={self.split_per_side}, ' + f'patch_size={self.patch_size}, ' + f'patch_jitter={self.patch_jitter}, ' + f'num_patches={self.num_patches}/{self.split_per_side**2})' + ) + return print_str + + +def subset_of_ImageNet_train_split(dataset_train, subset): + assert isinstance(subset, int) + assert subset > 0 + + all_indices = [] + for _, img_indices in buildLabelIndex(dataset_train.targets).items(): + assert len(img_indices) >= subset + all_indices += img_indices[:subset] + + dataset_train.imgs = [dataset_train.imgs[idx] for idx in all_indices] + dataset_train.samples = [dataset_train.samples[idx] for idx in all_indices] + dataset_train.targets = [dataset_train.targets[idx] for idx in all_indices] + assert len(dataset_train) == (subset * 1000) + + return dataset_train + + +def get_ImageNet_data_for_obow( + data_dir, + subset=None, + cjitter=[0.4, 0.4, 0.4, 0.1], + cjitter_p=0.8, + gray_p=0.2, + gaussian_blur=[0.1, 2.0], + gaussian_blur_p=0.5, + num_img_crops=2, + image_crop_size=160, + image_crop_range=[0.08, 0.6], + num_img_patches=0, + img_patch_preresize=256, + img_patch_preresize_range=[0.6, 1.0], + img_patch_size=96, + img_patch_jitter=24, + only_patches=False): + + normalize = T.Normalize(mean=_MEAN_PIXEL_IMAGENET, std=_STD_PIXEL_IMAGENET) + + image_crops_transform = T.Compose([ + T.RandomResizedCrop(image_crop_size, scale=image_crop_range), + T.RandomApply([T.ColorJitter(*cjitter)], p=cjitter_p), + T.RandomGrayscale(p=gray_p), + T.RandomApply([GaussianBlur(gaussian_blur)], p=gaussian_blur_p), + T.RandomHorizontalFlip(), + T.ToTensor(), + normalize, + ]) + + image_crops_transform = StackMultipleViews( + image_crops_transform, num_views=num_img_crops) + + transform_original_train = T.Compose([ + T.Resize(256), + T.CenterCrop(224), + T.RandomHorizontalFlip(), # So as, to see both image views. + T.ToTensor(), + normalize, + ]) + transform_train = [transform_original_train, image_crops_transform] + + if num_img_patches > 0: + assert num_img_patches <= 9 + image_patch_transform = T.Compose([ + T.RandomResizedCrop(img_patch_preresize, scale=img_patch_preresize_range), + T.RandomApply([T.ColorJitter(*cjitter)], p=cjitter_p), + T.RandomGrayscale(p=gray_p), + T.RandomHorizontalFlip(), + T.ToTensor(), + normalize, + CropImagePatches( + patch_size=img_patch_size, patch_jitter=img_patch_jitter, + num_patches=num_img_patches, split_per_side=3), + ]) + if only_patches: + transform_train[-1] = image_patch_transform + else: + transform_train.append(image_patch_transform) + + transform_train = ParallelTransforms(transform_train) + + transform_original = T.Compose([ + T.Resize(256), + T.CenterCrop(224), + T.ToTensor(), + normalize, + ]) + transform_test = ParallelTransforms([transform_original,]) + + print(f"Image transforms during training: {transform_train}") + print(f"Image transforms during testing: {transform_test}") + + print("Loading data.") + dataset_train = torchvision.datasets.ImageFolder( + os.path.join(data_dir, 'train'), transform=transform_train) + dataset_test = torchvision.datasets.ImageFolder( + os.path.join(data_dir, 'val'), transform=transform_test) + + if (subset is not None) and (subset >= 1): + dataset_train = subset_of_ImageNet_train_split(dataset_train, subset) + + return dataset_train, dataset_test + + +def get_data_loaders_for_OBoW( + dataset_name, + data_dir, + batch_size, + workers, + distributed, + epoch_size, + **kwargs): + + assert isinstance(dataset_name, str) + assert isinstance(batch_size, int) + assert isinstance(workers, int) + assert isinstance(distributed, bool) + assert (epoch_size is None) or isinstance(epoch_size, int) + + if dataset_name == "ImageNet": + dataset_train, dataset_test = get_ImageNet_data_for_obow(data_dir, **kwargs) + else: + raise NotImplementedError(f"Not supported dataset {dataset_name}") + + if (epoch_size is not None) and (epoch_size != len(dataset_train)): + elem_list = generate_element_list(epoch_size, len(dataset_train)) + dataset_train = torch.utils.data.Subset(dataset_train, elem_list) + + print("Creating data loaders") + if distributed: + sampler_train = torch.utils.data.distributed.DistributedSampler(dataset_train) + sampler_test = torch.utils.data.distributed.DistributedSampler(dataset_test) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_test = torch.utils.data.SequentialSampler(dataset_test) + + loader_train = torch.utils.data.DataLoader( + dataset_train, + batch_size=batch_size, + shuffle=(sampler_train is None), + num_workers=workers, + pin_memory=True, + sampler=sampler_train, + drop_last=True) + + loader_test = torch.utils.data.DataLoader( + dataset_test, + batch_size=batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, + sampler=sampler_test, + drop_last=False) + + return ( + loader_train, sampler_train, dataset_train, + loader_test, sampler_test, dataset_test) + + +#**************** DATA LOADERS FOR IMAGE CLASSIFICATIONS *********************** +def get_ImageNet_data_classification(data_dir, subset=None): + normalize = T.Normalize(mean=_MEAN_PIXEL_IMAGENET, std=_STD_PIXEL_IMAGENET) + transform_train = T.Compose([ + T.RandomResizedCrop(224), + T.RandomHorizontalFlip(), + T.ToTensor(), + normalize, + ]) + transform_test = T.Compose([ + T.Resize(256), + T.CenterCrop(224), + T.ToTensor(), + normalize, + ]) + + print("Loading data.") + dataset_train = torchvision.datasets.ImageFolder( + os.path.join(data_dir, 'train'), transform=transform_train) + dataset_test = torchvision.datasets.ImageFolder( + os.path.join(data_dir, 'val'), transform=transform_test) + + if (subset is not None) and (subset >= 1): + dataset_train = subset_of_ImageNet_train_split(dataset_train, subset) + + return dataset_train, dataset_test + + +def get_ImageNet_data_semisupervised_classification(data_dir, percentage=1): + normalize = T.Normalize( + mean=_MEAN_PIXEL_IMAGENET, + std=_STD_PIXEL_IMAGENET + ) + transform_train = T.Compose([ + T.RandomResizedCrop(224), + T.RandomHorizontalFlip(), + T.ToTensor(), + normalize, + ]) + + transform_test = T.Compose([ + T.Resize(256), + T.CenterCrop(224), + T.ToTensor(), + normalize, + ]) + + print("Loading data.") + train_data_path = os.path.join(data_dir, "train") + dataset_train = torchvision.datasets.ImageFolder( + train_data_path, transform=transform_train) + dataset_test = torchvision.datasets.ImageFolder( + os.path.join(data_dir, 'val'), transform=transform_test) + + # take either 1% or 10% of images + assert percentage in (1, 10) + import urllib.request + BASE_URL_PATH = "https://raw.githubusercontent.com/google-research/simclr/master/imagenet_subsets/" + subset_file = urllib.request.urlopen( + BASE_URL_PATH + str(percentage) + "percent.txt") + list_imgs = [li.decode("utf-8").split('\n')[0] for li in subset_file] + + samples, imgs, targets = [], [], [] + for file in list_imgs: + file_path = pathlib.Path(os.path.join(train_data_path, file.split('_')[0], file)) + assert file_path.is_file() + file_path_str = str(file_path) + target = dataset_train.class_to_idx[file.split('_')[0]] + imgs.append((file_path_str, target)) + targets.append(targets) + samples.append((file_path_str, target)) + + dataset_train.imgs = imgs + dataset_train.targets = targets + dataset_train.samples = samples + + assert len(dataset_train) == len(list_imgs) + + return dataset_train, dataset_test + + +def get_Places205_data_classification(data_dir): + normalize = T.Normalize( + mean=_MEAN_PIXEL_IMAGENET, + std=_STD_PIXEL_IMAGENET + ) + transform_train = T.Compose([ + T.Resize(256), + T.RandomCrop(224), + T.RandomHorizontalFlip(), + T.ToTensor(), + normalize, + ]) + + transform_test = T.Compose([ + T.Resize(256), + T.CenterCrop(224), + T.ToTensor(), + normalize, + ]) + + print("Loading data.") + dataset_train = torchvision.datasets.ImageFolder( + os.path.join(data_dir, 'train'), transform=transform_train) + dataset_test = torchvision.datasets.ImageFolder( + os.path.join(data_dir, 'val'), transform=transform_test) + + return dataset_train, dataset_test + + +def get_data_loaders_classification( + dataset_name, + data_dir, + batch_size, + workers, + distributed, + epoch_size, + **kwargs): + + assert isinstance(dataset_name, str) + assert isinstance(batch_size, int) + assert isinstance(workers, int) + assert isinstance(distributed, bool) + assert (epoch_size is None) or isinstance(epoch_size, int) + + if dataset_name == "ImageNet": + dataset_train, dataset_test = get_ImageNet_data_classification( + data_dir, **kwargs) + elif dataset_name == "Places205": + dataset_train, dataset_test = get_Places205_data_classification( + data_dir) + else: + raise NotImplementedError(f"Not supported dataset {dataset_name}") + + if (epoch_size is not None) and (epoch_size != len(dataset_train)): + elem_list = generate_element_list(epoch_size, len(dataset_train)) + dataset_train = torch.utils.data.Subset(dataset_train, elem_list) + + print("Creating data loaders") + if distributed: + sampler_train = torch.utils.data.distributed.DistributedSampler( + dataset_train) + sampler_test = torch.utils.data.distributed.DistributedSampler( + dataset_test) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_test = torch.utils.data.SequentialSampler(dataset_test) + + loader_train = torch.utils.data.DataLoader( + dataset_train, + batch_size=batch_size, + shuffle=(sampler_train is None), + num_workers=workers, + pin_memory=True, + sampler=sampler_train, + drop_last=True) + + loader_test = torch.utils.data.DataLoader( + dataset_test, + batch_size=batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, + sampler=sampler_test, + drop_last=False) + + return ( + loader_train, sampler_train, dataset_train, + loader_test, sampler_test, dataset_test) + + +def get_data_loaders_semisupervised_classification( + dataset_name, + data_dir, + batch_size, + workers, + distributed, + epoch_size, + percentage, + **kwargs): + + assert isinstance(dataset_name, str) + assert isinstance(batch_size, int) + assert isinstance(workers, int) + assert isinstance(distributed, bool) + assert (epoch_size is None) or isinstance(epoch_size, int) + + if dataset_name == "ImageNet": + dataset_train, dataset_test = get_ImageNet_data_semisupervised_classification( + data_dir, percentage=percentage, **kwargs) + else: + raise NotImplementedError(f"Not supported dataset {dataset_name}") + + if (epoch_size is not None) and (epoch_size != len(dataset_train)): + elem_list = generate_element_list(epoch_size, len(dataset_train)) + dataset_train = torch.utils.data.Subset(dataset_train, elem_list) + + print("Creating data loaders") + if distributed: + sampler_train = torch.utils.data.distributed.DistributedSampler( + dataset_train) + sampler_test = torch.utils.data.distributed.DistributedSampler( + dataset_test) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_test = torch.utils.data.SequentialSampler(dataset_test) + + loader_train = torch.utils.data.DataLoader( + dataset_train, + batch_size=batch_size, + shuffle=(sampler_train is None), + num_workers=workers, + pin_memory=True, + sampler=sampler_train, + drop_last=True) + + loader_test = torch.utils.data.DataLoader( + dataset_test, + batch_size=batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, + sampler=sampler_test, + drop_last=False) + + return ( + loader_train, sampler_train, dataset_train, + loader_test, sampler_test, dataset_test) + + +#*************** DATA LOADERS FOR FEW-SHOT EVALUATION ************************** +def load_ImageNet_fewshot_split(class_names, version=1): + _IMAGENET_LOWSHOT_BENCHMARK_CATEGORY_SPLITS_PATH = ( + './data/IMAGENET_LOWSHOT_BENCHMARK_CATEGORY_SPLITS.json') + with open(_IMAGENET_LOWSHOT_BENCHMARK_CATEGORY_SPLITS_PATH, 'r') as f: + label_idx = json.load(f) + + assert len(label_idx['label_names']) == len(class_names) + + def get_class_indices(class_indices1): + class_indices2 = [] + for index in class_indices1: + class_name_this = label_idx['label_names'][index] + assert class_name_this in class_names + class_indices2.append(class_names.index(class_name_this)) + + class_names_tmp1 = [ + label_idx['label_names'][index] for index in class_indices1] + class_names_tmp2 = [class_names[index] for index in class_indices2] + + assert class_names_tmp1 == class_names_tmp2 + + return class_indices2 + + if version == 1: + base_classes = label_idx['base_classes'] + base_classes_val = label_idx['base_classes_1'] + base_classes_test = label_idx['base_classes_2'] + novel_classes_val = label_idx['novel_classes_1'] + novel_classes_test = label_idx['novel_classes_2'] + elif version == 2: + base_classes = get_class_indices(label_idx['base_classes']) + base_classes_val = get_class_indices(label_idx['base_classes_1']) + base_classes_test = get_class_indices(label_idx['base_classes_2']) + novel_classes_val = get_class_indices(label_idx['novel_classes_1']) + novel_classes_test = get_class_indices(label_idx['novel_classes_2']) + + return (base_classes, + base_classes_val, base_classes_test, + novel_classes_val, novel_classes_test) + + +class ImageNetLowShot: + def __init__(self, dataset, phase='train'): + assert phase in ('train', 'test', 'val') + self.data = dataset + self.phase = phase + print(f'Loading ImageNet dataset (few-shot benchmark) - phase {phase}') + #*********************************************************************** + (base_classes, _, _, novel_classes_val, novel_classes_test) = ( + load_ImageNet_fewshot_split(self.data.classes, version=1)) + #*********************************************************************** + + self.labels = [item[1] for item in self.data.imgs] + self.label2ind = buildLabelIndex(self.labels) + self.labelIds = sorted(self.label2ind.keys()) + self.num_cats = len(self.labelIds) + assert self.num_cats == 1000 + + self.labelIds_base = base_classes + self.num_cats_base = len(self.labelIds_base) + if self.phase=='val' or self.phase=='test': + self.labelIds_novel = ( + novel_classes_val if (self.phase=='val') else + novel_classes_test) + self.num_cats_novel = len(self.labelIds_novel) + + intersection = set(self.labelIds_base) & set(self.labelIds_novel) + assert len(intersection) == 0 + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return len(self.data) + + +class FewShotDataset: + def __init__( + self, + dataset, + nKnovel=5, + nKbase=0, + nExemplars=1, + nTestNovel=15*5, + nTestBase=0, + epoch_size=500): + + self.dataset = dataset + self.phase = self.dataset.phase + max_possible_nKnovel = ( + self.dataset.num_cats_base if ( + self.phase=='train' or self.phase=='trainval') + else self.dataset.num_cats_novel) + + assert 0 <= nKnovel <= max_possible_nKnovel + self.nKnovel = nKnovel + + max_possible_nKbase = self.dataset.num_cats_base + nKbase = nKbase if nKbase >= 0 else max_possible_nKbase + if (self.phase=='train' or self.phase=='trainval') and nKbase > 0: + nKbase -= self.nKnovel + max_possible_nKbase -= self.nKnovel + + assert 0 <= nKbase <= max_possible_nKbase + self.nKbase = nKbase + self.nExemplars = nExemplars + self.nTestNovel = nTestNovel + self.nTestBase = nTestBase + self.epoch_size = epoch_size + self.is_eval_mode = (self.phase=='test') or (self.phase=='val') + + # remeber this state + state = random.getstate() + np_state = np.random.get_state() + + random.seed(0) + np.random.seed(0) + self._all_episodes = [] + for i in range(self.epoch_size): + Exemplars, Test, Kall, nKbase = self.sample_episode() + self._all_episodes.append((Exemplars, Test, Kall, nKbase)) + + # restore state + random.setstate(state) + np.random.set_state(np_state) + + + def sampleImageIdsFrom(self, cat_id, sample_size=1): + """ + Samples `sample_size` number of unique image ids picked from the + category `cat_id` (i.e., self.dataset.label2ind[cat_id]). + + Args: + cat_id: a scalar with the id of the category from which images will + be sampled. + sample_size: number of images that will be sampled. + + Returns: + image_ids: a list of length `sample_size` with unique image ids. + """ + assert(cat_id in self.dataset.label2ind.keys()) + assert(len(self.dataset.label2ind[cat_id]) >= sample_size) + # Note: random.sample samples elements without replacement. + return random.sample(self.dataset.label2ind[cat_id], sample_size) + + def sampleCategories(self, cat_set, sample_size=1): + """ + Samples `sample_size` number of unique categories picked from the + `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'. + + Args: + cat_set: string that specifies the set of categories from which + categories will be sampled. + sample_size: number of categories that will be sampled. + + Returns: + cat_ids: a list of length `sample_size` with unique category ids. + """ + if cat_set=='base': + labelIds = self.dataset.labelIds_base + elif cat_set=='novel': + labelIds = self.dataset.labelIds_novel + else: + raise ValueError('Not recognized category set {}'.format(cat_set)) + + assert(len(labelIds) >= sample_size) + # return sample_size unique categories chosen from labelIds set of + # categories (that can be either self.labelIds_base or self.labelIds_novel) + # Note: random.sample samples elements without replacement. + return random.sample(labelIds, sample_size) + + def sample_base_and_novel_categories(self, nKbase, nKnovel): + """ + Samples `nKbase` number of base categories and `nKnovel` number of novel + categories. + + Args: + nKbase: number of base categories + nKnovel: number of novel categories + + Returns: + Kbase: a list of length 'nKbase' with the ids of the sampled base + categories. + Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel + categories. + """ + if self.is_eval_mode: + assert(nKnovel <= self.dataset.num_cats_novel) + # sample from the set of base categories 'nKbase' number of base + # categories. + Kbase = sorted(self.sampleCategories('base', nKbase)) + # sample from the set of novel categories 'nKnovel' number of novel + # categories. + Knovel = sorted(self.sampleCategories('novel', nKnovel)) + else: + # sample from the set of base categories 'nKnovel' + 'nKbase' number + # of categories. + cats_ids = self.sampleCategories('base', nKnovel+nKbase) + assert(len(cats_ids) == (nKnovel+nKbase)) + # Randomly pick 'nKnovel' number of fake novel categories and keep + # the rest as base categories. + random.shuffle(cats_ids) + Knovel = sorted(cats_ids[:nKnovel]) + Kbase = sorted(cats_ids[nKnovel:]) + + return Kbase, Knovel + + def sample_test_examples_for_base_categories(self, Kbase, nTestBase): + """ + Sample `nTestBase` number of images from the `Kbase` categories. + + Args: + Kbase: a list of length `nKbase` with the ids of the categories from + where the images will be sampled. + nTestBase: the total number of images that will be sampled. + + Returns: + Tbase: a list of length `nTestBase` with 2-element tuples. The 1st + element of each tuple is the image id that was sampled and the + 2nd elemend is its category label (which is in the range + [0, len(Kbase)-1]). + """ + Tbase = [] + if len(Kbase) > 0: + # Sample for each base category a number images such that the total + # number sampled images of all categories to be equal to `nTestBase`. + KbaseIndices = np.random.choice( + np.arange(len(Kbase)), size=nTestBase, replace=True) + KbaseIndices, NumImagesPerCategory = np.unique( + KbaseIndices, return_counts=True) + for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory): + imd_ids = self.sampleImageIdsFrom( + Kbase[Kbase_idx], sample_size=NumImages) + Tbase += [(img_id, Kbase_idx) for img_id in imd_ids] + + assert len(Tbase) == nTestBase + + return Tbase + + def sample_train_and_test_examples_for_novel_categories( + self, Knovel, nTestExamplesTotal, nExemplars, nKbase): + """Samples train and test examples of the novel categories. + + Args: + Knovel: a list with the ids of the novel categories. + nTestExamplesTotal: the total number of test images that will be sampled + from all the novel categories. + nExemplars: the number of training examples per novel category that + will be sampled. + nKbase: the number of base categories. It is used as offset of the + category index of each sampled image. + + Returns: + Tnovel: a list of length `nTestNovel` with 2-element tuples. The + 1st element of each tuple is the image id that was sampled and + the 2nd element is its category label (which is in the range + [nKbase, nKbase + len(Knovel) - 1]). + Exemplars: a list of length len(Knovel) * nExemplars of 2-element + tuples. The 1st element of each tuple is the image id that was + sampled and the 2nd element is its category label (which is in + the ragne [nKbase, nKbase + len(Knovel) - 1]). + """ + + if len(Knovel) == 0: + return [], [] + + nKnovel = len(Knovel) + Tnovel = [] + Exemplars = [] + + assert (nTestExamplesTotal % nKnovel) == 0 + nTestExamples = nTestExamplesTotal // nKnovel + + for Knovel_idx in range(len(Knovel)): + img_ids = self.sampleImageIdsFrom( + Knovel[Knovel_idx], + sample_size=(nTestExamples + nExemplars)) + + img_labeled = img_ids[:(nTestExamples + nExemplars)] + img_tnovel = img_labeled[:nTestExamples] + img_exemplars = img_labeled[nTestExamples:] + + Tnovel += [ + (img_id, nKbase+Knovel_idx) for img_id in img_tnovel] + Exemplars += [ + (img_id, nKbase+Knovel_idx) for img_id in img_exemplars] + + assert len(Tnovel) == nTestExamplesTotal + assert len(Exemplars) == len(Knovel) * nExemplars + random.shuffle(Exemplars) + + return Tnovel, Exemplars + + def sample_episode(self): + """Samples a training episode.""" + nKnovel = self.nKnovel + nKbase = self.nKbase + nTestNovel = self.nTestNovel + nTestBase = self.nTestBase + nExemplars = self.nExemplars + + Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel) + Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase) + Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories( + Knovel, nTestNovel, nExemplars, nKbase) + + # concatenate the base and novel category examples. + Test = Tbase + Tnovel + random.shuffle(Test) + Kall = Kbase + Knovel + return Exemplars, Test, Kall, nKbase + + + def createExamplesTensorData(self, examples): + """ + Creates the examples image and label tensor data. + + Args: + examples: a list of 2-element tuples, each representing a + train or test example. The 1st element of each tuple + is the image id of the example and 2nd element is the + category label of the example, which is in the range + [0, nK - 1], where nK is the total number of categories + (both novel and base). + + Returns: + images: a tensor of shape [nExamples, Height, Width, 3] with the + example images, where nExamples is the number of examples + (i.e., nExamples = len(examples)). + labels: a tensor of shape [nExamples] with the category label + of each example. + """ + images = torch.stack( + [self.dataset[img_idx][0] for img_idx, _ in examples], + dim=0) + labels = torch.LongTensor( + [label for _, label in examples]) + return images, labels + + def __getitem__(self, index): + Exemplars, Test, Kall, nKbase = self._all_episodes[index] + Xt, Yt = self.createExamplesTensorData(Test) + Kall = torch.LongTensor(Kall) + if len(Exemplars) > 0: + Xe, Ye = self.createExamplesTensorData(Exemplars) + return Xe, Ye, Xt, Yt, Kall, nKbase + else: + return Xt, Yt, Kall, nKbase + + def __len__(self): + return self.epoch_size // self.batch_size + + +def get_ImageNet_fewshot_data(data_dir, split): + normalize = T.Normalize( + mean=_MEAN_PIXEL_IMAGENET, + std=_STD_PIXEL_IMAGENET + ) + transform = T.Compose([ + T.Resize(256), + T.CenterCrop(224), + T.ToTensor(), + normalize, + ]) + + print("Loading data.") + dataset = torchvision.datasets.ImageFolder( + os.path.join(data_dir, "val"), transform=transform) + + assert split in ("train", "val", "test") + dataset = ImageNetLowShot(dataset, phase=split) + + return dataset + + +def get_data_loaders_fewshot( + dataset_name, + data_dir, + batch_size, + workers, + distributed, + split='val', + epoch_size=500, + num_novel=5, + num_train=[1,5,10], + num_test=15*5): + + assert isinstance(dataset_name, str) + assert isinstance(batch_size, int) + assert isinstance(workers, int) + assert isinstance(distributed, bool) + + if dataset_name == "ImageNet": + dataset = get_ImageNet_fewshot_data(data_dir, split=split) + else: + raise NotImplementedError(f"Not supported dataset {dataset_name}.") + + assert isinstance(epoch_size, int) + sampler = torch.utils.data.SubsetRandomSampler(list(range(epoch_size))) + + if isinstance(num_train, int): + num_train = [num_train,] + + dataset_fewshot = [] + loader_fewshot = [] + for num_train_this in num_train: + dataset_fewshot_this = FewShotDataset( + dataset=dataset, + nKnovel=num_novel, + nKbase=0, + nExemplars=num_train_this, + nTestNovel=num_test, + nTestBase=0, + epoch_size=epoch_size) + + loader_fewshot_this = torch.utils.data.DataLoader( + dataset_fewshot_this, + batch_size=batch_size, + shuffle=False, + num_workers=workers, + pin_memory=False, + sampler=sampler, + drop_last=False) + dataset_fewshot.append(dataset_fewshot_this) + loader_fewshot.append(loader_fewshot_this) + + return loader_fewshot, sampler, dataset_fewshot + + +#******************************************************************************* +def get_ImageNet_data_for_visualization(data_dir, subset=None, split="train"): + normalize = T.Normalize( + mean=_MEAN_PIXEL_IMAGENET, + std=_STD_PIXEL_IMAGENET + ) + transform = T.Compose([ + T.Resize(256), + T.CenterCrop(224), + T.ToTensor(), + normalize, + ]) + + dataset = torchvision.datasets.ImageFolder( + os.path.join(data_dir, split), transform=transform) + + if (split == "train") and subset is not None: + assert isinstance(subset, int) + assert subset > 0 + + all_indices = [] + for _, img_indices in buildLabelIndex(dataset.targets).items(): + assert len(img_indices) >= subset + all_indices += img_indices[:subset] + + dataset.imgs = [dataset.imgs[idx] for idx in all_indices] + dataset.samples = [dataset.samples[idx] for idx in all_indices] + dataset.targets = [dataset.targets[idx] for idx in all_indices] + assert len(dataset) == (subset * 1000) + + return dataset + + +def get_data_loaders_for_visualization( + dataset_name, + data_dir, + batch_size, + workers, + distributed, + split, + **kwargs): + + assert isinstance(dataset_name, str) + assert isinstance(batch_size, int) + assert isinstance(workers, int) + assert isinstance(distributed, bool) + + if dataset_name == "ImageNet": + subset = kwargs.pop("subset", None) + dataset = get_ImageNet_data_for_visualization( + data_dir, split=split, subset=subset) + else: + raise NotImplementedError(f"Not supported dataset {dataset_name}.") + + print("Creating data loaders") + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, + sampler=torch.utils.data.SequentialSampler(dataset), + drop_last=False) + + + return loader, dataset +#******************************************************************************* + + +#******************************************************************************* +# Code for pre-caching the features of the linear classifier. +#******************************************************************************* +import pathlib +from tqdm import tqdm + +NUM_VIEWS = 10 +CENTRAL_VIEW = 4 +COMMON_NP_TYPES = [ + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.int8, + np.int16, + np.int32, + np.int64, + np.float32, + np.float64 +] + +def str_to_dtype(string): + for dtype in COMMON_NP_TYPES: + if dtype.__name__ == string: + return dtype + raise KeyError + + +def create_memmap(dir_path, dtype, shape): + dir_path = pathlib.Path(dir_path) + dir_path.mkdir(parents=True, exist_ok=True) + + def dtype_to_str(dtype): + return dtype.__name__ + + metadata_dict = dict(dtype=dtype_to_str(dtype), shape=list(shape), count=0) + metadata_file = dir_path / 'metadata.json' + + def update_metadata(count): + metadata_dict['count'] = count + print(f'Update metadata with metadata_dict={metadata_dict}') + metadata_file.write_text(json.dumps(metadata_dict, indent=4)) + + update_metadata(count=0) + memmap_file = dir_path / 'memmap.npy' + memmap = np.memmap(memmap_file, dtype, mode='w+', shape=shape) + return memmap, update_metadata + + +def open_memmap(dir_path): + dir_path = pathlib.Path(dir_path) + metadata_dict = json.loads((dir_path / 'metadata.json').read_text()) + dtype = str_to_dtype(metadata_dict['dtype']) + shape = tuple(metadata_dict['shape']) + #count = metadata_dict['count'] + return np.memmap(dir_path / 'memmap.npy', dtype, 'r+', shape=shape) + + +class ExtractCropsDataset: + def __init__(self, dataset, init_size, crop_size, mean, std, five_crop=True): + self.data = dataset + self.init_size = init_size + self.crop_size = crop_size + self.five_crop = five_crop + + if five_crop: + self.crop = T.Compose([ + T.Resize(self.init_size), + T.FiveCrop((self.crop_size, self.crop_size)), + ]) + else: + self.crop = T.Compose([ + T.Resize(self.init_size), + T.CenterCrop(self.crop_size), + ]) + + self.normalize = T.Compose([ + T.ToTensor(), + T.Normalize(mean=mean, std=std), + ]) + + def __repr__(self): + suffix = f'(\nfive_crop={self.crop}, normalize={self.normalize}\n)' + return self.__class__.__name__ + suffix + + def __getitem__(self, index): + img, labels = self.data[index] + crops = self.crop(img) + if self.five_crop: + crops = torch.stack([self.normalize(x) for x in crops], dim=0) + else: + crops = self.normalize(crops).unsqueeze(0) + + assert crops.dim() == 4 + assert ( + (self.five_crop and crops.size(0) == 5) or + (not self.five_crop and crops.size(0) == 1)) + assert crops.size(1) == 3 + assert crops.size(2) == self.crop_size + assert crops.size(3) == self.crop_size + + return crops, labels + + def __len__(self): + return len(self.data) + + +def make_memmap_crops( + memmap_path, + feature_extractor, + dataset, + device, + num_workers, + batch_size, + num_views=5): + + feature_extractor = feature_extractor.cuda() + feature_extractor.eval() + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + batch_size=batch_size, + shuffle=False, + drop_last=False, + pin_memory=True) + + update_metadata_step = 100 + memmap = None + count = 0 + num_imgs = len(dataset) + num_views *= 2 + num_features = len(dataset) * num_views + + with torch.no_grad(): + for i, (crops, _) in enumerate(tqdm(dataloader)): + crops = crops.cuda(device, non_blocking=True) + assert crops.dim() == 5 + # Add crop flips. + crops = torch.cat([crops, torch.flip(crops, dims=(4,))], dim=1) + assert crops.size(1) == num_views + batch_size_x_num_views = crops.size(0) * num_views + crops = crops.view([batch_size_x_num_views, ] + list(crops.size()[2:])) + features = feature_extractor(crops) + features_np = features.detach().cpu().numpy() + + if memmap is None: + memmap_shape = (num_features,) + features_np.shape[1:] + print(f'Creating dataset of size: {memmap_shape}') + memmap, update_metadata = create_memmap( + memmap_path, np.float32, memmap_shape) + + memmap[count:(count+batch_size_x_num_views)] = features_np + count += batch_size_x_num_views + + if ((i+1) % update_metadata_step) == 0: + # Update metadata every update_metadata_step mini-batches + update_metadata(count=count) + + if count != num_features: + raise ValueError(f'Count ({count}) must be equal to {num_features}.') + + update_metadata(count=count) + + +class PrecacheFeaturesDataset: + def __init__( + self, + data, + labels, + feature_extractor, + cache_dir, + random_view, + device, + init_size=256, + crop_size=224, + mean=_MEAN_PIXEL_IMAGENET, + std=_STD_PIXEL_IMAGENET, + precache_num_workers=4, + precache_batch_size=10, + epoch_size=None, + five_crop=True): + """ If the cache is made, we don't need the feature extractor.""" + + cache_dir = pathlib.Path(cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + print(f'cache_dir: {cache_dir}') + if five_crop: + done_file = cache_dir / 'cache_done' + memmap_dir = cache_dir / 'ten_crop' + else: + done_file = cache_dir / 'cache_done_1crop' + memmap_dir = cache_dir / 'single_crop' + + if (epoch_size is not None) and (epoch_size != len(data)): + elem_list = generate_element_list(epoch_size, len(data)) + data = torch.utils.data.Subset(data, elem_list) + labels = [labels[i] for i in elem_list] + assert len(labels) == len(data) + + self.labels = labels + self.data = ExtractCropsDataset( + data, init_size=init_size, crop_size=crop_size, mean=mean, std=std, + five_crop=five_crop) + + if not done_file.exists(): + print("Creating the memmap cache. It's going to take a while") + make_memmap_crops( + memmap_path=memmap_dir, + feature_extractor=feature_extractor, + dataset=self.data, + device=device, + num_workers=precache_num_workers, + batch_size=precache_batch_size, + num_views=(5 if five_crop else 1)) + done_file.touch() + + self._num_view = 10 if five_crop else 2 + self._central_view = CENTRAL_VIEW if five_crop else 0 + self.all_features = open_memmap(memmap_dir) + self.random_view = random_view + #self.view_index = 0 + + def __len__(self): + return len(self.labels) + + def __getitem__(self, index): + view_index = ( + random.randint(0, self._num_view-1) + if self.random_view else self._central_view) + + total_index = index * self._num_view + view_index + feature = torch.from_numpy(self.all_features[total_index]) + label = self.labels[index] + return feature, label + + +def get_data_loaders_linear_classification_precache( + dataset_name, + data_dir, + batch_size, + workers, + epoch_size, + feature_extractor, + cache_dir, + device, + precache_batch_size=10, + five_crop=True, + subset=None): + + assert isinstance(dataset_name, str) + assert isinstance(batch_size, int) + assert isinstance(workers, int) + assert (epoch_size is None) or isinstance(epoch_size, int) + + if dataset_name in ("ImageNet", "Places205"): + print("Loading data.") + dataset_train = torchvision.datasets.ImageFolder( + os.path.join(data_dir, "train"), transform=None) + dataset_test = torchvision.datasets.ImageFolder( + os.path.join(data_dir, "val"), transform=None) + train_split_str = "train" + if (subset is not None and subset >= 1): + train_split_str += f"_subset{subset}" + dataset_train = subset_of_ImageNet_train_split(dataset_train, subset) + + precache_batch_size_train = ( + (precache_batch_size // 10) + if five_crop else precache_batch_size) + dataset_train = PrecacheFeaturesDataset( + data=dataset_train, + labels=dataset_train.targets, + feature_extractor=feature_extractor, + cache_dir=cache_dir / dataset_name / train_split_str, + random_view=True, + device=device, + init_size=256, + crop_size=224, + mean=_MEAN_PIXEL_IMAGENET, + std=_STD_PIXEL_IMAGENET, + precache_num_workers=workers, + precache_batch_size=precache_batch_size_train, + epoch_size=epoch_size, + five_crop=five_crop) + + dataset_test = PrecacheFeaturesDataset( + data=dataset_test, + labels=dataset_test.targets, + feature_extractor=feature_extractor, + cache_dir=cache_dir / dataset_name / "val", + random_view=False, + device=device, + init_size=256, + crop_size=224, + mean=_MEAN_PIXEL_IMAGENET, + std=_STD_PIXEL_IMAGENET, + precache_num_workers=workers, + precache_batch_size=precache_batch_size, + epoch_size=None, + five_crop=False) + else: + raise VallueError(f"Unknown/not supported dataset {dataset_name}") + + print("Creating data loaders") + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_test = torch.utils.data.SequentialSampler(dataset_test) + + loader_train = torch.utils.data.DataLoader( + dataset_train, + batch_size=batch_size, + shuffle=(sampler_train is None), + num_workers=workers, + pin_memory=True, + sampler=sampler_train, + drop_last=True) + + loader_test = torch.utils.data.DataLoader( + dataset_test, + batch_size=batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, + sampler=sampler_test, + drop_last=False) + + return ( + loader_train, sampler_train, dataset_train, + loader_test, sampler_test, dataset_test) diff --git a/obow/feature_extractor.py b/obow/feature_extractor.py new file mode 100644 index 0000000..a862b3c --- /dev/null +++ b/obow/feature_extractor.py @@ -0,0 +1,299 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +import obow.utils as utils + + +class SequentialFeatureExtractorAbstractClass(nn.Module): + def __init__(self, all_feat_names, feature_blocks): + super(SequentialFeatureExtractorAbstractClass, self).__init__() + + assert(isinstance(feature_blocks, list)) + assert(isinstance(all_feat_names, list)) + assert(len(all_feat_names) == len(feature_blocks)) + + self.all_feat_names = all_feat_names + self._feature_blocks = nn.ModuleList(feature_blocks) + + + def _parse_out_keys_arg(self, out_feat_keys): + # By default return the features of the last layer / module. + out_feat_keys = ( + [self.all_feat_names[-1],] if out_feat_keys is None else + out_feat_keys) + + if len(out_feat_keys) == 0: + raise ValueError('Empty list of output feature keys.') + + for f, key in enumerate(out_feat_keys): + if key not in self.all_feat_names: + raise ValueError( + 'Feature with name {0} does not exist. ' + 'Existing features: {1}.'.format(key, self.all_feat_names)) + elif key in out_feat_keys[:f]: + raise ValueError( + 'Duplicate output feature key: {0}.'.format(key)) + + # Find the highest output feature in `out_feat_keys + max_out_feat = max( + [self.all_feat_names.index(key) for key in out_feat_keys]) + + return out_feat_keys, max_out_feat + + def get_subnetwork(self, out_feat_key): + if isinstance(out_feat_key, str): + out_feat_key = [out_feat_key,] + _, max_out_feat = self._parse_out_keys_arg(out_feat_key) + subnetwork = nn.Sequential() + for f in range(max_out_feat+1): + subnetwork.add_module( + self.all_feat_names[f], + self._feature_blocks[f] + ) + return subnetwork + + def forward(self, x, out_feat_keys=None): + """Forward the image `x` through the network and output the asked features. + Args: + x: input image. + out_feat_keys: a list/tuple with the feature names of the features + that the function should return. If out_feat_keys is None ( + default value) then the last feature of the network is returned. + + Return: + out_feats: If multiple output features were asked then `out_feats` + is a list with the asked output features placed in the same + order as in `out_feat_keys`. If a single output feature was + asked then `out_feats` is that output feature (and not a list). + """ + out_feat_keys, max_out_feat = self._parse_out_keys_arg(out_feat_keys) + out_feats = [None] * len(out_feat_keys) + + feat = x + for f in range(max_out_feat+1): + feat = self._feature_blocks[f](feat) + key = self.all_feat_names[f] + if key in out_feat_keys: + out_feats[out_feat_keys.index(key)] = feat + + out_feats = (out_feats[0] if len(out_feats) == 1 else out_feats) + + return out_feats + + +class BasicBlock(nn.Module): + def __init__( + self, + in_planes, + out_planes, + stride, + drop_rate=0.0, + kernel_size=3): + super(BasicBlock, self).__init__() + + if not isinstance(kernel_size, (list, tuple)): + kernel_size = [kernel_size, kernel_size] + assert isinstance(kernel_size, (list, tuple)) + assert len(kernel_size) == 2 + + kernel_size1, kernel_size2 = kernel_size + + assert kernel_size1 == 1 or kernel_size1 == 3 + padding1 = 1 if kernel_size1 == 3 else 0 + assert kernel_size2 == 1 or kernel_size2 == 3 + padding2 = 1 if kernel_size2 == 3 else 0 + + + self.equalInOut = (in_planes == out_planes and stride == 1) + + self.convResidual = nn.Sequential() + + if self.equalInOut: + self.convResidual.add_module('bn1', nn.BatchNorm2d(in_planes)) + self.convResidual.add_module('relu1', nn.ReLU(inplace=True)) + + self.convResidual.add_module( + 'conv1', + nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size1, + stride=stride, padding=padding1, bias=False)) + + self.convResidual.add_module('bn2', nn.BatchNorm2d(out_planes)) + self.convResidual.add_module('relu2', nn.ReLU(inplace=True)) + self.convResidual.add_module( + 'conv2', + nn.Conv2d( + out_planes, out_planes, kernel_size=kernel_size2, + stride=1, padding=padding2, bias=False)) + + if drop_rate > 0: + self.convResidual.add_module('dropout', nn.Dropout(p=drop_rate)) + + if self.equalInOut: + self.convShortcut = nn.Sequential() + else: + self.convShortcut = nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=False) + + def forward(self, x): + return self.convShortcut(x) + self.convResidual(x) + + +class NetworkBlock(nn.Module): + def __init__( + self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0): + super(NetworkBlock, self).__init__() + + self.layer = self._make_layer( + block, in_planes, out_planes, nb_layers, stride, drop_rate) + + def _make_layer( + self, block, in_planes, out_planes, nb_layers, stride, drop_rate): + + layers = [] + for i in range(nb_layers): + in_planes_arg = i == 0 and in_planes or out_planes + stride_arg = i == 0 and stride or 1 + layers.append( + block(in_planes_arg, out_planes, stride_arg, drop_rate)) + + return nn.Sequential(*layers) + + def forward(self, x): + return self.layer(x) + + +class WideResNet(SequentialFeatureExtractorAbstractClass): + def __init__( + self, + depth, + widen_factor=1, + drop_rate=0.0, + strides=[2, 2, 2], + global_pooling=True): + + assert (depth - 4) % 6 == 0 + num_channels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] + num_layers = [int((depth - 4) / 6) for _ in range(3)] + + block = BasicBlock + + all_feat_names = [] + feature_blocks = [] + + # 1st conv before any network block + conv1 = nn.Sequential() + conv1.add_module( + 'Conv', + nn.Conv2d(3, num_channels[0], kernel_size=3, padding=1, bias=False)) + conv1.add_module('BN', nn.BatchNorm2d(num_channels[0])) + conv1.add_module('ReLU', nn.ReLU(inplace=True)) + feature_blocks.append(conv1) + all_feat_names.append('conv1') + + # 1st block. + block1 = nn.Sequential() + block1.add_module( + 'Block', + NetworkBlock( + num_layers[0], num_channels[0], num_channels[1], BasicBlock, + strides[0], drop_rate)) + block1.add_module('BN', nn.BatchNorm2d(num_channels[1])) + block1.add_module('ReLU', nn.ReLU(inplace=True)) + feature_blocks.append(block1) + all_feat_names.append('block1') + + # 2nd block. + block2 = nn.Sequential() + block2.add_module( + 'Block', + NetworkBlock( + num_layers[1], num_channels[1], num_channels[2], BasicBlock, + strides[1], drop_rate)) + block2.add_module('BN', nn.BatchNorm2d(num_channels[2])) + block2.add_module('ReLU', nn.ReLU(inplace=True)) + feature_blocks.append(block2) + all_feat_names.append('block2') + + # 3rd block. + block3 = nn.Sequential() + block3.add_module( + 'Block', + NetworkBlock( + num_layers[2], num_channels[2], num_channels[3], BasicBlock, + strides[2], drop_rate)) + block3.add_module('BN', nn.BatchNorm2d(num_channels[3])) + block3.add_module('ReLU', nn.ReLU(inplace=True)) + feature_blocks.append(block3) + all_feat_names.append('block3') + + # global average pooling. + if global_pooling: + feature_blocks.append(utils.GlobalPooling(type="avg")) + all_feat_names.append('GlobalPooling') + + super(WideResNet, self).__init__(all_feat_names, feature_blocks) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class ResNet(SequentialFeatureExtractorAbstractClass): + def __init__(self, arch, pretrained=False, global_pooling=True): + net = models.__dict__[arch](num_classes=1000, pretrained=pretrained) + print(f'==> Pretrained parameters: {pretrained}') + all_feat_names = [] + feature_blocks = [] + + # 1st conv before any resnet block + conv1 = nn.Sequential() + conv1.add_module('Conv', net.conv1) + conv1.add_module('bn', net.bn1) + conv1.add_module('relu', net.relu) + conv1.add_module('maxpool', net.maxpool) + feature_blocks.append(conv1) + all_feat_names.append('conv1') + + # 1st block. + feature_blocks.append(net.layer1) + all_feat_names.append('block1') + # 2nd block. + feature_blocks.append(net.layer2) + all_feat_names.append('block2') + # 3rd block. + feature_blocks.append(net.layer3) + all_feat_names.append('block3') + # 4th block. + feature_blocks.append(net.layer4) + all_feat_names.append('block4') + # global average pooling. + if global_pooling: + feature_blocks.append(utils.GlobalPooling(type="avg")) + all_feat_names.append('GlobalPooling') + + super(ResNet, self).__init__(all_feat_names, feature_blocks) + self.num_channels = net.fc.in_features + + +def FeatureExtractor(arch, opts): + all_architectures = ( + 'wrn', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', + 'resnext101_32x8d', 'resnext50_32x4d', 'wide_resnet101_2', + 'wide_resnet50_2') + + assert arch in all_architectures + if arch == 'wrn': + num_channels = opts["widen_factor"] * 64 + return WideResNet(**opts), num_channels + else: + resnet_extractor = ResNet(arch=arch, **opts) + return resnet_extractor, resnet_extractor.num_channels diff --git a/obow/fewshot.py b/obow/fewshot.py new file mode 100644 index 0000000..ae94878 --- /dev/null +++ b/obow/fewshot.py @@ -0,0 +1,133 @@ +from __future__ import print_function + +import torch +import torch.nn.functional as F +import obow.utils as utils + + +def preprocess_5D_features(features, global_pooling): + meta_batch_size, num_examples, channels, height, width = features.size() + features = features.view( + meta_batch_size * num_examples, channels, height, width) + + if global_pooling: + features = utils.global_pooling(features, "avg") + + features = features.view(meta_batch_size, num_examples, -1) + + return features + + +def average_train_features(features_train, labels_train): + labels_train_transposed = labels_train.transpose(1,2) + weight_novel = torch.bmm(labels_train_transposed, features_train) + weight_novel = weight_novel.div( + labels_train_transposed.sum(dim=2, keepdim=True).expand_as( + weight_novel)) + + return weight_novel + + +def few_shot_classifier_with_prototypes( + features_test, features_train, labels_train, + scale_cls=10.0, global_pooling=True): + + #******* Generate classification weights for the novel categories ****** + if features_train.dim() == 5: + features_train = preprocess_5D_features(features_train, global_pooling) + features_test = preprocess_5D_features(features_test, global_pooling) + + assert features_train.dim() == 3 + assert features_test.dim() == 3 + + meta_batch_size = features_train.size(0) + num_novel = labels_train.size(2) + features_train = F.normalize(features_train, p=2, dim=2) + prototypes = average_train_features(features_train, labels_train) + prototypes = prototypes.view(meta_batch_size, num_novel, -1) + #*********************************************************************** + features_test = F.normalize(features_test, p=2, dim=2) + prototypes = F.normalize(prototypes, p=2, dim=2) + scores = scale_cls * torch.bmm(features_test, prototypes.transpose(1,2)) + + return scores + + +def few_shot_feature_classification( + classifier, features_test, features_train, labels_train_1hot, labels_test): + + scores = few_shot_classifier_with_prototypes( + features_test=features_test, + features_train=features_train, + labels_train=labels_train_1hot) + + assert scores.dim() == 3 + + scores = scores.view(scores.size(0) * scores.size(1), -1) + labels_test = labels_test.view(-1) + assert scores.size(0) == labels_test.size(0) + + loss = F.cross_entropy(scores, labels_test) + + with torch.no_grad(): + accuracy = utils.accuracy(scores, labels_test, topk=(1,)) + + return scores, loss, accuracy + + +@torch.no_grad() +def fewshot_classification( + feature_extractor, + images_train, + labels_train, + labels_train_1hot, + images_test, + labels_test, + feature_levels): + assert images_train.dim() == 5 + assert images_test.dim() == 5 + assert images_train.size(0) == images_test.size(0) + assert images_train.size(2) == images_test.size(2) + assert images_train.size(3) == images_test.size(3) + assert images_train.size(4) == images_test.size(4) + assert labels_train.dim() == 2 + assert labels_test.dim() == 2 + assert labels_train.size(0) == labels_test.size(0) + assert labels_train.size(0) == images_train.size(0) + assert (feature_levels is None) or isinstance(feature_levels, (list, tuple)) + meta_batch_size = images_train.size(0) + + images_train = utils.convert_from_5d_to_4d(images_train) + images_test = utils.convert_from_5d_to_4d(images_test) + labels_test = labels_test.view(-1) + batch_size_train = images_train.size(0) + images = torch.cat([images_train, images_test], dim=0) + + # Extract features from the train and test images. + features = feature_extractor(images, feature_levels) + if isinstance(features, torch.Tensor): + features = [features,] + + labels_test =labels_test.view(-1) + + loss, accuracy = [], [] + for i, features_i in enumerate(features): + features_train = features_i[:batch_size_train] + features_test = features_i[batch_size_train:] + features_train = utils.add_dimension(features_train, meta_batch_size) + features_test = utils.add_dimension(features_test, meta_batch_size) + + scores = few_shot_classifier_with_prototypes( + features_test, features_train, labels_train_1hot, + scale_cls=10.0, global_pooling=True) + + scores = scores.view(scores.size(0) * scores.size(1), -1) + assert scores.size(0) == labels_test.size(0) + loss.append(F.cross_entropy(scores, labels_test)) + with torch.no_grad(): + accuracy.append(utils.accuracy(scores, labels_test, topk=(1,))[0]) + + loss = torch.stack(loss, dim=0) + accuracy = torch.stack(accuracy, dim=0) + + return loss, accuracy diff --git a/obow/solver.py b/obow/solver.py new file mode 100644 index 0000000..39e47f7 --- /dev/null +++ b/obow/solver.py @@ -0,0 +1,372 @@ +"""Define a generic class for training and testing learning algorithms.""" +import sys +import os +import os.path +import pathlib +import datetime +import glob +import logging +import time +import math + +import torch +import torch.optim +import obow.utils as utils + + +logger = utils.setup_dist_logger(logging.getLogger(__name__)) + + +def initialize_optimizer(parameters, opts): + if opts is None: + return None + + optim_type = opts["optim_type"] + learning_rate = opts["lr"] + + if optim_type == "adam": + optimizer = torch.optim.Adam( + parameters, + lr=learning_rate, + betas=opts["beta"], + weight_decay=opts["weight_decay"]) + elif optim_type == "sgd": + optimizer = torch.optim.SGD( + parameters, + lr=learning_rate, + momentum=opts["momentum"], + weight_decay=opts["weight_decay"], + nesterov=opts["nesterov"]) + else: + raise NotImplementedError(f"Unrecognized optim_type: {optim_type}.") + + return optimizer + + +def compute_cosine_learning_rate(epoch, start_lr, end_lr, num_epochs, warmup_epochs=0): + if (warmup_epochs > 0) and (epoch < warmup_epochs): + # Warm-up period. + return start_lr * (float(epoch) / warmup_epochs) + assert epoch >= warmup_epochs + + scale = 0.5 * (1. + math.cos((math.pi * (epoch-warmup_epochs)) / (num_epochs-warmup_epochs))) + return end_lr + (start_lr - end_lr) * scale + + +class Solver: + def __init__( + self, + model, + exp_dir, + device, + opts, + print_freq=100, + optimizer=None, + use_fp16=False, + amp=None): + logger.info(f"Initialize solver: {opts}") + self.exp_dir = pathlib.Path(exp_dir) + self.exp_name = self.exp_dir.name + if utils.get_rank() == 0: + os.makedirs(self.exp_dir, exist_ok=True) + + self.model = model + self.opts = opts + self.optimizer = optimizer + self.use_fp16 = use_fp16 + if self.use_fp16: + assert amp is not None + self.amp = amp + + self.start_lr = self.opts["lr"] + self.current_lr = self.start_lr + self.num_epochs = self.opts["num_epochs"] + self.lr_schedule_type = self.opts["lr_schedule_type"] + assert self.lr_schedule_type in ("cos", "step_lr", "cos_warmup") + # TODO: use torch.optim.lr_scheduler the package. + if self.lr_schedule_type == "step_lr": + self.lr_schedule = self.opts["lr_schedule"] + self.lr_decay = self.opts["lr_decay"] + elif self.lr_schedule_type in ("cos", "cos_warmup"): + self.end_lr = self.opts.pop("end_lr", 0.0) + if self.lr_schedule_type == "cos_warmup": + self.warmup_epochs = self.opts.pop("warmup_epochs") + else: + self.warmup_epochs = 0 + self.eval_freq = self.opts.pop("eval_freq", 1) + self._best_metric_name = self.opts.get("best_metric_name") + self._best_largest = ( + self.opts["best_largest"] + if ("best_largest" in self.opts) else True) + self.reset_best_model_record() + + self.permanent = self.opts.get("permanent", -1) + self.print_freq = print_freq + self.device = device + self._epoch = 0 + + def set_device(self, device): + self.device = device + + def reset_best_model_record(self): + self._best_metric_val = None + self._best_model_meters = None + self._best_epoch = None + + def initialize_optimizer(self): + if self.optimizer is None: + logger.info(f"Initialize optimizer") + parameters = filter(lambda p: p.requires_grad, self.model.parameters()) + self.optimizer = initialize_optimizer(parameters, self.opts) + assert self.optimizer is not None + + def adjust_learning_rate(self, epoch): + """Decay the learning rate based on schedule""" + for i, param_group in enumerate(self.optimizer.param_groups): + if self.lr_schedule_type in ("cos", "cos_warmup"): + start_lr = param_group.get("start_lr", self.start_lr) + end_lr = param_group.get("end_lr", self.end_lr) + learning_rate = compute_cosine_learning_rate( + epoch, start_lr, end_lr, self.num_epochs, self.warmup_epochs) + elif self.lr_schedule_type == "step_lr": # stepwise lr schedule + learning_rate = param_group.get("start_lr", self.start_lr) + for milestone in self.lr_schedule: + learning_rate *= self.lr_decay if (epoch >= milestone) else 1. + else: + raise NotImplementedError( + f"Not supported learning rate schedule type: {self.lr_schedule_type}") + + param_group["lr"] = learning_rate + logger.info(f"==> Set lr for group {i}: {learning_rate:.10f}") + + def adjust_learning_rate_per_iter(self, epoch, iter, num_batches): + # TODO: the code for adjusting the learning rate needs cleaning up and + # refactoring. + for i, param_group in enumerate(self.optimizer.param_groups): + if self.lr_schedule_type != "cos_warmup" or epoch >= self.warmup_epochs: + continue + total_iter = epoch * num_batches + iter + start_lr = param_group.get("start_lr", self.start_lr) + learning_rate = start_lr * (float(total_iter) / (self.warmup_epochs * num_batches)) + param_group["lr"] = learning_rate + if (iter % 100) == 0: + logger.info(f"==> Set lr for group {i}: {learning_rate:.10f}") + + def find_last_epoch(self, suffix): + search_pattern = self.net_checkpoint_filename("{epoch}", suffix) + last_epoch, _ = utils.find_last_epoch(search_pattern) + logger.info(f"Load checkpoint of last epoch: {str(last_epoch)}") + return last_epoch + + def delete_checkpoint(self, epoch, suffix=""): + if utils.get_rank() == 0: + filename = pathlib.Path(self.net_checkpoint_filename(epoch, suffix)) + if filename.is_file(): + logger.info(f"Deleting {filename}") + os.remove(filename) + filename = pathlib.Path(self.optim_checkpoint_filename(epoch, suffix)) + if filename.is_file(): + logger.info(f"Deleting {filename}") + os.remove(filename) + + def save_checkpoint(self, epoch, suffix="", meters=None): + if utils.get_rank() == 0: + self.save_network(epoch, suffix, meters) + self.save_optimizer(epoch, suffix) + + def save_network(self, epoch, suffix="", meters=None): + filename = self.net_checkpoint_filename(epoch, suffix) + logger.info(f"Saving model params to: {filename}") + state = { + "epoch": epoch, + "network": self.model.state_dict(), + "meters": meters,} + if self.use_fp16: + state["amp"] = self.amp.state_dict() + torch.save(state, filename) + + def save_optimizer(self, epoch, suffix=""): + assert self.optimizer is not None + filename = self.optim_checkpoint_filename(epoch, suffix) + logger.info(f"Saving model optimizer to: {filename}") + state = { + "epoch": epoch, + "optimizer": self.optimizer.state_dict() + } + torch.save(state, filename) + + def load_checkpoint(self, epoch, suffix="", load_optimizer=True): + if epoch == -1: + epoch = self.find_last_epoch(suffix) + assert isinstance(epoch, int) + assert epoch >= 0 + self.load_network(epoch, suffix) # load network parameters + if load_optimizer: # initialize and load optimizer + self.load_optimizer(epoch, suffix) + self._epoch = epoch + + def load_network(self, epoch, suffix=""): + filename = pathlib.Path(self.net_checkpoint_filename(epoch, suffix)) + logger.info(f"Loading model params from: {filename}") + assert filename.is_file() + checkpoint = torch.load(filename, map_location="cpu") + self.model.load_state_dict(checkpoint["network"]) + if self.use_fp16: + self.amp.load_state_dict(checkpoint["amp"]) + return checkpoint["epoch"] + + def load_optimizer(self, epoch, suffix=""): + self.initialize_optimizer() + filename = pathlib.Path(self.optim_checkpoint_filename(epoch, suffix)) + logger.info(f"Loading model optimizer from: {filename}") + assert filename.is_file() + checkpoint = torch.load(filename, map_location="cpu") + self.optimizer.load_state_dict(checkpoint["optimizer"]) + return checkpoint["epoch"] + + def net_checkpoint_filename(self, epoch, suffix=""): + return str(self.exp_dir / f"model_net_checkpoint_{epoch}{suffix}.pth.tar") + + def optim_checkpoint_filename(self, epoch, suffix=""): + return str(self.exp_dir / f"model_optim_checkpoint_{epoch}{suffix}.pth.tar") + + def solve( + self, + loader_train, + distributed, + sampler_train, + loader_test=None): + + assert isinstance(distributed, bool) + + self.initialize_optimizer() + self.reset_best_model_record() + num_epochs = self.num_epochs + start_epoch = self._epoch + self._start_epoch = start_epoch + logger.info(f"Start training from epoch {start_epoch}") + start_time = time.time() + + for epoch in range(start_epoch, num_epochs): + self._epoch = epoch + logger.info(f"Training epoch: [{epoch+1}/{num_epochs}] ({self.exp_name})") + + if distributed: + logger.info( + f"Setting epoch={epoch} for distributed sampling.") + assert not (sampler_train is None) + sampler_train.set_epoch(epoch) + + self.adjust_learning_rate(epoch) + self.run_train_epoch(loader_train, epoch) + + self.save_checkpoint(epoch + 1) # create a checkpoint in the current epoch + is_permanent = (self.permanent > 0) and (epoch % self.permanent) == 0 + if (start_epoch != epoch) and (is_permanent is False): + # delete the checkpoint of the previous epoch + self.delete_checkpoint(epoch) + + if (loader_test is not None) and ((epoch+1) % self.eval_freq) == 0: + if not isinstance(loader_test, (list, tuple)): + loader_test = [loader_test,] + logger.info(f"Evaluate ({self.exp_name})") + test_metric_logger = [] + for i, loader_test_this in enumerate(loader_test): + test_metric_logger.append( + self.evaluate(loader_test_this, test_name=str(i))) + + self.keep_best_model_record(test_metric_logger[0], epoch) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info(f"Training time {total_time_str}") + + def run_train_epoch(self, loader_train, epoch): + self.model.train() + self.start_of_training_epoch() + num_batches = len(loader_train) + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter("iter/s", utils.AverageMeter(":.2f", out_val=True)) + header = f"Epoch: [{epoch+1}]" + for self._iter, mini_batch in enumerate( + metric_logger.log_every(loader_train, self.print_freq, header)): + start_time = time.time() + self.adjust_learning_rate_per_iter(epoch, self._iter, num_batches) + self._total_iter = self._iter + epoch * num_batches + self.train_step(mini_batch, metric_logger) + metric_logger["iter/s"].update(1.0 / (time.time() - start_time)) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + logger.info(f"==> Results: {str(metric_logger)}") + self.end_of_training_epoch() + + return metric_logger + + def evaluate(self, loader_test, test_name=None): + self.model.eval() + num_batches = len(loader_test) + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter("iter/s", utils.AverageMeter(":.2f", out_val=True)) + header = "Test :" if (test_name is None) else f"Test {test_name}:" + for iter, mini_batch in enumerate( + metric_logger.log_every(loader_test, self.print_freq, header)): + start_time = time.time() + self.evaluation_step(mini_batch, metric_logger) + metric_logger["iter/s"].update(1.0 / (time.time() - start_time)) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + logger.info(f"==> Results: {str(metric_logger)}") + + return metric_logger + + def keep_best_model_record(self, test_metric_logger, epoch): + if self._best_metric_name is None: + return + if (self._best_metric_name not in test_metric_logger.meters): + raise Warning( + f"The provided metric {self._best_metric_name} for keeping the " + "best model is not computed by the evaluation routine.") + return + + val = test_metric_logger[self._best_metric_name].val + if ((self._best_metric_val is None) or + (self._best_largest and (val >= self._best_metric_val)) or + ((not self._best_largest) and (val <= self._best_metric_val))): + + self._best_metric_val = val + self._best_model_meters = str(test_metric_logger) + self.save_checkpoint(epoch+1, suffix=".best") + if (self._best_epoch is not None): + self.delete_checkpoint(self.best_epoch+1, suffix=".best") + self._best_epoch = epoch + logger.info( + f"==> Best results w.r.t. {self._best_metric_name}: " + f"Epoch: [{self._best_epoch+1}] {self._best_model_meters}") + + # FROM HERE ON THERE ARE ABSTRACT FUNCTIONS THAT MUST BE IMPLEMENTED BY THE + # CLASS THAT INHERITS THE Solver CLASS + def train_step(self, mini_batch, metric_logger): + """Implements a training step that includes: + * Forward a batch through the network(s) + * Compute loss(es) + * Backward propagation through the networks + * Apply optimization step(s) + """ + pass + + def evaluation_step(self, mini_batch, metric_logger): + """Implements an evaluation step that includes: + * Forward a batch through the network(s) + * Compute loss(es) or any other evaluation metrics. + """ + pass + + + def end_of_training_epoch(self): + pass + + + def start_of_training_epoch(self): + pass diff --git a/obow/utils.py b/obow/utils.py new file mode 100644 index 0000000..4a77cf1 --- /dev/null +++ b/obow/utils.py @@ -0,0 +1,392 @@ +import glob +import os +import pathlib +import datetime +import logging +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed + + +from collections import defaultdict + + +def setup_printing(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not torch.distributed.is_available(): + return False + if not torch.distributed.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return torch.distributed.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return torch.distributed.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +class setup_dist_logger(object): # very hacky... + def __init__(self, logger): + self.logger = logger + self.is_main_process = is_main_process() + + def info(self, msg, *args, **kwargs): + if self.is_main_process: + self.logger.info(msg, *args, **kwargs) + + +def setup_logger(dst_dir, name): + logger = logging.getLogger(name) + + strHandler = logging.StreamHandler() + formatter = logging.Formatter( + '%(asctime)s - %(name)-8s - %(levelname)-6s - %(message)s') + strHandler.setFormatter(formatter) + logger.addHandler(strHandler) + logger.setLevel(logging.INFO) + + log_dir = dst_dir / "logs" + os.makedirs(log_dir, exist_ok=True) + now_str = datetime.datetime.now().__str__().replace(' ','_') + now_str = now_str.replace(' ','_').replace('-','').replace(':','') + logger.addHandler(logging.FileHandler(log_dir / f'LOG_INFO_{now_str}.txt')) + + return logger + + +logger = setup_dist_logger(logging.getLogger(__name__)) + + +@torch.no_grad() +def reduce_all(tensor): + if get_world_size() > 1: + torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) + return tensor + + +@torch.no_grad() +def concat_all_gather(tensor): + if get_world_size() > 1: + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + return torch.cat(tensors_gather, dim=0) + else: + return tensor + + +@torch.no_grad() +def top1accuracy(output, target): + pred = output.max(dim=1)[1] + pred = pred.view(-1) + target = target.view(-1) + accuracy = 100 * pred.eq(target).float().mean() + return accuracy + + +@torch.no_grad() +def sanity_check_for_distributed_training(model, buffers_only_bow_extr=True): + """ Verifies that all nodes have the same copy of params & bow buffers. """ + if get_world_size() > 1: + world_size = get_world_size() + rank = get_rank() + is_close_all = True + list_of_failed_states = [] + torch.distributed.barrier() + for name, state in model.named_parameters(): + state = state.data.detach() + state_src = state.clone() + torch.distributed.barrier() + torch.distributed.broadcast(state_src, src=0) + torch.distributed.barrier() + is_close = torch.allclose(state, state_src, rtol=1e-05, atol=1e-08) + is_close_tensor = torch.tensor( + [is_close], dtype=torch.float64, device='cuda') + torch.distributed.barrier() + is_close_all_nodes = concat_all_gather(is_close_tensor) + is_close_all_nodes = [v >= 0.5 for v in is_close_all_nodes.tolist()] + is_close_all_nodes_reduce = all(is_close_all_nodes) + is_close_all &= is_close_all_nodes_reduce + + status = "PASSED" if is_close_all_nodes_reduce else "FAILED" + + logger.info(f"====> Check {name}: [{status}]") + if not is_close_all_nodes_reduce: + logger.info(f"======> Failed nodes: [{is_close_all_nodes}]") + list_of_failed_states.append(name) + + for name, state in model.named_buffers(): + if buffers_only_bow_extr and name.find("module.bow_extractor") == -1: + continue + state = state.data.detach().float() + state_src = state.clone() + torch.distributed.barrier() + torch.distributed.broadcast(state_src, src=0) + torch.distributed.barrier() + is_close = torch.allclose(state, state_src, rtol=1e-05, atol=1e-08) + is_close_tensor = torch.tensor( + [is_close], dtype=torch.float64, device='cuda') + torch.distributed.barrier() + is_close_all_nodes = concat_all_gather(is_close_tensor) + is_close_all_nodes = [v >= 0.5 for v in is_close_all_nodes.tolist()] + is_close_all_nodes_reduce = all(is_close_all_nodes) + is_close_all &= is_close_all_nodes_reduce + + status = "PASSED" if is_close_all_nodes_reduce else "FAILED" + + logger.info(f"====> Check {name}: [{status}]") + if not is_close_all_nodes_reduce: + logger.info(f"======> Failed nodes: [{is_close_all_nodes}]") + list_of_failed_states.append(name) + + status = "ALL PASSED" if is_close_all else "FAILED" + logger.info(f"==> Sanity checked [{status}]") + if not is_close_all: + logger.info(f"====> List of failed states:\n{list_of_failed_states}") + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, fmt=':.4f', out_val=False): + self.fmt = fmt + self.out_val = out_val + self.reset() + + def reset(self): + self.val = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + + @property + def avg(self): + if self.count > 0: + return self.sum / self.count + else: + return 0 + + def synchronize_between_processes(self): + if not is_dist_avail_and_initialized(): + return + values = torch.tensor( + [self.count, self.sum], dtype=torch.float64, device='cuda') + torch.distributed.barrier() + torch.distributed.all_reduce(values) + values = values.tolist() + self.count = int(values[0]) + self.sum = values[1] + + def __str__(self): + if self.out_val: + fmtstr = '{avg' + self.fmt + '} ({val' + self.fmt + '})' + return fmtstr.format(avg=self.avg, val=self.val) + else: + fmtstr = '{avg' + self.fmt + '}' + return fmtstr.format(avg=self.avg) + + +class MetricLogger(object): + def __init__(self, delimiter="\t", prefix=""): + self.meters = defaultdict(AverageMeter) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getitem__(self, attr): + if not (attr in self.meters): + self.meters[attr] = AverageMeter() + return self.meters[attr] + + def __str__(self): + meters_str = [] + for key, meter in self.meters.items(): + meters_str.append("{}: {}".format(key, str(meter))) + return self.delimiter.join(meters_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None, sync=True): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = AverageMeter(out_val=True) + data_time = AverageMeter(out_val=True) + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + + log_msg_fmt = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}']) + if torch.cuda.is_available(): + log_msg_cuda_fmt = 'max mem: {memory:.0f}' + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0: + eta_seconds = iter_time.avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + log_msg = log_msg_fmt.format( + i, len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time)) + if torch.cuda.is_available(): + log_msg_cuda = log_msg_cuda_fmt.format( + memory=torch.cuda.max_memory_allocated() / MB) + log_msg = self.delimiter.join([log_msg, log_msg_cuda]) + logger.info(log_msg) + + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info(f'{header} Total time: {total_time}') + + + +def global_pooling(x, type): + assert x.dim() == 4 + if type == 'max': + return F.max_pool2d(x, (x.size(2), x.size(3))) + elif type == 'avg': + return F.avg_pool2d(x, (x.size(2), x.size(3))) + else: + raise ValueError( + f"Unknown pooling type '{type}'. Supported types: ('avg', 'max').") + + +class GlobalPooling(nn.Module): + def __init__(self, type): + super(GlobalPooling, self).__init__() + assert type in ("avg", "max") + self.type = type + + def forward(self, x): + return global_pooling(x, self.type) + + def extra_repr(self): + s = f'type={self.type}' + return s + + +class L2Normalize(nn.Module): + def __init__(self, dim): + super(L2Normalize, self).__init__() + self.dim = dim + + def forward(self, x): + return F.normalize(x, p=2, dim=self.dim) + + +def convert_from_5d_to_4d(tensor_5d): + _, _, channels, height, width = tensor_5d.size() + return tensor_5d.view(-1, channels, height, width) + + +def add_dimension(tensor, dim_size): + assert((tensor.size(0) % dim_size) == 0) + return tensor.view( + [dim_size, tensor.size(0) // dim_size,] + list(tensor.size()[1:])) + + +def find_last_epoch(search_pattern): + print(f"Search the last checkpoint with pattern {str(search_pattern)}") + + search_pattern = search_pattern.format(epoch="*") + + all_files = glob.glob(search_pattern) + if len(all_files) == 0: + raise ValueError(f"{search_pattern}: no such file.") + + substrings = search_pattern.split("*") + assert(len(substrings) == 2) + start, end = substrings + all_epochs = [fname.replace(start,"").replace(end,"") for fname in all_files] + all_epochs = [int(epoch) for epoch in all_epochs if epoch.isdigit()] + assert(len(all_epochs) > 0) + all_epochs = sorted(all_epochs) + last_epoch = int(all_epochs[-1]) + + checkpoint_filename = search_pattern.replace("*", str(last_epoch)) + print(f"Last epoch: {str(last_epoch)} ({checkpoint_filename})") + + checkpoint_filename = pathlib.Path(checkpoint_filename) + assert checkpoint_filename.is_file() + + return last_epoch, checkpoint_filename + + +def load_network_params(network, filename, strict=True): + if isinstance(filename, str): + filename = pathlib.Path(filename) + + print(f"[Rank {get_rank()}: load network params from: {filename}") + assert filename.is_file() + checkpoint = torch.load(filename, map_location="cpu") + return network.load_state_dict(checkpoint["network"], strict=strict) diff --git a/obow/visualization.py b/obow/visualization.py new file mode 100644 index 0000000..f6e49f5 --- /dev/null +++ b/obow/visualization.py @@ -0,0 +1,202 @@ +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import obow.utils as utils +import numpy as np +import datetime +import time + +from tqdm import tqdm + + +logger = utils.setup_dist_logger(logging.getLogger(__name__)) + + +def extract_visual_words(model, dataloader, dtype='uint32'): + model.eval() + all_vword_ids, all_vword_mag, num_words = [], [], [] + count = 0 + for i, batch in enumerate(tqdm(dataloader)): + with torch.no_grad(): + img = batch[0] if isinstance(batch, (list, tuple)) else batch + img = img[0] if isinstance(img, (list, tuple)) else img + img = img.cuda() + assert img.dim()==4 + + # Add horizontal flip: + img = torch.stack([img, torch.flip(img, dims=(3,))], dim=1) + assert img.dim() == 5 and img.size(1) == 2 + img = img.view(2 * img.size(0), 3, img.size(3), img.size(4)) + + features = model.feature_extractor_teacher(img, model._bow_levels) + _, vword_codes = model.bow_extractor(features) + assert isinstance(vword_codes, (list, tuple)) + num_levels = len(vword_codes) + batch_size = img.size(0) // 2 + + if i == 0: + max_count = len(dataloader) * batch_size + logger.info(f'image size: {img.size()}') + logger.info(f'max count: {max_count}') + logger.info(f'batch size: {batch_size}') + for level in range(num_levels): + _, num_words_this, height, width = vword_codes[level].size() + dshape = [max_count * 2, height, width] + all_vword_ids.append(np.zeros(dshape, dtype=dtype)) + all_vword_mag.append(np.zeros(dshape, dtype='float32')) + num_words.append(num_words_this) + logger.info( + f'Level {level}: shape: {dshape[1:]}, ' + f'num_words: {num_words[level]}') + + for level in range(num_levels): + vwords_mag, vwords_ids = vword_codes[level].max(dim=1) + assert vwords_mag.dim() == 3 + assert vwords_ids.dim() == 3 + vwords_ids = vwords_ids.cpu().numpy() + vwords_mag = vwords_mag.cpu().numpy().astype('float32') + all_vword_ids[level][count:(count + batch_size*2)] = vwords_ids.astype(dtype) + all_vword_mag[level][count:(count + batch_size*2)] = vwords_mag + + count += batch_size*2 + + for level in range(num_levels): + all_vword_ids[level] = all_vword_ids[level][:count] + all_vword_mag[level] = all_vword_mag[level][:count] + logger.info(f'Shape of extracted dataset: {all_vword_ids[level].shape}') + + return all_vword_ids, all_vword_mag, num_words + + +def visualize_visual_words( + num_words, + num_patches, + patch_size, + dataset_images, + all_vword_ids, + all_vword_mag, + words_order, + dst_dir, + rank=0, + offset_k=0, + mean_pixel=[0.485, 0.456, 0.406], + std_pixel=[0.229, 0.224, 0.225], + skip_border=True): + + assert all_vword_mag.shape == all_vword_ids.shape + assert (len(dataset_images) * 2) == all_vword_ids.shape[0] + + mean_pixel = torch.Tensor(mean_pixel).view(1, 3, 1, 1) + std_pixel = torch.Tensor(std_pixel).view(1, 3, 1, 1) + + num_images, height, width = all_vword_ids.shape + num_images = num_images // 2 + all_vword_ids = all_vword_ids.reshape(num_images, 2, height, width) + all_vword_mag = all_vword_mag.reshape(num_images, 2, height, width) + all_vword_mag_flat = all_vword_mag.reshape(-1) + all_vword_ids_flat = all_vword_ids.reshape(-1) + + num_locs = height * width + num_locs_flip = 2 * height * width + + assert height == width + size_out = (height+2) if skip_border else height + + def parse_index(index): + img = index // num_locs_flip + flip_loc = index % num_locs_flip + flip = flip_loc // num_locs + loc = flip_loc % num_locs + y = loc // width + x = loc % width + return img, flip, y, x + + def extract_patch(image, flip, y, x): + assert image.dim() == 3 + assert image.size(0) == 3 + assert image.size(1) == image.size(2) + assert (image.size(2) % size_out) == 0 + size_in = image.size(2) + stride = size_in // size_out + offset = stride // 2 # offset due to image padding + halfp = patch_size // 2 + + if skip_border: + x = x + 1 + y = y + 1 + + if flip == 1: + image = torch.flip(image, dims=(2,)) + + image_pad = F.pad(image, (halfp, halfp, halfp, halfp), 'constant', 0.0) + + xc = x * stride + offset + halfp + yc = y * stride + offset + halfp + x1 = xc - halfp + y1 = yc - halfp + x2 = xc + halfp + y2 = yc + halfp + assert x1 > 0 + assert y1 > 0 + assert y2 < image_pad.size(1) + assert x2 < image_pad.size(2) + + #print(x1, x2, y1, y2) + return image_pad[:, y1:y2, x1:x2] + + num_words_order = words_order.shape[0] + iter_start_time = time.time() + total_time = 0 + for k in range(num_words_order): + visual_word_id = words_order[k] + indices_k = np.nonzero(all_vword_ids_flat == visual_word_id)[0] + if indices_k.shape[0] == 0: + print(f"==> The visual word with id {visual_word_id} is empty.") + else: + vword_mag_k = all_vword_mag_flat[indices_k] + order = np.argsort(-vword_mag_k) + vword_mag_k = vword_mag_k[order] + indices_k = indices_k[order] + + if order.shape[0] >= 2: + assert vword_mag_k[0] >= vword_mag_k[1] + + count_patches = 0 + count = 0 + used_image = np.zeros(num_images, dtype='uint8') + patches_k = torch.zeros(num_patches, 3, patch_size, patch_size) + while (count_patches < num_patches) and (count < order.shape[0]): + index = indices_k[count] + img, flip, y, x = parse_index(index) + assert all_vword_mag_flat[index] == vword_mag_k[count] + if used_image[img] == 0: + used_image[img] = 1 + assert all_vword_ids[img,flip,y,x] == visual_word_id + assert all_vword_mag[img,flip,y,x] == vword_mag_k[count] + image = dataset_images[img][0] # get image + patch_this = extract_patch(image, flip, y, x) # extract patch + patches_k[count_patches].copy_(patch_this) # copy patch. + count_patches += 1 + count += 1 + + #image_normalized = (image - mean) / std + patches_k_unnormalized = patches_k.mul(std_pixel).add(mean_pixel) + patches_k_unnormalized = patches_k_unnormalized.clamp_(0.0, 1.0) + patches_k_vis = torchvision.utils.make_grid( + patches_k_unnormalized, nrow=8, padding=5, normalize=False) + + dst_file = dst_dir + f'/freq_{offset_k+k}_visual_word_{visual_word_id}.jpg' + torchvision.utils.save_image(patches_k_vis, dst_file) + + iter_time = time.time() - iter_start_time + total_time += iter_time + if (k % 20) == 0: + avg_time = total_time / (k+1) + eta_secs = avg_time * (num_words_order - k) + elaphsed_time_string = str(datetime.timedelta(seconds=int(total_time))) + eta_string = str(datetime.timedelta(seconds=int(eta_secs))) + print(f"Iteration [{k}/{num_words_order}][rank={rank}]: elapsed_time = {elaphsed_time_string}, eta = {eta_string}.") + + iter_start_time = time.time() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b6eea2e --- /dev/null +++ b/setup.py @@ -0,0 +1,15 @@ +from setuptools import setup +from setuptools import find_packages + +setup( + name='OBoW', + version='0.0.1', + description='OBoW', + author='Spyros Gidaris', + packages=find_packages(), + install_requires=["tqdm", + "numpy", + "torch", + "torchvision", + "Pillow"] + ) diff --git a/utils/configs/benchmark_tasks/image_classification/voc07/resnet50_supervised_extract_gap_features.yaml b/utils/configs/benchmark_tasks/image_classification/voc07/resnet50_supervised_extract_gap_features.yaml new file mode 100644 index 0000000..a04ad4e --- /dev/null +++ b/utils/configs/benchmark_tasks/image_classification/voc07/resnet50_supervised_extract_gap_features.yaml @@ -0,0 +1,25 @@ +DATASET: voc2007 +NUM_DEVICES: 8 +LOGGER_FREQUENCY: 10 +MODEL: + NUM_CLASSES: 20 + MODEL_NAME: resnet_supervised_finetune_linear + DEPTH: 50 + ALLOW_INPLACE_SUM: True + MEMONGER: True + EXTRACT_FEATURES_ONLY: True + EXTRACT_BLOBS: [pool5, pool5_bn] +TRAIN: + DATA_TYPE: train + BATCH_SIZE: 256 + GLOBAL_RESIZE_VALUE: 224 + DATA_TRANSFORMS: [scale, global_resize] + DATA_PROCESSING: [color_normalization] +TEST: + # for VOC2007, we train on the trainval split and evaluate on the test set. + DATA_TYPE: test + BATCH_SIZE: 256 + # IN1k RN50 supervised + PARAMS_FILE: https://dl.fbaipublicfiles.com/fair_self_supervision_benchmark/models/resnet50_in1k_supervised.pkl + # 386 init places205 supervised + PARAMS_FILE: https://dl.fbaipublicfiles.com/fair_self_supervision_benchmark/models/resnet50_places205_supervised.pkl diff --git a/utils/convert_pytorch_to_caffe2.py b/utils/convert_pytorch_to_caffe2.py new file mode 100644 index 0000000..006bfa7 --- /dev/null +++ b/utils/convert_pytorch_to_caffe2.py @@ -0,0 +1,154 @@ +""" +Converts a torchvision model to caffe2 format. +""" + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division +from __future__ import unicode_literals + +import argparse +import errno +import os +import pickle +import re +import six +import sys +import torch +from collections import OrderedDict +from uuid import uuid4 + + +def _rename_basic_resnet_weights(layer_keys): + # Make Caffe2 compatible - architecture: + # https://github.com/facebookresearch/fair_self_supervision_benchmark/blob/master/self_supervision_benchmark/modeling/supervised/resnet_supervised_finetune_full.py + layer_keys = [k.replace(".downsample.1.", ".branch1_bn.") for k in layer_keys] + layer_keys = [k.replace(".downsample.0.", ".branch1.") for k in layer_keys] + layer_keys = [k.replace(".conv1.", ".branch2a.") for k in layer_keys] + layer_keys = [k.replace(".bn1.", ".branch2a_bn.") for k in layer_keys] + layer_keys = [k.replace(".conv2.", ".branch2b.") for k in layer_keys] + layer_keys = [k.replace(".bn2.", ".branch2b_bn.") for k in layer_keys] + layer_keys = [k.replace(".conv3.", ".branch2c.") for k in layer_keys] + layer_keys = [k.replace(".bn3.", ".branch2c_bn.") for k in layer_keys] + layer_keys = [k.replace("layer1.", "res2.") for k in layer_keys] + layer_keys = [k.replace("layer2.", "res3.") for k in layer_keys] + layer_keys = [k.replace("layer3.", "res4.") for k in layer_keys] + layer_keys = [k.replace("layer4.", "res5.") for k in layer_keys] + + layer_keys = [k.replace("bn1.", "conv1_bn.") for k in layer_keys] + layer_keys = [k.replace("conv1_", "res.conv1_") for k in layer_keys] + + layer_keys = [k.replace( "_bn.weight", "_bn.scale") for k in layer_keys] + layer_keys = [k.replace("_bn.scale", "_bn.s") for k in layer_keys] + layer_keys = [k.replace("_bn.running_mean", "_bn.rm") for k in layer_keys] + layer_keys = [k.replace("_bn.running_var", "_bn.riv") for k in layer_keys] + layer_keys = [k.replace(".weight", ".w") for k in layer_keys] + layer_keys = [k.replace(".bias", ".b") for k in layer_keys] + + layer_keys = [k.replace("_bn", ".bn") for k in layer_keys] + layer_keys = [k.replace(".", "_") for k in layer_keys] + + return layer_keys + + +def _rename_weights_for_resnet(weights): + original_keys = sorted(weights.keys()) + layer_keys = sorted(weights.keys()) + + layer_keys = _rename_basic_resnet_weights(layer_keys) + key_map = {k: v for k, v in zip(original_keys, layer_keys)} + + print("Remapping PyTorch weights") + max_pth_key_size = max([len(k) for k in original_keys]) + + new_weights = OrderedDict() + for k in original_keys: + v = weights[k] + if 'fc' in k: + continue + w = v.data.cpu().numpy() + print("PyTorch name: {: <{}} mapped name: {}".format( + k, max_pth_key_size, key_map[k])) + new_weights[key_map[k]] = w + print('Number of blobs: {}'.format(len(new_weights))) + return new_weights + + +def _load_pytorch_weights(file_path): + checkpoint = torch.load(file_path) + if "state_dict" in checkpoint: + weights = checkpoint["state_dict"] + elif "network" in checkpoint: + weights = checkpoint["network"] + else: + weights = checkpoint + return weights + + +def convert_rgb2bgr(state_dict): + key = 'conv1.weight' + weight = state_dict[key] + assert (weight.shape == (64, 3, 7, 7)) + weight_np = weight.detach().numpy() + weight_np = weight_np[:, ::-1, :, :].copy() + weight = torch.from_numpy(weight_np) + state_dict[key] = weight + print(f'BGR ===> RGB for {key}.') + return state_dict + + +def save_object(obj, file_name, pickle_format=2): + file_name = os.path.abspath(file_name) + # Avoid filesystem race conditions (particularly on network filesystems) + # by saving to a random tmp file on the same filesystem, and then + # atomically rename to the target filename. + tmp_file_name = file_name + ".tmp." + uuid4().hex + try: + with open(tmp_file_name, 'wb') as f: + pickle.dump(obj, f, pickle_format) + f.flush() # make sure it's written to disk + os.fsync(f.fileno()) + os.rename(tmp_file_name, file_name) + print('Saved: {}'.format(file_name)) + finally: + # Clean up the temp file on failure. Rather than using os.path.exists(), + # which can be unreliable on network filesystems, attempt to delete and + # ignore os errors. + try: + os.remove(tmp_file_name) + except EnvironmentError as e: # parent class of IOError, OSError + if getattr(e, 'errno', None) != errno.ENOENT: # We expect ENOENT + print("Could not delete temp file %r", + tmp_file_name, exc_info=True) + # pass through since we don't want the job to crash + + +def main(): + parser = argparse.ArgumentParser(description="Convert PyTorch model to C2") + parser.add_argument('--pth_model', type=str, default=None, + help='Path to PyTorch RN-50 model') + parser.add_argument('--output_model', type=str, default=None, + help='Path to save C2 RN-50 model') + parser.add_argument('--arch', type=str, default="R-50", + help='R-50 | R-101 | R-152') + parser.add_argument('--rgb2bgr', dest='rgb2bgr', default=False, + help='Revert bgr order to rgb order') + args = parser.parse_args() + + # load the pytorch model first + state_dict = _load_pytorch_weights(args.pth_model) + + # depending on the image reading library, we convert the weights to be + # compatible order. The default order of caffe2 weights is BGR (openCV). + if args.rgb2bgr: + state_dict = convert_rgb2bgr(state_dict) + + blobs = _rename_weights_for_resnet(state_dict) + blobs = dict(blobs=blobs) + print('Saving converted weights to: {}'.format(args.output_model)) + save_object(blobs, args.output_model) + print('Done!!') + + +if __name__ == '__main__': + main()