From f5640cdde31a43581d084c741e889ccb04b4fd4e Mon Sep 17 00:00:00 2001 From: qthequartermasterman Date: Thu, 10 Oct 2024 14:00:45 -0500 Subject: [PATCH 1/7] fix argument type hints (remove implicit optional types) --- lm_eval/models/vllm_causallms.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index 168f490a7b..f0efb925ca 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -3,6 +3,7 @@ from importlib.util import find_spec from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union +import transformers from more_itertools import distribute from packaging.version import parse as parse_version from tqdm import tqdm @@ -41,25 +42,25 @@ def __init__( pretrained: str, dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto", revision: Optional[str] = None, - trust_remote_code: Optional[bool] = False, + trust_remote_code: bool = False, tokenizer: Optional[str] = None, tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_revision: Optional[str] = None, - add_bos_token: Optional[bool] = False, + add_bos_token: bool = False, prefix_token_id: Optional[int] = None, tensor_parallel_size: int = 1, quantization: Optional[str] = None, max_gen_toks: int = 256, swap_space: int = 4, batch_size: Union[str, int] = 1, - max_batch_size=None, - max_length: int = None, - max_model_len: int = None, + max_batch_size: Optional[int] = None, + max_length: Optional[int] = None, + max_model_len: Optional[int] = None, seed: int = 1234, gpu_memory_utilization: float = 0.9, device: str = "cuda", data_parallel_size: int = 1, - lora_local_path: str = None, + lora_local_path: Optional[str] = None, **kwargs, ): super().__init__() @@ -75,7 +76,9 @@ def __init__( max_length is None or max_model_len is None ), "Either max_length or max_model_len may be provided, but not both" - self._max_length = max_model_len if max_model_len is not None else max_length + self._max_length: Optional[int] = ( + max_model_len if max_model_len is not None else max_length + ) self.tensor_parallel_size = int(tensor_parallel_size) self.data_parallel_size = int(data_parallel_size) self.model_args = { @@ -114,7 +117,7 @@ def __init__( self._config = AutoConfig.from_pretrained( pretrained, trust_remote_code=trust_remote_code, revision=revision ) - self.tokenizer = get_tokenizer( + self.tokenizer: transformers.PreTrainedTokenizerBase = get_tokenizer( tokenizer if tokenizer else pretrained, tokenizer_mode=tokenizer_mode, trust_remote_code=trust_remote_code, @@ -136,6 +139,7 @@ def __init__( self._max_gen_toks = max_gen_toks + self.lora_request: Optional[LoRARequest] if lora_local_path is not None: assert parse_version(version("vllm")) > parse_version( "0.3.0" @@ -194,7 +198,7 @@ def tokenizer_name(self) -> str: def tok_encode( self, string: Union[str, List[str]], - left_truncate_len: int = None, + left_truncate_len: Optional[int] = None, add_special_tokens: bool = False, truncation: bool = False, ) -> Union[List[int], List[List[int]]]: @@ -218,9 +222,9 @@ def tok_encode( def _model_generate( self, - requests: List[List[int]] = None, + requests: Optional[List[List[int]]] = None, generate: bool = False, - max_tokens: int = None, + max_tokens: Optional[int] = None, stop: Optional[List[str]] = None, **kwargs, ): @@ -298,9 +302,8 @@ def loglikelihood_rolling( ) # discard is_greedy - string_nll = [x[0] for x in string_nll] + string_nll = sum(x[0] for x in string_nll) - string_nll = sum(string_nll) loglikelihoods.append(string_nll) # cache this loglikelihood_rolling request From e9a7f3789415fa314c4b3d771042100627849bf6 Mon Sep 17 00:00:00 2001 From: qthequartermasterman Date: Thu, 10 Oct 2024 15:32:54 -0500 Subject: [PATCH 2/7] fix type hints for _model_generate and tok_encode --- lm_eval/models/vllm_causallms.py | 79 ++++++++++++++++++++++++++++---- 1 file changed, 71 insertions(+), 8 deletions(-) diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index f0efb925ca..243ef1e3de 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -1,12 +1,23 @@ import copy from importlib.metadata import version from importlib.util import find_spec -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Literal, + Optional, + Tuple, + Union, + overload, +) import transformers from more_itertools import distribute from packaging.version import parse as parse_version from tqdm import tqdm +from typing_extensions import TypedDict, Unpack from lm_eval.api.instance import Instance from lm_eval.api.model import TemplateLM @@ -33,6 +44,34 @@ eval_logger = eval_logger +class OptionalModelArgs(TypedDict, total=False): + skip_tokenizer_init: bool + cpu_offload_gb: float + enforce_eager: Optional[bool] + max_context_len_to_capture: Optional[int] + max_seq_len_to_capture: int + disable_custom_all_reduce: bool + disable_async_output_proc: bool + mm_processor_kwargs: Optional[Dict[str, Any]] + worker_use_ray: bool + + +class ModelArgs(OptionalModelArgs): + model: str + tokenizer: Optional[str] + tokenizer_mode: Literal["auto", "slow"] + trust_remote_code: bool + tensor_parallel_size: int + dtype: Literal["float16", "bfloat16", "float32", "auto"] + gpu_memory_utilization: float + revision: Optional[str] + tokenizer_revision: Optional[str] + max_model_len: Optional[int] + swap_space: int + quantization: Optional[str] + seed: int + + @register_model("vllm") class VLLM(TemplateLM): _DEFAULT_MAX_LENGTH = 2048 @@ -61,7 +100,7 @@ def __init__( device: str = "cuda", data_parallel_size: int = 1, lora_local_path: Optional[str] = None, - **kwargs, + **kwargs: Unpack[OptionalModelArgs], ): super().__init__() @@ -81,7 +120,7 @@ def __init__( ) self.tensor_parallel_size = int(tensor_parallel_size) self.data_parallel_size = int(data_parallel_size) - self.model_args = { + self.model_args: ModelArgs = { "model": pretrained, "gpu_memory_utilization": float(gpu_memory_utilization), "revision": revision, @@ -195,6 +234,24 @@ def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: def tokenizer_name(self) -> str: return self.tokenizer.name_or_path.replace("/", "__") + @overload # type: ignore[override] + def tok_encode( + self, + string: str, + left_truncate_len: Optional[int] = ..., + add_special_tokens: bool = ..., + truncation: bool = ..., + ) -> List[int]: ... + + @overload + def tok_encode( + self, + string: List[str], + left_truncate_len: Optional[int], + add_special_tokens: bool, + truncation: bool = ..., + ) -> List[List[int]]: ... + def tok_encode( self, string: Union[str, List[str]], @@ -204,7 +261,7 @@ def tok_encode( ) -> Union[List[int], List[List[int]]]: if not add_special_tokens: add_special_tokens = False or self.add_bos_token - encoding: Union[List[List[int]], List[int]] = self.tokenizer( + encoding = self.tokenizer( string, add_special_tokens=add_special_tokens, truncation=truncation, @@ -222,7 +279,7 @@ def tok_encode( def _model_generate( self, - requests: Optional[List[List[int]]] = None, + requests: List[List[int]], generate: bool = False, max_tokens: Optional[int] = None, stop: Optional[List[str]] = None, @@ -243,7 +300,9 @@ def _model_generate( # but then tensor_parallel breaks @ray.remote def run_inference_one_model( - model_args: dict, sampling_params, requests: List[List[int]] + model_args: ModelArgs, + sampling_params: SamplingParams, + requests: List[List[int]], ): llm = LLM(**model_args) return llm.generate( @@ -252,8 +311,12 @@ def run_inference_one_model( # dispatch requests to all self.data_parallel_size workers, in interleaved fashion # interleaved important to balance context lengths across workers - requests = [list(x) for x in distribute(self.data_parallel_size, requests)] - inputs = ((self.model_args, sampling_params, req) for req in requests) + distributed_requests = [ + list(x) for x in distribute(self.data_parallel_size, requests) + ] + inputs = ( + (self.model_args, sampling_params, req) for req in distributed_requests + ) object_refs = [run_inference_one_model.remote(*x) for x in inputs] results = ray.get(object_refs) # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required. From a2d96ea92bb2ceac22b3ce851f4ea8a9edc09f77 Mon Sep 17 00:00:00 2001 From: qthequartermasterman Date: Thu, 10 Oct 2024 15:35:10 -0500 Subject: [PATCH 3/7] fix type hints for loglikelihood_rolling --- lm_eval/models/vllm_causallms.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index 243ef1e3de..edfa43a6e6 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -365,12 +365,14 @@ def loglikelihood_rolling( ) # discard is_greedy - string_nll = sum(x[0] for x in string_nll) + summed_loglikelihood = sum(x[0] for x in string_nll) - loglikelihoods.append(string_nll) + loglikelihoods.append(summed_loglikelihood) # cache this loglikelihood_rolling request - self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll) + self.cache_hook.add_partial( + "loglikelihood_rolling", (string,), summed_loglikelihood + ) return loglikelihoods From 363ab6c5d1562da044349feb60e2cd8f81b072e0 Mon Sep 17 00:00:00 2001 From: qthequartermasterman Date: Thu, 10 Oct 2024 20:43:49 -0500 Subject: [PATCH 4/7] fix type hints for generate_until, tok_encode --- lm_eval/api/instance.py | 2 +- lm_eval/models/vllm_causallms.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lm_eval/api/instance.py b/lm_eval/api/instance.py index d3c6afa064..d3633f7023 100644 --- a/lm_eval/api/instance.py +++ b/lm_eval/api/instance.py @@ -29,7 +29,7 @@ def __post_init__(self) -> None: self.task_name, self.doc_id, self.repeats = self.metadata @property - def args(self): + def args(self) -> Tuple[str, ...]: """ Returns (string,) where `string` is the string to calculate loglikelihood over """ diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index edfa43a6e6..3c82a1043c 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -246,15 +246,15 @@ def tok_encode( @overload def tok_encode( self, - string: List[str], - left_truncate_len: Optional[int], - add_special_tokens: bool, + string: Union[List[str], Tuple[str, ...]], + left_truncate_len: Optional[int] = ..., + add_special_tokens: bool = ..., truncation: bool = ..., ) -> List[List[int]]: ... def tok_encode( self, - string: Union[str, List[str]], + string: Union[str, List[str], Tuple[str, ...]], left_truncate_len: Optional[int] = None, add_special_tokens: bool = False, truncation: bool = False, @@ -384,7 +384,7 @@ def generate_until( # batch tokenize contexts context, all_gen_kwargs = zip(*(req.args for req in requests)) context_encoding: List[List[int]] = self.tok_encode( - context, add_special_tokens=self.add_bos_token + list(context), add_special_tokens=self.add_bos_token ) requests = [ ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs) @@ -473,7 +473,7 @@ def _collate_gen(_requests): # reorder all group of results back to original unsorted form return re_ords.get_original(res) - def _loglikelihood_tokens( + def _loglikelihood_tokens( # type: ignore[override] self, requests: List[Tuple[Tuple[str, str], List[int], List[int]]], disable_tqdm: bool = False, From 35901f3f570ebb151f145572c5032903e8e69710 Mon Sep 17 00:00:00 2001 From: qthequartermasterman Date: Thu, 10 Oct 2024 20:46:12 -0500 Subject: [PATCH 5/7] fix type hints for generate_until, tok_encode --- lm_eval/models/vllm_causallms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index 3c82a1043c..9754e5b891 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -386,7 +386,7 @@ def generate_until( context_encoding: List[List[int]] = self.tok_encode( list(context), add_special_tokens=self.add_bos_token ) - requests = [ + requests_grouped = [ ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs) ] @@ -402,13 +402,13 @@ def _collate_gen(_requests): # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. - re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs") + re_ords = Collator(requests_grouped, _collate_gen, group_by="gen_kwargs") chunks = re_ords.get_batched( n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None ) pbar = tqdm( - total=len(requests), + total=len(requests_grouped), disable=(disable_tqdm or (self.rank != 0)), desc="Running generate_until requests", ) From eacd5a64a5ab6135051be1690617ddd1f9e50d76 Mon Sep 17 00:00:00 2001 From: qthequartermasterman Date: Fri, 18 Oct 2024 11:23:04 -0500 Subject: [PATCH 6/7] fix incorrect shortcircuit expression order --- lm_eval/models/vllm_causallms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index 9754e5b891..b185804414 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -260,7 +260,7 @@ def tok_encode( truncation: bool = False, ) -> Union[List[int], List[List[int]]]: if not add_special_tokens: - add_special_tokens = False or self.add_bos_token + add_special_tokens = self.add_bos_token or False encoding = self.tokenizer( string, add_special_tokens=add_special_tokens, From d98165bb1d054085885315a46598c8e4cb6e3b28 Mon Sep 17 00:00:00 2001 From: qthequartermasterman Date: Fri, 18 Oct 2024 11:46:19 -0500 Subject: [PATCH 7/7] do not use now-deprecated prompt_token_ids kwarg --- lm_eval/models/vllm_causallms.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index b185804414..4e86700247 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -305,9 +305,7 @@ def run_inference_one_model( requests: List[List[int]], ): llm = LLM(**model_args) - return llm.generate( - prompt_token_ids=requests, sampling_params=sampling_params - ) + return llm.generate(requests, sampling_params=sampling_params) # dispatch requests to all self.data_parallel_size workers, in interleaved fashion # interleaved important to balance context lengths across workers @@ -326,14 +324,14 @@ def run_inference_one_model( if self.lora_request is not None: outputs = self.model.generate( - prompt_token_ids=requests, + requests, sampling_params=sampling_params, use_tqdm=True if self.batch_size == "auto" else False, lora_request=self.lora_request, ) else: outputs = self.model.generate( - prompt_token_ids=requests, + requests, sampling_params=sampling_params, use_tqdm=True if self.batch_size == "auto" else False, )