From 9b7cbf65deedcac6aa0f9d385a527b4f37b46831 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Wed, 19 Jun 2024 18:25:04 +0200 Subject: [PATCH 1/7] add bare sqlalchemy session, Closes #3512 --- reflex/model.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/reflex/model.py b/reflex/model.py index 088601cae4..2679058ab1 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -396,7 +396,7 @@ def select(cls): def session(url: str | None = None) -> sqlmodel.Session: - """Get a session to interact with the database. + """Get a sqlmodel session to interact with the database. Args: url: The database url. @@ -405,3 +405,15 @@ def session(url: str | None = None) -> sqlmodel.Session: A database session. """ return sqlmodel.Session(get_engine(url)) + + +def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session: + """Get a bare sqlalchemy session to interact with the database. + + Args: + url: The database url. + + Returns: + A database session. + """ + return sqlalchemy.orm.Session(get_engine(url)) From fceccdfefe98fa193cb296ddfb5f249aa2dfd8a8 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Wed, 19 Jun 2024 19:27:58 +0200 Subject: [PATCH 2/7] expose sqla_session at module level, add tests, improve typing --- reflex.db | 0 reflex/__init__.py | 2 +- tests/test_model.py | 30 +++++--- tests/test_sqlalchemy.py | 162 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 181 insertions(+), 13 deletions(-) create mode 100644 reflex.db create mode 100644 tests/test_sqlalchemy.py diff --git a/reflex.db b/reflex.db new file mode 100644 index 0000000000..e69de29bb2 diff --git a/reflex/__init__.py b/reflex/__init__.py index a71a6cc46c..18d2b2efe8 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -304,7 +304,7 @@ "window_alert", ], "middleware": ["middleware", "Middleware"], - "model": ["session", "Model"], + "model": ["session", "sqla_session", "Model"], "state": [ "var", "Cookie", diff --git a/tests/test_model.py b/tests/test_model.py index ee0336b373..8272f9b2c6 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Optional from unittest import mock @@ -39,7 +40,7 @@ class ChildModel(Model): return ChildModel(name="name") -def test_default_primary_key(model_default_primary): +def test_default_primary_key(model_default_primary: Model): """Test that if a primary key is not defined a default is added. Args: @@ -48,7 +49,7 @@ def test_default_primary_key(model_default_primary): assert "id" in model_default_primary.__class__.__fields__ -def test_custom_primary_key(model_custom_primary): +def test_custom_primary_key(model_custom_primary: Model): """Test that if a primary key is defined no default key is added. Args: @@ -60,7 +61,10 @@ def test_custom_primary_key(model_custom_primary): @pytest.mark.filterwarnings( "ignore:This declarative base already contains a class with the same class name", ) -def test_automigration(tmp_working_dir, monkeypatch): +def test_automigration( + tmp_working_dir: Path, + monkeypatch: pytest.MonkeyPatch, +): """Test alembic automigration with add and drop table and column. Args: @@ -84,8 +88,10 @@ class AlembicThing(Model, table=True): # type: ignore t1: str with Model.get_db_engine().connect() as connection: - Model.alembic_autogenerate(connection=connection, message="Initial Revision") - Model.migrate() + assert Model.alembic_autogenerate( + connection=connection, message="Initial Revision" + ) + assert Model.migrate() version_scripts = list(versions.glob("*.py")) assert len(version_scripts) == 1 assert version_scripts[0].name.endswith("initial_revision.py") @@ -101,7 +107,7 @@ class AlembicThing(Model, table=True): # type: ignore t1: Optional[str] = "default" t2: str = "bar" - Model.migrate(autogenerate=True) + assert Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 2 with reflex.model.session() as session: @@ -120,7 +126,7 @@ class AlembicThing(Model, table=True): # type: ignore class AlembicThing(Model, table=True): # type: ignore t2: str = "bar" - Model.migrate(autogenerate=True) + assert Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 3 with reflex.model.session() as session: @@ -134,7 +140,7 @@ class AlembicSecond(Model, table=True): # type: ignore a: int = 42 b: float = 4.2 - Model.migrate(autogenerate=True) + assert Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 4 with reflex.model.session() as session: @@ -146,7 +152,7 @@ class AlembicSecond(Model, table=True): # type: ignore assert result[0].b == 4.2 # No-op - Model.migrate(autogenerate=True) + assert Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 4 # drop table (AlembicSecond) @@ -155,7 +161,7 @@ class AlembicSecond(Model, table=True): # type: ignore class AlembicThing(Model, table=True): # type: ignore t2: str = "bar" - Model.migrate(autogenerate=True) + assert Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 5 with reflex.model.session() as session: @@ -174,12 +180,12 @@ class AlembicThing(Model, table=True): # type: ignore # changing column type not supported by default t2: int = 42 - Model.migrate(autogenerate=True) + assert Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 5 # clear all metadata to avoid influencing subsequent tests sqlmodel.SQLModel.metadata.clear() # drop remaining tables - Model.migrate(autogenerate=True) + assert Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 6 diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py new file mode 100644 index 0000000000..21932662b8 --- /dev/null +++ b/tests/test_sqlalchemy.py @@ -0,0 +1,162 @@ +from pathlib import Path +from typing import Optional +from unittest import mock + +import pytest +import sqlalchemy +from sqlalchemy import select +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + MappedAsDataclass, + declared_attr, + mapped_column, +) + +import reflex.constants +import reflex.model +from reflex.model import Model, ModelRegistry, sqla_session + + +@pytest.mark.filterwarnings( + "ignore:This declarative base already contains a class with the same class name", +) +def test_automigration( + tmp_working_dir: Path, + monkeypatch: pytest.MonkeyPatch, +): + """Test alembic automigration with add and drop table and column. + + Args: + tmp_working_dir: directory where database and migrations are stored + monkeypatch: pytest fixture to overwrite attributes + """ + alembic_ini = tmp_working_dir / "alembic.ini" + versions = tmp_working_dir / "alembic" / "versions" + monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini)) + + config_mock = mock.Mock() + config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db" + monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock)) + + Model.alembic_init() + assert alembic_ini.exists() + assert versions.exists() + + class Base(DeclarativeBase): + @declared_attr.directive + def __tablename__(cls) -> str: + return f"{cls.__module__}_{cls.__name__}".lower() + + assert ModelRegistry.register(Base) + + class ModelBase(Base, MappedAsDataclass): + __abstract__ = True + id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None) + + # initial table + class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] + t1: Mapped[str] = mapped_column(default="") + + with Model.get_db_engine().connect() as connection: + assert Model.alembic_autogenerate( + connection=connection, message="Initial Revision" + ) + assert Model.migrate() + version_scripts = list(versions.glob("*.py")) + assert len(version_scripts) == 1 + assert version_scripts[0].name.endswith("initial_revision.py") + + with sqla_session() as session: + session.add(AlembicThing(t1="foo")) + session.commit() + + ModelRegistry.get_metadata().clear() + + # Create column t2, mark t1 as optional with default + class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] + t1: Mapped[Optional[str]] = mapped_column(default="default") + t2: Mapped[str] = mapped_column(default="bar") + + assert Model.migrate(autogenerate=True) + assert len(list(versions.glob("*.py"))) == 2 + + with sqla_session() as session: + session.add(AlembicThing(t2="baz")) + session.commit() + result = session.scalars(select(AlembicThing)).all() + assert len(result) == 2 + assert result[0].t1 == "foo" + assert result[0].t2 == "bar" + assert result[1].t1 == "default" + assert result[1].t2 == "baz" + + ModelRegistry.get_metadata().clear() + + # Drop column t1 + class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] + t2: Mapped[str] = mapped_column(default="bar") + + assert Model.migrate(autogenerate=True) + assert len(list(versions.glob("*.py"))) == 3 + + with sqla_session() as session: + result = session.scalars(select(AlembicThing)).all() + assert len(result) == 2 + assert result[0].t2 == "bar" + assert result[1].t2 == "baz" + + # Add table + class AlembicSecond(ModelBase): + a: Mapped[int] = mapped_column(default=42) + b: Mapped[float] = mapped_column(default=4.2) + + assert Model.migrate(autogenerate=True) + assert len(list(versions.glob("*.py"))) == 4 + + with reflex.model.session() as session: + session.add(AlembicSecond(id=None)) + session.commit() + result = session.scalars(select(AlembicSecond)).all() + assert len(result) == 1 + assert result[0].a == 42 + assert result[0].b == 4.2 + + # No-op + # assert Model.migrate(autogenerate=True) + # assert len(list(versions.glob("*.py"))) == 4 + + # drop table (AlembicSecond) + ModelRegistry.get_metadata().clear() + + class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] + t2: Mapped[str] = mapped_column(default="bar") + + assert Model.migrate(autogenerate=True) + assert len(list(versions.glob("*.py"))) == 5 + + with reflex.model.session() as session: + with pytest.raises(sqlalchemy.exc.OperationalError) as errctx: + _ = session.scalars(select(AlembicSecond)).all() + assert errctx.match(r"no such table: tests.test_session_alembicsecond") + # first table should still exist + result = session.scalars(select(AlembicThing)).all() + assert len(result) == 2 + assert result[0].t2 == "bar" + assert result[1].t2 == "baz" + + ModelRegistry.get_metadata().clear() + + class AlembicThing(ModelBase): + # changing column type not supported by default + t2: Mapped[int] = mapped_column(default=42) + + assert Model.migrate(autogenerate=True) + assert len(list(versions.glob("*.py"))) == 5 + + # clear all metadata to avoid influencing subsequent tests + ModelRegistry.get_metadata().clear() + + # drop remaining tables + assert Model.migrate(autogenerate=True) + assert len(list(versions.glob("*.py"))) == 6 From 390df29bb235cb1e5f6f87663b87d7cb950abb2d Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Wed, 19 Jun 2024 19:33:32 +0200 Subject: [PATCH 3/7] fix table name --- tests/test_sqlalchemy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py index 21932662b8..4144076381 100644 --- a/tests/test_sqlalchemy.py +++ b/tests/test_sqlalchemy.py @@ -46,7 +46,7 @@ def test_automigration( class Base(DeclarativeBase): @declared_attr.directive def __tablename__(cls) -> str: - return f"{cls.__module__}_{cls.__name__}".lower() + return cls.__name__.lower() assert ModelRegistry.register(Base) @@ -138,7 +138,7 @@ class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] with reflex.model.session() as session: with pytest.raises(sqlalchemy.exc.OperationalError) as errctx: _ = session.scalars(select(AlembicSecond)).all() - assert errctx.match(r"no such table: tests.test_session_alembicsecond") + assert errctx.match(r"no such table: alembicsecond") # first table should still exist result = session.scalars(select(AlembicThing)).all() assert len(result) == 2 From b2a35a1ed68739ff10aa149c8b4fe43e854440ca Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Wed, 19 Jun 2024 20:05:06 +0200 Subject: [PATCH 4/7] add model_registry fixture, improve typing --- reflex/model.py | 2 +- tests/conftest.py | 9 ++++++++- tests/test_model.py | 16 +++++++++------- tests/test_sqlalchemy.py | 22 +++++++++++++--------- 4 files changed, 31 insertions(+), 18 deletions(-) diff --git a/reflex/model.py b/reflex/model.py index 2679058ab1..71e26f76a7 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -24,7 +24,7 @@ from reflex.utils.compat import sqlmodel -def get_engine(url: str | None = None): +def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine: """Get the database engine. Args: diff --git a/tests/conftest.py b/tests/conftest.py index b896edba4a..1c4d452e9a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,13 +5,14 @@ import platform import uuid from pathlib import Path -from typing import Dict, Generator +from typing import Dict, Generator, Type from unittest import mock import pytest from reflex.app import App from reflex.event import EventSpec +from reflex.model import ModelRegistry from reflex.utils import prerequisites from .states import ( @@ -247,3 +248,9 @@ def token() -> str: A fresh/unique token string. """ return str(uuid.uuid4()) + + +@pytest.fixture +def model_registry() -> Generator[Type[ModelRegistry], None, None]: + yield ModelRegistry + ModelRegistry._metadata = None diff --git a/tests/test_model.py b/tests/test_model.py index 8272f9b2c6..ac8187e031 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional +from typing import Optional, Type from unittest import mock import pytest @@ -8,7 +8,7 @@ import reflex.constants import reflex.model -from reflex.model import Model +from reflex.model import Model, ModelRegistry @pytest.fixture @@ -64,12 +64,14 @@ def test_custom_primary_key(model_custom_primary: Model): def test_automigration( tmp_working_dir: Path, monkeypatch: pytest.MonkeyPatch, + model_registry: Type[ModelRegistry], ): """Test alembic automigration with add and drop table and column. Args: tmp_working_dir: directory where database and migrations are stored monkeypatch: pytest fixture to overwrite attributes + model_registry: clean reflex ModelRegistry """ alembic_ini = tmp_working_dir / "alembic.ini" versions = tmp_working_dir / "alembic" / "versions" @@ -100,7 +102,7 @@ class AlembicThing(Model, table=True): # type: ignore session.add(AlembicThing(id=None, t1="foo")) session.commit() - sqlmodel.SQLModel.metadata.clear() + model_registry.get_metadata().clear() # Create column t2, mark t1 as optional with default class AlembicThing(Model, table=True): # type: ignore @@ -120,7 +122,7 @@ class AlembicThing(Model, table=True): # type: ignore assert result[1].t1 == "default" assert result[1].t2 == "baz" - sqlmodel.SQLModel.metadata.clear() + model_registry.get_metadata().clear() # Drop column t1 class AlembicThing(Model, table=True): # type: ignore @@ -156,7 +158,7 @@ class AlembicSecond(Model, table=True): # type: ignore assert len(list(versions.glob("*.py"))) == 4 # drop table (AlembicSecond) - sqlmodel.SQLModel.metadata.clear() + model_registry.get_metadata().clear() class AlembicThing(Model, table=True): # type: ignore t2: str = "bar" @@ -174,7 +176,7 @@ class AlembicThing(Model, table=True): # type: ignore assert result[0].t2 == "bar" assert result[1].t2 == "baz" - sqlmodel.SQLModel.metadata.clear() + model_registry.get_metadata().clear() class AlembicThing(Model, table=True): # type: ignore # changing column type not supported by default @@ -184,7 +186,7 @@ class AlembicThing(Model, table=True): # type: ignore assert len(list(versions.glob("*.py"))) == 5 # clear all metadata to avoid influencing subsequent tests - sqlmodel.SQLModel.metadata.clear() + model_registry.get_metadata().clear() # drop remaining tables assert Model.migrate(autogenerate=True) diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py index 4144076381..b18799e0c1 100644 --- a/tests/test_sqlalchemy.py +++ b/tests/test_sqlalchemy.py @@ -1,10 +1,10 @@ from pathlib import Path -from typing import Optional +from typing import Optional, Type from unittest import mock import pytest -import sqlalchemy from sqlalchemy import select +from sqlalchemy.exc import OperationalError from sqlalchemy.orm import ( DeclarativeBase, Mapped, @@ -24,12 +24,14 @@ def test_automigration( tmp_working_dir: Path, monkeypatch: pytest.MonkeyPatch, + model_registry: Type[ModelRegistry], ): """Test alembic automigration with add and drop table and column. Args: tmp_working_dir: directory where database and migrations are stored monkeypatch: pytest fixture to overwrite attributes + model_registry: clean reflex ModelRegistry """ alembic_ini = tmp_working_dir / "alembic.ini" versions = tmp_working_dir / "alembic" / "versions" @@ -39,6 +41,8 @@ def test_automigration( config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db" monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock)) + assert alembic_ini.exists() is False + assert versions.exists() is False Model.alembic_init() assert alembic_ini.exists() assert versions.exists() @@ -48,7 +52,7 @@ class Base(DeclarativeBase): def __tablename__(cls) -> str: return cls.__name__.lower() - assert ModelRegistry.register(Base) + assert model_registry.register(Base) class ModelBase(Base, MappedAsDataclass): __abstract__ = True @@ -71,7 +75,7 @@ class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] session.add(AlembicThing(t1="foo")) session.commit() - ModelRegistry.get_metadata().clear() + model_registry.get_metadata().clear() # Create column t2, mark t1 as optional with default class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] @@ -91,7 +95,7 @@ class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] assert result[1].t1 == "default" assert result[1].t2 == "baz" - ModelRegistry.get_metadata().clear() + model_registry.get_metadata().clear() # Drop column t1 class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] @@ -127,7 +131,7 @@ class AlembicSecond(ModelBase): # assert len(list(versions.glob("*.py"))) == 4 # drop table (AlembicSecond) - ModelRegistry.get_metadata().clear() + model_registry.get_metadata().clear() class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] t2: Mapped[str] = mapped_column(default="bar") @@ -136,7 +140,7 @@ class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] assert len(list(versions.glob("*.py"))) == 5 with reflex.model.session() as session: - with pytest.raises(sqlalchemy.exc.OperationalError) as errctx: + with pytest.raises(OperationalError) as errctx: _ = session.scalars(select(AlembicSecond)).all() assert errctx.match(r"no such table: alembicsecond") # first table should still exist @@ -145,7 +149,7 @@ class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] assert result[0].t2 == "bar" assert result[1].t2 == "baz" - ModelRegistry.get_metadata().clear() + model_registry.get_metadata().clear() class AlembicThing(ModelBase): # changing column type not supported by default @@ -155,7 +159,7 @@ class AlembicThing(ModelBase): assert len(list(versions.glob("*.py"))) == 5 # clear all metadata to avoid influencing subsequent tests - ModelRegistry.get_metadata().clear() + model_registry.get_metadata().clear() # drop remaining tables assert Model.migrate(autogenerate=True) From 1b7da3136e2b65900f3e7921f1460c3b61dee579 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Wed, 19 Jun 2024 20:16:26 +0200 Subject: [PATCH 5/7] did not meant to push this --- .gitignore | 1 + reflex.db | 0 2 files changed, 1 insertion(+) delete mode 100644 reflex.db diff --git a/.gitignore b/.gitignore index a570ed353f..bbaa8f0c92 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ venv requirements.txt .pyi_generator_last_run .pyi_generator_diff +reflex.db diff --git a/reflex.db b/reflex.db deleted file mode 100644 index e69de29bb2..0000000000 From e726ae13ab36677e9fa3e6595d5466ca4e54e943 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Wed, 19 Jun 2024 20:18:10 +0200 Subject: [PATCH 6/7] add docstring to model_registry --- tests/conftest.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 1c4d452e9a..71815ca9ad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -252,5 +252,10 @@ def token() -> str: @pytest.fixture def model_registry() -> Generator[Type[ModelRegistry], None, None]: + """Create a model registry. + + Yields: + A fresh model registry. + """ yield ModelRegistry ModelRegistry._metadata = None From e7c6f96c2623b2d7348678767dcdf55081cbcd3a Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Tue, 25 Jun 2024 13:12:51 +0200 Subject: [PATCH 7/7] do not expose sqla_session in reflex namespace --- reflex/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/__init__.py b/reflex/__init__.py index 18d2b2efe8..a71a6cc46c 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -304,7 +304,7 @@ "window_alert", ], "middleware": ["middleware", "Middleware"], - "model": ["session", "sqla_session", "Model"], + "model": ["session", "Model"], "state": [ "var", "Cookie",