From 4ed938bd86932bf21340e14007210d8dc6fd72e1 Mon Sep 17 00:00:00 2001 From: James Guthrie Date: Fri, 15 Nov 2024 12:29:13 +0100 Subject: [PATCH] chore: a new database for each test In 12c17809ec235d759e37eaa0898ea3274fea6319 we made tests less fragile by giving each test class its own fresh container, which also made the tests run a bit slower. This change reuses the same container for all tests, but creates a new database for each test, achieving the same isolation, with low overhead. --- projects/pgai/tests/vectorizer/conftest.py | 2 +- .../tests/vectorizer/test_vectorizer_cli.py | 81 +++++++++++++------ 2 files changed, 56 insertions(+), 27 deletions(-) diff --git a/projects/pgai/tests/vectorizer/conftest.py b/projects/pgai/tests/vectorizer/conftest.py index 23f1b99d..84e54816 100644 --- a/projects/pgai/tests/vectorizer/conftest.py +++ b/projects/pgai/tests/vectorizer/conftest.py @@ -42,7 +42,7 @@ def vcr_(): ) -@pytest.fixture(scope="class") +@pytest.fixture(scope="session") def postgres_container(): extension_dir = os.path.abspath( os.path.join(os.path.dirname(__file__), "../../../extension/") diff --git a/projects/pgai/tests/vectorizer/test_vectorizer_cli.py b/projects/pgai/tests/vectorizer/test_vectorizer_cli.py index 66c66425..7edb2494 100644 --- a/projects/pgai/tests/vectorizer/test_vectorizer_cli.py +++ b/projects/pgai/tests/vectorizer/test_vectorizer_cli.py @@ -9,34 +9,73 @@ import psycopg import pytest from click.testing import CliRunner -from psycopg import Connection +from psycopg import Connection, sql from psycopg.rows import dict_row from testcontainers.postgres import PostgresContainer # type: ignore from pgai.cli import vectorizer_worker from tests.vectorizer import expected +count = 10000 + + +class TestDatabase: + """""" + + container: PostgresContainer + dbname: str + + def __init__(self, container: PostgresContainer): + global count + dbname = f"test_{count}" + count += 1 + self.container = container + self.dbname = dbname + url = self._create_connection_url(dbname="template1") + with psycopg.connect(url, autocommit=True) as conn: + conn.execute("CREATE EXTENSION IF NOT EXISTS ai CASCADE") + conn.execute( + sql.SQL("CREATE DATABASE {0}").format(sql.Identifier(self.dbname)) + ) + + def _create_connection_url( + self, + username: str | None = None, + password: str | None = None, + dbname: str | None = None, + ): + host = self.container._docker.host() # type: ignore + return super(PostgresContainer, self.container)._create_connection_url( # type: ignore + dialect="postgresql", + username=username or self.container.username, + password=password or self.container.password, + dbname=dbname or self.dbname, + host=host, + port=self.container.port, + ) + + def get_connection_url(self) -> str: + return self._create_connection_url() + @pytest.fixture def cli_db( postgres_container: PostgresContainer, -) -> Generator[tuple[PostgresContainer, Connection], None, None]: +) -> Generator[tuple[TestDatabase, Connection], None, None]: """Creates a test database with pgai installed""" - db_host = postgres_container._docker.host() # type: ignore - # Connect and setup initial database + test_database = TestDatabase(container=postgres_container) + + # Connect with psycopg.connect( - postgres_container.get_connection_url(host=db_host), + test_database.get_connection_url(), autocommit=True, ) as conn: - # Install pgai - conn.execute("CREATE EXTENSION IF NOT EXISTS ai CASCADE") - - yield postgres_container, conn + yield test_database, conn @pytest.fixture -def cli_db_url(cli_db: tuple[PostgresContainer, Connection]) -> str: +def cli_db_url(cli_db: tuple[TestDatabase, Connection]) -> str: """Constructs database URL from the cli_db fixture""" container, _ = cli_db return container.get_connection_url() @@ -53,7 +92,7 @@ def test_worker_no_tasks(cli_db_url: str): @pytest.fixture def configured_vectorizer_and_source_table( - cli_db: tuple[PostgresContainer, Connection], + cli_db: tuple[TestDatabase, Connection], test_params: tuple[int, int, int, str, str], ) -> int: """Creates and configures a vectorizer for testing""" @@ -61,16 +100,6 @@ def configured_vectorizer_and_source_table( _, conn = cli_db with conn.cursor(row_factory=dict_row) as cur: - # Cleanup from previous runs - cur.execute("SELECT id FROM ai.vectorizer") - for row in cur.fetchall(): - cur.execute("SELECT ai.drop_vectorizer(%s)", (row["id"],)) - - # Drop tables if they exist - cur.execute("DROP VIEW IF EXISTS blog_embedding") - cur.execute("DROP TABLE IF EXISTS blog_embedding_store") - cur.execute("DROP TABLE IF EXISTS blog") - # Create source table cur.execute(""" CREATE TABLE blog ( @@ -135,7 +164,7 @@ class TestWithConfiguredVectorizer: ) def test_process_vectorizer( self, - cli_db: tuple[PostgresContainer, Connection], + cli_db: tuple[TestDatabase, Connection], cli_db_url: str, configured_vectorizer_and_source_table: int, monkeypatch: pytest.MonkeyPatch, @@ -198,7 +227,7 @@ def test_process_vectorizer( ) def test_document_exceeds_model_context_length( self, - cli_db: tuple[PostgresContainer, Connection], + cli_db: tuple[TestDatabase, Connection], cli_db_url: str, configured_vectorizer_and_source_table: int, monkeypatch: pytest.MonkeyPatch, @@ -272,7 +301,7 @@ def test_document_exceeds_model_context_length( ) def test_invalid_api_key_error( self, - cli_db: tuple[PostgresContainer, Connection], + cli_db: tuple[TestDatabase, Connection], cli_db_url: str, configured_vectorizer_and_source_table: int, monkeypatch: pytest.MonkeyPatch, @@ -331,7 +360,7 @@ def test_invalid_api_key_error( ) def test_invalid_function_arguments( self, - cli_db: tuple[PostgresContainer, Connection], + cli_db: tuple[TestDatabase, Connection], cli_db_url: str, configured_vectorizer_and_source_table: int, monkeypatch: pytest.MonkeyPatch, @@ -407,7 +436,7 @@ def test_vectorizer_exits_when_vectorizers_specified_but_missing(cli_db_url: str def test_vectorizer_picks_up_new_vectorizer( - cli_db: tuple[PostgresContainer, Connection], + cli_db: tuple[TestDatabase, Connection], ): postgres_container, con = cli_db db_url = postgres_container.get_connection_url()