-
Notifications
You must be signed in to change notification settings - Fork 523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add eleuther_eval as recipe #549
Changes from all commits
763fcfb
343b42d
08d78c5
e6e13e6
f969236
ad3d43c
ef71ec6
639e72c
a98a3fc
4c637f4
8ca72a6
d46f4ee
047411e
c240491
0c33741
666d17c
6ae6548
863ecea
8d2fade
4888bc7
ac8945d
0b7b276
c31e367
0100b37
b1e999c
79d1c55
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Config for EleutherEvalRecipe in eleuther_eval.py | ||
# | ||
# To launch, run the following command from root torchtune directory: | ||
# tune eleuther_eval --config llama2_eleuther_eval tasks=["truthfulqa_mc2", "hellaswag"] | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama2.llama2_7b | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelTorchTuneCheckpointer | ||
checkpoint_dir: /tmp/llama/ | ||
checkpoint_files: [finetuned_model.pt] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually the TorchTune checkpointer outputs a model with a specific name format. Maybe we should update the config to that name so folks can use OOTB There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I commented this somewhere else, but the specific name format requires an epoch number so there's no way to know exactly what the checkpoint file will be. I'm fine changing this name to resemble the outputted file, but it will not match OOTB. |
||
output_dir: /tmp/llama/ | ||
model_type: LLAMA2 | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/llama/tokenizer.model | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
seed: 217 | ||
|
||
# EleutherAI specific eval args | ||
tasks: ["truthfulqa_mc2"] | ||
limit: null | ||
max_seq_length: 4096 |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,203 @@ | ||||||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
# All rights reserved. | ||||||
# | ||||||
# This source code is licensed under the BSD-style license found in the | ||||||
# LICENSE file in the root directory of this source tree. | ||||||
|
||||||
import sys | ||||||
import time | ||||||
|
||||||
from typing import Any, Dict, List | ||||||
|
||||||
import torch | ||||||
from omegaconf import DictConfig | ||||||
|
||||||
from torch import nn | ||||||
|
||||||
from torchtune import config, utils | ||||||
from torchtune.modules import Tokenizer, TransformerDecoder | ||||||
from torchtune.recipe_interfaces import EvalRecipeInterface | ||||||
|
||||||
|
||||||
logger = utils.get_logger("DEBUG") | ||||||
|
||||||
try: | ||||||
import lm_eval | ||||||
from lm_eval.evaluator import evaluate | ||||||
from lm_eval.models.huggingface import HFLM | ||||||
from lm_eval.tasks import get_task_dict | ||||||
except ImportError: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This catches if the user has an incorrect version installed or if they don't have any version installed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this is basically our workaround so that (a) we can still run eleuther eval as a recipe and (b) we do not have to take every dep on god's green earth in our package? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oui - I think it's reasonable that certain recipes may require other dependencies and we can make sure it's called out, but we ourselves don't have to depend on it in our torchtune pkg. |
||||||
logger.error( | ||||||
"Recipe requires EleutherAI Eval Harness v0.4. Please install with `pip install lm_eval==0.4.*`" | ||||||
) | ||||||
sys.exit(1) | ||||||
|
||||||
|
||||||
class _EvalWrapper(HFLM): | ||||||
"""An EvalWrapper for EleutherAI's eval harness based on gpt-fast's | ||||||
EvalWrapper: https://github.com/pytorch-labs/gpt-fast/blob/main/eval.py. | ||||||
|
||||||
Args: | ||||||
model (TransformerDecoder): The model to evaluate. | ||||||
tokenizer (Tokenizer): The tokenizer to use. | ||||||
device (torch.device): The device to use. | ||||||
max_seq_length (int): The maximum sequence length to use. | ||||||
batch_size (int): The batch size per GPU to use. | ||||||
""" | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
model: TransformerDecoder, | ||||||
tokenizer: Tokenizer, | ||||||
*, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. KWARGS! |
||||||
device: torch.device, | ||||||
max_seq_length: int = 4096, | ||||||
batch_size: int = 32, | ||||||
): | ||||||
super().__init__(device=str(device)) | ||||||
self._model = model | ||||||
self._tokenizer = tokenizer | ||||||
self._max_seq_length = max_seq_length | ||||||
self._batch_size = batch_size | ||||||
|
||||||
@property | ||||||
def eot_token_id(self): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you mean this, but not fully sure
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope! This is an lm-eval specific function name we have to match exactly. |
||||||
return self._tokenizer.eos_id | ||||||
|
||||||
@property | ||||||
def max_length(self): | ||||||
return self._max_seq_length | ||||||
|
||||||
@property | ||||||
def max_gen_toks(self): | ||||||
return 256 | ||||||
|
||||||
@property | ||||||
def batch_size(self): | ||||||
return self._batch_size | ||||||
|
||||||
@property | ||||||
def device(self): | ||||||
return self._device | ||||||
|
||||||
def tok_encode(self, text: str, **kwargs) -> List[int]: | ||||||
# Note on add_bos flag: setting to False as this gives better results, for example | ||||||
# +1% on truthfulqa_mc2 with a LoRA finetune. lit-gpt also sets this to False, | ||||||
# see https://github.com/Lightning-AI/lit-gpt/blob/main/eval/lm_eval_harness.py#L66, | ||||||
# though notably fast-gpt does the opposite | ||||||
# https://github.com/pytorch-labs/gpt-fast/blob/main/eval.py#L123. | ||||||
return self._tokenizer.encode(text=text, add_bos=False, add_eos=False) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very noob question, do all of these evals expect just tokenization? Do we need to make sure we align this with the special tokens used during training? Or how does that work? Sorry atm I'm a bit confused about the use of special tokens. cc: @RdoubleA There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A good question and one that looks like is still being discussed on Eleuther's Github: EleutherAI/lm-evaluation-harness#1017. They typically handle it all through HF's tokenization, which does special token processing under the hood. If we use special tokens during training, we will need to update our tokenizer with those special tokens. As we're not doing that at all, I think we tackle that as part of the changes required to do that? Otherwise, this PR will have to implement those changes for training, as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeh definitely not something we should address here, but something we should call out (if needed), esp if this means our eval results will be sub-optimal There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all special tokens should be coupled with the tokenizer, and ideally we need to keep those special tokens that were used for training when we do eval. This should be the same tokenizer object that is used in evals, so as long as we're calling the tokenizer the same way as we do in training we should be good? I don't fully understand why people are removing these, because it gave a slight boost to performance? |
||||||
|
||||||
def tok_decode(self, tokens: List[int], **kwargs) -> str: | ||||||
return self._tokenizer.decode(tokens) | ||||||
|
||||||
def _model_call(self, inps: torch.Tensor, **kwargs) -> torch.Tensor: | ||||||
return self._model(inps) | ||||||
|
||||||
def _model_generate(self, *args, **kwargs): | ||||||
raise RuntimeError( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Found this out the hard way. In a rough estimate, 85% of all tasks in Eleuther are not free generation so we have the majority of our bases covered. However, if people open a bunch of issues asking for this, we can add a generation method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the reason to fail on this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I understand the question completely, but here's some possible responses. Why raise an error here instead of letting it fail in Eleuther? Better UX, more descriptive message. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The second one - got you! |
||||||
"This recipe does not currently support tasks that evaluate free generation," | ||||||
"e.g. `truthfulqa_gen` or `bigbench_color_generate_until`." | ||||||
) | ||||||
|
||||||
|
||||||
class EleutherEvalRecipe(EvalRecipeInterface): | ||||||
"""This recipe runs evaluation on a trained model using EleutherAI's eval harness. | ||||||
This assumes the user has the EleutherAI eval harness installed. | ||||||
|
||||||
This recipe supports: | ||||||
- Single GPU evaluation | ||||||
- Loading model in fp32 or bf16 | ||||||
- Any task from the EleutherAI eval harness that is *not* free generation | ||||||
|
||||||
Assumptions: | ||||||
- Evaluation is launched with the Tune CLI (recommended) | ||||||
- User has the EleutherAI eval harness installed, see https://github.com/EleutherAI/lm-evaluation-harness | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I would move this point up to the top of the docstring (right after the first sentence). Just say: "To run this recipe, make sure you have installed ..." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a comment. |
||||||
|
||||||
The following configs can be used to run this recipe: | ||||||
- eleuther_eval.yaml | ||||||
|
||||||
Args: | ||||||
cfg (DictConfig): OmegaConf object parsed from YAML file | ||||||
""" | ||||||
|
||||||
def __init__(self, cfg: DictConfig) -> None: | ||||||
self._cfg = cfg | ||||||
|
||||||
def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]: | ||||||
checkpointer = config.instantiate(checkpointer_cfg) | ||||||
checkpoint_dict = checkpointer.load_checkpoint() | ||||||
return checkpoint_dict | ||||||
|
||||||
def setup(self) -> None: | ||||||
self._device = utils.get_device(device=self._cfg.device) | ||||||
self._dtype = utils.get_dtype(dtype=self._cfg.dtype) | ||||||
self._limit = self._cfg.limit | ||||||
self._tasks = list(self._cfg.tasks) | ||||||
|
||||||
seed = utils.set_seed(seed=self._cfg.seed) | ||||||
logger.info(f"Random seed set to {seed}.") | ||||||
|
||||||
ckpt_dict = self.load_checkpoint(self._cfg.checkpointer) | ||||||
self._model = self._setup_model( | ||||||
model_cfg=self._cfg.model, | ||||||
model_state_dict=ckpt_dict[utils.MODEL_KEY], | ||||||
) | ||||||
self._tokenizer = config.instantiate(self._cfg.tokenizer) | ||||||
logger.info("Tokenizer is initialized from file.") | ||||||
|
||||||
def _setup_model( | ||||||
self, | ||||||
model_cfg: DictConfig, | ||||||
model_state_dict: Dict[str, Any], | ||||||
) -> nn.Module: | ||||||
with utils.set_default_dtype(self._dtype), self._device: | ||||||
model = config.instantiate(model_cfg) | ||||||
|
||||||
model.load_state_dict(model_state_dict) | ||||||
|
||||||
# Validate model was loaded in with the expected dtype. | ||||||
utils.validate_expected_param_dtype(model, dtype=self._dtype) | ||||||
logger.info(f"Model is initialized with precision {self._dtype}.") | ||||||
return model | ||||||
|
||||||
@torch.no_grad() | ||||||
def evaluate(self) -> None: | ||||||
t1 = time.time() | ||||||
|
||||||
model_eval_wrapper = _EvalWrapper( | ||||||
self._model, | ||||||
self._tokenizer, | ||||||
device=self._device, | ||||||
max_seq_length=self._cfg.max_seq_length, | ||||||
) | ||||||
|
||||||
# Task initialization API changed between v0.4.1 and 0.4.2 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copied this from gpt-fast There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we give a bit more detail here? And maybe an explicit type of exception? |
||||||
try: | ||||||
lm_eval.tasks.initialize_tasks() | ||||||
except Exception: | ||||||
pass | ||||||
|
||||||
task_dict = get_task_dict(self._tasks) | ||||||
logger.info(f"Running evaluation on {self._tasks} tasks.") | ||||||
eleuther_output = evaluate( | ||||||
model_eval_wrapper, | ||||||
task_dict, | ||||||
limit=self._limit, | ||||||
) | ||||||
|
||||||
logger.info(f"Eval completed in {time.time() - t1:.02f} seconds.") | ||||||
for task, res in eleuther_output["results"].items(): | ||||||
logger.info(f"{task}: {res}") | ||||||
|
||||||
|
||||||
@config.parse | ||||||
def recipe_main(cfg: DictConfig) -> None: | ||||||
"""Entry point for the recipe.""" | ||||||
recipe = EleutherEvalRecipe(cfg=cfg) | ||||||
recipe.setup() | ||||||
recipe.evaluate() | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
sys.exit(recipe_main()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,8 +7,5 @@ sentencepiece | |
tqdm | ||
omegaconf | ||
|
||
# Evaluation | ||
lm_eval==0.4.1 | ||
|
||
# Quantization | ||
torchao-nightly |
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😎