Skip to content

Commit

Permalink
simplify the attention calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Jirmasek committed Dec 23, 2024
1 parent c69cf80 commit 05c5fc3
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,32 +160,21 @@ def _replace_scaled_dot_product_attention(self, op):
)
mul_out = mb.add(x=mul_out, y=mask_out, before_op=op)

# Numerical stability of softmax operation.
max_out = mb.reduce_max(x=mul_out, axes=[-1], keep_dims=True, before_op=op)
sub_out = mb.sub(x=mul_out, y=max_out, before_op=op)
# Calculate softmax of the product.
softmax_out = mb.softmax(x=mul_out, axis=-1, before_op=op)

exp_out = mb.exp(x=sub_out, before_op=op) # (x, bs, chunk_size, k_seq_length)

# Calculate s_star.
sum_out = mb.reduce_sum(x=exp_out, axes=[-1], keep_dims=True, before_op=op)
tile_reps = [1] * (q_size - 1) + [dims]
tile_out = mb.tile(x=sum_out, reps=tile_reps, before_op=op) # (x, bs, chunk_size, dims)

# Calculate v_star.
# Calculate the chunk of attention.
matmul_v_out = mb.matmul(
x=exp_out,
x=softmax_out,
y=v,
transpose_x=False,
transpose_y=False,
before_op=op,
)

# v_star / s_star
div_out = mb.real_div(x=matmul_v_out, y=tile_out, before_op=op) # (x, bs, chunk_size, dims)

# Add the chunk of attention to the result value.
concat_values = [concat_out] if concat_out is not None else []
concat_out = mb.concat(values=concat_values + [div_out], axis=-2, interleave=False, before_op=op)
concat_out = mb.concat(values=concat_values + [matmul_v_out], axis=-2, interleave=False, before_op=op)

# Remove the original SDPA operation.
op.enclosing_block.replace_uses_of_var_after_op(
Expand Down
7 changes: 4 additions & 3 deletions coremltools/converters/mil/mil/passes/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7471,8 +7471,8 @@ def verify_sdpa_outputs(self, example_inputs: Dict[str, torch.Tensor]):

assert ops_counts[0] == 1 or ops_counts[0] == 3 # (attn_mask might be cast to bool from input fp16 dtype)
assert ops_counts[1] == 1 or ops_counts[1] == 3 # the Q seq length is less than the default min seq length
assert ops_counts[2] >= 11 * 16 # 11 ops (without consts) per slice
assert ops_counts[3] >= 11 * 32
assert ops_counts[2] >= 6 * 16 # 6 ops (without consts) per slice
assert ops_counts[3] >= 6 * 32

predict_inputs = copy.deepcopy(example_inputs)
if "attn_mask" in predict_inputs:
Expand All @@ -7481,7 +7481,8 @@ def verify_sdpa_outputs(self, example_inputs: Dict[str, torch.Tensor]):
outputs = [list(coreml_model.predict(predict_inputs).values())[0] for coreml_model in coreml_models]

for i in range(1, len(outputs)):
np.testing.assert_allclose(outputs[0], outputs[i], rtol=0.01, strict=True)
assert outputs[0].shape == outputs[i].shape
np.testing.assert_allclose(outputs[0], outputs[i], rtol=0.01)

def test_scaled_dot_product_attention_sliced(self):
# Confirm the basic scenario.
Expand Down

0 comments on commit 05c5fc3

Please sign in to comment.