From 6d0027913b01936f8f505a110d7e1abe8a0bd211 Mon Sep 17 00:00:00 2001 From: artemorloff Date: Sun, 15 Sep 2024 18:52:52 +0300 Subject: [PATCH] add new truncation --- lm_eval/__main__.py | 27 ++ lm_eval/api/samplers.py | 12 +- lm_eval/api/task.py | 123 ++++++--- lm_eval/evaluator.py | 13 +- lm_eval/truncation_utils.py | 531 ++++++++++++++++++++++++++++++++++++ 5 files changed, 660 insertions(+), 46 deletions(-) create mode 100644 lm_eval/truncation_utils.py diff --git a/lm_eval/__main__.py b/lm_eval/__main__.py index ab68781939..4b86b97250 100644 --- a/lm_eval/__main__.py +++ b/lm_eval/__main__.py @@ -10,6 +10,7 @@ from lm_eval.evaluator import request_caching_arg_to_dict from lm_eval.loggers import EvaluationTracker, WandbLogger from lm_eval.tasks import TaskManager +from lm_eval.truncation_utils import process_truncation_args from lm_eval.utils import handle_non_serializable, make_table, simple_parse_args_string @@ -257,6 +258,28 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub", ) + parser.add_argument( + "--truncation_args", + type=str, + default="how=default,on=tokens,side=left,keep_first=False,max_symbols=2048,max_new_symbols=256", + help=( + "The truncation mode. Available options:" + "\n`how`:\n\t`no` - no predefined truncation, default option" + "\n\t`default` - regular harness truncation style, default option" + "\n\t`fewshots` - truncate only fewshots, for zero-shot does nothing" + "\n\t`user` - truncate only user-prompt" + "\n\t`transformers` - use transformers truncation" + "\n`on`:\n\t`tokens` - truncate by tokens, tokenizer required, uses models' max_length param, default option" + "\n\t`symbols` - truncate by symbols, does not use tokenizer, require max_symbols setting" + "\n`side`:\n\t`left` - truncate from the left side, default option" + "\n\t`right` - truncate from the right side, the test doc is preserved" + "\n`keep_first`:\n\t`true` - keep the first few-shot from the defined `side`," + " ignored for `how` != `fewshots`" + "\n\t`false` - does not keep the first few-shot, default option" + "\n`max_symbols`:\n\tinteger, the maximum number of symbols in request, default is 2048" + "\n`max_new_symbols`:\n\tinteger, the maximum number of new symbols to subtract this value from `max_symbols`, default is 256" + ), + ) return parser @@ -314,6 +337,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." ) + # make it a valid dict with truncation params + truncation_args = process_truncation_args(args.truncation_args) + if args.tasks is None: eval_logger.error("Need to specify task to evaluate.") sys.exit() @@ -404,6 +430,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: numpy_random_seed=args.seed[1], torch_random_seed=args.seed[2], fewshot_random_seed=args.seed[3], + truncation_args=truncation_args, **request_caching_args, ) diff --git a/lm_eval/api/samplers.py b/lm_eval/api/samplers.py index 2cdc4e43e7..765f8040c7 100644 --- a/lm_eval/api/samplers.py +++ b/lm_eval/api/samplers.py @@ -73,26 +73,28 @@ def get_context(self, doc, num_fewshot): # TODO: should we just stop people from using fewshot from same split as evaluating? selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] - labeled_examples = "" + labeled_examples = [] for doc in selected_docs: + one_example = "" doc_content = self.doc_to_text(doc) doc_target = self.doc_to_target(doc) - labeled_examples += ( + one_example += ( doc_content if self.config.doc_to_choice is None or isinstance(doc_content, str) else self.doc_to_choice(doc)[doc_content] ) if doc_target != "": - labeled_examples += self.target_delimiter - labeled_examples += ( + one_example += self.target_delimiter + one_example += ( str(doc_target[0]) if isinstance(doc_target, list) else doc_target if self.config.doc_to_choice is None or isinstance(doc_target, str) else str(self.doc_to_choice(doc)[doc_target]) ) - labeled_examples += self.fewshot_delimiter + one_example += self.fewshot_delimiter + labeled_examples.extend([one_example]) return labeled_examples diff --git a/lm_eval/api/task.py b/lm_eval/api/task.py index 8dadf94857..9cc387d2cb 100644 --- a/lm_eval/api/task.py +++ b/lm_eval/api/task.py @@ -28,6 +28,7 @@ from lm_eval.api import samplers from lm_eval.api.instance import Instance, OutputType from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity +from lm_eval.api.model import LM from lm_eval.api.registry import ( AGGREGATION_REGISTRY, DEFAULT_METRIC_REGISTRY, @@ -39,6 +40,7 @@ from lm_eval.caching.cache import load_from_cache, save_to_cache from lm_eval.filters import build_filter_ensemble from lm_eval.prompts import get_prompt +from lm_eval.truncation_utils import truncate_and_chat_template ALL_OUTPUT_TYPES = [ @@ -395,6 +397,8 @@ def build_all_requests( fewshot_as_multiturn: bool = False, chat_template: Optional[Callable] = None, tokenizer_name: str = "", + truncation_args: Optional[Dict[str, Union[str, bool, int]]] = None, + lm: "LM" = None, ) -> None: """Build a set of Instances for a task, and store them in task.instances""" @@ -448,13 +452,19 @@ def build_all_requests( total=num_docs, ): # sample fewshot context #TODO: need to offset doc_id by rank now! - fewshot_ctx = self.fewshot_context( + # fewshot_ctx is a List like [system_prompt, fewshot1, fewshot2, ...] + # TODO: system_promt is defined once and added to all task samples + fewshot_ctx, first_system = self.fewshot_context( doc, 0 if self.config.num_fewshot is None else self.config.num_fewshot, system_instruction, apply_chat_template, fewshot_as_multiturn, - chat_template, + ) + + # just add the test sample at the end of the list + fewshot_ctx = self.add_test_sample( + doc, fewshot_ctx, apply_chat_template, fewshot_as_multiturn ) # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute @@ -467,6 +477,12 @@ def build_all_requests( if not isinstance(inst, list): inst = [inst] + # TODO: add some notification system here based on returned status + for elem in inst: + elem, status = truncate_and_chat_template( + elem, lm, chat_template, truncation_args, first_system + ) + instances.append(inst) # now flatten, this is to allow slicing to work with pickles @@ -1057,49 +1073,43 @@ def fewshot_context( The fewshot context. """ - if apply_chat_template: - labeled_examples = [] - else: - labeled_examples = "" - - # get task description - if description := self.config.description: - description = utils.apply_template(self.config.description, doc) + # always make a list of fewshots + labeled_examples = [] + first_system = False - # create system prompt based on the provided system instruction and description - if system_instruction is not None and description: - system_prompt = ( - f"{system_instruction}{self.sampler.fewshot_delimiter}{description}" - ) - elif system_instruction is not None: - system_prompt = system_instruction - elif description: - system_prompt = description - else: - system_prompt = "" + system_prompt = self.define_system_prompt( + doc, system_instruction, apply_chat_template + ) - # add system prompt if specified - if system_prompt: - if apply_chat_template: - labeled_examples.append({"role": "system", "content": system_prompt}) - else: - labeled_examples = system_prompt + if system_prompt is not None: + labeled_examples.extend(system_prompt) + first_system = True # if few-shot - append examples after the system prompt + # fewshots are still a list if num_fewshot > 0: if apply_chat_template: - labeled_examples.extend( - self.sampler.get_chat_context( - doc, num_fewshot, fewshot_as_multiturn - ) + fewshots = self.sampler.get_chat_context( + doc, num_fewshot, fewshot_as_multiturn ) else: - labeled_examples += self.sampler.get_context(doc, num_fewshot) + fewshots = self.sampler.get_context(doc, num_fewshot) + labeled_examples.extend(fewshots) + return labeled_examples, first_system + + @utils.positional_deprecated + def add_test_sample( + self, + doc: str, + labeled_examples: List = [], + apply_chat_template: bool = False, + fewshot_as_multiturn: bool = False, + ): example = self.doc_to_text(doc) if apply_chat_template: if self.multiple_input: - return chat_template(labeled_examples) + return labeled_examples if isinstance(example, str): self.append_target_question( labeled_examples, example, fewshot_as_multiturn @@ -1111,7 +1121,7 @@ def fewshot_context( for ex in example: chat = deepcopy(labeled_examples) self.append_target_question(chat, ex, fewshot_as_multiturn) - labeled_examples_list.append(chat_template(chat)) + labeled_examples_list.append([chat]) return labeled_examples_list # if example is an integer, append the choice or convert to string elif isinstance(example, int): @@ -1124,21 +1134,54 @@ def fewshot_context( self.append_target_question( labeled_examples, str(example), fewshot_as_multiturn ) - # return lm.apply_chat_template(labeled_examples) - return chat_template(labeled_examples) + return labeled_examples else: if self.multiple_input: return labeled_examples if isinstance(example, str): - return labeled_examples + example + return labeled_examples + [example] elif isinstance(example, list): - return [labeled_examples + ex for ex in example] + labeled_examples_list = [labeled_examples + [ex] for ex in example] + return labeled_examples_list elif isinstance(example, int): if self.config.doc_to_choice is not None: choices = self.doc_to_choice(doc) - return labeled_examples + choices[example] + return labeled_examples + [choices[example]] else: - return labeled_examples + str(example) + return labeled_examples + [str(example)] + + @utils.positional_deprecated + def define_system_prompt( + self, + doc: str, + system_instruction: Optional[str] = None, + apply_chat_template: bool = False, + ): + # get task description + if description := self.config.description: + description = utils.apply_template(self.config.description, doc) + + # create system prompt based on the provided system instruction and description + if system_instruction is not None and description: + system_prompt = ( + f"{system_instruction}{self.sampler.fewshot_delimiter}{description}" + ) + elif system_instruction is not None: + system_prompt = system_instruction + elif description: + system_prompt = description + else: + system_prompt = "" + + # add system prompt if specified + if system_prompt: + # add system prompt if specified + if apply_chat_template: + return {"role": "system", "content": system_prompt} + + return system_prompt + # TODO: returning None is bad practice + return None def apply_filters(self): """Iterates over FilterEnsembles and applies them to instances""" diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index 4f4c8f7fba..f615dc7a10 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -4,7 +4,7 @@ import random import time from collections import defaultdict -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union import numpy as np import torch @@ -74,6 +74,7 @@ def simple_evaluate( numpy_random_seed: int = 1234, torch_random_seed: int = 1234, fewshot_random_seed: int = 1234, + truncation_args: Optional[Dict[str, Union[str, bool, int]]] = None, ): """Instantiate and evaluate a model on a list of tasks. @@ -132,6 +133,8 @@ def simple_evaluate( Random seed for torch. If set to None, the seed will not be set. :param fewshot_random_seed: int Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None. + :param truncation_args: Dict[str, Union[str, bool, int]] + The mode of truncation applied to sequences. :return Dictionary of results @@ -311,6 +314,7 @@ def _adjust_config(task_dict): apply_chat_template=apply_chat_template, fewshot_as_multiturn=fewshot_as_multiturn, verbosity=verbosity, + truncation_args=truncation_args, ) if lm.rank == 0: @@ -345,6 +349,7 @@ def _adjust_config(task_dict): "numpy_seed": numpy_random_seed, "torch_seed": torch_random_seed, "fewshot_seed": fewshot_random_seed, + "truncation_args": truncation_args, } ) results["git_hash"] = get_git_commit_hash() @@ -370,6 +375,7 @@ def evaluate( apply_chat_template: Union[bool, str] = False, fewshot_as_multiturn: bool = False, verbosity: str = "INFO", + truncation_args: Optional[Dict[str, Union[str, bool, int]]] = None, ): """Instantiate and evaluate a model on a list of tasks. @@ -394,6 +400,9 @@ def evaluate( Defaults to False (no chat template applied). :param fewshot_as_multiturn: bool Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param truncation_args: Dict[str, Union[str, bool, int]] + The mode of truncation applied to sequences. + :return Dictionary of results """ @@ -452,6 +461,8 @@ def evaluate( tokenizer_name=getattr(lm, "tokenizer_name", "") if apply_chat_template else "", + truncation_args=truncation_args, + lm=lm, ) eval_logger.debug( f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}" diff --git a/lm_eval/truncation_utils.py b/lm_eval/truncation_utils.py new file mode 100644 index 0000000000..b71c2962b5 --- /dev/null +++ b/lm_eval/truncation_utils.py @@ -0,0 +1,531 @@ +from typing import Dict, Union + +from lm_eval.utils import simple_parse_args_string + + +def process_truncation_args(args: Dict[str, str]) -> Dict[str, Union[str, int, bool]]: + default_args = { + "how": "no", + "on": "tokens", + "side": "left", + "keep_first": False, + "max_symbols": 2048, + "max_new_symbols": 256, + } + if args: + args = simple_parse_args_string(args) + default_args.update(args) + return default_args + + +def unpack_group(lst): + return [elem for pack in lst for elem in pack] + + +def group_dicts(req, skip_system): + groups = [] + if skip_system: + groups.extend([[req[0]]]) # system instaruction goes as separate list + # split remaining seq into pairs user-item + for first in range(skip_system, len(req) - 1, 2): + groups.extend([[req[first], req[first + 1]]]) + groups.extend([[req[-1]]]) + return groups + + +def tokenize_sequence(seq, tokenizer, add_special_tokens, symbols): + # TODO: make sure it works for API and vLLM tokenizers + if symbols: + # consider 1 symbol = 1 token, for models with no tokenizer the only option + return seq + return tokenizer(seq, add_special_tokens=add_special_tokens)["input_ids"] + + +def apply_chat_template(seq, tokenizer, chat_template, add_generation_prompt, tokenize): + # TODO: definitely work only for HF models + if tokenizer is None: + return chat_template(seq) + return tokenizer.apply_chat_template( + seq, add_generation_prompt=add_generation_prompt, tokenize=tokenize + ) + + +def instance_type(func): + def wrapper(request, **kwargs): + if len(request) == 1: + return func( + main=request[0], instance_type="loglikelihood_rolling", **kwargs + ) + elif request[0] == "": + return func(main=request[1], instance_type="acc_mutual_info", **kwargs) + elif isinstance(request[1], dict): + return func(main=request[0], instance_type="generate_until", **kwargs) + elif isinstance(request[0], list) and isinstance(request[1], (str, int, float)): + return func( + main=request[0], + additional=request[1], + instance_type="loglikelihood", + **kwargs, + ) + + return wrapper + + +def context_type(func): + def wrapper(main, additional="", instance_type="generate_until", **kwargs): + if isinstance(main, str): + return func(main, additional, instance_type, "string", **kwargs) + elif isinstance(main, list): + if isinstance(main[0], str): + return func(main, additional, instance_type, "no_template", **kwargs) + if isinstance(main[0], dict): + if isinstance(main[-1]["content"], list): + return func( + main, additional, instance_type, "chat_template", **kwargs + ) + else: + return func(main, additional, instance_type, "multiturn", **kwargs) + + return wrapper + + +def fewshots_truncation( + request, + target, + instance_type, + context_type, + first_system, + tokenizer, + truncation_args, + chat_template, + add_special_tokens, + max_new_tokens, + max_length, +): + if context_type == "string": + return request, "nothing" + if not len(request): + return request, "empty" + # do not cut off system prompt, so skip it if any + skip_system = int(first_system) + # whether to preserve the first element for truncation + skip_first = int(truncation_args["keep_first"]) + # small hack, with no tokenizer this case may be reduced to no chat template one + + if len(request) and isinstance(request[0], list): + req = [] + for lst in request: + if len(lst) == 1: + req.extend([lst[0]["content"]]) + else: + req.extend([lst[0]["content"] + lst[1]["content"]]) + else: + # do not change the initial sequence + req = request[:] + + if context_type == "no_template": + # append target for loglikelihood/multiple-choice, otherwise add empty string + req[-1] += target + # minimal length of request to truncate anything, +1 is for doc itself + min_number_elements = skip_system + skip_first + 1 + # if skip_first and zero-shot = error + if len(req) < min_number_elements: + return request, "bad params" + # TODO: do not tokenize the entire seq, tokenize shot by shot + tokens = tokenize_sequence( + req, tokenizer, add_special_tokens, truncation_args["on"] == "symbols" + ) + # remaining for user prompt tokens, for generation tasks subtract tokens to be generated + # TODO: take into account model type (seq2seq do not need subtraction) + remain_tokens = max_length - max_new_tokens * int( + instance_type == "generate_until" + ) + # accumulate the total sum + sum_seq = 0 + if truncation_args["side"] == "right": + # skip system prompt + start = skip_system + # keep doc (question from test set) and may keep last shot + end = len(tokens) - skip_first - 1 + reverse = 1 + else: + # consider left to be default option + # may skip first shot and skip system prompt + start = skip_system + skip_first + # keep doc (question from test set) + end = len(tokens) - 1 + reverse = -1 + # minimal amount of tokens decided by the user + sum_seq += sum(map(len, tokens[:start])) + sum(map(len, tokens[end:])) + if sum_seq > max_length: + return request, "error" + result = 0 + for seq in tokens[start:end][::reverse]: + sum_seq += len(seq) + if sum_seq > remain_tokens: + break + result += 1 + # final = system_prompt + shots + keep_first + doc (right) + if truncation_args["side"] == "right": + final = ( + request[:skip_system] + + request[skip_system : skip_system + result] + + request[-2 : -2 + skip_first] + + request[-1:] + ) + # final = system_prompt + keep_first + shots + doc (left) + else: + final = ( + request[:skip_system] + + request[skip_system : skip_system + skip_first] + + request[-1 - result : -1] + + request[-1:] + ) + + return final, result + skip_first + elif context_type == "chat_template": + if truncation_args["on"] != "symbols": + if first_system: + system = [req[0]] + system_tokens = apply_chat_template( + system, tokenizer, chat_template, False, True + ) + total_tokens = apply_chat_template( + [req[0], {"role": "user", "content": "".join(req[-1]["content"])}], + tokenizer, + chat_template, + add_generation_prompt=True, + tokenize=True, + ) + else: + system_tokens = [] + total_tokens = apply_chat_template( + [{"role": "user", "content": "".join(req[-1]["content"])}], + tokenizer, + chat_template, + add_generation_prompt=True, + tokenize=True, + ) + user_tokens = tokenize_sequence( + ["".join(req[-1]["content"])], tokenizer, add_special_tokens, False + )[0] + # offset = system prompt tokens + all special and generation prompt tokens that will always be in input + offset = len(system_tokens) + ( + len(total_tokens) - len(system_tokens) - len(user_tokens) + ) + # with no chat template there could be only two (system, user) or one (user) role in request + # user prompt is just a list of fewshots, this case is reduced to the no_template one + user, status = fewshots_truncation( + req[-1]["content"], + target, + instance_type, + "no_template", + False, + tokenizer, + truncation_args, + chat_template, + add_special_tokens, + max_new_tokens, + max_length - offset, + ) + else: + if first_system: + system = [req[0]["content"]] + len_system = len(system) + else: + len_system = 0 + user, status = fewshots_truncation( + req[-1]["content"], + target, + instance_type, + "no_template", + False, + tokenizer, + truncation_args, + chat_template, + add_special_tokens, + max_new_tokens, + max_length - len_system, + ) + if first_system: + final = [request[0], {"role": "user", "content": user}] + else: + final = [{"role": "user", "content": user}] + return final, status + elif context_type == "multiturn": + # for symbols truncation take into account only `content` of each dict + if truncation_args["on"] == "symbols": + groups = group_dicts(req, skip_system) + final, status = fewshots_truncation( + groups, + target, + instance_type, + "no_template", + first_system, + tokenizer, + truncation_args, + chat_template, + add_special_tokens, + max_new_tokens, + max_length, + ) + result = [] + for element in final: + for dictionary in element: + result.extend([dictionary]) + return result, status + else: + offset_max_len = max_length - max_new_tokens * int( + instance_type == "generate_until" + ) + # get number of fewshots (subtract system prompt and doc) + num_fewshots = (len(req) - skip_system - 1) // 2 + if num_fewshots == 0 and skip_first: + return request, "bad params" + # get the number of tokens for the entire request + tokenized_target = tokenize_sequence( + target, add_special_tokens=False, tokenizer=tokenizer, symbols=False + ) + tokens = ( + apply_chat_template( + req, + tokenizer, + chat_template, + tokenize=True, + add_generation_prompt=True, + ) + + tokenized_target + ) + # fits with no truncation + if len(tokens) <= offset_max_len: + return request, num_fewshots + # zero-shot, but still to long to fit into max_length + elif len(tokens) > offset_max_len and num_fewshots == 0: + return request, "error" + elif len(tokens) > offset_max_len and num_fewshots == 1 and skip_first: + return request, "bad params" + else: + # define the length of the system prompt (same for docs from the same task) + if first_system: + system_tokens = apply_chat_template( + [req[0]], + tokenizer, + chat_template, + add_generation_prompt=False, + tokenize=True, + ) + len_system = len(system_tokens) + else: + len_system = 0 + + system_and_doc = req[:skip_system] + [req[-1]] + sys_doc_tokens = apply_chat_template( + system_and_doc, + tokenizer, + chat_template, + tokenize=True, + add_generation_prompt=True, + ) + # even if has default system prompt, len_doc takes it into account + len_doc = len(sys_doc_tokens) - len_system + mean_tokens = (len(tokens) - len_system - len_doc) // num_fewshots + + approx_fewshots_num = ( + offset_max_len - len_system - len_doc - len(tokenized_target) + ) // mean_tokens - skip_first + + groups = group_dicts(req, skip_system) + + const_parts = [[], []] + if skip_system: + const_parts[0].extend(groups[0]) + if skip_first and truncation_args["side"] == "right": + const_parts[1].extend(groups[-2]) + start = skip_system + end = -2 + elif skip_first: + const_parts[0].extend(groups[skip_system]) + start = skip_system + 1 + end = -1 + else: + start = skip_system + end = -1 + const_parts[1].extend(groups[-1]) + + actual_shots = approx_fewshots_num + + if truncation_args["side"] == "right": + temp_result = ( + const_parts[0] + + unpack_group(groups[start : start + actual_shots]) + + const_parts[1] + ) + else: + temp_result = ( + const_parts[0] + + unpack_group(groups[start:end][::-1][:actual_shots][::-1]) + + const_parts[1] + ) + + sum_seq = len( + apply_chat_template( + temp_result, + tokenizer, + chat_template, + tokenize=True, + add_generation_prompt=True, + ) + ) + len(tokenized_target) + + if sum_seq > offset_max_len: + for i in range(1, num_fewshots - skip_first - actual_shots): + if truncation_args["side"] == "right": + temp_result = ( + const_parts[0] + + unpack_group( + groups[start : start + (actual_shots - i)] + ) + + const_parts[1] + ) + else: + temp_result = ( + const_parts[0] + + unpack_group( + groups[start:end][::-1][: (actual_shots - i)][::-1] + ) + + const_parts[1] + ) + sum_seq = len( + apply_chat_template( + temp_result, + tokenizer, + chat_template, + tokenize=True, + add_generation_prompt=True, + ) + ) + len(tokenized_target) + if sum_seq <= offset_max_len: + return temp_result, actual_shots - i + skip_first + if sum_seq > offset_max_len: + return request, "bad params" + elif sum_seq < offset_max_len: + prev = temp_result[:] + for i in range(actual_shots + 1, num_fewshots - skip_first + 1): + if truncation_args["side"] == "right": + temp_result = ( + const_parts[0] + + unpack_group(groups[start : start + i]) + + const_parts[1] + ) + else: + temp_result = ( + const_parts[0] + + unpack_group(groups[start:end][::-1][:i][::-1]) + + const_parts[1] + ) + sum_seq = len( + apply_chat_template( + temp_result, + tokenizer, + chat_template, + tokenize=True, + add_generation_prompt=True, + ) + ) + len(tokenized_target) + if sum_seq >= offset_max_len or i == num_fewshots - skip_first: + return prev, i + prev = temp_result + else: + return temp_result, actual_shots + skip_first + + +@instance_type +@context_type +def truncate( + request, + target, + instance_type, + context_type, + first_system, + tokenizer, + truncation_args, + chat_template, + add_special_tokens, + max_new_tokens, + max_length, + **kwargs, +): + if truncation_args["how"] == "fewshots": + return fewshots_truncation( + request, + target, + instance_type, + context_type, + first_system, + tokenizer, + truncation_args, + chat_template, + add_special_tokens, + max_new_tokens, + max_length, + ) + elif truncation_args["how"] == "no": + return request, "not_used" + return request, "not_implemented" + + +def restore_form(request, new_query, chat_template, tokenizer): + if isinstance(new_query, list): + if len(new_query): + if isinstance(new_query[0], str): + new_query = "".join(new_query) + elif isinstance(new_query[-1]["content"], list): + new_query[-1]["content"] = "".join(new_query[-1]["content"]) + else: + new_query = "" + + if not isinstance(new_query, str): + new_query = apply_chat_template( + new_query, + tokenizer, + chat_template, + tokenize=False, + add_generation_prompt=True, + ) + args = request.arguments + + if len(args) == 1: + new_pair = (new_query,) + elif args[0] == "": + new_pair = ("", new_query) + else: + new_pair = (new_query, args[1]) + + request.arguments = new_pair + return request + + +def truncate_and_chat_template( + request, lm, chat_template, truncation_args, first_system +): + if truncation_args["on"] == "symbols": + max_len = truncation_args["max_symbols"] + max_new = truncation_args["max_new_symbols"] + else: + max_len = getattr(lm, "max_length", 2048) + max_new = getattr(lm, "max_gen_toks", 256) + special_tokens = getattr(lm, "add_bos_token", False) + tokenizer = getattr(lm, "tokenizer", None) + req = request.arguments + new_query, status = truncate( + req, + first_system=first_system, + tokenizer=tokenizer, + truncation_args=truncation_args, + chat_template=chat_template, + add_special_tokens=special_tokens, + max_new_tokens=max_new, + max_length=max_len, + ) + processed_request = restore_form(request, new_query, chat_template, tokenizer) + return processed_request, status