diff --git a/README.md b/README.md index 1447803f..c264a0b8 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ and the configurations reuse the hyperparameters such as number of epochs used i Executable code can be found in [examples/](examples/) such as - [Image classification](examples/image_classification.py): ImageNet (ILSVRC 2012), CIFAR-10, CIFAR-100, etc - [Object detection](examples/object_detection.py): COCO 2017, etc -- [Semantic segmentation](examples/semantic_segmentation.py): COCO 2017, etc +- [Semantic segmentation](examples/semantic_segmentation.py): COCO 2017, PASCAL VOC, etc ## Google Colab Examples ### CIFAR-10 and CIFAR-100 diff --git a/configs/sample/cifar10/ce/densenet_bc_k12_depth100-final_run.yaml b/configs/sample/cifar10/ce/densenet_bc_k12_depth100-final_run.yaml index d34f61c2..abfb8240 100644 --- a/configs/sample/cifar10/ce/densenet_bc_k12_depth100-final_run.yaml +++ b/configs/sample/cifar10/ce/densenet_bc_k12_depth100-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/densenet_bc_k12_depth100-hyperparameter_tuning.yaml b/configs/sample/cifar10/ce/densenet_bc_k12_depth100-hyperparameter_tuning.yaml index 95e59eea..61311d9e 100644 --- a/configs/sample/cifar10/ce/densenet_bc_k12_depth100-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/ce/densenet_bc_k12_depth100-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/densenet_bc_k24_depth250-final_run.yaml b/configs/sample/cifar10/ce/densenet_bc_k24_depth250-final_run.yaml index 006a7eee..ee705f41 100644 --- a/configs/sample/cifar10/ce/densenet_bc_k24_depth250-final_run.yaml +++ b/configs/sample/cifar10/ce/densenet_bc_k24_depth250-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/densenet_bc_k24_depth250-hyperparameter_tuning.yaml b/configs/sample/cifar10/ce/densenet_bc_k24_depth250-hyperparameter_tuning.yaml index b4d24ff9..9ce2b6be 100644 --- a/configs/sample/cifar10/ce/densenet_bc_k24_depth250-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/ce/densenet_bc_k24_depth250-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/densenet_bc_k40_depth190-final_run.yaml b/configs/sample/cifar10/ce/densenet_bc_k40_depth190-final_run.yaml index 82dfc7e6..e07e7141 100644 --- a/configs/sample/cifar10/ce/densenet_bc_k40_depth190-final_run.yaml +++ b/configs/sample/cifar10/ce/densenet_bc_k40_depth190-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/densenet_bc_k40_depth190-hyperparameter_tuning.yaml b/configs/sample/cifar10/ce/densenet_bc_k40_depth190-hyperparameter_tuning.yaml index 1e3f3a5d..97324ed7 100644 --- a/configs/sample/cifar10/ce/densenet_bc_k40_depth190-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/ce/densenet_bc_k40_depth190-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/resnet110-final_run.yaml b/configs/sample/cifar10/ce/resnet110-final_run.yaml index de914f03..00a00bcc 100644 --- a/configs/sample/cifar10/ce/resnet110-final_run.yaml +++ b/configs/sample/cifar10/ce/resnet110-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/resnet110-hyperparameter_tuning.yaml b/configs/sample/cifar10/ce/resnet110-hyperparameter_tuning.yaml index 750c8f3f..b431c8d1 100644 --- a/configs/sample/cifar10/ce/resnet110-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/ce/resnet110-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/resnet1202-final_run.yaml b/configs/sample/cifar10/ce/resnet1202-final_run.yaml index c5f0a339..3bc6eb7a 100644 --- a/configs/sample/cifar10/ce/resnet1202-final_run.yaml +++ b/configs/sample/cifar10/ce/resnet1202-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/resnet1202-hyperparameter_tuning.yaml b/configs/sample/cifar10/ce/resnet1202-hyperparameter_tuning.yaml index 9e2e73c6..88acc144 100644 --- a/configs/sample/cifar10/ce/resnet1202-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/ce/resnet1202-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/resnet20-final_run.yaml b/configs/sample/cifar10/ce/resnet20-final_run.yaml index e5ea5bfc..3831348b 100644 --- a/configs/sample/cifar10/ce/resnet20-final_run.yaml +++ b/configs/sample/cifar10/ce/resnet20-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/resnet20-hyperparameter_tuning.yaml b/configs/sample/cifar10/ce/resnet20-hyperparameter_tuning.yaml index 0278dc1b..16277d8f 100644 --- a/configs/sample/cifar10/ce/resnet20-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/ce/resnet20-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/resnet32-final_run.yaml b/configs/sample/cifar10/ce/resnet32-final_run.yaml index b55e046b..59c8c9d2 100644 --- a/configs/sample/cifar10/ce/resnet32-final_run.yaml +++ b/configs/sample/cifar10/ce/resnet32-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/resnet32-hyperparameter_tuning.yaml b/configs/sample/cifar10/ce/resnet32-hyperparameter_tuning.yaml index a1703a4a..977a74ad 100644 --- a/configs/sample/cifar10/ce/resnet32-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/ce/resnet32-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/resnet44-final_run.yaml b/configs/sample/cifar10/ce/resnet44-final_run.yaml index 46574680..2bd14db9 100644 --- a/configs/sample/cifar10/ce/resnet44-final_run.yaml +++ b/configs/sample/cifar10/ce/resnet44-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/resnet44-hyperparameter_tuning.yaml b/configs/sample/cifar10/ce/resnet44-hyperparameter_tuning.yaml index 24e66c89..0cb64f40 100644 --- a/configs/sample/cifar10/ce/resnet44-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/ce/resnet44-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/resnet56-final_run.yaml b/configs/sample/cifar10/ce/resnet56-final_run.yaml index 2aad3f3e..664be6e9 100644 --- a/configs/sample/cifar10/ce/resnet56-final_run.yaml +++ b/configs/sample/cifar10/ce/resnet56-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/resnet56-hyperparameter_tuning.yaml b/configs/sample/cifar10/ce/resnet56-hyperparameter_tuning.yaml index 50c38ae7..b57632d6 100644 --- a/configs/sample/cifar10/ce/resnet56-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/ce/resnet56-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/wide_resnet16_8-final_run.yaml b/configs/sample/cifar10/ce/wide_resnet16_8-final_run.yaml index 16191b64..788c69a7 100644 --- a/configs/sample/cifar10/ce/wide_resnet16_8-final_run.yaml +++ b/configs/sample/cifar10/ce/wide_resnet16_8-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/wide_resnet16_8-hyperparameter_tuning.yaml b/configs/sample/cifar10/ce/wide_resnet16_8-hyperparameter_tuning.yaml index 8e5c4fd9..33b03bd8 100644 --- a/configs/sample/cifar10/ce/wide_resnet16_8-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/ce/wide_resnet16_8-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/wide_resnet28_10-final_run.yaml b/configs/sample/cifar10/ce/wide_resnet28_10-final_run.yaml index cd368ed7..4160b7b2 100644 --- a/configs/sample/cifar10/ce/wide_resnet28_10-final_run.yaml +++ b/configs/sample/cifar10/ce/wide_resnet28_10-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/wide_resnet28_10-hyperparameter_tuning.yaml b/configs/sample/cifar10/ce/wide_resnet28_10-hyperparameter_tuning.yaml index 883a5fac..a0f5bf39 100644 --- a/configs/sample/cifar10/ce/wide_resnet28_10-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/ce/wide_resnet28_10-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/wide_resnet40_4-final_run.yaml b/configs/sample/cifar10/ce/wide_resnet40_4-final_run.yaml index dfd28660..ec652ec4 100644 --- a/configs/sample/cifar10/ce/wide_resnet40_4-final_run.yaml +++ b/configs/sample/cifar10/ce/wide_resnet40_4-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/ce/wide_resnet40_4-hyperparameter_tuning.yaml b/configs/sample/cifar10/ce/wide_resnet40_4-hyperparameter_tuning.yaml index 6b03cc47..adfec716 100644 --- a/configs/sample/cifar10/ce/wide_resnet40_4-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/ce/wide_resnet40_4-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/kd/resnet20_from_densenet_bc_k12_depth100-final_run.yaml b/configs/sample/cifar10/kd/resnet20_from_densenet_bc_k12_depth100-final_run.yaml index ae0a7004..4bdb4959 100644 --- a/configs/sample/cifar10/kd/resnet20_from_densenet_bc_k12_depth100-final_run.yaml +++ b/configs/sample/cifar10/kd/resnet20_from_densenet_bc_k12_depth100-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/kd/resnet20_from_densenet_bc_k12_depth100-hyperparameter_tuning.yaml b/configs/sample/cifar10/kd/resnet20_from_densenet_bc_k12_depth100-hyperparameter_tuning.yaml index ff5116a9..3f17e6e3 100644 --- a/configs/sample/cifar10/kd/resnet20_from_densenet_bc_k12_depth100-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/kd/resnet20_from_densenet_bc_k12_depth100-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/kd/wide_resnet40_1_from_wide_resnet40_4-final_run.yaml b/configs/sample/cifar10/kd/wide_resnet40_1_from_wide_resnet40_4-final_run.yaml index 493619a3..9efc19d4 100644 --- a/configs/sample/cifar10/kd/wide_resnet40_1_from_wide_resnet40_4-final_run.yaml +++ b/configs/sample/cifar10/kd/wide_resnet40_1_from_wide_resnet40_4-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar10/kd/wide_resnet40_1_from_wide_resnet40_4-hyperparameter_tuning.yaml b/configs/sample/cifar10/kd/wide_resnet40_1_from_wide_resnet40_4-hyperparameter_tuning.yaml index 7babf19a..684926a4 100644 --- a/configs/sample/cifar10/kd/wide_resnet40_1_from_wide_resnet40_4-hyperparameter_tuning.yaml +++ b/configs/sample/cifar10/kd/wide_resnet40_1_from_wide_resnet40_4-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar10: name: &dataset_name 'cifar10' type: 'CIFAR10' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar100/ce/densenet_bc_k12_depth100-final_run.yaml b/configs/sample/cifar100/ce/densenet_bc_k12_depth100-final_run.yaml index c6c97686..f8d832dc 100644 --- a/configs/sample/cifar100/ce/densenet_bc_k12_depth100-final_run.yaml +++ b/configs/sample/cifar100/ce/densenet_bc_k12_depth100-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar100: name: &dataset_name 'cifar100' type: 'CIFAR100' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar100/ce/densenet_bc_k12_depth100-hyperparameter_tuning.yaml b/configs/sample/cifar100/ce/densenet_bc_k12_depth100-hyperparameter_tuning.yaml index 6a6dec82..4b2a97bf 100644 --- a/configs/sample/cifar100/ce/densenet_bc_k12_depth100-hyperparameter_tuning.yaml +++ b/configs/sample/cifar100/ce/densenet_bc_k12_depth100-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar100: name: &dataset_name 'cifar100' type: 'CIFAR100' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar100/ce/densenet_bc_k24_depth250-final_run.yaml b/configs/sample/cifar100/ce/densenet_bc_k24_depth250-final_run.yaml index 981aaf78..fabeafd5 100644 --- a/configs/sample/cifar100/ce/densenet_bc_k24_depth250-final_run.yaml +++ b/configs/sample/cifar100/ce/densenet_bc_k24_depth250-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar100: name: &dataset_name 'cifar100' type: 'CIFAR100' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar100/ce/densenet_bc_k24_depth250-hyperparameter_tuning.yaml b/configs/sample/cifar100/ce/densenet_bc_k24_depth250-hyperparameter_tuning.yaml index 40ce29f1..b37a7aea 100644 --- a/configs/sample/cifar100/ce/densenet_bc_k24_depth250-hyperparameter_tuning.yaml +++ b/configs/sample/cifar100/ce/densenet_bc_k24_depth250-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar100: name: &dataset_name 'cifar100' type: 'CIFAR100' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar100/ce/densenet_bc_k40_depth190-final_run.yaml b/configs/sample/cifar100/ce/densenet_bc_k40_depth190-final_run.yaml index 4171e419..d30f9f95 100644 --- a/configs/sample/cifar100/ce/densenet_bc_k40_depth190-final_run.yaml +++ b/configs/sample/cifar100/ce/densenet_bc_k40_depth190-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar100: name: &dataset_name 'cifar100' type: 'CIFAR100' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar100/ce/densenet_bc_k40_depth190-hyperparameter_tuning.yaml b/configs/sample/cifar100/ce/densenet_bc_k40_depth190-hyperparameter_tuning.yaml index 0e97e25e..77da9df7 100644 --- a/configs/sample/cifar100/ce/densenet_bc_k40_depth190-hyperparameter_tuning.yaml +++ b/configs/sample/cifar100/ce/densenet_bc_k40_depth190-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar100: name: &dataset_name 'cifar100' type: 'CIFAR100' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar100/ce/wide_resnet16_8-final_run.yaml b/configs/sample/cifar100/ce/wide_resnet16_8-final_run.yaml index 907e2ba9..edab5b70 100644 --- a/configs/sample/cifar100/ce/wide_resnet16_8-final_run.yaml +++ b/configs/sample/cifar100/ce/wide_resnet16_8-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar100: name: &dataset_name 'cifar100' type: 'CIFAR100' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar100/ce/wide_resnet16_8-hyperparameter_tuning.yaml b/configs/sample/cifar100/ce/wide_resnet16_8-hyperparameter_tuning.yaml index ba54cf73..6b33f1b2 100644 --- a/configs/sample/cifar100/ce/wide_resnet16_8-hyperparameter_tuning.yaml +++ b/configs/sample/cifar100/ce/wide_resnet16_8-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar100: name: &dataset_name 'cifar100' type: 'CIFAR100' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar100/ce/wide_resnet28_10-final_run.yaml b/configs/sample/cifar100/ce/wide_resnet28_10-final_run.yaml index 8fd95f84..f8861463 100644 --- a/configs/sample/cifar100/ce/wide_resnet28_10-final_run.yaml +++ b/configs/sample/cifar100/ce/wide_resnet28_10-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar100: name: &dataset_name 'cifar100' type: 'CIFAR100' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar100/ce/wide_resnet28_10-hyperparameter_tuning.yaml b/configs/sample/cifar100/ce/wide_resnet28_10-hyperparameter_tuning.yaml index a313f0b5..854a454e 100644 --- a/configs/sample/cifar100/ce/wide_resnet28_10-hyperparameter_tuning.yaml +++ b/configs/sample/cifar100/ce/wide_resnet28_10-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar100: name: &dataset_name 'cifar100' type: 'CIFAR100' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar100/ce/wide_resnet40_4-final_run.yaml b/configs/sample/cifar100/ce/wide_resnet40_4-final_run.yaml index ed203124..f0eaa6f8 100644 --- a/configs/sample/cifar100/ce/wide_resnet40_4-final_run.yaml +++ b/configs/sample/cifar100/ce/wide_resnet40_4-final_run.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar100: name: &dataset_name 'cifar100' type: 'CIFAR100' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/configs/sample/cifar100/ce/wide_resnet40_4-hyperparameter_tuning.yaml b/configs/sample/cifar100/ce/wide_resnet40_4-hyperparameter_tuning.yaml index d4173dab..cb9a2876 100644 --- a/configs/sample/cifar100/ce/wide_resnet40_4-hyperparameter_tuning.yaml +++ b/configs/sample/cifar100/ce/wide_resnet40_4-hyperparameter_tuning.yaml @@ -1,5 +1,5 @@ datasets: - ilsvrc2012: + cifar100: name: &dataset_name 'cifar100' type: 'CIFAR100' root: &root_dir !join ['./resource/dataset/', *dataset_name] diff --git a/setup.py b/setup.py index 51aacd3a..2e61493b 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ 'Trained models, training logs and configurations are available for ensuring the reproducibiliy.' setup( name='torchdistill', - version='0.1.2', + version='0.1.3', description=description, long_description=long_description, long_description_content_type='text/markdown', diff --git a/torchdistill/core/distillation.py b/torchdistill/core/distillation.py index 89c24e9d..e92aea76 100644 --- a/torchdistill/core/distillation.py +++ b/torchdistill/core/distillation.py @@ -144,13 +144,18 @@ def setup(self, train_config): trainable_module_list.append(self.teacher_model) self.optimizer = get_optimizer(trainable_module_list, optim_config['type'], optim_params_config) + self.optimizer.zero_grad() + self.max_grad_norm = optim_config.get('max_grad_norm', None) + self.grad_accum_step = optim_config.get('grad_accum_step', 1) optimizer_reset = True scheduler_config = train_config.get('scheduler', None) if scheduler_config is not None and len(scheduler_config) > 0: self.lr_scheduler = get_scheduler(self.optimizer, scheduler_config['type'], scheduler_config['params']) + self.scheduling_step = scheduler_config.get('scheduling_step', 0) elif optimizer_reset: self.lr_scheduler = None + self.scheduling_step = None # Set up apex if you require mixed-precision training self.apex = False @@ -168,6 +173,7 @@ def setup(self, train_config): def __init__(self, teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor): super().__init__() + # Key attributes (should not be modified) self.org_teacher_model = teacher_model self.org_student_model = student_model self.dataset_dict = dataset_dict @@ -175,6 +181,7 @@ def __init__(self, teacher_model, student_model, dataset_dict, self.device_ids = device_ids self.distributed = distributed self.lr_factor = lr_factor + # Local attributes (can be updated at each stage) self.teacher_model = None self.student_model = None self.teacher_forward_proc, self.student_forward_proc = None, None @@ -183,6 +190,10 @@ def __init__(self, teacher_model, student_model, dataset_dict, self.train_data_loader, self.val_data_loader, self.optimizer, self.lr_scheduler = None, None, None, None self.org_criterion, self.criterion, self.uses_teacher_output, self.extract_org_loss = None, None, None, None self.teacher_updatable, self.teacher_any_frozen, self.student_any_frozen = None, None, None + self.grad_accum_step = None + self.max_grad_norm = None + self.scheduling_step = 0 + self.stage_grad_count = 0 self.apex = None self.setup(train_config) self.num_epochs = train_config['num_epochs'] @@ -265,16 +276,33 @@ def forward(self, sample_batch, targets, supp_dict): return total_loss def update_params(self, loss): - self.optimizer.zero_grad() + self.stage_grad_count += 1 + if self.grad_accum_step > 1: + loss /= self.grad_accum_step + if self.apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() - self.optimizer.step() + + if self.stage_grad_count % self.grad_accum_step == 0: + if self.max_grad_norm is not None: + target_params = amp.master_params(self.optimizer) if self.apex \ + else [p for group in self.optimizer.param_groups for p in group['group']] + torch.nn.utils.clip_grad_norm_(target_params, self.max_grad_norm) + + self.optimizer.step() + self.optimizer.zero_grad() + + # Step-wise scheduler step + if self.lr_scheduler is not None and self.scheduling_step > 0 \ + and self.stage_grad_count % self.scheduling_step == 0: + self.lr_scheduler.step() def post_process(self, **kwargs): - if self.lr_scheduler is not None: + # Epoch-wise scheduler step + if self.lr_scheduler is not None and self.scheduling_step <= 0: self.lr_scheduler.step() if isinstance(self.teacher_model, SpecialModule): self.teacher_model.post_process() @@ -310,6 +338,7 @@ def __init__(self, teacher_model, student_model, data_loader_dict, def advance_to_next_stage(self): self.clean_modules() + self.stage_grad_count = 0 self.stage_number += 1 next_stage_config = self.train_config['stage{}'.format(self.stage_number)] self.setup(next_stage_config) diff --git a/torchdistill/core/training.py b/torchdistill/core/training.py index 8d404968..1474cc64 100644 --- a/torchdistill/core/training.py +++ b/torchdistill/core/training.py @@ -1,5 +1,6 @@ import sys +import torch from torch import distributed as dist from torch import nn @@ -108,13 +109,18 @@ def setup(self, train_config): trainable_module_list = nn.ModuleList([self.model]) self.optimizer = get_optimizer(trainable_module_list, optim_config['type'], optim_params_config) + self.optimizer.zero_grad() + self.max_grad_norm = optim_config.get('max_grad_norm', None) + self.grad_accum_step = optim_config.get('grad_accum_step', 1) optimizer_reset = True scheduler_config = train_config.get('scheduler', None) if scheduler_config is not None and len(scheduler_config) > 0: self.lr_scheduler = get_scheduler(self.optimizer, scheduler_config['type'], scheduler_config['params']) + self.scheduling_step = scheduler_config.get('scheduling_step', 0) elif optimizer_reset: self.lr_scheduler = None + self.scheduling_step = None # Set up apex if you require mixed-precision training self.apex = False @@ -131,12 +137,14 @@ def setup(self, train_config): def __init__(self, model, dataset_dict, train_config, device, device_ids, distributed, lr_factor): super().__init__() + # Key attributes (should not be modified) self.org_model = model self.dataset_dict = dataset_dict self.device = device self.device_ids = device_ids self.distributed = distributed self.lr_factor = lr_factor + # Local attributes (can be updated at each stage) self.model = None self.model_forward_proc = None self.target_model_pairs = list() @@ -144,6 +152,10 @@ def __init__(self, model, dataset_dict, train_config, device, device_ids, distri self.train_data_loader, self.val_data_loader, self.optimizer, self.lr_scheduler = None, None, None, None self.org_criterion, self.criterion, self.extract_org_loss = None, None, None self.model_any_frozen = None + self.grad_accum_step = None + self.max_grad_norm = None + self.scheduling_step = 0 + self.stage_grad_count = 0 self.apex = None self.setup(train_config) self.num_epochs = train_config['num_epochs'] @@ -168,16 +180,33 @@ def forward(self, sample_batch, targets, supp_dict): return total_loss def update_params(self, loss): - self.optimizer.zero_grad() + self.stage_grad_count += 1 + if self.grad_accum_step > 1: + loss /= self.grad_accum_step + if self.apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() - self.optimizer.step() + + if self.stage_grad_count % self.grad_accum_step == 0: + if self.max_grad_norm is not None: + target_params = amp.master_params(self.optimizer) if self.apex \ + else [p for group in self.optimizer.param_groups for p in group['group']] + torch.nn.utils.clip_grad_norm_(target_params, self.max_grad_norm) + + self.optimizer.step() + self.optimizer.zero_grad() + + # Step-wise scheduler step + if self.lr_scheduler is not None and self.scheduling_step > 0 \ + and self.stage_grad_count % self.scheduling_step == 0: + self.lr_scheduler.step() def post_process(self, **kwargs): - if self.lr_scheduler is not None: + # Epoch-wise scheduler step + if self.lr_scheduler is not None and self.scheduling_step <= 0: self.lr_scheduler.step() if isinstance(self.model, SpecialModule): self.model.post_process() @@ -206,6 +235,7 @@ def __init__(self, model, data_loader_dict, train_config, device, device_ids, di def advance_to_next_stage(self): self.clean_modules() + self.stage_grad_count = 0 self.stage_number += 1 next_stage_config = self.train_config['stage{}'.format(self.stage_number)] self.setup(next_stage_config) diff --git a/torchdistill/datasets/util.py b/torchdistill/datasets/util.py index 572ce7e3..1497acc4 100644 --- a/torchdistill/datasets/util.py +++ b/torchdistill/datasets/util.py @@ -4,7 +4,8 @@ import torchvision from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, random_split from torch.utils.data.distributed import DistributedSampler -from torchvision.datasets import PhotoTour, VOCDetection, Kinetics400, HMDB51, UCF101 +from torchvision.datasets import PhotoTour, Kinetics400, HMDB51, UCF101, Cityscapes, CocoCaptions, CocoDetection, \ + SBDataset, VOCSegmentation, VOCDetection from torchdistill.common.constant import def_logger from torchdistill.datasets.coco import ImageToTensor, Compose, CocoRandomHorizontalFlip, get_coco @@ -35,6 +36,9 @@ def build_transform(transform_params_config, compose_cls=None): if not isinstance(transform_params_config, (dict, list)) or len(transform_params_config) == 0: return None + if isinstance(compose_cls, str): + compose_cls = TRANSFORM_CLASS_DICT[compose_cls] + component_list = list() if isinstance(transform_params_config, dict): for component_key in sorted(transform_params_config.keys()): @@ -58,8 +62,13 @@ def build_transform(transform_params_config, compose_cls=None): def get_torchvision_dataset(dataset_cls, dataset_params_config): params_config = dataset_params_config.copy() - transform = build_transform(params_config.pop('transform_params', None)) - target_transform = build_transform(params_config.pop('target_transform_params', None)) + transform_compose_cls_name = params_config.pop('transform_compose_cls', None) + transform = build_transform(params_config.pop('transform_params', None), compose_cls=transform_compose_cls_name) + target_transform_compose_cls_name = params_config.pop('target_transform_compose_cls', None) + target_transform = build_transform(params_config.pop('target_transform_params', None), + compose_cls=target_transform_compose_cls_name) + transforms_compose_cls_name = params_config.pop('transforms_compose_cls', None) + transforms = build_transform(params_config.pop('transforms_params', None), compose_cls=transforms_compose_cls_name) if 'loader' in params_config: loader_config = params_config.pop('loader') loader_type = loader_config['type'] @@ -69,8 +78,12 @@ def get_torchvision_dataset(dataset_cls, dataset_params_config): params_config['loader'] = loader # For datasets without target_transform - if dataset_cls in (PhotoTour, VOCDetection, Kinetics400, HMDB51, UCF101): + if dataset_cls in (PhotoTour, Kinetics400, HMDB51, UCF101): return dataset_cls(transform=transform, **params_config) + # For datasets with transforms + if dataset_cls in (Cityscapes, CocoCaptions, CocoDetection, SBDataset, VOCSegmentation, VOCDetection): + return dataset_cls(transform=transform, target_transform=target_transform, + transforms=transforms, **params_config) return dataset_cls(transform=transform, target_transform=target_transform, **params_config) @@ -96,10 +109,13 @@ def split_dataset(org_dataset, random_split_config, dataset_id, dataset_dict): params_config = sub_split_params.copy() transform = build_transform(params_config.pop('transform_params', None)) target_transform = build_transform(params_config.pop('transform_params', None)) + transforms = build_transform(params_config.pop('transforms_params', None)) if hasattr(sub_dataset.dataset, 'transform') and transform is not None: sub_dataset.dataset.transform = transform if hasattr(sub_dataset.dataset, 'target_transform') and target_transform is not None: sub_dataset.dataset.target_transform = target_transform + if hasattr(sub_dataset.dataset, 'transforms') and transforms is not None: + sub_dataset.dataset.transforms = transforms dataset_dict[sub_dataset_id] = sub_dataset @@ -133,7 +149,7 @@ def get_dataset_dict(dataset_config): dataset_dict[dataset_id] = org_dataset else: split_dataset(org_dataset, random_split_config, dataset_id, dataset_dict) - logger.info('{} sec'.format(time.time() - st)) + logger.info('dataset_id `{}`: {} sec'.format(dataset_id, time.time() - st)) else: raise ValueError('dataset_type `{}` is not expected'.format(dataset_type)) return dataset_dict