Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

standard-tests: model param test #29595

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field
from pydantic import ConfigDict, Field


class Chat__ModuleName__(BaseChatModel):
Expand Down Expand Up @@ -266,7 +266,7 @@ class Joke(BaseModel):

""" # noqa: E501

model_name: str = Field(alias="model")
model: str = Field(alias="model_name")
"""The name of the model"""
parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed."""
Expand All @@ -276,6 +276,10 @@ class Joke(BaseModel):
stop: Optional[List[str]] = None
max_retries: int = 2

model_config = ConfigDict(
populate_by_name=True,
)

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
Expand All @@ -293,7 +297,7 @@ def _identifying_params(self) -> Dict[str, Any]:
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": self.model_name,
"model": self.model,
}

def _generate(
Expand Down
6 changes: 5 additions & 1 deletion libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ class BaseChatOpenAI(BaseChatModel):
async_client: Any = Field(default=None, exclude=True) #: :meta private:
root_client: Any = Field(default=None, exclude=True) #: :meta private:
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
model: str = Field(default="gpt-3.5-turbo", alias="model_name")
"""Model name to use."""
temperature: Optional[float] = None
"""What sampling temperature to use."""
Expand Down Expand Up @@ -545,6 +545,10 @@ class BaseChatOpenAI(BaseChatModel):

model_config = ConfigDict(populate_by_name=True)

@property
def model_name(self) -> str:
return self.model

@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
Expand Down
43 changes: 43 additions & 0 deletions libs/standard-tests/langchain_tests/unit_tests/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def chat_model_params(self) -> dict:
"""Initialization parameters for the chat model."""
return {}

@property
def chat_model_model_param(self) -> dict:
return self.chat_model_params.get("model", "test-model-name")

@property
def standard_chat_model_params(self) -> dict:
""":private:"""
Expand Down Expand Up @@ -521,6 +525,45 @@ def test_init(self) -> None:
)
assert model is not None

def test_model_param_name(self) -> None:
"""Tests model initializatiokn with a model= parameter. This should pass for all
integrations.

.. dropdown:: Troubleshooting

If this test fails, ensure that the model can be initialized with a
``model`` parameter, and that the model parameter can be accessed as
``.model``.

If not, the easiest way to configure this is likely to add
``from pydantic import ConfigDict`` at the top of your file, and add a
``model_config`` class attribute to your model class:

.. code-block:: python

class MyChatModel(BaseChatModel):
model: str = Field(alias="model_name")
model_config = ConfigDict(populate_by_name=True)

# optional property for backwards-compatibility
# for folks accessing chat_model.model_name
@property
def model_name(self) -> str:
return self.model
"""
params = {
**self.standard_chat_model_params,
**self.chat_model_params,
}
if "model_name" in params:
params["model"] = params.pop("model_name")
else:
params["model"] = self.chat_model_model_param

model = self.chat_model_class(**params)
assert model is not None
assert model.model == params["model"]

def test_init_from_env(self) -> None:
"""Test initialization from environment variables. Relies on the
``init_from_env_params`` property. Test is skipped if that property is not
Expand Down
10 changes: 7 additions & 3 deletions libs/standard-tests/tests/unit_tests/custom_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field
from pydantic import ConfigDict, Field


class ChatParrotLink(BaseChatModel):
Expand All @@ -33,7 +33,7 @@ class ChatParrotLink(BaseChatModel):
[HumanMessage(content="world")]])
"""

model_name: str = Field(alias="model")
model: str = Field(alias="model_name")
"""The name of the model"""
parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed."""
Expand All @@ -43,6 +43,10 @@ class ChatParrotLink(BaseChatModel):
stop: Optional[List[str]] = None
max_retries: int = 2

model_config = ConfigDict(
populate_by_name=True,
)

def _generate(
self,
messages: List[BaseMessage],
Expand Down Expand Up @@ -163,5 +167,5 @@ def _identifying_params(self) -> Dict[str, Any]:
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": self.model_name,
"model": self.model,
}
Loading