diff --git a/README.md b/README.md
index 8c4fe73f..73733ea0 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
[](https://github.com/onnx/turnkeyml/tree/main/test "Check out our tests")
[](https://github.com/onnx/turnkeyml/tree/main/test "Check out our tests")
-[](https://github.com/onnx/turnkeyml/blob/main/docs/install.md "Check out our instructions")
+[](https://github.com/onnx/turnkeyml/blob/main/docs/install.md "Check out our instructions")
[](https://github.com/onnx/turnkeyml/blob/main/docs/install.md "Check out our instructions")
diff --git a/docs/install.md b/docs/install.md
index 020f2b1f..66b7c224 100644
--- a/docs/install.md
+++ b/docs/install.md
@@ -20,6 +20,13 @@ bash Miniconda3-latest-Linux-x86_64.sh
If you are installing TurnkeyML on **Windows**, manually download and install [Miniconda3 for Windows 64-bit](https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe). Please note that PowerShell is recommended when using miniconda on Windows.
+
+If you are installing TurnkeyML on **MacOS**, run the command below:
+```
+wget https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh
+Miniconda3-latest-MacOSX-x86_64.sh
+```
+
Then create and activate a virtual environment like this:
```
diff --git a/docs/tools_user_guide.md b/docs/tools_user_guide.md
index 53d24635..29c2463a 100644
--- a/docs/tools_user_guide.md
+++ b/docs/tools_user_guide.md
@@ -12,11 +12,14 @@ The tools currently support the following combinations of runtimes and devices:
| ----------- | ---------- | ------------------------------------------------------------------------------------- | -------------------------------- | --------------------------------------------- |
| Nvidia GPU | nvidia | TensorRT† | trt | Any Nvidia GPU supported by TensorRT |
| x86 CPU | x86 | ONNX Runtime‡, Pytorch Eager, Pytoch 2.x Compiled | ort, torch-eager, torch-compiled | Any Intel or AMD CPU supported by the runtime |
+| Apple Silicon | apple_silicon | CoreML◊, ONNX Runtime‡, Pytorch Eager | coreml, ort, torch-eager | Any Apple M* Silicon supported by the runtime |
+
† Requires TensorRT >= 8.5.2
‡ Requires ONNX Runtime >= 1.13.1
* Requires Pytorch >= 2.0.0
+◊ Requires CoreML >= 7.1
# Table of Contents
@@ -252,6 +255,9 @@ Each device type has its own default runtime, as indicated below.
- `torch-compiled`: PyTorch 2.x-style compiled graph execution using TorchInductor.
- Valid runtimes for `nvidia` device
- `trt`: Nvidia TensorRT (default).
+- Valid runtimes for `apple_silicon` device
+ - `coreml`: CoreML (default).
+ - `ort`: ONNX Runtime.
This feature is also be available as an API argument:
- `benchmark_files(runtime=[...])`
diff --git a/setup.py b/setup.py
index 381b340a..76987596 100644
--- a/setup.py
+++ b/setup.py
@@ -18,6 +18,7 @@
"turnkeyml.run.onnxrt",
"turnkeyml.run.tensorrt",
"turnkeyml.run.torchrt",
+ "turnkeyml.run.coreml",
"turnkeyml.cli",
"turnkeyml.common",
"turnkeyml_models",
@@ -46,6 +47,7 @@
"pandas>=1.5.3",
"fasteners",
"GitPython>=3.1.40",
+ "coremltools>=7.1",
],
extras_require={
"tensorflow": [
diff --git a/src/turnkeyml/build/export.py b/src/turnkeyml/build/export.py
index eee8a306..9f7eeb4a 100644
--- a/src/turnkeyml/build/export.py
+++ b/src/turnkeyml/build/export.py
@@ -10,6 +10,7 @@
import onnxruntime
import onnxmltools
import onnx
+import coremltools as ct
import turnkeyml.build.stage as stage
import turnkeyml.common.exceptions as exp
import turnkeyml.common.build as build
@@ -40,6 +41,20 @@ def _warn_to_stdout(message, category, filename, line_number, _, line):
)
+def validate_torch_args(state: build.State) -> None:
+ """
+ Ensure that the inputs received match the model's forward function
+ """
+ all_args = list(inspect.signature(state.model.forward).parameters.keys())
+ for inp in list(state.inputs.keys()):
+ if inp not in all_args:
+ msg = f"""
+ Input name {inp} not found in the model's forward method. Available
+ input names are: {all_args}"
+ """
+ raise ValueError(msg)
+
+
def get_output_names(
onnx_model: Union[str, onnx.ModelProto]
): # pylint: disable=no-member
@@ -62,6 +77,13 @@ def base_onnx_file(state: build.State):
)
+def base_coreml_file(state: build.State):
+ return os.path.join(
+ onnx_dir(state),
+ f"{state.config.build_name}-op{state.config.onnx_opset}-base.mlmodel",
+ )
+
+
def opt_onnx_file(state: build.State):
return os.path.join(
onnx_dir(state),
@@ -234,16 +256,7 @@ def fire(self, state: build.State):
user_provided_args = list(state.inputs.keys())
if isinstance(state.model, torch.nn.Module):
- # Validate user provided args
- all_args = list(inspect.signature(state.model.forward).parameters.keys())
-
- for inp in user_provided_args:
- if inp not in all_args:
- msg = f"""
- Input name {inp} not found in the model's forward method. Available
- input names are: {all_args}"
- """
- raise ValueError(msg)
+ validate_torch_args(state)
# Most pytorch models have args that are kind = positional_or_keyword.
# The `torch.onnx.export()` function accepts model args as
@@ -252,6 +265,7 @@ def fire(self, state: build.State):
# the order of the input_names must reflect the order of the model args.
# Collect order of pytorch model args.
+ all_args = list(inspect.signature(state.model.forward).parameters.keys())
all_args_order_mapping = {arg: idx for idx, arg in enumerate(all_args)}
# Sort the user provided inputs with respect to model args and store as tuple.
@@ -620,3 +634,69 @@ def fire(self, state: build.State):
raise exp.StageError(msg)
return state
+
+
+class ExportToCoreML(stage.Stage):
+ """
+ Stage that takes a Pytorch model and inputs and converts to CoreML format.
+
+ Expected inputs:
+ - state.model is a torch.nn.Module or torch.jit.ScriptModule
+ - state.inputs is a dict that represents valid kwargs to the forward
+ function of state.model
+ Outputs:
+ - A *.mlmodel file
+ """
+
+ def __init__(self):
+ super().__init__(
+ unique_name="coreml_conversion",
+ monitor_message="Converting to CoreML",
+ )
+
+ def fire(self, state: build.State):
+ if not isinstance(state.model, (torch.nn.Module, torch.jit.ScriptModule)):
+ msg = f"""
+ The current stage (ExportToCoreML) is only compatible with
+ models of type torch.nn.Module or torch.jit.ScriptModule, however
+ the stage received a model of type {type(state.model)}.
+ """
+ raise exp.StageError(msg)
+
+ if isinstance(state.model, torch.nn.Module):
+ validate_torch_args(state)
+
+ # Send warnings to stdout (and therefore the log file)
+ default_warnings = warnings.showwarning
+ warnings.showwarning = _warn_to_stdout
+
+ # Generate a TorchScript Version
+ dummy_inputs = copy.deepcopy(state.inputs)
+ traced_model = torch.jit.trace(state.model, example_kwarg_inputs=dummy_inputs)
+
+ # Export the model to CoreML
+ output_path = base_coreml_file(state)
+ os.makedirs(onnx_dir(state), exist_ok=True)
+ coreml_model = ct.convert(
+ traced_model,
+ inputs=[ct.TensorType(shape=inp.shape) for inp in dummy_inputs.values()],
+ convert_to="neuralnetwork",
+ )
+
+ # Save the CoreML model
+ coreml_model.save(output_path)
+
+ # Save output names to ensure we are preserving the order of the outputs
+ state.expected_output_names = get_output_names(output_path)
+
+ # Restore default warnings behavior
+ warnings.showwarning = default_warnings
+
+ tensor_helpers.save_inputs(
+ [state.inputs], state.original_inputs_file, downcast=False
+ )
+
+ # Save intermediate results
+ state.intermediate_results = [output_path]
+
+ return state
diff --git a/src/turnkeyml/build/sequences.py b/src/turnkeyml/build/sequences.py
index abeb159c..d4e8198d 100644
--- a/src/turnkeyml/build/sequences.py
+++ b/src/turnkeyml/build/sequences.py
@@ -32,6 +32,15 @@
enable_model_validation=True,
)
+coreml = stage.Sequence(
+ "coreml",
+ "CoreML Sequence",
+ [
+ export.ExportToCoreML(),
+ ],
+ enable_model_validation=True,
+)
+
# Plugin interface for sequences
discovered_plugins = plugins.discover()
@@ -40,6 +49,7 @@
"optimize-fp16": optimize_fp16,
"optimize-fp32": optimize_fp32,
"onnx-fp32": onnx_fp32,
+ "coreml": coreml,
}
# Add sequences from plugins to supported sequences dict
diff --git a/src/turnkeyml/run/basert.py b/src/turnkeyml/run/basert.py
index 05b69caf..30636ed1 100644
--- a/src/turnkeyml/run/basert.py
+++ b/src/turnkeyml/run/basert.py
@@ -53,6 +53,8 @@ def __init__(
requires_docker: bool = False,
tensor_type=np.array,
execute_function: Optional[callable] = None,
+ model_filename="model.onnx",
+ model_dirname="onnxmodel",
):
self.tensor_type = tensor_type
self.cache_dir = cache_dir
@@ -71,8 +73,8 @@ def __init__(
self.runtime_version = runtime_version
self.model = model
self.inputs = inputs
- self.onnx_filename = "model.onnx"
- self.onnx_dirname = "onnxmodel"
+ self.model_filename = model_filename
+ self.model_dirname = model_dirname
self.outputs_filename = "outputs.json"
self.runtimes_supported = runtimes_supported
self.execute_function = execute_function
@@ -80,7 +82,7 @@ def __init__(
# Validate runtime is supported
if runtime not in runtimes_supported:
raise ValueError(
- f"'runtime' argument {runtime} passed to TensorRT, which only "
+ f"'runtime' argument {runtime} passed to a runtime that only "
f"supports runtimes: {runtimes_supported}"
)
@@ -102,23 +104,23 @@ def local_output_dir(self):
)
@property
- def local_onnx_dir(self):
- return os.path.join(self.local_output_dir, self.onnx_dirname)
+ def local_model_dir(self):
+ return os.path.join(self.local_output_dir, self.model_dirname)
@property
def docker_onnx_dir(self):
return self.posix_path_format(
- os.path.join(self.docker_output_dir, self.onnx_dirname)
+ os.path.join(self.docker_output_dir, self.model_dirname)
)
@property
- def local_onnx_file(self):
- return os.path.join(self.local_onnx_dir, self.onnx_filename)
+ def local_model_file(self):
+ return os.path.join(self.local_model_dir, self.model_filename)
@property
- def docker_onnx_file(self):
+ def docker_model_file(self):
return self.posix_path_format(
- os.path.join(self.docker_onnx_dir, self.onnx_filename)
+ os.path.join(self.docker_onnx_dir, self.model_filename)
)
@property
@@ -183,16 +185,16 @@ def benchmark(self) -> MeasuredPerformance:
raise exp.ModelRuntimeError(msg)
os.makedirs(self.local_output_dir, exist_ok=True)
- os.makedirs(self.local_onnx_dir, exist_ok=True)
- shutil.copy(model_file, self.local_onnx_file)
+ os.makedirs(self.local_model_dir, exist_ok=True)
+ shutil.copy(model_file, self.local_model_file)
# Execute benchmarking in hardware
if self.requires_docker:
_check_docker_install()
- onnx_file = self.docker_onnx_file
+ onnx_file = self.docker_model_file
_check_docker_running()
else:
- onnx_file = self.local_onnx_file
+ onnx_file = self.local_model_file
self._execute(
output_dir=self.local_output_dir,
diff --git a/src/turnkeyml/run/coreml/__init__.py b/src/turnkeyml/run/coreml/__init__.py
new file mode 100644
index 00000000..8f62e40b
--- /dev/null
+++ b/src/turnkeyml/run/coreml/__init__.py
@@ -0,0 +1,13 @@
+import turnkeyml.build.sequences as sequences
+from .runtime import CoreML
+
+implements = {
+ "runtimes": {
+ "coreml": {
+ "build_required": True,
+ "RuntimeClass": CoreML,
+ "supported_devices": {"apple_silicon"},
+ "default_sequence": sequences.coreml,
+ }
+ }
+}
diff --git a/src/turnkeyml/run/coreml/execute.py b/src/turnkeyml/run/coreml/execute.py
new file mode 100644
index 00000000..f4259acc
--- /dev/null
+++ b/src/turnkeyml/run/coreml/execute.py
@@ -0,0 +1,128 @@
+"""
+The following script is used to get the latency and outputs of a given run on the x86 CPUs.
+"""
+# pylint: disable = no-name-in-module
+# pylint: disable = import-error
+import os
+import subprocess
+import json
+from statistics import mean
+import platform
+import turnkeyml.run.plugin_helpers as plugin_helpers
+
+COREML_VERSION = "7.1"
+
+BATCHSIZE = 1
+
+
+def create_conda_env(conda_env_name: str):
+ """Create a Conda environment with the given name and install requirements."""
+ conda_path = os.getenv("CONDA_EXE")
+ if conda_path is None:
+ raise EnvironmentError(
+ "CONDA_EXE environment variable not set."
+ "Make sure Conda is properly installed."
+ )
+
+ env_path = os.path.join(
+ os.path.dirname(os.path.dirname(conda_path)), "envs", conda_env_name
+ )
+
+ # Only create the environment if it does not already exist
+ if not os.path.exists(env_path):
+ plugin_helpers.run_subprocess(
+ [
+ conda_path,
+ "create",
+ "--name",
+ conda_env_name,
+ "python=3.8",
+ "-y",
+ ]
+ )
+
+ # Using conda run to execute pip install within the environment
+ setup_cmd = [
+ conda_path,
+ "run",
+ "--name",
+ conda_env_name,
+ "pip",
+ "install",
+ f"coremltools=={COREML_VERSION}",
+ ]
+ plugin_helpers.run_subprocess(setup_cmd)
+
+
+def execute_benchmark(
+ coreml_file_path: str,
+ outputs_file: str,
+ output_dir: str,
+ conda_env_name: str,
+ iterations: int,
+):
+ """Execute the benchmark script and retrieve the output."""
+
+ python_in_env = plugin_helpers.get_python_path(conda_env_name)
+ iterations_file = os.path.join(output_dir, "per_iteration_latency.json")
+ benchmarking_log_file = os.path.join(output_dir, "coreml_benchmarking_log.txt")
+
+ cmd = [
+ python_in_env,
+ os.path.join(output_dir, "within_conda.py"),
+ "--coreml-file",
+ coreml_file_path,
+ "--iterations",
+ str(iterations),
+ "--iterations-file",
+ iterations_file,
+ ]
+
+ # Execute command and log stdout/stderr
+ plugin_helpers.logged_subprocess(
+ cmd=cmd,
+ cwd=os.path.dirname(output_dir),
+ log_to_std_streams=False,
+ log_to_file=True,
+ log_file_path=benchmarking_log_file,
+ )
+
+ # Parse per-iteration performance results and save aggregated results to a json file
+ if os.path.exists(iterations_file):
+ with open(iterations_file, "r", encoding="utf-8") as f:
+ per_iteration_latency = json.load(f)
+ else:
+ raise ValueError(
+ f"Execution of command {cmd} failed, see {benchmarking_log_file}"
+ )
+
+ cpu_performance = get_cpu_specs()
+ cpu_performance["CoreML Version"] = str(COREML_VERSION)
+ cpu_performance["Mean Latency(ms)"] = str(mean(per_iteration_latency) * 1000)
+ cpu_performance["Throughput"] = str(BATCHSIZE / mean(per_iteration_latency))
+ cpu_performance["Min Latency(ms)"] = str(min(per_iteration_latency) * 1000)
+ cpu_performance["Max Latency(ms)"] = str(max(per_iteration_latency) * 1000)
+
+ with open(outputs_file, "w", encoding="utf-8") as out_file:
+ json.dump(cpu_performance, out_file, ensure_ascii=False, indent=4)
+
+
+def get_cpu_specs() -> dict:
+ # Check the operating system and define the command accordingly
+ if platform.system() != "Darwin":
+ raise OSError("You must se MacOS to run models with CoreML.")
+
+ cpu_info_command = "sysctl -n machdep.cpu.brand_string"
+ cpu_info = subprocess.Popen(cpu_info_command.split(), stdout=subprocess.PIPE)
+ cpu_info_output, _ = cpu_info.communicate()
+ if not cpu_info_output:
+ raise EnvironmentError(
+ f"Could not get CPU info using '{cpu_info_command.split()[0]}'. "
+ "Please make sure this tool is correctly installed on your system before continuing."
+ )
+
+ # Store CPU specifications
+ decoded_info = cpu_info_output.decode().strip().split("\n")
+ cpu_spec = {"CPU Name": decoded_info[0]}
+
+ return cpu_spec
diff --git a/src/turnkeyml/run/coreml/runtime.py b/src/turnkeyml/run/coreml/runtime.py
new file mode 100644
index 00000000..0ea39681
--- /dev/null
+++ b/src/turnkeyml/run/coreml/runtime.py
@@ -0,0 +1,149 @@
+import platform
+import os
+import shutil
+import numpy as np
+from turnkeyml.run.basert import BaseRT
+import turnkeyml.common.exceptions as exp
+from turnkeyml.run.coreml.execute import COREML_VERSION
+from turnkeyml.common.filesystem import Stats, rebase_cache_dir
+import turnkeyml.common.build as build
+from turnkeyml.common.performance import MeasuredPerformance
+from turnkeyml.run.coreml.execute import create_conda_env, execute_benchmark
+import turnkeyml.run.plugin_helpers as plugin_helpers
+
+
+class CoreML(BaseRT):
+ def __init__(
+ self,
+ cache_dir: str,
+ build_name: str,
+ stats: Stats,
+ iterations: int,
+ device_type: str,
+ runtime: str = "coreml",
+ tensor_type=np.array,
+ model=None,
+ inputs=None,
+ ):
+ super().__init__(
+ cache_dir=cache_dir,
+ build_name=build_name,
+ stats=stats,
+ tensor_type=tensor_type,
+ device_type=device_type,
+ iterations=iterations,
+ runtime=runtime,
+ runtimes_supported=["coreml"],
+ runtime_version=COREML_VERSION,
+ base_path=os.path.dirname(__file__),
+ model=model,
+ inputs=inputs,
+ requires_docker=False,
+ model_filename="model.mlmodel",
+ model_dirname="mlmodel",
+ )
+
+ def _setup(self):
+ # Check OS
+ if platform.system() != "Darwin":
+ msg = "Only MacOS is supported for CoreML Runtime"
+ raise exp.ModelRuntimeError(msg)
+
+ # Check silicon
+ if "Apple M" not in self.device_name:
+ msg = (
+ "You need an 'Apple M*' processor to run using apple_silicon "
+ f", got '{self.device_name}'"
+ )
+ raise exp.ModelRuntimeError(msg)
+
+ self._transfer_files([self.conda_script])
+
+ def benchmark(self) -> MeasuredPerformance:
+ """
+ Transfer input artifacts, execute model on hardware, analyze output artifacts,
+ and return the performance.
+ """
+
+ # Remove previous benchmarking artifacts
+ if os.path.exists(self.local_outputs_file):
+ os.remove(self.local_outputs_file)
+
+ # Transfer input artifacts
+ state = build.load_state(self.cache_dir, self.build_name)
+
+ # Just in case the model file was generated on a different machine:
+ # strip the state's cache dir, then prepend the current cache dir
+ model_file = rebase_cache_dir(
+ state.results[0], state.config.build_name, self.cache_dir
+ )
+
+ if not os.path.exists(model_file):
+ msg = "Model file not found"
+ raise exp.ModelRuntimeError(msg)
+
+ os.makedirs(self.local_output_dir, exist_ok=True)
+ os.makedirs(self.local_model_dir, exist_ok=True)
+ shutil.copy(model_file, self.local_model_file)
+
+ # Execute benchmarking in hardware
+ self._execute(
+ output_dir=self.local_output_dir,
+ coreml_file_path=self.local_model_file,
+ outputs_file=self.local_outputs_file,
+ )
+
+ if not os.path.isfile(self.local_outputs_file):
+ raise exp.BenchmarkException(
+ "No benchmarking outputs file found after benchmarking run. "
+ "Sorry we don't have more information."
+ )
+
+ # Call property methods to analyze the output artifacts for performance stats
+ # and return them
+ return MeasuredPerformance(
+ mean_latency=self.mean_latency,
+ throughput=self.throughput,
+ device=self.device_name,
+ device_type=self.device_type,
+ runtime=self.runtime,
+ runtime_version=self.runtime_version,
+ build_name=self.build_name,
+ )
+
+ def _execute(
+ self,
+ output_dir: str,
+ coreml_file_path: str,
+ outputs_file: str,
+ ):
+ conda_env_name = "turnkey-coreml-ep"
+
+ try:
+ # Create and setup the conda env
+ create_conda_env(conda_env_name)
+ except Exception as e:
+ raise plugin_helpers.CondaError(
+ f"Conda env setup failed with exception: {e}"
+ )
+
+ # Execute the benchmark script in the conda environment
+ execute_benchmark(
+ coreml_file_path=coreml_file_path,
+ outputs_file=outputs_file,
+ output_dir=output_dir,
+ conda_env_name=conda_env_name,
+ iterations=self.iterations,
+ )
+
+ @property
+ def mean_latency(self):
+ return float(self._get_stat("Mean Latency(ms)"))
+
+ @property
+ def throughput(self):
+ return float(self._get_stat("Throughput"))
+
+ @property
+ def device_name(self):
+ return self._get_stat("CPU Name")
diff --git a/src/turnkeyml/run/coreml/within_conda.py b/src/turnkeyml/run/coreml/within_conda.py
new file mode 100644
index 00000000..7a8e91a2
--- /dev/null
+++ b/src/turnkeyml/run/coreml/within_conda.py
@@ -0,0 +1,66 @@
+import argparse
+import json
+import os
+import time
+from pathlib import Path
+import numpy as np
+import coremltools as ct
+
+
+def run_coreml_profile(
+ coreml_file_path: str,
+ iterations_file: str,
+ iterations: int,
+):
+ # Run the provided onnx model using onnxruntime and measure average latency
+
+ per_iteration_latency = []
+
+ # Load the CoreML model
+ model = ct.models.MLModel(coreml_file_path)
+
+ # Get inputs
+ inputs_path = os.path.join(Path(coreml_file_path).parents[2], "inputs.npy")
+ input_data = np.load(inputs_path, allow_pickle=True)[0]
+
+ # Change input keys to match model
+ input_data = {key + "_1": value for key, value in input_data.items()}
+
+ # Run model for a certain number of iterations
+ for _ in range(iterations):
+ start = time.perf_counter()
+ model.predict(input_data)
+ end = time.perf_counter()
+ iteration_latency = end - start
+ per_iteration_latency.append(iteration_latency)
+
+ with open(iterations_file, "w", encoding="utf-8") as out_file:
+ json.dump(per_iteration_latency, out_file, ensure_ascii=False, indent=4)
+
+
+if __name__ == "__main__":
+ # Parse Inputs
+ parser = argparse.ArgumentParser(description="Execute models using coreml")
+ parser.add_argument(
+ "--coreml-file",
+ required=True,
+ help="Path where the coreml file is located",
+ )
+ parser.add_argument(
+ "--iterations-file",
+ required=True,
+ help="File in which to place the per-iteration execution timings",
+ )
+ parser.add_argument(
+ "--iterations",
+ required=True,
+ type=int,
+ help="Number of times to execute the received onnx model",
+ )
+ args = parser.parse_args()
+
+ run_coreml_profile(
+ coreml_file_path=args.coreml_file,
+ iterations_file=args.iterations_file,
+ iterations=args.iterations,
+ )
diff --git a/src/turnkeyml/run/devices.py b/src/turnkeyml/run/devices.py
index c48d81b5..a7db351b 100644
--- a/src/turnkeyml/run/devices.py
+++ b/src/turnkeyml/run/devices.py
@@ -3,6 +3,7 @@
import turnkeyml.run.onnxrt as onnxrt
import turnkeyml.run.tensorrt as tensorrt
import turnkeyml.run.torchrt as torchrt
+import turnkeyml.run.coreml as coreml
import turnkeyml.common.plugins as plugins
from turnkeyml.build.stage import Sequence
import turnkeyml.build.sequences as sequences
@@ -30,7 +31,7 @@ def supported_devices_list(data: Dict, parent_key: str = "") -> List:
# Note: order matters here. We append the discovered_plugins after builtin so
# that the default runtime for each device will come from a builtin, whenever
# available.
-builtin_runtimes = [onnxrt, tensorrt, torchrt]
+builtin_runtimes = [onnxrt, tensorrt, torchrt, coreml]
plugin_modules = builtin_runtimes + list(discovered_plugins.values())
SUPPORTED_RUNTIMES = {}
diff --git a/src/turnkeyml/run/onnxrt/__init__.py b/src/turnkeyml/run/onnxrt/__init__.py
index 6ecc93de..86e14ef1 100644
--- a/src/turnkeyml/run/onnxrt/__init__.py
+++ b/src/turnkeyml/run/onnxrt/__init__.py
@@ -6,7 +6,7 @@
"ort": {
"build_required": True,
"RuntimeClass": OnnxRT,
- "supported_devices": {"x86"},
+ "supported_devices": {"x86","apple_silicon"},
"default_sequence": sequences.optimize_fp32,
}
}
diff --git a/src/turnkeyml/run/onnxrt/execute.py b/src/turnkeyml/run/onnxrt/execute.py
index bdf3e60c..faf401b5 100644
--- a/src/turnkeyml/run/onnxrt/execute.py
+++ b/src/turnkeyml/run/onnxrt/execute.py
@@ -126,7 +126,8 @@ def get_cpu_specs() -> dict:
}
# Check the operating system and define the command accordingly
- if platform.system() == "Windows":
+ system_platform = platform.system()
+ if system_platform == "Windows":
cpu_info_command = (
"wmic CPU get Architecture,Manufacturer,MaxClockSpeed,"
"Name,NumberOfCores /format:list"
@@ -136,6 +137,9 @@ def get_cpu_specs() -> dict:
cpu_info_command, stdout=subprocess.PIPE, shell=True
)
separator = "="
+ elif system_platform == "Darwin":
+ cpu_info_command = "sysctl -n machdep.cpu.brand_string"
+ cpu_info = subprocess.Popen(cpu_info_command.split(), stdout=subprocess.PIPE)
else:
cpu_info_command = "lscpu"
cpu_info = subprocess.Popen(cpu_info_command.split(), stdout=subprocess.PIPE)
@@ -151,17 +155,20 @@ def get_cpu_specs() -> dict:
decoded_info = (
cpu_info_output.decode()
.strip()
- .split("\r\n" if platform.system() == "Windows" else "\n")
+ .split("\r\n" if system_platform == "Windows" else "\n")
)
# Initialize an empty dictionary to hold the CPU specifications
cpu_spec = {}
- for line in decoded_info:
- key, value = line.split(separator, 1)
- # Get the corresponding key from the field mapping
- key = field_mapping.get(key.strip())
- if key:
- # Add the key and value to the CPU specifications dictionary
- cpu_spec[key] = value.strip()
+ if system_platform != "Darwin":
+ for line in decoded_info:
+ key, value = line.split(separator, 1)
+ # Get the corresponding key from the field mapping
+ key = field_mapping.get(key.strip())
+ if key:
+ # Add the key and value to the CPU specifications dictionary
+ cpu_spec[key] = value.strip()
+ else:
+ cpu_spec["CPU Name"] = decoded_info[0]
return cpu_spec
diff --git a/src/turnkeyml/run/onnxrt/runtime.py b/src/turnkeyml/run/onnxrt/runtime.py
index 70eda897..b6ee1475 100644
--- a/src/turnkeyml/run/onnxrt/runtime.py
+++ b/src/turnkeyml/run/onnxrt/runtime.py
@@ -39,10 +39,21 @@ def __init__(
)
def _setup(self):
+ # Apple Silicon requires M* devices
+ if (
+ str(self.device_type) == "apple_silicon"
+ and "Apple M" not in self.device_name
+ ):
+ msg = (
+ "You need an 'Apple M*' processor to run using apple_silicon "
+ f", got '{self.device_name}'"
+ )
+ raise exp.ModelRuntimeError(msg)
+
# Check if x86_64 (aka AMD64) CPU is available locally
machine = platform.uname().machine
if machine != "x86_64" and machine != "AMD64":
- msg = "Only x86_64 and AMD64 CPUs are supported, got {machine}"
+ msg = f"Only x86_64 and AMD64 CPUs are supported, got {machine}"
raise exp.ModelRuntimeError(msg)
self._transfer_files([self.conda_script])
diff --git a/src/turnkeyml/run/torchrt/__init__.py b/src/turnkeyml/run/torchrt/__init__.py
index dadc8d89..6987935a 100644
--- a/src/turnkeyml/run/torchrt/__init__.py
+++ b/src/turnkeyml/run/torchrt/__init__.py
@@ -5,7 +5,7 @@
"torch-eager": {
"build_required": False,
"RuntimeClass": TorchRT,
- "supported_devices": {"x86"},
+ "supported_devices": {"x86", "apple_silicon"},
},
"torch-compiled": {
"build_required": False,
diff --git a/src/turnkeyml/version.py b/src/turnkeyml/version.py
index 5becc17c..5c4105cd 100644
--- a/src/turnkeyml/version.py
+++ b/src/turnkeyml/version.py
@@ -1 +1 @@
-__version__ = "1.0.0"
+__version__ = "1.0.1"