Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Pass: scaled_dot_product_attention_sliced_q #2418

Merged
merged 7 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions coremltools/converters/mil/mil/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@
optimize_state,
optimize_tensor_operation,
preprocess,
transformer,
symbol_transform,
)
185 changes: 185 additions & 0 deletions coremltools/converters/mil/mil/passes/defs/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) 2024, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from typing import ClassVar, List, Tuple

import numpy as np

from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.pass_registry import register_pass
from coremltools.converters.mil._deployment_compatibility import AvailableTarget as target
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import types

from coremltools import _logger as logger


@register_pass(namespace="common")
class scaled_dot_product_attention_sliced_q(AbstractGraphPass):
"""
Replace the ios18.scaled_dot_product_attention operation with a memory efficient
implementation of attention calculation based on slicing Q. The benefits are clearly
visible for higher Q sequence lengths, though.

Graph pass options:
- min_seq_length: int
Only operations working with Q of sequence length greater or equal to this value will be transformed.
- seq_length_divider: int
Defines the size of the chunks of Q being processed in SDPA (chunk_size = seq_length / seq_length_divider)
"""

_DEFAULT_MIN_SEQ_LENGTH: ClassVar[int] = 1280
_DEFAULT_SEQ_LENGTH_DIVIDER: ClassVar[int] = 16

_min_seq_length: int
_seq_length_divider: int

def __init__(self):
super().__init__()
self._min_seq_length = self._DEFAULT_MIN_SEQ_LENGTH
self._seq_length_divider = self._DEFAULT_SEQ_LENGTH_DIVIDER

@property
def min_seq_length(self) -> int:
return self._min_seq_length

@min_seq_length.setter
def min_seq_length(self, length: int) -> None:
if not isinstance(length, int):
raise ValueError("pass option min_seq_length must be an int")
if length < 0:
raise ValueError("pass option min_seq_length must be >= 0")
self._min_seq_length = length

@property
def seq_length_divider(self) -> int:
return self._seq_length_divider

@seq_length_divider.setter
def seq_length_divider(self, divider: int) -> None:
if not isinstance(divider, int):
raise ValueError("pass option seq_length_divider must be an int")
if divider < 1:
raise ValueError("pass option seq_length_divider must be >= 1")
self._seq_length_divider = divider

def apply(self, prog):
for f in prog.functions.values():
if f.opset_version < target.iOS18:
logger.debug(f"ignoring block '{f.name}', target {f.opset_version} (required min iOS18)")
return

for op in list(f.operations):
if op.op_type == "scaled_dot_product_attention":
self._replace_scaled_dot_product_attention(op)

@staticmethod
def _get_input_vars(op):
mandatory_params = ["query", "key", "value"]
inputs = {}
for param in mandatory_params:
inputs[param] = op.inputs.get(param)
if inputs[param] is None:
raise ValueError(f"operation 'scaled_dot_product_attention': mandatory input '{param}' not present")
return tuple([inputs[param] for param in mandatory_params]) + (op.inputs.get("attn_mask"),)

@staticmethod
def _split_to_chunks(seq_length: int, count: int) -> List[Tuple[int, int]]:
chunk_size = max(seq_length // count, 1)
remainder = seq_length % count

result = []
chunk_start = 0
for i in range(count):
if chunk_start >= seq_length:
break
chunk_end = chunk_start + chunk_size + (1 if i < remainder else 0)
result.append((chunk_start, chunk_end))
chunk_start = chunk_end

return result

def _replace_scaled_dot_product_attention(self, op):
q, k, v, mask = self._get_input_vars(op)

q_size = len(q.shape)
q_seq_length = q.shape[-2]
if q_seq_length < self._min_seq_length:
logger.debug(
f"skipping SDPA op, Q seq_length is {q_seq_length} (minimum seq length needed: {self._min_seq_length}"
)
return

dims = q.shape[-1]
normalize_factor = float(dims) ** -0.5

q_dtype = types.nptype_from_builtin(type(q.dtype()))

chunks = self._split_to_chunks(q_seq_length, self._seq_length_divider)

concat_out = None
with op.enclosing_block:
if mask is not None:
if mask.dtype == types.bool:
cond_out = mb.logical_not(x=mask, before_op=op)
mask_zeros = mb.const(val=np.zeros(mask.shape, dtype=q_dtype), before_op=op)
mask_float = mb.select(cond=cond_out, a=q_dtype(-np.inf), b=mask_zeros, before_op=op)
else:
mask_float = mask

for chunk_start, chunk_end in chunks:
# Get a chunk of Q.
slice_begin = [0] * (q_size - 2) + [chunk_start, 0]
slice_end = list(q.shape[:-2] + (chunk_end, dims))
slice_end_mask = tuple([True] * (q_size - 2) + [False, True])
slice_out = mb.slice_by_index(
x=q,
begin=slice_begin,
end=slice_end,
end_mask=slice_end_mask,
before_op=op,
)

# Calculate chunk of Q x KT
matmul_out = mb.matmul(x=slice_out, y=k, transpose_x=False, transpose_y=True, before_op=op)
mul_out = mb.mul(x=matmul_out, y=np.array(normalize_factor, dtype=q_dtype), before_op=op)

# Apply the attention mask.
if mask is not None:
if mask.shape[-2] == 1:
mul_out = mb.add(x=mul_out, y=mask_float, before_op=op)
else:
mask_out = mb.slice_by_index(
x=mask_float,
begin=[chunk_start, 0],
end=[chunk_end, mask.shape[-1]],
end_mask=[False, True],
before_op=op,
)
mul_out = mb.add(x=mul_out, y=mask_out, before_op=op)

# Calculate softmax of the product.
softmax_out = mb.softmax(x=mul_out, axis=-1, before_op=op)

# Calculate the chunk of attention.
matmul_v_out = mb.matmul(
x=softmax_out,
y=v,
transpose_x=False,
transpose_y=False,
before_op=op,
)

# Add the chunk of attention to the result value.
concat_values = [concat_out] if concat_out is not None else []
concat_out = mb.concat(values=concat_values + [matmul_v_out], axis=-2, interleave=False, before_op=op)

# Remove the original SDPA operation.
op.enclosing_block.replace_uses_of_var_after_op(
anchor_op=op,
old_var=op.outputs[0],
new_var=concat_out,
)
op.enclosing_block.remove_ops([op])
133 changes: 133 additions & 0 deletions coremltools/converters/mil/mil/passes/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import itertools
import unittest

from typing import ClassVar, Dict, List, Optional

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -38,6 +40,7 @@
get_op_types_in_program,
)
from coremltools.models.utils import _macos_version
from coremltools.converters.mil.frontend.milproto.load import load as _milproto_to_pymil

np.random.seed(1984)
_VALIDATE_MODEL = True
Expand Down Expand Up @@ -7371,3 +7374,133 @@ def prog(x, y, z):

apply_pass_and_basic_check(prog, "common::fuse_stack_split")
assert get_op_types_in_program(prog) == ["stack", "split"] + ["squeeze"] * 3


class TestScaledDotProductAttentionSlicedQ:

class AttentionPyTorch(torch.nn.Module):
@staticmethod
def forward(q, k, v, attn_mask=None):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)

@staticmethod
def _get_example_inputs(
shape_size: int = 3,
qkv_same_shape: bool = True,
dtype: torch.dtype = torch.float16,
attn_mask_dtype: Optional[torch.dtype] = None,
):
batches, seq_length, dimensions = 4, 256, 768
q_shape = (batches, seq_length, dimensions)
kv_shape = q_shape if qkv_same_shape else (batches, seq_length - 16, dimensions)
if shape_size > 3:
q_shape = tuple([1] * (shape_size - len(q_shape)) + list(q_shape))
kv_shape = tuple([1] * (shape_size - len(kv_shape)) + list(kv_shape))
inputs = {
"q": torch.rand(q_shape, dtype=dtype),
"k": torch.rand(kv_shape, dtype=dtype),
"v": torch.rand(kv_shape, dtype=dtype),
}
if attn_mask_dtype is not None:
if attn_mask_dtype == torch.bool:
inputs["attn_mask"] = torch.randint(0, 2, (seq_length, seq_length), dtype=torch.bool)
else:
inputs["attn_mask"] = torch.randn((seq_length, seq_length), dtype=dtype)
return inputs

@staticmethod
def _get_trace_coreml_inputs(example_inputs: Dict[str, torch.Tensor]):
model_inputs = [example_inputs[key] for key in ["q", "k", "v"]]
if "attn_mask" in example_inputs:
model_inputs.append(example_inputs["attn_mask"])

coreml_model_inputs = []
for key in ["q", "k", "v", "attn_mask"]:
if key in example_inputs:
dtype = example_inputs[key].numpy().dtype
if dtype == bool:
dtype = np.float32
coreml_model_inputs.append(ct.TensorType(key, shape=example_inputs[key].shape, dtype=dtype))

return model_inputs, coreml_model_inputs

def verify_sdpa_outputs(self, example_inputs: Dict[str, torch.Tensor]):
pipeline_1 = ct.PassPipeline.DEFAULT

pipeline_2 = ct.PassPipeline.DEFAULT
pipeline_2.append_pass("common::scaled_dot_product_attention_sliced_q")

pipeline_3 = ct.PassPipeline.DEFAULT
pipeline_3.append_pass("common::scaled_dot_product_attention_sliced_q")
pipeline_3.set_options("common::scaled_dot_product_attention_sliced_q", {"min_seq_length": 256})

pipeline_4 = ct.PassPipeline.DEFAULT
pipeline_4.append_pass("common::scaled_dot_product_attention_sliced_q")
pipeline_4.set_options(
"common::scaled_dot_product_attention_sliced_q", {"min_seq_length": 256, "seq_length_divider": 32}
)

model = self.AttentionPyTorch()
model_inputs, coreml_model_inputs = self._get_trace_coreml_inputs(example_inputs)

coreml_models = [
ct.convert(
torch.jit.trace(model, model_inputs).eval(),
inputs=coreml_model_inputs,
minimum_deployment_target=ct.target.iOS18,
convert_to="mlprogram",
compute_units=ct.ComputeUnit.ALL,
skip_model_load=False,
pass_pipeline=pipeline,
)
for pipeline in [pipeline_1, pipeline_2, pipeline_3, pipeline_4]
]

model_specs = [coreml_model.get_spec() for coreml_model in coreml_models]
progs = []
for i in range(len(coreml_models)):
progs.append(
_milproto_to_pymil(
model_spec=model_specs[i],
specification_version=model_specs[i].specificationVersion,
file_weights_dir=coreml_models[i].weights_dir,
)
)

ops_counts = [len(prog.functions["main"].operations) for prog in progs]

assert ops_counts[0] == 1 or ops_counts[0] == 3 # (attn_mask might be cast to bool from input fp16 dtype)
assert ops_counts[1] == 1 or ops_counts[1] == 3 # the Q seq length is less than the default min seq length
assert ops_counts[2] >= 6 * 16 # 6 ops (without consts) per slice
assert ops_counts[3] >= 6 * 32

predict_inputs = copy.deepcopy(example_inputs)
if "attn_mask" in predict_inputs:
predict_inputs["attn_mask"] = predict_inputs["attn_mask"].to(dtype=torch.float32)

outputs = [list(coreml_model.predict(predict_inputs).values())[0] for coreml_model in coreml_models]

for i in range(1, len(outputs)):
assert outputs[0].shape == outputs[i].shape
np.testing.assert_allclose(outputs[0], outputs[i], rtol=0.01)

def test_scaled_dot_product_attention_sliced(self):
# Confirm the basic scenario.
example_inputs = self._get_example_inputs()
self.verify_sdpa_outputs(example_inputs)

# Confirm sdpa with Q, K and V as 4D tensors.
example_inputs = self._get_example_inputs(shape_size=4)
self.verify_sdpa_outputs(example_inputs)

# Confirm sdpa with attn_mask as a bias.
example_inputs = self._get_example_inputs(attn_mask_dtype=torch.float16)
self.verify_sdpa_outputs(example_inputs)

# Confirm sdpa with attn_mask as boolean flags.
example_inputs = self._get_example_inputs(attn_mask_dtype=torch.bool)
self.verify_sdpa_outputs(example_inputs)

# Confirm sdpa works well with different shapes for Q and K & V.
example_inputs = self._get_example_inputs(qkv_same_shape=False)
self.verify_sdpa_outputs(example_inputs)
8 changes: 8 additions & 0 deletions docs/source/coremltools.converters.mil.mil.passes.defs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,11 @@ symbol_transform
.. automodule:: coremltools.converters.mil.mil.passes.defs.symbol_transform

.. autoclass:: materialize_symbolic_shape_program


scaled_dot_product_attention
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to be renamed?

---------------------------------------------------------

.. automodule:: coremltools.converters.mil.mil.passes.defs.scaled_dot_product_attention_sliced_q
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here too


.. autoclass:: scaled_dot_product_attention_sliced_q
2 changes: 1 addition & 1 deletion reqs/test.pip
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pytest-timeout

transformers==4.26.0; platform_machine != "arm64"
transformers==4.38.2; platform_machine == "arm64"
peft
peft==0.13.2

# coremltools.optimize.torch
filelock==3.6.0
Expand Down