1212from vllm .compilation .decorators import support_torch_compile
1313from vllm .config import CacheConfig , PoolerConfig , VllmConfig
1414from vllm .distributed import get_tensor_model_parallel_world_size
15- from vllm .forward_context import get_forward_context
1615from vllm .model_executor .layers .activation import get_act_fn
1716from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
1817 QKVParallelLinear ,
@@ -60,7 +59,6 @@ def __init__(self, config: BertConfig):
6059 def forward (
6160 self ,
6261 input_ids : torch .Tensor ,
63- seq_lens : torch .Tensor ,
6462 position_ids : torch .Tensor ,
6563 token_type_ids : Optional [torch .Tensor ] = None ,
6664 ) -> torch .Tensor :
@@ -119,7 +117,6 @@ def forward(
119117 return pooled_output
120118
121119
122- @support_torch_compile
123120class BertEncoder (nn .Module ):
124121
125122 def __init__ (self , vllm_config : VllmConfig , prefix : str = "" ):
@@ -337,6 +334,7 @@ def forward(self, hidden_states: torch.Tensor,
337334 return hidden_states
338335
339336
337+ @support_torch_compile
340338class BertModel (nn .Module , SupportsQuant ):
341339
342340 is_pooling_model = True
@@ -368,13 +366,9 @@ def forward(
368366 if inputs_embeds is not None :
369367 hidden_states = inputs_embeds
370368 else :
371- attn_metadata = get_forward_context ().attn_metadata
372- assert hasattr (attn_metadata , "seq_lens_tensor" )
373- hidden_states = self .embeddings (
374- input_ids = input_ids ,
375- seq_lens = attn_metadata .seq_lens_tensor ,
376- position_ids = position_ids ,
377- token_type_ids = token_type_ids )
369+ hidden_states = self .embeddings (input_ids = input_ids ,
370+ position_ids = position_ids ,
371+ token_type_ids = token_type_ids )
378372 return self .encoder (hidden_states )
379373
380374 def _load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
@@ -447,7 +441,7 @@ def load_weights(self, weights: Iterable[tuple[str,
447441 return loaded_params
448442
449443
450- class BertEmbeddingModel (nn .Module , SupportsV0Only , SupportsQuant ):
444+ class BertEmbeddingModel (nn .Module , SupportsQuant ):
451445 """A model that uses Bert to provide embedding functionalities.
452446
453447 This class encapsulates the BertModel and provides an interface for
@@ -474,11 +468,13 @@ def forward(
474468 self ,
475469 input_ids : Optional [torch .Tensor ],
476470 positions : torch .Tensor ,
471+ token_type_ids : Optional [torch .Tensor ] = None ,
477472 intermediate_tensors : Optional [IntermediateTensors ] = None ,
478473 inputs_embeds : Optional [torch .Tensor ] = None ,
479474 ) -> torch .Tensor :
480475 return self .model (input_ids = input_ids ,
481476 position_ids = positions ,
477+ token_type_ids = token_type_ids ,
482478 inputs_embeds = inputs_embeds ,
483479 intermediate_tensors = intermediate_tensors )
484480
0 commit comments