Skip to content

Commit

Permalink
[Misc] remove peft as dependency for prompt models (vllm-project#8162)
Browse files Browse the repository at this point in the history
  • Loading branch information
prashantgupta24 authored and siddharth9820 committed Sep 30, 2024
1 parent 3728549 commit e3a6673
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 9 deletions.
8 changes: 0 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,14 +1558,6 @@ class PromptAdapterConfig:
prompt_adapter_dtype: Optional[torch.dtype] = None

def __post_init__(self):
library_name = 'peft'
try:
__import__(library_name)
except ImportError as e:
raise ImportError(
f"'{library_name}' is not installed for prompt adapter support."
f"Please install it using 'pip install {library_name}'."
) from e

if self.max_prompt_adapters < 1:
raise ValueError(f"max_prompt_adapters "
Expand Down
2 changes: 1 addition & 1 deletion vllm/prompt_adapter/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.prompt_adapter.layers import (
VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.utils import load_peft_weights

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,7 +91,6 @@ def from_local_checkpoint(
config: PromptAdapterConfig,
device: str = "cuda",
) -> "PromptAdapterModel":
from peft.utils import load_peft_weights

if num_virtual_tokens > config.max_prompt_adapter_token:
raise ValueError(
Expand Down
93 changes: 93 additions & 0 deletions vllm/prompt_adapter/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420

import os
from typing import Optional

import torch
from huggingface_hub import file_exists, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from safetensors.torch import load_file as safe_load_file

WEIGHTS_NAME = "adapter_model.bin"
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"


# Get current device name based on available devices
def infer_device() -> str:
if torch.cuda.is_available():
return "cuda"
return "cpu"


def load_peft_weights(model_id: str,
device: Optional[str] = None,
**hf_hub_download_kwargs) -> dict:
r"""
A helper method to load the PEFT weights from the HuggingFace Hub or locally
Args:
model_id (`str`):
The local path to the adapter weights or the name of the adapter to
load from the HuggingFace Hub.
device (`str`):
The device to load the weights onto.
hf_hub_download_kwargs (`dict`):
Additional arguments to pass to the `hf_hub_download` method when
loading from the HuggingFace Hub.
"""
path = (os.path.join(model_id, hf_hub_download_kwargs["subfolder"])
if hf_hub_download_kwargs.get("subfolder", None) is not None else
model_id)

if device is None:
device = infer_device()

if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)):
filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME)
use_safetensors = True
elif os.path.exists(os.path.join(path, WEIGHTS_NAME)):
filename = os.path.join(path, WEIGHTS_NAME)
use_safetensors = False
else:
token = hf_hub_download_kwargs.get("token", None)
if token is None:
token = hf_hub_download_kwargs.get("use_auth_token", None)

hub_filename = (os.path.join(hf_hub_download_kwargs["subfolder"],
SAFETENSORS_WEIGHTS_NAME)
if hf_hub_download_kwargs.get("subfolder", None)
is not None else SAFETENSORS_WEIGHTS_NAME)
has_remote_safetensors_file = file_exists(
repo_id=model_id,
filename=hub_filename,
revision=hf_hub_download_kwargs.get("revision", None),
repo_type=hf_hub_download_kwargs.get("repo_type", None),
token=token,
)
use_safetensors = has_remote_safetensors_file

if has_remote_safetensors_file:
# Priority 1: load safetensors weights
filename = hf_hub_download(
model_id,
SAFETENSORS_WEIGHTS_NAME,
**hf_hub_download_kwargs,
)
else:
try:
filename = hf_hub_download(model_id, WEIGHTS_NAME,
**hf_hub_download_kwargs)
except EntryNotFoundError:
raise ValueError( # noqa: B904
f"Can't find weights for {model_id} in {model_id} or \
in the Hugging Face Hub. "
f"Please check that the file {WEIGHTS_NAME} or \
{SAFETENSORS_WEIGHTS_NAME} is present at {model_id}.")

if use_safetensors:
adapters_weights = safe_load_file(filename, device=device)
else:
adapters_weights = torch.load(filename,
map_location=torch.device(device))

return adapters_weights

0 comments on commit e3a6673

Please sign in to comment.