|
7 | 7 |
|
8 | 8 | import torch |
9 | 9 | import torch.nn.functional as F |
| 10 | +from typing_extensions import assert_never |
10 | 11 |
|
11 | 12 | from vllm.config import VllmConfig |
12 | 13 | from vllm.logger import init_logger |
13 | | -from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingType, |
| 14 | +from vllm.model_executor.layers.pooler import (HAS_TRITON, Pooler, PoolingTask, |
| 15 | + PoolingType, |
14 | 16 | extract_vision_tokens_kernel) |
15 | 17 | # yapf: disable |
16 | 18 | from vllm.model_executor.pooling_metadata import ( |
17 | 19 | PoolingMetadata as V0PoolingMetadata) |
18 | 20 | from vllm.model_executor.pooling_metadata import PoolingTensors |
19 | 21 | # yapf: enable |
20 | 22 | from vllm.multimodal import MULTIMODAL_REGISTRY |
| 23 | +from vllm.pooling_params import PoolingParams |
21 | 24 | from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput |
22 | 25 | from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata |
23 | 26 |
|
|
36 | 39 | PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] |
37 | 40 |
|
38 | 41 |
|
39 | | -@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, |
40 | | - info=Qwen2VLProcessingInfo, |
41 | | - dummy_inputs=Qwen2VLDummyInputsBuilder) |
42 | | -class JinaVLForEmbedding(Qwen2VLForConditionalGeneration, |
43 | | - SupportsCrossEncoding, SupportsMultiModal): |
44 | | - # Weight mapping for HuggingFace checkpoint compatibility |
45 | | - weight_mapper = WeightsMapper( |
46 | | - orig_to_new_prefix={ |
47 | | - "model.": "language_model.model.", |
48 | | - "visual.": "visual.", |
49 | | - "lm_head.": "language_model.lm_head.", |
50 | | - }) |
51 | | - |
52 | | - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
53 | | - super().__init__(vllm_config=vllm_config, |
54 | | - prefix=maybe_prefix(prefix, "qwen2_vl")) |
| 42 | +class JinaVLPooler(Pooler): |
| 43 | + """Vision-aware pooler for Jina V4 with special vision token handling.""" |
55 | 44 |
|
| 45 | + def __init__(self, |
| 46 | + vllm_config: VllmConfig, |
| 47 | + pooling_backend: str = "pytorch"): |
| 48 | + super().__init__() |
56 | 49 | self.hidden_size = vllm_config.model_config.hf_config.hidden_size |
57 | | - pooler_config = vllm_config.model_config.pooler_config |
| 50 | + self.pooling_backend = pooling_backend |
58 | 51 | self.observability_config = vllm_config.observability_config |
59 | 52 |
|
60 | | - # Configuration for vision pooling backend |
61 | | - self.pooling_backend = getattr(vllm_config.model_config, |
62 | | - "jina_pooling_backend", "pytorch") |
63 | | - if self.pooling_backend not in ("triton", "pytorch"): |
64 | | - logger.warning( |
65 | | - "Invalid jina_pooling_backend '%s'. " |
66 | | - "Must be 'triton' or 'pytorch'. Defaulting to 'pytorch'.", |
67 | | - self.pooling_backend) |
68 | | - self.pooling_backend = "pytorch" |
| 53 | + # Performance tracking |
| 54 | + self._pooling_time_ms = 0.0 |
| 55 | + self._pooling_count = 0 |
69 | 56 |
|
70 | 57 | # Initialize base pooler for fallback |
| 58 | + pooler_config = vllm_config.model_config.pooler_config |
71 | 59 | self._base_pooler = Pooler.from_config_with_defaults( |
72 | 60 | pooler_config, |
73 | 61 | pooling_type=PoolingType.MEAN, |
74 | 62 | normalize=True, |
75 | 63 | softmax=False) |
76 | 64 |
|
77 | | - # Performance tracking |
78 | | - self._pooling_time_ms = 0.0 |
79 | | - self._pooling_count = 0 |
| 65 | + def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: |
| 66 | + """Return pooling params for embedding task.""" |
| 67 | + if task == "embed": |
| 68 | + return PoolingParams() |
80 | 69 |
|
81 | | - logger.info("Initialized JinaVLForEmbedding with thread-safe pooling") |
| 70 | + # The equalities are split up to keep mypy happy |
| 71 | + if task == "encode" or task == "classify" or task == "score": |
| 72 | + return None |
| 73 | + |
| 74 | + assert_never(task) |
| 75 | + |
| 76 | + def forward( |
| 77 | + self, |
| 78 | + hidden_states: Union[torch.Tensor, list[torch.Tensor]], |
| 79 | + pooling_metadata: PoolingMetadata, |
| 80 | + ) -> PoolerOutput: |
| 81 | + """Apply vision-aware pooling to hidden states.""" |
| 82 | + start_time = time.time() if self.observability_config else None |
| 83 | + |
| 84 | + # Validate inputs |
| 85 | + if hidden_states is None or hidden_states.numel() == 0: |
| 86 | + logger.warning("Empty hidden states received") |
| 87 | + return PoolerOutput(outputs=[]) |
| 88 | + |
| 89 | + # Extract token IDs safely from metadata |
| 90 | + token_ids_list, seq_ids = self._extract_token_ids_safe( |
| 91 | + pooling_metadata) |
| 92 | + |
| 93 | + if not token_ids_list: |
| 94 | + logger.warning("No valid sequences found for pooling") |
| 95 | + # Fallback to base pooler |
| 96 | + return self._base_pooler(hidden_states, pooling_metadata) |
| 97 | + |
| 98 | + # Get prompt lengths based on metadata type |
| 99 | + if isinstance(pooling_metadata, V1PoolingMetadata): |
| 100 | + prompt_lens = pooling_metadata.prompt_lens |
| 101 | + else: |
| 102 | + prompt_lens = PoolingTensors.from_pooling_metadata( |
| 103 | + pooling_metadata, hidden_states.device).prompt_lens |
| 104 | + |
| 105 | + # Validate lengths match |
| 106 | + assert len(token_ids_list) == len(prompt_lens), ( |
| 107 | + f"Mismatch: {len(token_ids_list)} sequences vs " |
| 108 | + f"{len(prompt_lens)} lengths") |
| 109 | + |
| 110 | + # Apply pooling based on configured backend |
| 111 | + if self.pooling_backend == "triton": |
| 112 | + pooled_data = self._apply_vision_pooling_optimized( |
| 113 | + hidden_states, token_ids_list, prompt_lens) |
| 114 | + else: # self.pooling_backend == "pytorch" |
| 115 | + pooled_data = self._apply_vision_pooling_pytorch( |
| 116 | + hidden_states, token_ids_list, prompt_lens) |
| 117 | + |
| 118 | + # Build output |
| 119 | + pooled_outputs = [ |
| 120 | + PoolingSequenceGroupOutput(data) for data in pooled_data |
| 121 | + ] |
| 122 | + |
| 123 | + # Record metrics |
| 124 | + if self.observability_config: |
| 125 | + elapsed_ms = (time.time() - start_time) * 1000 |
| 126 | + self._pooling_time_ms += elapsed_ms |
| 127 | + self._pooling_count += 1 |
| 128 | + |
| 129 | + if self._pooling_count % 100 == 0: |
| 130 | + avg_time = self._pooling_time_ms / self._pooling_count |
| 131 | + logger.debug("Average pooling time: %.2fms", avg_time) |
| 132 | + |
| 133 | + return PoolerOutput(outputs=pooled_outputs) |
82 | 134 |
|
83 | 135 | def _extract_token_ids_safe( |
84 | 136 | self, pooling_metadata: PoolingMetadata |
@@ -239,64 +291,41 @@ def _apply_vision_pooling_pytorch( |
239 | 291 |
|
240 | 292 | return pooled_outputs |
241 | 293 |
|
242 | | - def pooler( |
243 | | - self, |
244 | | - hidden_states: torch.Tensor, |
245 | | - pooling_metadata: PoolingMetadata, |
246 | | - ) -> Optional[PoolerOutput]: |
247 | | - """Thread-safe pooler with production error handling.""" |
248 | | - start_time = time.time() if self.observability_config else None |
249 | | - |
250 | | - # Validate inputs |
251 | | - if hidden_states is None or hidden_states.numel() == 0: |
252 | | - logger.warning("Empty hidden states received") |
253 | | - return PoolerOutput(outputs=[]) |
254 | | - |
255 | | - # Extract token IDs safely from metadata |
256 | | - token_ids_list, seq_ids = self._extract_token_ids_safe( |
257 | | - pooling_metadata) |
258 | 294 |
|
259 | | - if not token_ids_list: |
260 | | - logger.warning("No valid sequences found for pooling") |
261 | | - # Fallback to base pooler |
262 | | - return self._base_pooler(hidden_states, pooling_metadata) |
263 | | - |
264 | | - # Get prompt lengths based on metadata type |
265 | | - if isinstance(pooling_metadata, V1PoolingMetadata): |
266 | | - prompt_lens = pooling_metadata.prompt_lens |
267 | | - else: |
268 | | - prompt_lens = PoolingTensors.from_pooling_metadata( |
269 | | - pooling_metadata, hidden_states.device).prompt_lens |
| 295 | +@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, |
| 296 | + info=Qwen2VLProcessingInfo, |
| 297 | + dummy_inputs=Qwen2VLDummyInputsBuilder) |
| 298 | +class JinaVLForEmbedding(Qwen2VLForConditionalGeneration, |
| 299 | + SupportsCrossEncoding, SupportsMultiModal): |
270 | 300 |
|
271 | | - # Validate lengths match |
272 | | - assert len(token_ids_list) == len(prompt_lens), ( |
273 | | - f"Mismatch: {len(token_ids_list)} sequences vs " |
274 | | - f"{len(prompt_lens)} lengths") |
| 301 | + is_pooling_model = True |
275 | 302 |
|
276 | | - # Apply pooling based on configured backend |
277 | | - if self.pooling_backend == "triton": |
278 | | - pooled_data = self._apply_vision_pooling_optimized( |
279 | | - hidden_states, token_ids_list, prompt_lens) |
280 | | - else: # self.pooling_backend == "pytorch" |
281 | | - pooled_data = self._apply_vision_pooling_pytorch( |
282 | | - hidden_states, token_ids_list, prompt_lens) |
| 303 | + # Weight mapping for HuggingFace checkpoint compatibility |
| 304 | + weight_mapper = WeightsMapper( |
| 305 | + orig_to_new_prefix={ |
| 306 | + "model.": "language_model.model.", |
| 307 | + "visual.": "visual.", |
| 308 | + "lm_head.": "language_model.lm_head.", |
| 309 | + }) |
283 | 310 |
|
284 | | - # Build output |
285 | | - pooled_outputs = [ |
286 | | - PoolingSequenceGroupOutput(data) for data in pooled_data |
287 | | - ] |
| 311 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 312 | + super().__init__(vllm_config=vllm_config, |
| 313 | + prefix=maybe_prefix(prefix, "qwen2_vl")) |
288 | 314 |
|
289 | | - # Record metrics |
290 | | - if self.observability_config: |
291 | | - elapsed_ms = (time.time() - start_time) * 1000 |
292 | | - self._pooling_time_ms += elapsed_ms |
293 | | - self._pooling_count += 1 |
| 315 | + # Configuration for vision pooling backend |
| 316 | + self.pooling_backend = getattr(vllm_config.model_config, |
| 317 | + "jina_pooling_backend", "pytorch") |
| 318 | + if self.pooling_backend not in ("triton", "pytorch"): |
| 319 | + logger.warning( |
| 320 | + "Invalid jina_pooling_backend '%s'. " |
| 321 | + "Must be 'triton' or 'pytorch'. Defaulting to 'pytorch'.", |
| 322 | + self.pooling_backend) |
| 323 | + self.pooling_backend = "pytorch" |
294 | 324 |
|
295 | | - if self._pooling_count % 100 == 0: |
296 | | - avg_time = self._pooling_time_ms / self._pooling_count |
297 | | - logger.debug("Average pooling time: %.2fms", avg_time) |
| 325 | + # Initialize the vision-aware pooler |
| 326 | + self.pooler = JinaVLPooler(vllm_config, self.pooling_backend) |
298 | 327 |
|
299 | | - return PoolerOutput(outputs=pooled_outputs) |
| 328 | + logger.info("Initialized JinaVLForEmbedding with thread-safe pooling") |
300 | 329 |
|
301 | 330 | def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
302 | 331 | """Load weights with validation and error handling.""" |
|
0 commit comments