Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/llm as judge #41

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 71 additions & 13 deletions evalem/_base/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Iterable, List, Tuple

Check failure on line 6 in evalem/_base/metrics.py

View workflow job for this annotation

GitHub Actions / Flake8

evalem/_base/metrics.py#L6

'typing.Iterable' imported but unused (F401)

from jury import Jury
from sklearn.metrics import confusion_matrix

from ..misc.utils import format_to_jury
from .abc import AbstractBase
from .structures import (

Check failure on line 13 in evalem/_base/metrics.py

View workflow job for this annotation

GitHub Actions / Flake8

evalem/_base/metrics.py#L13

'.structures.MultiplePredictionInstance' imported but unused (F401)
EvaluationPredictionInstance,
EvaluationReferenceInstance,
MetricResult,
MultiplePredictionInstance,
MultipleReferenceInstance,
PredictionInstance,
SequenceType,
SinglePredictionInstance,
SingleReferenceInstance,
)


Expand Down Expand Up @@ -105,7 +109,64 @@
)

@staticmethod
def _flatten_references(
def _is_single_prediction_multi_reference(predictions, references) -> bool:
return isinstance(predictions, PredictionInstance) and isinstance(
references,
SequenceType,
)

@staticmethod
def _is_multi_prediction_single_reference(predictions, references) -> bool:
return isinstance(predictions, SequenceType) and isinstance(
references,
PredictionInstance,
)

@staticmethod
def _is_multi_prediction_multi_reference(predictions, references) -> bool:
return isinstance(predictions, SequenceType) and isinstance(
references,
SequenceType,
)

def _flatten_single_prediction_multi_reference(
self,
predictions: SinglePredictionInstance,
references: MultipleReferenceInstance,
) -> Tuple[SinglePredictionInstance, SingleReferenceInstance]:
res = []
for preds, refs in zip(predictions, references):
if Metric._is_single_prediction_multi_reference(preds, refs):
res.extend(list(map(lambda r: (preds, r), refs)))
else:
res.append((preds, refs))
predictions, references = zip(*res)
return predictions, references

def _flatten_multi_prediction_single_reference(
self,
predictions: MultipleReferenceInstance,
references: SingleReferenceInstance,
) -> Tuple[SinglePredictionInstance, SingleReferenceInstance]:
res = []
for preds, refs in zip(predictions, references):
if Metric._is_multi_prediction_single_reference(preds, refs):
res.extend(list(map(lambda p: (p, refs), preds)))
else:
res.append((preds, refs))
predictions, references = zip(*res)
return predictions, references

def _flatten_multi_prediction_multi_reference(
self,
predictions: MultipleReferenceInstance,
references: SingleReferenceInstance,
) -> Tuple[SinglePredictionInstance, SingleReferenceInstance]:
# No-op
return predictions, references

def _flatten_instances(
self,
predictions: EvaluationPredictionInstance,
references: EvaluationReferenceInstance,
) -> Tuple[EvaluationPredictionInstance, EvaluationReferenceInstance]:
Expand All @@ -124,17 +185,14 @@
Returns:
Tuple of flattened lists (predictions, references)
"""
res = []
for pred, ref in zip(predictions, references):
# if multiple predictions, skip for now
if isinstance(pred, SequenceType) and not isinstance(pred, str):
raise TypeError("Cannot handle multiple prediction instance")
# if multiple references
elif isinstance(ref, SequenceType) and not isinstance(ref, str):
res.extend(list(map(lambda r: (pred, r), ref)))
else:
res.append((pred, ref))
predictions, references = zip(*res)
predictions, references = self._flatten_multi_prediction_single_reference(
predictions,
references,
)
predictions, references = self._flatten_single_prediction_multi_reference(
predictions,
references,
)
return predictions, references


Expand Down Expand Up @@ -266,7 +324,7 @@
references,
)

predictions, references = self._flatten_references(predictions, references)
predictions, references = self._flatten_instances(predictions, references)

labels = self.__get_labels(predictions, references)
return MetricResult.from_dict(
Expand Down
6 changes: 3 additions & 3 deletions evalem/_base/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import deepcopy
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union

Check failure on line 8 in evalem/_base/structures.py

View workflow job for this annotation

GitHub Actions / Flake8

evalem/_base/structures.py#L8

'typing.Type' imported but unused (F401)

import numpy as np
import torch
Expand Down Expand Up @@ -84,14 +84,14 @@
# Represents type instance for any single downstream prediction
PredictionInstance = Union[
str,
Type[PredictionDTO],
PredictionDTO,
dict,
ImageTensor,
Type[ClassificationDTO],
ClassificationDTO,
]

# Represents type instance for any single downstream reference/ground-truth
ReferenceInstance = Union[str, Type[ReferenceDTO]]
ReferenceInstance = Union[str, ReferenceDTO]

SinglePredictionInstance = List[PredictionInstance]
MultiplePredictionInstance = List[List[PredictionInstance]]
Expand Down
11 changes: 10 additions & 1 deletion evalem/nlp/metrics/basics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#!/usr/bin/env python3

import dataclasses
from typing import Tuple

from ..._base.metrics import JuryBasedMetric
from ..._base.structures import (
EvaluationPredictionInstance,
EvaluationReferenceInstance,
MetricResult,
SinglePredictionInstance,
Expand All @@ -15,6 +17,13 @@ class ExactMatchMetric(JuryBasedMetric, NLPMetric):
def __init__(self) -> None:
super().__init__(metrics="exact_match")

def _flatten_multi_prediction_multi_reference(
self,
predictions: EvaluationPredictionInstance,
references: EvaluationReferenceInstance,
) -> Tuple[EvaluationPredictionInstance, EvaluationReferenceInstance]:
raise NotImplementedError()

def compute(
self,
predictions: SinglePredictionInstance,
Expand All @@ -24,7 +33,7 @@ def compute(
# This metric doesn't support multi-reference format.
# So, we flatten everything:
# Single Prediction, Multi-Ref -> Single Prediction, Single-Ref
predictions, references = self._flatten_references(predictions, references)
predictions, references = self._flatten_instances(predictions, references)
result = super().compute(
predictions=predictions,
references=references,
Expand Down
48 changes: 35 additions & 13 deletions evalem/nlp/metrics/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from loguru import logger
from outlines.models.openai import OpenAIConfig

from ..._base.structures import (

Check failure on line 12 in evalem/nlp/metrics/llm.py

View workflow job for this annotation

GitHub Actions / Flake8

evalem/nlp/metrics/llm.py#L12

'..._base.structures.SequenceType' imported but unused (F401)
EvaluationPredictionInstance,
EvaluationReferenceInstance,
MetricResult,
Expand Down Expand Up @@ -54,6 +54,15 @@
```aggregation_type```: ```Optional[AggregationType]```
Decides how to aggregate scores from the multiple judgement tries.
Defaults to `AggregationType.MEAN` if not provided.
```max_n```: ```Optional[int]```
If set, the total number of references or predictions per item.
This is to reduce LLM calls and thus minimizing scoring time.
Default behaviour is no truncation when set to `None` or less than 1.
will be truncated.
- If single reference, multiple predictions, total number of prediction will
be truncated
- If multiple reference, single prediction, total number of
reference will be truncated
```debug```:```bool```
Boolean flag for debug-mode outputs

Expand Down Expand Up @@ -103,21 +112,29 @@
temperature: float = 0.0,
prompt: Optional[str] = None,
aggregation_type: Optional[List[AggregationType]] = None,
max_n: Optional[int] = None,
debug: bool = False,
) -> None:
super().__init__(debug=debug)

model = self.__clean_model(model)
api_base = self.__clean_url(api_base)
self.model = outlines.models.openai(
self.__clean_model(model),
model,
base_url=api_base,
api_key=api_key,
config=OpenAIConfig(temperature=temperature),
)
self.api_base = self.__clean_url(api_base)
self.api_base = api_base
self.n_tries = n_tries or 1
self.prompt = prompt or LLMAsJudgeMetric._prompt
self.aggregation_type = aggregation_type or AggregationType.MEAN

self._sanity_check_prmopt(self.prompt)
self.max_n = max_n or None
if self.max_n:
logger.warning(
f"Total number of predictions/references per item will be truncated based on `max_n` value.",

Check failure on line 136 in evalem/nlp/metrics/llm.py

View workflow job for this annotation

GitHub Actions / Flake8

evalem/nlp/metrics/llm.py#L136

F-string is missing placeholders (F541)
)

def _sanity_check_prmopt(self, prompt: str) -> bool:
if "{prediction}" not in prompt or "{reference}" not in prompt:
Expand All @@ -133,24 +150,25 @@

def __clean_url(self, url: str) -> str:
if not url.endswith("/v1"):
url = urljoin(url, "/v1")
url = urljoin(url, "v1")
return url

@staticmethod
def _flatten_references(
def _flatten_instances(
self,
predictions,
references,
max_n: Optional[int] = None,
) -> Tuple[EvaluationPredictionInstance, EvaluationReferenceInstance]:
if max_n is not None and max_n < 1:
max_n = None
res = []
for preds, refs in zip(predictions, references):
# multiple predictions, single reference
if isinstance(preds, SequenceType) and isinstance(refs, str):
res.extend(list(map(lambda p: (p, refs), preds)))

if self._is_multi_prediction_single_reference(preds, refs):
res.extend(list(map(lambda p: (p, refs), preds[slice(max_n)])))
# single prediction, multiple references
elif isinstance(preds, str) and isinstance(refs, SequenceType):
res.extend(list(map(lambda r: (preds, r), refs)))

elif self._is_single_prediction_multi_reference(preds, refs):
res.extend(list(map(lambda r: (preds, r), refs[slice(max_n)])))
# single prediction, single reference
else:
res.append((preds, refs))
Expand All @@ -165,7 +183,11 @@
**kwargs,
) -> MetricResult:
# make sure to flatten
predictions, references = self._flatten_references(predictions, references)
predictions, references = self._flatten_instances(
predictions,
references,
max_n=self.max_n,
)
if self.debug:
logger.debug(f"Evaluating for {len(predictions)} predictions.")
generator = outlines.generate.choice(self.model, ["0", "1"])
Expand Down
Loading