From f9fa5097b6f4fd41b1510eef6497de43d672c89d Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 24 Jan 2024 14:20:54 +0200 Subject: [PATCH 1/5] Deprecate cache and cache_dir parameters support --- .../coco_detection_dataset_params.yaml | 4 - ...coco_detection_ppyoloe_dataset_params.yaml | 4 - ..._ssd_lite_mobilenet_v2_dataset_params.yaml | 4 - ...ction_yolo_format_base_dataset_params.yaml | 4 - ...oco_detection_yolo_nas_dataset_params.yaml | 4 - .../pascal_voc_detection_dataset_params.yaml | 4 - .../roboflow_detection_dataset_params.yaml | 4 - .../detection_datasets/detection_dataset.py | 110 +++--------------- .../pascal_voc_detection.py | 4 +- tests/deci_core_unit_test_suite_runner.py | 2 - tests/unit_tests/detection_caching.py | 109 ----------------- 11 files changed, 21 insertions(+), 232 deletions(-) delete mode 100644 tests/unit_tests/detection_caching.py diff --git a/src/super_gradients/recipes/dataset_params/coco_detection_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/coco_detection_dataset_params.yaml index 1a6ed2ec46..4c1b74073a 100644 --- a/src/super_gradients/recipes/dataset_params/coco_detection_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/coco_detection_dataset_params.yaml @@ -3,8 +3,6 @@ train_dataset_params: subdir: images/train2017 # sub directory path of data_dir containing the train data. json_file: instances_train2017.json # path to coco train json file, data_dir/annotations/train_json_file. input_dim: [640, 640] - cache_dir: - cache: False cache_annotations: True ignore_empty_annotations: True transforms: @@ -59,8 +57,6 @@ val_dataset_params: subdir: images/val2017 # sub directory path of data_dir containing the train data. json_file: instances_val2017.json # path to coco train json file, data_dir/annotations/train_json_file. input_dim: [640, 640] - cache_dir: - cache: False cache_annotations: True ignore_empty_annotations: True transforms: diff --git a/src/super_gradients/recipes/dataset_params/coco_detection_ppyoloe_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/coco_detection_ppyoloe_dataset_params.yaml index eff4bc5fd1..c448fcc082 100644 --- a/src/super_gradients/recipes/dataset_params/coco_detection_ppyoloe_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/coco_detection_ppyoloe_dataset_params.yaml @@ -3,8 +3,6 @@ train_dataset_params: subdir: images/train2017 # sub directory path of data_dir containing the train data. json_file: instances_train2017.json # path to coco train json file, data_dir/annotations/train_json_file. input_dim: # None, do not resize dataset on load - cache_dir: - cache: False cache_annotations: True ignore_empty_annotations: True transforms: @@ -70,8 +68,6 @@ val_dataset_params: subdir: images/val2017 # sub directory path of data_dir containing the train data. json_file: instances_val2017.json # path to coco train json file, data_dir/annotations/train_json_file. input_dim: - cache_dir: - cache: False cache_annotations: True ignore_empty_annotations: True transforms: diff --git a/src/super_gradients/recipes/dataset_params/coco_detection_ssd_lite_mobilenet_v2_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/coco_detection_ssd_lite_mobilenet_v2_dataset_params.yaml index fd6985cd90..0d6770d6ec 100644 --- a/src/super_gradients/recipes/dataset_params/coco_detection_ssd_lite_mobilenet_v2_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/coco_detection_ssd_lite_mobilenet_v2_dataset_params.yaml @@ -6,8 +6,6 @@ train_dataset_params: subdir: images/train2017 # sub directory path of data_dir containing the train data. json_file: instances_train2017.json # path to coco train json file, data_dir/annotations/train_json_file. input_dim: [320, 320] - cache_dir: - cache: False cache_annotations: True ignore_empty_annotations: True transforms: @@ -55,8 +53,6 @@ val_dataset_params: subdir: images/val2017 # sub directory path of data_dir containing the train data. json_file: instances_val2017.json # path to coco train json file, data_dir/annotations/train_json_file. input_dim: [320, 320] - cache_dir: - cache: False cache_annotations: True ignore_empty_annotations: True transforms: diff --git a/src/super_gradients/recipes/dataset_params/coco_detection_yolo_format_base_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/coco_detection_yolo_format_base_dataset_params.yaml index 1a0e48c84d..a2c292730b 100644 --- a/src/super_gradients/recipes/dataset_params/coco_detection_yolo_format_base_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/coco_detection_yolo_format_base_dataset_params.yaml @@ -11,8 +11,6 @@ train_dataset_params: keyboard, cell phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy bear, hair drier, toothbrush] # TO FILL: List of classes used in your dataset. input_dim: [640, 640] - cache_dir: - cache: False cache_annotations: True ignore_empty_annotations: True transforms: @@ -69,8 +67,6 @@ val_dataset_params: keyboard, cell phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy bear, hair drier, toothbrush] # TO FILL: List of classes used in your dataset. input_dim: [640, 640] - cache_dir: - cache: False cache_annotations: True ignore_empty_annotations: True transforms: diff --git a/src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml index 128196bd5e..61e5d3d08d 100644 --- a/src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml @@ -86,8 +86,6 @@ train_dataset_params: subdir: images/train2017 # sub directory path of data_dir containing the train data. json_file: instances_train2017.json # path to coco train json file, data_dir/annotations/train_json_file. input_dim: [640, 640] - cache_dir: - cache: False cache_annotations: True ignore_empty_annotations: True transforms: @@ -141,8 +139,6 @@ val_dataset_params: subdir: images/val2017 # sub directory path of data_dir containing the train data. json_file: instances_val2017.json # path to coco train json file, data_dir/annotations/train_json_file. input_dim: [636, 636] - cache_dir: - cache: False cache_annotations: True ignore_empty_annotations: True transforms: diff --git a/src/super_gradients/recipes/dataset_params/pascal_voc_detection_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/pascal_voc_detection_dataset_params.yaml index 4772c44700..f9dcaad098 100644 --- a/src/super_gradients/recipes/dataset_params/pascal_voc_detection_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/pascal_voc_detection_dataset_params.yaml @@ -1,8 +1,6 @@ train_dataset_params: data_dir: ./data/pascal_voc/ input_dim: [320, 320] - cache: False - cache_dir: transforms: - DetectionPaddedRescale: input_dim: ${dataset_params.train_dataset_params.input_dim} @@ -16,8 +14,6 @@ train_dataset_params: val_dataset_params: data_dir: ./data/pascal_voc/ input_dim: [320, 320] - cache: False - cache_dir: transforms: - DetectionPaddedRescale: input_dim: ${dataset_params.train_dataset_params.input_dim} diff --git a/src/super_gradients/recipes/dataset_params/roboflow_detection_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/roboflow_detection_dataset_params.yaml index 9f0e6f72db..fd819b649a 100644 --- a/src/super_gradients/recipes/dataset_params/roboflow_detection_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/roboflow_detection_dataset_params.yaml @@ -7,8 +7,6 @@ train_dataset_params: dataset_name: ${..dataset_name} split: train input_dim: [640, 640] - cache_dir: - cache: False cache_annotations: True ignore_empty_annotations: False transforms: @@ -68,8 +66,6 @@ val_dataset_params: dataset_name: ${..dataset_name} split: valid input_dim: [640, 640] - cache_dir: - cache: False cache_annotations: True ignore_empty_annotations: False transforms: diff --git a/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py b/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py index 2281d5cb02..d6a6d3f7c7 100644 --- a/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py +++ b/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py @@ -1,9 +1,8 @@ import collections -import hashlib import os import random +import warnings from copy import deepcopy -from multiprocessing.pool import ThreadPool from pathlib import Path from typing import List, Dict, Union, Any, Optional, Tuple @@ -50,7 +49,6 @@ class DetectionDataset(Dataset, HasPreprocessingParams): WORKFLOW: - On instantiation: - All annotations are cached. If class_inclusion_list was specified, there is also subclassing at this step. - - If cache is True, the images are also cached - On call (__getitem__) for a specific image index: - The image and annotations are grouped together in a dict called SAMPLE @@ -75,9 +73,7 @@ def __init__( data_dir: str, original_target_format: Union[ConcatenatedTensorFormat, DetectionTargetsFormat], max_num_samples: int = None, - cache: bool = False, cache_annotations: bool = True, - cache_dir: str = None, input_dim: Union[int, Tuple[int, int], None] = None, transforms: List[AbstractDetectionTransform] = [], all_classes_list: Optional[List[str]] = [], @@ -87,6 +83,8 @@ def __init__( output_fields: List[str] = None, verbose: bool = True, show_all_warnings: bool = False, + cache=None, + cache_dir=None, ): """Detection dataset. @@ -98,10 +96,8 @@ def __init__( :param original_target_format: Format of targets stored on disk. raw data format, the output format might differ based on transforms. :param max_num_samples: If not None, set the maximum size of the dataset by only indexing the first n annotations/images. - :param cache: Whether to cache images or not. :param cache_annotations: Whether to cache annotations or not. This reduces training time by pre-loading all the annotations, but requires more RAM and more time to instantiate the dataset when working on very large datasets. - :param cache_dir: Path to the directory where cached images will be stored in an optimized format. :param transforms: List of transforms to apply sequentially on sample. :param all_classes_list: All the class names. :param class_inclusion_list: If not None, define the subset of classes to be included as targets. @@ -116,7 +112,22 @@ def __init__( It has to include at least "image" and "target" but can include other. :param verbose: Whether to show additional information or not, such as loading progress. (doesnt include warnings) :param show_all_warnings: Whether to show all warnings or not. + :param cache: Deprecated. This parameter is not used and setting it has no effect. It will be removed in 3.8 + :param cache_dir: Deprecated. This parameter is not used and setting it has no effect. It will be removed in 3.8 """ + if cache is not None: + warnings.warn( + "cache parameter has been marked as deprecated and setting it has no effect. " + "It will be removed in SuperGradients 3.8. Please remove this parameter when instantiating a dataset instance", + DeprecationWarning, + ) + if cache_dir is not None: + warnings.warn( + "cache_dir parameter has been marked as deprecated and setting it has no effect. " + "It will be removed in SuperGradients 3.8. Please remove this parameter when instantiating a dataset instance", + DeprecationWarning, + ) + super().__init__() self.verbose = verbose self.show_all_warnings = show_all_warnings @@ -208,11 +219,6 @@ def __init__( self._n_samples = n_samples # Regardless of any filtering - # CACHE IMAGE - self.cache = cache - self.cache_dir = cache_dir - self.cached_imgs_padded = self._cache_images() if self.cache else None - @property def _all_classes(self): """Placeholder to setup the class names. This is an alternative to passing "all_classes_list" to __init__. @@ -327,64 +333,6 @@ def _sub_class_target(self, targets: np.ndarray, class_index: int) -> np.ndarray return np.array(targets_kept) if len(targets_kept) > 0 else np.zeros((0, 5), dtype=np.float32) - def _cache_images(self) -> np.ndarray: - """Cache the images. The cached image are stored in a file to be loaded faster mext time. - :return: Cached images - """ - cache_dir = Path(self.cache_dir) - if cache_dir is None: - raise ValueError("You must specify a cache_dir if you want to cache your images." "If you did not mean to use cache, please set cache=False ") - cache_dir.mkdir(parents=True, exist_ok=True) - - logger.warning( - "\n********************************************************************************\n" - "You are using cached images in RAM to accelerate training.\n" - "This requires large system RAM.\n" - "********************************************************************************" - ) - - if self.input_dim is None: - raise RuntimeError("caching is not possible without input_dim is not set") - max_h, max_w = self.input_dim[0], self.input_dim[1] - - # The cache should be the same as long as the images and their sizes are the same - hash = hashlib.sha256() - for index in range(len(self)): - annotation = self._get_sample_annotations(index=index, ignore_empty_annotations=self.ignore_empty_annotations) - values_to_hash = [annotation["resized_img_shape"][0], annotation["resized_img_shape"][1], Path(annotation["img_path"]).name] - for value in values_to_hash: - hash.update(str(value).encode("utf-8")) - cache_hash = hash.hexdigest() - - img_resized_cache_path = cache_dir / f"img_resized_cache_{cache_hash}.array" - - if not img_resized_cache_path.exists(): - logger.info("Caching images for the first time. Be aware that this will stay in the disk until you delete it yourself.") - NUM_THREADs = min(8, os.cpu_count()) - - # Inline-function because we should not to pollute the rest of the class with this function. - # This function is required because of legacy design - ideally we should not have to load annotations in order to get the image path. - def _load_image_from_index(index: int) -> np.ndarray: - annotations = self._get_sample_annotations(index=index, ignore_empty_annotations=self.ignore_empty_annotations) - return self._load_resized_img(image_path=annotations["img_path"]) - - loaded_images = ThreadPool(NUM_THREADs).imap(func=_load_image_from_index, iterable=range(len(self))) - - # Initialize placeholder for images - cached_imgs = np.memmap(str(img_resized_cache_path), shape=(len(self), max_h, max_w, 3), dtype=np.uint8, mode="w+") - - # Store images in the placeholder - with tqdm(enumerate(loaded_images), total=len(self), disable=not self.verbose) as loaded_images_pbar: - for i, image in loaded_images_pbar: - cached_imgs[i][: image.shape[0], : image.shape[1], :] = image.copy() - cached_imgs.flush() - else: - logger.warning("You are using cached imgs!") - - logger.info("Loading cached imgs...") - cached_imgs = np.memmap(str(img_resized_cache_path), shape=(len(self), max_h, max_w, 3), dtype=np.uint8, mode="r+") - return cached_imgs - def _load_resized_img(self, image_path: str) -> np.ndarray: """Load an image and resize it to the desired size (If relevant). :param image_path: Full path of the image @@ -411,11 +359,6 @@ def _load_image(self, image_path: str) -> np.ndarray: raise FileNotFoundError(f"{img_file} was no found. Please make sure that the dataset was" f"downloaded and that the path is correct") return img - def __del__(self): - """Clear the cached images""" - if hasattr(self, "cached_imgs_padded"): - del self.cached_imgs_padded - def __len__(self) -> int: """Get the length of the dataset. Note that this is the number of samples AFTER filtering (if relevant).""" return len(self._non_empty_sample_ids) if self.ignore_empty_annotations else self._n_samples @@ -444,23 +387,9 @@ def get_sample(self, index: int, ignore_empty_annotations: bool = False) -> Dict :return: Sample, i.e. a dictionary including at least "image" and "target" """ sample_annotations = self._get_sample_annotations(index=index, ignore_empty_annotations=ignore_empty_annotations) - if self.cache: - image = self._get_cached_image(index=index, cached_image_shape=sample_annotations["resized_img_shape"]) - else: - image = self._load_resized_img(image_path=sample_annotations["img_path"]) + image = self._load_resized_img(image_path=sample_annotations["img_path"]) return {"image": image, **deepcopy(sample_annotations)} - def _get_cached_image(self, index: int, cached_image_shape: Tuple[int, int]) -> np.ndarray: - """Load an image from cache. - :param index: Index refers to the index of the sample in the dataset, AFTER filtering (if relevant). 0<=index<=len(dataset)-1 - :param cached_image_shape: Shape of the cached image (after resizing if input_dim is set) - :return: Image - """ - padded_image = self.cached_imgs_padded[index] - cached_height, cached_width = cached_image_shape - resized_image = padded_image[:cached_height, :cached_width, :] - return resized_image.copy() - def apply_transforms(self, sample: Dict[str, Union[np.ndarray, Any]]) -> Dict[str, Union[np.ndarray, Any]]: """ Applies self.transforms sequentially to sample @@ -543,7 +472,6 @@ def plot( ) for plot_i in range(n_plots): - fig = plt.figure(figsize=(10, 10)) n_subplot = int(np.ceil(max_samples_per_plot**0.5)) diff --git a/src/super_gradients/training/datasets/detection_datasets/pascal_voc_detection.py b/src/super_gradients/training/datasets/detection_datasets/pascal_voc_detection.py index d4cb8709fb..003796137b 100755 --- a/src/super_gradients/training/datasets/detection_datasets/pascal_voc_detection.py +++ b/src/super_gradients/training/datasets/detection_datasets/pascal_voc_detection.py @@ -234,8 +234,8 @@ def __init__( self, data_dir: str, input_dim: tuple, - cache: bool = False, - cache_dir: str = None, + cache=None, + cache_dir=None, transforms: List[AbstractDetectionTransform] = [], class_inclusion_list: Optional[List[str]] = None, max_num_samples: int = None, diff --git a/tests/deci_core_unit_test_suite_runner.py b/tests/deci_core_unit_test_suite_runner.py index 2ac62152a5..fbb269058f 100644 --- a/tests/deci_core_unit_test_suite_runner.py +++ b/tests/deci_core_unit_test_suite_runner.py @@ -77,7 +77,6 @@ from tests.unit_tests.detection_sub_sampling_test import TestDetectionDatasetSubsampling from tests.unit_tests.detection_sub_classing_test import TestDetectionDatasetSubclassing from tests.unit_tests.detection_output_adapter_test import TestDetectionOutputAdapter -from tests.unit_tests.detection_caching import TestDetectionDatasetCaching from tests.unit_tests.multi_scaling_test import MultiScaleTest from tests.unit_tests.ppyoloe_unit_test import TestPPYOLOE from tests.unit_tests.bbox_formats_test import BBoxFormatsTest @@ -135,7 +134,6 @@ def _add_modules_to_unit_tests_suite(self): self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubsampling)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetSubclassing)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(QuantizationUtilityTest)) - self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionDatasetCaching)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MultiScaleTest)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainingParamsTest)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(CallTrainTwiceTest)) diff --git a/tests/unit_tests/detection_caching.py b/tests/unit_tests/detection_caching.py deleted file mode 100644 index ece5266777..0000000000 --- a/tests/unit_tests/detection_caching.py +++ /dev/null @@ -1,109 +0,0 @@ -import unittest -import numpy as np -from pathlib import Path -import tempfile -import os - -from super_gradients.training.datasets import DetectionDataset -from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL - - -class DummyDetectionDataset(DetectionDataset): - def __init__(self, input_dim, *args, **kwargs): - """Dummy Dataset testing subclassing, designed with no annotation that includes class_2.""" - - self.image_size = input_dim - self.n_samples = 321 - kwargs["all_classes_list"] = ["class_0", "class_1", "class_2"] - kwargs["original_target_format"] = XYXY_LABEL - super().__init__(input_dim=input_dim, *args, **kwargs) - - def _setup_data_source(self): - return self.n_samples - - def _load_annotation(self, sample_id: int) -> dict: - """Every image is made of one target, with label sample_id % len(all_classes_list) and - a seed to allow the random image to the same for a given sample_id - """ - cls_id = sample_id % len(self.all_classes_list) - return {"img_path": str(sample_id), "target": np.array([[0, 0, 10, 10, cls_id]]), "resized_img_shape": self.image_size, "seed": sample_id} - - # We overwrite this to fake images - def _load_image(self, image_path: str) -> np.ndarray: - np.random.seed(int(image_path)) - return np.random.random((self.image_size[0], self.image_size[1], 3)) * 255 - - -class TestDetectionDatasetCaching(unittest.TestCase): - def setUp(self) -> None: - self.temp_cache_dir = tempfile.TemporaryDirectory(prefix="cache").name - if not os.path.isdir(self.temp_cache_dir): - os.mkdir(self.temp_cache_dir) - - def _count_cached_array(self): - return len(list(Path(self.temp_cache_dir).glob("*.array"))) - - def _empty_cache(self): - for cache_file in Path(self.temp_cache_dir).glob("*.array"): - cache_file.unlink() - - def test_cache_keep_empty(self): - self._empty_cache() - - datasets = [ - DummyDetectionDataset( - input_dim=(640, 512), - ignore_empty_annotations=False, - class_inclusion_list=class_inclusion_list, - cache=True, - cache_dir=self.temp_cache_dir, - data_dir="/home/", - ) - for class_inclusion_list in [["class_0", "class_1", "class_2"], ["class_0"], ["class_1"], ["class_2"], ["class_1", "class_2"]] - ] - - self.assertEqual(1, self._count_cached_array()) - for first_dataset, second_dataset in zip(datasets[:-1], datasets[1:]): - self.assertTrue(np.array_equal(first_dataset.cached_imgs_padded, second_dataset.cached_imgs_padded)) - - self._empty_cache() - - def test_cache_ignore_empty(self): - self._empty_cache() - - datasets = [ - DummyDetectionDataset( - input_dim=(640, 512), - ignore_empty_annotations=True, - class_inclusion_list=class_inclusion_list, - cache=True, - cache_dir=self.temp_cache_dir, - data_dir="/home/", - ) - for class_inclusion_list in [["class_0", "class_1", "class_2"], ["class_0"], ["class_1"], ["class_2"], ["class_1", "class_2"]] - ] - - self.assertEqual(5, self._count_cached_array()) - for first_dataset, second_dataset in zip(datasets[:-1], datasets[1:]): - self.assertFalse(np.array_equal(first_dataset.cached_imgs_padded, second_dataset.cached_imgs_padded)) - - self._empty_cache() - - def test_cache_saved(self): - """Check that after the first time a dataset is called with specific params, - the next time it will call the saved array instead of building it.""" - self._empty_cache() - self.assertEqual(0, self._count_cached_array()) - - _ = DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True, cache=True, cache_dir=self.temp_cache_dir, data_dir="/home/") - self.assertEqual(1, self._count_cached_array()) - - for _ in range(5): - _ = DummyDetectionDataset(input_dim=(640, 512), ignore_empty_annotations=True, cache=True, cache_dir=self.temp_cache_dir, data_dir="/home/") - self.assertEqual(1, self._count_cached_array()) - - self._empty_cache() - - -if __name__ == "__main__": - unittest.main() From 5a904e048c487795fdc0bbf30cb8809cdbed3a0d Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 24 Jan 2024 16:11:48 +0200 Subject: [PATCH 2/5] Deprecate tight box rotation support for COCO dataset --- .../coco_detection_dataset_params.yaml | 2 -- ...coco_detection_ppyoloe_dataset_params.yaml | 2 -- ..._ssd_lite_mobilenet_v2_dataset_params.yaml | 2 -- ...oco_detection_yolo_nas_dataset_params.yaml | 2 -- .../roboflow_detection_dataset_params.yaml | 2 -- .../coco_format_detection.py | 27 ++++++------------- 6 files changed, 8 insertions(+), 29 deletions(-) diff --git a/src/super_gradients/recipes/dataset_params/coco_detection_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/coco_detection_dataset_params.yaml index 4c1b74073a..b2b7911d8d 100644 --- a/src/super_gradients/recipes/dataset_params/coco_detection_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/coco_detection_dataset_params.yaml @@ -36,7 +36,6 @@ train_dataset_params: - DetectionTargetsFormatTransform: input_dim: ${dataset_params.train_dataset_params.input_dim} output_format: LABEL_CXCYWH - tight_box_rotation: False class_inclusion_list: max_num_samples: with_crowd: False @@ -65,7 +64,6 @@ val_dataset_params: - DetectionTargetsFormatTransform: input_dim: ${dataset_params.val_dataset_params.input_dim} output_format: LABEL_CXCYWH - tight_box_rotation: False class_inclusion_list: max_num_samples: with_crowd: True diff --git a/src/super_gradients/recipes/dataset_params/coco_detection_ppyoloe_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/coco_detection_ppyoloe_dataset_params.yaml index c448fcc082..5ba3e73033 100644 --- a/src/super_gradients/recipes/dataset_params/coco_detection_ppyoloe_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/coco_detection_ppyoloe_dataset_params.yaml @@ -38,7 +38,6 @@ train_dataset_params: - DetectionTargetsFormatTransform: output_format: LABEL_CXCYWH - tight_box_rotation: False class_inclusion_list: max_num_samples: with_crowd: False @@ -78,7 +77,6 @@ val_dataset_params: std: [ 58.395, 57.12, 57.375 ] - DetectionTargetsFormatTransform: output_format: LABEL_CXCYWH - tight_box_rotation: False class_inclusion_list: max_num_samples: with_crowd: True diff --git a/src/super_gradients/recipes/dataset_params/coco_detection_ssd_lite_mobilenet_v2_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/coco_detection_ssd_lite_mobilenet_v2_dataset_params.yaml index 0d6770d6ec..d6b51e977f 100644 --- a/src/super_gradients/recipes/dataset_params/coco_detection_ssd_lite_mobilenet_v2_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/coco_detection_ssd_lite_mobilenet_v2_dataset_params.yaml @@ -32,7 +32,6 @@ train_dataset_params: input_dim: ${dataset_params.train_dataset_params.input_dim} output_format: LABEL_NORMALIZED_CXCYWH - tight_box_rotation: False class_inclusion_list: max_num_samples: with_crowd: False @@ -61,7 +60,6 @@ val_dataset_params: - DetectionTargetsFormatTransform: input_dim: ${dataset_params.val_dataset_params.input_dim} output_format: LABEL_NORMALIZED_CXCYWH - tight_box_rotation: False class_inclusion_list: max_num_samples: with_crowd: True diff --git a/src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml index 61e5d3d08d..240029f7c9 100644 --- a/src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/coco_detection_yolo_nas_dataset_params.yaml @@ -121,7 +121,6 @@ train_dataset_params: - DetectionTargetsFormatTransform: output_format: LABEL_CXCYWH - tight_box_rotation: False class_inclusion_list: max_num_samples: with_crowd: False @@ -153,7 +152,6 @@ val_dataset_params: - DetectionTargetsFormatTransform: input_dim: [640, 640] output_format: LABEL_CXCYWH - tight_box_rotation: False class_inclusion_list: max_num_samples: with_crowd: True diff --git a/src/super_gradients/recipes/dataset_params/roboflow_detection_dataset_params.yaml b/src/super_gradients/recipes/dataset_params/roboflow_detection_dataset_params.yaml index fd819b649a..8ede1af1d9 100644 --- a/src/super_gradients/recipes/dataset_params/roboflow_detection_dataset_params.yaml +++ b/src/super_gradients/recipes/dataset_params/roboflow_detection_dataset_params.yaml @@ -43,7 +43,6 @@ train_dataset_params: - DetectionTargetsFormatTransform: input_dim: ${dataset_params.train_dataset_params.input_dim} output_format: LABEL_CXCYWH - tight_box_rotation: False class_inclusion_list: max_num_samples: with_crowd: False @@ -77,7 +76,6 @@ val_dataset_params: - DetectionTargetsFormatTransform: input_dim: ${dataset_params.val_dataset_params.input_dim} output_format: LABEL_CXCYWH - tight_box_rotation: False class_inclusion_list: max_num_samples: with_crowd: True diff --git a/src/super_gradients/training/datasets/detection_datasets/coco_format_detection.py b/src/super_gradients/training/datasets/detection_datasets/coco_format_detection.py index 9e90b23d56..455d5b69fe 100644 --- a/src/super_gradients/training/datasets/detection_datasets/coco_format_detection.py +++ b/src/super_gradients/training/datasets/detection_datasets/coco_format_detection.py @@ -1,7 +1,6 @@ import copy import os -import cv2 import numpy as np from pycocotools.coco import COCO from typing import List, Optional @@ -28,9 +27,9 @@ def __init__( data_dir: str, json_annotation_file: str, images_dir: str, - tight_box_rotation: bool = False, with_crowd: bool = True, class_ids_to_ignore: Optional[List[int]] = None, + tight_box_rotation=None, *args, **kwargs, ): @@ -38,14 +37,16 @@ def __init__( :param data_dir: Where the data is stored. :param json_annotation_file: Name of the coco json file. Path relative to data_dir. :param images_dir: Name of the directory that includes all the images. Path relative to data_dir. - :param tight_box_rotation: bool, whether to use of segmentation maps convex hull as target_seg - (check get_sample docs). :param with_crowd: Add the crowd groundtruths to __getitem__ :param class_ids_to_ignore: List of class ids to ignore in the dataset. By default, doesnt ignore any class. + :param tight_box_rotation: This parameter is deprecated and will be removed in a SuperGradients 3.8. """ + if tight_box_rotation is not None: + logger.warning( + "Parameter `tight_box_rotation` is deprecated and will be removed in a SuperGradients 3.8." "Please remove this parameter from your code." + ) self.images_dir = images_dir self.json_annotation_file = json_annotation_file - self.tight_box_rotation = tight_box_rotation self.with_crowd = with_crowd self.class_ids_to_ignore = class_ids_to_ignore or [] @@ -95,7 +96,7 @@ def _init_coco(self) -> COCO: else: coco = COCO(annotation_file_path) - remove_useless_info(coco, self.tight_box_rotation) + remove_useless_info(coco, False) return coco def _load_annotation(self, sample_id: int) -> dict: @@ -133,21 +134,11 @@ def _load_annotation(self, sample_id: int) -> dict: non_crowd_annotations = [annotation for annotation in cleaned_annotations if annotation["iscrowd"] == 0] target = np.zeros((len(non_crowd_annotations), 5)) - num_seg_values = 98 if self.tight_box_rotation else 0 - target_segmentation = np.ones((len(non_crowd_annotations), num_seg_values)) - target_segmentation.fill(np.nan) + for ix, annotation in enumerate(non_crowd_annotations): cls = self.class_ids.index(annotation["category_id"]) target[ix, 0:4] = annotation["clean_bbox"] target[ix, 4] = cls - if self.tight_box_rotation: - seg_points = [j for i in annotation.get("segmentation", []) for j in i] - if seg_points: - seg_points_c = np.array(seg_points).reshape((-1, 2)).astype(np.int32) - seg_points_convex = cv2.convexHull(seg_points_c).ravel() - else: - seg_points_convex = [] - target_segmentation[ix, : len(seg_points_convex)] = seg_points_convex crowd_annotations = [annotation for annotation in cleaned_annotations if annotation["iscrowd"] == 1] @@ -163,7 +154,6 @@ def _load_annotation(self, sample_id: int) -> dict: r = min(self.input_dim[0] / height, self.input_dim[1] / width) target[:, :4] *= r crowd_target[:, :4] *= r - target_segmentation *= r resized_img_shape = (int(height * r), int(width * r)) else: resized_img_shape = initial_img_shape @@ -175,7 +165,6 @@ def _load_annotation(self, sample_id: int) -> dict: annotation = { "target": target, "crowd_target": crowd_target, - "target_segmentation": target_segmentation, "initial_img_shape": initial_img_shape, "resized_img_shape": resized_img_shape, "img_path": img_path, From 24cf66ca088ab0f7a5a84bdc8b8d834496e6683a Mon Sep 17 00:00:00 2001 From: Eugene Date: Fri, 26 Jan 2024 10:20:40 +0200 Subject: [PATCH 3/5] deprecated_parameter --- src/super_gradients/common/deprecate.py | 91 +++++++++++++++++++ .../detection_datasets/coco_detection.py | 2 - .../coco_format_detection.py | 11 ++- 3 files changed, 98 insertions(+), 6 deletions(-) diff --git a/src/super_gradients/common/deprecate.py b/src/super_gradients/common/deprecate.py index 65aedf5136..a91178a862 100644 --- a/src/super_gradients/common/deprecate.py +++ b/src/super_gradients/common/deprecate.py @@ -1,8 +1,11 @@ +import inspect import warnings from functools import wraps from typing import Optional, Callable from pkg_resources import parse_version +__all__ = ["deprecated", "deprecated_parameter", "deprecated_training_param", "deprecate_param"] + def deprecated(deprecated_since: str, removed_from: str, target: Optional[callable] = None, reason: str = ""): """ @@ -78,6 +81,94 @@ def wrapper(*args, **kwargs): return decorator +def deprecated_parameter(parameter_name: str, deprecated_since: str, removed_from: str, target: Optional[callable] = None, reason: str = ""): + """ + Decorator to mark a parameter of a callable as deprecated. + It provides a clear and actionable warning message informing + the user about the version in which parameter was deprecated, + the version in which it will be removed, and guidance on how to replace it. + + :param parameter_name: Name of the parameter + :param deprecated_since: Version number when the function was deprecated. + :param removed_from: Version number when the function will be removed. + :param target: (Optional) The new function that should be used as a replacement. If provided, it will guide the user to the updated function. + :param reason: (Optional) Additional information or reason for the deprecation. + + Example usage: + If a parameter removed with no replacement alternative: + >>> @deprecated_parameter("c",deprecated_since='3.2.0', removed_from='4.0.0', reason="This argument is not used") + >>> def do_some_work(a,b,c = None): + >>> return a+b + + If a parameter has new name: + >>> @deprecated_parameter("c", target="new_parameter", deprecated_since='3.2.0', removed_from='4.0.0', reason="This argument is not used") + >>> def do_some_work(a,b,target,c = None): + >>> return a+b+target + + """ + + def decorator(func: callable) -> callable: + argspec = inspect.getfullargspec(func) + argument_index = argspec.args.index(parameter_name) + + default_value = None + sig = inspect.signature(func) + for name, param in sig.parameters.items(): + if name == parameter_name: + default_value = param.default + break + + @wraps(func) + def wrapper(*args, **kwargs): + try: + value = args[argument_index] + except IndexError: + value = kwargs[parameter_name] + + if value != default_value and not wrapper._warned: + import super_gradients + + is_still_supported = parse_version(super_gradients.__version__) < parse_version(removed_from) + status_msg = "is deprecated" if is_still_supported else "was deprecated and has been removed" + message = ( + f"Parameter `{parameter_name}` of `{func.__module__}.{func.__name__}` {status_msg} since version `{deprecated_since}` " + f"and will be removed in version `{removed_from}`.\n" + ) + if reason: + message += f"Reason: {reason}.\n" + + if target is not None: + message += ( + f"Please update your code:\n" + f" [-] from `{func.__name__}(..., {parameter_name}={value})`\n" + f" [+] to `{func.__name__}(..., {target}={value})`\n" + ) + else: + # fmt: off + message += ( + f"Please update your code:\n" + f" [-] from `{func.__name__}(..., {parameter_name}={value})`\n" + f" [+] to `{func.__name__}(...)`\n" + ) + # fmt: on + + if is_still_supported: + warnings.simplefilter("once", DeprecationWarning) # Required, otherwise the warning may never be displayed. + warnings.warn(message, DeprecationWarning, stacklevel=2) + wrapper._warned = True + else: + raise ImportError(message) + + return func(*args, **kwargs) + + # Each decorated object will have its own _warned state + # This state ensures that the warning will appear only once, to avoid polluting the console in case the function is called too often. + wrapper._warned = False + return wrapper + + return decorator + + def deprecated_training_param(deprecated_tparam_name: str, deprecated_since: str, removed_from: str, new_arg_assigner: Callable, message: str = ""): """ Decorator for deprecating training hyperparameters. diff --git a/src/super_gradients/training/datasets/detection_datasets/coco_detection.py b/src/super_gradients/training/datasets/detection_datasets/coco_detection.py index 984f019b70..236fd8e9a7 100644 --- a/src/super_gradients/training/datasets/detection_datasets/coco_detection.py +++ b/src/super_gradients/training/datasets/detection_datasets/coco_detection.py @@ -49,8 +49,6 @@ def __init__( """ :param json_file: Name of the coco json file, that resides in data_dir/annotations/json_file. :param subdir: Sub directory of data_dir containing the data. - :param tight_box_rotation: bool, whether to use of segmentation maps convex hull as target_seg - (check get_sample docs). :param with_crowd: Add the crowd groundtruths to __getitem__ kwargs: diff --git a/src/super_gradients/training/datasets/detection_datasets/coco_format_detection.py b/src/super_gradients/training/datasets/detection_datasets/coco_format_detection.py index 455d5b69fe..82e2a94e63 100644 --- a/src/super_gradients/training/datasets/detection_datasets/coco_format_detection.py +++ b/src/super_gradients/training/datasets/detection_datasets/coco_format_detection.py @@ -7,6 +7,7 @@ from contextlib import redirect_stdout from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.common.deprecate import deprecated_parameter from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset from super_gradients.common.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL @@ -22,6 +23,12 @@ class COCOFormatDetectionDataset(DetectionDataset): Output format: (x, y, x, y, class_id) """ + @deprecated_parameter( + "tight_box_rotation", + deprecated_since="3.7.0", + removed_from="3.8.0", + reason="Support of `tight_box_rotation` has been removed. This parameter has no effect anymore.", + ) def __init__( self, data_dir: str, @@ -41,10 +48,6 @@ def __init__( :param class_ids_to_ignore: List of class ids to ignore in the dataset. By default, doesnt ignore any class. :param tight_box_rotation: This parameter is deprecated and will be removed in a SuperGradients 3.8. """ - if tight_box_rotation is not None: - logger.warning( - "Parameter `tight_box_rotation` is deprecated and will be removed in a SuperGradients 3.8." "Please remove this parameter from your code." - ) self.images_dir = images_dir self.json_annotation_file = json_annotation_file self.with_crowd = with_crowd From 377b4db8026296df17420f037f033d2e14d4ee2b Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Fri, 26 Jan 2024 14:33:34 +0200 Subject: [PATCH 4/5] Fix bug in deprecated_parameter --- src/super_gradients/common/deprecate.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/super_gradients/common/deprecate.py b/src/super_gradients/common/deprecate.py index a91178a862..7a8b75faaf 100644 --- a/src/super_gradients/common/deprecate.py +++ b/src/super_gradients/common/deprecate.py @@ -120,10 +120,17 @@ def decorator(func: callable) -> callable: @wraps(func) def wrapper(*args, **kwargs): + + # Initialize the value to the default value + value = default_value + + # Try to get the actual value from the arguments + # Have to check both positional and keyword arguments try: value = args[argument_index] except IndexError: - value = kwargs[parameter_name] + if parameter_name in kwargs: + value = kwargs[parameter_name] if value != default_value and not wrapper._warned: import super_gradients From 11ce40760cc6998693c65f7a3723809b3f1cc4ca Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 5 Feb 2024 15:30:38 +0200 Subject: [PATCH 5/5] Improve test but adding subTest to indicate a tested architecture and use np.testing.assert_array_almost_equal to get more detailed output if test fails --- tests/unit_tests/repvgg_unit_test.py | 49 +++++++++++++++------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/tests/unit_tests/repvgg_unit_test.py b/tests/unit_tests/repvgg_unit_test.py index 4d4e27a1e6..757da8d9af 100644 --- a/tests/unit_tests/repvgg_unit_test.py +++ b/tests/unit_tests/repvgg_unit_test.py @@ -4,6 +4,7 @@ from super_gradients.training.utils.utils import HpmStruct import torch import copy +import numpy as np class BackboneBasedModel(torch.nn.Module): @@ -50,29 +51,31 @@ def test_deployment_architecture(self): # skip custom constructors to keep all_arch_params as general as a possible if "repvgg" not in arch_name or "custom" in arch_name: continue - model = ARCHITECTURES[arch_name](arch_params=self.all_arch_params) - self.assertTrue(hasattr(model.stem, "branch_3x3")) # check single layer for training mode - self.assertTrue(model.build_residual_branches) - - training_mode_sd = model.state_dict() - for module in training_mode_sd: - self.assertFalse("reparam" in module) # deployment block included in training mode - test_input = torch.ones((1, in_channels, image_size, image_size)) - model.eval() - training_mode_output = model(test_input) - - model.prep_model_for_conversion() - self.assertTrue(hasattr(model.stem, "rbr_reparam")) # check single layer for training mode - self.assertFalse(model.build_residual_branches) - - deployment_mode_sd = model.state_dict() - for module in deployment_mode_sd: - self.assertFalse("running_mean" in module) # BN were not fused - self.assertFalse("branch" in module) # branches were not joined - - deployment_mode_output = model(test_input) - # difference is of very low magnitude - self.assertFalse(False in torch.isclose(training_mode_output, deployment_mode_output, atol=1e-4)) + + with self.subTest(arch_name=arch_name): + model = ARCHITECTURES[arch_name](arch_params=self.all_arch_params) + self.assertTrue(hasattr(model.stem, "branch_3x3")) # check single layer for training mode + self.assertTrue(model.build_residual_branches) + + training_mode_sd = model.state_dict() + for module in training_mode_sd: + self.assertFalse("reparam" in module) # deployment block included in training mode + test_input = torch.ones((1, in_channels, image_size, image_size)) + model.eval() + training_mode_output = model(test_input) + + model.prep_model_for_conversion() + self.assertTrue(hasattr(model.stem, "rbr_reparam")) # check single layer for training mode + self.assertFalse(model.build_residual_branches) + + deployment_mode_sd = model.state_dict() + for module in deployment_mode_sd: + self.assertFalse("running_mean" in module) # BN were not fused + self.assertFalse("branch" in module) # branches were not joined + + deployment_mode_output = model(test_input) + # difference is of very low magnitude + np.testing.assert_array_almost_equal(training_mode_output.detach().numpy(), deployment_mode_output.detach().numpy(), decimal=4) def test_backbone_mode(self): """