@@ -181,6 +181,7 @@ def tok_batch_multimodal_encode(
181181 self ,
182182 all_texts : List [str ],
183183 all_images : List [List [PIL .Image .Image ]],
184+ left_truncate_len : int = None ,
184185 * args ,
185186 ** kwargs ,
186187 ):
@@ -218,6 +219,12 @@ def tok_batch_multimodal_encode(
218219
219220 # Convert the batch to the format expected by the HF
220221 tok_batch ["input_ids" ] = tok_batch .pop ("tokens" )
222+
223+ # the harness will use left_truncate_len to indicate that the current batch
224+ # needs to be truncated to self.max_seq_len - self.max_gen_toks
225+ if left_truncate_len is not None :
226+ tok_batch ["input_ids" ] = tok_batch ["input_ids" ][:, - left_truncate_len :]
227+
221228 return tok_batch
222229
223230 @torch .inference_mode ()
@@ -370,7 +377,7 @@ def tok_encode(self, text: str, **kwargs) -> List[int]:
370377 return self ._tokenizer .encode (text = text , add_bos = False , add_eos = False )
371378
372379 def tok_batch_encode (
373- self , text : List [str ], ** kwargs
380+ self , text : List [str ], left_truncate_len : int = None , ** kwargs
374381 ) -> Tuple [torch .Tensor , torch .Tensor ]:
375382 tokenized_text = [self .tok_encode (x ) for x in text ]
376383
@@ -381,6 +388,11 @@ def tok_batch_encode(
381388 padding_value = self ._tokenizer .pad_id ,
382389 )
383390
391+ # the harness will use left_truncate_len to indicate that the current batch
392+ # needs to be truncated to self.max_seq_len - self.max_gen_toks
393+ if left_truncate_len is not None :
394+ x = x [:, - left_truncate_len :]
395+
384396 return x , torch .ones_like (x ) # return 'mask' b/c it's expected by the harness
385397
386398 def tok_decode (self , tokens : Union [List [int ], int ], ** kwargs ) -> str :
@@ -506,11 +518,6 @@ def setup(self, cfg: DictConfig) -> None:
506518
507519 # Initialize tokenizer/transform
508520 model_transform = config .instantiate (cfg .tokenizer )
509- max_seq_len = (
510- model_transform .max_seq_len
511- if model_transform .max_seq_len is not None
512- else 4096 # default max_seq_len to 4096
513- )
514521
515522 # Finally, we setup the actual EvalWrapper class
516523 if isinstance (model , DeepFusionModel ):
@@ -526,7 +533,7 @@ def setup(self, cfg: DictConfig) -> None:
526533 model ,
527534 model_transform ,
528535 device = self .device ,
529- max_seq_length = max_seq_len ,
536+ max_seq_length = cfg . max_seq_length ,
530537 batch_size = self .batch_size ,
531538 dtype = self .dtype ,
532539 enable_kv_cache = self .enable_kv_cache ,
0 commit comments