@@ -407,7 +407,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
407407 _query_start_loc_to_chunk_indices_offsets (
408408 cu_seqlens , chunk_size , cu_seqlens [- 1 ])
409409 Y_ref = torch .empty_like (X )
410- state_ref = mamba_chunk_scan_combined (
410+ state_ref = mamba_chunk_scan_combined_varlen (
411411 X ,
412412 dt ,
413413 A ,
@@ -419,7 +419,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
419419 seq_idx = seq_idx ,
420420 chunk_indices = chunk_indices ,
421421 chunk_offsets = chunk_offsets ,
422- return_varlen_states = True ,
423422 initial_states = None ,
424423 out = Y_ref ,
425424 )
@@ -435,27 +434,27 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
435434 chunked_seq_idx = torch .repeat_interleave (
436435 torch .arange (len (chunked_seqlens ), device = device ),
437436 chunked_seqlens ,
438- output_size = chunked_cu_seqlens [- 1 ]).unsqueeze ( 0 ). to (torch .int32 )
437+ output_size = chunked_cu_seqlens [- 1 ]).to (torch .int32 )
439438 chunked_input_seq_len = chunked_cu_seqlens [- 1 ]
440- X_chunked = torch .zeros_like (X )[:, : chunked_input_seq_len , ...]
441- dt_chunked = torch .zeros_like (dt )[:, : chunked_input_seq_len , ...]
442- B_chunked = torch .zeros_like (B )[:, : chunked_input_seq_len , ...]
443- C_chunked = torch .zeros_like (C )[:, : chunked_input_seq_len , ...]
439+ X_chunked = torch .zeros_like (X )[:chunked_input_seq_len , ...]
440+ dt_chunked = torch .zeros_like (dt )[:chunked_input_seq_len , ...]
441+ B_chunked = torch .zeros_like (B )[:chunked_input_seq_len , ...]
442+ C_chunked = torch .zeros_like (C )[:chunked_input_seq_len , ...]
444443 for i in range (num_sequences ):
445444 # fmt: off
446- chunk_f = lambda x , i : x [:, cu_seqlens [i ]:cu_seqlens [i ] + chunked_seqlens [i ], ...] # noqa: E501
445+ chunk_f = lambda x , i : x [cu_seqlens [i ]:cu_seqlens [i ] + chunked_seqlens [i ], ...] # noqa: E501
447446
448- X_chunked [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (X , i ) # noqa: E501
449- dt_chunked [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (dt , i ) # noqa: E501
450- B_chunked [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (B , i ) # noqa: E501
451- C_chunked [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (C , i ) # noqa: E501
447+ X_chunked [chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (X , i ) # noqa: E501
448+ dt_chunked [chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (dt , i ) # noqa: E501
449+ B_chunked [chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (B , i ) # noqa: E501
450+ C_chunked [chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (C , i ) # noqa: E501
452451 # fmt: on
453452
454453 chunk_indices , chunk_offsets = \
455454 _query_start_loc_to_chunk_indices_offsets (
456455 chunked_cu_seqlens , chunk_size , chunked_cu_seqlens [- 1 ])
457456 Y_partial = torch .empty_like (X_chunked )
458- partial_state = mamba_chunk_scan_combined (
457+ partial_state = mamba_chunk_scan_combined_varlen (
459458 X_chunked ,
460459 dt_chunked ,
461460 A ,
@@ -467,7 +466,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
467466 seq_idx = chunked_seq_idx ,
468467 chunk_indices = chunk_indices ,
469468 chunk_offsets = chunk_offsets ,
470- return_varlen_states = True ,
471469 initial_states = None ,
472470 out = Y_partial ,
473471 )
@@ -482,29 +480,28 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
482480 remaining_chunked_seq_idx = torch .repeat_interleave (
483481 torch .arange (len (remaining_chunked_seqlens ), device = device ),
484482 remaining_chunked_seqlens ,
485- output_size = remaining_chunked_cu_seqlens [- 1 ]).unsqueeze (0 ).to (
486- torch .int32 )
483+ output_size = remaining_chunked_cu_seqlens [- 1 ]).to (torch .int32 )
487484 remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens [- 1 ]
488485 # fmt: off
489- remaining_X_chunked = torch .zeros_like (X )[:, : remaining_chunked_input_seq_len , ...] # noqa: E501
490- remaining_dt_chunked = torch .zeros_like (dt )[:, : remaining_chunked_input_seq_len , ...] # noqa: E501
491- remaining_B_chunked = torch .zeros_like (B )[:, : remaining_chunked_input_seq_len , ...] # noqa: E501
492- remaining_C_chunked = torch .zeros_like (C )[:, : remaining_chunked_input_seq_len , ...] # noqa: E501
486+ remaining_X_chunked = torch .zeros_like (X )[:remaining_chunked_input_seq_len , ...] # noqa: E501
487+ remaining_dt_chunked = torch .zeros_like (dt )[:remaining_chunked_input_seq_len , ...] # noqa: E501
488+ remaining_B_chunked = torch .zeros_like (B )[:remaining_chunked_input_seq_len , ...] # noqa: E501
489+ remaining_C_chunked = torch .zeros_like (C )[:remaining_chunked_input_seq_len , ...] # noqa: E501
493490 for i in range (num_sequences ):
494- remaining_chunk_f = lambda x , i : x [:, cu_seqlens [i ] + chunked_seqlens [i ]:cu_seqlens [i + 1 ], ...] # noqa: E501
491+ remaining_chunk_f = lambda x , i : x [cu_seqlens [i ] + chunked_seqlens [i ]:cu_seqlens [i + 1 ], ...] # noqa: E501
495492
496- remaining_X_chunked [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (X , i ) # noqa: E501
497- remaining_dt_chunked [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (dt , i ) # noqa: E501
498- remaining_B_chunked [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (B , i ) # noqa: E501
499- remaining_C_chunked [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (C , i ) # noqa: E501
493+ remaining_X_chunked [remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (X , i ) # noqa: E501
494+ remaining_dt_chunked [remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (dt , i ) # noqa: E501
495+ remaining_B_chunked [remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (B , i ) # noqa: E501
496+ remaining_C_chunked [remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (C , i ) # noqa: E501
500497
501498 # assert input chunking is correct
502499 concat_chunk_f = lambda pt1 , pt2 , i : torch .cat ([
503- pt1 [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ],...],
504- pt2 [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ],...],
500+ pt1 [chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ],...],
501+ pt2 [remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ],...],
505502 ],
506- dim = 1 )
507- concat_batch_f = lambda pt1 , pt2 : torch .cat ([concat_chunk_f (pt1 , pt2 , i ) for i in range (num_sequences )], dim = 1 ) # noqa: E501
503+ dim = 0 )
504+ concat_batch_f = lambda pt1 , pt2 : torch .cat ([concat_chunk_f (pt1 , pt2 , i ) for i in range (num_sequences )], dim = 0 ) # noqa: E501
508505 # fmt: on
509506
510507 assert concat_batch_f (X_chunked , remaining_X_chunked ).equal (X )
@@ -519,7 +516,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
519516 remaining_chunked_cu_seqlens [- 1 ])
520517
521518 Y_chunked = torch .empty_like (remaining_X_chunked )
522- state_chunked = mamba_chunk_scan_combined (
519+ state_chunked = mamba_chunk_scan_combined_varlen (
523520 remaining_X_chunked ,
524521 remaining_dt_chunked ,
525522 A ,
@@ -531,25 +528,24 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
531528 seq_idx = remaining_chunked_seq_idx ,
532529 chunk_indices = chunk_indices ,
533530 chunk_offsets = chunk_offsets ,
534- return_varlen_states = True ,
535531 initial_states = partial_state ,
536532 out = Y_chunked ,
537533 )
538534 Y = concat_batch_f (Y_partial , Y_chunked )
539535
540536 # kernel chunked is same as kernel overall
541537 for i in range (num_sequences ):
542- Y_seq = Y [:, cu_seqlens [i ]:cu_seqlens [i + 1 ], ...]
543- Y_ref_seq = Y_ref [:, cu_seqlens [i ]:cu_seqlens [i + 1 ], ...]
538+ Y_seq = Y [cu_seqlens [i ]:cu_seqlens [i + 1 ], ...]
539+ Y_ref_seq = Y_ref [cu_seqlens [i ]:cu_seqlens [i + 1 ], ...]
544540 torch .testing .assert_close (
545- Y_seq [:, : chunked_seqlens [i ], ...],
546- Y_ref_seq [:, : chunked_seqlens [i ], ...],
541+ Y_seq [:chunked_seqlens [i ], ...],
542+ Y_ref_seq [:chunked_seqlens [i ], ...],
547543 atol = atol ,
548544 rtol = rtol ,
549545 msg = lambda x : f"seq{ i } output part1 " + x ) # noqa: B023
550546 torch .testing .assert_close (
551- Y_seq [:, chunked_seqlens [i ]:, ...],
552- Y_ref_seq [:, chunked_seqlens [i ]:, ...],
547+ Y_seq [chunked_seqlens [i ]:, ...],
548+ Y_ref_seq [chunked_seqlens [i ]:, ...],
553549 atol = atol ,
554550 rtol = rtol ,
555551 msg = lambda x : f"seq{ i } output part2 " + x ) # noqa: B023
0 commit comments