Skip to content

Commit

Permalink
Move back to HTTPX (#13)
Browse files Browse the repository at this point in the history
* Revert "Use niquests (#7)"

This reverts commit bfaec30.

* Leave good stuff
  • Loading branch information
vrslev authored Dec 5, 2024
1 parent 75f1430 commit ec25bed
Show file tree
Hide file tree
Showing 16 changed files with 224 additions and 335 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: extractions/setup-just@v2
- uses: astral-sh/setup-uv@v4
- uses: astral-sh/setup-uv@v3
with:
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
Expand Down
10 changes: 0 additions & 10 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,8 @@ lint:
uv run --group lint ruff format
uv run --group lint mypy .

_test-no-http *args:
uv run pytest --ignore tests/test_http.py {{ args }}

test *args:
#!/bin/bash
uv run litestar --app tests.testing_app:app run &
APP_PID=$!
uv run pytest {{ args }}
TEST_RESULT=$?
kill $APP_PID
wait $APP_PID 2>/dev/null
exit $TEST_RESULT

publish:
rm -rf dist
Expand Down
12 changes: 5 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,25 +162,23 @@ async with any_llm_client.OpenAIClient(config, ...) as client:
#### Timeouts, proxy & other HTTP settings


Pass custom [niquests](https://niquests.readthedocs.io) kwargs to `any_llm_client.get_client()`:
Pass custom [HTTPX](https://www.python-httpx.org) kwargs to `any_llm_client.get_client()`:

```python
import urllib3
import httpx

import any_llm_client


async with any_llm_client.get_client(
...,
proxies={"https://api.openai.com": "http://localhost:8030"},
timeout=urllib3.Timeout(total=10.0, connect=5.0),
mounts={"https://api.openai.com": httpx.AsyncHTTPTransport(proxy="http://localhost:8030")},
timeout=httpx.Timeout(None, connect=5.0),
) as client:
...
```

`timeout` and `proxies` parameters are special cased here: `niquests.AsyncSession` doesn't receive them by default.

Default timeout is `urllib3.Timeout(total=None, connect=5.0)`.
Default timeout is `httpx.Timeout(None, connect=5.0)` (5 seconds on connect, unlimited on read, write or pool).

#### Retries

Expand Down
68 changes: 39 additions & 29 deletions any_llm_client/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from http import HTTPStatus

import annotated_types
import niquests
import httpx
import httpx_sse
import pydantic
import typing_extensions

Expand All @@ -19,9 +20,8 @@
OutOfTokensOrSymbolsError,
UserMessage,
)
from any_llm_client.http import HttpClient, HttpStatusError
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
from any_llm_client.retry import RequestRetryConfig
from any_llm_client.sse import parse_sse_events


OPENAI_AUTH_TOKEN_ENV_NAME: typing.Final = "ANY_LLM_CLIENT_OPENAI_AUTH_TOKEN"
Expand Down Expand Up @@ -99,34 +99,31 @@ def _make_user_assistant_alternate_messages(
yield ChatCompletionsMessage(role=current_message_role, content="\n\n".join(current_message_content_chunks))


def _handle_status_error(error: HttpStatusError) -> typing.NoReturn:
if (
error.status_code == HTTPStatus.BAD_REQUEST and b"Please reduce the length of the messages" in error.content
): # vLLM
raise OutOfTokensOrSymbolsError(response_content=error.content)
raise LLMError(response_content=error.content)
def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn:
if status_code == HTTPStatus.BAD_REQUEST and b"Please reduce the length of the messages" in content: # vLLM
raise OutOfTokensOrSymbolsError(response_content=content)
raise LLMError(response_content=content)


@dataclasses.dataclass(slots=True, init=False)
class OpenAIClient(LLMClient):
config: OpenAIConfig
http_client: HttpClient
httpx_client: httpx.AsyncClient
request_retry: RequestRetryConfig

def __init__(
self,
config: OpenAIConfig,
*,
request_retry: RequestRetryConfig | None = None,
**niquests_kwargs: typing.Any, # noqa: ANN401
**httpx_kwargs: typing.Any, # noqa: ANN401
) -> None:
self.config = config
self.http_client = HttpClient(
request_retry=request_retry or RequestRetryConfig(), niquests_kwargs=niquests_kwargs
)
self.request_retry = request_retry or RequestRetryConfig()
self.httpx_client = get_http_client_from_kwargs(httpx_kwargs)

def _build_request(self, payload: dict[str, typing.Any]) -> niquests.Request:
return niquests.Request(
def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
return self.httpx_client.build_request(
method="POST",
url=str(self.config.url),
json=payload,
Expand Down Expand Up @@ -155,17 +152,24 @@ async def request_llm_message(
**extra or {},
).model_dump(mode="json")
try:
response: typing.Final = await self.http_client.request(self._build_request(payload))
except HttpStatusError as exception:
_handle_status_error(exception)
return ChatCompletionsNotStreamingResponse.model_validate_json(response).choices[0].message.content
response: typing.Final = await make_http_request(
httpx_client=self.httpx_client,
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
)
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)
try:
return ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message.content
finally:
await response.aclose()

async def _iter_partial_responses(self, response: typing.AsyncIterable[bytes]) -> typing.AsyncIterable[str]:
async def _iter_partial_responses(self, response: httpx.Response) -> typing.AsyncIterable[str]:
text_chunks: typing.Final = []
async for one_event in parse_sse_events(response):
if one_event.data == "[DONE]":
async for event in httpx_sse.EventSource(response).aiter_sse():
if event.data == "[DONE]":
break
validated_response = ChatCompletionsStreamingEvent.model_validate_json(one_event.data)
validated_response = ChatCompletionsStreamingEvent.model_validate_json(event.data)
if not (one_chunk := validated_response.choices[0].delta.content):
continue
text_chunks.append(one_chunk)
Expand All @@ -183,13 +187,19 @@ async def stream_llm_partial_messages(
**extra or {},
).model_dump(mode="json")
try:
async with self.http_client.stream(request=self._build_request(payload)) as response:
async with make_streaming_http_request(
httpx_client=self.httpx_client,
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
) as response:
yield self._iter_partial_responses(response)
except HttpStatusError as exception:
_handle_status_error(exception)
except httpx.HTTPStatusError as exception:
content: typing.Final = await exception.response.aread()
await exception.response.aclose()
_handle_status_error(status_code=exception.response.status_code, content=content)

async def __aenter__(self) -> typing_extensions.Self:
await self.http_client.__aenter__()
await self.httpx_client.__aenter__()
return self

async def __aexit__(
Expand All @@ -198,4 +208,4 @@ async def __aexit__(
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
await self.http_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
66 changes: 38 additions & 28 deletions any_llm_client/clients/yandexgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from http import HTTPStatus

import annotated_types
import niquests
import httpx
import pydantic
import typing_extensions

from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError, UserMessage
from any_llm_client.http import HttpClient, HttpStatusError
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
from any_llm_client.retry import RequestRetryConfig


Expand Down Expand Up @@ -61,34 +61,34 @@ class YandexGPTResponse(pydantic.BaseModel):
result: YandexGPTResult


def _handle_status_error(error: HttpStatusError) -> typing.NoReturn:
if error.status_code == HTTPStatus.BAD_REQUEST and (
b"number of input tokens must be no more than" in error.content
or (b"text length is" in error.content and b"which is outside the range" in error.content)
def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn:
if status_code == HTTPStatus.BAD_REQUEST and (
b"number of input tokens must be no more than" in content
or (b"text length is" in content and b"which is outside the range" in content)
):
raise OutOfTokensOrSymbolsError(response_content=error.content)
raise LLMError(response_content=error.content)
raise OutOfTokensOrSymbolsError(response_content=content)
raise LLMError(response_content=content)


@dataclasses.dataclass(slots=True, init=False)
class YandexGPTClient(LLMClient):
config: YandexGPTConfig
http_client: HttpClient
httpx_client: httpx.AsyncClient
request_retry: RequestRetryConfig

def __init__(
self,
config: YandexGPTConfig,
*,
request_retry: RequestRetryConfig | None = None,
**niquests_kwargs: typing.Any, # noqa: ANN401
**httpx_kwargs: typing.Any, # noqa: ANN401
) -> None:
self.config = config
self.http_client = HttpClient(
request_retry=request_retry or RequestRetryConfig(), niquests_kwargs=niquests_kwargs
)
self.request_retry = request_retry or RequestRetryConfig()
self.httpx_client = get_http_client_from_kwargs(httpx_kwargs)

def _build_request(self, payload: dict[str, typing.Any]) -> niquests.Request:
return niquests.Request(
def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
return self.httpx_client.build_request(
method="POST",
url=str(self.config.url),
json=payload,
Expand Down Expand Up @@ -121,14 +121,18 @@ async def request_llm_message(
)

try:
response: typing.Final = await self.http_client.request(self._build_request(payload))
except HttpStatusError as exception:
_handle_status_error(exception)

return YandexGPTResponse.model_validate_json(response).result.alternatives[0].message.text

async def _iter_completion_messages(self, response: typing.AsyncIterable[bytes]) -> typing.AsyncIterable[str]:
async for one_line in response:
response: typing.Final = await make_http_request(
httpx_client=self.httpx_client,
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
)
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)

return YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text

async def _iter_completion_messages(self, response: httpx.Response) -> typing.AsyncIterable[str]:
async for one_line in response.aiter_lines():
validated_response = YandexGPTResponse.model_validate_json(one_line)
yield validated_response.result.alternatives[0].message.text

Expand All @@ -141,13 +145,19 @@ async def stream_llm_partial_messages(
)

try:
async with self.http_client.stream(request=self._build_request(payload)) as response:
async with make_streaming_http_request(
httpx_client=self.httpx_client,
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
) as response:
yield self._iter_completion_messages(response)
except HttpStatusError as exception:
_handle_status_error(exception)
except httpx.HTTPStatusError as exception:
content: typing.Final = await exception.response.aread()
await exception.response.aclose()
_handle_status_error(status_code=exception.response.status_code, content=content)

async def __aenter__(self) -> typing_extensions.Self:
await self.http_client.__aenter__()
await self.httpx_client.__aenter__()
return self

async def __aexit__(
Expand All @@ -156,4 +166,4 @@ async def __aexit__(
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
await self.http_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
Loading

0 comments on commit ec25bed

Please sign in to comment.