Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 46 additions & 40 deletions coremltools/converters/mil/mil/passes/defs/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from abc import abstractmethod
from enum import Enum as _Enum
from typing import Dict, Set, Text, Tuple
from typing import Dict, Set, Text, Tuple, Optional

import numpy as np

Expand Down Expand Up @@ -248,22 +248,30 @@ def transform_function_signatures(self, func: Function) -> None:

func.output_types = output_types

def should_cast_parameter(self, op: Operation, param_name: str) -> bool:
def _check_target_dtype_for_param(self, op: Operation, param_name: str, param_target_dtype: str) -> bool:
"""
Determines if a param of an op should be cast to target_dtype.
Determines if a param of an op should be cast to param_target_dtype.
Returns the target dtype string if it should be cast, otherwise None.

There are two cases that an op shouldn't be cast:
1. The op's parameter doesn't support target_dtype.
2. The cast op itself doesn't support target_dtype
1. The op's parameter doesn't support param_target_dtype.
2. The cast op itself doesn't support param_target_dtype
"""
type_domain = getattr(op.input_spec.input_types[param_name], "type_domain", None)
if type_domain and types.string_to_builtin(self.target_dtype) not in type_domain:
if type_domain and types.string_to_builtin(param_target_dtype) not in type_domain:
return False
if self.target_dtype not in SSAOpRegistry._get_core_op_cls("cast").supported_dtypes():
if param_target_dtype not in SSAOpRegistry._get_core_op_cls("cast").supported_dtypes():
return False

return True

def should_cast_parameter(self, op: Operation, param_name: str) -> Optional[str]:
"""
Determines if a param of an op should be cast to target_dtype.
Returns the target dtype string if it should be cast, otherwise None.
"""
raise NotImplementedError("Must be implemented in child class.")

def _get_casted_outputs(self, op: Operation, casted_inputs: Dict[str, Var]) -> Tuple[Var]:
"""
Given an op and casted_inputs, this utility returns the new resulting outputs.
Expand All @@ -277,7 +285,8 @@ def transform_op(self, op) -> None:
inputs_modified = False

for param, inputs in op.inputs.items():
if not self.should_cast_parameter(op, param):
param_target_dtype = self.should_cast_parameter(op, param)
if param_target_dtype is None:
continue

is_list_input = isinstance(inputs, (list, tuple))
Expand All @@ -300,7 +309,7 @@ def transform_op(self, op) -> None:
continue

inputs_modified = True
casted_var_name = f"{var.name}_to_{self.target_dtype}"
casted_var_name = f"{var.name}_to_{param_target_dtype}"
if len(var._child_ops) > 1 and casted_var_name in self.current_cache_vars():
if self.current_cache_vars()[casted_var_name].op.x != var:
logger.warning(
Expand All @@ -311,11 +320,11 @@ def transform_op(self, op) -> None:
else:
x = mb.cast(
x=var,
dtype=self.target_dtype,
dtype=param_target_dtype,
name=casted_var_name,
before_op=op,
)
if self.target_dtype == "fp16":
if param_target_dtype == "fp16":
self._check_underflow_to_zero(x, var)
Block._copy_metadata(var, x)

Expand Down Expand Up @@ -446,24 +455,24 @@ def is_valid_op(self, op: Operation) -> bool:

return True

def should_cast_parameter(self, op: Operation, param_name: str) -> bool:
def should_cast_parameter(self, op: Operation, param_name: str) -> Optional[str]:
"""Determines if a param of an op should be cast to fp16."""
if not super().should_cast_parameter(op, param_name):
return False
if not super()._check_target_dtype_for_param(op, param_name, self.target_dtype):
return None

if is_current_opset_version_compatible_with(AvailableTarget.iOS17):
# In IOS17+ activation ops with alpha/beta support mixed precision, and we don't want to
# cast alpha/beta to fp16 for better numerical accuracy.
if op.op_type in self._ACTIVATION_ALPHA_OPS and param_name == "alpha":
return False
return None
if op.op_type in self._ACTIVATION_ALPHA_BETA_OPS and param_name in {"alpha", "beta"}:
return False
return None

# Element-wise unary ops with epsilon also support mixed precision.
if op.op_type in self._ELEMENTWISE_UNARY_EPSILON_OPS and param_name == "epsilon":
return False
return None

return True
return self.target_dtype

def _check_underflow_to_zero(self, new_var, var):
# We check whether there are casted values that "becomes" 0 which is not ideal for eps purposes.
Expand Down Expand Up @@ -539,25 +548,16 @@ class add_int16_cast(CastTypeQuantization):

def __init__(self, op_selector=None):
super().__init__(op_selector=op_selector)
# Use variable instead of hard-coded "int16" because the target dtype could be uint16
# depending on if the param is non-negative const and within uint16 range.
self._target_dtype: str = "int16"

@property
def origin_dtype(self) -> str:
return "int32"

@property
def target_dtype(self) -> str:
return self._target_dtype

@target_dtype.setter
def target_dtype(self, target_dtype: str):
if target_dtype not in {"int16", "uint16"}:
raise ValueError("The target_dtype in add_int16_cast must be int16 or uint16")
self._target_dtype = target_dtype
return "int16"

def should_cast_parameter(self, op: Operation, param_name: str) -> bool:
def should_cast_parameter(self, op: Operation, param_name: str) -> Optional[str]:
"""
Determine if a parameter should be cast or not.
If should be cast, determine whether to use int16 or uint16.
Expand All @@ -571,11 +571,15 @@ def should_cast_parameter(self, op: Operation, param_name: str) -> bool:
# input may be a string that specifies the dtype,
# so if it is "int32" then we would like to replace with "int16"
if input_var.dtype == types.str:
return input_var.val == "int32"
if input_var.val == "int32":
return "int16"
else:
return None
# otherwise only int32 tensor / scalar should get cast to int16
elif not input_var.is_tensor_or_scalar_of(dtype="int32"):
return False
return None

param_target_dtype = None
input_op = input_var.op
if input_op is not None:
# here we do not handle input variables
Expand All @@ -585,14 +589,14 @@ def should_cast_parameter(self, op: Operation, param_name: str) -> bool:
input_op.outputs[0].val.min() >= _UINT16_MIN
and input_op.outputs[0].val.max() <= _UINT16_MAX
):
self._target_dtype = "uint16"
param_target_dtype = "uint16"
elif (
input_op.outputs[0].val.min() >= _INT16_MIN
and input_op.outputs[0].val.max() <= _INT16_MAX
):
self._target_dtype = "int16"
param_target_dtype = "int16"
else:
return False
return None
elif input_op.op_type == "cast":
# If the input op is a `cast`, then check if it is "cast uint16 to int32".
# If so, then the correct 16-bit integer quantization for it should be to
Expand All @@ -601,22 +605,24 @@ def should_cast_parameter(self, op: Operation, param_name: str) -> bool:
# the only pattern for cast optimization to cancel these 2 casts, details see
# https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.passes.defs.html#coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.cast_optimization
if input_op.x.dtype == types.uint16:
self._target_dtype = "uint16"
param_target_dtype = "uint16"
else:
self._target_dtype = "int16"
param_target_dtype = "int16"

# In `gather` and `gather_along_axis`, if the dim size of x is larger than int16
# upperbound, the dynamic indices could overflow, so it shouldn't be cast.
if op.op_type in {"gather", "gather_along_axis"} and param_name == "indices":
if op.indices.val is None and op.x.shape is not None:
dim_size = op.x.shape[op.axis.val]
if not is_symbolic(dim_size) and dim_size > _INT16_MAX:
return False
return None

if not super().should_cast_parameter(op, param_name):
return False
if not param_target_dtype:
param_target_dtype = self.target_dtype
if not super()._check_target_dtype_for_param(op, param_name, param_target_dtype):
return None

return True
return param_target_dtype

def is_valid_op(self, op: Operation) -> bool:
"""Determines if op is valid for int16/uint16 casting."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2853,6 +2853,49 @@ def prog():
assert cast_op.dtype.val == "uint16"
assert cast_op.outputs[0] == block.find_ops(op_type="gather")[0].indices

def test_stable_sort_program_does_not_underflow_output(self):
"""Stable-sort implementation should not underflow output."""

@mb.program(
input_specs=[
mb.TensorSpec(shape=(7,), dtype=types.int32),
mb.TensorSpec(shape=(7,), dtype=types.int32),
],
opset_version=ct.target.iOS17
)
def prog(a, b):
cast_0 = mb.cast(x=b, dtype="fp32")
mul_0 = mb.mul(x=[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0], y=0.07142857142857142)
add_0 = mb.add(x=cast_0, y=mul_0)
argsort_0 = mb.argsort(x=add_0, axis=0, ascending=True)
gather_along_axis_0 = mb.gather_along_axis(
x=a, indices=argsort_0, axis=0, validate_indices=False
)
cast_1 = mb.cast(x=gather_along_axis_0, dtype="fp32")
mul_1 = mb.mul(x=[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0], y=0.07142857142857142)
add_1 = mb.add(x=cast_1, y=mul_1)
argsort_1 = mb.argsort(x=add_1, axis=0, ascending=True)
gather_along_axis_1 = mb.gather_along_axis(
x=argsort_0, indices=argsort_1, axis=0, validate_indices=False
)
gather_along_axis_2 = mb.gather_along_axis(
x=a, indices=gather_along_axis_1, axis=0, validate_indices=False
)
gather_along_axis_3 = mb.gather_along_axis(
x=b, indices=gather_along_axis_1, axis=0, validate_indices=False
)
return gather_along_axis_2, gather_along_axis_3

model = ct.convert(prog, minimum_deployment_target=ct.target.iOS17)

a = np.array([1, 3, 1, 4, 3, 5, 4], dtype=np.int32)
b = np.array([0, 4, 0, 4, 0, -21, -12], dtype=np.int32)
output1, output2 = tuple(model.predict({"a": a, "b": b}).values())
# The CoreML program implements a stable multi-key sort, ensure the output is sorted by a, then by b
# Ensure that the output is not underflowed due to uint16 casting
assert np.array_equal(output2, np.array([1, 1, 3, 3, 4, 4, 5], dtype=np.int32))
assert np.array_equal(output1, np.array([0, 0, 0, 4, -12, 4, -21], dtype=np.int32))

@pytest.mark.parametrize(
"dtype, opset_version",
itertools.product(
Expand Down