From e3a6673748837a60d1812039cc7b5109979f1908 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 10 Sep 2024 14:21:56 -0700 Subject: [PATCH] [Misc] remove peft as dependency for prompt models (#8162) --- vllm/config.py | 8 --- vllm/prompt_adapter/models.py | 2 +- vllm/prompt_adapter/utils.py | 93 +++++++++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 9 deletions(-) create mode 100644 vllm/prompt_adapter/utils.py diff --git a/vllm/config.py b/vllm/config.py index 8f5e02e35f28d..9e7c107900aaf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1558,14 +1558,6 @@ class PromptAdapterConfig: prompt_adapter_dtype: Optional[torch.dtype] = None def __post_init__(self): - library_name = 'peft' - try: - __import__(library_name) - except ImportError as e: - raise ImportError( - f"'{library_name}' is not installed for prompt adapter support." - f"Please install it using 'pip install {library_name}'." - ) from e if self.max_prompt_adapters < 1: raise ValueError(f"max_prompt_adapters " diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py index 93eb3bde646ac..18a5f86c341a9 100644 --- a/vllm/prompt_adapter/models.py +++ b/vllm/prompt_adapter/models.py @@ -14,6 +14,7 @@ from vllm.prompt_adapter.layers import ( VocabParallelEmbeddingWithPromptAdapter) # yapf: disable from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.utils import load_peft_weights logger = logging.getLogger(__name__) @@ -90,7 +91,6 @@ def from_local_checkpoint( config: PromptAdapterConfig, device: str = "cuda", ) -> "PromptAdapterModel": - from peft.utils import load_peft_weights if num_virtual_tokens > config.max_prompt_adapter_token: raise ValueError( diff --git a/vllm/prompt_adapter/utils.py b/vllm/prompt_adapter/utils.py new file mode 100644 index 0000000000000..989cc5a0f87c8 --- /dev/null +++ b/vllm/prompt_adapter/utils.py @@ -0,0 +1,93 @@ +# code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420 + +import os +from typing import Optional + +import torch +from huggingface_hub import file_exists, hf_hub_download +from huggingface_hub.utils import EntryNotFoundError +from safetensors.torch import load_file as safe_load_file + +WEIGHTS_NAME = "adapter_model.bin" +SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" + + +# Get current device name based on available devices +def infer_device() -> str: + if torch.cuda.is_available(): + return "cuda" + return "cpu" + + +def load_peft_weights(model_id: str, + device: Optional[str] = None, + **hf_hub_download_kwargs) -> dict: + r""" + A helper method to load the PEFT weights from the HuggingFace Hub or locally + + Args: + model_id (`str`): + The local path to the adapter weights or the name of the adapter to + load from the HuggingFace Hub. + device (`str`): + The device to load the weights onto. + hf_hub_download_kwargs (`dict`): + Additional arguments to pass to the `hf_hub_download` method when + loading from the HuggingFace Hub. + """ + path = (os.path.join(model_id, hf_hub_download_kwargs["subfolder"]) + if hf_hub_download_kwargs.get("subfolder", None) is not None else + model_id) + + if device is None: + device = infer_device() + + if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)): + filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME) + use_safetensors = True + elif os.path.exists(os.path.join(path, WEIGHTS_NAME)): + filename = os.path.join(path, WEIGHTS_NAME) + use_safetensors = False + else: + token = hf_hub_download_kwargs.get("token", None) + if token is None: + token = hf_hub_download_kwargs.get("use_auth_token", None) + + hub_filename = (os.path.join(hf_hub_download_kwargs["subfolder"], + SAFETENSORS_WEIGHTS_NAME) + if hf_hub_download_kwargs.get("subfolder", None) + is not None else SAFETENSORS_WEIGHTS_NAME) + has_remote_safetensors_file = file_exists( + repo_id=model_id, + filename=hub_filename, + revision=hf_hub_download_kwargs.get("revision", None), + repo_type=hf_hub_download_kwargs.get("repo_type", None), + token=token, + ) + use_safetensors = has_remote_safetensors_file + + if has_remote_safetensors_file: + # Priority 1: load safetensors weights + filename = hf_hub_download( + model_id, + SAFETENSORS_WEIGHTS_NAME, + **hf_hub_download_kwargs, + ) + else: + try: + filename = hf_hub_download(model_id, WEIGHTS_NAME, + **hf_hub_download_kwargs) + except EntryNotFoundError: + raise ValueError( # noqa: B904 + f"Can't find weights for {model_id} in {model_id} or \ + in the Hugging Face Hub. " + f"Please check that the file {WEIGHTS_NAME} or \ + {SAFETENSORS_WEIGHTS_NAME} is present at {model_id}.") + + if use_safetensors: + adapters_weights = safe_load_file(filename, device=device) + else: + adapters_weights = torch.load(filename, + map_location=torch.device(device)) + + return adapters_weights