Skip to content

Commit

Permalink
Hotfix/sg 000 fix breaking changes in some imports (#1101)
Browse files Browse the repository at this point in the history
* Added backward compatibility fixes for imports

* Added backward compatibility fixes for imports
  • Loading branch information
BloodAxe authored May 30, 2023
1 parent d92337a commit 6b5785d
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 14 deletions.
38 changes: 38 additions & 0 deletions src/super_gradients/training/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from .sg_module import SgModule

# Classification models
Expand Down Expand Up @@ -33,6 +35,7 @@
PreActResNet152,
)
from super_gradients.training.models.classification_models.resnet import (
BasicResNetBlock,
ResNet,
ResNet18,
ResNet34,
Expand Down Expand Up @@ -129,6 +132,38 @@
from super_gradients.common.object_names import Models
from super_gradients.common.registry.registry import ARCHITECTURES

from super_gradients.training.utils import make_divisible as _make_divisible_current_version


def make_deprecated(func, reason):
def inner(*args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter("once", DeprecationWarning)
warnings.warn(reason, category=DeprecationWarning, stacklevel=2)
warnings.warn(reason, DeprecationWarning)
return func(*args, **kwargs)

return inner


make_divisible = make_deprecated(
func=_make_divisible_current_version,
reason="You're importing `make_divisible` from `super_gradients.training.models`. This is deprecated since SuperGradients 3.1.0.\n"
"Please update your code to import it as follows:\n"
"[-] from super_gradients.training.models import make_divisible\n"
"[+] from super_gradients.training.utils import make_divisible\n",
)


BasicBlock = make_deprecated(
func=BasicResNetBlock,
reason="You're importing `BasicBlock` class from `super_gradients.training.models`. This is deprecated since SuperGradients 3.1.0.\n"
"This block was renamed to BasicResNetBlock for better clarity.\n"
"Please update your code to import it as follows:\n"
"[-] from super_gradients.training.models import BasicBlock\n"
"[+] from super_gradients.training.models import BasicResNetBlock\n",
)

__all__ = [
"SPP",
"YoloNAS_S",
Expand Down Expand Up @@ -293,4 +328,7 @@
"SegFormerB4",
"SegFormerB5",
"DDRNet39Backbone",
"make_divisible",
"BasicResNetBlock",
"BasicBlock",
]
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from super_gradients.common.object_names import Models


class BasicBlock(nn.Module):
class BasicResNetBlock(nn.Module):
def __init__(self, in_planes, planes, stride=1, expansion=1, final_relu=True, droppath_prob=0.0):
super(BasicBlock, self).__init__()
super(BasicResNetBlock, self).__init__()
self.expansion = expansion
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
Expand Down Expand Up @@ -236,7 +236,7 @@ def replace_head(self, new_num_classes=None, new_head=None):
class ResNet18(ResNet):
def __init__(self, arch_params, num_classes=None):
super().__init__(
BasicBlock,
BasicResNetBlock,
[2, 2, 2, 2],
num_classes=num_classes or arch_params.num_classes,
droppath_prob=get_param(arch_params, "droppath_prob", 0),
Expand All @@ -247,14 +247,14 @@ def __init__(self, arch_params, num_classes=None):
@register_model(Models.RESNET18_CIFAR)
class ResNet18Cifar(CifarResNet):
def __init__(self, arch_params, num_classes=None):
super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes or arch_params.num_classes)
super().__init__(BasicResNetBlock, [2, 2, 2, 2], num_classes=num_classes or arch_params.num_classes)


@register_model(Models.RESNET34)
class ResNet34(ResNet):
def __init__(self, arch_params, num_classes=None):
super().__init__(
BasicBlock,
BasicResNetBlock,
[3, 4, 6, 3],
num_classes=num_classes or arch_params.num_classes,
droppath_prob=get_param(arch_params, "droppath_prob", 0),
Expand Down Expand Up @@ -317,7 +317,7 @@ def __init__(self, arch_params, num_classes=None):
@register_model(Models.CUSTOM_RESNET_CIFAR)
class CustomizedResnetCifar(CifarResNet):
def __init__(self, arch_params, num_classes=None):
super().__init__(BasicBlock, arch_params.structure, width_mult=arch_params.width_mult, num_classes=num_classes or arch_params.num_classes)
super().__init__(BasicResNetBlock, arch_params.structure, width_mult=arch_params.width_mult, num_classes=num_classes or arch_params.num_classes)


@register_model(Models.CUSTOM_RESNET50_CIFAR)
Expand All @@ -330,7 +330,7 @@ def __init__(self, arch_params, num_classes=None):
class CustomizedResnet(ResNet):
def __init__(self, arch_params, num_classes=None):
super().__init__(
BasicBlock,
BasicResNetBlock,
arch_params.structure,
width_mult=arch_params.width_mult,
num_classes=num_classes or arch_params.num_classes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from super_gradients.common.registry.registry import register_model
from super_gradients.common.object_names import Models
from super_gradients.training.models.classification_models.resnet import BasicBlock, Bottleneck
from super_gradients.training.models.classification_models.resnet import BasicResNetBlock, Bottleneck
from super_gradients.training.models.segmentation_models.segmentation_module import SegmentationModule
from super_gradients.training.utils import get_param, HpmStruct

Expand Down Expand Up @@ -530,8 +530,8 @@ def __init__(self, arch_params: HpmStruct):

DEFAULT_DDRNET_23_PARAMS = {
"input_channels": 3,
"block": BasicBlock,
"skip_block": BasicBlock,
"block": BasicResNetBlock,
"skip_block": BasicResNetBlock,
"layer5_block": Bottleneck,
"layer5_bottleneck_expansion": 2,
"layers": [2, 2, 2, 2, 1, 2, 2, 1],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from super_gradients.training.utils import HpmStruct
from super_gradients.common.registry.registry import register_model
from super_gradients.common.object_names import Models
from super_gradients.training.models.classification_models.resnet import BasicBlock, ResNet, Bottleneck
from super_gradients.training.models.classification_models.resnet import BasicResNetBlock, ResNet, Bottleneck


class FCNHead(nn.Module):
Expand Down Expand Up @@ -93,12 +93,12 @@ def forward(self, x):

class ShelfResNetBackBone18(ShelfResNetBackBone):
def __init__(self, num_classes: int):
super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
super().__init__(BasicResNetBlock, [2, 2, 2, 2], num_classes=num_classes)


class ShelfResNetBackBone34(ShelfResNetBackBone):
def __init__(self, num_classes: int):
super().__init__(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)
super().__init__(BasicResNetBlock, [3, 4, 6, 3], num_classes=num_classes)


class ShelfResNetBackBone503343(ShelfResNetBackBone):
Expand Down
12 changes: 11 additions & 1 deletion src/super_gradients/training/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
from super_gradients.training.utils.utils import Timer, HpmStruct, WrappedModel, convert_to_tensor, get_param, tensor_container_to_device, random_seed
from super_gradients.training.utils.utils import (
Timer,
HpmStruct,
WrappedModel,
convert_to_tensor,
get_param,
tensor_container_to_device,
random_seed,
make_divisible,
)
from super_gradients.training.utils.checkpoint_utils import adapt_state_dict_to_fit_model_layer_names, raise_informative_runtime_error
from super_gradients.training.utils.version_utils import torch_version_is_greater_or_equal
from super_gradients.training.utils.config_utils import raise_if_unused_params, warn_if_unused_params
Expand All @@ -21,4 +30,5 @@
"EarlyStop",
"DEKRPoseEstimationDecodeCallback",
"DEKRVisualizationCallback",
"make_divisible",
]
18 changes: 18 additions & 0 deletions tests/unit_tests/test_deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,24 @@ def test_deprecated_arch_params_top_level_class_via_models_get(self):
model = models.get("DummyModelV2", arch_params=arch_params, num_classes=80)
assert isinstance(model, DummyModelV2)

def test_deprecated_make_divisible(self):
try:
with self.assertWarns(DeprecationWarning):
from super_gradients.training.models import make_divisible # noqa

assert make_divisible(1, 1) == 1
except ImportError:
self.fail("ImportError raised unexpectedly for make_divisible")

def test_deprecated_BasicBlock(self):
try:
with self.assertWarns(DeprecationWarning):
from super_gradients.training.models import BasicBlock, BasicResNetBlock # noqa

assert isinstance(BasicBlock(1, 1, 1), BasicResNetBlock)
except ImportError:
self.fail("ImportError raised unexpectedly for BasicBlock")


if __name__ == "__main__":
unittest.main()

0 comments on commit 6b5785d

Please sign in to comment.