1717# See the License for the specific language governing permissions and 
1818# limitations under the License. 
1919
20- from  typing  import  Optional 
20+ from  typing  import  List ,  Optional 
2121
2222import  torch 
2323import  torch .nn  as  nn 
2424from  transformers  import  PretrainedConfig 
25+ from  vllm .attention .backends .abstract  import  AttentionMetadata 
2526from  vllm .config  import  CacheConfig , ModelConfig , VllmConfig 
2627from  vllm .model_executor .layers .layernorm  import  RMSNorm 
2728from  vllm .model_executor .layers .logits_processor  import  LogitsProcessor 
3435    SharedHead )
3536from  vllm .model_executor .models .utils  import  maybe_prefix 
3637from  vllm .model_executor .sampling_metadata  import  SamplingMetadata 
38+ from  vllm .sequence  import  IntermediateTensors 
3739
3840from  .deepseek_v2  import  CustomDeepseekV2DecoderLayer 
3941
@@ -69,6 +71,8 @@ def forward(
6971        self ,
7072        input_ids : torch .Tensor ,
7173        positions : torch .Tensor ,
74+         kv_cache : torch .Tensor ,
75+         attn_metadata : AttentionMetadata ,
7276        previous_hidden_states : torch .Tensor ,
7377        inputs_embeds : Optional [torch .Tensor ] =  None ,
7478        spec_step_index : int  =  0 ,
@@ -88,6 +92,8 @@ def forward(
8892
8993        hidden_states , residual  =  self .mtp_block (positions = positions ,
9094                                                 hidden_states = hidden_states ,
95+                                                  kv_cache = kv_cache ,
96+                                                  attn_metadata = attn_metadata ,
9197                                                 residual = None )
9298        hidden_states  =  residual  +  hidden_states 
9399        return  hidden_states 
@@ -125,14 +131,20 @@ def forward(
125131        self ,
126132        input_ids : torch .Tensor ,
127133        positions : torch .Tensor ,
134+         kv_caches : torch .Tensor ,
135+         attn_metadata : AttentionMetadata ,
128136        previous_hidden_states : torch .Tensor ,
129137        inputs_embeds : Optional [torch .Tensor ] =  None ,
130138        spec_step_idx : int  =  0 ,
131139    ) ->  torch .Tensor :
132140        current_step_idx  =  (spec_step_idx  %  self .num_mtp_layers )
141+         step_kv_cache  =  kv_caches [
142+             current_step_idx ] if  kv_caches  is  not   None  else  None 
133143        return  self .layers_list [current_step_idx ](
134144            input_ids ,
135145            positions ,
146+             step_kv_cache ,
147+             attn_metadata ,
136148            previous_hidden_states ,
137149            inputs_embeds ,
138150            current_step_idx ,
@@ -170,3 +182,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
170182                                                           prefix , "model" ))
171183
172184        self .sampler  =  get_sampler ()
185+ 
186+     def  forward (
187+         self ,
188+         input_ids : torch .Tensor ,
189+         positions : torch .Tensor ,
190+         kv_caches : Optional [List [torch .Tensor ]] =  None ,
191+         attn_metadata : Optional [AttentionMetadata ] =  None ,
192+         previous_hidden_states : Optional [torch .Tensor ] =  None ,
193+         intermediate_tensors : Optional [IntermediateTensors ] =  None ,
194+         inputs_embeds : Optional [torch .Tensor ] =  None ,
195+         spec_step_idx : int  =  0 ,
196+     ) ->  torch .Tensor :
197+         hidden_states  =  self .model (input_ids , positions , kv_caches ,
198+                                    attn_metadata , previous_hidden_states ,
199+                                    inputs_embeds , spec_step_idx )
200+         return  hidden_states 
0 commit comments