Skip to content

Commit 6c85da3

Browse files
authored
[V1]SupportsV0Only protocol for model definitions (#13959)
Signed-off-by: Roger Wang <ywang@roblox.com>
1 parent 67fc426 commit 6c85da3

File tree

19 files changed

+93
-32
lines changed

19 files changed

+93
-32
lines changed

vllm/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,11 @@ def supported_runner_types(self) -> Set[RunnerType]:
10391039
def runner_type(self) -> RunnerType:
10401040
return _TASK_RUNNER[self.task]
10411041

1042+
@property
1043+
def is_v1_compatible(self) -> bool:
1044+
architectures = getattr(self.hf_config, "architectures", [])
1045+
return ModelRegistry.is_v1_compatible(architectures)
1046+
10421047

10431048
class CacheConfig:
10441049
"""Configuration for the KV cache.

vllm/model_executor/models/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
4-
SupportsPP, has_inner_state, supports_lora,
5-
supports_multimodal, supports_pp)
4+
SupportsPP, SupportsV0Only, has_inner_state,
5+
supports_lora, supports_multimodal, supports_pp,
6+
supports_v0_only)
67
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
78
is_pooling_model, is_text_generation_model)
89
from .registry import ModelRegistry
@@ -21,4 +22,6 @@
2122
"supports_multimodal",
2223
"SupportsPP",
2324
"supports_pp",
25+
"SupportsV0Only",
26+
"supports_v0_only",
2427
]

vllm/model_executor/models/bamba.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
from vllm.sequence import IntermediateTensors
3333
from vllm.utils import LayerBlockType
3434

35-
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
35+
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
36+
SupportsV0Only)
3637
from .utils import (is_pp_missing_parameter,
3738
make_empty_intermediate_tensors_factory, make_layers,
3839
maybe_prefix)
@@ -366,7 +367,7 @@ def forward(
366367

367368

368369
class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
369-
IsHybrid):
370+
IsHybrid, SupportsV0Only):
370371
packed_modules_mapping = {
371372
"qkv_proj": [
372373
"q_proj",

vllm/model_executor/models/bart.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from vllm.model_executor.sampling_metadata import SamplingMetadata
4444
from vllm.sequence import IntermediateTensors
4545

46+
from .interfaces import SupportsV0Only
4647
from .utils import maybe_prefix
4748

4849
logger = logging.get_logger(__name__)
@@ -776,7 +777,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
776777
return decoder_outputs
777778

778779

779-
class BartForConditionalGeneration(nn.Module):
780+
class BartForConditionalGeneration(nn.Module, SupportsV0Only):
780781
base_model_prefix = "model"
781782

782783
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm/model_executor/models/bert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from vllm.transformers_utils.config import (
2727
get_cross_encoder_activation_function)
2828

29-
from .interfaces import SupportsCrossEncoding
29+
from .interfaces import SupportsCrossEncoding, SupportsV0Only
3030
from .utils import WeightsMapper, maybe_prefix
3131

3232

@@ -385,7 +385,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
385385
return loaded_params
386386

387387

388-
class BertEmbeddingModel(nn.Module):
388+
class BertEmbeddingModel(nn.Module, SupportsV0Only):
389389
"""A model that uses Bert to provide embedding functionalities.
390390
391391
This class encapsulates the BertModel and provides an interface for

vllm/model_executor/models/florence2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
3030
from vllm.sequence import IntermediateTensors
3131

32-
from .interfaces import SupportsMultiModal
32+
from .interfaces import SupportsMultiModal, SupportsV0Only
3333
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
3434

3535

@@ -651,7 +651,7 @@ def forward(
651651
return decoder_outputs
652652

653653

654-
class Florence2LanguageForConditionalGeneration(nn.Module):
654+
class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
655655

656656
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
657657
super().__init__()

vllm/model_executor/models/gritlm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
PoolingSequenceGroupOutput)
2020
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
2121

22+
from .interfaces import SupportsV0Only
23+
2224
logger = init_logger(__name__)
2325

2426

@@ -177,7 +179,7 @@ def forward(
177179
return PoolerOutput(outputs=pooled_outputs)
178180

179181

180-
class GritLM(LlamaForCausalLM):
182+
class GritLM(LlamaForCausalLM, SupportsV0Only):
181183
"""This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
182184
183185
The class inherits from LlamaForCausalLM and provides a custom pooling

vllm/model_executor/models/interfaces.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,3 +498,29 @@ def supports_transcription(
498498
return isinstance(model, SupportsTranscription)
499499

500500
return isinstance(model, SupportsTranscription)
501+
502+
503+
@runtime_checkable
504+
class SupportsV0Only(Protocol):
505+
"""Models with this interface are not compatible with V1 vLLM."""
506+
507+
supports_v0_only: ClassVar[Literal[True]] = True
508+
509+
510+
@overload
511+
def supports_v0_only(model: Type[object]) -> TypeIs[Type[SupportsV0Only]]:
512+
...
513+
514+
515+
@overload
516+
def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
517+
...
518+
519+
520+
def supports_v0_only(
521+
model: Union[Type[object], object],
522+
) -> Union[TypeIs[Type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
523+
if isinstance(model, type):
524+
return isinstance(model, SupportsV0Only)
525+
526+
return isinstance(model, SupportsV0Only)

vllm/model_executor/models/jamba.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from vllm.sequence import IntermediateTensors, PoolerOutput
3131
from vllm.utils import LayerBlockType
3232

33-
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
33+
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
34+
SupportsV0Only)
3435
from .utils import (is_pp_missing_parameter,
3536
make_empty_intermediate_tensors_factory, make_layers,
3637
maybe_prefix)
@@ -353,7 +354,7 @@ def forward(
353354

354355

355356
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
356-
IsHybrid):
357+
IsHybrid, SupportsV0Only):
357358
packed_modules_mapping = {
358359
"qkv_proj": [
359360
"q_proj",

vllm/model_executor/models/mamba.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
2020
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2121
from vllm.model_executor.models.interfaces import (HasInnerState,
22-
IsAttentionFree, SupportsPP)
22+
IsAttentionFree, SupportsPP,
23+
SupportsV0Only)
2324
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
2425
MambaCacheParams)
2526
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -155,7 +156,8 @@ def forward(
155156
return hidden_states
156157

157158

158-
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
159+
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
160+
SupportsV0Only):
159161

160162
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
161163
config = vllm_config.model_config.hf_config

0 commit comments

Comments
 (0)