Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove __target__ for Detection CollateFN #1470

Merged
merged 10 commits into from
Oct 3, 2023
6 changes: 2 additions & 4 deletions documentation/source/ObjectDetection.md
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,7 @@ train_dataloader_params:
worker_init_fn:
_target_: super_gradients.training.utils.utils.load_func
dotpath: super_gradients.training.datasets.datasets_utils.worker_init_reset_seed
collate_fn:
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataset_params:
data_dir: ${dataset_params.root_dir}
Expand All @@ -521,8 +520,7 @@ val_dataloader_params:
num_workers: 8
drop_last: True
pin_memory: True
collate_fn:
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN
```

In your training recipe add/change the following lines to:
Expand Down
43 changes: 43 additions & 0 deletions src/super_gradients/common/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from .dataset_exceptions import (
EmptyDatasetException,
DatasetItemsException,
DatasetValidationException,
IllegalDatasetParameterException,
ParameterMismatchException,
UnsupportedBatchItemsFormat,
)
from .loss_exceptions import RequiredLossComponentReductionException, IllegalRangeForLossAttributeException
from .factory_exceptions import UnknownTypeException
from .kd_trainer_exceptions import (
KDModelException,
UnsupportedKDModelArgException,
UnsupportedKDArchitectureException,
ArchitectureKwargsException,
InconsistentParamsException,
TeacherKnowledgeException,
UndefinedNumClassesException,
)
from .sg_trainer_exceptions import IllegalDataloaderInitialization, UnsupportedOptimizerFormat, UnsupportedTrainingParameterFormat, GPUModeNotSetupError

__all__ = [
"EmptyDatasetException",
"DatasetItemsException",
"DatasetValidationException",
"IllegalDatasetParameterException",
"ParameterMismatchException",
"UnsupportedBatchItemsFormat",
"RequiredLossComponentReductionException",
"IllegalRangeForLossAttributeException",
"UnknownTypeException",
"KDModelException",
"UnsupportedKDModelArgException",
"UnsupportedKDArchitectureException",
"ArchitectureKwargsException",
"InconsistentParamsException",
"TeacherKnowledgeException",
"UndefinedNumClassesException",
"IllegalDataloaderInitialization",
"UnsupportedOptimizerFormat",
"UnsupportedTrainingParameterFormat",
"GPUModeNotSetupError",
]
16 changes: 16 additions & 0 deletions src/super_gradients/common/exceptions/dataset_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Tuple, Type
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved


class DatasetValidationException(Exception):
pass

Expand Down Expand Up @@ -46,3 +49,16 @@ def __init__(self, batch_items: tuple):
"To fix this, please change the implementation of your dataset __getitem__ method, so that it would return the items defined above.\n"
)
super().__init__(self.message)


class DatasetItemsException(Exception):
def __init__(self, data_sample: Tuple, collate_type: Type, expected_item_names: Tuple):
"""
:param data_sample: item(s) returned by a dataset
:param collate_type: type of the collate that caused the exception
:param expected_item_names: tuple of names of items that are expected by the collate to be returned from the dataset
"""
collate_type_name = collate_type.__name__
num_sample_items = len(data_sample) if isinstance(data_sample, tuple) else 1
error_msg = f"`{collate_type_name}` only supports Datasets that return a tuple {expected_item_names}, but got a tuple of len={num_sample_items}"
super().__init__(error_msg)
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ train_dataloader_params:
worker_init_fn:
_target_: super_gradients.training.utils.utils.load_func
dotpath: super_gradients.training.datasets.datasets_utils.worker_init_reset_seed
collate_fn: # collate function for trainset
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataset_params:
data_dir: /data/coco # root path to coco data
Expand All @@ -80,8 +79,7 @@ val_dataloader_params:
num_workers: 8
drop_last: False
pin_memory: True
collate_fn: # collate function for valset
_target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
collate_fn: CrowdDetectionCollateFN


_convert_: all
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ train_dataloader_params:
worker_init_fn:
_target_: super_gradients.training.utils.utils.load_func
dotpath: super_gradients.training.datasets.datasets_utils.worker_init_reset_seed
collate_fn: # collate function for trainset
_target_: super_gradients.training.utils.detection_utils.PPYoloECollateFN
random_resize_sizes: [ 320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768 ]
random_resize_modes:
- 0 # cv::INTER_NEAREST
- 1 # cv::INTER_LINEAR
- 2 # cv::INTER_CUBIC
- 3 # cv::INTER_AREA
- 4 # cv::INTER_LANCZOS4
collate_fn:
PPYoloECollateFN:
random_resize_sizes: [ 320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768 ]
random_resize_modes:
- 0 # cv::INTER_NEAREST
- 1 # cv::INTER_LINEAR
- 2 # cv::INTER_CUBIC
- 3 # cv::INTER_AREA
- 4 # cv::INTER_LANCZOS4

val_dataset_params:
data_dir: /data/coco # root path to coco data
Expand Down Expand Up @@ -93,7 +93,6 @@ val_dataloader_params:
drop_last: False
shuffle: False
pin_memory: False
collate_fn: # collate function for valset
_target_: super_gradients.training.utils.detection_utils.CrowdDetectionPPYoloECollateFN
collate_fn: CrowdDetectionPPYoloECollateFN

_convert_: all
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ train_dataloader_params:
worker_init_fn:
_target_: super_gradients.training.utils.utils.load_func
dotpath: super_gradients.training.datasets.datasets_utils.worker_init_reset_seed
collate_fn: # collate function for trainset
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataset_params:
data_dir: /data/coco # root path to coco data
Expand All @@ -76,7 +75,6 @@ val_dataloader_params:
num_workers: 8
drop_last: False
pin_memory: True
collate_fn: # collate function for valset
_target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
collate_fn: CrowdDetectionCollateFN

_convert_: all
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ train_dataloader_params:
shuffle: True
drop_last: True
pin_memory: True
collate_fn:
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataset_params:
data_dir: /data/coco # TO FILL: Where the data is stored.
Expand Down Expand Up @@ -88,7 +87,6 @@ val_dataloader_params:
num_workers: 8
drop_last: False
pin_memory: True
collate_fn:
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

_convert_: all
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ train_dataloader_params:
shuffle: True
drop_last: True
pin_memory: True
collate_fn:
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataset_params:
data_dir: /data/coco # root path to coco data
Expand Down Expand Up @@ -86,7 +85,6 @@ val_dataloader_params:
drop_last: False
shuffle: False
pin_memory: True
collate_fn:
_target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
collate_fn: CrowdDetectionCollateFN

_convert_: all
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,14 @@ train_dataloader_params:
worker_init_fn:
_target_: super_gradients.training.utils.utils.load_func
dotpath: super_gradients.training.datasets.datasets_utils.worker_init_reset_seed
collate_fn: # collate function for trainset
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataloader_params:
batch_size: 64
num_workers: 8
drop_last: False
pin_memory: True
collate_fn: # collate function for trainset
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN


_convert_: all
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ train_dataloader_params:
worker_init_fn:
_target_: super_gradients.training.utils.utils.load_func
dotpath: super_gradients.training.datasets.datasets_utils.worker_init_reset_seed
collate_fn: # collate function for trainset
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataset_params:
data_dir: ${..data_dir} # root path to Robflow datasets
Expand Down Expand Up @@ -94,8 +93,7 @@ val_dataloader_params:
drop_last: False
shuffle: False
pin_memory: True
collate_fn: # collate function for valset
_target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
collate_fn: CrowdDetectionCollateFN


_convert_: all
6 changes: 6 additions & 0 deletions src/super_gradients/training/utils/collate_fn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .detection_collate_fn import DetectionCollateFN
from .ppyoloe_collate_fn import PPYoloECollateFN
from .crowd_detection_collate_fn import CrowdDetectionCollateFN
from .crowd_detection_ppyoloe_collate_fn import CrowdDetectionPPYoloECollateFN

__all__ = ["DetectionCollateFN", "PPYoloECollateFN", "CrowdDetectionCollateFN", "CrowdDetectionPPYoloECollateFN"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Tuple, Dict

import torch

from super_gradients.common.registry import register_collate_function
from super_gradients.common.exceptions.dataset_exceptions import DatasetItemsException
from super_gradients.training.utils.collate_fn.detection_collate_fn import DetectionCollateFN


@register_collate_function()
class CrowdDetectionCollateFN(DetectionCollateFN):
"""
Collate function for Yolox training with additional_batch_items that includes crowd targets
"""

def __init__(self):
super().__init__()
self.expected_item_names = ("image", "targets", "crowd_targets")

def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
try:
images_batch, labels_batch, crowd_labels_batch = list(zip(*data))
except (ValueError, TypeError):
raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)

return self._format_images(images_batch), self._format_targets(labels_batch), {"crowd_targets": self._format_targets(crowd_labels_batch)}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Union, List, Tuple, Dict

import torch

from super_gradients.common.registry import register_collate_function
from super_gradients.common.exceptions.dataset_exceptions import DatasetItemsException
from super_gradients.training.utils.collate_fn.ppyoloe_collate_fn import PPYoloECollateFN


@register_collate_function()
class CrowdDetectionPPYoloECollateFN(PPYoloECollateFN):
"""
Collate function for Yolox training with additional_batch_items that includes crowd targets
"""

def __init__(self, random_resize_sizes: Union[List[int], None] = None, random_resize_modes: Union[List[int], None] = None):
super().__init__(random_resize_sizes, random_resize_modes)
self.expected_item_names = ("image", "targets", "crowd_targets")

def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:

if self.random_resize_sizes is not None:
data = self.random_resize(data)

try:
images_batch, labels_batch, crowd_labels_batch = list(zip(*data))
except (ValueError, TypeError):
raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)

return self._format_images(images_batch), self._format_targets(labels_batch), {"crowd_targets": self._format_targets(crowd_labels_batch)}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Tuple, List, Union

import numpy as np
import torch

from super_gradients.common.registry import register_collate_function
from super_gradients.common.exceptions.dataset_exceptions import DatasetItemsException


@register_collate_function()
class DetectionCollateFN:
"""
Collate function for Yolox training
"""

def __init__(self):
self.expected_item_names = ("image", "targets")

def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
try:
images_batch, labels_batch = list(zip(*data))
except (ValueError, TypeError):
raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)

return self._format_images(images_batch), self._format_targets(labels_batch)

def _format_images(self, images_batch: List[Union[torch.Tensor, np.array]]) -> torch.Tensor:
images_batch = [torch.tensor(img) for img in images_batch]
images_batch_stack = torch.stack(images_batch, 0)
if images_batch_stack.shape[3] == 3:
images_batch_stack = torch.moveaxis(images_batch_stack, -1, 1).float()
return images_batch_stack

def _format_targets(self, labels_batch: List[Union[torch.Tensor, np.array]]) -> torch.Tensor:
"""
Stack a batch id column to targets and concatenate
:param labels_batch: a list of targets per image (each of arbitrary length)
:return: one tensor of targets of all imahes of shape [N, 6], where N is the total number of targets in a batch
and the 1st column is batch item index
"""
labels_batch = [torch.tensor(labels) for labels in labels_batch]
labels_batch_indexed = []
for i, labels in enumerate(labels_batch):
batch_column = labels.new_ones((labels.shape[0], 1)) * i
labels = torch.cat((batch_column, labels), dim=-1)
labels_batch_indexed.append(labels)
return torch.cat(labels_batch_indexed, 0)
Loading