diff --git a/log10/_httpx_utils.py b/log10/_httpx_utils.py index 69cecccb..dfe26f93 100644 --- a/log10/_httpx_utils.py +++ b/log10/_httpx_utils.py @@ -340,22 +340,20 @@ class _RequestHooks: def __init__(self): logger.debug("LOG10: initializing request hooks") self.event_hooks = { - "request": [self.get_completion_id, self.log_request], + "request": [self.log_request], } - self.completion_id = "" self.log_row = {} - def get_completion_id(self, request: httpx.Request): - logger.debug("LOG10: generating completion id") - self.completion_id = get_completion_id(request) - def log_request(self, request: httpx.Request): - logger.debug("LOG10: sending sync request") - if not self.completion_id: + logger.debug("LOG10: generating completion id") + completion_id = get_completion_id(request) + if not completion_id: logger.debug("LOG10: completion id is not generated. Skipping") return + + logger.debug("LOG10: sending sync request") self.log_row = _init_log_row(request) - _try_post_request(url=f"{base_url}/api/completions/{self.completion_id}", payload=self.log_row) + _try_post_request(url=f"{base_url}/api/completions/{completion_id}", payload=self.log_row) class _AsyncRequestHooks: @@ -369,23 +367,21 @@ class _AsyncRequestHooks: def __init__(self): logger.debug("LOG10: initializing async request hooks") self.event_hooks = { - "request": [self.get_completion_id, self.log_request], + "request": [self.log_request], } - self.completion_id = "" self.log_row = {} - def get_completion_id(self, request: httpx.Request): - logger.debug("LOG10: generating completion id") - self.completion_id = get_completion_id(request) - async def log_request(self, request: httpx.Request): - logger.debug("LOG10: sending async request") - if not self.completion_id: + logger.debug("LOG10: generating completion id") + completion_id = get_completion_id(request) + if not completion_id: logger.debug("LOG10: completion id is not generated. Skipping") return + + logger.debug("LOG10: sending async request") self.log_row = _init_log_row(request) asyncio.create_task( - _try_post_request_async(url=f"{base_url}/api/completions/{self.completion_id}", payload=self.log_row) + _try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=self.log_row) ) diff --git a/tests/test_request_hooks.py b/tests/test_request_hooks.py deleted file mode 100644 index 53aa2031..00000000 --- a/tests/test_request_hooks.py +++ /dev/null @@ -1,78 +0,0 @@ -import asyncio -from unittest.mock import AsyncMock, MagicMock, call, patch - -import httpx -import pytest - -from log10._httpx_utils import _AsyncRequestHooks, finalize - - -@pytest.mark.asyncio -async def test_async_log_request_not_called_with_empty_completion_id(): - hooks = _AsyncRequestHooks() - hooks.completion_id = "" - request = httpx.Request("GET", "http://test.com") - - with patch("log10._httpx_utils._try_post_request_async", new_callable=AsyncMock) as mock_post_request: - mock_post_request.assert_not_called() - - await hooks.log_request(request) - await finalize() - - mock_post_request.assert_not_called() - - -@pytest.mark.asyncio -async def test_async_log_request_not_called_with_completion_id_none(): - hooks = _AsyncRequestHooks() - hooks.completion_id = None - request = httpx.Request("GET", "http://test.com") - - with patch("log10._httpx_utils._try_post_request_async", new_callable=AsyncMock) as mock_post_request: - mock_post_request.assert_not_called() - - await hooks.log_request(request) - await finalize() - - mock_post_request.assert_not_called() - - -@pytest.mark.asyncio -async def test_async_log_request_called_once(): - hooks = _AsyncRequestHooks() - hooks.completion_id = "abc-123" - request = httpx.Request("GET", "http://test.com") - - # Assert that log_request was not called during initialization - with patch("log10._httpx_utils._try_post_request_async", new_callable=AsyncMock) as mock_post_request: - mock_post_request.assert_not_called() - - await hooks.log_request(request) - await finalize() - - mock_post_request.assert_called_once() - - -@pytest.mark.asyncio -async def test_request_hooks_call_order(): - hooks = _AsyncRequestHooks() - hooks.get_completion_id = MagicMock() - hooks.log_request = AsyncMock() - request = httpx.Request(method="GET", url="https://example.com") - hooks.event_hooks["request"] = [hooks.get_completion_id, hooks.log_request] - - with patch("log10._httpx_utils._try_post_request_async", new=AsyncMock()): - for hook in hooks.event_hooks["request"]: - if asyncio.iscoroutinefunction(hook): - await hook(request) - else: - hook(request) - - hooks.get_completion_id.assert_called_once_with(request) - hooks.log_request.assert_awaited_once_with(request) - - # Check that get_completion_id was called before log_request - assert hooks.get_completion_id.call_args_list[0] == call(request) - assert hooks.log_request.await_args_list[0] == call(request) - call_order = list(hooks.get_completion_id.call_args_list + hooks.log_request.await_args_list) - assert call_order == [call(request), call(request)], "get_completion_id was not called before log_request"