|
| 1 | +import json |
| 2 | +from dataclasses import asdict, dataclass |
| 3 | +from typing import Optional, Tuple |
| 4 | + |
| 5 | +import coremltools as ct |
| 6 | +import torch |
| 7 | +from coremltools.converters.mil.frontend.torch.utils import TORCH_DTYPE_TO_MIL_DTYPE |
| 8 | + |
| 9 | +_IGNORE_RANGE_CONSTRAINTS: bool = False |
| 10 | + |
| 11 | + |
| 12 | +@dataclass(frozen=True, slots=True) |
| 13 | +class _SymInt: |
| 14 | + key_name: str |
| 15 | + low: Optional[int] |
| 16 | + high: Optional[int] |
| 17 | + |
| 18 | + @classmethod |
| 19 | + def from_symint_and_range_constraints(cls, s: torch.SymInt, range_constraints=None): |
| 20 | + # Canonicalize: "Sym(s0)" -> "s0", or leave "s0" as is |
| 21 | + def _symkey(sym: torch.SymInt) -> str: |
| 22 | + s = str(sym) |
| 23 | + return s[4:-1] if s.startswith("Sym(") and s.endswith(")") else s |
| 24 | + |
| 25 | + # Convert symint to int. Infinity is converted to None |
| 26 | + def _as_int_or_none(b): |
| 27 | + if b is None: |
| 28 | + return None |
| 29 | + s = str(b) |
| 30 | + if s in {"int_oo", "-int_oo", "oo", "-oo", "Infinity", "-Infinity"}: |
| 31 | + return None |
| 32 | + return int(s) |
| 33 | + |
| 34 | + # Get low/high from range_constraints if provided |
| 35 | + low, high = None, None |
| 36 | + if range_constraints is not None: |
| 37 | + for k, v in range_constraints.items(): |
| 38 | + if _symkey(k) == _symkey(s): |
| 39 | + low = _as_int_or_none(getattr(v, "lower", getattr(v, "min", None))) |
| 40 | + high = _as_int_or_none(getattr(v, "upper", getattr(v, "max", None))) |
| 41 | + return _SymInt(_symkey(s), low, high) |
| 42 | + |
| 43 | + |
| 44 | +@dataclass(frozen=True, slots=True) |
| 45 | +class _SymbolicShape: |
| 46 | + shape: Tuple[int | _SymInt] |
| 47 | + |
| 48 | + @classmethod |
| 49 | + def from_shape_and_range_constraints(cls, shape, range_constraints=None): |
| 50 | + out_shape = [] |
| 51 | + for s in shape: |
| 52 | + if isinstance(s, int): |
| 53 | + assert s >= 0 |
| 54 | + out_shape.append(s) |
| 55 | + elif isinstance(s, torch.SymInt): |
| 56 | + out_shape.append( |
| 57 | + _SymInt.from_symint_and_range_constraints(s, range_constraints) |
| 58 | + ) |
| 59 | + else: |
| 60 | + raise ValueError(f"Unexpected type found in shape: {type(s)}") |
| 61 | + out_shape = tuple(out_shape) |
| 62 | + return _SymbolicShape(out_shape) |
| 63 | + |
| 64 | + def is_static_shape(self): |
| 65 | + for s in self.shape: |
| 66 | + if isinstance(s, _SymInt): |
| 67 | + return False |
| 68 | + return True |
| 69 | + |
| 70 | + def __len__(self): |
| 71 | + return len(self.shape) |
| 72 | + |
| 73 | + def __getitem__(self, key): |
| 74 | + return self.shape[key] |
| 75 | + |
| 76 | + def to_dict(self): |
| 77 | + return asdict(self) |
| 78 | + |
| 79 | + @classmethod |
| 80 | + def from_dict(cls, d): |
| 81 | + assert len(d) == 1 and "shape" in d |
| 82 | + shape = [] |
| 83 | + for s in d["shape"]: |
| 84 | + if isinstance(s, int): |
| 85 | + shape.append(s) |
| 86 | + elif isinstance(s, dict): |
| 87 | + assert len(s) == 3 and "key_name" in s and "low" in s and "high" in s |
| 88 | + shape.append(_SymInt(**s)) |
| 89 | + else: |
| 90 | + raise ValueError(f"Unexpected type found in shape: {type(s)}") |
| 91 | + shape = tuple(shape) |
| 92 | + return _SymbolicShape(shape) |
| 93 | + |
| 94 | + |
| 95 | +def _iterate_over_fake_user_inputs(ep): |
| 96 | + user_inputs = ep.graph_signature.user_inputs |
| 97 | + for node in ep.graph.nodes: |
| 98 | + if node.op == "placeholder" and node.name in user_inputs: |
| 99 | + yield (node.name, node.meta["val"]) |
| 100 | + |
| 101 | + |
| 102 | +def _create_enumeration_map(ep, enumerated_shapes, *, ignore_range_constraints=False): |
| 103 | + # Each input should have the same number of enumerations |
| 104 | + assert len(enumerated_shapes) > 0, "No enumerated shapes provided" |
| 105 | + num_enumerations = None |
| 106 | + for name, eshapes in enumerated_shapes.items(): |
| 107 | + if num_enumerations is None: |
| 108 | + num_enumerations = len(eshapes) |
| 109 | + else: |
| 110 | + assert ( |
| 111 | + len(eshapes) > 1 |
| 112 | + ), f"Input {name} only has {len(eshapes)} enumerated shapes provided. You should not specify enumerated shapes for inputs with only 1 input." |
| 113 | + assert ( |
| 114 | + len(eshapes) == num_enumerations |
| 115 | + ), f"Input {name} has {len(eshapes)} enumerated shape provided, but other inputs have {num_enumerations} enumerated shapes" |
| 116 | + |
| 117 | + symbolic_shape_to_enumerations = {} |
| 118 | + for name, fake_input in _iterate_over_fake_user_inputs(ep): |
| 119 | + shape = fake_input.shape |
| 120 | + serialized_shape = _SymbolicShape.from_shape_and_range_constraints( |
| 121 | + shape, ep.range_constraints if not ignore_range_constraints else None |
| 122 | + ) |
| 123 | + if serialized_shape.is_static_shape(): |
| 124 | + continue |
| 125 | + # Shape is dynamic |
| 126 | + if name not in enumerated_shapes: |
| 127 | + raise ValueError( |
| 128 | + f"The input {name} has a symbolic shape, but you did not provide an enumeration for it" |
| 129 | + ) |
| 130 | + # Validate |
| 131 | + for eshape in enumerated_shapes[name]: |
| 132 | + assert len(serialized_shape) == len( |
| 133 | + eshape |
| 134 | + ), f"In {name}, the rank of the enumeration is {len(eshape)}, but the symbolic shape has rank {len(serialized_shape)}" |
| 135 | + for i in range(len(eshape)): |
| 136 | + assert isinstance( |
| 137 | + eshape[i], int |
| 138 | + ), f"Enumerated shapes must be ints, but got {type(eshape[i])}." |
| 139 | + assert eshape[i] >= 1, "Each enumerated shape dimension must be >= 1" |
| 140 | + if isinstance(serialized_shape[i], int): |
| 141 | + assert ( |
| 142 | + serialized_shape[i] == eshape[i] |
| 143 | + ), f"In {name}, the shape enumeration {eshape} does not match {shape} on the non-symbolic value at index {i}" |
| 144 | + else: |
| 145 | + # Check eshape is within bound |
| 146 | + if serialized_shape[i].low is not None: |
| 147 | + # We add special case for when the low bound is 2. This is because Torch does not usually allow 1 as a lower bound |
| 148 | + assert (eshape[i] >= serialized_shape[i].low) or ( |
| 149 | + eshape[i] == 1 and serialized_shape[i].low == 2 |
| 150 | + ), f"In {name}, the shape enumeration {eshape} violates the lower range-constraint on the symbolic shape {shape} at index {i}" |
| 151 | + if serialized_shape[i].high is not None: |
| 152 | + assert ( |
| 153 | + eshape[i] <= serialized_shape[i].high |
| 154 | + ), f"In {name}, the shape enumeration {eshape} violates the upper range-constraint on the symbolic shape {shape} at index {i}" |
| 155 | + if serialized_shape in symbolic_shape_to_enumerations: |
| 156 | + enumerations, names = symbolic_shape_to_enumerations[serialized_shape] |
| 157 | + assert ( |
| 158 | + enumerations == enumerated_shapes[name] |
| 159 | + ), f"The symbolic shape {shape}, has multiple enumerations defined. A new enumeration is defined for input {name}, but the existing inputs {names} have a different one defined. If these inputs have different enumerations, they should be exported with different symbolic shapes." |
| 160 | + names.append(name) |
| 161 | + symbolic_shape_to_enumerations[serialized_shape] = (enumerations, names) |
| 162 | + else: |
| 163 | + symbolic_shape_to_enumerations[serialized_shape] = ( |
| 164 | + enumerated_shapes[name], |
| 165 | + [name], |
| 166 | + ) |
| 167 | + return symbolic_shape_to_enumerations |
| 168 | + |
| 169 | + |
| 170 | +class _SymbolicShapeToEnumeratedShapeMap: |
| 171 | + def __init__(self, emap): |
| 172 | + self.emap = emap |
| 173 | + |
| 174 | + def to_json(self): |
| 175 | + json_list = [] |
| 176 | + for k in self.emap: |
| 177 | + json_list.append((k.to_dict(), self.emap[k])) |
| 178 | + return json.dumps(json_list) |
| 179 | + |
| 180 | + @classmethod |
| 181 | + def from_json(cls, s): |
| 182 | + emap = {} |
| 183 | + json_list = json.loads(s) |
| 184 | + for k, v in json_list: |
| 185 | + k = _SymbolicShape.from_dict(k) |
| 186 | + emap[k] = tuple(v) |
| 187 | + return cls(emap) |
| 188 | + |
| 189 | + @classmethod |
| 190 | + def from_exported_program( |
| 191 | + cls, |
| 192 | + ep, |
| 193 | + enumerated_shapes, |
| 194 | + *, |
| 195 | + ignore_range_constraints=_IGNORE_RANGE_CONSTRAINTS, |
| 196 | + ): |
| 197 | + emap = _create_enumeration_map( |
| 198 | + ep, enumerated_shapes, ignore_range_constraints=ignore_range_constraints |
| 199 | + ) |
| 200 | + return cls(emap) |
| 201 | + |
| 202 | + def __getitem__(self, key: _SymbolicShape): |
| 203 | + return self.emap[key][0] |
| 204 | + |
| 205 | + def __contains__(self, key): |
| 206 | + return key in self.emap |
| 207 | + |
| 208 | + def __repr__(self): |
| 209 | + return f"_SymbolicShapeToEnumeratedShapeMap(emap={self.emap})" |
| 210 | + |
| 211 | + |
| 212 | +def _get_ct_inputs(ep, emap: _SymbolicShapeToEnumeratedShapeMap): |
| 213 | + ct_inputs = [] |
| 214 | + for name, fake_input in _iterate_over_fake_user_inputs(ep): |
| 215 | + |
| 216 | + # CoreML can do funny conversions in ct.convert (e.g., int64 -> int32, int16 -> int32), so here |
| 217 | + # we restrict users to use dtypes we know are supported |
| 218 | + _ENUMERATED_SHAPE_INPUT_DTYPES = [torch.float16, torch.float32, torch.int32] |
| 219 | + for dtype in _ENUMERATED_SHAPE_INPUT_DTYPES: |
| 220 | + assert dtype in TORCH_DTYPE_TO_MIL_DTYPE |
| 221 | + assert ( |
| 222 | + fake_input.dtype in _ENUMERATED_SHAPE_INPUT_DTYPES |
| 223 | + ), f"When using enumerated shapes, all inputs must have one of the following dtyeps {_ENUMERATED_SHAPE_INPUT_DTYPES}, but {name} has dtype {fake_input.dtype}" |
| 224 | + |
| 225 | + ct_dtype = TORCH_DTYPE_TO_MIL_DTYPE[fake_input.dtype] |
| 226 | + shape = fake_input.shape |
| 227 | + serializable_shape = _SymbolicShape.from_shape_and_range_constraints( |
| 228 | + shape, ep.range_constraints if not _IGNORE_RANGE_CONSTRAINTS else None |
| 229 | + ) |
| 230 | + if serializable_shape.is_static_shape(): |
| 231 | + ct_inputs.append( |
| 232 | + ct.TensorType(name=name, shape=serializable_shape.shape, dtype=ct_dtype) |
| 233 | + ) |
| 234 | + continue |
| 235 | + # Dynamic shape |
| 236 | + assert ( |
| 237 | + serializable_shape in emap |
| 238 | + ), f"The shape of input {name} ({serializable_shape}) is not in the _SymbolicShapeToEnumeratedShapeMap={emap}" |
| 239 | + enumerations = emap[serializable_shape] |
| 240 | + ct_enumerated_shape = ct.EnumeratedShapes(shapes=enumerations) |
| 241 | + ct_inputs.append( |
| 242 | + ct.TensorType(name=name, shape=ct_enumerated_shape, dtype=ct_dtype) |
| 243 | + ) |
| 244 | + return ct_inputs |
0 commit comments