diff --git a/src/hume/expression_measurement/batch/client_with_utils.py b/src/hume/expression_measurement/batch/client_with_utils.py index 612e44ad..12f48503 100644 --- a/src/hume/expression_measurement/batch/client_with_utils.py +++ b/src/hume/expression_measurement/batch/client_with_utils.py @@ -1,9 +1,17 @@ import aiofiles import typing +import json as jsonlib +from json.decoder import JSONDecodeError from ...core.request_options import RequestOptions +from ...core.jsonable_encoder import jsonable_encoder +from ... import core +from .types.inference_base_request import InferenceBaseRequest +from ...core.pydantic_utilities import parse_obj_as +from .types.job_id import JobId from .client import AsyncBatchClient, BatchClient +from ...core.api_error import ApiError class BatchClientWithUtils(BatchClient): def get_and_write_job_artifacts( @@ -47,6 +55,69 @@ def get_and_write_job_artifacts( for chunk in self.get_job_artifacts(id=id, request_options=request_options): f.write(chunk) + def start_inference_job_from_local_file( + self, + *, + file: typing.List[core.File], + json: typing.Optional[InferenceBaseRequest] = None, + request_options: typing.Optional[RequestOptions] = None, + ) -> str: + """ + Start a new batch inference job. + + Parameters + ---------- + file : typing.List[core.File] + See core.File for more documentation + + json : typing.Optional[InferenceBaseRequest] + The inference job configuration. + + request_options : typing.Optional[RequestOptions] + Request-specific configuration. + + Returns + ------- + str + + + Examples + -------- + from hume import HumeClient + + client = HumeClient( + api_key="YOUR_API_KEY", + ) + client.expression_measurement.batch.start_inference_job_from_local_file() + """ + files: typing.Dict[str, typing.Any] = { + "file": file, + } + if json is not None: + files["json"] = jsonlib.dumps(jsonable_encoder(json)).encode("utf-8") + + _response = self._client_wrapper.httpx_client.request( + "v0/batch/jobs", + method="POST", + files=files, + request_options=request_options, + ) + try: + if 200 <= _response.status_code < 300: + _parsed_response = typing.cast( + JobId, + parse_obj_as( + type_=JobId, # type: ignore + object_=_response.json(), + ), + ) + return _parsed_response.job_id + _response_json = _response.json() + except JSONDecodeError: + raise ApiError(status_code=_response.status_code, body=_response.text) + raise ApiError(status_code=_response.status_code, body=_response_json) + + class AsyncBatchClientWithUtils(AsyncBatchClient): async def get_and_write_job_artifacts( self, @@ -87,4 +158,66 @@ async def get_and_write_job_artifacts( """ async with aiofiles.open(file_name, mode='wb') as f: async for chunk in self.get_job_artifacts(id=id, request_options=request_options): - await f.write(chunk) \ No newline at end of file + await f.write(chunk) + + async def start_inference_job_from_local_file( + self, + *, + file: typing.List[core.File], + json: typing.Optional[InferenceBaseRequest] = None, + request_options: typing.Optional[RequestOptions] = None, + ) -> str: + """ + Start a new batch inference job. + + Parameters + ---------- + file : typing.List[core.File] + See core.File for more documentation + + json : typing.Optional[InferenceBaseRequest] + The inference job configuration. + + request_options : typing.Optional[RequestOptions] + Request-specific configuration. + + Returns + ------- + str + + + Examples + -------- + from hume import HumeClient + + client = HumeClient( + api_key="YOUR_API_KEY", + ) + client.expression_measurement.batch.start_inference_job_from_local_file() + """ + files: typing.Dict[str, typing.Any] = { + "file": file, + } + if json is not None: + files["json"] = jsonlib.dumps(jsonable_encoder(json)).encode("utf-8") + + _response = await self._client_wrapper.httpx_client.request( + "v0/batch/jobs", + method="POST", + files=files, + request_options=request_options, + ) + try: + if 200 <= _response.status_code < 300: + _parsed_response = typing.cast( + JobId, + parse_obj_as( + type_=JobId, # type: ignore + object_=_response.json(), + ), + ) + return _parsed_response.job_id + _response_json = _response.json() + except JSONDecodeError: + raise ApiError(status_code=_response.status_code, body=_response.text) + raise ApiError(status_code=_response.status_code, body=_response_json) diff --git a/tests/custom/test_client.py b/tests/custom/test_client.py index 56919b18..dd3a501e 100644 --- a/tests/custom/test_client.py +++ b/tests/custom/test_client.py @@ -1,12 +1,11 @@ import pytest import aiofiles -from hume.client import AsyncHumeClient +from hume.client import AsyncHumeClient, HumeClient +from hume.expression_measurement.batch.types.face import Face +from hume.expression_measurement.batch.types.inference_base_request import InferenceBaseRequest +from hume.expression_measurement.batch.types.models import Models -# Get started with writing tests with pytest at https://docs.pytest.org -@pytest.mark.skip(reason="Unimplemented") -def test_client() -> None: - assert True == True @pytest.mark.skip(reason="CI does not have authentication.") async def test_write_job_artifacts() -> None: @@ -20,4 +19,16 @@ async def test_get_job_predictions() -> None: client = AsyncHumeClient(api_key="MY_API_KEY") await client.expression_measurement.batch.get_job_predictions(id="my-job-id", request_options={ "max_retries": 3, - }) \ No newline at end of file + }) + +@pytest.mark.skip(reason="CI does not have authentication.") +async def test_start_inference_job_from_local_file() -> None: + client = HumeClient(api_key="MY_API_KEY") + client.expression_measurement.batch.start_inference_job_from_local_file( + file=[], + json=InferenceBaseRequest( + models=Models( + face=Face() + ) + ) + )