Skip to content

Commit

Permalink
chore: a new database for each test
Browse files Browse the repository at this point in the history
In 12c1780 we made tests less fragile
by giving each test class its own fresh container, which also made the
tests run a bit slower.

This change reuses the same container for all tests, but creates a new
database for each test, achieving the same isolation, with low overhead.
  • Loading branch information
JamesGuthrie committed Nov 15, 2024
1 parent 3e909b8 commit 4ed938b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 27 deletions.
2 changes: 1 addition & 1 deletion projects/pgai/tests/vectorizer/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def vcr_():
)


@pytest.fixture(scope="class")
@pytest.fixture(scope="session")
def postgres_container():
extension_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../extension/")
Expand Down
81 changes: 55 additions & 26 deletions projects/pgai/tests/vectorizer/test_vectorizer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,73 @@
import psycopg
import pytest
from click.testing import CliRunner
from psycopg import Connection
from psycopg import Connection, sql
from psycopg.rows import dict_row
from testcontainers.postgres import PostgresContainer # type: ignore

from pgai.cli import vectorizer_worker
from tests.vectorizer import expected

count = 10000


class TestDatabase:
""""""

container: PostgresContainer
dbname: str

def __init__(self, container: PostgresContainer):
global count
dbname = f"test_{count}"
count += 1
self.container = container
self.dbname = dbname
url = self._create_connection_url(dbname="template1")
with psycopg.connect(url, autocommit=True) as conn:
conn.execute("CREATE EXTENSION IF NOT EXISTS ai CASCADE")
conn.execute(
sql.SQL("CREATE DATABASE {0}").format(sql.Identifier(self.dbname))
)

def _create_connection_url(
self,
username: str | None = None,
password: str | None = None,
dbname: str | None = None,
):
host = self.container._docker.host() # type: ignore
return super(PostgresContainer, self.container)._create_connection_url( # type: ignore
dialect="postgresql",
username=username or self.container.username,
password=password or self.container.password,
dbname=dbname or self.dbname,
host=host,
port=self.container.port,
)

def get_connection_url(self) -> str:
return self._create_connection_url()


@pytest.fixture
def cli_db(
postgres_container: PostgresContainer,
) -> Generator[tuple[PostgresContainer, Connection], None, None]:
) -> Generator[tuple[TestDatabase, Connection], None, None]:
"""Creates a test database with pgai installed"""
db_host = postgres_container._docker.host() # type: ignore

# Connect and setup initial database
test_database = TestDatabase(container=postgres_container)

# Connect
with psycopg.connect(
postgres_container.get_connection_url(host=db_host),
test_database.get_connection_url(),
autocommit=True,
) as conn:
# Install pgai
conn.execute("CREATE EXTENSION IF NOT EXISTS ai CASCADE")

yield postgres_container, conn
yield test_database, conn


@pytest.fixture
def cli_db_url(cli_db: tuple[PostgresContainer, Connection]) -> str:
def cli_db_url(cli_db: tuple[TestDatabase, Connection]) -> str:
"""Constructs database URL from the cli_db fixture"""
container, _ = cli_db
return container.get_connection_url()
Expand All @@ -53,24 +92,14 @@ def test_worker_no_tasks(cli_db_url: str):

@pytest.fixture
def configured_vectorizer_and_source_table(
cli_db: tuple[PostgresContainer, Connection],
cli_db: tuple[TestDatabase, Connection],
test_params: tuple[int, int, int, str, str],
) -> int:
"""Creates and configures a vectorizer for testing"""
num_items, concurrency, batch_size, chunking, formatting = test_params
_, conn = cli_db

with conn.cursor(row_factory=dict_row) as cur:
# Cleanup from previous runs
cur.execute("SELECT id FROM ai.vectorizer")
for row in cur.fetchall():
cur.execute("SELECT ai.drop_vectorizer(%s)", (row["id"],))

# Drop tables if they exist
cur.execute("DROP VIEW IF EXISTS blog_embedding")
cur.execute("DROP TABLE IF EXISTS blog_embedding_store")
cur.execute("DROP TABLE IF EXISTS blog")

# Create source table
cur.execute("""
CREATE TABLE blog (
Expand Down Expand Up @@ -135,7 +164,7 @@ class TestWithConfiguredVectorizer:
)
def test_process_vectorizer(
self,
cli_db: tuple[PostgresContainer, Connection],
cli_db: tuple[TestDatabase, Connection],
cli_db_url: str,
configured_vectorizer_and_source_table: int,
monkeypatch: pytest.MonkeyPatch,
Expand Down Expand Up @@ -198,7 +227,7 @@ def test_process_vectorizer(
)
def test_document_exceeds_model_context_length(
self,
cli_db: tuple[PostgresContainer, Connection],
cli_db: tuple[TestDatabase, Connection],
cli_db_url: str,
configured_vectorizer_and_source_table: int,
monkeypatch: pytest.MonkeyPatch,
Expand Down Expand Up @@ -272,7 +301,7 @@ def test_document_exceeds_model_context_length(
)
def test_invalid_api_key_error(
self,
cli_db: tuple[PostgresContainer, Connection],
cli_db: tuple[TestDatabase, Connection],
cli_db_url: str,
configured_vectorizer_and_source_table: int,
monkeypatch: pytest.MonkeyPatch,
Expand Down Expand Up @@ -331,7 +360,7 @@ def test_invalid_api_key_error(
)
def test_invalid_function_arguments(
self,
cli_db: tuple[PostgresContainer, Connection],
cli_db: tuple[TestDatabase, Connection],
cli_db_url: str,
configured_vectorizer_and_source_table: int,
monkeypatch: pytest.MonkeyPatch,
Expand Down Expand Up @@ -407,7 +436,7 @@ def test_vectorizer_exits_when_vectorizers_specified_but_missing(cli_db_url: str


def test_vectorizer_picks_up_new_vectorizer(
cli_db: tuple[PostgresContainer, Connection],
cli_db: tuple[TestDatabase, Connection],
):
postgres_container, con = cli_db
db_url = postgres_container.get_connection_url()
Expand Down

0 comments on commit 4ed938b

Please sign in to comment.