diff --git a/lm_eval/models/api_models.py b/lm_eval/models/api_models.py index b16ac67ba8..25223e906a 100644 --- a/lm_eval/models/api_models.py +++ b/lm_eval/models/api_models.py @@ -104,9 +104,8 @@ def __init__( self._truncate = truncate self._max_gen_toks = int(max_gen_toks) self._seed = int(seed) - # max_length - 1 as we always have 1 token for generation - eval_logger.info(f"Using max length {max_length} - 1") - self.max_length = max_length - 1 + eval_logger.info(f"Using max length {max_length}") + self.max_length = max_length if int(num_concurrent) <= 1: eval_logger.info( "Concurrent requests are disabled. To enable concurrent requests, set `num_concurrent` > 1." @@ -419,9 +418,10 @@ def batch_logliklehood_requests( cache_keys = [] for chunk in chunks: for cache_key, context_enc, continuation_enc in chunk: - inp = (context_enc + continuation_enc)[-(self.max_length) :] + # max_length - 1 as we always have 1 token for generation + inp = (context_enc + continuation_enc)[-(self.max_length - 1) :] ctxlen = len(context_enc) - max( - 0, len(context_enc) + len(continuation_enc) - (self.max_length) + 0, len(context_enc) + len(continuation_enc) - (self.max_length - 1) ) inputs.append(inp)