From c3f77f79892516fd878ef64c1625b7548c512e47 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Jun 2025 11:51:33 -0700 Subject: [PATCH 1/4] Support gpa in aten spda Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/nn.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index f62a4f27a1..aa108970c2 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1772,10 +1772,6 @@ def aten_scaled_dot_product_attention( "is_causal and attn_mask cannot be set at the same time" ) - assert not enable_gqa, ( - "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" - ) - # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html if scale is None: scale = _attention_scale(query) From 06f20eee8ae8e86789112d0fd281da44a1e19104 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Jun 2025 12:08:09 -0700 Subject: [PATCH 2/4] Add gpa Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/nn.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index aa108970c2..e479023a21 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1786,7 +1786,7 @@ def aten_scaled_dot_product_attention( ) return _aten_scaled_dot_product_attention_float_mask_onnx( - query, key, value, attn_mask, scale, dropout_p + query, key, value, attn_mask, scale, dropout_p, enable_gqa ) @@ -1978,28 +1978,24 @@ def aten_scaled_dot_product_attention_bool_mask( "is_causal and attn_mask cannot be set at the same time" ) - assert not enable_gqa, ( - "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" - ) - if scale is None: scale = _attention_scale(query) scale = op.CastLike(scale, query) if is_causal: attn_mask = _causal_attention_mask(query, key) - # The causal mask is always float - return _aten_scaled_dot_product_attention_float_mask_onnx( - query, key, value, attn_mask, scale, dropout_p - ) if attn_mask is None: return _aten_scaled_dot_product_attention_no_mask_onnx( - query, key, value, scale, dropout_p + query, key, value, scale, dropout_p, enable_gqa=enable_gqa ) - return _aten_scaled_dot_product_attention_bool_mask_onnx( - query, key, value, attn_mask, scale, dropout_p + if attn_mask.dtype == ir.DataType.BOOL: + return _aten_scaled_dot_product_attention_bool_mask_onnx( + query, key, value, attn_mask, scale, dropout_p, enable_gqa=enable_gqa + ) + return _aten_scaled_dot_product_attention_float_mask_onnx( + query, key, value, attn_mask, scale, dropout_p, enable_gqa=enable_gqa ) From 5a939cc8b79b36521db19767462e69c9aa592356 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Jun 2025 12:09:19 -0700 Subject: [PATCH 3/4] Optional dropout Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/nn.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index e479023a21..858c741aeb 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2005,6 +2005,7 @@ def _aten_scaled_dot_product_attention_no_mask_onnx( value: TFloat, scale: TFloat, dropout_p: float, + enable_gqa: bool, ) -> TFloat: # Swap the last two axes of key key_shape = op.Shape(key) @@ -2029,7 +2030,8 @@ def _aten_scaled_dot_product_attention_no_mask_onnx( op.MatMul(query_scaled, key_transposed_scaled), axis=-1, ) - attn_weight, _ = op.Dropout(attn_weight, dropout_p) + if dropout_p > 0.0: + attn_weight, _ = op.Dropout(attn_weight, dropout_p) return op.MatMul(attn_weight, value) @@ -2040,6 +2042,7 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( attn_mask: BOOL, scale: TFloat, dropout_p: float, + enable_gqa: bool, ) -> TFloat: # Swap the last two axes of key key_shape = op.Shape(key) @@ -2068,7 +2071,8 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), axis=-1, ) - attn_weight, _ = op.Dropout(attn_weight, dropout_p) + if dropout_p > 0.0: + attn_weight, _ = op.Dropout(attn_weight, dropout_p) return op.MatMul(attn_weight, value) @@ -2079,6 +2083,7 @@ def _aten_scaled_dot_product_attention_float_mask_onnx( attn_mask: TFloat, scale: TFloat, dropout_p: float, + enable_gqa: bool, ) -> TFloat: # Swap the last two axes of key key_shape = op.Shape(key) @@ -2103,7 +2108,8 @@ def _aten_scaled_dot_product_attention_float_mask_onnx( op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), axis=-1, ) - attn_weight, _ = op.Dropout(attn_weight, dropout_p) + if dropout_p > 0.0: + attn_weight, _ = op.Dropout(attn_weight, dropout_p) return op.MatMul(attn_weight, value) From 16d75e94d4f67abd85b9f1299bdca390b7c50a04 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 20 Jun 2025 12:14:53 -0700 Subject: [PATCH 4/4] wip Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/nn.py | 141 +++++++++++++++++++ 1 file changed, 141 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 858c741aeb..219318aab8 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2007,6 +2007,53 @@ def _aten_scaled_dot_product_attention_no_mask_onnx( dropout_p: float, enable_gqa: bool, ) -> TFloat: + # Handle Grouped Query Attention (GQA) if enabled + if enable_gqa: + # Get head dimensions + query_shape = op.Shape(query) + key_shape = op.Shape(key) + query_heads = op.Slice(query_shape, [-3], [-2]) # query.size(-3) + key_heads = op.Slice(key_shape, [-3], [-2]) # key.size(-3) + + # Calculate the repeat factor: query_heads // key_heads + repeat_factor = op.Div(query_heads, key_heads) + + # Expand key and value to match query head dimension + # Implement key.repeat_interleave(repeat_factor, -3) using Expand + # First, get the shape of key and modify the head dimension + key_shape_expanded = op.Concat( + op.Slice(key_shape, [0], [-3]), # batch and other dims + op.Mul(key_heads, repeat_factor), # expanded head dimension + op.Slice(key_shape, [-2], [_INT64_MAX]), # remaining dims + axis=0 + ) + # Expand key by repeating each head 'repeat_factor' times + key_unsqueezed = op.Unsqueeze(key, [-2]) # Add dimension for repeating + key_tiled = op.Tile(key_unsqueezed, op.Concat( + op.Constant(value_ints=[1, 1, 1]), # don't repeat batch, seq, head dims + repeat_factor, # repeat factor for the new dimension + op.Constant(value_ints=[1, 1]), # don't repeat the remaining dims + axis=0 + )) + key = op.Reshape(key_tiled, key_shape_expanded) + + # Same for value + value_shape = op.Shape(value) + value_shape_expanded = op.Concat( + op.Slice(value_shape, [0], [-3]), # batch and other dims + op.Mul(key_heads, repeat_factor), # expanded head dimension + op.Slice(value_shape, [-2], [_INT64_MAX]), # remaining dims + axis=0 + ) + value_unsqueezed = op.Unsqueeze(value, [-2]) + value_tiled = op.Tile(value_unsqueezed, op.Concat( + op.Constant(value_ints=[1, 1, 1]), + repeat_factor, + op.Constant(value_ints=[1, 1]), + axis=0 + )) + value = op.Reshape(value_tiled, value_shape_expanded) + # Swap the last two axes of key key_shape = op.Shape(key) key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX])) @@ -2044,6 +2091,53 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( dropout_p: float, enable_gqa: bool, ) -> TFloat: + # Handle Grouped Query Attention (GQA) if enabled + if enable_gqa: + # Get head dimensions + query_shape = op.Shape(query) + key_shape = op.Shape(key) + query_heads = op.Slice(query_shape, [-3], [-2]) # query.size(-3) + key_heads = op.Slice(key_shape, [-3], [-2]) # key.size(-3) + + # Calculate the repeat factor: query_heads // key_heads + repeat_factor = op.Div(query_heads, key_heads) + + # Expand key and value to match query head dimension + # Implement key.repeat_interleave(repeat_factor, -3) using Expand + # First, get the shape of key and modify the head dimension + key_shape_expanded = op.Concat( + op.Slice(key_shape, [0], [-3]), # batch and other dims + op.Mul(key_heads, repeat_factor), # expanded head dimension + op.Slice(key_shape, [-2], [_INT64_MAX]), # remaining dims + axis=0 + ) + # Expand key by repeating each head 'repeat_factor' times + key_unsqueezed = op.Unsqueeze(key, [-2]) # Add dimension for repeating + key_tiled = op.Tile(key_unsqueezed, op.Concat( + op.Constant(value_ints=[1, 1, 1]), # don't repeat batch, seq, head dims + repeat_factor, # repeat factor for the new dimension + op.Constant(value_ints=[1, 1]), # don't repeat the remaining dims + axis=0 + )) + key = op.Reshape(key_tiled, key_shape_expanded) + + # Same for value + value_shape = op.Shape(value) + value_shape_expanded = op.Concat( + op.Slice(value_shape, [0], [-3]), # batch and other dims + op.Mul(key_heads, repeat_factor), # expanded head dimension + op.Slice(value_shape, [-2], [_INT64_MAX]), # remaining dims + axis=0 + ) + value_unsqueezed = op.Unsqueeze(value, [-2]) + value_tiled = op.Tile(value_unsqueezed, op.Concat( + op.Constant(value_ints=[1, 1, 1]), + repeat_factor, + op.Constant(value_ints=[1, 1]), + axis=0 + )) + value = op.Reshape(value_tiled, value_shape_expanded) + # Swap the last two axes of key key_shape = op.Shape(key) key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX])) @@ -2085,6 +2179,53 @@ def _aten_scaled_dot_product_attention_float_mask_onnx( dropout_p: float, enable_gqa: bool, ) -> TFloat: + # Handle Grouped Query Attention (GQA) if enabled + if enable_gqa: + # Get head dimensions + query_shape = op.Shape(query) + key_shape = op.Shape(key) + query_heads = op.Slice(query_shape, [-3], [-2]) # query.size(-3) + key_heads = op.Slice(key_shape, [-3], [-2]) # key.size(-3) + + # Calculate the repeat factor: query_heads // key_heads + repeat_factor = op.Div(query_heads, key_heads) + + # Expand key and value to match query head dimension + # Implement key.repeat_interleave(repeat_factor, -3) using Expand + # First, get the shape of key and modify the head dimension + key_shape_expanded = op.Concat( + op.Slice(key_shape, [0], [-3]), # batch and other dims + op.Mul(key_heads, repeat_factor), # expanded head dimension + op.Slice(key_shape, [-2], [_INT64_MAX]), # remaining dims + axis=0 + ) + # Expand key by repeating each head 'repeat_factor' times + key_unsqueezed = op.Unsqueeze(key, [-2]) # Add dimension for repeating + key_tiled = op.Tile(key_unsqueezed, op.Concat( + op.Constant(value_ints=[1, 1, 1]), # don't repeat batch, seq, head dims + repeat_factor, # repeat factor for the new dimension + op.Constant(value_ints=[1, 1]), # don't repeat the remaining dims + axis=0 + )) + key = op.Reshape(key_tiled, key_shape_expanded) + + # Same for value + value_shape = op.Shape(value) + value_shape_expanded = op.Concat( + op.Slice(value_shape, [0], [-3]), # batch and other dims + op.Mul(key_heads, repeat_factor), # expanded head dimension + op.Slice(value_shape, [-2], [_INT64_MAX]), # remaining dims + axis=0 + ) + value_unsqueezed = op.Unsqueeze(value, [-2]) + value_tiled = op.Tile(value_unsqueezed, op.Concat( + op.Constant(value_ints=[1, 1, 1]), + repeat_factor, + op.Constant(value_ints=[1, 1]), + axis=0 + )) + value = op.Reshape(value_tiled, value_shape_expanded) + # Swap the last two axes of key key_shape = op.Shape(key) key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX]))