diff --git a/coremltools/converters/mil/mil/passes/__init__.py b/coremltools/converters/mil/mil/passes/__init__.py index 4e3f40b1b..05b5e149b 100644 --- a/coremltools/converters/mil/mil/passes/__init__.py +++ b/coremltools/converters/mil/mil/passes/__init__.py @@ -44,5 +44,6 @@ optimize_state, optimize_tensor_operation, preprocess, + transformer, symbol_transform, ) diff --git a/coremltools/converters/mil/mil/passes/defs/transformer.py b/coremltools/converters/mil/mil/passes/defs/transformer.py new file mode 100644 index 000000000..ca3b79c4a --- /dev/null +++ b/coremltools/converters/mil/mil/passes/defs/transformer.py @@ -0,0 +1,185 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from typing import ClassVar, List, Tuple + +import numpy as np + +from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass +from coremltools.converters.mil.mil.passes.pass_registry import register_pass +from coremltools.converters.mil._deployment_compatibility import AvailableTarget as target +from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil import types + +from coremltools import _logger as logger + + +@register_pass(namespace="common") +class scaled_dot_product_attention_sliced_q(AbstractGraphPass): + """ + Replace the ios18.scaled_dot_product_attention operation with a memory efficient + implementation of attention calculation based on slicing Q. The benefits are clearly + visible for higher Q sequence lengths, though. + + Graph pass options: + - min_seq_length: int + Only operations working with Q of sequence length greater or equal to this value will be transformed. + - seq_length_divider: int + Defines the size of the chunks of Q being processed in SDPA (chunk_size = seq_length / seq_length_divider) + """ + + _DEFAULT_MIN_SEQ_LENGTH: ClassVar[int] = 1280 + _DEFAULT_SEQ_LENGTH_DIVIDER: ClassVar[int] = 16 + + _min_seq_length: int + _seq_length_divider: int + + def __init__(self): + super().__init__() + self._min_seq_length = self._DEFAULT_MIN_SEQ_LENGTH + self._seq_length_divider = self._DEFAULT_SEQ_LENGTH_DIVIDER + + @property + def min_seq_length(self) -> int: + return self._min_seq_length + + @min_seq_length.setter + def min_seq_length(self, length: int) -> None: + if not isinstance(length, int): + raise ValueError("pass option min_seq_length must be an int") + if length < 0: + raise ValueError("pass option min_seq_length must be >= 0") + self._min_seq_length = length + + @property + def seq_length_divider(self) -> int: + return self._seq_length_divider + + @seq_length_divider.setter + def seq_length_divider(self, divider: int) -> None: + if not isinstance(divider, int): + raise ValueError("pass option seq_length_divider must be an int") + if divider < 1: + raise ValueError("pass option seq_length_divider must be >= 1") + self._seq_length_divider = divider + + def apply(self, prog): + for f in prog.functions.values(): + if f.opset_version < target.iOS18: + logger.debug(f"ignoring block '{f.name}', target {f.opset_version} (required min iOS18)") + return + + for op in list(f.operations): + if op.op_type == "scaled_dot_product_attention": + self._replace_scaled_dot_product_attention(op) + + @staticmethod + def _get_input_vars(op): + mandatory_params = ["query", "key", "value"] + inputs = {} + for param in mandatory_params: + inputs[param] = op.inputs.get(param) + if inputs[param] is None: + raise ValueError(f"operation 'scaled_dot_product_attention': mandatory input '{param}' not present") + return tuple([inputs[param] for param in mandatory_params]) + (op.inputs.get("attn_mask"),) + + @staticmethod + def _split_to_chunks(seq_length: int, count: int) -> List[Tuple[int, int]]: + chunk_size = max(seq_length // count, 1) + remainder = seq_length % count + + result = [] + chunk_start = 0 + for i in range(count): + if chunk_start >= seq_length: + break + chunk_end = chunk_start + chunk_size + (1 if i < remainder else 0) + result.append((chunk_start, chunk_end)) + chunk_start = chunk_end + + return result + + def _replace_scaled_dot_product_attention(self, op): + q, k, v, mask = self._get_input_vars(op) + + q_size = len(q.shape) + q_seq_length = q.shape[-2] + if q_seq_length < self._min_seq_length: + logger.debug( + f"skipping SDPA op, Q seq_length is {q_seq_length} (minimum seq length needed: {self._min_seq_length}" + ) + return + + dims = q.shape[-1] + normalize_factor = float(dims) ** -0.5 + + q_dtype = types.nptype_from_builtin(type(q.dtype())) + + chunks = self._split_to_chunks(q_seq_length, self._seq_length_divider) + + concat_out = None + with op.enclosing_block: + if mask is not None: + if mask.dtype == types.bool: + cond_out = mb.logical_not(x=mask, before_op=op) + mask_zeros = mb.const(val=np.zeros(mask.shape, dtype=q_dtype), before_op=op) + mask_float = mb.select(cond=cond_out, a=q_dtype(-np.inf), b=mask_zeros, before_op=op) + else: + mask_float = mask + + for chunk_start, chunk_end in chunks: + # Get a chunk of Q. + slice_begin = [0] * (q_size - 2) + [chunk_start, 0] + slice_end = list(q.shape[:-2] + (chunk_end, dims)) + slice_end_mask = tuple([True] * (q_size - 2) + [False, True]) + slice_out = mb.slice_by_index( + x=q, + begin=slice_begin, + end=slice_end, + end_mask=slice_end_mask, + before_op=op, + ) + + # Calculate chunk of Q x KT + matmul_out = mb.matmul(x=slice_out, y=k, transpose_x=False, transpose_y=True, before_op=op) + mul_out = mb.mul(x=matmul_out, y=np.array(normalize_factor, dtype=q_dtype), before_op=op) + + # Apply the attention mask. + if mask is not None: + if mask.shape[-2] == 1: + mul_out = mb.add(x=mul_out, y=mask_float, before_op=op) + else: + mask_out = mb.slice_by_index( + x=mask_float, + begin=[chunk_start, 0], + end=[chunk_end, mask.shape[-1]], + end_mask=[False, True], + before_op=op, + ) + mul_out = mb.add(x=mul_out, y=mask_out, before_op=op) + + # Calculate softmax of the product. + softmax_out = mb.softmax(x=mul_out, axis=-1, before_op=op) + + # Calculate the chunk of attention. + matmul_v_out = mb.matmul( + x=softmax_out, + y=v, + transpose_x=False, + transpose_y=False, + before_op=op, + ) + + # Add the chunk of attention to the result value. + concat_values = [concat_out] if concat_out is not None else [] + concat_out = mb.concat(values=concat_values + [matmul_v_out], axis=-2, interleave=False, before_op=op) + + # Remove the original SDPA operation. + op.enclosing_block.replace_uses_of_var_after_op( + anchor_op=op, + old_var=op.outputs[0], + new_var=concat_out, + ) + op.enclosing_block.remove_ops([op]) diff --git a/coremltools/converters/mil/mil/passes/tests/test_passes.py b/coremltools/converters/mil/mil/passes/tests/test_passes.py index 4a07b1627..b5fa03560 100644 --- a/coremltools/converters/mil/mil/passes/tests/test_passes.py +++ b/coremltools/converters/mil/mil/passes/tests/test_passes.py @@ -7,6 +7,8 @@ import itertools import unittest +from typing import ClassVar, Dict, List, Optional + import numpy as np import pytest import torch @@ -38,6 +40,7 @@ get_op_types_in_program, ) from coremltools.models.utils import _macos_version +from coremltools.converters.mil.frontend.milproto.load import load as _milproto_to_pymil np.random.seed(1984) _VALIDATE_MODEL = True @@ -7371,3 +7374,133 @@ def prog(x, y, z): apply_pass_and_basic_check(prog, "common::fuse_stack_split") assert get_op_types_in_program(prog) == ["stack", "split"] + ["squeeze"] * 3 + + +class TestScaledDotProductAttentionSlicedQ: + + class AttentionPyTorch(torch.nn.Module): + @staticmethod + def forward(q, k, v, attn_mask=None): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask) + + @staticmethod + def _get_example_inputs( + shape_size: int = 3, + qkv_same_shape: bool = True, + dtype: torch.dtype = torch.float16, + attn_mask_dtype: Optional[torch.dtype] = None, + ): + batches, seq_length, dimensions = 4, 256, 768 + q_shape = (batches, seq_length, dimensions) + kv_shape = q_shape if qkv_same_shape else (batches, seq_length - 16, dimensions) + if shape_size > 3: + q_shape = tuple([1] * (shape_size - len(q_shape)) + list(q_shape)) + kv_shape = tuple([1] * (shape_size - len(kv_shape)) + list(kv_shape)) + inputs = { + "q": torch.rand(q_shape, dtype=dtype), + "k": torch.rand(kv_shape, dtype=dtype), + "v": torch.rand(kv_shape, dtype=dtype), + } + if attn_mask_dtype is not None: + if attn_mask_dtype == torch.bool: + inputs["attn_mask"] = torch.randint(0, 2, (seq_length, seq_length), dtype=torch.bool) + else: + inputs["attn_mask"] = torch.randn((seq_length, seq_length), dtype=dtype) + return inputs + + @staticmethod + def _get_trace_coreml_inputs(example_inputs: Dict[str, torch.Tensor]): + model_inputs = [example_inputs[key] for key in ["q", "k", "v"]] + if "attn_mask" in example_inputs: + model_inputs.append(example_inputs["attn_mask"]) + + coreml_model_inputs = [] + for key in ["q", "k", "v", "attn_mask"]: + if key in example_inputs: + dtype = example_inputs[key].numpy().dtype + if dtype == bool: + dtype = np.float32 + coreml_model_inputs.append(ct.TensorType(key, shape=example_inputs[key].shape, dtype=dtype)) + + return model_inputs, coreml_model_inputs + + def verify_sdpa_outputs(self, example_inputs: Dict[str, torch.Tensor]): + pipeline_1 = ct.PassPipeline.DEFAULT + + pipeline_2 = ct.PassPipeline.DEFAULT + pipeline_2.append_pass("common::scaled_dot_product_attention_sliced_q") + + pipeline_3 = ct.PassPipeline.DEFAULT + pipeline_3.append_pass("common::scaled_dot_product_attention_sliced_q") + pipeline_3.set_options("common::scaled_dot_product_attention_sliced_q", {"min_seq_length": 256}) + + pipeline_4 = ct.PassPipeline.DEFAULT + pipeline_4.append_pass("common::scaled_dot_product_attention_sliced_q") + pipeline_4.set_options( + "common::scaled_dot_product_attention_sliced_q", {"min_seq_length": 256, "seq_length_divider": 32} + ) + + model = self.AttentionPyTorch() + model_inputs, coreml_model_inputs = self._get_trace_coreml_inputs(example_inputs) + + coreml_models = [ + ct.convert( + torch.jit.trace(model, model_inputs).eval(), + inputs=coreml_model_inputs, + minimum_deployment_target=ct.target.iOS18, + convert_to="mlprogram", + compute_units=ct.ComputeUnit.ALL, + skip_model_load=False, + pass_pipeline=pipeline, + ) + for pipeline in [pipeline_1, pipeline_2, pipeline_3, pipeline_4] + ] + + model_specs = [coreml_model.get_spec() for coreml_model in coreml_models] + progs = [] + for i in range(len(coreml_models)): + progs.append( + _milproto_to_pymil( + model_spec=model_specs[i], + specification_version=model_specs[i].specificationVersion, + file_weights_dir=coreml_models[i].weights_dir, + ) + ) + + ops_counts = [len(prog.functions["main"].operations) for prog in progs] + + assert ops_counts[0] == 1 or ops_counts[0] == 3 # (attn_mask might be cast to bool from input fp16 dtype) + assert ops_counts[1] == 1 or ops_counts[1] == 3 # the Q seq length is less than the default min seq length + assert ops_counts[2] >= 6 * 16 # 6 ops (without consts) per slice + assert ops_counts[3] >= 6 * 32 + + predict_inputs = copy.deepcopy(example_inputs) + if "attn_mask" in predict_inputs: + predict_inputs["attn_mask"] = predict_inputs["attn_mask"].to(dtype=torch.float32) + + outputs = [list(coreml_model.predict(predict_inputs).values())[0] for coreml_model in coreml_models] + + for i in range(1, len(outputs)): + assert outputs[0].shape == outputs[i].shape + np.testing.assert_allclose(outputs[0], outputs[i], rtol=0.01) + + def test_scaled_dot_product_attention_sliced(self): + # Confirm the basic scenario. + example_inputs = self._get_example_inputs() + self.verify_sdpa_outputs(example_inputs) + + # Confirm sdpa with Q, K and V as 4D tensors. + example_inputs = self._get_example_inputs(shape_size=4) + self.verify_sdpa_outputs(example_inputs) + + # Confirm sdpa with attn_mask as a bias. + example_inputs = self._get_example_inputs(attn_mask_dtype=torch.float16) + self.verify_sdpa_outputs(example_inputs) + + # Confirm sdpa with attn_mask as boolean flags. + example_inputs = self._get_example_inputs(attn_mask_dtype=torch.bool) + self.verify_sdpa_outputs(example_inputs) + + # Confirm sdpa works well with different shapes for Q and K & V. + example_inputs = self._get_example_inputs(qkv_same_shape=False) + self.verify_sdpa_outputs(example_inputs) diff --git a/docs/source/coremltools.converters.mil.mil.passes.defs.rst b/docs/source/coremltools.converters.mil.mil.passes.defs.rst index 2c361eec7..5a9e3919d 100644 --- a/docs/source/coremltools.converters.mil.mil.passes.defs.rst +++ b/docs/source/coremltools.converters.mil.mil.passes.defs.rst @@ -147,3 +147,11 @@ symbol_transform .. automodule:: coremltools.converters.mil.mil.passes.defs.symbol_transform .. autoclass:: materialize_symbolic_shape_program + + +transformer +--------------------------------------------------------- + +.. automodule:: coremltools.converters.mil.mil.passes.defs.transformer + + .. autoclass:: scaled_dot_product_attention_sliced_q diff --git a/reqs/test.pip b/reqs/test.pip index 0800bc334..27fb447b0 100644 --- a/reqs/test.pip +++ b/reqs/test.pip @@ -53,7 +53,7 @@ pytest-timeout transformers==4.26.0; platform_machine != "arm64" transformers==4.38.2; platform_machine == "arm64" -peft +peft==0.13.2 # coremltools.optimize.torch filelock==3.6.0