-
Notifications
You must be signed in to change notification settings - Fork 431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Eval recipe not using max_seq_length #1644
Comments
We should address this by truncating the context ourselves in the generate call. |
@SalmanMohammadi What's the update here? Eleuther is going to push a fix? |
Yep yep. https://github.com/EleutherAI/lm-evaluation-harness/pull/2353/files |
Not sure I follow - how does this deal with max_seq_len? |
After some investigation I think we'll need the above PR to land because it addresses these lines in # set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length which determine the appropriate length to truncate the prompt to given the configured The actual truncation happens in # in lm_eval
context_enc, attn_masks = self.tok_batch_encode(
contexts,
left_truncate_len=max_ctx_len,
truncation=self.truncation,
)
# in torchtune
def tok_batch_encode(
self, text: List[str], left_truncate_len: int = None, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
tokenized_text = [self.tok_encode(x) for x in text]
# pad left
x = left_pad_sequence(
[torch.tensor(x) for x in tokenized_text],
batch_first=True,
padding_value=self._tokenizer.pad_id,
)
if left_truncate_len is not None:
x = x[:, -left_truncate_len:]
return x, torch.ones_like(x) # return 'mask' b/c it's expected by the harness |
cc @joecummings
The text was updated successfully, but these errors were encountered: