Skip to content

Commit bca23f5

Browse files
committed
Update
[ghstack-poisoned]
2 parents d6a70d6 + 5afe90d commit bca23f5

File tree

18 files changed

+480
-35
lines changed

18 files changed

+480
-35
lines changed

.github/workflows/apple.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ jobs:
156156
"kernels_llm"
157157
"kernels_optimized"
158158
"kernels_quantized"
159+
"kernels_torchao"
159160
"threadpool"
160161
)
161162

backends/apple/coreml/compiler/coreml_preprocess.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
import coremltools as ct
1818
import coremltools.optimize as cto
1919
from executorch.backends.apple.coreml import executorchcoreml
20+
from executorch.backends.apple.coreml.compiler.enumerated_shape_utils import (
21+
_get_ct_inputs,
22+
_SymbolicShapeToEnumeratedShapeMap,
23+
)
2024
from executorch.backends.apple.coreml.logging import get_coreml_log_level
2125
from executorch.exir.backend.backend_details import (
2226
BackendDetails,
@@ -37,6 +41,7 @@ class COMPILE_SPEC_KEYS(Enum):
3741
MIN_DEPLOYMENT_TARGET = "min_deployment_target"
3842
MODEL_COMPUTE_PRECISION = "model_compute_precision"
3943
OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config"
44+
ENUMERATED_SHAPES = "enumerated_shapes"
4045

4146

4247
class MODEL_PATHS(Enum):
@@ -143,7 +148,7 @@ def generate_minimum_deployment_target_compile_spec(
143148
@staticmethod
144149
def min_deployment_target_from_compile_specs(
145150
compile_specs: List[CompileSpec],
146-
) -> ct.target:
151+
) -> Optional[ct.target]:
147152
"""
148153
Returns the minimum deployment target by parsing the list of compile specs.
149154
"""
@@ -214,6 +219,54 @@ def op_linear_quantizer_config_from_compile_specs(
214219

215220
return None
216221

222+
@staticmethod
223+
def generate_enumerated_shapes_compile_spec(
224+
ep: ExportedProgram,
225+
enumerated_shapes: Dict[str, List[List[int]]],
226+
) -> CompileSpec:
227+
"""
228+
Returns the compile spec representing the model enumerated shapes
229+
enumerated_shapes is a dictionary for each input to its enumerated shapes, e.g.,
230+
231+
enumerated_shapes = {
232+
{"x": [[1, 1, 24], [8, 9, 24]]
233+
{"y": [[1, 6], [30, 6]],
234+
]
235+
236+
means the model can handle x can be shape [1, 1, 24] or [8, 9, 24] and y can be shape [1, 6] or [30, 6].
237+
238+
Only multiple inputs can have enumerated shapes if using iOS18 or later.
239+
In this case, each input must have the same number of enumerated shapes, and these shapes are tied together
240+
by their order in the list. For example, the model above can handle x with shape [1, 1, 24] and y with shape [1, 6],
241+
or x with shape [8, 9, 24] and y with shape [30, 6], but not x with shape [1, 1, 24] and y with shape [30, 6].
242+
243+
Passing incorrect shapes at runtime will result in an error.
244+
"""
245+
emap = _SymbolicShapeToEnumeratedShapeMap.from_exported_program(
246+
ep,
247+
enumerated_shapes,
248+
)
249+
str_representation = emap.to_json()
250+
byte_representation = str_representation.encode("utf-8")
251+
return CompileSpec(
252+
COMPILE_SPEC_KEYS.ENUMERATED_SHAPES.value,
253+
byte_representation,
254+
)
255+
256+
@staticmethod
257+
def enumerated_shapes_from_compile_specs(
258+
compile_specs: List[CompileSpec],
259+
) -> cto.coreml.OpLinearQuantizerConfig:
260+
"""
261+
Returns the model's post conversion quantization by parsing the list of compile specs.
262+
"""
263+
for compile_spec in compile_specs:
264+
if compile_spec.key == COMPILE_SPEC_KEYS.ENUMERATED_SHAPES.value:
265+
emap_json = compile_spec.value.decode("utf-8")
266+
emap = _SymbolicShapeToEnumeratedShapeMap.from_json(emap_json)
267+
return emap
268+
return None
269+
217270
@staticmethod
218271
def generate_compile_specs(
219272
compute_unit: ct.ComputeUnit = ct.ComputeUnit.ALL,
@@ -446,6 +499,28 @@ def preprocess(
446499
op_linear_quantizer_config = (
447500
CoreMLBackend.op_linear_quantizer_config_from_compile_specs(compile_specs)
448501
)
502+
enumerated_shapes = CoreMLBackend.enumerated_shapes_from_compile_specs(
503+
compile_specs
504+
)
505+
506+
# If using enumerated shapes, we need to pass the inputs to CoreML's convert() function
507+
# explicitly
508+
ct_inputs = None
509+
if enumerated_shapes is not None:
510+
ct_inputs = _get_ct_inputs(edge_program, enumerated_shapes)
511+
512+
# Check there are not multiple enumerated inputs if iOS is below 18
513+
if (minimum_deployment_target is None) or (
514+
minimum_deployment_target < ct.target.iOS18
515+
):
516+
n_enumerated_inputs = 0
517+
for ct_in in ct_inputs:
518+
if isinstance(ct_in.shape, ct.EnumeratedShapes):
519+
n_enumerated_inputs += 1
520+
if n_enumerated_inputs > 1:
521+
raise ValueError(
522+
f"You're program has {n_enumerated_inputs}, but the minimum_deployment_target is set to {minimum_deployment_target}. Multiple enumerated inputs requires iOS18 or later."
523+
)
449524

450525
# Load the model if MODEL_TYPE is 'COMPILED_MODEL'. This step is necessary because
451526
# get_compiled_model_path() requires a loaded model.
@@ -459,6 +534,7 @@ def preprocess(
459534
compute_precision=model_compute_precision,
460535
minimum_deployment_target=minimum_deployment_target,
461536
compute_units=compute_units,
537+
inputs=ct_inputs,
462538
)
463539

464540
if op_linear_quantizer_config is not None:
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
import json
2+
from dataclasses import asdict, dataclass
3+
from typing import Optional, Tuple
4+
5+
import coremltools as ct
6+
import torch
7+
from coremltools.converters.mil.frontend.torch.utils import TORCH_DTYPE_TO_MIL_DTYPE
8+
9+
_IGNORE_RANGE_CONSTRAINTS: bool = False
10+
11+
12+
@dataclass(frozen=True, slots=True)
13+
class _SymInt:
14+
key_name: str
15+
low: Optional[int]
16+
high: Optional[int]
17+
18+
@classmethod
19+
def from_symint_and_range_constraints(cls, s: torch.SymInt, range_constraints=None):
20+
# Canonicalize: "Sym(s0)" -> "s0", or leave "s0" as is
21+
def _symkey(sym: torch.SymInt) -> str:
22+
s = str(sym)
23+
return s[4:-1] if s.startswith("Sym(") and s.endswith(")") else s
24+
25+
# Convert symint to int. Infinity is converted to None
26+
def _as_int_or_none(b):
27+
if b is None:
28+
return None
29+
s = str(b)
30+
if s in {"int_oo", "-int_oo", "oo", "-oo", "Infinity", "-Infinity"}:
31+
return None
32+
return int(s)
33+
34+
# Get low/high from range_constraints if provided
35+
low, high = None, None
36+
if range_constraints is not None:
37+
for k, v in range_constraints.items():
38+
if _symkey(k) == _symkey(s):
39+
low = _as_int_or_none(getattr(v, "lower", getattr(v, "min", None)))
40+
high = _as_int_or_none(getattr(v, "upper", getattr(v, "max", None)))
41+
return _SymInt(_symkey(s), low, high)
42+
43+
44+
@dataclass(frozen=True, slots=True)
45+
class _SymbolicShape:
46+
shape: Tuple[int | _SymInt]
47+
48+
@classmethod
49+
def from_shape_and_range_constraints(cls, shape, range_constraints=None):
50+
out_shape = []
51+
for s in shape:
52+
if isinstance(s, int):
53+
assert s >= 0
54+
out_shape.append(s)
55+
elif isinstance(s, torch.SymInt):
56+
out_shape.append(
57+
_SymInt.from_symint_and_range_constraints(s, range_constraints)
58+
)
59+
else:
60+
raise ValueError(f"Unexpected type found in shape: {type(s)}")
61+
out_shape = tuple(out_shape)
62+
return _SymbolicShape(out_shape)
63+
64+
def is_static_shape(self):
65+
for s in self.shape:
66+
if isinstance(s, _SymInt):
67+
return False
68+
return True
69+
70+
def __len__(self):
71+
return len(self.shape)
72+
73+
def __getitem__(self, key):
74+
return self.shape[key]
75+
76+
def to_dict(self):
77+
return asdict(self)
78+
79+
@classmethod
80+
def from_dict(cls, d):
81+
assert len(d) == 1 and "shape" in d
82+
shape = []
83+
for s in d["shape"]:
84+
if isinstance(s, int):
85+
shape.append(s)
86+
elif isinstance(s, dict):
87+
assert len(s) == 3 and "key_name" in s and "low" in s and "high" in s
88+
shape.append(_SymInt(**s))
89+
else:
90+
raise ValueError(f"Unexpected type found in shape: {type(s)}")
91+
shape = tuple(shape)
92+
return _SymbolicShape(shape)
93+
94+
95+
def _iterate_over_fake_user_inputs(ep):
96+
user_inputs = ep.graph_signature.user_inputs
97+
for node in ep.graph.nodes:
98+
if node.op == "placeholder" and node.name in user_inputs:
99+
yield (node.name, node.meta["val"])
100+
101+
102+
def _create_enumeration_map(ep, enumerated_shapes, *, ignore_range_constraints=False):
103+
# Each input should have the same number of enumerations
104+
assert len(enumerated_shapes) > 0, "No enumerated shapes provided"
105+
num_enumerations = None
106+
for name, eshapes in enumerated_shapes.items():
107+
if num_enumerations is None:
108+
num_enumerations = len(eshapes)
109+
else:
110+
assert (
111+
len(eshapes) > 1
112+
), f"Input {name} only has {len(eshapes)} enumerated shapes provided. You should not specify enumerated shapes for inputs with only 1 input."
113+
assert (
114+
len(eshapes) == num_enumerations
115+
), f"Input {name} has {len(eshapes)} enumerated shape provided, but other inputs have {num_enumerations} enumerated shapes"
116+
117+
symbolic_shape_to_enumerations = {}
118+
for name, fake_input in _iterate_over_fake_user_inputs(ep):
119+
shape = fake_input.shape
120+
serialized_shape = _SymbolicShape.from_shape_and_range_constraints(
121+
shape, ep.range_constraints if not ignore_range_constraints else None
122+
)
123+
if serialized_shape.is_static_shape():
124+
continue
125+
# Shape is dynamic
126+
if name not in enumerated_shapes:
127+
raise ValueError(
128+
f"The input {name} has a symbolic shape, but you did not provide an enumeration for it"
129+
)
130+
# Validate
131+
for eshape in enumerated_shapes[name]:
132+
assert len(serialized_shape) == len(
133+
eshape
134+
), f"In {name}, the rank of the enumeration is {len(eshape)}, but the symbolic shape has rank {len(serialized_shape)}"
135+
for i in range(len(eshape)):
136+
assert isinstance(
137+
eshape[i], int
138+
), f"Enumerated shapes must be ints, but got {type(eshape[i])}."
139+
assert eshape[i] >= 1, "Each enumerated shape dimension must be >= 1"
140+
if isinstance(serialized_shape[i], int):
141+
assert (
142+
serialized_shape[i] == eshape[i]
143+
), f"In {name}, the shape enumeration {eshape} does not match {shape} on the non-symbolic value at index {i}"
144+
else:
145+
# Check eshape is within bound
146+
if serialized_shape[i].low is not None:
147+
# We add special case for when the low bound is 2. This is because Torch does not usually allow 1 as a lower bound
148+
assert (eshape[i] >= serialized_shape[i].low) or (
149+
eshape[i] == 1 and serialized_shape[i].low == 2
150+
), f"In {name}, the shape enumeration {eshape} violates the lower range-constraint on the symbolic shape {shape} at index {i}"
151+
if serialized_shape[i].high is not None:
152+
assert (
153+
eshape[i] <= serialized_shape[i].high
154+
), f"In {name}, the shape enumeration {eshape} violates the upper range-constraint on the symbolic shape {shape} at index {i}"
155+
if serialized_shape in symbolic_shape_to_enumerations:
156+
enumerations, names = symbolic_shape_to_enumerations[serialized_shape]
157+
assert (
158+
enumerations == enumerated_shapes[name]
159+
), f"The symbolic shape {shape}, has multiple enumerations defined. A new enumeration is defined for input {name}, but the existing inputs {names} have a different one defined. If these inputs have different enumerations, they should be exported with different symbolic shapes."
160+
names.append(name)
161+
symbolic_shape_to_enumerations[serialized_shape] = (enumerations, names)
162+
else:
163+
symbolic_shape_to_enumerations[serialized_shape] = (
164+
enumerated_shapes[name],
165+
[name],
166+
)
167+
return symbolic_shape_to_enumerations
168+
169+
170+
class _SymbolicShapeToEnumeratedShapeMap:
171+
def __init__(self, emap):
172+
self.emap = emap
173+
174+
def to_json(self):
175+
json_list = []
176+
for k in self.emap:
177+
json_list.append((k.to_dict(), self.emap[k]))
178+
return json.dumps(json_list)
179+
180+
@classmethod
181+
def from_json(cls, s):
182+
emap = {}
183+
json_list = json.loads(s)
184+
for k, v in json_list:
185+
k = _SymbolicShape.from_dict(k)
186+
emap[k] = tuple(v)
187+
return cls(emap)
188+
189+
@classmethod
190+
def from_exported_program(
191+
cls,
192+
ep,
193+
enumerated_shapes,
194+
*,
195+
ignore_range_constraints=_IGNORE_RANGE_CONSTRAINTS,
196+
):
197+
emap = _create_enumeration_map(
198+
ep, enumerated_shapes, ignore_range_constraints=ignore_range_constraints
199+
)
200+
return cls(emap)
201+
202+
def __getitem__(self, key: _SymbolicShape):
203+
return self.emap[key][0]
204+
205+
def __contains__(self, key):
206+
return key in self.emap
207+
208+
def __repr__(self):
209+
return f"_SymbolicShapeToEnumeratedShapeMap(emap={self.emap})"
210+
211+
212+
def _get_ct_inputs(ep, emap: _SymbolicShapeToEnumeratedShapeMap):
213+
ct_inputs = []
214+
for name, fake_input in _iterate_over_fake_user_inputs(ep):
215+
216+
# CoreML can do funny conversions in ct.convert (e.g., int64 -> int32, int16 -> int32), so here
217+
# we restrict users to use dtypes we know are supported
218+
_ENUMERATED_SHAPE_INPUT_DTYPES = [torch.float16, torch.float32, torch.int32]
219+
for dtype in _ENUMERATED_SHAPE_INPUT_DTYPES:
220+
assert dtype in TORCH_DTYPE_TO_MIL_DTYPE
221+
assert (
222+
fake_input.dtype in _ENUMERATED_SHAPE_INPUT_DTYPES
223+
), f"When using enumerated shapes, all inputs must have one of the following dtyeps {_ENUMERATED_SHAPE_INPUT_DTYPES}, but {name} has dtype {fake_input.dtype}"
224+
225+
ct_dtype = TORCH_DTYPE_TO_MIL_DTYPE[fake_input.dtype]
226+
shape = fake_input.shape
227+
serializable_shape = _SymbolicShape.from_shape_and_range_constraints(
228+
shape, ep.range_constraints if not _IGNORE_RANGE_CONSTRAINTS else None
229+
)
230+
if serializable_shape.is_static_shape():
231+
ct_inputs.append(
232+
ct.TensorType(name=name, shape=serializable_shape.shape, dtype=ct_dtype)
233+
)
234+
continue
235+
# Dynamic shape
236+
assert (
237+
serializable_shape in emap
238+
), f"The shape of input {name} ({serializable_shape}) is not in the _SymbolicShapeToEnumeratedShapeMap={emap}"
239+
enumerations = emap[serializable_shape]
240+
ct_enumerated_shape = ct.EnumeratedShapes(shapes=enumerations)
241+
ct_inputs.append(
242+
ct.TensorType(name=name, shape=ct_enumerated_shape, dtype=ct_dtype)
243+
)
244+
return ct_inputs

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
import torch
1111

1212
from executorch.backends.apple.coreml.compiler import CoreMLBackend
13+
from executorch.backends.apple.coreml.compiler.coreml_preprocess import (
14+
COMPILE_SPEC_KEYS,
15+
)
1316

1417
from executorch.backends.apple.coreml.logging import get_coreml_log_level
1518
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -192,6 +195,13 @@ def __init__(
192195
if skip_ops_for_coreml_delegation is None:
193196
skip_ops_for_coreml_delegation = []
194197
self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation
198+
199+
for compile_spec in compile_specs or []:
200+
if compile_spec.key == COMPILE_SPEC_KEYS.ENUMERATED_SHAPES.value:
201+
assert (
202+
lower_full_graph
203+
), "lower_full_graph must be True in the CoreMLPartitioner when using an enumerated shape compile spec"
204+
195205
self.delegation_spec = DelegationSpec(
196206
backend_id=CoreMLBackend.__name__,
197207
compile_specs=compile_specs if compile_specs is not None else [],

0 commit comments

Comments
 (0)