diff --git a/.mock/definition/__package__.yml b/.mock/definition/__package__.yml index ae77adcd4..1b195fc5c 100644 --- a/.mock/definition/__package__.yml +++ b/.mock/definition/__package__.yml @@ -2347,6 +2347,19 @@ types: organization: optional source: openapi: openapi/openapi.yaml + InferenceRunCostEstimate: + properties: + prompt_cost_usd: + type: optional + docs: Cost of the prompt (in USD) + completion_cost_usd: + type: optional + docs: Cost of the completion (in USD) + total_cost_usd: + type: optional + docs: Total cost of the inference (in USD) + source: + openapi: openapi/openapi.yaml InferenceRunOrganization: discriminated: false union: diff --git a/.mock/definition/prompts/versions.yml b/.mock/definition/prompts/versions.yml index 8a69a01d9..35b41766d 100644 --- a/.mock/definition/prompts/versions.yml +++ b/.mock/definition/prompts/versions.yml @@ -192,7 +192,7 @@ service: 'All', 'Sample', or 'HasGT') response: docs: '' - type: float + type: root.InferenceRunCostEstimate examples: - path-parameters: prompt_id: 1 @@ -201,7 +201,10 @@ service: project_id: 1 project_subset: 1 response: - body: 1.1 + body: + prompt_cost_usd: 1.1 + completion_cost_usd: 1.1 + total_cost_usd: 1.1 audiences: - public refine_prompt: diff --git a/src/label_studio_sdk/__init__.py b/src/label_studio_sdk/__init__.py index c0829e6c8..d2c383277 100644 --- a/src/label_studio_sdk/__init__.py +++ b/src/label_studio_sdk/__init__.py @@ -35,6 +35,7 @@ GcsImportStorage, GcsImportStorageStatus, InferenceRun, + InferenceRunCostEstimate, InferenceRunCreatedBy, InferenceRunOrganization, InferenceRunProjectSubset, @@ -220,6 +221,7 @@ "GcsImportStorageStatus", "ImportStorageListTypesResponseItem", "InferenceRun", + "InferenceRunCostEstimate", "InferenceRunCreatedBy", "InferenceRunOrganization", "InferenceRunProjectSubset", diff --git a/src/label_studio_sdk/prompts/versions/client.py b/src/label_studio_sdk/prompts/versions/client.py index 8769f885f..3886161a1 100644 --- a/src/label_studio_sdk/prompts/versions/client.py +++ b/src/label_studio_sdk/prompts/versions/client.py @@ -9,6 +9,7 @@ from ...core.jsonable_encoder import jsonable_encoder from ...core.pydantic_utilities import pydantic_v1 from ...core.request_options import RequestOptions +from ...types.inference_run_cost_estimate import InferenceRunCostEstimate from ...types.prompt_version import PromptVersion from ...types.prompt_version_created_by import PromptVersionCreatedBy from ...types.prompt_version_organization import PromptVersionOrganization @@ -343,7 +344,7 @@ def cost_estimate( project_id: int, project_subset: int, request_options: typing.Optional[RequestOptions] = None, - ) -> float: + ) -> InferenceRunCostEstimate: """ Get cost estimate for running a prompt version on a particular project/subset @@ -366,7 +367,7 @@ def cost_estimate( Returns ------- - float + InferenceRunCostEstimate Examples @@ -391,7 +392,7 @@ def cost_estimate( ) try: if 200 <= _response.status_code < 300: - return pydantic_v1.parse_obj_as(float, _response.json()) # type: ignore + return pydantic_v1.parse_obj_as(InferenceRunCostEstimate, _response.json()) # type: ignore _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, body=_response.text) @@ -796,7 +797,7 @@ async def cost_estimate( project_id: int, project_subset: int, request_options: typing.Optional[RequestOptions] = None, - ) -> float: + ) -> InferenceRunCostEstimate: """ Get cost estimate for running a prompt version on a particular project/subset @@ -819,7 +820,7 @@ async def cost_estimate( Returns ------- - float + InferenceRunCostEstimate Examples @@ -844,7 +845,7 @@ async def cost_estimate( ) try: if 200 <= _response.status_code < 300: - return pydantic_v1.parse_obj_as(float, _response.json()) # type: ignore + return pydantic_v1.parse_obj_as(InferenceRunCostEstimate, _response.json()) # type: ignore _response_json = _response.json() except JSONDecodeError: raise ApiError(status_code=_response.status_code, body=_response.text) diff --git a/src/label_studio_sdk/types/__init__.py b/src/label_studio_sdk/types/__init__.py index 63b00095e..3b19c71dd 100644 --- a/src/label_studio_sdk/types/__init__.py +++ b/src/label_studio_sdk/types/__init__.py @@ -34,6 +34,7 @@ from .gcs_import_storage import GcsImportStorage from .gcs_import_storage_status import GcsImportStorageStatus from .inference_run import InferenceRun +from .inference_run_cost_estimate import InferenceRunCostEstimate from .inference_run_created_by import InferenceRunCreatedBy from .inference_run_organization import InferenceRunOrganization from .inference_run_project_subset import InferenceRunProjectSubset @@ -128,6 +129,7 @@ "GcsImportStorage", "GcsImportStorageStatus", "InferenceRun", + "InferenceRunCostEstimate", "InferenceRunCreatedBy", "InferenceRunOrganization", "InferenceRunProjectSubset", diff --git a/src/label_studio_sdk/types/inference_run_cost_estimate.py b/src/label_studio_sdk/types/inference_run_cost_estimate.py new file mode 100644 index 000000000..30920a15f --- /dev/null +++ b/src/label_studio_sdk/types/inference_run_cost_estimate.py @@ -0,0 +1,42 @@ +# This file was auto-generated by Fern from our API Definition. + +import datetime as dt +import typing + +from ..core.datetime_utils import serialize_datetime +from ..core.pydantic_utilities import deep_union_pydantic_dicts, pydantic_v1 + + +class InferenceRunCostEstimate(pydantic_v1.BaseModel): + prompt_cost_usd: typing.Optional[float] = pydantic_v1.Field(default=None) + """ + Cost of the prompt (in USD) + """ + + completion_cost_usd: typing.Optional[float] = pydantic_v1.Field(default=None) + """ + Cost of the completion (in USD) + """ + + total_cost_usd: typing.Optional[float] = pydantic_v1.Field(default=None) + """ + Total cost of the inference (in USD) + """ + + def json(self, **kwargs: typing.Any) -> str: + kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs} + return super().json(**kwargs_with_defaults) + + def dict(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: + kwargs_with_defaults_exclude_unset: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs} + kwargs_with_defaults_exclude_none: typing.Any = {"by_alias": True, "exclude_none": True, **kwargs} + + return deep_union_pydantic_dicts( + super().dict(**kwargs_with_defaults_exclude_unset), super().dict(**kwargs_with_defaults_exclude_none) + ) + + class Config: + frozen = True + smart_union = True + extra = pydantic_v1.Extra.allow + json_encoders = {dt.datetime: serialize_datetime} diff --git a/tests/prompts/test_versions.py b/tests/prompts/test_versions.py index 13d6d1f51..fca8c3d78 100644 --- a/tests/prompts/test_versions.py +++ b/tests/prompts/test_versions.py @@ -150,8 +150,8 @@ async def test_update(client: LabelStudio, async_client: AsyncLabelStudio) -> No async def test_cost_estimate(client: LabelStudio, async_client: AsyncLabelStudio) -> None: - expected_response: typing.Any = 1.1 - expected_types: typing.Any = None + expected_response: typing.Any = {"prompt_cost_usd": 1.1, "completion_cost_usd": 1.1, "total_cost_usd": 1.1} + expected_types: typing.Any = {"prompt_cost_usd": None, "completion_cost_usd": None, "total_cost_usd": None} response = client.prompts.versions.cost_estimate(prompt_id=1, version_id=1, project_id=1, project_subset=1) validate_response(response, expected_response, expected_types)