Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add new truncation strategy #2300

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions lm_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager
from lm_eval.truncation_utils import process_truncation_args
from lm_eval.utils import handle_non_serializable, make_table, simple_parse_args_string


Expand Down Expand Up @@ -257,6 +258,28 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
)
parser.add_argument(
"--truncation_args",
type=str,
default="how=default,on=tokens,side=left,keep_first=False,max_symbols=2048,max_new_symbols=256",
help=(
"The truncation mode. Available options:"
"\n`how`:\n\t`no` - no predefined truncation, default option"
"\n\t`default` - regular harness truncation style, default option"
"\n\t`fewshots` - truncate only fewshots, for zero-shot does nothing"
"\n\t`user` - truncate only user-prompt"
"\n\t`transformers` - use transformers truncation"
"\n`on`:\n\t`tokens` - truncate by tokens, tokenizer required, uses models' max_length param, default option"
"\n\t`symbols` - truncate by symbols, does not use tokenizer, require max_symbols setting"
"\n`side`:\n\t`left` - truncate from the left side, default option"
"\n\t`right` - truncate from the right side, the test doc is preserved"
"\n`keep_first`:\n\t`true` - keep the first few-shot from the defined `side`,"
" ignored for `how` != `fewshots`"
"\n\t`false` - does not keep the first few-shot, default option"
"\n`max_symbols`:\n\tinteger, the maximum number of symbols in request, default is 2048"
"\n`max_new_symbols`:\n\tinteger, the maximum number of new symbols to subtract this value from `max_symbols`, default is 256"
),
)
return parser


Expand Down Expand Up @@ -314,6 +337,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)

# make it a valid dict with truncation params
truncation_args = process_truncation_args(args.truncation_args)

if args.tasks is None:
eval_logger.error("Need to specify task to evaluate.")
sys.exit()
Expand Down Expand Up @@ -404,6 +430,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
truncation_args=truncation_args,
**request_caching_args,
)

Expand Down
12 changes: 7 additions & 5 deletions lm_eval/api/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,28 @@ def get_context(self, doc, num_fewshot):
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]

labeled_examples = ""
labeled_examples = []
for doc in selected_docs:
one_example = ""
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
labeled_examples += (
one_example += (
doc_content
if self.config.doc_to_choice is None or isinstance(doc_content, str)
else self.doc_to_choice(doc)[doc_content]
)

if doc_target != "":
labeled_examples += self.target_delimiter
labeled_examples += (
one_example += self.target_delimiter
one_example += (
str(doc_target[0])
if isinstance(doc_target, list)
else doc_target
if self.config.doc_to_choice is None or isinstance(doc_target, str)
else str(self.doc_to_choice(doc)[doc_target])
)
labeled_examples += self.fewshot_delimiter
one_example += self.fewshot_delimiter
labeled_examples.extend([one_example])

return labeled_examples

Expand Down
123 changes: 83 additions & 40 deletions lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from lm_eval.api import samplers
from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.model import LM
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
Expand All @@ -39,6 +40,7 @@
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt
from lm_eval.truncation_utils import truncate_and_chat_template


ALL_OUTPUT_TYPES = [
Expand Down Expand Up @@ -382,6 +384,8 @@ def build_all_requests(
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
tokenizer_name: str = "",
truncation_args: Optional[Dict[str, Union[str, bool, int]]] = None,
lm: "LM" = None,
) -> None:
"""Build a set of Instances for a task, and store them in task.instances"""

Expand Down Expand Up @@ -435,13 +439,19 @@ def build_all_requests(
total=num_docs,
):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context(
# fewshot_ctx is a List like [system_prompt, fewshot1, fewshot2, ...]
# TODO: system_promt is defined once and added to all task samples
fewshot_ctx, first_system = self.fewshot_context(
doc,
0 if self.config.num_fewshot is None else self.config.num_fewshot,
system_instruction,
apply_chat_template,
fewshot_as_multiturn,
chat_template,
)

# just add the test sample at the end of the list
fewshot_ctx = self.add_test_sample(
doc, fewshot_ctx, apply_chat_template, fewshot_as_multiturn
)

# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
Expand All @@ -454,6 +464,12 @@ def build_all_requests(
if not isinstance(inst, list):
inst = [inst]

# TODO: add some notification system here based on returned status
for elem in inst:
elem, status = truncate_and_chat_template(
elem, lm, chat_template, truncation_args, first_system
)

instances.append(inst)

# now flatten, this is to allow slicing to work with pickles
Expand Down Expand Up @@ -1044,49 +1060,43 @@ def fewshot_context(
The fewshot context.
"""

if apply_chat_template:
labeled_examples = []
else:
labeled_examples = ""

# get task description
if description := self.config.description:
description = utils.apply_template(self.config.description, doc)
# always make a list of fewshots
labeled_examples = []
first_system = False

# create system prompt based on the provided system instruction and description
if system_instruction is not None and description:
system_prompt = (
f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
)
elif system_instruction is not None:
system_prompt = system_instruction
elif description:
system_prompt = description
else:
system_prompt = ""
system_prompt = self.define_system_prompt(
doc, system_instruction, apply_chat_template
)

# add system prompt if specified
if system_prompt:
if apply_chat_template:
labeled_examples.append({"role": "system", "content": system_prompt})
else:
labeled_examples = system_prompt
if system_prompt is not None:
labeled_examples.extend(system_prompt)
first_system = True

# if few-shot - append examples after the system prompt
# fewshots are still a list
if num_fewshot > 0:
if apply_chat_template:
labeled_examples.extend(
self.sampler.get_chat_context(
doc, num_fewshot, fewshot_as_multiturn
)
fewshots = self.sampler.get_chat_context(
doc, num_fewshot, fewshot_as_multiturn
)
else:
labeled_examples += self.sampler.get_context(doc, num_fewshot)
fewshots = self.sampler.get_context(doc, num_fewshot)
labeled_examples.extend(fewshots)

return labeled_examples, first_system

@utils.positional_deprecated
def add_test_sample(
self,
doc: str,
labeled_examples: List = [],
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
):
example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
return chat_template(labeled_examples)
return labeled_examples
if isinstance(example, str):
self.append_target_question(
labeled_examples, example, fewshot_as_multiturn
Expand All @@ -1098,7 +1108,7 @@ def fewshot_context(
for ex in example:
chat = deepcopy(labeled_examples)
self.append_target_question(chat, ex, fewshot_as_multiturn)
labeled_examples_list.append(chat_template(chat))
labeled_examples_list.append([chat])
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
Expand All @@ -1111,21 +1121,54 @@ def fewshot_context(
self.append_target_question(
labeled_examples, str(example), fewshot_as_multiturn
)
# return lm.apply_chat_template(labeled_examples)
return chat_template(labeled_examples)
return labeled_examples
else:
if self.multiple_input:
return labeled_examples
if isinstance(example, str):
return labeled_examples + example
return labeled_examples + [example]
elif isinstance(example, list):
return [labeled_examples + ex for ex in example]
labeled_examples_list = [labeled_examples + [ex] for ex in example]
return labeled_examples_list
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example]
return labeled_examples + [choices[example]]
else:
return labeled_examples + str(example)
return labeled_examples + [str(example)]

@utils.positional_deprecated
def define_system_prompt(
self,
doc: str,
system_instruction: Optional[str] = None,
apply_chat_template: bool = False,
):
# get task description
if description := self.config.description:
description = utils.apply_template(self.config.description, doc)

# create system prompt based on the provided system instruction and description
if system_instruction is not None and description:
system_prompt = (
f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
)
elif system_instruction is not None:
system_prompt = system_instruction
elif description:
system_prompt = description
else:
system_prompt = ""

# add system prompt if specified
if system_prompt:
# add system prompt if specified
if apply_chat_template:
return {"role": "system", "content": system_prompt}

return system_prompt
# TODO: returning None is bad practice
return None

def apply_filters(self):
"""Iterates over FilterEnsembles and applies them to instances"""
Expand Down
13 changes: 12 additions & 1 deletion lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
import time
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -74,6 +74,7 @@ def simple_evaluate(
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
truncation_args: Optional[Dict[str, Union[str, bool, int]]] = None,
):
"""Instantiate and evaluate a model on a list of tasks.

Expand Down Expand Up @@ -132,6 +133,8 @@ def simple_evaluate(
Random seed for torch. If set to None, the seed will not be set.
:param fewshot_random_seed: int
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
:param truncation_args: Dict[str, Union[str, bool, int]]
The mode of truncation applied to sequences.

:return
Dictionary of results
Expand Down Expand Up @@ -311,6 +314,7 @@ def _adjust_config(task_dict):
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
verbosity=verbosity,
truncation_args=truncation_args,
)

if lm.rank == 0:
Expand Down Expand Up @@ -345,6 +349,7 @@ def _adjust_config(task_dict):
"numpy_seed": numpy_random_seed,
"torch_seed": torch_random_seed,
"fewshot_seed": fewshot_random_seed,
"truncation_args": truncation_args,
}
)
results["git_hash"] = get_git_commit_hash()
Expand All @@ -370,6 +375,7 @@ def evaluate(
apply_chat_template: Union[bool, str] = False,
fewshot_as_multiturn: bool = False,
verbosity: str = "INFO",
truncation_args: Optional[Dict[str, Union[str, bool, int]]] = None,
):
"""Instantiate and evaluate a model on a list of tasks.

Expand All @@ -394,6 +400,9 @@ def evaluate(
Defaults to False (no chat template applied).
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param truncation_args: Dict[str, Union[str, bool, int]]
The mode of truncation applied to sequences.

:return
Dictionary of results
"""
Expand Down Expand Up @@ -456,6 +465,8 @@ def evaluate(
tokenizer_name=getattr(lm, "tokenizer_name", "")
if apply_chat_template
else "",
truncation_args=truncation_args,
lm=lm,
)
eval_logger.debug(
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
Expand Down
Loading
Loading