|
5 | 5 | from transformers import BertConfig |
6 | 6 |
|
7 | 7 | from vllm.attention import Attention, AttentionMetadata, AttentionType |
| 8 | +from vllm.compilation.decorators import support_torch_compile |
8 | 9 | from vllm.config import CacheConfig, PoolerConfig, VllmConfig |
9 | 10 | from vllm.distributed import get_tensor_model_parallel_world_size |
10 | 11 | from vllm.model_executor.layers.activation import get_act_fn |
@@ -92,14 +93,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
92 | 93 | return pooled_output |
93 | 94 |
|
94 | 95 |
|
| 96 | +@support_torch_compile |
95 | 97 | class BertEncoder(nn.Module): |
96 | 98 |
|
97 | | - def __init__(self, |
98 | | - config: BertConfig, |
99 | | - cache_config: Optional[CacheConfig] = None, |
100 | | - quant_config: Optional[QuantizationConfig] = None, |
101 | | - prefix: str = ""): |
| 99 | + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): |
102 | 100 | super().__init__() |
| 101 | + config = vllm_config.model_config.hf_config |
| 102 | + cache_config = vllm_config.cache_config |
| 103 | + quant_config = vllm_config.quant_config |
103 | 104 | self.layer = nn.ModuleList([ |
104 | 105 | BertLayer(config=config, |
105 | 106 | cache_config=cache_config, |
@@ -336,12 +337,8 @@ def __init__(self, |
336 | 337 | add_pooling_layer: bool = False): |
337 | 338 | super().__init__() |
338 | 339 | config = vllm_config.model_config.hf_config |
339 | | - cache_config = vllm_config.cache_config |
340 | | - quant_config = vllm_config.quant_config |
341 | 340 | self.embeddings = embedding_class(config) |
342 | | - self.encoder = BertEncoder(config, |
343 | | - cache_config, |
344 | | - quant_config, |
| 341 | + self.encoder = BertEncoder(vllm_config=vllm_config, |
345 | 342 | prefix=f"{prefix}.encoder") |
346 | 343 | self.pooler = BertPooler(config) if add_pooling_layer else None |
347 | 344 |
|
|
0 commit comments