From 1de0d2300529a593639cb92ad5ddcd966212886d Mon Sep 17 00:00:00 2001 From: Kasper Nielsen Date: Sun, 23 Nov 2025 19:29:38 +0100 Subject: [PATCH 1/4] Illustrate issue --- .../passes/tests/test_quantization_passes.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/coremltools/converters/mil/mil/passes/tests/test_quantization_passes.py b/coremltools/converters/mil/mil/passes/tests/test_quantization_passes.py index dac40f137..3985d9fde 100644 --- a/coremltools/converters/mil/mil/passes/tests/test_quantization_passes.py +++ b/coremltools/converters/mil/mil/passes/tests/test_quantization_passes.py @@ -2853,6 +2853,54 @@ def prog(): assert cast_op.dtype.val == "uint16" assert cast_op.outputs[0] == block.find_ops(op_type="gather")[0].indices + def test_gather_along_axis_overflow_second_input(self): + """First input is safe to cast, but second input is not.""" + + @mb.program( + input_specs=[ + mb.TensorSpec(shape=(7,), dtype=types.int32), + mb.TensorSpec(shape=(7,), dtype=types.int32), + ], + opset_version=ct.target.iOS17 + ) + def prog(a, b): + cast_0 = mb.cast(x=b, dtype="fp32") + mul_0 = mb.mul(x=[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0], y=0.07142857142857142) + add_0 = mb.add(x=cast_0, y=mul_0) + argsort_0 = mb.argsort(x=add_0, axis=0, ascending=True) + gather_along_axis_0 = mb.gather_along_axis( + x=a, indices=argsort_0, axis=0, validate_indices=False + ) + cast_1 = mb.cast(x=gather_along_axis_0, dtype="fp32") + mul_1 = mb.mul(x=[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0], y=0.07142857142857142) + add_1 = mb.add(x=cast_1, y=mul_1) + argsort_1 = mb.argsort(x=add_1, axis=0, ascending=True) + gather_along_axis_1 = mb.gather_along_axis( + x=argsort_0, indices=argsort_1, axis=0, validate_indices=False + ) + gather_along_axis_2 = mb.gather_along_axis( + x=a, indices=gather_along_axis_1, axis=0, validate_indices=False + ) + gather_along_axis_3 = mb.gather_along_axis( + x=b, indices=gather_along_axis_1, axis=0, validate_indices=False + ) + return gather_along_axis_2, gather_along_axis_3 + + # prev_prog, _, block = apply_pass_and_basic_check(prog, "common::add_int16_cast") + # assert get_op_types_in_program(prog) == get_op_types_in_program(prev_prog) + + # prev_model = ct.convert(prev_prog, minimum_deployment_target=ct.target.iOS17) + model = ct.convert(prog, minimum_deployment_target=ct.target.iOS17) + + # prev_output = list(prev_model.predict({"x": x, "y": y}).values())[0] + + a = np.array([1, 3, 1, 4, 3, 5, 4], dtype=np.int32) + b = np.array([0, 4, 0, 4, 0, -21, -12], dtype=np.int32) + output1, output2 = tuple(model.predict({"a": a, "b": b}).values()) + # The CoreML program implements a stable multi-key sort, ensure the output is sorted by a, then by b + assert np.array_equal(output2, np.array([1, 1, 3, 3, 4, 4, 5], dtype=np.int32)) + assert np.array_equal(output1, np.array([0, 0, 0, 4, -12, 4, -21], dtype=np.int32)) + @pytest.mark.parametrize( "dtype, opset_version", itertools.product( From 5903ef8869ca966104d9ff5f40d01671ddfbaccc Mon Sep 17 00:00:00 2001 From: Kasper Nielsen Date: Sun, 23 Nov 2025 20:16:13 +0100 Subject: [PATCH 2/4] Attempt to fix side-effect in should_cast_parameter --- .../mil/mil/passes/defs/quantization.py | 81 ++++++++++--------- 1 file changed, 42 insertions(+), 39 deletions(-) diff --git a/coremltools/converters/mil/mil/passes/defs/quantization.py b/coremltools/converters/mil/mil/passes/defs/quantization.py index 4dec5cc6c..ff1c33e21 100644 --- a/coremltools/converters/mil/mil/passes/defs/quantization.py +++ b/coremltools/converters/mil/mil/passes/defs/quantization.py @@ -5,7 +5,7 @@ from abc import abstractmethod from enum import Enum as _Enum -from typing import Dict, Set, Text, Tuple +from typing import Dict, Set, Text, Tuple, Optional import numpy as np @@ -248,22 +248,30 @@ def transform_function_signatures(self, func: Function) -> None: func.output_types = output_types - def should_cast_parameter(self, op: Operation, param_name: str) -> bool: + def _check_target_dtype_for_param(self, op: Operation, param_name: str, param_target_dtype: str) -> bool: """ - Determines if a param of an op should be cast to target_dtype. + Determines if a param of an op should be cast to param_target_dtype. + Returns the target dtype string if it should be cast, otherwise None. There are two cases that an op shouldn't be cast: - 1. The op's parameter doesn't support target_dtype. - 2. The cast op itself doesn't support target_dtype + 1. The op's parameter doesn't support param_target_dtype. + 2. The cast op itself doesn't support param_target_dtype """ type_domain = getattr(op.input_spec.input_types[param_name], "type_domain", None) - if type_domain and types.string_to_builtin(self.target_dtype) not in type_domain: + if type_domain and types.string_to_builtin(param_target_dtype) not in type_domain: return False - if self.target_dtype not in SSAOpRegistry._get_core_op_cls("cast").supported_dtypes(): + if param_target_dtype not in SSAOpRegistry._get_core_op_cls("cast").supported_dtypes(): return False return True + def should_cast_parameter(self, op: Operation, param_name: str) -> Optional[str]: + """ + Determines if a param of an op should be cast to target_dtype. + Returns the target dtype string if it should be cast, otherwise None. + """ + raise NotImplementedError("Must be implemented in child class.") + def _get_casted_outputs(self, op: Operation, casted_inputs: Dict[str, Var]) -> Tuple[Var]: """ Given an op and casted_inputs, this utility returns the new resulting outputs. @@ -277,7 +285,8 @@ def transform_op(self, op) -> None: inputs_modified = False for param, inputs in op.inputs.items(): - if not self.should_cast_parameter(op, param): + param_target_dtype = self.should_cast_parameter(op, param) + if param_target_dtype is None: continue is_list_input = isinstance(inputs, (list, tuple)) @@ -300,7 +309,7 @@ def transform_op(self, op) -> None: continue inputs_modified = True - casted_var_name = f"{var.name}_to_{self.target_dtype}" + casted_var_name = f"{var.name}_to_{param_target_dtype}" if len(var._child_ops) > 1 and casted_var_name in self.current_cache_vars(): if self.current_cache_vars()[casted_var_name].op.x != var: logger.warning( @@ -311,11 +320,11 @@ def transform_op(self, op) -> None: else: x = mb.cast( x=var, - dtype=self.target_dtype, + dtype=param_target_dtype, name=casted_var_name, before_op=op, ) - if self.target_dtype == "fp16": + if param_target_dtype == "fp16": self._check_underflow_to_zero(x, var) Block._copy_metadata(var, x) @@ -446,24 +455,24 @@ def is_valid_op(self, op: Operation) -> bool: return True - def should_cast_parameter(self, op: Operation, param_name: str) -> bool: + def should_cast_parameter(self, op: Operation, param_name: str) -> Optional[str]: """Determines if a param of an op should be cast to fp16.""" - if not super().should_cast_parameter(op, param_name): - return False + if not super()._check_target_dtype_for_param(op, param_name, self.target_dtype): + return None if is_current_opset_version_compatible_with(AvailableTarget.iOS17): # In IOS17+ activation ops with alpha/beta support mixed precision, and we don't want to # cast alpha/beta to fp16 for better numerical accuracy. if op.op_type in self._ACTIVATION_ALPHA_OPS and param_name == "alpha": - return False + return None if op.op_type in self._ACTIVATION_ALPHA_BETA_OPS and param_name in {"alpha", "beta"}: - return False + return None # Element-wise unary ops with epsilon also support mixed precision. if op.op_type in self._ELEMENTWISE_UNARY_EPSILON_OPS and param_name == "epsilon": - return False + return None - return True + return self.target_dtype def _check_underflow_to_zero(self, new_var, var): # We check whether there are casted values that "becomes" 0 which is not ideal for eps purposes. @@ -539,9 +548,6 @@ class add_int16_cast(CastTypeQuantization): def __init__(self, op_selector=None): super().__init__(op_selector=op_selector) - # Use variable instead of hard-coded "int16" because the target dtype could be uint16 - # depending on if the param is non-negative const and within uint16 range. - self._target_dtype: str = "int16" @property def origin_dtype(self) -> str: @@ -549,15 +555,9 @@ def origin_dtype(self) -> str: @property def target_dtype(self) -> str: - return self._target_dtype - - @target_dtype.setter - def target_dtype(self, target_dtype: str): - if target_dtype not in {"int16", "uint16"}: - raise ValueError("The target_dtype in add_int16_cast must be int16 or uint16") - self._target_dtype = target_dtype + return "int16" - def should_cast_parameter(self, op: Operation, param_name: str) -> bool: + def should_cast_parameter(self, op: Operation, param_name: str) -> Optional[str]: """ Determine if a parameter should be cast or not. If should be cast, determine whether to use int16 or uint16. @@ -574,8 +574,9 @@ def should_cast_parameter(self, op: Operation, param_name: str) -> bool: return input_var.val == "int32" # otherwise only int32 tensor / scalar should get cast to int16 elif not input_var.is_tensor_or_scalar_of(dtype="int32"): - return False + return None + param_target_dtype = None input_op = input_var.op if input_op is not None: # here we do not handle input variables @@ -585,14 +586,14 @@ def should_cast_parameter(self, op: Operation, param_name: str) -> bool: input_op.outputs[0].val.min() >= _UINT16_MIN and input_op.outputs[0].val.max() <= _UINT16_MAX ): - self._target_dtype = "uint16" + param_target_dtype = "uint16" elif ( input_op.outputs[0].val.min() >= _INT16_MIN and input_op.outputs[0].val.max() <= _INT16_MAX ): - self._target_dtype = "int16" + param_target_dtype = "int16" else: - return False + return None elif input_op.op_type == "cast": # If the input op is a `cast`, then check if it is "cast uint16 to int32". # If so, then the correct 16-bit integer quantization for it should be to @@ -601,9 +602,9 @@ def should_cast_parameter(self, op: Operation, param_name: str) -> bool: # the only pattern for cast optimization to cancel these 2 casts, details see # https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.passes.defs.html#coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.cast_optimization if input_op.x.dtype == types.uint16: - self._target_dtype = "uint16" + param_target_dtype = "uint16" else: - self._target_dtype = "int16" + param_target_dtype = "int16" # In `gather` and `gather_along_axis`, if the dim size of x is larger than int16 # upperbound, the dynamic indices could overflow, so it shouldn't be cast. @@ -611,12 +612,14 @@ def should_cast_parameter(self, op: Operation, param_name: str) -> bool: if op.indices.val is None and op.x.shape is not None: dim_size = op.x.shape[op.axis.val] if not is_symbolic(dim_size) and dim_size > _INT16_MAX: - return False + return None - if not super().should_cast_parameter(op, param_name): - return False + if not param_target_dtype: + param_target_dtype = self.target_dtype + if not super()._check_target_dtype_for_param(op, param_name, param_target_dtype): + return None - return True + return param_target_dtype def is_valid_op(self, op: Operation) -> bool: """Determines if op is valid for int16/uint16 casting.""" From d509429b21503bc573240fabba81c711380ca841 Mon Sep 17 00:00:00 2001 From: Kasper Nielsen Date: Sun, 23 Nov 2025 20:22:18 +0100 Subject: [PATCH 3/4] Clean up test --- .../mil/mil/passes/tests/test_quantization_passes.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/coremltools/converters/mil/mil/passes/tests/test_quantization_passes.py b/coremltools/converters/mil/mil/passes/tests/test_quantization_passes.py index 3985d9fde..1f96f5d93 100644 --- a/coremltools/converters/mil/mil/passes/tests/test_quantization_passes.py +++ b/coremltools/converters/mil/mil/passes/tests/test_quantization_passes.py @@ -2853,8 +2853,8 @@ def prog(): assert cast_op.dtype.val == "uint16" assert cast_op.outputs[0] == block.find_ops(op_type="gather")[0].indices - def test_gather_along_axis_overflow_second_input(self): - """First input is safe to cast, but second input is not.""" + def test_stable_sort_program_does_not_underflow_output(self): + """Stable-sort implementation should not underflow output.""" @mb.program( input_specs=[ @@ -2886,18 +2886,13 @@ def prog(a, b): ) return gather_along_axis_2, gather_along_axis_3 - # prev_prog, _, block = apply_pass_and_basic_check(prog, "common::add_int16_cast") - # assert get_op_types_in_program(prog) == get_op_types_in_program(prev_prog) - - # prev_model = ct.convert(prev_prog, minimum_deployment_target=ct.target.iOS17) model = ct.convert(prog, minimum_deployment_target=ct.target.iOS17) - # prev_output = list(prev_model.predict({"x": x, "y": y}).values())[0] - a = np.array([1, 3, 1, 4, 3, 5, 4], dtype=np.int32) b = np.array([0, 4, 0, 4, 0, -21, -12], dtype=np.int32) output1, output2 = tuple(model.predict({"a": a, "b": b}).values()) # The CoreML program implements a stable multi-key sort, ensure the output is sorted by a, then by b + # Ensure that the output is not underflowed due to uint16 casting assert np.array_equal(output2, np.array([1, 1, 3, 3, 4, 4, 5], dtype=np.int32)) assert np.array_equal(output1, np.array([0, 0, 0, 4, -12, 4, -21], dtype=np.int32)) From c9b3812a6fbf462274b668bb87482de8104673a6 Mon Sep 17 00:00:00 2001 From: Kasper Nielsen Date: Sat, 6 Dec 2025 11:12:22 +0100 Subject: [PATCH 4/4] Fix boolean return --- coremltools/converters/mil/mil/passes/defs/quantization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/coremltools/converters/mil/mil/passes/defs/quantization.py b/coremltools/converters/mil/mil/passes/defs/quantization.py index ff1c33e21..338f5f5ab 100644 --- a/coremltools/converters/mil/mil/passes/defs/quantization.py +++ b/coremltools/converters/mil/mil/passes/defs/quantization.py @@ -571,7 +571,10 @@ def should_cast_parameter(self, op: Operation, param_name: str) -> Optional[str] # input may be a string that specifies the dtype, # so if it is "int32" then we would like to replace with "int16" if input_var.dtype == types.str: - return input_var.val == "int32" + if input_var.val == "int32": + return "int16" + else: + return None # otherwise only int32 tensor / scalar should get cast to int16 elif not input_var.is_tensor_or_scalar_of(dtype="int32"): return None