From 6f955501decc60004a500db56b7b1716aa05cf68 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 8 Oct 2024 23:37:01 +0100 Subject: [PATCH] fixign --- recipes/eleuther_eval.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index ce07497899..dc96a48106 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -181,6 +181,7 @@ def tok_batch_multimodal_encode( self, all_texts: List[str], all_images: List[List[PIL.Image.Image]], + left_truncate_len: int = None, *args, **kwargs, ): @@ -218,6 +219,12 @@ def tok_batch_multimodal_encode( # Convert the batch to the format expected by the HF tok_batch["input_ids"] = tok_batch.pop("tokens") + + # the harness will use left_truncate_len to indicate that the current batch + # needs to be truncated to self.max_seq_len - self.max_gen_toks + if left_truncate_len is not None: + tok_batch["input_ids"] = tok_batch["input_ids"][:, -left_truncate_len:] + return tok_batch @torch.inference_mode() @@ -370,7 +377,7 @@ def tok_encode(self, text: str, **kwargs) -> List[int]: return self._tokenizer.encode(text=text, add_bos=False, add_eos=False) def tok_batch_encode( - self, text: List[str], **kwargs + self, text: List[str], left_truncate_len: int = None, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor]: tokenized_text = [self.tok_encode(x) for x in text] @@ -381,6 +388,11 @@ def tok_batch_encode( padding_value=self._tokenizer.pad_id, ) + # the harness will use left_truncate_len to indicate that the current batch + # needs to be truncated to self.max_seq_len - self.max_gen_toks + if left_truncate_len is not None: + x = x[:, -left_truncate_len:] + return x, torch.ones_like(x) # return 'mask' b/c it's expected by the harness def tok_decode(self, tokens: Union[List[int], int], **kwargs) -> str: @@ -506,11 +518,6 @@ def setup(self, cfg: DictConfig) -> None: # Initialize tokenizer/transform model_transform = config.instantiate(cfg.tokenizer) - max_seq_len = ( - model_transform.max_seq_len - if model_transform.max_seq_len is not None - else 4096 # default max_seq_len to 4096 - ) # Finally, we setup the actual EvalWrapper class if isinstance(model, DeepFusionModel): @@ -526,7 +533,7 @@ def setup(self, cfg: DictConfig) -> None: model, model_transform, device=self.device, - max_seq_length=max_seq_len, + max_seq_length=cfg.max_seq_length, batch_size=self.batch_size, dtype=self.dtype, enable_kv_cache=self.enable_kv_cache,