1111from vllm .config import CompilationConfig , CompilationLevel
1212from vllm .distributed import cleanup_dist_env_and_memory
1313from vllm .forward_context import get_forward_context
14- from vllm .model_executor .models .gemma3n import Gemma3nForConditionalGeneration
14+ from vllm .model_executor .models .gemma3n_mm import (
15+ Gemma3nForConditionalGeneration )
1516from vllm .model_executor .models .registry import ModelRegistry
1617from vllm .model_executor .models .utils import extract_layer_index
1718from vllm .sequence import IntermediateTensors
@@ -32,12 +33,13 @@ def forward(
3233 inputs_embeds : Optional [torch .Tensor ] = None ,
3334 ** kwargs ,
3435 ) -> Union [torch .Tensor , IntermediateTensors ]:
35- hidden_states = self .model (input_ids , positions , intermediate_tensors ,
36- inputs_embeds , ** kwargs )
36+ hidden_states = super ().forward (input_ids , positions ,
37+ intermediate_tensors , inputs_embeds ,
38+ ** kwargs )
3739 attn_metadata = get_forward_context ().attn_metadata
3840 # attn_metadata is None during dummy runs
3941 if (attn_metadata is not None
40- and self .cache_config .kv_sharing_fast_prefill ):
42+ and self .language_model . cache_config .kv_sharing_fast_prefill ):
4143 assert isinstance (attn_metadata , dict ) # true in V1
4244 # Gemma3n-E2B has 30 layers, with last 20 layers being
4345 # cross-decoder layers. Check attention metadata is correct
@@ -52,7 +54,7 @@ def forward(
5254
5355 # Last layer will be a KV sharing layer
5456 layer_attn_metadata = attn_metadata [
55- self .model . language_model .layers [- 1 ].self_attn .attn .layer_name ]
57+ self .language_model . model .layers [- 1 ].self_attn .attn .layer_name ]
5658 logits_indices_padded = (layer_attn_metadata .logits_indices_padded )
5759 assert logits_indices_padded is not None
5860 num_logits_indices = layer_attn_metadata .num_logits_indices
0 commit comments