Skip to content

Commit

Permalink
Improve instance flattening based on robust checks on predictions and
Browse files Browse the repository at this point in the history
references
  • Loading branch information
NISH1001 committed Dec 14, 2024
1 parent 6498625 commit 897f061
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 24 deletions.
84 changes: 71 additions & 13 deletions evalem/_base/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
EvaluationPredictionInstance,
EvaluationReferenceInstance,
MetricResult,
MultiplePredictionInstance,
MultipleReferenceInstance,
PredictionInstance,
SequenceType,
SinglePredictionInstance,
SingleReferenceInstance,
)


Expand Down Expand Up @@ -105,7 +109,64 @@ def __call__(
)

@staticmethod
def _flatten_references(
def _is_single_prediction_multi_reference(predictions, references) -> bool:
return isinstance(predictions, PredictionInstance) and isinstance(
references,
SequenceType,
)

@staticmethod
def _is_multi_prediction_single_reference(predictions, references) -> bool:
return isinstance(predictions, SequenceType) and isinstance(
references,
PredictionInstance,
)

@staticmethod
def _is_multi_prediction_multi_reference(predictions, references) -> bool:
return isinstance(predictions, SequenceType) and isinstance(
references,
SequenceType,
)

def _flatten_single_prediction_multi_reference(
self,
predictions: SinglePredictionInstance,
references: MultipleReferenceInstance,
) -> Tuple[SinglePredictionInstance, SingleReferenceInstance]:
res = []
for preds, refs in zip(predictions, references):
if Metric._is_single_prediction_multi_reference(preds, refs):
res.extend(list(map(lambda r: (preds, r), refs)))
else:
res.append((preds, refs))
predictions, references = zip(*res)
return predictions, references

def _flatten_multi_prediction_single_reference(
self,
predictions: MultipleReferenceInstance,
references: SingleReferenceInstance,
) -> Tuple[SinglePredictionInstance, SingleReferenceInstance]:
res = []
for preds, refs in zip(predictions, references):
if Metric._is_multi_prediction_single_reference(preds, refs):
res.extend(list(map(lambda p: (p, refs), preds)))
else:
res.append((preds, refs))
predictions, references = zip(*res)
return predictions, references

def _flatten_multi_prediction_multi_reference(
self,
predictions: MultipleReferenceInstance,
references: SingleReferenceInstance,
) -> Tuple[SinglePredictionInstance, SingleReferenceInstance]:
# No-op
return predictions, references

def _flatten_instances(
self,
predictions: EvaluationPredictionInstance,
references: EvaluationReferenceInstance,
) -> Tuple[EvaluationPredictionInstance, EvaluationReferenceInstance]:
Expand All @@ -124,17 +185,14 @@ def _flatten_references(
Returns:
Tuple of flattened lists (predictions, references)
"""
res = []
for pred, ref in zip(predictions, references):
# if multiple predictions, skip for now
if isinstance(pred, SequenceType) and not isinstance(pred, str):
raise TypeError("Cannot handle multiple prediction instance")
# if multiple references
elif isinstance(ref, SequenceType) and not isinstance(ref, str):
res.extend(list(map(lambda r: (pred, r), ref)))
else:
res.append((pred, ref))
predictions, references = zip(*res)
predictions, references = self._flatten_multi_prediction_single_reference(
predictions,
references,
)
predictions, references = self._flatten_single_prediction_multi_reference(
predictions,
references,
)
return predictions, references


Expand Down Expand Up @@ -266,7 +324,7 @@ def compute(
references,
)

predictions, references = self._flatten_references(predictions, references)
predictions, references = self._flatten_instances(predictions, references)

labels = self.__get_labels(predictions, references)
return MetricResult.from_dict(
Expand Down
6 changes: 3 additions & 3 deletions evalem/_base/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ def __hash__(self) -> str:
# Represents type instance for any single downstream prediction
PredictionInstance = Union[
str,
Type[PredictionDTO],
PredictionDTO,
dict,
ImageTensor,
Type[ClassificationDTO],
ClassificationDTO,
]

# Represents type instance for any single downstream reference/ground-truth
ReferenceInstance = Union[str, Type[ReferenceDTO]]
ReferenceInstance = Union[str, ReferenceDTO]

SinglePredictionInstance = List[PredictionInstance]
MultiplePredictionInstance = List[List[PredictionInstance]]
Expand Down
11 changes: 10 additions & 1 deletion evalem/nlp/metrics/basics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#!/usr/bin/env python3

import dataclasses
from typing import Tuple

from ..._base.metrics import JuryBasedMetric
from ..._base.structures import (
EvaluationPredictionInstance,
EvaluationReferenceInstance,
MetricResult,
SinglePredictionInstance,
Expand All @@ -15,6 +17,13 @@ class ExactMatchMetric(JuryBasedMetric, NLPMetric):
def __init__(self) -> None:
super().__init__(metrics="exact_match")

def _flatten_multi_prediction_multi_reference(
self,
predictions: EvaluationPredictionInstance,
references: EvaluationReferenceInstance,
) -> Tuple[EvaluationPredictionInstance, EvaluationReferenceInstance]:
raise NotImplementedError()

def compute(
self,
predictions: SinglePredictionInstance,
Expand All @@ -24,7 +33,7 @@ def compute(
# This metric doesn't support multi-reference format.
# So, we flatten everything:
# Single Prediction, Multi-Ref -> Single Prediction, Single-Ref
predictions, references = self._flatten_references(predictions, references)
predictions, references = self._flatten_instances(predictions, references)
result = super().compute(
predictions=predictions,
references=references,
Expand Down
12 changes: 5 additions & 7 deletions evalem/nlp/metrics/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def __clean_url(self, url: str) -> str:
url = urljoin(url, "v1")
return url

@staticmethod
def _flatten_references(
def _flatten_instances(
self,
predictions,
references,
max_n: Optional[int] = None,
Expand All @@ -164,13 +164,11 @@ def _flatten_references(
res = []
for preds, refs in zip(predictions, references):
# multiple predictions, single reference
if isinstance(preds, SequenceType) and isinstance(refs, str):
if self._is_multi_prediction_single_reference(preds, refs):
res.extend(list(map(lambda p: (p, refs), preds[slice(max_n)])))

# single prediction, multiple references
elif isinstance(preds, str) and isinstance(refs, SequenceType):
elif self._is_single_prediction_multi_reference(preds, refs):
res.extend(list(map(lambda r: (preds, r), refs[slice(max_n)])))

# single prediction, single reference
else:
res.append((preds, refs))
Expand All @@ -185,7 +183,7 @@ def compute(
**kwargs,
) -> MetricResult:
# make sure to flatten
predictions, references = self._flatten_references(
predictions, references = self._flatten_instances(
predictions,
references,
max_n=self.max_n,
Expand Down

0 comments on commit 897f061

Please sign in to comment.