diff --git a/src/llama_stack_client/_client.py b/src/llama_stack_client/_client.py index e0094eda..00922520 100644 --- a/src/llama_stack_client/_client.py +++ b/src/llama_stack_client/_client.py @@ -34,7 +34,6 @@ scoring, shields, datasets, - datasetio, inference, providers, telemetry, @@ -90,7 +89,6 @@ class LlamaStackClient(SyncAPIClient): shields: shields.ShieldsResource synthetic_data_generation: synthetic_data_generation.SyntheticDataGenerationResource telemetry: telemetry.TelemetryResource - datasetio: datasetio.DatasetioResource scoring: scoring.ScoringResource scoring_functions: scoring_functions.ScoringFunctionsResource benchmarks: benchmarks.BenchmarksResource @@ -172,7 +170,6 @@ def __init__( self.shields = shields.ShieldsResource(self) self.synthetic_data_generation = synthetic_data_generation.SyntheticDataGenerationResource(self) self.telemetry = telemetry.TelemetryResource(self) - self.datasetio = datasetio.DatasetioResource(self) self.scoring = scoring.ScoringResource(self) self.scoring_functions = scoring_functions.ScoringFunctionsResource(self) self.benchmarks = benchmarks.BenchmarksResource(self) @@ -306,7 +303,6 @@ class AsyncLlamaStackClient(AsyncAPIClient): shields: shields.AsyncShieldsResource synthetic_data_generation: synthetic_data_generation.AsyncSyntheticDataGenerationResource telemetry: telemetry.AsyncTelemetryResource - datasetio: datasetio.AsyncDatasetioResource scoring: scoring.AsyncScoringResource scoring_functions: scoring_functions.AsyncScoringFunctionsResource benchmarks: benchmarks.AsyncBenchmarksResource @@ -388,7 +384,6 @@ def __init__( self.shields = shields.AsyncShieldsResource(self) self.synthetic_data_generation = synthetic_data_generation.AsyncSyntheticDataGenerationResource(self) self.telemetry = telemetry.AsyncTelemetryResource(self) - self.datasetio = datasetio.AsyncDatasetioResource(self) self.scoring = scoring.AsyncScoringResource(self) self.scoring_functions = scoring_functions.AsyncScoringFunctionsResource(self) self.benchmarks = benchmarks.AsyncBenchmarksResource(self) @@ -525,7 +520,6 @@ def __init__(self, client: LlamaStackClient) -> None: client.synthetic_data_generation ) self.telemetry = telemetry.TelemetryResourceWithRawResponse(client.telemetry) - self.datasetio = datasetio.DatasetioResourceWithRawResponse(client.datasetio) self.scoring = scoring.ScoringResourceWithRawResponse(client.scoring) self.scoring_functions = scoring_functions.ScoringFunctionsResourceWithRawResponse(client.scoring_functions) self.benchmarks = benchmarks.BenchmarksResourceWithRawResponse(client.benchmarks) @@ -554,7 +548,6 @@ def __init__(self, client: AsyncLlamaStackClient) -> None: client.synthetic_data_generation ) self.telemetry = telemetry.AsyncTelemetryResourceWithRawResponse(client.telemetry) - self.datasetio = datasetio.AsyncDatasetioResourceWithRawResponse(client.datasetio) self.scoring = scoring.AsyncScoringResourceWithRawResponse(client.scoring) self.scoring_functions = scoring_functions.AsyncScoringFunctionsResourceWithRawResponse( client.scoring_functions @@ -585,7 +578,6 @@ def __init__(self, client: LlamaStackClient) -> None: client.synthetic_data_generation ) self.telemetry = telemetry.TelemetryResourceWithStreamingResponse(client.telemetry) - self.datasetio = datasetio.DatasetioResourceWithStreamingResponse(client.datasetio) self.scoring = scoring.ScoringResourceWithStreamingResponse(client.scoring) self.scoring_functions = scoring_functions.ScoringFunctionsResourceWithStreamingResponse( client.scoring_functions @@ -618,7 +610,6 @@ def __init__(self, client: AsyncLlamaStackClient) -> None: ) ) self.telemetry = telemetry.AsyncTelemetryResourceWithStreamingResponse(client.telemetry) - self.datasetio = datasetio.AsyncDatasetioResourceWithStreamingResponse(client.datasetio) self.scoring = scoring.AsyncScoringResourceWithStreamingResponse(client.scoring) self.scoring_functions = scoring_functions.AsyncScoringFunctionsResourceWithStreamingResponse( client.scoring_functions diff --git a/src/llama_stack_client/_models.py b/src/llama_stack_client/_models.py index c4401ff8..b51a1bf5 100644 --- a/src/llama_stack_client/_models.py +++ b/src/llama_stack_client/_models.py @@ -65,7 +65,7 @@ from ._constants import RAW_RESPONSE_HEADER if TYPE_CHECKING: - from pydantic_core.core_schema import ModelField, LiteralSchema, ModelFieldsSchema + from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema __all__ = ["BaseModel", "GenericModel"] @@ -646,15 +646,18 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None: schema = model.__pydantic_core_schema__ + if schema["type"] == "definitions": + schema = schema["schema"] + if schema["type"] != "model": return None + schema = cast("ModelSchema", schema) fields_schema = schema["schema"] if fields_schema["type"] != "model-fields": return None fields_schema = cast("ModelFieldsSchema", fields_schema) - field = fields_schema["fields"].get(field_name) if not field: return None diff --git a/src/llama_stack_client/pagination.py b/src/llama_stack_client/pagination.py new file mode 100644 index 00000000..c2f7fe80 --- /dev/null +++ b/src/llama_stack_client/pagination.py @@ -0,0 +1,50 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Generic, TypeVar, Optional +from typing_extensions import override + +from ._base_client import BasePage, PageInfo, BaseSyncPage, BaseAsyncPage + +__all__ = ["SyncDatasetsIterrows", "AsyncDatasetsIterrows"] + +_T = TypeVar("_T") + + +class SyncDatasetsIterrows(BaseSyncPage[_T], BasePage[_T], Generic[_T]): + data: List[_T] + next_index: Optional[int] = None + + @override + def _get_page_items(self) -> List[_T]: + data = self.data + if not data: + return [] + return data + + @override + def next_page_info(self) -> Optional[PageInfo]: + next_index = self.next_index + if not next_index: + return None + + return PageInfo(params={"start_index": next_index}) + + +class AsyncDatasetsIterrows(BaseAsyncPage[_T], BasePage[_T], Generic[_T]): + data: List[_T] + next_index: Optional[int] = None + + @override + def _get_page_items(self) -> List[_T]: + data = self.data + if not data: + return [] + return data + + @override + def next_page_info(self) -> Optional[PageInfo]: + next_index = self.next_index + if not next_index: + return None + + return PageInfo(params={"start_index": next_index}) diff --git a/src/llama_stack_client/resources/__init__.py b/src/llama_stack_client/resources/__init__.py index 449fb4a1..865d77e0 100644 --- a/src/llama_stack_client/resources/__init__.py +++ b/src/llama_stack_client/resources/__init__.py @@ -80,14 +80,6 @@ DatasetsResourceWithStreamingResponse, AsyncDatasetsResourceWithStreamingResponse, ) -from .datasetio import ( - DatasetioResource, - AsyncDatasetioResource, - DatasetioResourceWithRawResponse, - AsyncDatasetioResourceWithRawResponse, - DatasetioResourceWithStreamingResponse, - AsyncDatasetioResourceWithStreamingResponse, -) from .inference import ( InferenceResource, AsyncInferenceResource, @@ -300,12 +292,6 @@ "AsyncTelemetryResourceWithRawResponse", "TelemetryResourceWithStreamingResponse", "AsyncTelemetryResourceWithStreamingResponse", - "DatasetioResource", - "AsyncDatasetioResource", - "DatasetioResourceWithRawResponse", - "AsyncDatasetioResourceWithRawResponse", - "DatasetioResourceWithStreamingResponse", - "AsyncDatasetioResourceWithStreamingResponse", "ScoringResource", "AsyncScoringResource", "ScoringResourceWithRawResponse", diff --git a/src/llama_stack_client/resources/datasetio.py b/src/llama_stack_client/resources/datasetio.py deleted file mode 100644 index 23577926..00000000 --- a/src/llama_stack_client/resources/datasetio.py +++ /dev/null @@ -1,300 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import Dict, Union, Iterable - -import httpx - -from ..types import datasetio_append_rows_params, datasetio_get_rows_paginated_params -from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven -from .._utils import ( - maybe_transform, - async_maybe_transform, -) -from .._compat import cached_property -from .._resource import SyncAPIResource, AsyncAPIResource -from .._response import ( - to_raw_response_wrapper, - to_streamed_response_wrapper, - async_to_raw_response_wrapper, - async_to_streamed_response_wrapper, -) -from .._base_client import make_request_options -from ..types.paginated_rows_result import PaginatedRowsResult - -__all__ = ["DatasetioResource", "AsyncDatasetioResource"] - - -class DatasetioResource(SyncAPIResource): - @cached_property - def with_raw_response(self) -> DatasetioResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return DatasetioResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> DatasetioResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return DatasetioResourceWithStreamingResponse(self) - - def append_rows( - self, - *, - dataset_id: str, - rows: Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]], - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return self._post( - "/v1/datasetio/rows", - body=maybe_transform( - { - "dataset_id": dataset_id, - "rows": rows, - }, - datasetio_append_rows_params.DatasetioAppendRowsParams, - ), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=NoneType, - ) - - def get_rows_paginated( - self, - *, - dataset_id: str, - rows_in_page: int, - filter_condition: str | NotGiven = NOT_GIVEN, - page_token: str | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> PaginatedRowsResult: - """ - Get a paginated list of rows from a dataset. - - Args: - dataset_id: The ID of the dataset to get the rows from. - - rows_in_page: The number of rows to get per page. - - filter_condition: (Optional) A condition to filter the rows by. - - page_token: The token to get the next page of rows. - - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return self._get( - "/v1/datasetio/rows", - options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=maybe_transform( - { - "dataset_id": dataset_id, - "rows_in_page": rows_in_page, - "filter_condition": filter_condition, - "page_token": page_token, - }, - datasetio_get_rows_paginated_params.DatasetioGetRowsPaginatedParams, - ), - ), - cast_to=PaginatedRowsResult, - ) - - -class AsyncDatasetioResource(AsyncAPIResource): - @cached_property - def with_raw_response(self) -> AsyncDatasetioResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return AsyncDatasetioResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> AsyncDatasetioResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return AsyncDatasetioResourceWithStreamingResponse(self) - - async def append_rows( - self, - *, - dataset_id: str, - rows: Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]], - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return await self._post( - "/v1/datasetio/rows", - body=await async_maybe_transform( - { - "dataset_id": dataset_id, - "rows": rows, - }, - datasetio_append_rows_params.DatasetioAppendRowsParams, - ), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=NoneType, - ) - - async def get_rows_paginated( - self, - *, - dataset_id: str, - rows_in_page: int, - filter_condition: str | NotGiven = NOT_GIVEN, - page_token: str | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> PaginatedRowsResult: - """ - Get a paginated list of rows from a dataset. - - Args: - dataset_id: The ID of the dataset to get the rows from. - - rows_in_page: The number of rows to get per page. - - filter_condition: (Optional) A condition to filter the rows by. - - page_token: The token to get the next page of rows. - - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return await self._get( - "/v1/datasetio/rows", - options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=await async_maybe_transform( - { - "dataset_id": dataset_id, - "rows_in_page": rows_in_page, - "filter_condition": filter_condition, - "page_token": page_token, - }, - datasetio_get_rows_paginated_params.DatasetioGetRowsPaginatedParams, - ), - ), - cast_to=PaginatedRowsResult, - ) - - -class DatasetioResourceWithRawResponse: - def __init__(self, datasetio: DatasetioResource) -> None: - self._datasetio = datasetio - - self.append_rows = to_raw_response_wrapper( - datasetio.append_rows, - ) - self.get_rows_paginated = to_raw_response_wrapper( - datasetio.get_rows_paginated, - ) - - -class AsyncDatasetioResourceWithRawResponse: - def __init__(self, datasetio: AsyncDatasetioResource) -> None: - self._datasetio = datasetio - - self.append_rows = async_to_raw_response_wrapper( - datasetio.append_rows, - ) - self.get_rows_paginated = async_to_raw_response_wrapper( - datasetio.get_rows_paginated, - ) - - -class DatasetioResourceWithStreamingResponse: - def __init__(self, datasetio: DatasetioResource) -> None: - self._datasetio = datasetio - - self.append_rows = to_streamed_response_wrapper( - datasetio.append_rows, - ) - self.get_rows_paginated = to_streamed_response_wrapper( - datasetio.get_rows_paginated, - ) - - -class AsyncDatasetioResourceWithStreamingResponse: - def __init__(self, datasetio: AsyncDatasetioResource) -> None: - self._datasetio = datasetio - - self.append_rows = async_to_streamed_response_wrapper( - datasetio.append_rows, - ) - self.get_rows_paginated = async_to_streamed_response_wrapper( - datasetio.get_rows_paginated, - ) diff --git a/src/llama_stack_client/resources/datasets.py b/src/llama_stack_client/resources/datasets.py index 144769f9..c4c6dc94 100644 --- a/src/llama_stack_client/resources/datasets.py +++ b/src/llama_stack_client/resources/datasets.py @@ -3,10 +3,11 @@ from __future__ import annotations from typing import Dict, Type, Union, Iterable, Optional, cast +from typing_extensions import Literal import httpx -from ..types import dataset_register_params +from ..types import dataset_iterrows_params, dataset_register_params from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from .._utils import ( maybe_transform, @@ -23,7 +24,8 @@ from .._wrappers import DataWrapper from .._base_client import make_request_options from ..types.dataset_list_response import DatasetListResponse -from ..types.shared_params.param_type import ParamType +from ..types.dataset_iterrows_response import DatasetIterrowsResponse +from ..types.dataset_register_response import DatasetRegisterResponse from ..types.dataset_retrieve_response import DatasetRetrieveResponse __all__ = ["DatasetsResource", "AsyncDatasetsResource"] @@ -102,24 +104,100 @@ def list( cast_to=cast(Type[DatasetListResponse], DataWrapper[DatasetListResponse]), ) + def iterrows( + self, + dataset_id: str, + *, + limit: int | NotGiven = NOT_GIVEN, + start_index: int | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> DatasetIterrowsResponse: + """Get a paginated list of rows from a dataset. + + Uses cursor-based pagination. + + Args: + limit: The number of rows to get per page. + + start_index: Index into dataset for the first row to get. Get all rows if None. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not dataset_id: + raise ValueError(f"Expected a non-empty value for `dataset_id` but received {dataset_id!r}") + return self._get( + f"/v1/datasetio/iterrows/{dataset_id}", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "limit": limit, + "start_index": start_index, + }, + dataset_iterrows_params.DatasetIterrowsParams, + ), + ), + cast_to=DatasetIterrowsResponse, + ) + def register( self, *, - dataset_id: str, - dataset_schema: Dict[str, ParamType], - url: dataset_register_params.URL, + purpose: Literal["post-training/messages", "eval/question-answer", "eval/messages-answer"], + source: dataset_register_params.Source, + dataset_id: str | NotGiven = NOT_GIVEN, metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN, - provider_dataset_id: str | NotGiven = NOT_GIVEN, - provider_id: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: - """ + ) -> DatasetRegisterResponse: + """Register a new dataset. + Args: + purpose: The purpose of the dataset. + + One of - "post-training/messages": The dataset + contains a messages column with list of messages for post-training. { + "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": + "assistant", "content": "Hello, world!"}, ] } - "eval/question-answer": The + dataset contains a question column and an answer column for evaluation. { + "question": "What is the capital of France?", "answer": "Paris" } - + "eval/messages-answer": The dataset contains a messages column with list of + messages and an answer column for evaluation. { "messages": [ {"role": "user", + "content": "Hello, my name is John Doe."}, {"role": "assistant", "content": + "Hello, John Doe. How can I help you today?"}, {"role": "user", "content": + "What's my name?"}, ], "answer": "John Doe" } + + source: + The data source of the dataset. Examples: - { "type": "uri", "uri": + "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": + "lsfs://mydata.jsonl" } - { "type": "uri", "uri": + "data:csv;base64,{base64_content}" } - { "type": "uri", "uri": + "huggingface://llamastack/simpleqa?split=train" } - { "type": "rows", "rows": [ + { "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": + "assistant", "content": "Hello, world!"}, ] } ] } + + dataset_id: The ID of the dataset. If not provided, an ID will be generated. + + metadata: The metadata for the dataset. - E.g. {"description": "My dataset"} + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -128,24 +206,21 @@ def register( timeout: Override the client-level default timeout for this request, in seconds """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return self._post( "/v1/datasets", body=maybe_transform( { + "purpose": purpose, + "source": source, "dataset_id": dataset_id, - "dataset_schema": dataset_schema, - "url": url, "metadata": metadata, - "provider_dataset_id": provider_dataset_id, - "provider_id": provider_id, }, dataset_register_params.DatasetRegisterParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=DatasetRegisterResponse, ) def unregister( @@ -254,24 +329,100 @@ async def list( cast_to=cast(Type[DatasetListResponse], DataWrapper[DatasetListResponse]), ) + async def iterrows( + self, + dataset_id: str, + *, + limit: int | NotGiven = NOT_GIVEN, + start_index: int | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> DatasetIterrowsResponse: + """Get a paginated list of rows from a dataset. + + Uses cursor-based pagination. + + Args: + limit: The number of rows to get per page. + + start_index: Index into dataset for the first row to get. Get all rows if None. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not dataset_id: + raise ValueError(f"Expected a non-empty value for `dataset_id` but received {dataset_id!r}") + return await self._get( + f"/v1/datasetio/iterrows/{dataset_id}", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform( + { + "limit": limit, + "start_index": start_index, + }, + dataset_iterrows_params.DatasetIterrowsParams, + ), + ), + cast_to=DatasetIterrowsResponse, + ) + async def register( self, *, - dataset_id: str, - dataset_schema: Dict[str, ParamType], - url: dataset_register_params.URL, + purpose: Literal["post-training/messages", "eval/question-answer", "eval/messages-answer"], + source: dataset_register_params.Source, + dataset_id: str | NotGiven = NOT_GIVEN, metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN, - provider_dataset_id: str | NotGiven = NOT_GIVEN, - provider_id: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: - """ + ) -> DatasetRegisterResponse: + """Register a new dataset. + Args: + purpose: The purpose of the dataset. + + One of - "post-training/messages": The dataset + contains a messages column with list of messages for post-training. { + "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": + "assistant", "content": "Hello, world!"}, ] } - "eval/question-answer": The + dataset contains a question column and an answer column for evaluation. { + "question": "What is the capital of France?", "answer": "Paris" } - + "eval/messages-answer": The dataset contains a messages column with list of + messages and an answer column for evaluation. { "messages": [ {"role": "user", + "content": "Hello, my name is John Doe."}, {"role": "assistant", "content": + "Hello, John Doe. How can I help you today?"}, {"role": "user", "content": + "What's my name?"}, ], "answer": "John Doe" } + + source: + The data source of the dataset. Examples: - { "type": "uri", "uri": + "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": + "lsfs://mydata.jsonl" } - { "type": "uri", "uri": + "data:csv;base64,{base64_content}" } - { "type": "uri", "uri": + "huggingface://llamastack/simpleqa?split=train" } - { "type": "rows", "rows": [ + { "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": + "assistant", "content": "Hello, world!"}, ] } ] } + + dataset_id: The ID of the dataset. If not provided, an ID will be generated. + + metadata: The metadata for the dataset. - E.g. {"description": "My dataset"} + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -280,24 +431,21 @@ async def register( timeout: Override the client-level default timeout for this request, in seconds """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return await self._post( "/v1/datasets", body=await async_maybe_transform( { + "purpose": purpose, + "source": source, "dataset_id": dataset_id, - "dataset_schema": dataset_schema, - "url": url, "metadata": metadata, - "provider_dataset_id": provider_dataset_id, - "provider_id": provider_id, }, dataset_register_params.DatasetRegisterParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=DatasetRegisterResponse, ) async def unregister( @@ -343,6 +491,9 @@ def __init__(self, datasets: DatasetsResource) -> None: self.list = to_raw_response_wrapper( datasets.list, ) + self.iterrows = to_raw_response_wrapper( + datasets.iterrows, + ) self.register = to_raw_response_wrapper( datasets.register, ) @@ -361,6 +512,9 @@ def __init__(self, datasets: AsyncDatasetsResource) -> None: self.list = async_to_raw_response_wrapper( datasets.list, ) + self.iterrows = async_to_raw_response_wrapper( + datasets.iterrows, + ) self.register = async_to_raw_response_wrapper( datasets.register, ) @@ -379,6 +533,9 @@ def __init__(self, datasets: DatasetsResource) -> None: self.list = to_streamed_response_wrapper( datasets.list, ) + self.iterrows = to_streamed_response_wrapper( + datasets.iterrows, + ) self.register = to_streamed_response_wrapper( datasets.register, ) @@ -397,6 +554,9 @@ def __init__(self, datasets: AsyncDatasetsResource) -> None: self.list = async_to_streamed_response_wrapper( datasets.list, ) + self.iterrows = async_to_streamed_response_wrapper( + datasets.iterrows, + ) self.register = async_to_streamed_response_wrapper( datasets.register, ) diff --git a/src/llama_stack_client/types/__init__.py b/src/llama_stack_client/types/__init__.py index f0d6c2e2..b45996a9 100644 --- a/src/llama_stack_client/types/__init__.py +++ b/src/llama_stack_client/types/__init__.py @@ -74,7 +74,6 @@ from .list_shields_response import ListShieldsResponse as ListShieldsResponse from .memory_retrieval_step import MemoryRetrievalStep as MemoryRetrievalStep from .model_register_params import ModelRegisterParams as ModelRegisterParams -from .paginated_rows_result import PaginatedRowsResult as PaginatedRowsResult from .query_chunks_response import QueryChunksResponse as QueryChunksResponse from .query_condition_param import QueryConditionParam as QueryConditionParam from .algorithm_config_param import AlgorithmConfigParam as AlgorithmConfigParam @@ -86,6 +85,7 @@ from .tool_invocation_result import ToolInvocationResult as ToolInvocationResult from .vector_io_query_params import VectorIoQueryParams as VectorIoQueryParams from .benchmark_list_response import BenchmarkListResponse as BenchmarkListResponse +from .dataset_iterrows_params import DatasetIterrowsParams as DatasetIterrowsParams from .dataset_register_params import DatasetRegisterParams as DatasetRegisterParams from .list_providers_response import ListProvidersResponse as ListProvidersResponse from .scoring_fn_params_param import ScoringFnParamsParam as ScoringFnParamsParam @@ -96,6 +96,8 @@ from .list_vector_dbs_response import ListVectorDBsResponse as ListVectorDBsResponse from .safety_run_shield_params import SafetyRunShieldParams as SafetyRunShieldParams from .benchmark_register_params import BenchmarkRegisterParams as BenchmarkRegisterParams +from .dataset_iterrows_response import DatasetIterrowsResponse as DatasetIterrowsResponse +from .dataset_register_response import DatasetRegisterResponse as DatasetRegisterResponse from .dataset_retrieve_response import DatasetRetrieveResponse as DatasetRetrieveResponse from .eval_evaluate_rows_params import EvalEvaluateRowsParams as EvalEvaluateRowsParams from .list_tool_groups_response import ListToolGroupsResponse as ListToolGroupsResponse @@ -109,7 +111,6 @@ from .telemetry_get_span_response import TelemetryGetSpanResponse as TelemetryGetSpanResponse from .vector_db_register_response import VectorDBRegisterResponse as VectorDBRegisterResponse from .vector_db_retrieve_response import VectorDBRetrieveResponse as VectorDBRetrieveResponse -from .datasetio_append_rows_params import DatasetioAppendRowsParams as DatasetioAppendRowsParams from .scoring_score_batch_response import ScoringScoreBatchResponse as ScoringScoreBatchResponse from .telemetry_query_spans_params import TelemetryQuerySpansParams as TelemetryQuerySpansParams from .telemetry_query_traces_params import TelemetryQueryTracesParams as TelemetryQueryTracesParams @@ -127,7 +128,6 @@ from .telemetry_get_span_tree_response import TelemetryGetSpanTreeResponse as TelemetryGetSpanTreeResponse from .batch_inference_completion_params import BatchInferenceCompletionParams as BatchInferenceCompletionParams from .synthetic_data_generation_response import SyntheticDataGenerationResponse as SyntheticDataGenerationResponse -from .datasetio_get_rows_paginated_params import DatasetioGetRowsPaginatedParams as DatasetioGetRowsPaginatedParams from .chat_completion_response_stream_chunk import ( ChatCompletionResponseStreamChunk as ChatCompletionResponseStreamChunk, ) diff --git a/src/llama_stack_client/types/dataset_iterrows_params.py b/src/llama_stack_client/types/dataset_iterrows_params.py new file mode 100644 index 00000000..5c38d7c1 --- /dev/null +++ b/src/llama_stack_client/types/dataset_iterrows_params.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import TypedDict + +__all__ = ["DatasetIterrowsParams"] + + +class DatasetIterrowsParams(TypedDict, total=False): + limit: int + """The number of rows to get per page.""" + + start_index: int + """Index into dataset for the first row to get. Get all rows if None.""" diff --git a/src/llama_stack_client/types/dataset_iterrows_response.py b/src/llama_stack_client/types/dataset_iterrows_response.py new file mode 100644 index 00000000..f82233b5 --- /dev/null +++ b/src/llama_stack_client/types/dataset_iterrows_response.py @@ -0,0 +1,18 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union, Optional + +from .._models import BaseModel + +__all__ = ["DatasetIterrowsResponse"] + + +class DatasetIterrowsResponse(BaseModel): + data: List[Dict[str, Union[bool, float, str, List[object], object, None]]] + """The rows in the current page.""" + + next_index: Optional[int] = None + """Index into dataset for the first row in the next page. + + None if there are no more rows. + """ diff --git a/src/llama_stack_client/types/dataset_list_response.py b/src/llama_stack_client/types/dataset_list_response.py index 1dc2afa4..902c6274 100644 --- a/src/llama_stack_client/types/dataset_list_response.py +++ b/src/llama_stack_client/types/dataset_list_response.py @@ -1,21 +1,49 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Dict, List, Union -from typing_extensions import Literal, TypeAlias +from typing_extensions import Literal, Annotated, TypeAlias +from .._utils import PropertyInfo from .._models import BaseModel -from .shared.param_type import ParamType -__all__ = ["DatasetListResponse", "DatasetListResponseItem", "DatasetListResponseItemURL"] +__all__ = [ + "DatasetListResponse", + "DatasetListResponseItem", + "DatasetListResponseItemSource", + "DatasetListResponseItemSourceUriDataSource", + "DatasetListResponseItemSourceRowsDataSource", +] -class DatasetListResponseItemURL(BaseModel): +class DatasetListResponseItemSourceUriDataSource(BaseModel): + type: Literal["uri"] + uri: str + """The dataset can be obtained from a URI. + E.g. - "https://mywebsite.com/mydata.jsonl" - "lsfs://mydata.jsonl" - + "data:csv;base64,{base64_content}" + """ + + +class DatasetListResponseItemSourceRowsDataSource(BaseModel): + rows: List[Dict[str, Union[bool, float, str, List[object], object, None]]] + """The dataset is stored in rows. + + E.g. - [ {"messages": [{"role": "user", "content": "Hello, world!"}, {"role": + "assistant", "content": "Hello, world!"}]} ] + """ + + type: Literal["rows"] -class DatasetListResponseItem(BaseModel): - dataset_schema: Dict[str, ParamType] +DatasetListResponseItemSource: TypeAlias = Annotated[ + Union[DatasetListResponseItemSourceUriDataSource, DatasetListResponseItemSourceRowsDataSource], + PropertyInfo(discriminator="type"), +] + + +class DatasetListResponseItem(BaseModel): identifier: str metadata: Dict[str, Union[bool, float, str, List[object], object, None]] @@ -24,9 +52,13 @@ class DatasetListResponseItem(BaseModel): provider_resource_id: str - type: Literal["dataset"] + purpose: Literal["post-training/messages", "eval/question-answer", "eval/messages-answer"] + """Purpose of the dataset. Each purpose has a required input data schema.""" + + source: DatasetListResponseItemSource + """A dataset that can be obtained from a URI.""" - url: DatasetListResponseItemURL + type: Literal["dataset"] DatasetListResponse: TypeAlias = List[DatasetListResponseItem] diff --git a/src/llama_stack_client/types/dataset_register_params.py b/src/llama_stack_client/types/dataset_register_params.py index 1c1cf234..d2ff9d3a 100644 --- a/src/llama_stack_client/types/dataset_register_params.py +++ b/src/llama_stack_client/types/dataset_register_params.py @@ -3,26 +3,65 @@ from __future__ import annotations from typing import Dict, Union, Iterable -from typing_extensions import Required, TypedDict +from typing_extensions import Literal, Required, TypeAlias, TypedDict -from .shared_params.param_type import ParamType - -__all__ = ["DatasetRegisterParams", "URL"] +__all__ = ["DatasetRegisterParams", "Source", "SourceUriDataSource", "SourceRowsDataSource"] class DatasetRegisterParams(TypedDict, total=False): - dataset_id: Required[str] + purpose: Required[Literal["post-training/messages", "eval/question-answer", "eval/messages-answer"]] + """The purpose of the dataset. - dataset_schema: Required[Dict[str, ParamType]] + One of - "post-training/messages": The dataset contains a messages column with + list of messages for post-training. { "messages": [ {"role": "user", "content": + "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}, ] } - + "eval/question-answer": The dataset contains a question column and an answer + column for evaluation. { "question": "What is the capital of France?", "answer": + "Paris" } - "eval/messages-answer": The dataset contains a messages column with + list of messages and an answer column for evaluation. { "messages": [ {"role": + "user", "content": "Hello, my name is John Doe."}, {"role": "assistant", + "content": "Hello, John Doe. How can I help you today?"}, {"role": "user", + "content": "What's my name?"}, ], "answer": "John Doe" } + """ - url: Required[URL] + source: Required[Source] + """The data source of the dataset. - metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] + Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl" } - { + "type": "uri", "uri": "lsfs://mydata.jsonl" } - { "type": "uri", "uri": + "data:csv;base64,{base64_content}" } - { "type": "uri", "uri": + "huggingface://llamastack/simpleqa?split=train" } - { "type": "rows", "rows": [ + { "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": + "assistant", "content": "Hello, world!"}, ] } ] } + """ + + dataset_id: str + """The ID of the dataset. If not provided, an ID will be generated.""" - provider_dataset_id: str + metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] + """The metadata for the dataset. - E.g. {"description": "My dataset"}""" - provider_id: str +class SourceUriDataSource(TypedDict, total=False): + type: Required[Literal["uri"]] -class URL(TypedDict, total=False): uri: Required[str] + """The dataset can be obtained from a URI. + + E.g. - "https://mywebsite.com/mydata.jsonl" - "lsfs://mydata.jsonl" - + "data:csv;base64,{base64_content}" + """ + + +class SourceRowsDataSource(TypedDict, total=False): + rows: Required[Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]]] + """The dataset is stored in rows. + + E.g. - [ {"messages": [{"role": "user", "content": "Hello, world!"}, {"role": + "assistant", "content": "Hello, world!"}]} ] + """ + + type: Required[Literal["rows"]] + + +Source: TypeAlias = Union[SourceUriDataSource, SourceRowsDataSource] diff --git a/src/llama_stack_client/types/dataset_register_response.py b/src/llama_stack_client/types/dataset_register_response.py new file mode 100644 index 00000000..8038b192 --- /dev/null +++ b/src/llama_stack_client/types/dataset_register_response.py @@ -0,0 +1,52 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union +from typing_extensions import Literal, Annotated, TypeAlias + +from .._utils import PropertyInfo +from .._models import BaseModel + +__all__ = ["DatasetRegisterResponse", "Source", "SourceUriDataSource", "SourceRowsDataSource"] + + +class SourceUriDataSource(BaseModel): + type: Literal["uri"] + + uri: str + """The dataset can be obtained from a URI. + + E.g. - "https://mywebsite.com/mydata.jsonl" - "lsfs://mydata.jsonl" - + "data:csv;base64,{base64_content}" + """ + + +class SourceRowsDataSource(BaseModel): + rows: List[Dict[str, Union[bool, float, str, List[object], object, None]]] + """The dataset is stored in rows. + + E.g. - [ {"messages": [{"role": "user", "content": "Hello, world!"}, {"role": + "assistant", "content": "Hello, world!"}]} ] + """ + + type: Literal["rows"] + + +Source: TypeAlias = Annotated[Union[SourceUriDataSource, SourceRowsDataSource], PropertyInfo(discriminator="type")] + + +class DatasetRegisterResponse(BaseModel): + identifier: str + + metadata: Dict[str, Union[bool, float, str, List[object], object, None]] + + provider_id: str + + provider_resource_id: str + + purpose: Literal["post-training/messages", "eval/question-answer", "eval/messages-answer"] + """Purpose of the dataset. Each purpose has a required input data schema.""" + + source: Source + """A dataset that can be obtained from a URI.""" + + type: Literal["dataset"] diff --git a/src/llama_stack_client/types/dataset_retrieve_response.py b/src/llama_stack_client/types/dataset_retrieve_response.py index bd819a56..debce418 100644 --- a/src/llama_stack_client/types/dataset_retrieve_response.py +++ b/src/llama_stack_client/types/dataset_retrieve_response.py @@ -1,21 +1,40 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Dict, List, Union -from typing_extensions import Literal +from typing_extensions import Literal, Annotated, TypeAlias +from .._utils import PropertyInfo from .._models import BaseModel -from .shared.param_type import ParamType -__all__ = ["DatasetRetrieveResponse", "URL"] +__all__ = ["DatasetRetrieveResponse", "Source", "SourceUriDataSource", "SourceRowsDataSource"] -class URL(BaseModel): +class SourceUriDataSource(BaseModel): + type: Literal["uri"] + uri: str + """The dataset can be obtained from a URI. + E.g. - "https://mywebsite.com/mydata.jsonl" - "lsfs://mydata.jsonl" - + "data:csv;base64,{base64_content}" + """ + + +class SourceRowsDataSource(BaseModel): + rows: List[Dict[str, Union[bool, float, str, List[object], object, None]]] + """The dataset is stored in rows. + + E.g. - [ {"messages": [{"role": "user", "content": "Hello, world!"}, {"role": + "assistant", "content": "Hello, world!"}]} ] + """ + + type: Literal["rows"] -class DatasetRetrieveResponse(BaseModel): - dataset_schema: Dict[str, ParamType] +Source: TypeAlias = Annotated[Union[SourceUriDataSource, SourceRowsDataSource], PropertyInfo(discriminator="type")] + + +class DatasetRetrieveResponse(BaseModel): identifier: str metadata: Dict[str, Union[bool, float, str, List[object], object, None]] @@ -24,6 +43,10 @@ class DatasetRetrieveResponse(BaseModel): provider_resource_id: str - type: Literal["dataset"] + purpose: Literal["post-training/messages", "eval/question-answer", "eval/messages-answer"] + """Purpose of the dataset. Each purpose has a required input data schema.""" + + source: Source + """A dataset that can be obtained from a URI.""" - url: URL + type: Literal["dataset"] diff --git a/src/llama_stack_client/types/datasetio_append_rows_params.py b/src/llama_stack_client/types/datasetio_append_rows_params.py deleted file mode 100644 index 2378454c..00000000 --- a/src/llama_stack_client/types/datasetio_append_rows_params.py +++ /dev/null @@ -1,14 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import Dict, Union, Iterable -from typing_extensions import Required, TypedDict - -__all__ = ["DatasetioAppendRowsParams"] - - -class DatasetioAppendRowsParams(TypedDict, total=False): - dataset_id: Required[str] - - rows: Required[Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]]] diff --git a/src/llama_stack_client/types/datasetio_get_rows_paginated_params.py b/src/llama_stack_client/types/datasetio_get_rows_paginated_params.py deleted file mode 100644 index 7566c992..00000000 --- a/src/llama_stack_client/types/datasetio_get_rows_paginated_params.py +++ /dev/null @@ -1,21 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing_extensions import Required, TypedDict - -__all__ = ["DatasetioGetRowsPaginatedParams"] - - -class DatasetioGetRowsPaginatedParams(TypedDict, total=False): - dataset_id: Required[str] - """The ID of the dataset to get the rows from.""" - - rows_in_page: Required[int] - """The number of rows to get per page.""" - - filter_condition: str - """(Optional) A condition to filter the rows by.""" - - page_token: str - """The token to get the next page of rows.""" diff --git a/src/llama_stack_client/types/paginated_rows_result.py b/src/llama_stack_client/types/paginated_rows_result.py deleted file mode 100644 index 4eccb803..00000000 --- a/src/llama_stack_client/types/paginated_rows_result.py +++ /dev/null @@ -1,18 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing import Dict, List, Union, Optional - -from .._models import BaseModel - -__all__ = ["PaginatedRowsResult"] - - -class PaginatedRowsResult(BaseModel): - rows: List[Dict[str, Union[bool, float, str, List[object], object, None]]] - """The rows in the current page.""" - - total_count: int - """The total number of rows in the dataset.""" - - next_page_token: Optional[str] = None - """The token to get the next page of rows.""" diff --git a/src/llama_stack_client/types/shared_params/__init__.py b/src/llama_stack_client/types/shared_params/__init__.py index d647c238..bd623812 100644 --- a/src/llama_stack_client/types/shared_params/__init__.py +++ b/src/llama_stack_client/types/shared_params/__init__.py @@ -3,7 +3,6 @@ from .message import Message as Message from .document import Document as Document from .tool_call import ToolCall as ToolCall -from .param_type import ParamType as ParamType from .return_type import ReturnType as ReturnType from .agent_config import AgentConfig as AgentConfig from .query_config import QueryConfig as QueryConfig diff --git a/src/llama_stack_client/types/shared_params/param_type.py b/src/llama_stack_client/types/shared_params/param_type.py deleted file mode 100644 index b93dfeff..00000000 --- a/src/llama_stack_client/types/shared_params/param_type.py +++ /dev/null @@ -1,74 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import Union -from typing_extensions import Literal, Required, TypeAlias, TypedDict - -__all__ = [ - "ParamType", - "StringType", - "NumberType", - "BooleanType", - "ArrayType", - "ObjectType", - "JsonType", - "UnionType", - "ChatCompletionInputType", - "CompletionInputType", - "AgentTurnInputType", -] - - -class StringType(TypedDict, total=False): - type: Required[Literal["string"]] - - -class NumberType(TypedDict, total=False): - type: Required[Literal["number"]] - - -class BooleanType(TypedDict, total=False): - type: Required[Literal["boolean"]] - - -class ArrayType(TypedDict, total=False): - type: Required[Literal["array"]] - - -class ObjectType(TypedDict, total=False): - type: Required[Literal["object"]] - - -class JsonType(TypedDict, total=False): - type: Required[Literal["json"]] - - -class UnionType(TypedDict, total=False): - type: Required[Literal["union"]] - - -class ChatCompletionInputType(TypedDict, total=False): - type: Required[Literal["chat_completion_input"]] - - -class CompletionInputType(TypedDict, total=False): - type: Required[Literal["completion_input"]] - - -class AgentTurnInputType(TypedDict, total=False): - type: Required[Literal["agent_turn_input"]] - - -ParamType: TypeAlias = Union[ - StringType, - NumberType, - BooleanType, - ArrayType, - ObjectType, - JsonType, - UnionType, - ChatCompletionInputType, - CompletionInputType, - AgentTurnInputType, -] diff --git a/tests/api_resources/test_datasetio.py b/tests/api_resources/test_datasetio.py deleted file mode 100644 index cfd72d94..00000000 --- a/tests/api_resources/test_datasetio.py +++ /dev/null @@ -1,180 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -import os -from typing import Any, cast - -import pytest - -from tests.utils import assert_matches_type -from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import ( - PaginatedRowsResult, -) - -base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") - - -class TestDatasetio: - parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - def test_method_append_rows(self, client: LlamaStackClient) -> None: - datasetio = client.datasetio.append_rows( - dataset_id="dataset_id", - rows=[{"foo": True}], - ) - assert datasetio is None - - @parametrize - def test_raw_response_append_rows(self, client: LlamaStackClient) -> None: - response = client.datasetio.with_raw_response.append_rows( - dataset_id="dataset_id", - rows=[{"foo": True}], - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - datasetio = response.parse() - assert datasetio is None - - @parametrize - def test_streaming_response_append_rows(self, client: LlamaStackClient) -> None: - with client.datasetio.with_streaming_response.append_rows( - dataset_id="dataset_id", - rows=[{"foo": True}], - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - datasetio = response.parse() - assert datasetio is None - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_get_rows_paginated(self, client: LlamaStackClient) -> None: - datasetio = client.datasetio.get_rows_paginated( - dataset_id="dataset_id", - rows_in_page=0, - ) - assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) - - @parametrize - def test_method_get_rows_paginated_with_all_params(self, client: LlamaStackClient) -> None: - datasetio = client.datasetio.get_rows_paginated( - dataset_id="dataset_id", - rows_in_page=0, - filter_condition="filter_condition", - page_token="page_token", - ) - assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) - - @parametrize - def test_raw_response_get_rows_paginated(self, client: LlamaStackClient) -> None: - response = client.datasetio.with_raw_response.get_rows_paginated( - dataset_id="dataset_id", - rows_in_page=0, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - datasetio = response.parse() - assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) - - @parametrize - def test_streaming_response_get_rows_paginated(self, client: LlamaStackClient) -> None: - with client.datasetio.with_streaming_response.get_rows_paginated( - dataset_id="dataset_id", - rows_in_page=0, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - datasetio = response.parse() - assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) - - assert cast(Any, response.is_closed) is True - - -class TestAsyncDatasetio: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - async def test_method_append_rows(self, async_client: AsyncLlamaStackClient) -> None: - datasetio = await async_client.datasetio.append_rows( - dataset_id="dataset_id", - rows=[{"foo": True}], - ) - assert datasetio is None - - @parametrize - async def test_raw_response_append_rows(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.datasetio.with_raw_response.append_rows( - dataset_id="dataset_id", - rows=[{"foo": True}], - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - datasetio = await response.parse() - assert datasetio is None - - @parametrize - async def test_streaming_response_append_rows(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.datasetio.with_streaming_response.append_rows( - dataset_id="dataset_id", - rows=[{"foo": True}], - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - datasetio = await response.parse() - assert datasetio is None - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_get_rows_paginated(self, async_client: AsyncLlamaStackClient) -> None: - datasetio = await async_client.datasetio.get_rows_paginated( - dataset_id="dataset_id", - rows_in_page=0, - ) - assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) - - @parametrize - async def test_method_get_rows_paginated_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - datasetio = await async_client.datasetio.get_rows_paginated( - dataset_id="dataset_id", - rows_in_page=0, - filter_condition="filter_condition", - page_token="page_token", - ) - assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) - - @parametrize - async def test_raw_response_get_rows_paginated(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.datasetio.with_raw_response.get_rows_paginated( - dataset_id="dataset_id", - rows_in_page=0, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - datasetio = await response.parse() - assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) - - @parametrize - async def test_streaming_response_get_rows_paginated(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.datasetio.with_streaming_response.get_rows_paginated( - dataset_id="dataset_id", - rows_in_page=0, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - datasetio = await response.parse() - assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) - - assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_datasets.py b/tests/api_resources/test_datasets.py index 987f3c22..7f19e741 100644 --- a/tests/api_resources/test_datasets.py +++ b/tests/api_resources/test_datasets.py @@ -9,7 +9,12 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import DatasetListResponse, DatasetRetrieveResponse +from llama_stack_client.types import ( + DatasetListResponse, + DatasetIterrowsResponse, + DatasetRegisterResponse, + DatasetRetrieveResponse, +) base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -80,52 +85,106 @@ def test_streaming_response_list(self, client: LlamaStackClient) -> None: assert cast(Any, response.is_closed) is True + @parametrize + def test_method_iterrows(self, client: LlamaStackClient) -> None: + dataset = client.datasets.iterrows( + dataset_id="dataset_id", + ) + assert_matches_type(DatasetIterrowsResponse, dataset, path=["response"]) + + @parametrize + def test_method_iterrows_with_all_params(self, client: LlamaStackClient) -> None: + dataset = client.datasets.iterrows( + dataset_id="dataset_id", + limit=0, + start_index=0, + ) + assert_matches_type(DatasetIterrowsResponse, dataset, path=["response"]) + + @parametrize + def test_raw_response_iterrows(self, client: LlamaStackClient) -> None: + response = client.datasets.with_raw_response.iterrows( + dataset_id="dataset_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + dataset = response.parse() + assert_matches_type(DatasetIterrowsResponse, dataset, path=["response"]) + + @parametrize + def test_streaming_response_iterrows(self, client: LlamaStackClient) -> None: + with client.datasets.with_streaming_response.iterrows( + dataset_id="dataset_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + dataset = response.parse() + assert_matches_type(DatasetIterrowsResponse, dataset, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_iterrows(self, client: LlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `dataset_id` but received ''"): + client.datasets.with_raw_response.iterrows( + dataset_id="", + ) + @parametrize def test_method_register(self, client: LlamaStackClient) -> None: dataset = client.datasets.register( - dataset_id="dataset_id", - dataset_schema={"foo": {"type": "string"}}, - url={"uri": "uri"}, + purpose="post-training/messages", + source={ + "type": "uri", + "uri": "uri", + }, ) - assert dataset is None + assert_matches_type(DatasetRegisterResponse, dataset, path=["response"]) @parametrize def test_method_register_with_all_params(self, client: LlamaStackClient) -> None: dataset = client.datasets.register( + purpose="post-training/messages", + source={ + "type": "uri", + "uri": "uri", + }, dataset_id="dataset_id", - dataset_schema={"foo": {"type": "string"}}, - url={"uri": "uri"}, metadata={"foo": True}, - provider_dataset_id="provider_dataset_id", - provider_id="provider_id", ) - assert dataset is None + assert_matches_type(DatasetRegisterResponse, dataset, path=["response"]) @parametrize def test_raw_response_register(self, client: LlamaStackClient) -> None: response = client.datasets.with_raw_response.register( - dataset_id="dataset_id", - dataset_schema={"foo": {"type": "string"}}, - url={"uri": "uri"}, + purpose="post-training/messages", + source={ + "type": "uri", + "uri": "uri", + }, ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = response.parse() - assert dataset is None + assert_matches_type(DatasetRegisterResponse, dataset, path=["response"]) @parametrize def test_streaming_response_register(self, client: LlamaStackClient) -> None: with client.datasets.with_streaming_response.register( - dataset_id="dataset_id", - dataset_schema={"foo": {"type": "string"}}, - url={"uri": "uri"}, + purpose="post-training/messages", + source={ + "type": "uri", + "uri": "uri", + }, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = response.parse() - assert dataset is None + assert_matches_type(DatasetRegisterResponse, dataset, path=["response"]) assert cast(Any, response.is_closed) is True @@ -234,52 +293,106 @@ async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient assert cast(Any, response.is_closed) is True + @parametrize + async def test_method_iterrows(self, async_client: AsyncLlamaStackClient) -> None: + dataset = await async_client.datasets.iterrows( + dataset_id="dataset_id", + ) + assert_matches_type(DatasetIterrowsResponse, dataset, path=["response"]) + + @parametrize + async def test_method_iterrows_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + dataset = await async_client.datasets.iterrows( + dataset_id="dataset_id", + limit=0, + start_index=0, + ) + assert_matches_type(DatasetIterrowsResponse, dataset, path=["response"]) + + @parametrize + async def test_raw_response_iterrows(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.datasets.with_raw_response.iterrows( + dataset_id="dataset_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + dataset = await response.parse() + assert_matches_type(DatasetIterrowsResponse, dataset, path=["response"]) + + @parametrize + async def test_streaming_response_iterrows(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.datasets.with_streaming_response.iterrows( + dataset_id="dataset_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + dataset = await response.parse() + assert_matches_type(DatasetIterrowsResponse, dataset, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_iterrows(self, async_client: AsyncLlamaStackClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `dataset_id` but received ''"): + await async_client.datasets.with_raw_response.iterrows( + dataset_id="", + ) + @parametrize async def test_method_register(self, async_client: AsyncLlamaStackClient) -> None: dataset = await async_client.datasets.register( - dataset_id="dataset_id", - dataset_schema={"foo": {"type": "string"}}, - url={"uri": "uri"}, + purpose="post-training/messages", + source={ + "type": "uri", + "uri": "uri", + }, ) - assert dataset is None + assert_matches_type(DatasetRegisterResponse, dataset, path=["response"]) @parametrize async def test_method_register_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: dataset = await async_client.datasets.register( + purpose="post-training/messages", + source={ + "type": "uri", + "uri": "uri", + }, dataset_id="dataset_id", - dataset_schema={"foo": {"type": "string"}}, - url={"uri": "uri"}, metadata={"foo": True}, - provider_dataset_id="provider_dataset_id", - provider_id="provider_id", ) - assert dataset is None + assert_matches_type(DatasetRegisterResponse, dataset, path=["response"]) @parametrize async def test_raw_response_register(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.datasets.with_raw_response.register( - dataset_id="dataset_id", - dataset_schema={"foo": {"type": "string"}}, - url={"uri": "uri"}, + purpose="post-training/messages", + source={ + "type": "uri", + "uri": "uri", + }, ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = await response.parse() - assert dataset is None + assert_matches_type(DatasetRegisterResponse, dataset, path=["response"]) @parametrize async def test_streaming_response_register(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.datasets.with_streaming_response.register( - dataset_id="dataset_id", - dataset_schema={"foo": {"type": "string"}}, - url={"uri": "uri"}, + purpose="post-training/messages", + source={ + "type": "uri", + "uri": "uri", + }, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = await response.parse() - assert dataset is None + assert_matches_type(DatasetRegisterResponse, dataset, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/test_models.py b/tests/test_models.py index ee96638a..8b4c0bc9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -854,3 +854,35 @@ class Model(BaseModel): m = construct_type(value={"cls": "foo"}, type_=Model) assert isinstance(m, Model) assert isinstance(m.cls, str) + + +def test_discriminated_union_case() -> None: + class A(BaseModel): + type: Literal["a"] + + data: bool + + class B(BaseModel): + type: Literal["b"] + + data: List[Union[A, object]] + + class ModelA(BaseModel): + type: Literal["modelA"] + + data: int + + class ModelB(BaseModel): + type: Literal["modelB"] + + required: str + + data: Union[A, B] + + # when constructing ModelA | ModelB, value data doesn't match ModelB exactly - missing `required` + m = construct_type( + value={"type": "modelB", "data": {"type": "a", "data": True}}, + type_=cast(Any, Annotated[Union[ModelA, ModelB], PropertyInfo(discriminator="type")]), + ) + + assert isinstance(m, ModelB)