@@ -292,28 +292,32 @@ def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
292292
293293
294294@pytest .mark .parametrize (
295- "prefill_batch_size,decode_batch_size,block_size,large_tile_size" ,
295+ "prefill_batch_size,decode_batch_size,block_size,large_tile_size,num_heads,num_queries_per_kv,head_size,mixed_precision " ,
296296 [
297- (1 , 199 , 1 , 512 ), # 512 blocks
298- (4 , 12 , 256 , 2048 ), # 128 blocks
299- (4 , 12 , 16 , 2048 ), # 128 blocks
300- (4 , 12 , 4 , 1024 ), # 256 blocks
301- (4 , 12 , 32 , 2048 ), # 64 blocks
302- (4 , 12 , 32 , 4096 ), # 128 blocks
303- (4 , 12 , 32 , 8192 ), # 256 blocks
304- (4 , 12 , 64 , 8192 ), # 128 blocks
305- ],
306- )
307- @pytest .mark .parametrize (
308- "num_heads,num_queries_per_kv,head_size" ,
309- [
310- (4 , 2 , 8 ),
311- (32 , 8 , 64 ),
312- (4 , 4 , 128 ),
313- (8 , 1 , 32 ),
314- ],
315- )
316- @pytest .mark .parametrize ("mixed_precision" , [True , False ])
297+ # Test minimal configurations (small block size)
298+ (1 , 199 , 1 , 512 , 4 , 2 , 8 , False
299+ ), # minimal block size, small dimensions
300+ (1 , 199 , 1 , 512 , 4 , 2 , 8 , True ), # same with mixed precision
301+
302+ # Test common/medium configurations
303+ (4 , 12 , 32 , 2048 , 32 , 8 , 64 , False ), # common case, larger heads
304+ (4 , 12 , 32 , 2048 , 16 , 4 , 32 ,
305+ True ), # medium size, mixed precision, grouped-query attention (GQA)
306+
307+ # Test large configurations
308+ (4 , 12 , 256 , 8192 , 8 , 1 , 128 , False ), # large blocks, large head size
309+ (4 , 12 , 256 , 8192 , 64 , 8 , 64 , True ), # large blocks, many heads
310+
311+ # Test asymmetric configurations
312+ (2 , 24 , 64 , 4096 , 12 , 4 , 96 , False ), # varied batch sizes
313+ (8 , 8 , 128 , 2048 , 24 , 2 , 48 , True ), # balanced batches
314+
315+ # Test edge cases
316+ (1 , 128 , 16 , 1024 , 4 , 2 , 16 , False ), # large decode batch
317+ (16 , 4 , 8 , 8192 , 48 , 1 , 128 , True ), # large prefill batch
318+ (4 , 12 , 32 , 2048 , 16 , 1 , 32 , True ), # multi-head attention (MHA)
319+ (4 , 12 , 32 , 2048 , 16 , 16 , 32 , True ), # multi-query attention (MQA)
320+ ])
317321@torch .inference_mode ()
318322def test_contexted_kv_attention (
319323 prefill_batch_size : int ,
0 commit comments