|
55 | 55 | from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config |
56 | 56 |
|
57 | 57 | from .interfaces import MultiModalEmbeddings, SupportsMultiModal |
58 | | -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, |
59 | | - maybe_prefix, merge_multimodal_embeddings) |
| 58 | +from .llama4 import Llama4ForCausalLM |
| 59 | +from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, |
| 60 | + merge_multimodal_embeddings) |
60 | 61 | from .vision import scatter_patch_features, select_patch_features |
61 | 62 |
|
62 | 63 | logger = init_logger(__name__) |
@@ -710,12 +711,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
710 | 711 | self.config, |
711 | 712 | None, |
712 | 713 | prefix=maybe_prefix(prefix, "multi_modal_projector")) |
713 | | - self.language_model = init_vllm_registered_model( |
714 | | - vllm_config=vllm_config, |
715 | | - hf_config=config.text_config, |
716 | | - architectures=["Llama4ForCausalLM"], |
717 | | - prefix=maybe_prefix(prefix, "language_model")) |
718 | | - |
| 714 | + language_model_vllm_config = vllm_config.with_hf_config( |
| 715 | + config.text_config, architectures=["Llama4ForCausalLM"]) |
| 716 | + self.language_model = Llama4ForCausalLM( |
| 717 | + vllm_config=language_model_vllm_config, |
| 718 | + prefix=maybe_prefix(prefix, "language_model"), |
| 719 | + ) |
719 | 720 | self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config) |
720 | 721 |
|
721 | 722 | def _parse_and_validate_image_input( |
@@ -857,9 +858,8 @@ def load_weights(self, weights: Iterable[Tuple[str, |
857 | 858 |
|
858 | 859 | # language_model is an Llama4ForCausalLM instance. We load it's |
859 | 860 | # using llama4's load_weights routine. |
860 | | - language_model_prefix = "language_model.model." |
861 | 861 | language_model_weights, other_weights = self.separate_weights( |
862 | | - weights, prefix=language_model_prefix) |
| 862 | + weights, prefix="language_model.model.") |
863 | 863 | loader = AutoWeightsLoader(self) |
864 | 864 | loaded_language_model_params = loader.load_weights( |
865 | 865 | language_model_weights) |
|
0 commit comments