@@ -297,8 +297,6 @@ def forward(
297297 self ,
298298 positions : torch .Tensor ,
299299 hidden_states : torch .Tensor ,
300- kv_cache : torch .Tensor ,
301- attn_metadata : AttentionMetadata ,
302300 ) -> torch .Tensor :
303301 if is_hpu :
304302 # need reshape from tensor(x0, y0) to tensor(x1) for hpu
@@ -353,7 +351,7 @@ def forward(
353351 q = q .reshape (_batch_size , q .shape [0 ] // _batch_size , q .shape [1 ])
354352 k = k .reshape (_batch_size , k .shape [0 ] // _batch_size , k .shape [1 ])
355353 v = v .reshape (_batch_size , v .shape [0 ] // _batch_size , v .shape [1 ])
356- attn_output = self .attn (q , k , v , kv_cache , attn_metadata )
354+ attn_output = self .attn (q , k , v )
357355 if is_hpu :
358356 # need restore from tensor(x0, y0, z0) to tensor(x1, y1) for hpu
359357 attn_output = attn_output .reshape (
@@ -500,8 +498,6 @@ def forward(
500498 self ,
501499 positions : torch .Tensor ,
502500 hidden_states : torch .Tensor ,
503- kv_cache : torch .Tensor ,
504- attn_metadata : AttentionMetadata ,
505501 ) -> torch .Tensor :
506502 if self .q_lora_rank is not None :
507503 ckq = self .q_a_proj (hidden_states )[0 ]
@@ -511,8 +507,7 @@ def forward(
511507 kv_c , k_pe = self .kv_a_proj_with_mqa (hidden_states )[0 ].split (
512508 [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
513509 kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
514- return self .mla_attn (hidden_states_or_q_c , kv_c_normed , k_pe , kv_cache ,
515- attn_metadata )
510+ return self .mla_attn (hidden_states_or_q_c , kv_c_normed , k_pe , output_shape = hidden_states .shape )
516511
517512
518513class DeepseekV2DecoderLayer (nn .Module ):
@@ -581,8 +576,6 @@ def forward(
581576 self ,
582577 positions : torch .Tensor ,
583578 hidden_states : torch .Tensor ,
584- kv_cache : torch .Tensor ,
585- attn_metadata : AttentionMetadata ,
586579 residual : Optional [torch .Tensor ],
587580 ) -> torch .Tensor :
588581 # Self Attention
@@ -595,8 +588,6 @@ def forward(
595588 hidden_states = self .self_attn (
596589 positions = positions ,
597590 hidden_states = hidden_states ,
598- kv_cache = kv_cache ,
599- attn_metadata = attn_metadata ,
600591 )
601592
602593 # Fully Connected
@@ -657,8 +648,6 @@ def forward(
657648 self ,
658649 input_ids : torch .Tensor ,
659650 positions : torch .Tensor ,
660- kv_caches : List [torch .Tensor ],
661- attn_metadata : AttentionMetadata ,
662651 intermediate_tensors : Optional [IntermediateTensors ],
663652 inputs_embeds : Optional [torch .Tensor ] = None ,
664653 ) -> Union [torch .Tensor , IntermediateTensors ]:
@@ -673,12 +662,8 @@ def forward(
673662 hidden_states = intermediate_tensors ["hidden_states" ]
674663 residual = intermediate_tensors ["residual" ]
675664
676- for i in range (self .start_layer , self .end_layer ):
677- layer = self .layers [i ]
678- kvcaches = None if kv_caches is None else kv_caches [i - self .start_layer ]
679- hidden_states , residual = layer (positions , hidden_states ,
680- kvcaches ,
681- attn_metadata , residual )
665+ for layer in self .layers [self .start_layer :self .end_layer ]:
666+ hidden_states , residual = layer (positions , hidden_states , residual )
682667
683668 if not get_pp_group ().is_last_rank :
684669 return IntermediateTensors ({
@@ -715,13 +700,10 @@ def forward(
715700 self ,
716701 input_ids : torch .Tensor ,
717702 positions : torch .Tensor ,
718- kv_caches : List [torch .Tensor ],
719- attn_metadata : AttentionMetadata ,
720703 intermediate_tensors : Optional [IntermediateTensors ] = None ,
721704 inputs_embeds : Optional [torch .Tensor ] = None ,
722705 ) -> Union [torch .Tensor , IntermediateTensors ]:
723- hidden_states = self .model (input_ids , positions , kv_caches ,
724- attn_metadata , intermediate_tensors ,
706+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
725707 inputs_embeds )
726708 return hidden_states
727709
@@ -778,13 +760,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
778760 if "rotary_emb.inv_freq" in name :
779761 continue
780762
781- # TODO(simon): support nextn predict layers
782- if hasattr (self .config , "num_nextn_predict_layers"
783- ) and self .config .num_nextn_predict_layers > 0 :
784- assert self .config .num_nextn_predict_layers == 1
785- layer_idx = self .config .num_hidden_layers
786- if name .startswith (f"model.layers.{ layer_idx } " ):
787- continue
763+ spec_layer = get_spec_layer_idx_from_weight_name (self .config , name )
764+ if spec_layer is not None :
765+ continue # skip spec decode layers for main model
788766
789767 for (param_name , weight_name , shard_id ) in stacked_params_mapping :
790768 # Skip non-stacked layers and experts (experts handled below).
@@ -860,3 +838,15 @@ def load_weights(self, weights: Iterable[Tuple[str,
860838
861839class DeepseekV3ForCausalLM (DeepseekV2ForCausalLM ):
862840 pass
841+
842+
843+ def get_spec_layer_idx_from_weight_name (config : PretrainedConfig ,
844+ weight_name : str ) -> Optional [int ]:
845+ if hasattr (config ,
846+ "num_nextn_predict_layers" ) and (config .num_nextn_predict_layers
847+ > 0 ):
848+ layer_idx = config .num_hidden_layers
849+ for i in range (config .num_nextn_predict_layers ):
850+ if weight_name .startswith (f"model.layers.{ layer_idx + i } ." ):
851+ return layer_idx + i
852+ return None
0 commit comments