Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chores: cleanup metrics #1348

Merged
merged 9 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 54 additions & 15 deletions src/ragas/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,56 @@
import sys

from ragas.metrics._answer_correctness import AnswerCorrectness, answer_correctness
from ragas.metrics._answer_relevance import AnswerRelevancy, answer_relevancy
from ragas.metrics._answer_similarity import AnswerSimilarity, answer_similarity
from ragas.metrics._answer_relevance import (
AnswerRelevancy,
ResponseRelevancy,
answer_relevancy,
)
from ragas.metrics._answer_similarity import (
AnswerSimilarity,
SemanticSimilarity,
answer_similarity,
)
from ragas.metrics._aspect_critic import AspectCritic
from ragas.metrics._bleu_score import BleuScore
from ragas.metrics._context_entities_recall import (
ContextEntityRecall,
context_entity_recall,
)
from ragas.metrics._context_precision import (
ContextPrecision,
ContextUtilization,
LLMContextPrecisionWithoutReference,
NonLLMContextPrecisionWithReference,
context_precision,
context_utilization,
)
from ragas.metrics._context_recall import ContextRecall, context_recall
from ragas.metrics._context_recall import (
ContextRecall,
LLMContextRecall,
NonLLMContextRecall,
context_recall,
)
from ragas.metrics._datacompy_score import DataCompyScore
from ragas.metrics._domain_specific_rubrics import (
RubricsScoreWithoutReference,
RubricsScoreWithReference,
rubrics_score_with_reference,
rubrics_score_without_reference,
)
from ragas.metrics._factual_correctness import FactualCorrectness
from ragas.metrics._faithfulness import Faithfulness, FaithulnesswithHHEM, faithfulness
from ragas.metrics._noise_sensitivity import (
NoiseSensitivity,
noise_sensitivity_irrelevant,
noise_sensitivity_relevant,
from ragas.metrics._goal_accuracy import (
AgentGoalAccuracyWithoutReference,
AgentGoalAccuracyWithReference,
)
from ragas.metrics._instance_specific_rubrics import (
InstanceRubricsScoreWithoutReference,
InstanceRubricsWithReference,
)
from ragas.metrics._noise_sensitivity import NoiseSensitivity
from ragas.metrics._rogue_score import RougeScore
from ragas.metrics._sql_semantic_equivalence import LLMSQLEquivalence
from ragas.metrics._string import ExactMatch, NonLLMStringSimilarity, StringPresence, DistanceMeasure
from ragas.metrics._summarization import SummarizationScore, summarization_score
from ragas.metrics._tool_call_accuracy import ToolCallAccuracy

__all__ = [
"AnswerCorrectness",
Expand All @@ -41,7 +64,6 @@
"ContextPrecision",
"context_precision",
"ContextUtilization",
"context_utilization",
"ContextRecall",
"context_recall",
"AspectCritic",
Expand All @@ -52,12 +74,29 @@
"SummarizationScore",
"summarization_score",
"NoiseSensitivity",
"noise_sensitivity_irrelevant",
"noise_sensitivity_relevant",
"rubrics_score_with_reference",
"rubrics_score_without_reference",
"RubricsScoreWithoutReference",
"RubricsScoreWithReference",
"LLMContextPrecisionWithoutReference",
"NonLLMContextPrecisionWithReference",
"LLMContextPrecisionWithoutReference",
"LLMContextRecall",
"NonLLMContextRecall",
"FactualCorrectness",
"InstanceRubricsScoreWithoutReference",
"InstanceRubricsWithReference",
"NonLLMStringSimilarity",
"ExactMatch",
"StringPresence",
"BleuScore",
"RougeScore",
"DataCompyScore",
"LLMSQLEquivalence",
"AgentGoalAccuracyWithoutReference",
"AgentGoalAccuracyWithReference",
"ToolCallAccuracy",
"ResponseRelevancy",
"SemanticSimilarity",
"DistanceMeasure",
]

current_module = sys.modules[__name__]
Expand Down
23 changes: 14 additions & 9 deletions src/ragas/metrics/_answer_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
from ragas.llms.prompt import PromptValue


class AnswerRelevanceClassification(BaseModel):
class ResponseRelevanceClassification(BaseModel):
question: str
noncommittal: int


_output_instructions = get_json_format_instructions(
pydantic_object=AnswerRelevanceClassification
pydantic_object=ResponseRelevanceClassification
)
_output_parser = RagasoutputParser(pydantic_object=AnswerRelevanceClassification)
_output_parser = RagasoutputParser(pydantic_object=ResponseRelevanceClassification)


QUESTION_GEN = Prompt(
Expand All @@ -44,7 +44,7 @@ class AnswerRelevanceClassification(BaseModel):
{
"answer": """Albert Einstein was born in Germany.""",
"context": """Albert Einstein was a German-born theoretical physicist who is widely held to be one of the greatest and most influential scientists of all time""",
"output": AnswerRelevanceClassification.parse_obj(
"output": ResponseRelevanceClassification.parse_obj(
{
"question": "Where was Albert Einstein born?",
"noncommittal": 0,
Expand All @@ -54,7 +54,7 @@ class AnswerRelevanceClassification(BaseModel):
{
"answer": """It can change its skin color based on the temperature of its environment.""",
"context": """A recent scientific study has discovered a new species of frog in the Amazon rainforest that has the unique ability to change its skin color based on the temperature of its environment.""",
"output": AnswerRelevanceClassification.parse_obj(
"output": ResponseRelevanceClassification.parse_obj(
{
"question": "What unique ability does the newly discovered species of frog have?",
"noncommittal": 0,
Expand All @@ -64,7 +64,7 @@ class AnswerRelevanceClassification(BaseModel):
{
"answer": """Everest""",
"context": """The tallest mountain on Earth, measured from sea level, is a renowned peak located in the Himalayas.""",
"output": AnswerRelevanceClassification.parse_obj(
"output": ResponseRelevanceClassification.parse_obj(
{
"question": "What is the tallest mountain on Earth?",
"noncommittal": 0,
Expand All @@ -74,7 +74,7 @@ class AnswerRelevanceClassification(BaseModel):
{
"answer": """I don't know about the groundbreaking feature of the smartphone invented in 2023 as am unaware of information beyond 2022. """,
"context": """In 2023, a groundbreaking invention was announced: a smartphone with a battery life of one month, revolutionizing the way people use mobile technology.""",
"output": AnswerRelevanceClassification.parse_obj(
"output": ResponseRelevanceClassification.parse_obj(
{
"question": "What was the groundbreaking feature of the smartphone invented in 2023?",
"noncommittal": 1,
Expand All @@ -89,7 +89,7 @@ class AnswerRelevanceClassification(BaseModel):


@dataclass
class AnswerRelevancy(MetricWithLLM, MetricWithEmbeddings, SingleTurnMetric):
class ResponseRelevancy(MetricWithLLM, MetricWithEmbeddings, SingleTurnMetric):
"""
Scores the relevancy of the answer according to the given question.
Answers with incomplete, redundant or unnecessary information is penalized.
Expand Down Expand Up @@ -139,7 +139,7 @@ def calculate_similarity(
)

def _calculate_score(
self, answers: t.Sequence[AnswerRelevanceClassification], row: t.Dict
self, answers: t.Sequence[ResponseRelevanceClassification], row: t.Dict
) -> float:
question = row["user_input"]
gen_questions = [answer.question for answer in answers]
Expand Down Expand Up @@ -197,4 +197,9 @@ def save(self, cache_dir: str | None = None) -> None:
self.question_generation.save(cache_dir)


class AnswerRelevancy(ResponseRelevancy):
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
return await super()._ascore(row, callbacks)


answer_relevancy = AnswerRelevancy()
7 changes: 6 additions & 1 deletion src/ragas/metrics/_answer_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@dataclass
class AnswerSimilarity(MetricWithLLM, MetricWithEmbeddings, SingleTurnMetric):
class SemanticSimilarity(MetricWithLLM, MetricWithEmbeddings, SingleTurnMetric):
"""
Scores the semantic similarity of ground truth with generated answer.
cross encoder score is used to quantify semantic similarity.
Expand Down Expand Up @@ -91,4 +91,9 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
return score.tolist()[0]


class AnswerSimilarity(SemanticSimilarity):
async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
return await super()._ascore(row, callbacks)


answer_similarity = AnswerSimilarity()
3 changes: 1 addition & 2 deletions src/ragas/metrics/_context_entities_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ class ContextEntityRecall(MetricWithLLM, SingleTurnMetric):
context_entity_recall_prompt: Prompt = field(
default_factory=lambda: TEXT_ENTITY_EXTRACTION
)
batch_size: int = 15
max_retries: int = 1

def _compute_score(
Expand Down Expand Up @@ -195,4 +194,4 @@ def save(self, cache_dir: str | None = None) -> None:
return self.context_entity_recall_prompt.save(cache_dir)


context_entity_recall = ContextEntityRecall(batch_size=15)
context_entity_recall = ContextEntityRecall()
4 changes: 0 additions & 4 deletions src/ragas/metrics/_domain_specific_rubrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,3 @@ def _create_single_turn_prompt(self, row: t.Dict) -> SingleTurnWithReferenceInpu
reference=ground_truth,
rubrics=self.rubrics,
)


rubrics_score_with_reference = RubricsScoreWithReference()
rubrics_score_without_reference = RubricsScoreWithoutReference()
2 changes: 1 addition & 1 deletion src/ragas/metrics/_factual_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from numpy.typing import NDArray
from pydantic import BaseModel, Field

from ragas.experimental.prompt import PydanticPrompt
from ragas.experimental.metrics._faithfulness import (
NLIStatementInput,
NLIStatementPrompt,
)
from ragas.experimental.prompt import PydanticPrompt
from ragas.metrics.base import (
MetricType,
MetricWithLLM,
Expand Down
8 changes: 1 addition & 7 deletions src/ragas/metrics/_noise_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
@dataclass
class NoiseSensitivity(MetricWithLLM, SingleTurnMetric):
name: str = "noise_sensitivity" # type: ignore
focus: str = "relevant"
focus: t.Literal["relevant", "irrelevant"] = "relevant"
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
default_factory=lambda: {
MetricType.SINGLE_TURN: {
Expand Down Expand Up @@ -266,8 +266,6 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None:
assert self.llm is not None, "LLM is not set"

logger.info(f"Adapting Faithfulness metric to {language}")

self.nli_statements_message = self.nli_statements_message.adapt(
language, self.llm, cache_dir
)
Expand All @@ -280,7 +278,3 @@ def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None:
def save(self, cache_dir: t.Optional[str] = None) -> None:
self.nli_statements_message.save(cache_dir)
self.statement_prompt.save(cache_dir)


noise_sensitivity_relevant = NoiseSensitivity()
noise_sensitivity_irrelevant = NoiseSensitivity(focus="irrelevant")
15 changes: 10 additions & 5 deletions src/ragas/metrics/_rogue_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass, field

from langchain_core.callbacks import Callbacks
from rouge_score import rouge_scorer

from ragas.dataset_schema import SingleTurnSample
from ragas.metrics.base import MetricType, SingleTurnMetric
Expand All @@ -18,6 +17,15 @@ class RougeScore(SingleTurnMetric):
rogue_type: t.Literal["rouge1", "rougeL"] = "rougeL"
measure_type: t.Literal["fmeasure", "precision", "recall"] = "fmeasure"

def __post_init__(self):
try:
from rouge_score import rouge_scorer
except ImportError as e:
raise ImportError(
f"{e.name} is required for rouge score. Please install it using `pip install {e.name}"
)
self.rouge_scorer = rouge_scorer

def init(self, run_config: RunConfig):
pass

Expand All @@ -26,12 +34,9 @@ async def _single_turn_ascore(
) -> float:
assert isinstance(sample.reference, str), "Sample reference must be a string"
assert isinstance(sample.response, str), "Sample response must be a string"
scorer = rouge_scorer.RougeScorer([self.rogue_type], use_stemmer=True)
scorer = self.rouge_scorer.RougeScorer([self.rogue_type], use_stemmer=True)
scores = scorer.score(sample.reference, sample.response)
return getattr(scores[self.rogue_type], self.measure_type)

async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
return await self._single_turn_ascore(SingleTurnSample(**row), callbacks)


rouge_score = RougeScore()
2 changes: 1 addition & 1 deletion src/ragas/metrics/_sql_semantic_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class EquivalencePrompt(PydanticPrompt[EquivalenceInput, EquivalenceOutput]):


@dataclass
class LLMSqlEquivalenceWithReference(MetricWithLLM, SingleTurnMetric):
class LLMSQLEquivalence(MetricWithLLM, SingleTurnMetric):
name: str = "llm_sql_equivalence_with_reference" # type: ignore
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
default_factory=lambda: {
Expand Down
4 changes: 3 additions & 1 deletion src/ragas/metrics/_tool_call_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class ToolCallAccuracy(MultiTurnMetric):
}
)

arg_comparison_metric: SingleTurnMetric = ExactMatch()
arg_comparison_metric: SingleTurnMetric = field(
default_factory=lambda: ExactMatch()
)

def init(self, run_config):
pass
Expand Down
Loading