Skip to content

[Feature]: Add efficient interface for evaluating probabilities of fixed prompt-completion pairs #5234

@xinyangz

Description

@xinyangz

Proposed Feature

Add an efficient interface for generation probabilities on fixed prompt and completion pairs. For example:

# ... load LLM or engine
prompt_completion_pairs = [
    ("1 + 1 = ", "2"),
    ("1 + 1 = ", "3"),
]
prompts, completions = list(zip(*prompt_completion_pairs))
probs = llm.completion_logprobs(prompts=prompts, completions=completions)

Alternatively, the interface could evaluate the probabilities of a fixed prompt with multiple generation options to better leverage prefix caching:

prompt = "1 + 1 = "
completions = ["2", "3", "4"]
probs = llm.completion_logprobs(prompt=prompt, completions=completions)

Currently, there are interfaces in class SamplingParams to return the log probabilities of prompt (prompt_logprobs) and the generated tokens (logprobs). However, they are either inefficient or has incomplete support for this use case.

Motivation

The motivation of this feature comes from LLM evaluations on multiple-choice questions (e.g., MMLU). vLLM is a popular tool adopted by mainstream LLM evaluation frameworks (e.g., lm-evaluation-harness) for this purpose.

Using the following example:

Question: Which of the following is true?
(A) ABC
(B) DEF
The answer is:

Evaluating a base LLM on this question involves calculating the probability on each choice $P_{\text{LLM}}(\text{choice} \mid \text{question})$ and selecting the choice with the highest probability.

Current solution

Currently, lm-evaluation-harness runs two generations and evaluate the full prompt probabilities for this purpose.

question = "1 + 1 = "
choices = ["2", "3", "4"]
prompts = [question + c for c in choices]
sampling_params = SamplingParams(temperature=0, max_tokens=1, prompt_logprobs=1)
outputs = llm.generate(prompts=prompts, sampling_params=sampling_params)

Instead of evaluating probabilities on the choices, it evaluates on question + choices and runs through multiple generations because of the limitations in vLLM's user interface.

Efficiency issue with current solution

The issue of using prompt_logprobs is that it is very inefficient on long prompts.

Let's use the following minimal profiling example (profiling.py):

import time

import numpy as np
from vllm import LLM, SamplingParams

n_seqs = 100
vocab_size = 10_000
seq_len = 4000
data = np.random.randint(0, vocab_size, (n_seqs, seq_len)).tolist()

llm = LLM("mistral-community/Mistral-7B-v0.2", max_model_len=8000, gpu_memory_utilization=0.6)
sampling_params = SamplingParams(temperature=0, max_tokens=1, prompt_logprobs=1)

start = time.perf_counter()
outputs = llm.generate(prompts=None, prompt_token_ids=data, sampling_params=sampling_params)
end = time.perf_counter()

print(f"Inference took {end - start:.4f} seconds")

Running the code with vLLM's official docker image:

docker run --gpus all --shm-size=10g --rm -e HF_TOKEN=[token] -v "$(pwd):/app" --entrypoint python3 vllm/vllm-openai:v0.4.3 /app/profiling.py

On a single A100-40G GPU, it runs around 500 seconds with prompt_logprobs=1 and only 27 seconds with no prompt_logprobs. Moreover, we can fit much longer input prompt if we turn it off.

Analysis on the efficiency issue

A quick search of prompt_logprobs takes us to Sampler.forward method in vllm/model_executor/layers/sampler.py.

First, we noticed the shape of logits changes from (1, vocab_size) to (input_len, vocab_size) if we set prompt_logprobs. Second, we found the get_logprobs involves lots of python for loops and CPU-GPU communications.

Potential Changes

I see two ways to fix the efficiency issue with the current approach.

Option 1: Use prompt_logprobs but don't calculate on the full prompt

We could reuse prompt_logprobs but limit the probability calculation to the final few tokens, so we don't have to pass around the large logits array.

Option 2: Use the sampling logprobs but constraint the generation on the choices

Currently, there is a controlled generation interface for OpenAI compatible server, but not for the offline inference.

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    feature requestNew feature or requeststaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions