Skip to content

Commit 594c83e

Browse files
omar-selomz2
authored andcommitted
Refactor tests, use db migrations, correct session, rollback automatically (#12)
* Refactor tests, use db migrations, correct session, rollback automatically * Remove unnecessary test * Remove unnecessary package * Fix rollback to be global * Remove unused functions and improve a test case * Return accidentally deleted test case * Bring back a couple of tests and fix one * Apply small review suggestions
1 parent 19e3d84 commit 594c83e

8 files changed

+125
-213
lines changed

backend/poetry.lock

+2-20
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ fallback_version = "0.0.0"
2121

2222
[tool.poetry.group.test.dependencies]
2323
pytest = "^7.3.1"
24-
pytest-mock = "^3.10.0"
2524
requests-mock = "^1.10.0"
2625
python-multipart = "^0.0.6"
2726
httpx = "^0.24.0"

backend/src/main.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,8 @@ async def get_version():
7373
@app.put("/snapmanager")
7474
def snap_manager(db: Session = Depends(get_db)):
7575
try:
76-
session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
77-
with session() as sess:
78-
processed_artefacts = snap_manager_controller(sess)
79-
logger.info("INFO: Processed artefacts %s", processed_artefacts)
76+
processed_artefacts = snap_manager_controller(db)
77+
logger.info("INFO: Processed artefacts %s", processed_artefacts)
8078
if False in processed_artefacts.values():
8179
return JSONResponse(
8280
status_code=500,

backend/src/repository.py

+1-33
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,6 @@
2222
from .data_access.models import Family, Stage, Artefact
2323

2424

25-
def get_stages_by_family_name(session: Session, family_name: str) -> list | None:
26-
"""
27-
Fetch stages objects related to specific family
28-
29-
:session: DB session
30-
:family_name: name of the family
31-
:return: list of stages
32-
"""
33-
family = session.query(Family).filter(Family.name == family_name).first()
34-
if family is None:
35-
return []
36-
stages = (
37-
session.query(Stage)
38-
.filter(Stage.family_id == family.id)
39-
.options(joinedload(Stage.artefacts))
40-
.all()
41-
)
42-
return stages
43-
44-
4525
def get_stage_by_name(session: Session, stage_name: str, family: Family) -> Stage:
4626
"""
4727
Get the stage object by its name
@@ -59,21 +39,9 @@ def get_stage_by_name(session: Session, stage_name: str, family: Family) -> Stag
5939
return stage
6040

6141

62-
def get_family_by_name(session: Session, family_name: str):
63-
"""
64-
Get the family object by its name
65-
66-
:session: DB session
67-
:family_name: Name of the family
68-
:return: Family
69-
"""
70-
family = session.query(Family).filter(Family.name == family_name).first()
71-
return family
72-
73-
7442
def get_artefacts_by_family_name(
7543
session: Session, family_name: str, is_archived: bool = None
76-
):
44+
) -> list[Artefact]:
7745
"""
7846
Get all the artefacts in a family
7947

backend/tests/conftest.py

+34-79
Original file line numberDiff line numberDiff line change
@@ -16,103 +16,58 @@
1616
#
1717
# Written by:
1818
# Nadzeya Hutsko <nadzeya.hutsko@canonical.com>
19+
# Omar Selo <omar.selo@canonical.com>
1920
"""Fixtures for testing"""
2021

2122

2223
import pytest
23-
from sqlalchemy import create_engine
24-
from sqlalchemy.orm import sessionmaker, Session
25-
from sqlalchemy_utils import database_exists, create_database, drop_database
24+
from alembic import command
25+
from alembic.config import Config
2626
from fastapi.testclient import TestClient
27-
from src.main import app
27+
from sqlalchemy import Engine, create_engine
28+
from sqlalchemy.orm import Session, sessionmaker
29+
from sqlalchemy_utils import create_database, database_exists, drop_database
2830
from src.data_access import Base
29-
from src.data_access.models import Family, Stage, Artefact
31+
from src.data_access.models import Artefact, Stage
32+
from src.main import app, get_db
3033

3134

32-
# Setup Test Database
33-
SQLALCHEMY_DATABASE_URL = (
34-
"postgresql+pg8000://postgres:password@test-observer-db:5432/test"
35-
)
36-
engine = create_engine(SQLALCHEMY_DATABASE_URL)
37-
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
35+
@pytest.fixture(scope="session")
36+
def db_engine():
37+
db_uri = "postgresql+pg8000://postgres:password@test-observer-db:5432/test"
3838

39+
if not database_exists(db_uri):
40+
create_database(db_uri)
3941

40-
@pytest.fixture
41-
def seed_db(db_session: Session):
42-
"""Populate database with fake data"""
43-
# Snap family
44-
family = Family(name="snap")
45-
db_session.add(family)
46-
# Edge stage
47-
stage = Stage(name="edge", family=family, position=10)
48-
db_session.add(stage)
49-
artefact = Artefact(
50-
name="core20", stage=stage, version="1.1.1", source={}, artefact_group=None
51-
)
52-
db_session.add(artefact)
53-
artefact = Artefact(
54-
name="docker",
55-
stage=stage,
56-
version="1.1.1",
57-
source={},
58-
artefact_group=None,
59-
is_archived=True,
60-
)
61-
db_session.add(artefact)
62-
# Beta stage
63-
stage = Stage(name="beta", family=family, position=20)
64-
db_session.add(stage)
65-
artefact = Artefact(
66-
name="core22", stage=stage, version="1.1.0", source={}, artefact_group=None
67-
)
68-
db_session.add(artefact)
42+
engine = create_engine(db_uri)
6943

70-
# Deb family
71-
family = Family(name="deb")
72-
db_session.add(family)
73-
# Proposed stage
74-
stage = Stage(name="proposed", family=family, position=10)
75-
db_session.add(stage)
76-
artefact = Artefact(
77-
name="jammy", stage=stage, version="2.1.1", source={}, artefact_group=None
78-
)
79-
db_session.add(artefact)
80-
# Updates stage
81-
stage = Stage(name="updates", family=family, position=10)
82-
db_session.add(stage)
83-
artefact = Artefact(
84-
name="raspi", stage=stage, version="2.1.0", source={}, artefact_group=None
85-
)
86-
db_session.add(artefact)
87-
db_session.commit()
44+
alembic_config = Config("alembic.ini")
45+
alembic_config.set_main_option("sqlalchemy.url", db_uri)
46+
command.upgrade(alembic_config, "head")
8847

89-
yield
48+
yield engine
9049

91-
# Cleanup
92-
db_session.query(Artefact).delete()
93-
db_session.query(Stage).delete()
94-
db_session.query(Family).delete()
95-
db_session.commit()
50+
Base.metadata.drop_all(engine)
51+
engine.dispose()
52+
drop_database(db_uri)
9653

9754

98-
@pytest.fixture(scope="session")
99-
def db_session():
100-
"""Set up and tear down the test database"""
101-
if not database_exists(SQLALCHEMY_DATABASE_URL):
102-
create_database(SQLALCHEMY_DATABASE_URL)
55+
@pytest.fixture(scope="function")
56+
def db_session(db_engine: Engine):
57+
connection = db_engine.connect()
58+
# Start transaction and not commit it to rollback automatically
59+
transaction = connection.begin()
60+
session = sessionmaker(autocommit=False, autoflush=False, bind=connection)()
10361

104-
Base.metadata.create_all(bind=engine)
105-
session = TestingSessionLocal()
10662
yield session
10763

108-
# Cleanup
10964
session.close()
110-
Base.metadata.drop_all(bind=engine)
111-
drop_database(SQLALCHEMY_DATABASE_URL)
65+
transaction.close()
66+
connection.close()
11267

11368

114-
@pytest.fixture(scope="session")
115-
def test_app():
116-
"""Create a pytest fixture for the app"""
117-
client = TestClient(app)
118-
yield client
69+
@pytest.fixture(scope="function")
70+
def test_client(db_session: Session) -> TestClient:
71+
"""Create a test http client"""
72+
app.dependency_overrides[get_db] = lambda: db_session
73+
return TestClient(app)

backend/tests/helpers.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from sqlalchemy.orm import Session
2+
from src.data_access.models import Artefact, Stage
3+
4+
5+
def create_artefact(db_session: Session, stage_name: str, **kwargs):
6+
"""Create a dummy artefact"""
7+
stage = db_session.query(Stage).filter(Stage.name == stage_name).first()
8+
artefact = Artefact(
9+
name=kwargs.get("name", ""),
10+
stage=stage,
11+
version=kwargs.get("version", "1.1.1"),
12+
source=kwargs.get("source", {}),
13+
artefact_group=None,
14+
is_archived=kwargs.get("is_archived", False),
15+
)
16+
db_session.add(artefact)
17+
db_session.commit()
18+
return artefact

0 commit comments

Comments
 (0)