|
25 | 25 | from transformers import BatchFeature |
26 | 26 |
|
27 | 27 | from vllm.config import VllmConfig |
28 | | -from vllm.model_executor.layers.pooler import (AllPool, PoolerHead, |
29 | | - PoolerIdentity, SimplePooler) |
| 28 | +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler |
30 | 29 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
31 | 30 | from vllm.model_executor.models.interfaces import ( |
32 | | - IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput) |
| 31 | + IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput, |
| 32 | + default_pooling_type) |
33 | 33 | from vllm.model_executor.models.utils import AutoWeightsLoader |
34 | 34 | from vllm.multimodal import MULTIMODAL_REGISTRY |
35 | 35 | from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, |
@@ -142,6 +142,7 @@ def apply( |
142 | 142 | ) |
143 | 143 |
|
144 | 144 |
|
| 145 | +@default_pooling_type("All") |
145 | 146 | @MULTIMODAL_REGISTRY.register_processor( |
146 | 147 | PrithviGeoSpatialMAEMultiModalProcessor, |
147 | 148 | info=PrithviGeoSpatialMAEProcessingInfo, |
@@ -198,7 +199,11 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): |
198 | 199 | "Only SemanticSegmentationTask is supported for now " |
199 | 200 | "by PrithviGeospatialMAE.") |
200 | 201 |
|
201 | | - self.pooler = SimplePooler(AllPool(), PoolerHead(PoolerIdentity())) |
| 202 | + pooler_config = vllm_config.model_config.pooler_config |
| 203 | + assert pooler_config is not None |
| 204 | + |
| 205 | + self.pooler = DispatchPooler( |
| 206 | + {"encode": Pooler.for_encode(pooler_config)}, ) |
202 | 207 |
|
203 | 208 | def _parse_and_validate_multimodal_data( |
204 | 209 | self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
0 commit comments