Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed mlx_lm.evaluate #1174

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 92 additions & 80 deletions llms/mlx_lm/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
from importlib.metadata import version
from pathlib import Path
from typing import Optional, Union
from typing import Optional

import lm_eval
import mlx.core as mx
Expand All @@ -20,11 +20,10 @@
from lm_eval.api.registry import register_model
from tqdm import tqdm

from .models.base import create_causal_mask
from .models.cache import make_prompt_cache
from .utils import load, stream_generate

PAD = 0


def _len_longest_common_prefix(a, b):
l = 0
Expand All @@ -43,31 +42,14 @@ def _rstrip_until(s, untils):
return s[: min(f)]


def _pad_inputs(
inputs,
maxlen,
genlen=0,
pad_left=False,
pad_multiple=32,
truncate=False,
):
# pad the prompts to the left with at least genlen tokens.
actual_maxlen = max(len(p) for p in inputs) + genlen
if actual_maxlen > maxlen:
if not truncate:
raise ValueError("Inputs are too long.")
else: # drop begining
actual_maxlen = maxlen
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
if pad_multiple > 0:
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
maxlen *= pad_multiple
assert PAD == 0
lr = np.array((1, 0) if pad_left else (0, 1))
return np.stack(
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
def _pad_inputs(inputs):
lengths = np.array([len(x) for x in inputs])
maxlen = lengths.max()
padded = np.stack(
[np.pad(x, (0, maxlen - len(x))) for x in inputs],
axis=0,
)
return mx.array(padded), mx.array(lengths)


@register_model("mlxlm")
Expand All @@ -83,32 +65,33 @@ def __init__(
self._batch_size = batch_size
self._model, self.tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self.tokenizer.model_max_length
self.use_chat_template = use_chat_template or (
self.use_chat_template = use_chat_template and (
self.tokenizer.chat_template is not None
)

def _score_fn(self, inputs, tokenize=True, step_size=32):
if tokenize:
inputs = self._tokenize(inputs)
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
inputs = mx.array(inputs)
def _score_fn(self, inputs, step_size: int = 64):
inputs, lengths = _pad_inputs(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:]

cache = make_prompt_cache(self._model)

mask = targets != PAD

scores, is_greedy = [], []
for i in range(0, inputs.shape[1], step_size):
logits = self._model(inputs[:, i : i + step_size], cache=cache)
inp = inputs[:, i : i + step_size]
T = inp.shape[1]

offset = cache[0].offset
mask = create_causal_mask(T, offset, lengths=lengths)
mask = mask == 0

logits = self._model(inp, cache=cache, mask=mask)
log_probs = nn.log_softmax(logits.astype(mx.float32))

score = mx.take_along_axis(
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
)[..., 0]
ig = mask[:, i : i + step_size] * (
targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
)
ig = targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
ig = mx.where(mx.arange(T) + offset < lengths[:, None], ig, False)

mx.eval(score, ig)
mx.metal.clear_cache()
Expand All @@ -119,38 +102,32 @@ def _score_fn(self, inputs, tokenize=True, step_size=32):
scores = mx.concatenate(scores, axis=1)
is_greedy = mx.concatenate(is_greedy, axis=1)

return scores, mask.sum(axis=-1), is_greedy
return scores, lengths, is_greedy

def _loglikelihood(self, texts, score_spans=None):
all_scores = mx.zeros(len(texts))
all_is_greedy = mx.zeros(len(texts), dtype=mx.bool_)
for i in tqdm(range(0, len(texts), self._batch_size)):
batch = texts[i : i + self._batch_size]
scores, lengths, is_greedy = self._score_fn(batch)

ind = np.arange(scores.shape[-1])
if score_spans is not None:
spans = score_spans[i : i + self._batch_size]
lengths = [end - start for start, end in spans]
masks = mx.array(
np.array([(ind >= start) & (ind < end) for start, end in spans])
)
else:
masks = ind[None] < lengths[:, None]

def _loglikelihood(self, texts, score_spans=None, tokenize=True):
# sort by length to get batches with little padding.
sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
sorted_spans = None
if score_spans is not None:
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]

results = []
for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
batch = sorted_inputs[i : i + self._batch_size]
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
for j in range(len(batch)):
if sorted_spans is None: # full sequence score
mask = mx.arange(scores[j].shape[-1]) < length
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
else: # subsequence score
start, end = sorted_spans[i + j]
score = scores[j][start:end].astype(mx.float32).sum()
ig = is_greedy[j][start:end].astype(mx.int32).sum()
length = end - start

results.append((score.item(), ig.item(), length))

# reorder the outputs
inv_sort = np.argsort(sorted_indices)
results = [results[inv_sort[i]] for i in range(len(results))]
scores = (masks * scores).sum(axis=-1)
is_greedy = (masks * is_greedy).sum(axis=-1)

return results
all_scores[i : i + self._batch_size] = scores
all_is_greedy[i : i + self._batch_size] = is_greedy == lengths

return all_scores, all_is_greedy

def _tokenize(self, texts):
return [
Expand Down Expand Up @@ -222,16 +199,53 @@ def loglikelihood(self, requests) -> list[tuple[float, bool]]:
+ "completion longer than context."
)

num_results = len(shortened)

# sort by length to get batches with little padding.
sorted_indices = sorted(range(len(shortened)), key=lambda i: -len(shortened[i]))
shortened = [shortened[i] for i in sorted_indices]
completion_spans = [completion_spans[i] for i in sorted_indices]

group = mx.distributed.init()

# split strided so we have approximately the same lengths on each node
shortened = shortened[group.rank() :: group.size()]
completion_spans = completion_spans[group.rank() :: group.size()]

# model scoring, returns num_requests x (logp, is_greedy, length).
results = self._loglikelihood(
scores, is_greedy = self._loglikelihood(
shortened,
score_spans=completion_spans,
tokenize=False,
)
return [(r[0], r[1] == r[2]) for r in results]

# all gather the results across groups
if group.size() > 1:
per_group = int(np.ceil(num_results / group.size()))
scores = mx.pad(scores, ((0, per_group - len(scores)),))
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
scores = scores.T.reshape(-1)
is_greedy = is_greedy.T.reshape(-1)

scores = np.array(scores[:num_results])
is_greedy = np.array(is_greedy[:num_results])

results = [(score, ig) for score, ig in zip(scores, is_greedy)]
inv_sort = np.argsort(sorted_indices)
results = [results[inv_sort[i]] for i in range(len(inv_sort))]
return results

tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template

def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
if len(chat_history) == 0:
return ""
return lm_eval.models.huggingface.HFLM.apply_chat_template(
chat_history, add_generation_prompt
)

def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
Expand Down Expand Up @@ -268,8 +282,9 @@ def loglikelihood_rolling(self, requests) -> list[float]:
logging.info(
"Estimating loglikelihood rolling for %d sequences." % len(requests)
)
inputs = [req.args[0] for req in requests]
return [t[0] for t in self._loglikelihood(inputs)]
inputs = self._tokenize([req.args[0] for req in requests])
scores, _ = self._loglikelihood(inputs)
return scores.tolist()

def generate_until(self, requests) -> list[str]:
"""Generate greedily until a stopping sequence
Expand Down Expand Up @@ -332,7 +347,7 @@ def main():
)
parser.add_argument(
"--limit",
default=1.0,
default=None,
help="Limit the number of examples per task.",
type=float,
)
Expand All @@ -346,11 +361,8 @@ def main():
)
parser.add_argument(
"--apply-chat-template",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was impossible to disable this before, so I've changed it to be off by default (which mirrors the lm_eval behavior)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about defaulting it to off for instruct models.. it seems like you would always want this on for most models that are used regularly? Does it make sense to change this to --ignore-chat-template instead to be able to shut it off if needed?

action=argparse.BooleanOptionalAction,
help="Specifies whether to apply a chat template to the prompt. If "
"the model has a chat template, this defaults to `True`, "
"otherwise `False`.",
default=None,
action="store_true",
help="Specifies whether to apply a chat template to the prompt.",
)
args = parser.parse_args()

Expand Down