@@ -412,71 +412,13 @@ def forward_cuda(
412412 dim = - 1 ,
413413 )
414414
415- # 2. Convolution sequence transformation
416415 conv_weights = self .conv1d .weight .view (self .conv1d .weight .size (0 ),
417416 self .conv1d .weight .size (2 ))
418417
419- # causal_conv1d_fn deals with both prefill and decode if input
420- # has prefill requests.
421- if has_prefill :
422- # |---------- N-1 iteration --------|
423- # |---------------- N iteration ---------------------|
424- # |- tokenA -|......................|-- newTokens ---|
425- # |---------- context_len ----------|
426- # |-------------------- seq_len ---------------------|
427- # |-- query_len ---|
428-
429- # - "cache_indices" updates the conv_state cache in positions
430- # pointed to by "mamba_cache_params.state_indices_tensor"
431- hidden_states_B_C = causal_conv1d_fn (
432- hidden_states_B_C .transpose (0 , 1 ),
433- conv_weights ,
434- self .conv1d .bias ,
435- activation = self .activation ,
436- conv_states = mamba_cache_params .conv_state ,
437- has_initial_state = mamba2_metadata .has_initial_states ,
438- cache_indices = mamba_cache_params .state_indices_tensor ,
439- query_start_loc = attn_metadata .query_start_loc ).transpose (
440- 0 , 1 )[:seq_len ]
441-
442- # TODO: Why is this needed?
443- hidden_states_B_C = hidden_states_B_C .contiguous ()
444- else :
445- hidden_states_B_C = causal_conv1d_update (
446- hidden_states_B_C ,
447- mamba_cache_params .conv_state ,
448- conv_weights ,
449- self .conv1d .bias ,
450- self .activation ,
451- conv_state_indices = mamba_cache_params .state_indices_tensor )
452-
453- # - get hidden_states, B and C after depthwise convolution.
454- hidden_states , B , C = torch .split (
455- hidden_states_B_C ,
456- [
457- self .intermediate_size // self .tp_size ,
458- groups_time_state_size // self .tp_size ,
459- groups_time_state_size // self .tp_size ,
460- ],
461- dim = - 1 ,
462- )
463-
464- # 3. State Space Model sequence transformation
465-
466418 # Separate prefill and decode by splitting varlen input
467419 # Split along token dimension
468- hidden_states_p , hidden_states_d = torch .split (
469- hidden_states ,
470- [num_prefill_tokens , num_decodes ],
471- dim = 0 ,
472- )
473- B_p , B_d = torch .split (
474- B ,
475- [num_prefill_tokens , num_decodes ],
476- dim = 0 ,
477- )
478- C_p , C_d = torch .split (
479- C ,
420+ hidden_states_B_C_p , hidden_states_B_C_d = torch .split (
421+ hidden_states_B_C ,
480422 [num_prefill_tokens , num_decodes ],
481423 dim = 0 ,
482424 )
@@ -491,18 +433,50 @@ def forward_cuda(
491433 [num_prefills , num_decodes ],
492434 dim = 0 ,
493435 )
436+ query_start_loc_p = (attn_metadata .query_start_loc [:num_prefills + 1 ]
437+ if has_prefill else None )
494438
495- hidden_states_list = []
439+ # - get hidden_states, B and C after depthwise convolution.
440+ split_hidden_states_B_C_fn = lambda hidden_states_B_C : torch .split (
441+ hidden_states_B_C ,
442+ [
443+ self .intermediate_size // self .tp_size ,
444+ groups_time_state_size // self .tp_size ,
445+ groups_time_state_size // self .tp_size ,
446+ ],
447+ dim = - 1 ,
448+ )
449+
450+ ssd_output_list = []
496451
497452 # Process prefill requests
498453 if has_prefill :
454+ # 2. Convolution sequence transformation
455+ # - "cache_indices" updates the conv_state cache in positions
456+ # pointed to by "mamba_cache_params.state_indices_tensor"
457+ hidden_states_B_C_p = causal_conv1d_fn (
458+ hidden_states_B_C_p .transpose (0 , 1 ),
459+ conv_weights ,
460+ self .conv1d .bias ,
461+ activation = self .activation ,
462+ conv_states = mamba_cache_params .conv_state ,
463+ has_initial_state = mamba2_metadata .has_initial_states ,
464+ cache_indices = state_indices_tensor_p ,
465+ query_start_loc = query_start_loc_p ).transpose (
466+ 0 , 1 )[:num_prefill_tokens ]
467+
468+ # TODO: Why is this needed?
469+ hidden_states_B_C_p = hidden_states_B_C_p .contiguous ()
470+ hidden_states_p , B_p , C_p = split_hidden_states_B_C_fn (
471+ hidden_states_B_C_p )
472+
473+ # 3. State Space Model sequence transformation
499474 initial_states = None
500475 if (mamba2_metadata .has_initial_states is not None
501476 and mamba2_metadata .prep_initial_states ):
502477 # making a copy of the states
503478 initial_states = torch .where (
504- mamba2_metadata .has_initial_states [:num_prefills , None ,
505- None , None ],
479+ mamba2_metadata .has_initial_states [:, None , None , None ],
506480 mamba_cache_params .ssm_state [state_indices_tensor_p ], 0 )
507481
508482 scan_output , varlen_state = mamba_chunk_scan_combined (
@@ -535,10 +509,23 @@ def forward_cuda(
535509 mamba_cache_params .ssm_state [state_indices_tensor_p ] = varlen_state
536510
537511 # - reshape
538- hidden_states_list .append (scan_output .view (num_prefill_tokens , - 1 ))
512+ ssd_output_list .append (scan_output .view (num_prefill_tokens , - 1 ))
539513
540514 # Process decode requests
541515 if has_decode :
516+ # 2. Convolution sequence transformation
517+ hidden_states_B_C_d = causal_conv1d_update (
518+ hidden_states_B_C_d ,
519+ mamba_cache_params .conv_state ,
520+ conv_weights ,
521+ self .conv1d .bias ,
522+ self .activation ,
523+ conv_state_indices = state_indices_tensor_d )
524+
525+ hidden_states_d , B_d , C_d = split_hidden_states_B_C_fn (
526+ hidden_states_B_C_d )
527+
528+ # 3. State Space Model sequence transformation
542529 n_groups = self .n_groups // self .tp_size
543530 A_d = self .A [:, None , ...][:, :, None ].expand (
544531 - 1 , self .head_dim , self .ssm_state_size ).to (dtype = torch .float32 )
@@ -567,12 +554,12 @@ def forward_cuda(
567554 dt_softplus = True ,
568555 state_batch_indices = state_indices_tensor_d ,
569556 )
570- hidden_states_list .append (
557+ ssd_output_list .append (
571558 hidden_states_d .view (- 1 , (self .num_heads // self .tp_size ) *
572559 self .head_dim ))
573560
574561 # Merge prefill and decode outputs before passing to gated MLP
575- hidden_states = torch .vstack (hidden_states_list )
562+ hidden_states = torch .vstack (ssd_output_list )
576563
577564 # 4. gated MLP
578565 hidden_states = self .norm (hidden_states , gate )
0 commit comments