Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 60 additions & 47 deletions src/llama_stack_client/resources/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
)
from .._wrappers import DataWrapper
from .._base_client import make_request_options
from ..types.provider_info import ProviderInfo
from ..types.provider_list_response import ProviderListResponse
from ..types.provider_get_response import GetProviderResponse

__all__ = ["ProvidersResource", "AsyncProvidersResource"]

Expand All @@ -43,49 +43,57 @@ def with_streaming_response(self) -> ProvidersResourceWithStreamingResponse:
"""
return ProvidersResourceWithStreamingResponse(self)

def list(
def retrieve(
self,
provider_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ProviderListResponse:
) -> ProviderInfo:
"""
Args:
extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

timeout: Override the client-level default timeout for this request, in seconds
"""
if not provider_id:
raise ValueError(f"Expected a non-empty value for `provider_id` but received {provider_id!r}")
return self._get(
"/v1/providers",
f"/v1/providers/{provider_id}",
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
post_parser=DataWrapper[ProviderListResponse]._unwrapper,
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=cast(Type[ProviderListResponse], DataWrapper[ProviderListResponse]),
cast_to=ProviderInfo,
)

def inspect(
def list(
self,
provider_id,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> GetProviderResponse:
) -> ProviderListResponse:
return self._get(
f"/v1/providers/{provider_id}",
"/v1/providers",
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
post_parser=DataWrapper[GetProviderResponse]._unwrapper,
post_parser=DataWrapper[ProviderListResponse]._unwrapper,
),
cast_to=cast(Type[GetProviderResponse], DataWrapper[GetProviderResponse]),
cast_to=cast(Type[ProviderListResponse], DataWrapper[ProviderListResponse]),
)


Expand All @@ -109,98 +117,103 @@ def with_streaming_response(self) -> AsyncProvidersResourceWithStreamingResponse
"""
return AsyncProvidersResourceWithStreamingResponse(self)

async def list(
async def retrieve(
self,
provider_id: str,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ProviderListResponse:
) -> ProviderInfo:
"""
Args:
extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

timeout: Override the client-level default timeout for this request, in seconds
"""
if not provider_id:
raise ValueError(f"Expected a non-empty value for `provider_id` but received {provider_id!r}")
return await self._get(
"/v1/providers",
f"/v1/providers/{provider_id}",
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
post_parser=DataWrapper[ProviderListResponse]._unwrapper,
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=cast(Type[ProviderListResponse], DataWrapper[ProviderListResponse]),
cast_to=ProviderInfo,
)
async def inspect(

async def list(
self,
provider_id,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> GetProviderResponse:
) -> ProviderListResponse:
return await self._get(
f"/v1/providers/{provider_id}",
"/v1/providers",
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
post_parser=DataWrapper[GetProviderResponse]._unwrapper,
post_parser=DataWrapper[ProviderListResponse]._unwrapper,
),
cast_to=cast(Type[GetProviderResponse], DataWrapper[GetProviderResponse]),
cast_to=cast(Type[ProviderListResponse], DataWrapper[ProviderListResponse]),
)


class ProvidersResourceWithRawResponse:
def __init__(self, providers: ProvidersResource) -> None:
self._providers = providers

self.retrieve = to_raw_response_wrapper(
providers.retrieve,
)
self.list = to_raw_response_wrapper(
providers.list,
)

self.inspect = to_raw_response_wrapper(
providers.inspect,
)

class AsyncProvidersResourceWithRawResponse:
def __init__(self, providers: AsyncProvidersResource) -> None:
self._providers = providers

self.retrieve = async_to_raw_response_wrapper(
providers.retrieve,
)
self.list = async_to_raw_response_wrapper(
providers.list,
)

self.inspect = async_to_raw_response_wrapper(
providers.inspect,
)



class ProvidersResourceWithStreamingResponse:
def __init__(self, providers: ProvidersResource) -> None:
self._providers = providers

self.retrieve = to_streamed_response_wrapper(
providers.retrieve,
)
self.list = to_streamed_response_wrapper(
providers.list,
)

self.inspect = to_streamed_response_wrapper(
providers.inspect,
)


class AsyncProvidersResourceWithStreamingResponse:
def __init__(self, providers: AsyncProvidersResource) -> None:
self._providers = providers

self.retrieve = async_to_streamed_response_wrapper(
providers.retrieve,
)
self.list = async_to_streamed_response_wrapper(
providers.list,
)
self.inspect = async_to_streamed_response_wrapper(
providers.inspect,
)
9 changes: 0 additions & 9 deletions src/llama_stack_client/types/provider_get_response.py

This file was deleted.

11 changes: 2 additions & 9 deletions src/llama_stack_client/types/provider_info.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from typing import Dict, List, Union

from .._models import BaseModel
from typing import Dict, Any

__all__ = ["ProviderInfo"]


class ProviderInfo(BaseModel):
api: str

provider_id: str

provider_type: str

class ProviderInfoWithConfig(BaseModel):
api: str
config: Dict[str, Union[bool, float, str, List[object], object, None]]

provider_id: str

provider_type: str

config: Dict[str, Any]
78 changes: 77 additions & 1 deletion tests/api_resources/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,52 @@

from tests.utils import assert_matches_type
from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient
from llama_stack_client.types import ProviderListResponse
from llama_stack_client.types import ProviderInfo, ProviderListResponse

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")


class TestProviders:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])

@parametrize
def test_method_retrieve(self, client: LlamaStackClient) -> None:
provider = client.providers.retrieve(
"provider_id",
)
assert_matches_type(ProviderInfo, provider, path=["response"])

@parametrize
def test_raw_response_retrieve(self, client: LlamaStackClient) -> None:
response = client.providers.with_raw_response.retrieve(
"provider_id",
)

assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
provider = response.parse()
assert_matches_type(ProviderInfo, provider, path=["response"])

@parametrize
def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None:
with client.providers.with_streaming_response.retrieve(
"provider_id",
) as response:
assert not response.is_closed
assert response.http_request.headers.get("X-Stainless-Lang") == "python"

provider = response.parse()
assert_matches_type(ProviderInfo, provider, path=["response"])

assert cast(Any, response.is_closed) is True

@parametrize
def test_path_params_retrieve(self, client: LlamaStackClient) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `provider_id` but received ''"):
client.providers.with_raw_response.retrieve(
"",
)

@parametrize
def test_method_list(self, client: LlamaStackClient) -> None:
provider = client.providers.list()
Expand Down Expand Up @@ -46,6 +84,44 @@ def test_streaming_response_list(self, client: LlamaStackClient) -> None:
class TestAsyncProviders:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])

@parametrize
async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> None:
provider = await async_client.providers.retrieve(
"provider_id",
)
assert_matches_type(ProviderInfo, provider, path=["response"])

@parametrize
async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None:
response = await async_client.providers.with_raw_response.retrieve(
"provider_id",
)

assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
provider = await response.parse()
assert_matches_type(ProviderInfo, provider, path=["response"])

@parametrize
async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None:
async with async_client.providers.with_streaming_response.retrieve(
"provider_id",
) as response:
assert not response.is_closed
assert response.http_request.headers.get("X-Stainless-Lang") == "python"

provider = await response.parse()
assert_matches_type(ProviderInfo, provider, path=["response"])

assert cast(Any, response.is_closed) is True

@parametrize
async def test_path_params_retrieve(self, async_client: AsyncLlamaStackClient) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `provider_id` but received ''"):
await async_client.providers.with_raw_response.retrieve(
"",
)

@parametrize
async def test_method_list(self, async_client: AsyncLlamaStackClient) -> None:
provider = await async_client.providers.list()
Expand Down