Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/llama_stack_client/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ def __init__(
custom_query: Mapping[str, object] | None = None,
_strict_response_validation: bool,
) -> None:
kwargs: dict[str, Any] = {}
if limits is not None:
warnings.warn(
"The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead",
Expand All @@ -804,6 +805,7 @@ def __init__(
limits = DEFAULT_CONNECTION_LIMITS

if transport is not None:
kwargs["transport"] = transport
warnings.warn(
"The `transport` argument is deprecated. The `http_client` argument should be passed instead",
category=DeprecationWarning,
Expand All @@ -813,6 +815,7 @@ def __init__(
raise ValueError("The `http_client` argument is mutually exclusive with `transport`")

if proxies is not None:
kwargs["proxies"] = proxies
warnings.warn(
"The `proxies` argument is deprecated. The `http_client` argument should be passed instead",
category=DeprecationWarning,
Expand Down Expand Up @@ -856,10 +859,9 @@ def __init__(
base_url=base_url,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
proxies=proxies,
transport=transport,
limits=limits,
follow_redirects=True,
**kwargs, # type: ignore
)

def is_closed(self) -> bool:
Expand Down Expand Up @@ -1358,6 +1360,7 @@ def __init__(
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
) -> None:
kwargs: dict[str, Any] = {}
if limits is not None:
warnings.warn(
"The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead",
Expand All @@ -1370,6 +1373,7 @@ def __init__(
limits = DEFAULT_CONNECTION_LIMITS

if transport is not None:
kwargs["transport"] = transport
warnings.warn(
"The `transport` argument is deprecated. The `http_client` argument should be passed instead",
category=DeprecationWarning,
Expand All @@ -1379,6 +1383,7 @@ def __init__(
raise ValueError("The `http_client` argument is mutually exclusive with `transport`")

if proxies is not None:
kwargs["proxies"] = proxies
warnings.warn(
"The `proxies` argument is deprecated. The `http_client` argument should be passed instead",
category=DeprecationWarning,
Expand Down Expand Up @@ -1422,10 +1427,9 @@ def __init__(
base_url=base_url,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
proxies=proxies,
transport=transport,
limits=limits,
follow_redirects=True,
**kwargs, # type: ignore
)

def is_closed(self) -> bool:
Expand Down
8 changes: 3 additions & 5 deletions src/llama_stack_client/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def model_dump(
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
warnings=warnings,
# warnings are not supported in Pydantic v1
warnings=warnings if PYDANTIC_V2 else True,
)
return cast(
"dict[str, Any]",
Expand Down Expand Up @@ -213,9 +214,6 @@ def __set_name__(self, owner: type[Any], name: str) -> None: ...
# __set__ is not defined at runtime, but @cached_property is designed to be settable
def __set__(self, instance: object, value: _T) -> None: ...
else:
try:
from functools import cached_property as cached_property
except ImportError:
from cached_property import cached_property as cached_property
from functools import cached_property as cached_property

typed_cached_property = cached_property
158 changes: 125 additions & 33 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,12 +675,25 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str
@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
respx_mock.post("/alpha/models/register").mock(side_effect=httpx.TimeoutException("Test timeout error"))
respx_mock.post("/alpha/inference/chat-completion").mock(
side_effect=httpx.TimeoutException("Test timeout error")
)

with pytest.raises(APITimeoutError):
self.client.post(
"/alpha/models/register",
body=cast(object, dict(model_id="model_id")),
"/alpha/inference/chat-completion",
body=cast(
object,
dict(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
Expand All @@ -690,12 +703,23 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> No
@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
respx_mock.post("/alpha/models/register").mock(return_value=httpx.Response(500))
respx_mock.post("/alpha/inference/chat-completion").mock(return_value=httpx.Response(500))

with pytest.raises(APIStatusError):
self.client.post(
"/alpha/models/register",
body=cast(object, dict(model_id="model_id")),
"/alpha/inference/chat-completion",
body=cast(
object,
dict(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
Expand Down Expand Up @@ -726,9 +750,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler)
respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler)

response = client.models.with_raw_response.register(model_id="model_id")
response = client.inference.with_raw_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
)

assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
Expand All @@ -750,10 +782,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler)

response = client.models.with_raw_response.register(
model_id="model_id", extra_headers={"x-stainless-retry-count": Omit()}
respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler)

response = client.inference.with_raw_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
extra_headers={"x-stainless-retry-count": Omit()},
)

assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
Expand All @@ -775,10 +814,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler)

response = client.models.with_raw_response.register(
model_id="model_id", extra_headers={"x-stainless-retry-count": "42"}
respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler)

response = client.inference.with_raw_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
extra_headers={"x-stainless-retry-count": "42"},
)

assert response.http_request.headers.get("x-stainless-retry-count") == "42"
Expand Down Expand Up @@ -1416,12 +1462,25 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte
@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
respx_mock.post("/alpha/models/register").mock(side_effect=httpx.TimeoutException("Test timeout error"))
respx_mock.post("/alpha/inference/chat-completion").mock(
side_effect=httpx.TimeoutException("Test timeout error")
)

with pytest.raises(APITimeoutError):
await self.client.post(
"/alpha/models/register",
body=cast(object, dict(model_id="model_id")),
"/alpha/inference/chat-completion",
body=cast(
object,
dict(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
Expand All @@ -1431,12 +1490,23 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter)
@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
respx_mock.post("/alpha/models/register").mock(return_value=httpx.Response(500))
respx_mock.post("/alpha/inference/chat-completion").mock(return_value=httpx.Response(500))

with pytest.raises(APIStatusError):
await self.client.post(
"/alpha/models/register",
body=cast(object, dict(model_id="model_id")),
"/alpha/inference/chat-completion",
body=cast(
object,
dict(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
Expand Down Expand Up @@ -1468,9 +1538,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler)
respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler)

response = await client.models.with_raw_response.register(model_id="model_id")
response = await client.inference.with_raw_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
)

assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
Expand All @@ -1493,10 +1571,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler)

response = await client.models.with_raw_response.register(
model_id="model_id", extra_headers={"x-stainless-retry-count": Omit()}
respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler)

response = await client.inference.with_raw_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
extra_headers={"x-stainless-retry-count": Omit()},
)

assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
Expand All @@ -1519,10 +1604,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler)

response = await client.models.with_raw_response.register(
model_id="model_id", extra_headers={"x-stainless-retry-count": "42"}
respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler)

response = await client.inference.with_raw_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
extra_headers={"x-stainless-retry-count": "42"},
)

assert response.http_request.headers.get("x-stainless-retry-count") == "42"
Expand All @@ -1539,7 +1631,7 @@ def test_get_platform(self) -> None:
import threading

from llama_stack_client._utils import asyncify
from llama_stack_client._base_client import get_platform
from llama_stack_client._base_client import get_platform

async def test_main() -> None:
result = await asyncify(get_platform)()
Expand Down
8 changes: 8 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,14 @@ class Model(BaseModel):
m.model_dump(warnings=False)


def test_compat_method_no_error_for_warnings() -> None:
class Model(BaseModel):
foo: Optional[str]

m = Model(foo="hello")
assert isinstance(model_dump(m, warnings=False), dict)


def test_to_json() -> None:
class Model(BaseModel):
foo: Optional[str] = Field(alias="FOO", default=None)
Expand Down