From 4fa4de0f35ca52c6776226a43290d732a2e1d381 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fan=20Keskin?= <98788987+DoganK01@users.noreply.github.com> Date: Sat, 23 Nov 2024 22:03:57 +0300 Subject: [PATCH] Create mocked_openai.py --- tests/fixtures/mocked_openai.py | 95 +++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/fixtures/mocked_openai.py diff --git a/tests/fixtures/mocked_openai.py b/tests/fixtures/mocked_openai.py new file mode 100644 index 0000000..7178f8e --- /dev/null +++ b/tests/fixtures/mocked_openai.py @@ -0,0 +1,95 @@ +import os +import subprocess +import sys +import threading +import time +from contextlib import contextmanager +from pathlib import Path +from typing import Dict + +import pytest +import requests # type: ignore + +THIS_DIR = Path(__file__).parent +MOCKED_OPENAI_SERVER_PY_PATH = THIS_DIR / "../mocks/openai_fastapi_mock_app.py" + + +@pytest.fixture(scope="session") +def mocked_openai(): + """Start a mocked OpenAI server running on port 5005 for testing calls to OpenAI.""" + openai_mock_process = start_mock_server( + port=5005, + ) + + with temporary_env_vars( + { + "OPENAI_BASE_URL": "http://localhost:5005", + "OPENAI_API_KEY": "mocked_key", + } + ): + yield + + # Cleanup: Terminate the OpenAI mock server process + openai_mock_process.terminate() + openai_mock_process.wait() + + +################# +# --- Utils --- # +################# + + +@contextmanager +def temporary_env_vars(env_vars: Dict[str, str]): + """Temporarily set and restore environment variables.""" + original_env_vars = os.environ.copy() + os.environ.update(env_vars) + try: + yield + finally: + os.environ.clear() + os.environ.update(original_env_vars) + + +def _stream_output(pipe, name): + """Stream output from a subprocess pipe.""" + for line in iter(pipe.readline, ""): + if line: + print(f"[{name}] {line.strip()}") + + +def start_mock_server(port: int, max_retries: int = 3, retry_delay_seconds: int = 1) -> subprocess.Popen: + """Start a mock server and verify it's running by hitting the `/` endpoint.""" + # pylint: disable=consider-using-with + process = subprocess.Popen( + [ + # launch the mocked server using python interpreter from the same virtual env used for running the tests + sys.executable, + # the mocked server is a python file in the mocks directory + str(MOCKED_OPENAI_SERVER_PY_PATH), + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + env={"OPENAI_MOCK_PORT": str(port)}, + ) + + # Start threads to asynchronously print stdout and stderr + stdout_thread = threading.Thread(target=_stream_output, args=(process.stdout, "stdout")) + stderr_thread = threading.Thread(target=_stream_output, args=(process.stderr, "stderr")) + stdout_thread.start() + stderr_thread.start() + + for _ in range(max_retries): + try: + response = requests.get(f"http://localhost:{port}/") + if response.status_code == 200: + return process + except requests.exceptions.ConnectionError: + time.sleep(retry_delay_seconds) + + process.terminate() + stdout_thread.join() + stderr_thread.join() + raise RuntimeError(f"Mock server at port {port} failed to start after {max_retries} attempts.")