Skip to content

Commit

Permalink
Fix loglikelihood_rolling caching ( EleutherAI#1821 ) (EleutherAI#2187
Browse files Browse the repository at this point in the history
)

* fix revision type

* allow for None-input loglikelihood reqs to be cached

* handle no remaining cache items

* pre-commit

* change cache_hook.add_partial(loglikelihood_rolling...) convention

---------

Co-authored-by: Baber Abbasi <baber@eleuther.ai>
  • Loading branch information
haileyschoelkopf and baberabb authored Aug 28, 2024
1 parent 2de3688 commit 8138fd5
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 7 deletions.
7 changes: 5 additions & 2 deletions lm_eval/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,11 @@ def fn(requests):
eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
)
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
if remaining_reqs:
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
else:
rem_res = []

# stick the new ones back into the list and also cache any of the new ones
resptr = 0
Expand Down
5 changes: 4 additions & 1 deletion lm_eval/models/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def _collate(req: LogLikelihoodInputs):
):
if answer_ is not None:
res.append(answer_)
# partial caching
# cache requests that aren't from a loglikelihood_rolling request
if cache_key is not None:
self.cache_hook.add_partial(
"loglikelihood", cache_key, answer_
Expand Down Expand Up @@ -638,4 +638,7 @@ def loglikelihood_rolling(

string_nll = sum(string_nll)
loglikelihoods.append(string_nll)

# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
return loglikelihoods
11 changes: 10 additions & 1 deletion lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,9 @@ def loglikelihood_rolling(
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)

# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)

return loglikelihoods

def _batch_scheduler(self, pos, n_reordered_requests):
Expand Down Expand Up @@ -1246,7 +1249,13 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):

res.append(answer)

self.cache_hook.add_partial("loglikelihood", request_str, answer)
if request_str is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
# all with cache key None. instead do add_partial on the per-example level
# in the loglikelihood_rolling() function for those.
self.cache_hook.add_partial(
"loglikelihood", request_str, answer
)
pbar.update(1)

pbar.close()
Expand Down
6 changes: 6 additions & 0 deletions lm_eval/models/nemo_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ def loglikelihood_rolling(

string_nll = sum(string_nll)
loglikelihoods.append(string_nll)

# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
return loglikelihoods

def _loglikelihood_tokens(self, requests, disable_tqdm=False):
Expand Down Expand Up @@ -468,6 +471,9 @@ def _collate(x):
answer = (logprob, is_greedy)

if cache_key is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
# all with cache key None. instead do add_partial on the per-example level
# in the loglikelihood_rolling() function for those.
self.cache_hook.add_partial("loglikelihood", cache_key, answer)

res.append(answer)
Expand Down
3 changes: 3 additions & 0 deletions lm_eval/models/neuralmagic.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ def _collate(x):
res.append(answer)

if cache_key is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
# all with cache key None. instead do add_partial on the per-example level
# in the loglikelihood_rolling() function for those.
self.cache_hook.add_partial("loglikelihood", cache_key, answer)

return re_ord.get_original(res)
Expand Down
9 changes: 7 additions & 2 deletions lm_eval/models/neuron_optimum.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):

string_nll = sum(string_nll)
loglikelihoods.append(string_nll)

# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
return loglikelihoods

def _loglikelihood_tokens(
Expand Down Expand Up @@ -620,7 +621,11 @@ def _collate(x):

res.append(answer)

self.cache_hook.add_partial("loglikelihood", cache_key, answer)
if cache_key is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
# all with cache key None. instead do add_partial on the per-example level
# in the loglikelihood_rolling() function for those.
self.cache_hook.add_partial("loglikelihood", cache_key, answer)

return re_ord.get_original(res)

Expand Down
8 changes: 7 additions & 1 deletion lm_eval/models/vllm_causallms.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ def loglikelihood_rolling(

string_nll = sum(string_nll)
loglikelihoods.append(string_nll)

# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)

return loglikelihoods

def generate_until(
Expand Down Expand Up @@ -453,8 +457,10 @@ def _collate(x):

res.append(answer)

# partial caching
if cache_key is not None:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
# all with cache key None. instead do add_partial on the per-example level
# in the loglikelihood_rolling() function for those.
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
pbar.update(1)
pbar.close()
Expand Down

0 comments on commit 8138fd5

Please sign in to comment.