diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 54f738ba737..f63f15d2699 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -133,7 +133,7 @@ def __init__( def forward( self, - input_pos: torch.Tensor, + input_pos: Optional[torch.Tensor], q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim) k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim) v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim) @@ -218,13 +218,17 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.head_dim, args.enable_dynamic_shape, ) - self.SDPA = SDPA( - dim=self.n_local_heads * self.head_dim, - head_dim=self.head_dim, - n_rep=self.n_rep, - max_context_len=self.max_context_len, - enable_dynamic_shape=args.enable_dynamic_shape, - ) + else: + # Use a constant state to avoid export error + self.zero_pos = torch.tensor([0]) + + self.SDPA = SDPA( + dim=self.n_local_heads * self.head_dim, + head_dim=self.head_dim, + n_rep=self.n_rep, + max_context_len=self.max_context_len, + enable_dynamic_shape=args.enable_dynamic_shape, + ) def forward( self, @@ -258,20 +262,8 @@ def forward( assert input_pos is not None k, v = self.kv_cache.update(input_pos, k, v) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) - return self.wo(output), None - - # grouped multiquery attention: expand out keys and values - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - - assert hasattr(self, "mask") - - mask = self.mask[:seqlen, :seqlen] - - output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) - - output = self.wo(output) - - return output, None + else: + mask = self.mask[:seqlen, :seqlen] + # No kv cache. Pass 0 input_pos + output = self.SDPA(self.zero_pos, q, k, v, bsz, seqlen, mask) + return self.wo(output), None diff --git a/examples/models/llama/tests/test_attention_sma.py b/examples/models/llama/tests/test_attention_sma.py new file mode 100644 index 00000000000..beff1e95f18 --- /dev/null +++ b/examples/models/llama/tests/test_attention_sma.py @@ -0,0 +1,88 @@ +import unittest + +import torch +from executorch.examples.models.llama.attention import ( + AttentionMHA, + KVCache, + ModelArgs, + Rope, + SDPA, +) + + +class TestAttentionMHA(unittest.TestCase): + + def create_mock_args(self): + return ModelArgs( + use_kv_cache=True, + n_heads=8, + n_kv_heads=4, + head_dim=64, + max_batch_size=2, + max_context_len=16, + dim=512, + attention_qkv_bias=False, + enable_dynamic_shape=False, + ) + + def test_attentionmha_init(self): + args = self.create_mock_args() + rope = Rope(args) + attn = AttentionMHA(args, layer_id=0, rope=rope) + + self.assertEqual(attn.n_heads, 8) + self.assertEqual(attn.n_kv_heads, 4) + self.assertEqual(attn.n_local_heads, 8) + self.assertEqual(attn.n_local_kv_heads, 4) + self.assertEqual(attn.head_dim, 64) + self.assertEqual(attn.dim, 512) + self.assertEqual(attn.mask.shape, (16, 16)) # Causal mask shape check + self.assertTrue(attn.use_kv_cache) + + if attn.use_kv_cache: + self.assertIsInstance(attn.kv_cache, KVCache) + self.assertIsInstance(attn.SDPA, SDPA) + + def test_attentionmha_forward(self): + args = self.create_mock_args() + rope = Rope(args) + attn = AttentionMHA(args, layer_id=0, rope=rope) + + bsz, seqlen, dim = 2, 4, args.dim + x = torch.randn(bsz, seqlen, dim) + freqs_cos = torch.randn(seqlen, args.head_dim // 2) + freqs_sin = torch.randn(seqlen, args.head_dim // 2) + input_pos = torch.tensor([0, 1, 2, 3]) + + output, _ = attn.forward(x, freqs_cos, freqs_sin, input_pos=input_pos) + + self.assertEqual(output.shape, (bsz, seqlen, dim)) + + def test_attentionmha_forward_no_kv_cache(self): + args = self.create_mock_args() + args.use_kv_cache = False # Disable KV cache for this test + rope = Rope(args) + attn = AttentionMHA(args, layer_id=0, rope=rope) + + bsz, seqlen, dim = 2, 4, args.dim + x = torch.randn(bsz, seqlen, dim) + freqs_cos = torch.randn(seqlen, args.head_dim // 2) + freqs_sin = torch.randn(seqlen, args.head_dim // 2) + + output, _ = attn.forward(x, freqs_cos, freqs_sin) + + self.assertEqual(output.shape, (bsz, seqlen, dim)) + + def test_attentionmha_invalid_kv_cache(self): + args = self.create_mock_args() + rope = Rope(args) + attn = AttentionMHA(args, layer_id=0, rope=rope) + + bsz, seqlen, dim = 2, 4, args.dim + x = torch.randn(bsz, seqlen, dim) + freqs_cos = torch.randn(seqlen, args.head_dim // 2) + freqs_sin = torch.randn(seqlen, args.head_dim // 2) + + # No input_pos provided, should raise assertion error + with self.assertRaises(AssertionError): + attn.forward(x, freqs_cos, freqs_sin)