diff --git a/evals/metrics/ragas/ragas.py b/evals/metrics/ragas/ragas.py index 9b0a1d3e..c80ff94e 100644 --- a/evals/metrics/ragas/ragas.py +++ b/evals/metrics/ragas/ragas.py @@ -4,12 +4,16 @@ # SPDX-License-Identifier: Apache-2.0 # import os +import re from typing import Dict, Optional, Union from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel from langchain_huggingface import HuggingFaceEndpoint +# import * is only allowed at module level according to python syntax +from ragas.metrics import * + def format_ragas_metric_name(name: str): return f"{name} (ragas)" @@ -29,16 +33,17 @@ def __init__( self.model = model self.embeddings = embeddings self.metrics = metrics - self.validated_list = [ - "answer_correctness", - "answer_relevancy", - "answer_similarity", - "context_precision", - "context_recall", - "faithfulness", - "context_utilization", - # "reference_free_rubrics_score", - ] + + # self.validated_list = [ + # "answer_correctness", + # "answer_relevancy", + # "answer_similarity", + # "context_precision", + # "context_recall", + # "faithfulness", + # "context_utilization", + # # "reference_free_rubrics_score", + # ] async def a_measure(self, test_case: Dict): return self.measure(test_case) @@ -47,37 +52,51 @@ def measure(self, test_case: Dict): # sends to server try: from ragas import evaluate - from ragas.metrics import ( - answer_correctness, - answer_relevancy, - answer_similarity, - context_precision, - context_recall, - context_utilization, - faithfulness, - ) + from ragas.metrics import ALL_METRICS + + self.metric_names = [metric.__class__.__name__ for metric in ALL_METRICS] + self.metric_names = [re.sub(r"(?