Skip to content

Commit

Permalink
Feature/sg 1033 fix yolox anchors (#1369)
Browse files Browse the repository at this point in the history
* Update readme

* Fix small bug in __repr__ implementation of KeypointsImageToTensor

* Test

* Test

* Test

* Test

* Test

* Test

* Make graphsurgeon an optional

* Make graphsurgeon an optional

* Properly handle imports of optional packages

* Added empty __init__.py files

* Do imports of gs inside the export call

* Do imports of gs inside the export call

* Fix DEKR's missing HasPredict interface

* Update notebook & example doc to reflect changes in imports & function names

* Update readme

* Put back images

* Install onnx_graphsurgeon in CI

* Install onnx_graphsurgeon in CI

* Working prototype of YoloX fix of Anchors that can load model weights as well

* Added more tests for detection predict() and yolox checkpoint loading

* Fix version of ONNX-GS installed in CI and installed on-demand

* Added docs

* Added docs

* Added docs

* Remove leftover

* Set ignore_errors=True to trainer test and declare why

* Fix bug in maybe_remove_module_prefix
  • Loading branch information
BloodAxe authored Aug 15, 2023
1 parent 138da9a commit 6a1f1bc
Show file tree
Hide file tree
Showing 7 changed files with 1,336 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,14 @@ def prep_model_for_conversion(self, input_size: Union[tuple, list] = None, **kwa
with convertible substitutes and remove all auxiliary or training related parts.
:param input_size: [H,W]
"""
self.head.cache_anchors(input_size)

# There is some discrepancy of what input_size is.
# When exporting to ONNX it is passed as 4-element tuple (B,C,H,W)
# When called from predict() it is just (H,W)
# So we take two last elements of the tuple which handles both cases but ultimately we should fix this
h, w = input_size[-2:]

self.head.cache_anchors((h, w))

for module in self.modules():
if isinstance(module, RepVGGBlock):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Union, Type, List, Tuple, Optional
from functools import lru_cache

import numpy as np
import torch
import torch.nn as nn

Expand Down Expand Up @@ -177,7 +178,7 @@ class DetectX(nn.Module):
def __init__(
self,
num_classes: int,
stride: torch.Tensor,
stride: np.ndarray,
activation_func_type: type,
channels: list,
depthwise=False,
Expand All @@ -203,7 +204,7 @@ def __init__(
self.n_anchors = 1
self.grid = [torch.zeros(1)] * self.detection_layers_num # init grid

self.register_buffer("stride", stride)
self.register_buffer("stride", torch.tensor(stride), persistent=False)

self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
Expand Down Expand Up @@ -409,7 +410,7 @@ def __init__(self, arch_params):
) # 24

self._shortcuts = nn.ModuleList([CrossModelSkipConnection() for _ in range(len(self._skip_connections_dict.keys()) - 1)])
self.anchors = anchors

self.width_mult = width_mult

def forward(self, intermediate_output):
Expand Down Expand Up @@ -481,6 +482,7 @@ def __init__(self, backbone: Type[nn.Module], arch_params: HpmStruct, initialize
self._image_processor: Optional[Processing] = None
self._default_nms_iou: Optional[float] = None
self._default_nms_conf: Optional[float] = None
self.register_buffer("strides", torch.tensor(self.arch_params.anchors.stride), persistent=False)

@staticmethod
def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback:
Expand Down Expand Up @@ -617,8 +619,6 @@ def _check_strides(self):
if not torch.equal(m.stride, stride):
raise RuntimeError("Provided anchor strides do not match the model strides")

self.register_buffer("stride", m.stride) # USED ONLY FOR CONVERSION

def _initialize_biases(self):
"""initialize biases into DetectX()"""
detect_module = self._head._modules_list[-1] # DetectX() module
Expand Down Expand Up @@ -650,7 +650,7 @@ def prep_model_for_conversion(self, input_size: Union[tuple, list] = None, **kwa
assert not self.training, "model has to be in eval mode to be converted"

# Verify dummy_input from converter is of multiple of the grid size
max_stride = int(max(self.stride))
max_stride = int(max(self.strides))

# Validate the image size
image_dims = input_size[-2:] # assume torch uses channels first layout
Expand Down
1,310 changes: 1,278 additions & 32 deletions src/super_gradients/training/utils/checkpoint_utils.py

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions src/super_gradients/training/utils/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def visualize_batch(
return out_images


class Anchors(nn.Module):
class Anchors:
"""
A wrapper function to hold the anchors used by detection models such as Yolo
"""
Expand All @@ -568,15 +568,15 @@ def __init__(self, anchors_list: List[List], strides: List[int]):
super().__init__()

self.__anchors_list = anchors_list
self.__strides = strides
self.__strides = tuple(strides)

self._check_all_lists(anchors_list)
self._check_all_len_equal_and_even(anchors_list)

self._stride = nn.Parameter(torch.Tensor(strides).float(), requires_grad=False)
anchors = torch.Tensor(anchors_list).float().view(len(anchors_list), -1, 2)
self._anchors = nn.Parameter(anchors / self._stride.view(-1, 1, 1), requires_grad=False)
self._anchor_grid = nn.Parameter(anchors.clone().view(len(anchors_list), 1, -1, 1, 1, 2), requires_grad=False)
self._stride = np.array(strides, dtype=np.float32)
anchors = np.array(anchors_list, dtype=np.float32).reshape((len(anchors_list), -1, 2))
self._anchors = anchors / self._stride.reshape((-1, 1, 1))
self._anchor_grid = anchors.copy().reshape(len(anchors_list), 1, -1, 1, 1, 2)

@staticmethod
def _check_all_lists(anchors: list) -> bool:
Expand All @@ -592,15 +592,15 @@ def _check_all_len_equal_and_even(anchors: list) -> bool:
raise RuntimeError("All objects of anchors_list must be of the same even length")

@property
def stride(self) -> nn.Parameter:
def stride(self) -> np.ndarray:
return self._stride

@property
def anchors(self) -> nn.Parameter:
def anchors(self) -> np.ndarray:
return self._anchors

@property
def anchor_grid(self) -> nn.Parameter:
def anchor_grid(self) -> np.ndarray:
return self._anchor_grid

@property
Expand Down
5 changes: 4 additions & 1 deletion tests/end_to_end_tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def tearDownClass(cls) -> None:
for experiment_name in cls.experiment_names:
experiment_dir = get_checkpoints_dir_path(experiment_name=experiment_name)
if os.path.isdir(experiment_dir):
shutil.rmtree(experiment_dir)
# TODO: Occasionally this method fails because log files are still open (See setup_logging() call).
# TODO: Need to find a way to close them at the end of training, this is however tricky to achieve
# TODO: because setup_logging() called outside of Trainer class.
shutil.rmtree(experiment_dir, ignore_errors=True)

@staticmethod
def get_classification_trainer(name=""):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/export_detection_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def manual_test_export_export_all_variants(self):
# pass

for model_type in [
# Models.YOLOX_S don't have full support for YOLOX so it's commented out,
Models.YOLOX_S,
Models.PP_YOLOE_S,
Models.YOLO_NAS_S,
]:
Expand Down
30 changes: 30 additions & 0 deletions tests/unit_tests/yolox_unit_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
import tempfile
import unittest

import torch

from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training.losses import YoloXDetectionLoss, YoloXFastDetectionLoss
from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X
from super_gradients.training.utils.detection_utils import DetectionCollateFN
Expand Down Expand Up @@ -69,6 +73,32 @@ def test_yolox_loss(self):
result = loss(predictions, targets.to(device))
print(result)

def test_yolo_x_checkpoint_solver(self):
"""
This test checks whether we can:
1. load an old pretrained weights for YoloX that has non-matching keys (Using custom solver under the hood).
2. load a regular checkpoint (As if one would train a model from scratch).
3. that both models produce the same output.
:return:
"""
model_variant = [Models.YOLOX_S, Models.YOLOX_M, Models.YOLOX_L, Models.YOLOX_T, Models.YOLOX_N]
for model_name in model_variant:
model = models.get(model_name, pretrained_weights="coco").eval()
input = torch.randn((1, 3, 320, 320))

output1 = model(input)

sd = model.state_dict()

with tempfile.TemporaryDirectory() as tmp_dirname:
path = os.path.join(tmp_dirname, f"{model_name}_coco.pth")
torch.save({"net": sd}, path)
model = models.get(model_name, num_classes=80, checkpoint_path=path).eval()
output2 = model(input)

assert torch.allclose(output1[0], output2[0], atol=1e-4)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6a1f1bc

Please sign in to comment.