1717# See the License for the specific language governing permissions and
1818# limitations under the License.
1919
20- from typing import Optional
20+ from typing import List , Optional
2121
2222import torch
2323import torch .nn as nn
2424from transformers import PretrainedConfig
25+ from vllm .attention .backends .abstract import AttentionMetadata
2526from vllm .config import CacheConfig , ModelConfig , VllmConfig
2627from vllm .model_executor .layers .layernorm import RMSNorm
2728from vllm .model_executor .layers .logits_processor import LogitsProcessor
3435 SharedHead )
3536from vllm .model_executor .models .utils import maybe_prefix
3637from vllm .model_executor .sampling_metadata import SamplingMetadata
38+ from vllm .sequence import IntermediateTensors
3739
3840from .deepseek_v2 import CustomDeepseekV2DecoderLayer
3941
@@ -69,6 +71,8 @@ def forward(
6971 self ,
7072 input_ids : torch .Tensor ,
7173 positions : torch .Tensor ,
74+ kv_cache : torch .Tensor ,
75+ attn_metadata : AttentionMetadata ,
7276 previous_hidden_states : torch .Tensor ,
7377 inputs_embeds : Optional [torch .Tensor ] = None ,
7478 spec_step_index : int = 0 ,
@@ -88,6 +92,8 @@ def forward(
8892
8993 hidden_states , residual = self .mtp_block (positions = positions ,
9094 hidden_states = hidden_states ,
95+ kv_cache = kv_cache ,
96+ attn_metadata = attn_metadata ,
9197 residual = None )
9298 hidden_states = residual + hidden_states
9399 return hidden_states
@@ -125,14 +131,20 @@ def forward(
125131 self ,
126132 input_ids : torch .Tensor ,
127133 positions : torch .Tensor ,
134+ kv_caches : torch .Tensor ,
135+ attn_metadata : AttentionMetadata ,
128136 previous_hidden_states : torch .Tensor ,
129137 inputs_embeds : Optional [torch .Tensor ] = None ,
130138 spec_step_idx : int = 0 ,
131139 ) -> torch .Tensor :
132140 current_step_idx = (spec_step_idx % self .num_mtp_layers )
141+ step_kv_cache = kv_caches [
142+ current_step_idx ] if kv_caches is not None else None
133143 return self .layers_list [current_step_idx ](
134144 input_ids ,
135145 positions ,
146+ step_kv_cache ,
147+ attn_metadata ,
136148 previous_hidden_states ,
137149 inputs_embeds ,
138150 current_step_idx ,
@@ -170,3 +182,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
170182 prefix , "model" ))
171183
172184 self .sampler = get_sampler ()
185+
186+ def forward (
187+ self ,
188+ input_ids : torch .Tensor ,
189+ positions : torch .Tensor ,
190+ kv_caches : Optional [List [torch .Tensor ]] = None ,
191+ attn_metadata : Optional [AttentionMetadata ] = None ,
192+ previous_hidden_states : Optional [torch .Tensor ] = None ,
193+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
194+ inputs_embeds : Optional [torch .Tensor ] = None ,
195+ spec_step_idx : int = 0 ,
196+ ) -> torch .Tensor :
197+ hidden_states = self .model (input_ids , positions , kv_caches ,
198+ attn_metadata , previous_hidden_states ,
199+ inputs_embeds , spec_step_idx )
200+ return hidden_states
0 commit comments