Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Apple CoreML compilation tutorials for any SG model #1007

Closed
wants to merge 43 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
416f6f1
Added Apple CoreML compilation tutorials for any SG model
avideci May 15, 2023
e3daa5e
Merge branch 'master' into feature/SG-000_add_coreml_compilation_note…
avideci May 15, 2023
73341ec
Updated CoreML notebook with CR fixes
avideci May 15, 2023
bb4e300
Merge branch 'feature/SG-000_add_coreml_compilation_notebook' of gith…
avideci May 15, 2023
99871ec
Merge branch 'master' into feature/SG-000_add_coreml_compilation_note…
avideci May 15, 2023
c60e64c
fixed prep_model_for_conversion call
shaydeci May 15, 2023
7502bd7
Merge branch 'master' of github.com:Deci-AI/super-gradients into HEAD
avideci May 18, 2023
0205617
Added coreml exprot method to super_gradients.models, updated CoreML …
avideci May 18, 2023
aee5758
Ran black
avideci May 18, 2023
08a5d89
Ran black
avideci May 18, 2023
f2986ab
Merge branch 'master' into feature/SG-000_add_coreml_compilation_note…
avideci May 22, 2023
2cbfc3d
Replaced exists with isfile
avideci May 22, 2023
9157bd6
Merge branch 'feature/SG-000_add_coreml_compilation_notebook' of gith…
avideci May 22, 2023
d22d4c0
Merge branch 'master' into feature/SG-000_add_coreml_compilation_note…
avideci May 22, 2023
2fcffc6
Remove input_shape and added input_size, added tests
avideci May 22, 2023
3fc785e
Added tests, CR fixes
avideci May 22, 2023
2b16f44
Merge branch 'feature/SG-000_add_coreml_compilation_notebook' of gith…
avideci May 22, 2023
62a5d7f
Fixed dir test in mlpackage export
avideci May 22, 2023
62cdc00
Formetted with black
avideci May 22, 2023
7cb4249
Merge branch 'master' into feature/SG-000_add_coreml_compilation_note…
avideci May 23, 2023
d037b17
Updated CoreML notebook with CR fixes
avideci May 15, 2023
5b7a589
Fix train_loader not initalized properly (#981)
Louis-Dupont May 11, 2023
8e06ce7
empty mask (#982)
lkdci May 11, 2023
ae74b35
changed requirement (#1002)
shaydeci May 14, 2023
b956f1c
Fix convert_to_onnx to correctly handle case when input_shape is None…
BloodAxe May 15, 2023
0f7f633
Feature/sg 757 resume for spots (#870)
shaydeci May 15, 2023
67adb2d
fixed prep_model_for_conversion call
shaydeci May 15, 2023
4892bfd
Cityscapes AutoLabelling dataset (#1000)
lkdci May 15, 2023
33bfebb
Predict on fused model (#998)
Louis-Dupont May 16, 2023
80c2b5e
Proposal of issue template (#1018)
Louis-Dupont May 17, 2023
d97d795
Added coreml exprot method to super_gradients.models, updated CoreML …
avideci May 18, 2023
689289b
Ran black
avideci May 18, 2023
05a8c69
Ran black
avideci May 18, 2023
d6fb152
Replaced exists with isfile
avideci May 22, 2023
95c04f4
Fix doc (#1019)
Louis-Dupont May 18, 2023
e4f9885
Fixed missing encoding (Issue #999) (#1044)
T0T4R4 May 21, 2023
92143e4
Remove input_shape and added input_size, added tests
avideci May 22, 2023
07d06e3
Added tests, CR fixes
avideci May 22, 2023
851c761
Cache model from platform locally (#1009)
BloodAxe May 22, 2023
e77836c
Fixed dir test in mlpackage export
avideci May 22, 2023
cd5b8d7
Formetted with black
avideci May 22, 2023
49d89d9
Feature/infra 000 nightly (#1051)
shaydeci May 22, 2023
e14634c
Merge branch 'feature/SG-000_add_coreml_compilation_notebook' of gith…
avideci May 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/super_gradients/training/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
import super_gradients.training.models.user_models as user_models
from super_gradients.training.models.model_factory import get, get_model_name
from super_gradients.training.models.arch_params_factory import get_arch_params
from super_gradients.training.models.conversion import convert_to_onnx, convert_from_config
from super_gradients.training.models.conversion import convert_to_coreml, convert_to_onnx, convert_from_config


from super_gradients.common.object_names import Models
Expand Down Expand Up @@ -280,6 +280,7 @@
"get",
"get_model_name",
"get_arch_params",
"convert_to_coreml",
"convert_to_onnx",
"convert_from_config",
"ARCHITECTURES",
Expand Down
126 changes: 118 additions & 8 deletions src/super_gradients/training/models/conversion.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
import os
import pathlib
from pathlib import Path

import hydra
import numpy as np
import onnx
import torch
from omegaconf import DictConfig
import numpy as np
from onnxsim import simplify
from torch.nn import Identity

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.environment.cfg_utils import load_experiment_cfg
from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.training import models
from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path
from super_gradients.common.environment.cfg_utils import load_experiment_cfg
from super_gradients.training.utils.sg_trainer_utils import parse_args
import os
import pathlib

from onnxsim import simplify
import onnx

logger = get_logger(__name__)

ct = None

try:
import coremltools as coreml_tools

ct = coreml_tools
except (ImportError, ModuleNotFoundError):
pass


class ConvertableCompletePipelineModel(torch.nn.Module):
"""
Expand Down Expand Up @@ -48,6 +56,108 @@ def forward(self, x):
return self.post_process(self.model(self.pre_process(x)))


@resolve_param("pre_process", TransformsFactory())
@resolve_param("post_process", TransformsFactory())
def convert_to_coreml(
model: torch.nn.Module,
out_path: str,
input_size: tuple = None,
pre_process: torch.nn.Module = None,
post_process: torch.nn.Module = None,
prep_model_for_conversion_kwargs=None,
export_as_ml_program=False,
torch_trace_kwargs=None,
):
"""
Exports a given SG model to CoreML mlprogram or package.

:param model: torch.nn.Module, model to export to ONNX.
:param out_path: str, destination path for the .onnx file.
:param input_size: Input shape without batch dimensions ([C,H,W]). Batch size assumed to be 1.
:param pre_process: torch.nn.Module, preprocessing pipeline, will be resolved by TransformsFactory()
:param post_process: torch.nn.Module, postprocessing pipeline, will be resolved by TransformsFactory()
:param prep_model_for_conversion_kwargs: dict, for SgModules- args to be passed to model.prep_model_for_conversion
prior to torch.onnx.export call. Supported keys are:
- input_size - Shape of inputs with batch dimension, [C,H,W] for image inputs.
When true, the simplified model will be saved in out_path (default=True).
:param export_as_ml_program: Whether to convert to the new program format (better) or legacy coreml proto file
(Supports more iOS versions and devices, but this format will be deprecated at some point).
:param torch_trace_kwargs: kwargs for torch.jit.trace
:return: Path
"""
if ct is None:
raise ImportError(
'"coremltools" is required for CoreML export, but is not installed. Please install CoreML Tools using:\n'
' "python3 -m pip install coremltools" and try again (Tested with version 6.3.0);'
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
)

logger.debug("Building model...")
logger.debug(model)
logger.debug("Model child nodes:")
logger.debug(next(model.named_children()))

if not os.path.isdir(pathlib.Path(out_path).parent.resolve()):
raise FileNotFoundError(f"Could not find destination directory {out_path} for the ONNX file.")
torch_trace_kwargs = torch_trace_kwargs or dict()
prep_model_for_conversion_kwargs = prep_model_for_conversion_kwargs or dict()

if input_size is not None:
input_size = (1, *input_size)
logger.warning(
f"input_shape is deprecated and will be removed in the next major release."
f"Use the convert_to_coreml(..., prep_model_for_conversion_kwargs(input_size={input_size})) instead"
)
prep_model_for_conversion_kwargs["input_size"] = input_size

if "input_size" not in prep_model_for_conversion_kwargs:
raise KeyError("input_size must be provided in prep_model_for_conversion_kwargs")

input_size = prep_model_for_conversion_kwargs["input_size"]

# TODO: support more than 1 input when prep_for_conversoin will support it.
example_inputs = [torch.Tensor(np.zeros(input_size))]

if not out_path.endswith(".mlpackage") and not out_path.endswith(".mlmodel"):
out_path += ".mlpackage" if export_as_ml_program else ".mlmodel"

complete_model = ConvertableCompletePipelineModel(model, pre_process, post_process, **prep_model_for_conversion_kwargs)

# Set the model in evaluation mode.
complete_model.eval()

logger.info("Creating torch jit trace...")
traced_model = torch.jit.trace(complete_model, example_inputs, **torch_trace_kwargs)
logger.info("Tracing the model with the provided inputs...")
out = traced_model(*example_inputs) # using * because example_inputs is a list
logger.info(f"Inferred output shapes: {[o.shape for o in out]}")
if export_as_ml_program:
coreml_model = ct.convert(
traced_model, convert_to="mlprogram", inputs=[ct.ImageType(name=f"x_{i + 1}", shape=_.shape) for i, _ in enumerate(example_inputs)]
)
else:
coreml_model = ct.convert(traced_model, inputs=[ct.ImageType(name=f"x_{i + 1}", shape=_.shape) for i, _ in enumerate(example_inputs)])

spec = coreml_model.get_spec()
logger.debug(spec.description)

# Changing the input names:
# In CoreML, the input name is compiled into classes (named keyword argument in predict).
# We want to re-use the same names among different models to make research easier.
# We normalize the inputs names to be x_1, x_2, etc.
for i, _input in enumerate(spec.description.input):
new_input_name = "x_" + str(i + 1)
logger.info(f"Renaming input {_input.name} to {new_input_name}")
ct.utils.rename_feature(spec, _input.name, new_input_name)

# Re-Initializing the model with the new spec
coreml_model = ct.models.MLModel(spec, weights_dir=coreml_model.weights_dir)

# Saving the model
coreml_model.save(out_path)
logger.info(f"CoreML model successfully save to {os.path.abspath(out_path)}")
return out_path


@resolve_param("pre_process", TransformsFactory())
@resolve_param("post_process", TransformsFactory())
def convert_to_onnx(
Expand Down
56 changes: 56 additions & 0 deletions tests/unit_tests/export_coreml_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import tempfile
import unittest

from torchvision.transforms import Compose, Normalize, Resize

from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training.transforms import Standardize


class TestModelsCoreMLExport(unittest.TestCase):
def test_models_onnx_export_with_explicit_input_size(self):
pretrained_model = models.get(Models.RESNET18, num_classes=1000, pretrained_weights="imagenet")
preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
with tempfile.TemporaryDirectory() as tmpdirname:
out_path = os.path.join(tmpdirname, "resnet18.mlmodel")
models.convert_to_coreml(model=pretrained_model, out_path=out_path, input_size=(3, 256, 256), pre_process=preprocess)
self.assertTrue(os.path.isfile(out_path))

def test_models_onnx_export_without_explicit_input_size_raises_error(self):
pretrained_model = models.get(Models.RESNET18, num_classes=1000, pretrained_weights="imagenet")
preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
with self.assertRaises(KeyError):
models.convert_to_coreml(model=pretrained_model, out_path="some-output-path.coreml", pre_process=preprocess)

def test_models_coreml_export(self, **export_kwargs):
pretrained_model = models.get(Models.YOLO_NAS_S, num_classes=1000, pretrained_weights="coco")

# Just for the sake of testing, not really COCO preprocessing
preprocess = Compose([Resize(224), Standardize(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
with tempfile.TemporaryDirectory() as tmpdirname:
out_path = os.path.join(tmpdirname, "yolo_nas_s")
model_path = models.convert_to_coreml(
model=pretrained_model,
out_path=out_path,
pre_process=preprocess,
prep_model_for_conversion_kwargs=dict(input_size=(1, 3, 640, 640)),
**export_kwargs,
)

if export_kwargs.get("export_as_ml_program"):
# Expecting a directory
self.assertTrue(os.path.isdir(model_path))
self.assertTrue(model_path.endswith(".mlpackage"))
else:
# Expecting a single file
self.assertTrue(os.path.isfile(model_path))
self.assertTrue(model_path.endswith(".mlmodel"))

def test_models_coreml_export_as_mlprogram(self):
self.test_models_coreml_export(export_as_ml_program=True)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions tests/unit_tests/export_onnx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_models_onnx_export_with_deprecated_input_shape(self):
with tempfile.TemporaryDirectory() as tmpdirname:
out_path = os.path.join(tmpdirname, "resnet18.onnx")
models.convert_to_onnx(model=pretrained_model, out_path=out_path, input_shape=(3, 256, 256), pre_process=preprocess)
self.assertTrue(os.path.exists(out_path))
avideci marked this conversation as resolved.
Show resolved Hide resolved

def test_models_onnx_export(self):
pretrained_model = models.get(Models.RESNET18, num_classes=1000, pretrained_weights="imagenet")
Expand All @@ -24,6 +25,7 @@ def test_models_onnx_export(self):
models.convert_to_onnx(
model=pretrained_model, out_path=out_path, pre_process=preprocess, prep_model_for_conversion_kwargs=dict(input_size=(1, 3, 640, 640))
)
self.assertTrue(os.path.exists(out_path))
avideci marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
Expand Down
Loading