-
Notifications
You must be signed in to change notification settings - Fork 81
Support gqa in aten spda #2408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Support gqa in aten spda #2408
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | |||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1772,10 +1772,6 @@ def aten_scaled_dot_product_attention( | ||||||||||||||||
"is_causal and attn_mask cannot be set at the same time" | |||||||||||||||||
) | |||||||||||||||||
|
|||||||||||||||||
assert not enable_gqa, ( | |||||||||||||||||
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True" | |||||||||||||||||
) | |||||||||||||||||
|
|||||||||||||||||
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html | |||||||||||||||||
if scale is None: | |||||||||||||||||
scale = _attention_scale(query) | |||||||||||||||||
|
@@ -1790,7 +1786,7 @@ def aten_scaled_dot_product_attention( | ||||||||||||||||
) | |||||||||||||||||
|
|||||||||||||||||
return _aten_scaled_dot_product_attention_float_mask_onnx( | |||||||||||||||||
query, key, value, attn_mask, scale, dropout_p | |||||||||||||||||
query, key, value, attn_mask, scale, dropout_p, enable_gqa | |||||||||||||||||
) | |||||||||||||||||
|
|||||||||||||||||
|
|||||||||||||||||
|
@@ -1982,28 +1978,24 @@ def aten_scaled_dot_product_attention_bool_mask( | ||||||||||||||||
"is_causal and attn_mask cannot be set at the same time" | |||||||||||||||||
) | |||||||||||||||||
|
|||||||||||||||||
assert not enable_gqa, ( | |||||||||||||||||
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True" | |||||||||||||||||
) | |||||||||||||||||
|
|||||||||||||||||
if scale is None: | |||||||||||||||||
scale = _attention_scale(query) | |||||||||||||||||
scale = op.CastLike(scale, query) | |||||||||||||||||
|
|||||||||||||||||
if is_causal: | |||||||||||||||||
attn_mask = _causal_attention_mask(query, key) | |||||||||||||||||
# The causal mask is always float | |||||||||||||||||
return _aten_scaled_dot_product_attention_float_mask_onnx( | |||||||||||||||||
query, key, value, attn_mask, scale, dropout_p | |||||||||||||||||
) | |||||||||||||||||
|
|||||||||||||||||
if attn_mask is None: | |||||||||||||||||
return _aten_scaled_dot_product_attention_no_mask_onnx( | |||||||||||||||||
query, key, value, scale, dropout_p | |||||||||||||||||
query, key, value, scale, dropout_p, enable_gqa=enable_gqa | |||||||||||||||||
) | |||||||||||||||||
|
|||||||||||||||||
return _aten_scaled_dot_product_attention_bool_mask_onnx( | |||||||||||||||||
query, key, value, attn_mask, scale, dropout_p | |||||||||||||||||
if attn_mask.dtype == ir.DataType.BOOL: | |||||||||||||||||
return _aten_scaled_dot_product_attention_bool_mask_onnx( | |||||||||||||||||
query, key, value, attn_mask, scale, dropout_p, enable_gqa=enable_gqa | |||||||||||||||||
) | |||||||||||||||||
Comment on lines
+1994
to
+1996
Check failureCode scanning / CodeQL Wrong name for an argument in a call
Keyword argument 'enable_gqa' is not a supported parameter name of [function _aten_scaled_dot_product_attention_bool_mask_onnx](1).
Copilot AutofixAI 3 months ago To fix the issue, the keyword argument
Suggested changeset
1
onnxscript/function_libs/torch_lib/ops/nn.py
Copilot is powered by AI and may make mistakes. Always verify output.
Unable to commit as this autofix suggestion is now outdated
Positive FeedbackNegative Feedback
Refresh and try again.
|
|||||||||||||||||
return _aten_scaled_dot_product_attention_float_mask_onnx( | |||||||||||||||||
|
|||||||||||||||||
query, key, value, attn_mask, scale, dropout_p, enable_gqa=enable_gqa | |||||||||||||||||
) | |||||||||||||||||
|
|||||||||||||||||
|
|||||||||||||||||
|
@@ -2013,7 +2005,55 @@ def _aten_scaled_dot_product_attention_no_mask_onnx( | ||||||||||||||||
value: TFloat, | |||||||||||||||||
scale: TFloat, | |||||||||||||||||
dropout_p: float, | |||||||||||||||||
enable_gqa: bool, | |||||||||||||||||
) -> TFloat: | |||||||||||||||||
# Handle Grouped Query Attention (GQA) if enabled | |||||||||||||||||
if enable_gqa: | |||||||||||||||||
# Get head dimensions | |||||||||||||||||
query_shape = op.Shape(query) | |||||||||||||||||
key_shape = op.Shape(key) | |||||||||||||||||
query_heads = op.Slice(query_shape, [-3], [-2]) # query.size(-3) | |||||||||||||||||
key_heads = op.Slice(key_shape, [-3], [-2]) # key.size(-3) | |||||||||||||||||
|
|||||||||||||||||
# Calculate the repeat factor: query_heads // key_heads | |||||||||||||||||
repeat_factor = op.Div(query_heads, key_heads) | |||||||||||||||||
|
|||||||||||||||||
# Expand key and value to match query head dimension | |||||||||||||||||
# Implement key.repeat_interleave(repeat_factor, -3) using Expand | |||||||||||||||||
# First, get the shape of key and modify the head dimension | |||||||||||||||||
key_shape_expanded = op.Concat( | |||||||||||||||||
op.Slice(key_shape, [0], [-3]), # batch and other dims | |||||||||||||||||
op.Mul(key_heads, repeat_factor), # expanded head dimension | |||||||||||||||||
op.Slice(key_shape, [-2], [_INT64_MAX]), # remaining dims | |||||||||||||||||
axis=0 | |||||||||||||||||
) | |||||||||||||||||
# Expand key by repeating each head 'repeat_factor' times | |||||||||||||||||
key_unsqueezed = op.Unsqueeze(key, [-2]) # Add dimension for repeating | |||||||||||||||||
key_tiled = op.Tile(key_unsqueezed, op.Concat( | |||||||||||||||||
op.Constant(value_ints=[1, 1, 1]), # don't repeat batch, seq, head dims | |||||||||||||||||
repeat_factor, # repeat factor for the new dimension | |||||||||||||||||
op.Constant(value_ints=[1, 1]), # don't repeat the remaining dims | |||||||||||||||||
axis=0 | |||||||||||||||||
)) | |||||||||||||||||
key = op.Reshape(key_tiled, key_shape_expanded) | |||||||||||||||||
|
|||||||||||||||||
# Same for value | |||||||||||||||||
value_shape = op.Shape(value) | |||||||||||||||||
value_shape_expanded = op.Concat( | |||||||||||||||||
op.Slice(value_shape, [0], [-3]), # batch and other dims | |||||||||||||||||
op.Mul(key_heads, repeat_factor), # expanded head dimension | |||||||||||||||||
op.Slice(value_shape, [-2], [_INT64_MAX]), # remaining dims | |||||||||||||||||
axis=0 | |||||||||||||||||
) | |||||||||||||||||
value_unsqueezed = op.Unsqueeze(value, [-2]) | |||||||||||||||||
value_tiled = op.Tile(value_unsqueezed, op.Concat( | |||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. op.Tile does not align to PyTorch inplementation. if (
(q_num_heads != k_num_heads)
and (q_num_heads % k_num_heads == 0)
and (k_num_heads == v_num_heads)
):
seq_reps = q_num_heads // k_num_heads
# Interleave-repeat each KV head: [h0, h0, h1, h1, ...]
K = np.repeat(K, repeats=seq_reps, axis=1)
V = np.repeat(V, repeats=seq_reps, axis=1) We should be able to reuse There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we use expand for repeat interleave for simplicity over tile? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we can just adapt whatever function body is in defs.cc to torchlib? Is there any difference? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably not. I must have need using the old implementation |
|||||||||||||||||
op.Constant(value_ints=[1, 1, 1]), | |||||||||||||||||
repeat_factor, | |||||||||||||||||
op.Constant(value_ints=[1, 1]), | |||||||||||||||||
axis=0 | |||||||||||||||||
)) | |||||||||||||||||
value = op.Reshape(value_tiled, value_shape_expanded) | |||||||||||||||||
|
|||||||||||||||||
# Swap the last two axes of key | |||||||||||||||||
key_shape = op.Shape(key) | |||||||||||||||||
key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX])) | |||||||||||||||||
|
@@ -2037,7 +2077,8 @@ def _aten_scaled_dot_product_attention_no_mask_onnx( | ||||||||||||||||
op.MatMul(query_scaled, key_transposed_scaled), | |||||||||||||||||
axis=-1, | |||||||||||||||||
) | |||||||||||||||||
attn_weight, _ = op.Dropout(attn_weight, dropout_p) | |||||||||||||||||
if dropout_p > 0.0: | |||||||||||||||||
attn_weight, _ = op.Dropout(attn_weight, dropout_p) | |||||||||||||||||
return op.MatMul(attn_weight, value) | |||||||||||||||||
|
|||||||||||||||||
|
|||||||||||||||||
|
@@ -2048,7 +2089,55 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( | ||||||||||||||||
attn_mask: BOOL, | |||||||||||||||||
scale: TFloat, | |||||||||||||||||
dropout_p: float, | |||||||||||||||||
enable_gqa: bool, | |||||||||||||||||
) -> TFloat: | |||||||||||||||||
# Handle Grouped Query Attention (GQA) if enabled | |||||||||||||||||
if enable_gqa: | |||||||||||||||||
# Get head dimensions | |||||||||||||||||
query_shape = op.Shape(query) | |||||||||||||||||
key_shape = op.Shape(key) | |||||||||||||||||
query_heads = op.Slice(query_shape, [-3], [-2]) # query.size(-3) | |||||||||||||||||
key_heads = op.Slice(key_shape, [-3], [-2]) # key.size(-3) | |||||||||||||||||
|
|||||||||||||||||
# Calculate the repeat factor: query_heads // key_heads | |||||||||||||||||
repeat_factor = op.Div(query_heads, key_heads) | |||||||||||||||||
|
|||||||||||||||||
# Expand key and value to match query head dimension | |||||||||||||||||
# Implement key.repeat_interleave(repeat_factor, -3) using Expand | |||||||||||||||||
# First, get the shape of key and modify the head dimension | |||||||||||||||||
key_shape_expanded = op.Concat( | |||||||||||||||||
op.Slice(key_shape, [0], [-3]), # batch and other dims | |||||||||||||||||
op.Mul(key_heads, repeat_factor), # expanded head dimension | |||||||||||||||||
op.Slice(key_shape, [-2], [_INT64_MAX]), # remaining dims | |||||||||||||||||
axis=0 | |||||||||||||||||
) | |||||||||||||||||
# Expand key by repeating each head 'repeat_factor' times | |||||||||||||||||
key_unsqueezed = op.Unsqueeze(key, [-2]) # Add dimension for repeating | |||||||||||||||||
key_tiled = op.Tile(key_unsqueezed, op.Concat( | |||||||||||||||||
op.Constant(value_ints=[1, 1, 1]), # don't repeat batch, seq, head dims | |||||||||||||||||
repeat_factor, # repeat factor for the new dimension | |||||||||||||||||
op.Constant(value_ints=[1, 1]), # don't repeat the remaining dims | |||||||||||||||||
axis=0 | |||||||||||||||||
)) | |||||||||||||||||
key = op.Reshape(key_tiled, key_shape_expanded) | |||||||||||||||||
|
|||||||||||||||||
# Same for value | |||||||||||||||||
value_shape = op.Shape(value) | |||||||||||||||||
value_shape_expanded = op.Concat( | |||||||||||||||||
op.Slice(value_shape, [0], [-3]), # batch and other dims | |||||||||||||||||
op.Mul(key_heads, repeat_factor), # expanded head dimension | |||||||||||||||||
op.Slice(value_shape, [-2], [_INT64_MAX]), # remaining dims | |||||||||||||||||
axis=0 | |||||||||||||||||
) | |||||||||||||||||
value_unsqueezed = op.Unsqueeze(value, [-2]) | |||||||||||||||||
value_tiled = op.Tile(value_unsqueezed, op.Concat( | |||||||||||||||||
op.Constant(value_ints=[1, 1, 1]), | |||||||||||||||||
repeat_factor, | |||||||||||||||||
op.Constant(value_ints=[1, 1]), | |||||||||||||||||
axis=0 | |||||||||||||||||
)) | |||||||||||||||||
value = op.Reshape(value_tiled, value_shape_expanded) | |||||||||||||||||
|
|||||||||||||||||
# Swap the last two axes of key | |||||||||||||||||
key_shape = op.Shape(key) | |||||||||||||||||
key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX])) | |||||||||||||||||
|
@@ -2076,7 +2165,8 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( | ||||||||||||||||
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), | |||||||||||||||||
axis=-1, | |||||||||||||||||
) | |||||||||||||||||
attn_weight, _ = op.Dropout(attn_weight, dropout_p) | |||||||||||||||||
if dropout_p > 0.0: | |||||||||||||||||
attn_weight, _ = op.Dropout(attn_weight, dropout_p) | |||||||||||||||||
return op.MatMul(attn_weight, value) | |||||||||||||||||
|
|||||||||||||||||
|
|||||||||||||||||
|
@@ -2087,7 +2177,55 @@ def _aten_scaled_dot_product_attention_float_mask_onnx( | ||||||||||||||||
attn_mask: TFloat, | |||||||||||||||||
scale: TFloat, | |||||||||||||||||
dropout_p: float, | |||||||||||||||||
enable_gqa: bool, | |||||||||||||||||
) -> TFloat: | |||||||||||||||||
# Handle Grouped Query Attention (GQA) if enabled | |||||||||||||||||
if enable_gqa: | |||||||||||||||||
# Get head dimensions | |||||||||||||||||
query_shape = op.Shape(query) | |||||||||||||||||
key_shape = op.Shape(key) | |||||||||||||||||
query_heads = op.Slice(query_shape, [-3], [-2]) # query.size(-3) | |||||||||||||||||
key_heads = op.Slice(key_shape, [-3], [-2]) # key.size(-3) | |||||||||||||||||
|
|||||||||||||||||
# Calculate the repeat factor: query_heads // key_heads | |||||||||||||||||
repeat_factor = op.Div(query_heads, key_heads) | |||||||||||||||||
|
|||||||||||||||||
# Expand key and value to match query head dimension | |||||||||||||||||
# Implement key.repeat_interleave(repeat_factor, -3) using Expand | |||||||||||||||||
# First, get the shape of key and modify the head dimension | |||||||||||||||||
key_shape_expanded = op.Concat( | |||||||||||||||||
op.Slice(key_shape, [0], [-3]), # batch and other dims | |||||||||||||||||
op.Mul(key_heads, repeat_factor), # expanded head dimension | |||||||||||||||||
op.Slice(key_shape, [-2], [_INT64_MAX]), # remaining dims | |||||||||||||||||
axis=0 | |||||||||||||||||
) | |||||||||||||||||
# Expand key by repeating each head 'repeat_factor' times | |||||||||||||||||
key_unsqueezed = op.Unsqueeze(key, [-2]) # Add dimension for repeating | |||||||||||||||||
key_tiled = op.Tile(key_unsqueezed, op.Concat( | |||||||||||||||||
op.Constant(value_ints=[1, 1, 1]), # don't repeat batch, seq, head dims | |||||||||||||||||
repeat_factor, # repeat factor for the new dimension | |||||||||||||||||
op.Constant(value_ints=[1, 1]), # don't repeat the remaining dims | |||||||||||||||||
axis=0 | |||||||||||||||||
)) | |||||||||||||||||
key = op.Reshape(key_tiled, key_shape_expanded) | |||||||||||||||||
|
|||||||||||||||||
# Same for value | |||||||||||||||||
value_shape = op.Shape(value) | |||||||||||||||||
value_shape_expanded = op.Concat( | |||||||||||||||||
op.Slice(value_shape, [0], [-3]), # batch and other dims | |||||||||||||||||
op.Mul(key_heads, repeat_factor), # expanded head dimension | |||||||||||||||||
op.Slice(value_shape, [-2], [_INT64_MAX]), # remaining dims | |||||||||||||||||
axis=0 | |||||||||||||||||
) | |||||||||||||||||
value_unsqueezed = op.Unsqueeze(value, [-2]) | |||||||||||||||||
value_tiled = op.Tile(value_unsqueezed, op.Concat( | |||||||||||||||||
op.Constant(value_ints=[1, 1, 1]), | |||||||||||||||||
repeat_factor, | |||||||||||||||||
op.Constant(value_ints=[1, 1]), | |||||||||||||||||
axis=0 | |||||||||||||||||
)) | |||||||||||||||||
value = op.Reshape(value_tiled, value_shape_expanded) | |||||||||||||||||
|
|||||||||||||||||
# Swap the last two axes of key | |||||||||||||||||
key_shape = op.Shape(key) | |||||||||||||||||
key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX])) | |||||||||||||||||
|
@@ -2111,7 +2249,8 @@ def _aten_scaled_dot_product_attention_float_mask_onnx( | ||||||||||||||||
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), | |||||||||||||||||
axis=-1, | |||||||||||||||||
) | |||||||||||||||||
attn_weight, _ = op.Dropout(attn_weight, dropout_p) | |||||||||||||||||
if dropout_p > 0.0: | |||||||||||||||||
attn_weight, _ = op.Dropout(attn_weight, dropout_p) | |||||||||||||||||
return op.MatMul(attn_weight, value) | |||||||||||||||||
|
|||||||||||||||||
|
|||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.