From dcb54919e132cf8b901c68e8d4bc834f2aff0aa6 Mon Sep 17 00:00:00 2001 From: Vinicius Reis Date: Mon, 24 Feb 2020 11:14:52 -0800 Subject: [PATCH] Add standard resnet models (#405) Summary: Pull Request resolved: https://github.com/facebookresearch/ClassyVision/pull/405 It is annoying to configure the ResNet blocks all the time. Add the standard models to the code so we can refer to them by name Differential Revision: D20050757 fbshipit-source-id: 82ef26dc660f5ac3b9778560f40135810c5acdc4 --- classy_vision/models/resnext.py | 121 +++++++++++++++++++++++++++++--- 1 file changed, 113 insertions(+), 8 deletions(-) diff --git a/classy_vision/models/resnext.py b/classy_vision/models/resnext.py index 044de5bb99..d017d9314a 100644 --- a/classy_vision/models/resnext.py +++ b/classy_vision/models/resnext.py @@ -9,7 +9,7 @@ """ import math -from typing import Any, Dict +from typing import Any, Dict, List, Optional, Tuple, Union import torch.nn as nn from classy_vision.generic.util import is_pos_int @@ -228,13 +228,13 @@ class ResNeXt(ClassyModel): def __init__( self, num_blocks, - init_planes, - reduction, - small_input, - zero_init_bn_residuals, - base_width_and_cardinality, - basic_layer, - final_bn_relu, + init_planes: int = 64, + reduction: int = 4, + small_input: bool = False, + zero_init_bn_residuals: bool = False, + base_width_and_cardinality: Optional[Union[Tuple, List]] = None, + basic_layer: bool = False, + final_bn_relu: bool = True, ): """ Implementation of `ResNeXt `_. @@ -414,3 +414,108 @@ def output_shape(self): @property def model_depth(self): return sum(self.num_blocks) + + +@register_model("resnet18") +class ResNet18(ResNeXt): + def __init__(self): + super().__init__( + num_blocks=[2, 2, 2, 2], basic_layer=True, zero_init_bn_residuals=True + ) + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": + return cls() + + +@register_model("resnet34") +class ResNet34(ResNeXt): + def __init__(self): + super().__init__( + num_blocks=[3, 4, 6, 3], basic_layer=True, zero_init_bn_residuals=True + ) + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": + return cls() + + +@register_model("resnet50") +class ResNet50(ResNeXt): + def __init__(self): + super().__init__( + num_blocks=[3, 4, 6, 3], basic_layer=False, zero_init_bn_residuals=True + ) + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": + return cls() + + +@register_model("resnet101") +class ResNet101(ResNeXt): + def __init__(self): + super().__init__( + num_blocks=[3, 4, 23, 3], basic_layer=False, zero_init_bn_residuals=True + ) + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": + return cls() + + +@register_model("resnet152") +class ResNet152(ResNeXt): + def __init__(self): + super().__init__( + num_blocks=[3, 8, 36, 3], basic_layer=False, zero_init_bn_residuals=True + ) + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": + return cls() + + +@register_model("resnext50_32x4d") +class ResNeXt50(ResNeXt): + def __init__(self): + super().__init__( + num_blocks=[3, 4, 6, 3], + basic_layer=False, + zero_init_bn_residuals=True, + base_width_and_cardinality=(4, 32), + ) + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": + return cls() + + +@register_model("resnext101_32x4d") +class ResNeXt101(ResNeXt): + def __init__(self): + super().__init__( + num_blocks=[3, 4, 23, 3], + basic_layer=False, + zero_init_bn_residuals=True, + base_width_and_cardinality=(4, 32), + ) + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": + return cls() + + +@register_model("resnext152_32x4d") +class ResNeXt152(ResNeXt): + def __init__(self): + super().__init__( + num_blocks=[3, 8, 36, 3], + basic_layer=False, + zero_init_bn_residuals=True, + base_width_and_cardinality=(4, 32), + ) + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": + return cls()