Skip to content

Commit

Permalink
Added support for litellm as judge backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus committed Dec 22, 2024
1 parent 221d5d5 commit 81f02ca
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
39 changes: 36 additions & 3 deletions src/lighteval/metrics/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from tqdm import tqdm

from lighteval.utils.imports import is_openai_available, is_vllm_available
from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available


logging.getLogger("openai").setLevel(logging.ERROR)
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
model: str,
templates: Callable,
process_judge_response: Callable,
judge_backend: Literal["openai", "transformers", "tgi", "vllm"],
judge_backend: Literal["litellm", "openai", "transformers", "tgi", "vllm"],
url: str | None = None,
api_key: str | None = None,
):
Expand All @@ -93,7 +93,7 @@ def __init__(

def __lazy_load_client(self):
match self.backend:
# Wether we use openai or TGI models, we go trhough the openai API
# Wether we use openai or TGI models, we go through the openai API
# to route to the endpoint
case "openai" | "tgi" if is_openai_available():
if self.client is None:
Expand All @@ -104,6 +104,8 @@ def __lazy_load_client(self):
else:
self.client = OpenAI(base_url=self.url, api_key=self.api_key)
return self.__call_api_parallel
case "litellm" if is_litellm_available():
return self.__call_litellm
case "vllm" if is_vllm_available():
if self.pipe is None:
from vllm import LLM, SamplingParams
Expand Down Expand Up @@ -187,6 +189,37 @@ def __call_vllm(self, prompt):
outputs = [output.outputs[0].text for output in output]
return outputs

def __call_litellm(self, prompts):
import litellm

def __call_api(prompt):
for _ in range(self.API_MAX_RETRY):
try:
response = litellm.completion(
model=self.model,
messages=prompt,
response_format={"type": "text"},
max_tokens=512,
n=1,
caching=True,
)
text = response.choices[0].message.content
return text
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
raise Exception("Failed to get response from the API")

results = []
with ThreadPoolExecutor(100) as executor:
for entry in tqdm(executor.map(__call_api, prompts), total=len(prompts)):
results.append(entry)

if None in results:
raise ValueError("Some entries are not annotated due to errors in annotate_p, please inspect and retry.")

return results

def __call_api_parallel(self, prompts):
results = []
with ThreadPoolExecutor(100) as executor:
Expand Down
5 changes: 4 additions & 1 deletion src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ def __init__(
judge_model_name: str,
template: Callable,
process_judge_response: Callable,
judge_backend: Literal["openai", "transformers", "vllm", "tgi"],
judge_backend: Literal["litellm", "openai", "transformers", "vllm", "tgi"],
short_judge_name: str | None = None,
) -> None:
match judge_backend:
Expand All @@ -871,6 +871,9 @@ def __init__(
case "tgi":
api_key = os.getenv("HF_TOKEN")
url = "https://api-inference.huggingface.co/v1/"
case "litellm":
api_key = None
url = None
case "transformers" | "vllm":
api = HfApi()
models = api.list_models(model_name=judge_model_name)
Expand Down

0 comments on commit 81f02ca

Please sign in to comment.