Skip to content

Commit

Permalink
chore!: moving to pydantic2 (explodinggradients#1394)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjmachan authored and shahules786 committed Oct 2, 2024
1 parent deafb4a commit ac572a2
Show file tree
Hide file tree
Showing 16 changed files with 45 additions and 249 deletions.
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ dependencies = [
"datasets",
"tiktoken",
"langchain",
"langchain-core<0.3",
"langchain-core",
"langchain-community",
"langchain_openai",
"openai>1",
"pysbd>=0.3.4",
"nest-asyncio",
"appdirs",
"pydantic>=2",
"openai>1",
"pysbd>=0.3.4",
]
dynamic = ["version", "readme"]

Expand Down
2 changes: 1 addition & 1 deletion src/ragas/_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import requests
from appdirs import user_data_dir
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field

from ragas.utils import get_debug_mode

Expand Down
6 changes: 4 additions & 2 deletions src/ragas/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel

from ragas.utils import get_from_dict

Expand Down Expand Up @@ -39,7 +39,9 @@ def cost(
+ self.output_tokens * cost_per_output_token
)

def __eq__(self, other: "TokenUsage") -> bool:
def __eq__(self, other: object) -> bool:
if not isinstance(other, TokenUsage):
return False
return (
self.input_tokens == other.input_tokens
and self.output_tokens == other.output_tokens
Expand Down
2 changes: 1 addition & 1 deletion src/ragas/llms/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel

from ragas.llms import BaseRagasLLM
from ragas.llms.prompt import Prompt, PromptValue
Expand Down
2 changes: 1 addition & 1 deletion src/ragas/metrics/_answer_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field

import numpy as np
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel

from ragas.dataset_schema import SingleTurnSample
from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
Expand Down
2 changes: 1 addition & 1 deletion src/ragas/metrics/_answer_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field

import numpy as np
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel

from ragas.dataset_schema import SingleTurnSample
from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
Expand Down
2 changes: 1 addition & 1 deletion src/ragas/metrics/_context_entities_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ContextEntitiesResponse(BaseModel):


_output_instructions = get_json_format_instructions(
pydantic_object=ContextEntitiesResponse
pydantic_object=ContextEntitiesResponse # type: ignore
)
_output_parser = RagasoutputParser(pydantic_object=ContextEntitiesResponse)

Expand Down
2 changes: 1 addition & 1 deletion src/ragas/metrics/_context_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ContextPrecisionVerifications(BaseModel):


_verification_output_instructions = get_json_format_instructions(
ContextPrecisionVerification
ContextPrecisionVerification # type: ignore
)
_output_parser = RagasoutputParser(pydantic_object=ContextPrecisionVerification)

Expand Down
8 changes: 4 additions & 4 deletions src/ragas/metrics/_context_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field

import numpy as np
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel, RootModel

from ragas.dataset_schema import SingleTurnSample
from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
Expand All @@ -29,11 +29,11 @@ class ContextRecallClassificationAnswer(BaseModel):
reason: str


class ContextRecallClassificationAnswers(BaseModel):
__root__: t.List[ContextRecallClassificationAnswer]
class ContextRecallClassificationAnswers(RootModel):
root: t.List[ContextRecallClassificationAnswer]

def dicts(self) -> t.List[t.Dict]:
return self.dict()["__root__"]
return self.model_dump()


_classification_output_instructions = get_json_format_instructions(
Expand Down
35 changes: 16 additions & 19 deletions src/ragas/metrics/_faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, field

import numpy as np
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field, RootModel

from ragas.dataset_schema import SingleTurnSample
from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
Expand Down Expand Up @@ -39,11 +39,8 @@ class Statements(BaseModel):
simpler_statements: t.List[str] = Field(..., description="the simpler statements")


class StatementsAnswers(BaseModel):
__root__: t.List[Statements]

def dicts(self) -> t.List[t.Dict]:
return self.dict()["__root__"]
class StatementsAnswers(RootModel):
root: t.List[Statements]


_statements_output_instructions = get_json_format_instructions(StatementsAnswers)
Expand Down Expand Up @@ -79,7 +76,7 @@ def dicts(self) -> t.List[t.Dict]:
],
},
]
).dicts(),
).model_dump(),
}
],
input_keys=["question", "answer", "sentences"],
Expand All @@ -94,11 +91,11 @@ class StatementFaithfulnessAnswer(BaseModel):
verdict: int = Field(..., description="the verdict(0/1) of the faithfulness.")


class StatementFaithfulnessAnswers(BaseModel):
__root__: t.List[StatementFaithfulnessAnswer]
class StatementFaithfulnessAnswers(RootModel):
root: t.List[StatementFaithfulnessAnswer]

def dicts(self) -> t.List[t.Dict]:
return self.dict()["__root__"]
def dicts(self):
return self.model_dump()


_faithfulness_output_instructions = get_json_format_instructions(
Expand Down Expand Up @@ -144,20 +141,20 @@ def dicts(self) -> t.List[t.Dict]:
"verdict": 0,
},
]
).dicts(),
).model_dump(),
},
{
"context": """Photosynthesis is a process used by plants, algae, and certain bacteria to convert light energy into chemical energy.""",
"statements": ["Albert Einstein was a genius."],
"answer": StatementFaithfulnessAnswers.parse_obj(
"answer": StatementFaithfulnessAnswers.model_validate(
[
{
"statement": "Albert Einstein was a genius.",
"reason": "The context and statement are unrelated",
"verdict": 0,
}
]
).dicts(),
).model_dump(),
},
],
input_keys=["context", "statements"],
Expand Down Expand Up @@ -237,9 +234,9 @@ def _create_statements_prompt(self, row: t.Dict) -> PromptValue:
def _compute_score(self, answers: StatementFaithfulnessAnswers):
# check the verdicts and compute the score
faithful_statements = sum(
1 if answer.verdict else 0 for answer in answers.__root__
1 if answer.verdict else 0 for answer in answers.model_dump()
)
num_statements = len(answers.__root__)
num_statements = len(answers.model_dump())
if num_statements:
score = faithful_statements / num_statements
else:
Expand Down Expand Up @@ -272,7 +269,7 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
if statements is None:
return np.nan

statements = [item["simpler_statements"] for item in statements.dicts()]
statements = [item["simpler_statements"] for item in statements.model_dump()]
statements = [item for sublist in statements for item in sublist]

assert isinstance(statements, t.List), "statements must be a list"
Expand All @@ -295,7 +292,7 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
]

faithfulness_list = [
faith.dicts() for faith in faithfulness_list if faith is not None
faith.model_dump() for faith in faithfulness_list if faith is not None
]

if faithfulness_list:
Expand Down Expand Up @@ -385,7 +382,7 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
if statements is None:
return np.nan

statements = [item["simpler_statements"] for item in statements.dicts()]
statements = [item["simpler_statements"] for item in statements.model_dump()]
statements = [item for sublist in statements for item in sublist]

assert isinstance(statements, t.List), "statements must be a list"
Expand Down
12 changes: 7 additions & 5 deletions src/ragas/metrics/_noise_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,12 @@ async def _evaluate_statement_faithfulness(
"verdict",
)

faithfulness_list = StatementFaithfulnessAnswers.parse_obj(
faithfulness_list = StatementFaithfulnessAnswers.model_validate(
faithfulness_list
)

verdict_list = [
1 if statement.verdict else 0
for statement in faithfulness_list.__root__
1 if statement.verdict else 0 for statement in faithfulness_list.dicts()
]
return np.array(verdict_list)
else:
Expand All @@ -162,14 +161,17 @@ async def _decompose_answer_into_statements(
callbacks=callbacks,
)
else:
statements_gen = self.llm.generate(
statements_gen = await self.llm.generate(
p_value,
callbacks=callbacks,
)

# Await the aparse method
statements = await _statements_output_parser.aparse(
statements_gen.generations[0][0].text, p_value, self.llm, self.max_retries # type: ignore
statements_gen.generations[0][0].text,
p_value,
self.llm,
self.max_retries, # type: ignore
)

if statements is None:
Expand Down
6 changes: 3 additions & 3 deletions src/ragas/metrics/_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ class GenerateAnswersResponse(BaseModel):


_output_instructions_question_generation = get_json_format_instructions(
pydantic_object=GenerateQuestionsResponse
pydantic_object=GenerateQuestionsResponse # type: ignore
)
_output_instructions_answer_generation = get_json_format_instructions(
pydantic_object=GenerateAnswersResponse
pydantic_object=GenerateAnswersResponse # type: ignore
)
_output_instructions_keyphrase_extraction = get_json_format_instructions(
pydantic_object=ExtractKeyphrasesResponse
pydantic_object=ExtractKeyphrasesResponse # type: ignore
)
_output_parser_question_generation = RagasoutputParser(
pydantic_object=GenerateQuestionsResponse
Expand Down
1 change: 1 addition & 0 deletions src/ragas/testset/prompts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
from langchain_core.pydantic_v1 import BaseModel

from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
Expand Down
Loading

0 comments on commit ac572a2

Please sign in to comment.