Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to Pydantic v2 #8330

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/api_reference/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def setup(app):

autodoc_pydantic_model_show_json = False
autodoc_pydantic_field_list_validators = False
autodoc_pydantic_config_members = False
autodoc_pydantic_model_show_config_summary = False
autodoc_pydantic_model_show_validator_members = False
autodoc_pydantic_model_show_validator_summary = False
Expand Down
2 changes: 1 addition & 1 deletion docs/api_reference/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-e libs/langchain
autodoc_pydantic==1.8.0
autodoc_pydantic==2.0.0
myst_parser
nbsphinx==0.8.9
sphinx==4.5.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain.chains.base import Chain
from langchain.schema.language_model import BaseLanguageModel
from langchain.vectorstores.base import VectorStore
from pydantic import BaseModel, Field
from pydantic import ConfigDict, BaseModel, Field

from langchain_experimental.autonomous_agents.baby_agi.task_creation import (
TaskCreationChain,
Expand All @@ -29,11 +29,7 @@ class BabyAGI(Chain, BaseModel):
task_id_counter: int = Field(1)
vectorstore: VectorStore = Field(init=False)
max_iterations: Optional[int] = None

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

def add_task(self, task: Dict) -> None:
self.task_list.append(task)
Expand Down
21 changes: 10 additions & 11 deletions libs/experimental/langchain_experimental/cpal/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import duckdb
import pandas as pd
from langchain.graphs.networkx_graph import NetworkxEntityGraph
from pydantic import BaseModel, Field, PrivateAttr, root_validator, validator
from pydantic import field_validator, ConfigDict, BaseModel, Field, PrivateAttr, root_validator

from langchain_experimental.cpal.constants import Constant

Expand All @@ -20,7 +20,8 @@ class NarrativeModel(BaseModel):
story_hypothetical: str
story_plot: str # causal stack of operations

@validator("*", pre=True)
@field_validator("*", mode="before")
@classmethod
def empty_str_to_none(cls, v: str) -> Union[str, None]:
"""Empty strings are not allowed"""
if v == "":
Expand All @@ -33,14 +34,10 @@ class EntityModel(BaseModel):
code: str = Field(description="entity actions")
value: float = Field(description="entity initial value")
depends_on: list[str] = Field(default=[], description="ancestor entities")
model_config = ConfigDict(validate_assignment=True)

# TODO: generalize to multivariate math
# TODO: acyclic graph

class Config:
validate_assignment = True

@validator("name")
@field_validator("name")
@classmethod
def lower_case_name(cls, v: str) -> str:
v = v.lower()
return v
Expand All @@ -64,7 +61,8 @@ class EntitySettingModel(BaseModel):
attribute: str = Field(description="name of the attribute to be calculated")
value: float = Field(description="entity's attribute value (calculated)")

@validator("name")
@field_validator("name")
@classmethod
def lower_case_transform(cls, v: str) -> str:
v = v.lower()
return v
Expand Down Expand Up @@ -98,7 +96,8 @@ class InterventionModel(BaseModel):
entity_settings: list[EntitySettingModel]
system_settings: Optional[list[SystemSettingModel]] = None

@validator("system_settings")
@field_validator("system_settings")
@classmethod
def lower_case_name(cls, v: str) -> Union[str, None]:
if v is not None:
raise NotImplementedError("system_setting is not implemented yet")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from pydantic import BaseModel, Field
from pydantic import ConfigDict, BaseModel, Field

from langchain_experimental.generative_agents.memory import GenerativeAgentMemory

Expand Down Expand Up @@ -38,11 +38,7 @@ class GenerativeAgent(BaseModel):

daily_summaries: List[str] = Field(default_factory=list) # : :meta private:
"""Summary of the events in the plan that the agent took."""

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

# LLM-related methods
@staticmethod
Expand Down
12 changes: 4 additions & 8 deletions libs/experimental/langchain_experimental/pal_chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from langchain.utilities import PythonREPL
from pydantic import Extra, Field, root_validator
from pydantic import model_validator, ConfigDict, Field

from langchain_experimental.pal_chain.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain_experimental.pal_chain.math_prompt import MATH_PROMPT
Expand Down Expand Up @@ -114,14 +114,10 @@ class PALChain(Chain):
"""Validations to perform on the generated code."""
timeout: Optional[int] = 10
"""Timeout in seconds for the generated code to execute."""
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
Expand Down
12 changes: 4 additions & 8 deletions libs/experimental/langchain_experimental/sql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.sql_database.prompt import QUERY_CHECKER
from langchain.utilities.sql_database import SQLDatabase
from pydantic import Extra, Field, root_validator
from pydantic import model_validator, ConfigDict, Field

from langchain_experimental.sql.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS

Expand Down Expand Up @@ -53,14 +53,10 @@ class SQLDatabaseChain(Chain):
to fix the initial SQL from the LLM."""
query_checker_prompt: Optional[BasePromptTemplate] = None
"""The prompt template that should be used by the query checker"""
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
Expand Down
11 changes: 7 additions & 4 deletions libs/langchain/langchain/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import yaml
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, model_validator

from langchain.agents.agent_iterator import AgentExecutorIterator
from langchain.agents.agent_types import AgentType
Expand Down Expand Up @@ -494,7 +494,8 @@ def input_keys(self) -> List[str]:
"""
return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"})

@root_validator()
@model_validator()
@classmethod
def validate_prompt(cls, values: Dict) -> Dict:
"""Validate that prompt matches format."""
prompt = values["llm_chain"].prompt
Expand Down Expand Up @@ -693,7 +694,8 @@ def from_agent_and_tools(
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
)

@root_validator()
@model_validator()
@classmethod
def validate_tools(cls, values: Dict) -> Dict:
"""Validate that tools are compatible with agent."""
agent = values["agent"]
Expand All @@ -707,7 +709,8 @@ def validate_tools(cls, values: Dict) -> Dict:
)
return values

@root_validator()
@model_validator()
@classmethod
def validate_return_direct_tool(cls, values: Dict) -> Dict:
"""Validate that tools are compatible with agent."""
agent = values["agent"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, List

from pydantic import Field
from pydantic import ConfigDict, Field

from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
Expand All @@ -18,11 +18,7 @@ class AmadeusToolkit(BaseToolkit):
"""Toolkit for interacting with Office365."""

client: Client = Field(default_factory=authenticate)

class Config:
"""Pydantic config."""

arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, List

from pydantic import Field
from pydantic import ConfigDict, Field

from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
Expand Down Expand Up @@ -31,11 +31,7 @@ class GmailToolkit(BaseToolkit):
"""Toolkit for interacting with Gmail."""

api_resource: Resource = Field(default_factory=build_resource_service)

class Config:
"""Pydantic config."""

arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, List

from pydantic import Field
from pydantic import ConfigDict, Field

from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools import BaseTool
Expand All @@ -21,11 +21,7 @@ class O365Toolkit(BaseToolkit):
"""Toolkit for interacting with Office 365."""

account: Account = Field(default_factory=authenticate)

class Config:
"""Pydantic config."""

arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import TYPE_CHECKING, List, Optional, Type, cast

from pydantic import Extra, root_validator
from pydantic import ConfigDict, root_validator

from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.tools.base import BaseTool
Expand Down Expand Up @@ -36,12 +36,7 @@ class PlayWrightBrowserToolkit(BaseToolkit):

sync_browser: Optional["SyncBrowser"] = None
async_browser: Optional["AsyncBrowser"] = None

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)

@root_validator
def validate_imports_and_browser_provided(cls, values: dict) -> dict:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Toolkit for interacting with a Power BI dataset."""
from typing import List, Optional, Union

from pydantic import Field
from pydantic import ConfigDict, Field

from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.callbacks.base import BaseCallbackManager
Expand Down Expand Up @@ -38,11 +38,7 @@ class PowerBIToolkit(BaseToolkit):
callback_manager: Optional[BaseCallbackManager] = None
output_token_limit: Optional[int] = None
tiktoken_model_name: Optional[str] = None

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Toolkit for interacting with Spark SQL."""
from typing import List

from pydantic import Field
from pydantic import ConfigDict, Field

from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.schema.language_model import BaseLanguageModel
Expand All @@ -20,11 +20,7 @@ class SparkSQLToolkit(BaseToolkit):

db: SparkSQL = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
Expand Down
7 changes: 2 additions & 5 deletions libs/langchain/langchain/agents/agent_toolkits/sql/toolkit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Toolkit for interacting with an SQL database."""
from typing import List

from pydantic import Field
from pydantic import ConfigDict, Field

from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.schema.language_model import BaseLanguageModel
Expand All @@ -26,10 +26,7 @@ def dialect(self) -> str:
"""Return string representation of SQL dialect to use."""
return self.db.dialect

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
Expand Down
Loading