Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Type Hints for vLLM CausalLM model #2408

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion lm_eval/api/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __post_init__(self) -> None:
self.task_name, self.doc_id, self.repeats = self.metadata

@property
def args(self):
def args(self) -> Tuple[str, ...]:
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
Expand Down
134 changes: 100 additions & 34 deletions lm_eval/models/vllm_causallms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import copy
from importlib.metadata import version
from importlib.util import find_spec
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
overload,
)

import transformers
from more_itertools import distribute
from packaging.version import parse as parse_version
from tqdm import tqdm
from typing_extensions import TypedDict, Unpack

from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
Expand All @@ -32,6 +44,34 @@
eval_logger = eval_logger


class OptionalModelArgs(TypedDict, total=False):
skip_tokenizer_init: bool
cpu_offload_gb: float
enforce_eager: Optional[bool]
max_context_len_to_capture: Optional[int]
max_seq_len_to_capture: int
disable_custom_all_reduce: bool
disable_async_output_proc: bool
mm_processor_kwargs: Optional[Dict[str, Any]]
worker_use_ray: bool


class ModelArgs(OptionalModelArgs):
model: str
tokenizer: Optional[str]
tokenizer_mode: Literal["auto", "slow"]
trust_remote_code: bool
tensor_parallel_size: int
dtype: Literal["float16", "bfloat16", "float32", "auto"]
gpu_memory_utilization: float
revision: Optional[str]
tokenizer_revision: Optional[str]
max_model_len: Optional[int]
swap_space: int
quantization: Optional[str]
seed: int


@register_model("vllm")
class VLLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048
Expand All @@ -41,26 +81,26 @@ def __init__(
pretrained: str,
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None,
trust_remote_code: Optional[bool] = False,
trust_remote_code: bool = False,
tokenizer: Optional[str] = None,
tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None,
add_bos_token: Optional[bool] = False,
add_bos_token: bool = False,
prefix_token_id: Optional[int] = None,
tensor_parallel_size: int = 1,
quantization: Optional[str] = None,
max_gen_toks: int = 256,
swap_space: int = 4,
batch_size: Union[str, int] = 1,
max_batch_size=None,
max_length: int = None,
max_model_len: int = None,
max_batch_size: Optional[int] = None,
max_length: Optional[int] = None,
max_model_len: Optional[int] = None,
seed: int = 1234,
gpu_memory_utilization: float = 0.9,
device: str = "cuda",
data_parallel_size: int = 1,
lora_local_path: str = None,
**kwargs,
lora_local_path: Optional[str] = None,
**kwargs: Unpack[OptionalModelArgs],
):
super().__init__()

Expand All @@ -75,10 +115,12 @@ def __init__(
max_length is None or max_model_len is None
), "Either max_length or max_model_len may be provided, but not both"

self._max_length = max_model_len if max_model_len is not None else max_length
self._max_length: Optional[int] = (
max_model_len if max_model_len is not None else max_length
)
self.tensor_parallel_size = int(tensor_parallel_size)
self.data_parallel_size = int(data_parallel_size)
self.model_args = {
self.model_args: ModelArgs = {
"model": pretrained,
"gpu_memory_utilization": float(gpu_memory_utilization),
"revision": revision,
Expand Down Expand Up @@ -114,7 +156,7 @@ def __init__(
self._config = AutoConfig.from_pretrained(
pretrained, trust_remote_code=trust_remote_code, revision=revision
)
self.tokenizer = get_tokenizer(
self.tokenizer: transformers.PreTrainedTokenizerBase = get_tokenizer(
tokenizer if tokenizer else pretrained,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
Expand All @@ -136,6 +178,7 @@ def __init__(

self._max_gen_toks = max_gen_toks

self.lora_request: Optional[LoRARequest]
if lora_local_path is not None:
assert parse_version(version("vllm")) > parse_version(
"0.3.0"
Expand Down Expand Up @@ -191,16 +234,34 @@ def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
def tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__")

@overload # type: ignore[override]
def tok_encode(
self,
string: str,
left_truncate_len: Optional[int] = ...,
add_special_tokens: bool = ...,
truncation: bool = ...,
) -> List[int]: ...

@overload
def tok_encode(
self,
string: Union[str, List[str]],
left_truncate_len: int = None,
string: Union[List[str], Tuple[str, ...]],
left_truncate_len: Optional[int] = ...,
add_special_tokens: bool = ...,
truncation: bool = ...,
) -> List[List[int]]: ...

def tok_encode(
self,
string: Union[str, List[str], Tuple[str, ...]],
left_truncate_len: Optional[int] = None,
add_special_tokens: bool = False,
truncation: bool = False,
) -> Union[List[int], List[List[int]]]:
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer(
add_special_tokens = self.add_bos_token or False
encoding = self.tokenizer(
string,
add_special_tokens=add_special_tokens,
truncation=truncation,
Expand All @@ -218,9 +279,9 @@ def tok_encode(

def _model_generate(
self,
requests: List[List[int]] = None,
requests: List[List[int]],
generate: bool = False,
max_tokens: int = None,
max_tokens: Optional[int] = None,
stop: Optional[List[str]] = None,
**kwargs,
):
Expand All @@ -239,17 +300,21 @@ def _model_generate(
# but then tensor_parallel breaks
@ray.remote
def run_inference_one_model(
model_args: dict, sampling_params, requests: List[List[int]]
model_args: ModelArgs,
sampling_params: SamplingParams,
requests: List[List[int]],
):
llm = LLM(**model_args)
return llm.generate(
prompt_token_ids=requests, sampling_params=sampling_params
)
return llm.generate(requests, sampling_params=sampling_params)

# dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
inputs = ((self.model_args, sampling_params, req) for req in requests)
distributed_requests = [
list(x) for x in distribute(self.data_parallel_size, requests)
]
inputs = (
(self.model_args, sampling_params, req) for req in distributed_requests
)
object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs)
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
Expand All @@ -259,14 +324,14 @@ def run_inference_one_model(

if self.lora_request is not None:
outputs = self.model.generate(
prompt_token_ids=requests,
requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request,
)
else:
outputs = self.model.generate(
prompt_token_ids=requests,
requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
)
Expand Down Expand Up @@ -298,13 +363,14 @@ def loglikelihood_rolling(
)

# discard is_greedy
string_nll = [x[0] for x in string_nll]
summed_loglikelihood = sum(x[0] for x in string_nll)

string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
loglikelihoods.append(summed_loglikelihood)

# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
self.cache_hook.add_partial(
"loglikelihood_rolling", (string,), summed_loglikelihood
)

return loglikelihoods

Expand All @@ -316,9 +382,9 @@ def generate_until(
# batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding: List[List[int]] = self.tok_encode(
context, add_special_tokens=self.add_bos_token
list(context), add_special_tokens=self.add_bos_token
)
requests = [
requests_grouped = [
((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
]

Expand All @@ -334,13 +400,13 @@ def _collate_gen(_requests):
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
re_ords = Collator(requests_grouped, _collate_gen, group_by="gen_kwargs")
chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
)

pbar = tqdm(
total=len(requests),
total=len(requests_grouped),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests",
)
Expand Down Expand Up @@ -405,7 +471,7 @@ def _collate_gen(_requests):
# reorder all group of results back to original unsorted form
return re_ords.get_original(res)

def _loglikelihood_tokens(
def _loglikelihood_tokens( # type: ignore[override]
self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
Expand Down
Loading