Skip to content

Commit

Permalink
[BUGFIX] Adding sequence truncation to max_seq_length in eval recipe (
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Oct 9, 2024
1 parent 8d96d6c commit 89f21c2
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def tok_batch_multimodal_encode(
self,
all_texts: List[str],
all_images: List[List[PIL.Image.Image]],
left_truncate_len: int = None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -218,6 +219,12 @@ def tok_batch_multimodal_encode(

# Convert the batch to the format expected by the HF
tok_batch["input_ids"] = tok_batch.pop("tokens")

# the harness will use left_truncate_len to indicate that the current batch
# needs to be truncated to self.max_seq_len - self.max_gen_toks
if left_truncate_len is not None:
tok_batch["input_ids"] = tok_batch["input_ids"][:, -left_truncate_len:]

return tok_batch

@torch.inference_mode()
Expand Down Expand Up @@ -370,7 +377,7 @@ def tok_encode(self, text: str, **kwargs) -> List[int]:
return self._tokenizer.encode(text=text, add_bos=False, add_eos=False)

def tok_batch_encode(
self, text: List[str], **kwargs
self, text: List[str], left_truncate_len: int = None, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
tokenized_text = [self.tok_encode(x) for x in text]

Expand All @@ -381,6 +388,11 @@ def tok_batch_encode(
padding_value=self._tokenizer.pad_id,
)

# the harness will use left_truncate_len to indicate that the current batch
# needs to be truncated to self.max_seq_len - self.max_gen_toks
if left_truncate_len is not None:
x = x[:, -left_truncate_len:]

return x, torch.ones_like(x) # return 'mask' b/c it's expected by the harness

def tok_decode(self, tokens: Union[List[int], int], **kwargs) -> str:
Expand Down Expand Up @@ -506,11 +518,6 @@ def setup(self, cfg: DictConfig) -> None:

# Initialize tokenizer/transform
model_transform = config.instantiate(cfg.tokenizer)
max_seq_len = (
model_transform.max_seq_len
if model_transform.max_seq_len is not None
else 4096 # default max_seq_len to 4096
)

# Finally, we setup the actual EvalWrapper class
if isinstance(model, DeepFusionModel):
Expand All @@ -526,7 +533,7 @@ def setup(self, cfg: DictConfig) -> None:
model,
model_transform,
device=self.device,
max_seq_length=max_seq_len,
max_seq_length=cfg.max_seq_length,
batch_size=self.batch_size,
dtype=self.dtype,
enable_kv_cache=self.enable_kv_cache,
Expand Down

0 comments on commit 89f21c2

Please sign in to comment.