Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] Adding sequence truncation to max_seq_length in eval recipe #1773

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def tok_batch_multimodal_encode(
self,
all_texts: List[str],
all_images: List[List[PIL.Image.Image]],
left_truncate_len: int = None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -218,6 +219,12 @@ def tok_batch_multimodal_encode(

# Convert the batch to the format expected by the HF
tok_batch["input_ids"] = tok_batch.pop("tokens")

# the harness will use left_truncate_len to indicate that the current batch
# needs to be truncated to self.max_seq_len - self.max_gen_toks
if left_truncate_len is not None:
tok_batch["input_ids"] = tok_batch["input_ids"][:, -left_truncate_len:]

return tok_batch

@torch.inference_mode()
Expand Down Expand Up @@ -370,7 +377,7 @@ def tok_encode(self, text: str, **kwargs) -> List[int]:
return self._tokenizer.encode(text=text, add_bos=False, add_eos=False)

def tok_batch_encode(
self, text: List[str], **kwargs
self, text: List[str], left_truncate_len: int = None, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
tokenized_text = [self.tok_encode(x) for x in text]

Expand All @@ -381,6 +388,11 @@ def tok_batch_encode(
padding_value=self._tokenizer.pad_id,
)

# the harness will use left_truncate_len to indicate that the current batch
# needs to be truncated to self.max_seq_len - self.max_gen_toks
if left_truncate_len is not None:
x = x[:, -left_truncate_len:]

return x, torch.ones_like(x) # return 'mask' b/c it's expected by the harness

def tok_decode(self, tokens: Union[List[int], int], **kwargs) -> str:
Expand Down Expand Up @@ -506,11 +518,6 @@ def setup(self, cfg: DictConfig) -> None:

# Initialize tokenizer/transform
model_transform = config.instantiate(cfg.tokenizer)
max_seq_len = (
model_transform.max_seq_len
if model_transform.max_seq_len is not None
else 4096 # default max_seq_len to 4096
)

# Finally, we setup the actual EvalWrapper class
if isinstance(model, DeepFusionModel):
Expand All @@ -526,7 +533,7 @@ def setup(self, cfg: DictConfig) -> None:
model,
model_transform,
device=self.device,
max_seq_length=max_seq_len,
max_seq_length=cfg.max_seq_length,
batch_size=self.batch_size,
dtype=self.dtype,
enable_kv_cache=self.enable_kv_cache,
Expand Down
Loading