Skip to content

Commit

Permalink
Revert "integrations[patch]: remove non-required chat param defaults" (
Browse files Browse the repository at this point in the history
…#29048)

Reverts #26730

discuss best way to release default changes (esp openai temperature)
  • Loading branch information
efriis authored Jan 6, 2025
1 parent 3d7ae8b commit 187131c
Show file tree
Hide file tree
Showing 15 changed files with 43 additions and 51 deletions.
8 changes: 4 additions & 4 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ class ChatAnthropic(BaseChatModel):
Key init args — client params:
timeout: Optional[float]
Timeout for requests.
max_retries: Optional[int]
max_retries: int
Max number of retries if a request fails.
api_key: Optional[str]
Anthropic API key. If not passed in will be read from env var ANTHROPIC_API_KEY.
Expand Down Expand Up @@ -558,7 +558,8 @@ class Joke(BaseModel):
default_request_timeout: Optional[float] = Field(None, alias="timeout")
"""Timeout for requests to Anthropic Completion API."""

max_retries: Optional[int] = None
# sdk default = 2: https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#retries
max_retries: int = 2
"""Number of retries allowed for requests sent to the Anthropic Completion API."""

stop_sequences: Optional[List[str]] = Field(None, alias="stop")
Expand Down Expand Up @@ -661,10 +662,9 @@ def _client_params(self) -> Dict[str, Any]:
client_params: Dict[str, Any] = {
"api_key": self.anthropic_api_key.get_secret_value(),
"base_url": self.anthropic_api_url,
"max_retries": self.max_retries,
"default_headers": (self.default_headers or None),
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries
# value <= 0 indicates the param should be ignored. None is a meaningful value
# for Anthropic client and treated differently than not specifying the param at
# all.
Expand Down
2 changes: 1 addition & 1 deletion libs/partners/fireworks/langchain_fireworks/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def is_lc_serializable(cls) -> bool:
default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model"
)
"""Model name to use."""
temperature: Optional[float] = None
temperature: float = 0.0
"""What sampling temperature to use."""
stop: Optional[Union[str, List[str]]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
'request_timeout': 60.0,
'stop': list([
]),
'temperature': 0.0,
}),
'lc': 1,
'name': 'ChatFireworks',
Expand Down
16 changes: 7 additions & 9 deletions libs/partners/groq/langchain_groq/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class ChatGroq(BaseChatModel):
Key init args — client params:
timeout: Union[float, Tuple[float, float], Any, None]
Timeout for requests.
max_retries: Optional[int]
max_retries: int
Max number of retries.
api_key: Optional[str]
Groq API key. If not passed in will be read from env var GROQ_API_KEY.
Expand Down Expand Up @@ -303,7 +303,7 @@ class Joke(BaseModel):
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="mixtral-8x7b-32768", alias="model")
"""Model name to use."""
temperature: Optional[float] = None
temperature: float = 0.7
"""What sampling temperature to use."""
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""
Expand All @@ -327,11 +327,11 @@ class Joke(BaseModel):
)
"""Timeout for requests to Groq completion API. Can be float, httpx.Timeout or
None."""
max_retries: Optional[int] = None
max_retries: int = 2
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
n: Optional[int] = None
n: int = 1
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
Expand Down Expand Up @@ -379,11 +379,10 @@ def build_extra(cls, values: Dict[str, Any]) -> Any:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n is not None and self.n < 1:
if self.n < 1:
raise ValueError("n must be at least 1.")
elif self.n is not None and self.n > 1 and self.streaming:
if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

if self.temperature == 0:
self.temperature = 1e-8

Expand All @@ -393,11 +392,10 @@ def validate_environment(self) -> Self:
),
"base_url": self.groq_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries

try:
import groq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'mixtral-8x7b-32768',
'n': 1,
'request_timeout': 60.0,
'stop': list([
]),
Expand Down
15 changes: 6 additions & 9 deletions libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,8 @@ def _create_retry_decorator(
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""

errors = [httpx.RequestError, httpx.StreamError]
kwargs: dict = dict(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
return create_base_retry_decorator(
**{k: v for k, v in kwargs.items() if v is not None}
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)


Expand Down Expand Up @@ -383,13 +380,13 @@ class ChatMistralAI(BaseChatModel):
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
)
endpoint: Optional[str] = Field(default=None, alias="base_url")
max_retries: Optional[int] = None
timeout: Optional[int] = None
max_concurrent_requests: Optional[int] = None
max_retries: int = 5
timeout: int = 120
max_concurrent_requests: int = 64
model: str = Field(default="mistral-small", alias="model_name")
temperature: Optional[float] = None
temperature: float = 0.7
max_tokens: Optional[int] = None
top_p: Optional[float] = None
top_p: float = 1
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
random_seed: Optional[int] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
]),
'kwargs': dict({
'endpoint': 'boo',
'max_concurrent_requests': 64,
'max_retries': 2,
'max_tokens': 100,
'mistral_api_key': dict({
Expand All @@ -21,6 +22,7 @@
'model': 'mistral-small',
'temperature': 0.0,
'timeout': 60,
'top_p': 1,
}),
'lc': 1,
'name': 'ChatMistralAI',
Expand Down
9 changes: 4 additions & 5 deletions libs/partners/openai/langchain_openai/chat_models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
timeout: Union[float, Tuple[float, float], Any, None]
Timeout for requests.
max_retries: Optional[int]
max_retries: int
Max number of retries.
organization: Optional[str]
OpenAI organization ID. If not passed in will be read from env
Expand Down Expand Up @@ -586,9 +586,9 @@ def is_lc_serializable(cls) -> bool:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n is not None and self.n < 1:
if self.n < 1:
raise ValueError("n must be at least 1.")
elif self.n is not None and self.n > 1 and self.streaming:
if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

if self.disabled_params is None:
Expand Down Expand Up @@ -641,11 +641,10 @@ def validate_environment(self) -> Self:
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries
if not self.client:
sync_specific = {"http_client": self.http_client}
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
Expand Down
20 changes: 9 additions & 11 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class BaseChatOpenAI(BaseChatModel):
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use."""
temperature: Optional[float] = None
temperature: float = 0.7
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
Expand All @@ -430,7 +430,7 @@ class BaseChatOpenAI(BaseChatModel):
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None."""
max_retries: Optional[int] = None
max_retries: int = 2
"""Maximum number of retries to make when generating."""
presence_penalty: Optional[float] = None
"""Penalizes repeated tokens."""
Expand All @@ -448,7 +448,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Modify the likelihood of specified tokens appearing in the completion."""
streaming: bool = False
"""Whether to stream the results or not."""
n: Optional[int] = None
n: int = 1
"""Number of chat completions to generate for each prompt."""
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""
Expand Down Expand Up @@ -532,9 +532,9 @@ def validate_temperature(cls, values: Dict[str, Any]) -> Any:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n is not None and self.n < 1:
if self.n < 1:
raise ValueError("n must be at least 1.")
elif self.n is not None and self.n > 1 and self.streaming:
if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

# Check OPENAI_ORGANIZATION for backwards compatibility.
Expand All @@ -551,12 +551,10 @@ def validate_environment(self) -> Self:
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries

if self.openai_proxy and (self.http_client or self.http_async_client):
openai_proxy = self.openai_proxy
http_client = self.http_client
Expand Down Expand Up @@ -611,14 +609,14 @@ def _default_params(self) -> Dict[str, Any]:
"stop": self.stop or None, # also exclude empty list for this
"max_tokens": self.max_tokens,
"extra_body": self.extra_body,
"n": self.n,
"temperature": self.temperature,
"reasoning_effort": self.reasoning_effort,
}

params = {
"model": self.model_name,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
**{k: v for k, v in exclude_if_none.items() if v is not None},
**self.model_kwargs,
}
Expand Down Expand Up @@ -1567,7 +1565,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
timeout: Union[float, Tuple[float, float], Any, None]
Timeout for requests.
max_retries: Optional[int]
max_retries: int
Max number of retries.
api_key: Optional[str]
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
}),
'max_retries': 2,
'max_tokens': 100,
'n': 1,
'openai_api_key': dict({
'id': list([
'AZURE_OPENAI_API_KEY',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'gpt-3.5-turbo',
'n': 1,
'openai_api_key': dict({
'id': list([
'OPENAI_API_KEY',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,8 @@ def test__get_request_payload() -> None:
],
"model": "gpt-4o-2024-08-06",
"stream": False,
"n": 1,
"temperature": 0.7,
}
payload = llm._get_request_payload(messages)
assert payload == expected
Expand Down
8 changes: 1 addition & 7 deletions libs/partners/xai/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@ TEST_FILE ?= tests/unit_tests/

integration_test integration_tests: TEST_FILE=tests/integration_tests/

test tests:
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)

test_watch:
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)

integration_test integration_tests:
test tests integration_test integration_tests:
poetry run pytest $(TEST_FILE)

######################
Expand Down
7 changes: 3 additions & 4 deletions libs/partners/xai/langchain_xai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ def _get_ls_params(
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n is not None and self.n < 1:
if self.n < 1:
raise ValueError("n must be at least 1.")
if self.n is not None and self.n > 1 and self.streaming:
if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

client_params: dict = {
Expand All @@ -331,11 +331,10 @@ def validate_environment(self) -> Self:
),
"base_url": self.xai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries

if client_params["api_key"] is None:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'grok-beta',
'n': 1,
'request_timeout': 60.0,
'stop': list([
]),
Expand Down

0 comments on commit 187131c

Please sign in to comment.