Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bare sqlalchemy session + tests #3522

Merged
merged 7 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ venv
requirements.txt
.pyi_generator_last_run
.pyi_generator_diff
reflex.db
16 changes: 14 additions & 2 deletions reflex/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from reflex.utils.compat import sqlmodel


def get_engine(url: str | None = None):
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
"""Get the database engine.

Args:
Expand Down Expand Up @@ -396,7 +396,7 @@ def select(cls):


def session(url: str | None = None) -> sqlmodel.Session:
"""Get a session to interact with the database.
"""Get a sqlmodel session to interact with the database.

Args:
url: The database url.
Expand All @@ -405,3 +405,15 @@ def session(url: str | None = None) -> sqlmodel.Session:
A database session.
"""
return sqlmodel.Session(get_engine(url))


def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
"""Get a bare sqlalchemy session to interact with the database.

Args:
url: The database url.

Returns:
A database session.
"""
return sqlalchemy.orm.Session(get_engine(url))
14 changes: 13 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import platform
import uuid
from pathlib import Path
from typing import Dict, Generator
from typing import Dict, Generator, Type
from unittest import mock

import pytest

from reflex.app import App
from reflex.event import EventSpec
from reflex.model import ModelRegistry
from reflex.utils import prerequisites

from .states import (
Expand Down Expand Up @@ -247,3 +248,14 @@ def token() -> str:
A fresh/unique token string.
"""
return str(uuid.uuid4())


@pytest.fixture
def model_registry() -> Generator[Type[ModelRegistry], None, None]:
"""Create a model registry.

Yields:
A fresh model registry.
"""
yield ModelRegistry
ModelRegistry._metadata = None
46 changes: 27 additions & 19 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional
from pathlib import Path
from typing import Optional, Type
from unittest import mock

import pytest
Expand All @@ -7,7 +8,7 @@

import reflex.constants
import reflex.model
from reflex.model import Model
from reflex.model import Model, ModelRegistry


@pytest.fixture
Expand Down Expand Up @@ -39,7 +40,7 @@ class ChildModel(Model):
return ChildModel(name="name")


def test_default_primary_key(model_default_primary):
def test_default_primary_key(model_default_primary: Model):
"""Test that if a primary key is not defined a default is added.

Args:
Expand All @@ -48,7 +49,7 @@ def test_default_primary_key(model_default_primary):
assert "id" in model_default_primary.__class__.__fields__


def test_custom_primary_key(model_custom_primary):
def test_custom_primary_key(model_custom_primary: Model):
"""Test that if a primary key is defined no default key is added.

Args:
Expand All @@ -60,12 +61,17 @@ def test_custom_primary_key(model_custom_primary):
@pytest.mark.filterwarnings(
"ignore:This declarative base already contains a class with the same class name",
)
def test_automigration(tmp_working_dir, monkeypatch):
def test_automigration(
tmp_working_dir: Path,
monkeypatch: pytest.MonkeyPatch,
model_registry: Type[ModelRegistry],
):
"""Test alembic automigration with add and drop table and column.

Args:
tmp_working_dir: directory where database and migrations are stored
monkeypatch: pytest fixture to overwrite attributes
model_registry: clean reflex ModelRegistry
"""
alembic_ini = tmp_working_dir / "alembic.ini"
versions = tmp_working_dir / "alembic" / "versions"
Expand All @@ -84,8 +90,10 @@ class AlembicThing(Model, table=True): # type: ignore
t1: str

with Model.get_db_engine().connect() as connection:
Model.alembic_autogenerate(connection=connection, message="Initial Revision")
Model.migrate()
assert Model.alembic_autogenerate(
connection=connection, message="Initial Revision"
)
assert Model.migrate()
version_scripts = list(versions.glob("*.py"))
assert len(version_scripts) == 1
assert version_scripts[0].name.endswith("initial_revision.py")
Expand All @@ -94,14 +102,14 @@ class AlembicThing(Model, table=True): # type: ignore
session.add(AlembicThing(id=None, t1="foo"))
session.commit()

sqlmodel.SQLModel.metadata.clear()
model_registry.get_metadata().clear()

# Create column t2, mark t1 as optional with default
class AlembicThing(Model, table=True): # type: ignore
t1: Optional[str] = "default"
t2: str = "bar"

Model.migrate(autogenerate=True)
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 2

with reflex.model.session() as session:
Expand All @@ -114,13 +122,13 @@ class AlembicThing(Model, table=True): # type: ignore
assert result[1].t1 == "default"
assert result[1].t2 == "baz"

sqlmodel.SQLModel.metadata.clear()
model_registry.get_metadata().clear()

# Drop column t1
class AlembicThing(Model, table=True): # type: ignore
t2: str = "bar"

Model.migrate(autogenerate=True)
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 3

with reflex.model.session() as session:
Expand All @@ -134,7 +142,7 @@ class AlembicSecond(Model, table=True): # type: ignore
a: int = 42
b: float = 4.2

Model.migrate(autogenerate=True)
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 4

with reflex.model.session() as session:
Expand All @@ -146,16 +154,16 @@ class AlembicSecond(Model, table=True): # type: ignore
assert result[0].b == 4.2

# No-op
Model.migrate(autogenerate=True)
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 4

# drop table (AlembicSecond)
sqlmodel.SQLModel.metadata.clear()
model_registry.get_metadata().clear()

class AlembicThing(Model, table=True): # type: ignore
t2: str = "bar"

Model.migrate(autogenerate=True)
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 5

with reflex.model.session() as session:
Expand All @@ -168,18 +176,18 @@ class AlembicThing(Model, table=True): # type: ignore
assert result[0].t2 == "bar"
assert result[1].t2 == "baz"

sqlmodel.SQLModel.metadata.clear()
model_registry.get_metadata().clear()

class AlembicThing(Model, table=True): # type: ignore
# changing column type not supported by default
t2: int = 42

Model.migrate(autogenerate=True)
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 5

# clear all metadata to avoid influencing subsequent tests
sqlmodel.SQLModel.metadata.clear()
model_registry.get_metadata().clear()

# drop remaining tables
Model.migrate(autogenerate=True)
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 6
166 changes: 166 additions & 0 deletions tests/test_sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from pathlib import Path
from typing import Optional, Type
from unittest import mock

import pytest
from sqlalchemy import select
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
MappedAsDataclass,
declared_attr,
mapped_column,
)

import reflex.constants
import reflex.model
from reflex.model import Model, ModelRegistry, sqla_session


@pytest.mark.filterwarnings(
"ignore:This declarative base already contains a class with the same class name",
)
def test_automigration(
tmp_working_dir: Path,
monkeypatch: pytest.MonkeyPatch,
model_registry: Type[ModelRegistry],
):
"""Test alembic automigration with add and drop table and column.

Args:
tmp_working_dir: directory where database and migrations are stored
monkeypatch: pytest fixture to overwrite attributes
model_registry: clean reflex ModelRegistry
"""
alembic_ini = tmp_working_dir / "alembic.ini"
versions = tmp_working_dir / "alembic" / "versions"
monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))

config_mock = mock.Mock()
config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))

assert alembic_ini.exists() is False
assert versions.exists() is False
Model.alembic_init()
assert alembic_ini.exists()
assert versions.exists()

class Base(DeclarativeBase):
@declared_attr.directive
def __tablename__(cls) -> str:
return cls.__name__.lower()

assert model_registry.register(Base)

class ModelBase(Base, MappedAsDataclass):
__abstract__ = True
id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)

# initial table
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
t1: Mapped[str] = mapped_column(default="")

with Model.get_db_engine().connect() as connection:
assert Model.alembic_autogenerate(
connection=connection, message="Initial Revision"
)
assert Model.migrate()
version_scripts = list(versions.glob("*.py"))
assert len(version_scripts) == 1
assert version_scripts[0].name.endswith("initial_revision.py")

with sqla_session() as session:
session.add(AlembicThing(t1="foo"))
session.commit()

model_registry.get_metadata().clear()

# Create column t2, mark t1 as optional with default
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
t1: Mapped[Optional[str]] = mapped_column(default="default")
t2: Mapped[str] = mapped_column(default="bar")

assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 2

with sqla_session() as session:
session.add(AlembicThing(t2="baz"))
session.commit()
result = session.scalars(select(AlembicThing)).all()
assert len(result) == 2
assert result[0].t1 == "foo"
assert result[0].t2 == "bar"
assert result[1].t1 == "default"
assert result[1].t2 == "baz"

model_registry.get_metadata().clear()

# Drop column t1
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
t2: Mapped[str] = mapped_column(default="bar")

assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 3

with sqla_session() as session:
result = session.scalars(select(AlembicThing)).all()
assert len(result) == 2
assert result[0].t2 == "bar"
assert result[1].t2 == "baz"

# Add table
class AlembicSecond(ModelBase):
a: Mapped[int] = mapped_column(default=42)
b: Mapped[float] = mapped_column(default=4.2)

assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 4

with reflex.model.session() as session:
session.add(AlembicSecond(id=None))
session.commit()
result = session.scalars(select(AlembicSecond)).all()
assert len(result) == 1
assert result[0].a == 42
assert result[0].b == 4.2

# No-op
# assert Model.migrate(autogenerate=True)
# assert len(list(versions.glob("*.py"))) == 4

# drop table (AlembicSecond)
model_registry.get_metadata().clear()

class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
t2: Mapped[str] = mapped_column(default="bar")

assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 5

with reflex.model.session() as session:
with pytest.raises(OperationalError) as errctx:
_ = session.scalars(select(AlembicSecond)).all()
assert errctx.match(r"no such table: alembicsecond")
# first table should still exist
result = session.scalars(select(AlembicThing)).all()
assert len(result) == 2
assert result[0].t2 == "bar"
assert result[1].t2 == "baz"

model_registry.get_metadata().clear()

class AlembicThing(ModelBase):
# changing column type not supported by default
t2: Mapped[int] = mapped_column(default=42)

assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 5

# clear all metadata to avoid influencing subsequent tests
model_registry.get_metadata().clear()

# drop remaining tables
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 6
Loading