Skip to content

Commit 571841b

Browse files
authored
[torch.compile] support encoder based models (#10613)
Signed-off-by: youkaichao <youkaichao@gmail.com>
1 parent 7ea3cd7 commit 571841b

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

tests/compile/test_basic_correctness.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ class TestSetting:
6262
method="encode",
6363
fullgraph=True,
6464
),
65+
# encoder-based embedding model (BERT)
66+
TestSetting(
67+
model="BAAI/bge-base-en-v1.5",
68+
model_args=["--task", "embedding"],
69+
pp_size=1,
70+
tp_size=1,
71+
attn_backend="XFORMERS",
72+
method="encode",
73+
fullgraph=True,
74+
),
6575
# vision language model
6676
TestSetting(
6777
model="microsoft/Phi-3.5-vision-instruct",

vllm/model_executor/models/bert.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from transformers import BertConfig
66

77
from vllm.attention import Attention, AttentionMetadata, AttentionType
8+
from vllm.compilation.decorators import support_torch_compile
89
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
910
from vllm.distributed import get_tensor_model_parallel_world_size
1011
from vllm.model_executor.layers.activation import get_act_fn
@@ -92,14 +93,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
9293
return pooled_output
9394

9495

96+
@support_torch_compile
9597
class BertEncoder(nn.Module):
9698

97-
def __init__(self,
98-
config: BertConfig,
99-
cache_config: Optional[CacheConfig] = None,
100-
quant_config: Optional[QuantizationConfig] = None,
101-
prefix: str = ""):
99+
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
102100
super().__init__()
101+
config = vllm_config.model_config.hf_config
102+
cache_config = vllm_config.cache_config
103+
quant_config = vllm_config.quant_config
103104
self.layer = nn.ModuleList([
104105
BertLayer(config=config,
105106
cache_config=cache_config,
@@ -336,12 +337,8 @@ def __init__(self,
336337
add_pooling_layer: bool = False):
337338
super().__init__()
338339
config = vllm_config.model_config.hf_config
339-
cache_config = vllm_config.cache_config
340-
quant_config = vllm_config.quant_config
341340
self.embeddings = embedding_class(config)
342-
self.encoder = BertEncoder(config,
343-
cache_config,
344-
quant_config,
341+
self.encoder = BertEncoder(vllm_config=vllm_config,
345342
prefix=f"{prefix}.encoder")
346343
self.pooler = BertPooler(config) if add_pooling_layer else None
347344

0 commit comments

Comments
 (0)