Skip to content

Commit 6f95550

Browse files
fixign
1 parent 27b0fcc commit 6f95550

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

recipes/eleuther_eval.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def tok_batch_multimodal_encode(
181181
self,
182182
all_texts: List[str],
183183
all_images: List[List[PIL.Image.Image]],
184+
left_truncate_len: int = None,
184185
*args,
185186
**kwargs,
186187
):
@@ -218,6 +219,12 @@ def tok_batch_multimodal_encode(
218219

219220
# Convert the batch to the format expected by the HF
220221
tok_batch["input_ids"] = tok_batch.pop("tokens")
222+
223+
# the harness will use left_truncate_len to indicate that the current batch
224+
# needs to be truncated to self.max_seq_len - self.max_gen_toks
225+
if left_truncate_len is not None:
226+
tok_batch["input_ids"] = tok_batch["input_ids"][:, -left_truncate_len:]
227+
221228
return tok_batch
222229

223230
@torch.inference_mode()
@@ -370,7 +377,7 @@ def tok_encode(self, text: str, **kwargs) -> List[int]:
370377
return self._tokenizer.encode(text=text, add_bos=False, add_eos=False)
371378

372379
def tok_batch_encode(
373-
self, text: List[str], **kwargs
380+
self, text: List[str], left_truncate_len: int = None, **kwargs
374381
) -> Tuple[torch.Tensor, torch.Tensor]:
375382
tokenized_text = [self.tok_encode(x) for x in text]
376383

@@ -381,6 +388,11 @@ def tok_batch_encode(
381388
padding_value=self._tokenizer.pad_id,
382389
)
383390

391+
# the harness will use left_truncate_len to indicate that the current batch
392+
# needs to be truncated to self.max_seq_len - self.max_gen_toks
393+
if left_truncate_len is not None:
394+
x = x[:, -left_truncate_len:]
395+
384396
return x, torch.ones_like(x) # return 'mask' b/c it's expected by the harness
385397

386398
def tok_decode(self, tokens: Union[List[int], int], **kwargs) -> str:
@@ -506,11 +518,6 @@ def setup(self, cfg: DictConfig) -> None:
506518

507519
# Initialize tokenizer/transform
508520
model_transform = config.instantiate(cfg.tokenizer)
509-
max_seq_len = (
510-
model_transform.max_seq_len
511-
if model_transform.max_seq_len is not None
512-
else 4096 # default max_seq_len to 4096
513-
)
514521

515522
# Finally, we setup the actual EvalWrapper class
516523
if isinstance(model, DeepFusionModel):
@@ -526,7 +533,7 @@ def setup(self, cfg: DictConfig) -> None:
526533
model,
527534
model_transform,
528535
device=self.device,
529-
max_seq_length=max_seq_len,
536+
max_seq_length=cfg.max_seq_length,
530537
batch_size=self.batch_size,
531538
dtype=self.dtype,
532539
enable_kv_cache=self.enable_kv_cache,

0 commit comments

Comments
 (0)