Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for segmentation models to support .export() #1860

Merged
merged 12 commits into from
Mar 13, 2024
Merged
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ jobs:


run_and_convert_notebooks_to_docs:
parallelism: 11 # Adjust based on your needs and available resources
parallelism: 12 # Adjust based on your needs and available resources
docker:
- image: 307629990626.dkr.ecr.us-east-1.amazonaws.com/deci/infra/circleci/runner/sg-gpu:<< pipeline.parameters.sg_docker_version >>
resource_class: deci-ai/sg-gpu-on-premise
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ NOTEBOOKS_TO_CHECK += notebooks/DEKR_PoseEstimationFineTuning.ipynb
NOTEBOOKS_TO_CHECK += notebooks/albumentations_tutorial.ipynb
NOTEBOOKS_TO_CHECK += notebooks/yolo_nas_pose_eval_with_pycocotools.ipynb
NOTEBOOKS_TO_CHECK += notebooks/dataloader_adapter.ipynb

NOTEBOOKS_TO_CHECK += notebooks/Segmentation_Model_Export.ipynb

# This Makefile target runs notebooks listed below and converts them to markdown files in documentation/source/
check_notebooks_version_match: $(NOTEBOOKS_TO_CHECK)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ model = models.get("model-name", pretrained_weights="pretrained-model-name")
* [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Deci-AI/super-gradients/blob/master/notebooks/quickstart_segmentation.ipynb) [Segmentation Quick Start](https://github.com/Deci-AI/super-gradients/blob/master/notebooks/quickstart_segmentation.ipynb)
* [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Deci-AI/super-gradients/blob/master/notebooks/transfer_learning_semantic_segmentation.ipynb) [Segmentation Transfer Learning](https://github.com/Deci-AI/super-gradients/blob/master/notebooks/transfer_learning_semantic_segmentation.ipynb)
* [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Deci-AI/super-gradients/blob/master/notebooks/segmentation_connect_custom_dataset.ipynb) [How to Connect Custom Dataset](https://github.com/Deci-AI/super-gradients/blob/master/notebooks/segmentation_connect_custom_dataset.ipynb)
* [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Deci-AI/super-gradients/blob/master/notebooks/Segmentation_Model_Export.ipynb) [How to export segmentation model to ONNX](https://github.com/Deci-AI/super-gradients/blob/master/notebooks/Segmentation_Model_Export.ipynb)


### Pose Estimation
Expand Down
549 changes: 549 additions & 0 deletions notebooks/Segmentation_Model_Export.ipynb

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions src/super_gradients/conversion/export_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ class ExportParams:
:param postprocessing: If True, the postprocessing will be included in the ONNX model.
This option is only available for models that support model.export() syntax.

:param confidence_threshold: The confidence threshold for object detection models.
This option is only available for models that support model.export() syntax.
:param confidence_threshold: The confidence threshold for object detection models
or image binary segmentation models.
This attribute used only for models inheriting ExportableSegmentationModel
and ExportableObjectDetectionModel.
If None, the default confidence threshold for a given model will be used.
:param onnx_export_kwargs: (dict) Optional keyword arguments for torch.onnx.export() function.
:param onnx_simplify: (bool) If True, apply onnx-simplifier to the exported model.
Expand Down
13 changes: 12 additions & 1 deletion src/super_gradients/module_interfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from .pose_estimation_post_prediction_callback import AbstractPoseEstimationPostPredictionCallback, PoseEstimationPredictions
from .supports_input_shape_check import SupportsInputShapeCheck
from .quantization_result import QuantizationResult

from .exportable_segmentation import (
SegmentationModelExportResult,
ExportableSegmentationModel,
AbstractSegmentationDecodingModule,
SemanticSegmentationDecodingModule,
BinarySegmentationDecodingModule,
)

__all__ = [
"HasPredict",
Expand All @@ -24,4 +30,9 @@
"SupportsInputShapeCheck",
"ObjectDetectionModelExportResult",
"QuantizationResult",
"SegmentationModelExportResult",
"ExportableSegmentationModel",
"AbstractSegmentationDecodingModule",
"SemanticSegmentationDecodingModule",
"BinarySegmentationDecodingModule",
]
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from super_gradients.conversion import ExportTargetBackend, ExportQuantizationMode, DetectionOutputFormatMode
from super_gradients.conversion.conversion_utils import find_compatible_model_device_for_dtype
from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_install
from super_gradients.module_interfaces.exceptions import ModelHasNoPreprocessingParamsException
from super_gradients.module_interfaces.supports_input_shape_check import SupportsInputShapeCheck
from super_gradients.training.utils.export_utils import infer_format_from_file_name, infer_image_shape_from_model, infer_image_input_channels
from super_gradients.training.utils.quantization.fix_pytorch_quantization_modules import patch_pytorch_quantization_modules_if_needed
Expand All @@ -31,14 +32,6 @@
]


class ModelHasNoPreprocessingParamsException(Exception):
"""
Exception that is raised when model does not have preprocessing parameters.
"""

pass


class AbstractObjectDetectionDecodingModule(nn.Module):
"""
Abstract class for decoding outputs from object detection models to a tuple of two tensors (boxes, scores)
Expand Down
Loading
Loading