From efadb35c4d90b87cf95b85e3abd599d1db2b2cde Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sun, 10 Oct 2021 08:23:06 +0200 Subject: [PATCH 01/16] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20REFACTOR:=20Remove?= =?UTF-8?q?=20direct=20access=20to=20BackendManager?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `BackendManager` is an implementation detail of the Django/SQLA backends and should not be accessed directly. --- aiida/backends/__init__.py | 17 ------ aiida/backends/djsite/__init__.py | 6 +- aiida/backends/djsite/manager.py | 24 -------- aiida/backends/manager.py | 23 ------- aiida/backends/sqlalchemy/__init__.py | 6 +- aiida/backends/sqlalchemy/manager.py | 13 ---- aiida/cmdline/commands/cmd_database.py | 6 +- aiida/cmdline/commands/cmd_setup.py | 3 +- aiida/engine/utils.py | 6 +- aiida/manage/manager.py | 61 ++++++------------- aiida/orm/implementation/django/__init__.py | 8 --- aiida/orm/implementation/django/backend.py | 21 ++++++- aiida/orm/implementation/django/convert.py | 2 - aiida/orm/implementation/django/groups.py | 2 - aiida/orm/implementation/django/users.py | 2 - aiida/orm/implementation/sql/backends.py | 57 +++++++++++++++++ .../orm/implementation/sqlalchemy/backend.py | 16 ++++- tests/cmdline/commands/test_database.py | 6 +- tests/cmdline/commands/test_setup.py | 8 +-- tests/restapi/conftest.py | 6 +- 20 files changed, 131 insertions(+), 162 deletions(-) diff --git a/aiida/backends/__init__.py b/aiida/backends/__init__.py index 81095dac98..1e4705626e 100644 --- a/aiida/backends/__init__.py +++ b/aiida/backends/__init__.py @@ -11,20 +11,3 @@ BACKEND_DJANGO = 'django' BACKEND_SQLA = 'sqlalchemy' - - -def get_backend_manager(backend): - """Get an instance of the `BackendManager` for the current backend. - - :param backend: the type of the database backend - :return: `BackendManager` - """ - if backend == BACKEND_DJANGO: - from aiida.backends.djsite.manager import DjangoBackendManager - return DjangoBackendManager() - - if backend == BACKEND_SQLA: - from aiida.backends.sqlalchemy.manager import SqlaBackendManager - return SqlaBackendManager() - - raise Exception(f'unknown backend type `{backend}`') diff --git a/aiida/backends/djsite/__init__.py b/aiida/backends/djsite/__init__.py index 94f96ba742..e1dd0e202d 100644 --- a/aiida/backends/djsite/__init__.py +++ b/aiida/backends/djsite/__init__.py @@ -31,7 +31,7 @@ def reset_session(): SESSION_FACTORY = None -def get_scoped_session(**kwargs): +def get_scoped_session(profile=None, **kwargs): """Return a scoped session for the given profile that is exclusively to be used for the `QueryBuilder`. Since the `QueryBuilder` implementation uses SqlAlchemy to map the query onto the models in order to generate the @@ -51,6 +51,8 @@ def get_scoped_session(**kwargs): :return: :class:`sqlalchemy.orm.session.Session` instance with engine configured for the given profile. """ from aiida.manage.configuration import get_profile + if profile is None: + profile = get_profile() global ENGINE global SESSION_FACTORY @@ -60,7 +62,7 @@ def get_scoped_session(**kwargs): return session if ENGINE is None: - ENGINE = create_sqlalchemy_engine(get_profile(), **kwargs) + ENGINE = create_sqlalchemy_engine(profile, **kwargs) SESSION_FACTORY = create_scoped_session_factory(ENGINE) diff --git a/aiida/backends/djsite/manager.py b/aiida/backends/djsite/manager.py index b635835b22..2cfd76efed 100644 --- a/aiida/backends/djsite/manager.py +++ b/aiida/backends/djsite/manager.py @@ -9,11 +9,6 @@ ########################################################################### # pylint: disable=import-error,no-name-in-module """Utilities and configuration of the Django database schema.""" - -import os - -import django - from aiida.common import NotExistent from ..manager import SCHEMA_VERSION_DESCRIPTION, SCHEMA_VERSION_KEY, BackendManager, Setting, SettingsManager @@ -35,25 +30,6 @@ def get_settings_manager(self): return self._settings_manager - def _load_backend_environment(self, **kwargs): - """Load the backend environment. - - The scoped session is needed for the QueryBuilder only. - - :param kwargs: keyword arguments that will be passed on to :py:func:`aiida.backends.djsite.get_scoped_session`. - """ - os.environ['DJANGO_SETTINGS_MODULE'] = 'aiida.backends.djsite.settings' - django.setup() # pylint: disable=no-member - - # For QueryBuilder only - from . import get_scoped_session - get_scoped_session(**kwargs) - - def reset_backend_environment(self): - """Reset the backend environment.""" - from . import reset_session - reset_session() - def is_database_schema_ahead(self): """Determine whether the database schema version is ahead of the code schema version. diff --git a/aiida/backends/manager.py b/aiida/backends/manager.py index 5a57baf141..082b1d8f0d 100644 --- a/aiida/backends/manager.py +++ b/aiida/backends/manager.py @@ -107,29 +107,6 @@ def get_settings_manager(self): :return: `SettingsManager` """ - def load_backend_environment(self, profile, validate_schema=True, **kwargs): - """Load the backend environment. - - :param profile: the profile whose backend environment to load - :param validate_schema: boolean, if True, validate the schema first before loading the environment. - :param kwargs: keyword arguments that will be passed on to the backend specific scoped session getter function. - """ - self._load_backend_environment(**kwargs) - - if validate_schema: - self.validate_schema(profile) - - @abc.abstractmethod - def _load_backend_environment(self, **kwargs): - """Load the backend environment. - - :param kwargs: keyword arguments that will be passed on to the backend specific scoped session getter function. - """ - - @abc.abstractmethod - def reset_backend_environment(self): - """Reset the backend environment.""" - def migrate(self): """Migrate the database to the latest schema generation or version.""" try: diff --git a/aiida/backends/sqlalchemy/__init__.py b/aiida/backends/sqlalchemy/__init__.py index 232346800d..45d35d9170 100644 --- a/aiida/backends/sqlalchemy/__init__.py +++ b/aiida/backends/sqlalchemy/__init__.py @@ -31,7 +31,7 @@ def reset_session(): SESSION_FACTORY = None -def get_scoped_session(**kwargs): +def get_scoped_session(profile=None, **kwargs): """Return a scoped session According to SQLAlchemy docs, this returns always the same object within a thread, and a different object in a @@ -43,6 +43,8 @@ def get_scoped_session(**kwargs): more info. """ from aiida.manage.configuration import get_profile + if profile is None: + profile = get_profile() global ENGINE global SESSION_FACTORY @@ -52,7 +54,7 @@ def get_scoped_session(**kwargs): return session if ENGINE is None: - ENGINE = create_sqlalchemy_engine(get_profile(), **kwargs) + ENGINE = create_sqlalchemy_engine(profile, **kwargs) SESSION_FACTORY = create_scoped_session_factory(ENGINE, expire_on_commit=True) diff --git a/aiida/backends/sqlalchemy/manager.py b/aiida/backends/sqlalchemy/manager.py index c8ded0f617..5299858b0b 100644 --- a/aiida/backends/sqlalchemy/manager.py +++ b/aiida/backends/sqlalchemy/manager.py @@ -81,19 +81,6 @@ def get_settings_manager(self): return self._settings_manager - def _load_backend_environment(self, **kwargs): - """Load the backend environment. - - :param kwargs: keyword arguments that will be passed on to - :py:func:`aiida.backends.sqlalchemy.get_scoped_session`. - """ - get_scoped_session(**kwargs) - - def reset_backend_environment(self): - """Reset the backend environment.""" - from . import reset_session - reset_session() - def is_database_schema_ahead(self): """Determine whether the database schema version is ahead of the code schema version. diff --git a/aiida/cmdline/commands/cmd_database.py b/aiida/cmdline/commands/cmd_database.py index 4752ad0964..d6955e298e 100644 --- a/aiida/cmdline/commands/cmd_database.py +++ b/aiida/cmdline/commands/cmd_database.py @@ -33,12 +33,12 @@ def database_version(): manager = get_manager() manager._load_backend(schema_check=False) # pylint: disable=protected-access - backend_manager = manager.get_backend_manager() + backend = manager.get_backend() echo.echo('Generation: ', bold=True, nl=False) - echo.echo(backend_manager.get_schema_generation_database()) + echo.echo(backend.get_schema_generation_database()) echo.echo('Revision: ', bold=True, nl=False) - echo.echo(backend_manager.get_schema_version_database()) + echo.echo(backend.get_schema_version_database()) @verdi_database.command('migrate') diff --git a/aiida/cmdline/commands/cmd_setup.py b/aiida/cmdline/commands/cmd_setup.py index 28f52ef8ad..8034084403 100644 --- a/aiida/cmdline/commands/cmd_setup.py +++ b/aiida/cmdline/commands/cmd_setup.py @@ -92,8 +92,7 @@ def setup( # Retrieve the repository UUID from the database. If set, this means this database is associated with the repository # with that UUID and we have to make sure that the provided repository corresponds to it. - backend_manager = manager.get_backend_manager() - repository_uuid_database = backend_manager.get_repository_uuid() + repository_uuid_database = backend.get_repository_uuid() repository_uuid_profile = profile.get_repository().uuid if repository_uuid_database != repository_uuid_profile: diff --git a/aiida/engine/utils.py b/aiida/engine/utils.py index f57f7500eb..e815d99ed5 100644 --- a/aiida/engine/utils.py +++ b/aiida/engine/utils.py @@ -275,7 +275,7 @@ def set_process_state_change_timestamp(process: 'Process') -> None: try: manager = get_manager() - manager.get_backend_manager().get_settings_manager().set(key, value, description) + manager.get_backend().set_value(key, value, description) except UniquenessError as exception: process.logger.debug(f'could not update the {key} setting because of a UniquenessError: {exception}') @@ -294,7 +294,7 @@ def get_process_state_change_timestamp(process_type: Optional[str] = None) -> Op from aiida.common.exceptions import NotExistent from aiida.manage.manager import get_manager # pylint: disable=cyclic-import - manager = get_manager().get_backend_manager().get_settings_manager() + backend = get_manager().get_backend() valid_process_types = ['calculation', 'work'] if process_type is not None and process_type not in valid_process_types: @@ -310,7 +310,7 @@ def get_process_state_change_timestamp(process_type: Optional[str] = None) -> Op for process_type_key in process_types: key = PROCESS_STATE_CHANGE_KEY.format(process_type_key) try: - time_stamp = timezone.isoformat_to_datetime(manager.get(key).value) + time_stamp = timezone.isoformat_to_datetime(backend.get_value(key).value) if time_stamp is not None: timestamps.append(time_stamp) except NotExistent: diff --git a/aiida/manage/manager.py b/aiida/manage/manager.py index 51a9cd79f7..e526457435 100644 --- a/aiida/manage/manager.py +++ b/aiida/manage/manager.py @@ -17,7 +17,6 @@ from kiwipy.rmq import RmqThreadCommunicator from plumpy.process_comms import RemoteProcessThreadController - from aiida.backends.manager import BackendManager from aiida.engine.daemon.client import DaemonClient from aiida.engine.persistence import AiiDAPersister from aiida.engine.runners import Runner @@ -46,7 +45,6 @@ class Manager: def __init__(self) -> None: self._backend: Optional['Backend'] = None - self._backend_manager: Optional['BackendManager'] = None self._config: Optional['Config'] = None self._daemon_client: Optional['DaemonClient'] = None self._profile: Optional['Profile'] = None @@ -63,7 +61,6 @@ def close(self) -> None: self._runner.stop() self._backend = None - self._backend_manager = None self._config = None self._profile = None self._communicator = None @@ -95,8 +92,8 @@ def get_profile() -> Optional['Profile']: def unload_backend(self) -> None: """Unload the current backend and its corresponding database environment.""" - manager = self.get_backend_manager() - manager.reset_backend_environment() + backend = self.get_backend() + backend.reset_environment() self._backend = None def _load_backend(self, schema_check: bool = True, repository_check: bool = True) -> 'Backend': @@ -110,7 +107,7 @@ def _load_backend(self, schema_check: bool = True, repository_check: bool = True for the current profile. :return: the database backend. """ - from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA, get_backend_manager + from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA from aiida.common import ConfigurationError, InvalidOperation from aiida.common.log import configure_logging from aiida.manage import configuration @@ -125,19 +122,29 @@ def _load_backend(self, schema_check: bool = True, repository_check: bool = True if configuration.BACKEND_UUID is not None and configuration.BACKEND_UUID != profile.uuid: raise InvalidOperation('cannot load backend because backend of another profile is already loaded') - backend_manager = get_backend_manager(profile.database_backend) + backend_type = profile.database_backend + + # For django only, the setup module must be loaded before the class can be instantiated, + if backend_type == BACKEND_DJANGO: + from aiida.orm.implementation.django.backend import DjangoBackend + backend_cls = DjangoBackend + elif backend_type == BACKEND_SQLA: + from aiida.orm.implementation.sqlalchemy.backend import SqlaBackend + backend_cls = SqlaBackend # Do NOT reload the backend environment if already loaded, simply reload the backend instance after if configuration.BACKEND_UUID is None: - backend_manager.load_backend_environment(profile, validate_schema=schema_check) + backend_cls.load_environment(profile, validate_schema=schema_check) configuration.BACKEND_UUID = profile.uuid + backend = backend_cls() + # Perform the check on the repository compatibility. Since this is new functionality and the stability is not # yet known, we issue a warning in the case the repo and database are incompatible. In the future this might # then become an exception once we have verified that it is working reliably. if repository_check and not profile.is_test_profile: repository_uuid_config = profile.get_repository().uuid - repository_uuid_database = backend_manager.get_repository_uuid() + repository_uuid_database = backend.get_repository_uuid() from aiida.cmdline.utils import echo if repository_uuid_config != repository_uuid_database: @@ -149,15 +156,7 @@ def _load_backend(self, schema_check: bool = True, repository_check: bool = True 'Please make sure that the configuration of your profile is correct.\n' ) - backend_type = profile.database_backend - - # Can only import the backend classes after the backend has been loaded - if backend_type == BACKEND_DJANGO: - from aiida.orm.implementation.django.backend import DjangoBackend - self._backend = DjangoBackend() - elif backend_type == BACKEND_SQLA: - from aiida.orm.implementation.sqlalchemy.backend import SqlaBackend - self._backend = SqlaBackend() + self._backend = backend # Reconfigure the logging with `with_orm=True` to make sure that profile specific logging configuration options # are taken into account and the `DbLogHandler` is configured. @@ -173,32 +172,6 @@ def backend_loaded(self) -> bool: """ return self._backend is not None - def get_backend_manager(self) -> 'BackendManager': - """Return the database backend manager. - - .. note:: this is not the actual backend, but a manager class that is necessary for database operations that - go around the actual ORM. For example when the schema version has not yet been validated. - - :return: the database backend manager - - """ - from aiida.backends import get_backend_manager - from aiida.common import ConfigurationError - - if self._backend_manager is None: - - if self._backend is None: - self._load_backend() - - profile = self.get_profile() - if profile is None: - raise ConfigurationError( - 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' - ) - self._backend_manager = get_backend_manager(profile.database_backend) - - return self._backend_manager - def get_backend(self) -> 'Backend': """Return the database backend diff --git a/aiida/orm/implementation/django/__init__.py b/aiida/orm/implementation/django/__init__.py index 5089f32237..86aa47b6c6 100644 --- a/aiida/orm/implementation/django/__init__.py +++ b/aiida/orm/implementation/django/__init__.py @@ -15,17 +15,9 @@ # pylint: disable=wildcard-import from .backend import * -from .convert import * -from .groups import * -from .users import * __all__ = ( 'DjangoBackend', - 'DjangoGroup', - 'DjangoGroupCollection', - 'DjangoUser', - 'DjangoUserCollection', - 'get_backend_entity', ) # yapf: enable diff --git a/aiida/orm/implementation/django/backend.py b/aiida/orm/implementation/django/backend.py index b6056bda35..06c38c2867 100644 --- a/aiida/orm/implementation/django/backend.py +++ b/aiida/orm/implementation/django/backend.py @@ -9,13 +9,15 @@ ########################################################################### """Django implementation of `aiida.orm.implementation.backends.Backend`.""" from contextlib import contextmanager +import os # pylint: disable=import-error,no-name-in-module +import django from django.db import models, transaction +from aiida.backends.djsite import get_scoped_session, reset_session from aiida.backends.djsite.manager import DjangoBackendManager -from . import authinfos, comments, computers, convert, groups, logs, nodes, querybuilder, users from ..sql.backends import SqlBackend __all__ = ('DjangoBackend',) @@ -26,6 +28,8 @@ class DjangoBackend(SqlBackend[models.Model]): def __init__(self): """Construct the backend instance by initializing all the collections.""" + from . import authinfos, comments, computers, groups, logs, nodes, users + super().__init__() self._authinfos = authinfos.DjangoAuthInfoCollection(self) self._comments = comments.DjangoCommentCollection(self) self._computers = computers.DjangoComputerCollection(self) @@ -35,6 +39,18 @@ def __init__(self): self._backend_manager = DjangoBackendManager() self._users = users.DjangoUserCollection(self) + @classmethod + def load_environment(cls, profile, validate_schema=True, **kwargs): + os.environ['DJANGO_SETTINGS_MODULE'] = 'aiida.backends.djsite.settings' + django.setup() # pylint: disable=no-member + # For QueryBuilder only + get_scoped_session(profile, **kwargs) + if validate_schema: + DjangoBackendManager().validate_schema(profile) + + def reset_environment(self): + reset_session() + def migrate(self): self._backend_manager.migrate() @@ -63,6 +79,7 @@ def nodes(self): return self._nodes def query(self): + from . import querybuilder return querybuilder.DjangoQueryBuilder(self) @property @@ -83,13 +100,13 @@ def get_session(): :return: an instance of :class:`sqlalchemy.orm.session.Session` """ - from aiida.backends.djsite import get_scoped_session return get_scoped_session() # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` def get_backend_entity(self, model): """Return a `BackendEntity` instance from a `DbModel` instance.""" + from . import convert return convert.get_backend_entity(model, self) @contextmanager diff --git a/aiida/orm/implementation/django/convert.py b/aiida/orm/implementation/django/convert.py index 0bfb836ee4..d9862f1a20 100644 --- a/aiida/orm/implementation/django/convert.py +++ b/aiida/orm/implementation/django/convert.py @@ -17,8 +17,6 @@ import aiida.backends.djsite.db.models as djmodels -__all__ = ('get_backend_entity',) - @singledispatch def get_backend_entity(dbmodel, backend): # pylint: disable=unused-argument diff --git a/aiida/orm/implementation/django/groups.py b/aiida/orm/implementation/django/groups.py index 7c900f410e..50082ea707 100644 --- a/aiida/orm/implementation/django/groups.py +++ b/aiida/orm/implementation/django/groups.py @@ -21,8 +21,6 @@ from . import entities, users, utils -__all__ = ('DjangoGroup', 'DjangoGroupCollection') - class DjangoGroup(entities.DjangoModelEntity[models.DbGroup], BackendGroup): # pylint: disable=abstract-method """The Django group object""" diff --git a/aiida/orm/implementation/django/users.py b/aiida/orm/implementation/django/users.py index bbf7f37183..cddf3bc10e 100644 --- a/aiida/orm/implementation/django/users.py +++ b/aiida/orm/implementation/django/users.py @@ -17,8 +17,6 @@ from . import entities, utils -__all__ = ('DjangoUser', 'DjangoUserCollection') - class DjangoUser(entities.DjangoModelEntity[models.DbUser], BackendUser): """The Django user class""" diff --git a/aiida/orm/implementation/sql/backends.py b/aiida/orm/implementation/sql/backends.py index 2bb21f22af..1ab0711aa7 100644 --- a/aiida/orm/implementation/sql/backends.py +++ b/aiida/orm/implementation/sql/backends.py @@ -29,6 +29,11 @@ class SqlBackend(typing.Generic[ModelType], backends.Backend): if any of these assumptions do not fit then just implement a backend from :class:`aiida.orm.implementation.Backend` """ + def __init__(self) -> None: + from aiida.backends.manager import BackendManager + super().__init__() + self._backend_manager = BackendManager() + @abc.abstractmethod def get_backend_entity(self, model): """ @@ -74,3 +79,55 @@ def execute_prepared_statement(self, sql, parameters): results.append(row) return results + + @classmethod + @abc.abstractmethod + def load_environment(cls, profile, validate_schema=True, **kwargs): + """Load the backend environment. + + :param profile: the profile whose backend environment to load + :param validate_schema: boolean, if True, validate the schema after loading the environment. + :param kwargs: keyword arguments that will be passed on to the backend specific scoped session getter function. + """ + + @abc.abstractmethod + def reset_environment(self): + """Reset the backend environment.""" + + def get_repository_uuid(self): + """Return the UUID of the repository that is associated with this database. + + :return: the UUID of the repository associated with this database or None if it doesn't exist. + """ + return self._backend_manager.get_repository_uuid() + + def get_schema_generation_database(self): + """Return the database schema version. + + :return: `distutils.version.LooseVersion` with schema version of the database + """ + return self._backend_manager.get_schema_generation_database() + + def get_schema_version_database(self): + """Return the database schema version. + + :return: `distutils.version.LooseVersion` with schema version of the database + """ + return self._backend_manager.get_schema_version_database() + + def set_value(self, key: str, value: typing.Any, description: typing.Optional[str] = None) -> None: + """Set a global key/value pair on the profile backend. + + :param key: the key identifying the setting + :param value: the value for the setting + :param description: optional setting description + """ + return self._backend_manager.get_settings_manager().set(key, value, description) + + def get_value(self, key: str) -> typing.Any: + """Get a global key/value pair on the profile backend. + + :param key: the key identifying the setting + :raises: `~aiida.common.exceptions.NotExistent` if the settings does not exist + """ + return self._backend_manager.get_settings_manager().get(key) diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index 64a7109bf9..516e4c2c34 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -10,6 +10,7 @@ """SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`.""" from contextlib import contextmanager +from aiida.backends.sqlalchemy import get_scoped_session, reset_session from aiida.backends.sqlalchemy.manager import SqlaBackendManager from aiida.backends.sqlalchemy.models import base @@ -24,17 +25,27 @@ class SqlaBackend(SqlBackend[base.Base]): def __init__(self): """Construct the backend instance by initializing all the collections.""" + super().__init__() self._authinfos = authinfos.SqlaAuthInfoCollection(self) self._comments = comments.SqlaCommentCollection(self) self._computers = computers.SqlaComputerCollection(self) self._groups = groups.SqlaGroupCollection(self) self._logs = logs.SqlaLogCollection(self) self._nodes = nodes.SqlaNodeCollection(self) - self._schema_manager = SqlaBackendManager() + self._backend_manager = SqlaBackendManager() self._users = users.SqlaUserCollection(self) + @classmethod + def load_environment(cls, profile, validate_schema=True, **kwargs): + get_scoped_session(profile, **kwargs) + if validate_schema: + SqlaBackendManager().validate_schema(profile) + + def reset_environment(self): # pylint: disable=no-self-use + reset_session() + def migrate(self): - self._schema_manager.migrate() + self._backend_manager.migrate() @property def authinfos(self): @@ -89,7 +100,6 @@ def get_session(): :return: an instance of :class:`sqlalchemy.orm.session.Session` """ - from aiida.backends.sqlalchemy import get_scoped_session return get_scoped_session() # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` diff --git a/tests/cmdline/commands/test_database.py b/tests/cmdline/commands/test_database.py index 7a70e1c302..8d08aaeb4b 100644 --- a/tests/cmdline/commands/test_database.py +++ b/tests/cmdline/commands/test_database.py @@ -174,10 +174,10 @@ def test_detect_invalid_nodes_unknown_node_type(self): @pytest.mark.usefixtures('aiida_profile') def tests_database_version(run_cli_command, manager): """Test the ``verdi database version`` command.""" - backend_manager = manager.get_backend_manager() + backend = manager.get_backend() result = run_cli_command(cmd_database.database_version) - assert result.output_lines[0].endswith(backend_manager.get_schema_generation_database()) - assert result.output_lines[1].endswith(backend_manager.get_schema_version_database()) + assert result.output_lines[0].endswith(backend.get_schema_generation_database()) + assert result.output_lines[1].endswith(backend.get_schema_version_database()) @pytest.mark.usefixtures('clear_database_before_test') diff --git a/tests/cmdline/commands/test_setup.py b/tests/cmdline/commands/test_setup.py index d23853c815..3e318ecbc4 100644 --- a/tests/cmdline/commands/test_setup.py +++ b/tests/cmdline/commands/test_setup.py @@ -79,8 +79,8 @@ def test_quicksetup(self): # Check that the repository UUID was stored in the database manager = get_manager() - backend_manager = manager.get_backend_manager() - self.assertEqual(backend_manager.get_repository_uuid(), profile.get_repository().uuid) + backend = manager.get_backend() + self.assertEqual(backend.get_repository_uuid(), profile.get_repository().uuid) def test_quicksetup_from_config_file(self): """Test `verdi quicksetup` from configuration file.""" @@ -166,5 +166,5 @@ def test_setup(self): # Check that the repository UUID was stored in the database manager = get_manager() - backend_manager = manager.get_backend_manager() - self.assertEqual(backend_manager.get_repository_uuid(), profile.get_repository().uuid) + backend = manager.get_backend() + self.assertEqual(backend.get_repository_uuid(), profile.get_repository().uuid) diff --git a/tests/restapi/conftest.py b/tests/restapi/conftest.py index 2ee8628764..7ff29333d5 100644 --- a/tests/restapi/conftest.py +++ b/tests/restapi/conftest.py @@ -52,9 +52,9 @@ def restrict_sqlalchemy_queuepool(aiida_profile): """Create special SQLAlchemy engine for use with QueryBuilder - backend-agnostic""" from aiida.manage.manager import get_manager - backend_manager = get_manager().get_backend_manager() - backend_manager.reset_backend_environment() - backend_manager.load_backend_environment(aiida_profile, pool_timeout=1, max_overflow=0) + backend = get_manager().get_backend() + backend.reset_environment() + backend.load_environment(aiida_profile, pool_timeout=1, max_overflow=0) @pytest.fixture From 61bc86d6a4b9742bc22051ffa74a79ed919e60b2 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sun, 10 Oct 2021 22:13:52 +0200 Subject: [PATCH 02/16] fix test --- tests/restapi/conftest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/restapi/conftest.py b/tests/restapi/conftest.py index 7ff29333d5..7f3398f786 100644 --- a/tests/restapi/conftest.py +++ b/tests/restapi/conftest.py @@ -52,9 +52,10 @@ def restrict_sqlalchemy_queuepool(aiida_profile): """Create special SQLAlchemy engine for use with QueryBuilder - backend-agnostic""" from aiida.manage.manager import get_manager + manager = get_manager() backend = get_manager().get_backend() backend.reset_environment() - backend.load_environment(aiida_profile, pool_timeout=1, max_overflow=0) + backend.load_environment(manager.get_profile(), pool_timeout=1, max_overflow=0) @pytest.fixture From 49c250ffd84a55b619aeea3beb03b7ca8a57153f Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sun, 10 Oct 2021 22:39:22 +0200 Subject: [PATCH 03/16] fix pylint --- tests/restapi/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/restapi/conftest.py b/tests/restapi/conftest.py index 7f3398f786..de991b2abe 100644 --- a/tests/restapi/conftest.py +++ b/tests/restapi/conftest.py @@ -48,7 +48,7 @@ def server_url(): @pytest.fixture -def restrict_sqlalchemy_queuepool(aiida_profile): +def restrict_sqlalchemy_queuepool(aiida_profile): # pylint: disable=unused-argument """Create special SQLAlchemy engine for use with QueryBuilder - backend-agnostic""" from aiida.manage.manager import get_manager From fbfe1f7080739cef6db52fe58c1af88550d50a61 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 14 Oct 2021 08:35:01 +0200 Subject: [PATCH 04/16] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20REFACTOR:=20Remove?= =?UTF-8?q?=20global=20variable=20for=20sqlalchemy=20session=20access?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aiida/backends/djsite/__init__.py | 58 --------- aiida/backends/djsite/manage.py | 2 +- aiida/backends/djsite/manager.py | 22 +--- aiida/backends/manager.py | 4 + aiida/backends/sqlalchemy/__init__.py | 50 -------- aiida/backends/sqlalchemy/manage.py | 11 +- aiida/backends/sqlalchemy/manager.py | 21 ++-- aiida/backends/sqlalchemy/models/base.py | 82 +----------- aiida/backends/sqlalchemy/models/settings.py | 21 ++-- aiida/backends/sqlalchemy/utils.py | 10 +- aiida/backends/testbase.py | 16 ++- aiida/cmdline/commands/cmd_database.py | 6 +- aiida/cmdline/commands/cmd_setup.py | 2 +- aiida/common/lang.py | 9 +- aiida/manage/configuration/__init__.py | 2 +- aiida/manage/manager.py | 37 +++--- aiida/manage/tests/main.py | 7 +- aiida/orm/implementation/backends.py | 33 ++++- aiida/orm/implementation/django/backend.py | 53 ++++---- aiida/orm/implementation/django/comments.py | 2 +- aiida/orm/implementation/django/computers.py | 2 +- aiida/orm/implementation/django/entities.py | 2 +- aiida/orm/implementation/django/groups.py | 4 +- aiida/orm/implementation/django/logs.py | 2 +- aiida/orm/implementation/groups.py | 7 +- aiida/orm/implementation/sql/backends.py | 83 +++++++----- .../implementation/sqlalchemy/authinfos.py | 29 ++--- .../orm/implementation/sqlalchemy/backend.py | 41 +++--- .../orm/implementation/sqlalchemy/comments.py | 32 ++--- .../implementation/sqlalchemy/computers.py | 19 +-- .../orm/implementation/sqlalchemy/entities.py | 4 +- aiida/orm/implementation/sqlalchemy/groups.py | 36 +++--- aiida/orm/implementation/sqlalchemy/logs.py | 33 ++--- aiida/orm/implementation/sqlalchemy/nodes.py | 30 ++--- .../sqlalchemy/querybuilder/joiner.py | 18 +-- aiida/orm/implementation/sqlalchemy/users.py | 4 +- aiida/orm/implementation/sqlalchemy/utils.py | 107 +++++++++------- .../importexport/dbimport/backends/sqla.py | 4 +- .../migrations/test_migrations_common.py | 4 +- .../aiida_sqlalchemy/test_migrations.py | 119 +++++++++--------- tests/backends/aiida_sqlalchemy/test_nodes.py | 23 ++-- .../backends/aiida_sqlalchemy/test_schema.py | 14 ++- .../backends/aiida_sqlalchemy/test_session.py | 26 ++-- tests/orm/implementation/test_backend.py | 21 +++- tests/orm/implementation/test_comments.py | 12 +- tests/orm/implementation/test_logs.py | 12 +- tests/orm/implementation/test_nodes.py | 11 +- tests/orm/test_authinfos.py | 1 - tests/orm/test_groups.py | 2 + tests/restapi/conftest.py | 5 +- 50 files changed, 493 insertions(+), 662 deletions(-) diff --git a/aiida/backends/djsite/__init__.py b/aiida/backends/djsite/__init__.py index e1dd0e202d..b9ca8f4299 100644 --- a/aiida/backends/djsite/__init__.py +++ b/aiida/backends/djsite/__init__.py @@ -9,61 +9,3 @@ ########################################################################### # pylint: disable=global-statement """Module with implementation of the database backend using Django.""" -from aiida.backends.utils import create_scoped_session_factory, create_sqlalchemy_engine - -ENGINE = None -SESSION_FACTORY = None - - -def reset_session(): - """Reset the session which means setting the global engine and session factory instances to `None`.""" - global ENGINE - global SESSION_FACTORY - - if ENGINE is not None: - ENGINE.dispose() - - if SESSION_FACTORY is not None: - SESSION_FACTORY.expunge_all() # pylint: disable=no-member - SESSION_FACTORY.close() # pylint: disable=no-member - - ENGINE = None - SESSION_FACTORY = None - - -def get_scoped_session(profile=None, **kwargs): - """Return a scoped session for the given profile that is exclusively to be used for the `QueryBuilder`. - - Since the `QueryBuilder` implementation uses SqlAlchemy to map the query onto the models in order to generate the - SQL to be sent to the database, it requires a session, which is an :class:`sqlalchemy.orm.session.Session` instance. - The only purpose is for SqlAlchemy to be able to connect to the database perform the query and retrieve the results. - Even the Django backend implementation will use SqlAlchemy for its `QueryBuilder` and so also needs an SqlA session. - It is important that we do not reuse the scoped session factory in the SqlAlchemy implementation, because that runs - the risk of cross-talk once profiles can be switched dynamically in a single python interpreter. Therefore the - Django implementation of the `QueryBuilder` should keep its own SqlAlchemy engine and scoped session factory - instances that are used to provide the query builder with a session. - - :param kwargs: keyword arguments that will be passed on to :py:func:`aiida.backends.utils.create_sqlalchemy_engine`, - opening the possibility to change QueuePool time outs and more. - See https://docs.sqlalchemy.org/en/13/core/engines.html?highlight=create_engine#sqlalchemy.create_engine for - more info. - - :return: :class:`sqlalchemy.orm.session.Session` instance with engine configured for the given profile. - """ - from aiida.manage.configuration import get_profile - if profile is None: - profile = get_profile() - - global ENGINE - global SESSION_FACTORY - - if SESSION_FACTORY is not None: - session = SESSION_FACTORY() - return session - - if ENGINE is None: - ENGINE = create_sqlalchemy_engine(profile, **kwargs) - - SESSION_FACTORY = create_scoped_session_factory(ENGINE) - - return SESSION_FACTORY() diff --git a/aiida/backends/djsite/manage.py b/aiida/backends/djsite/manage.py index ec7732002a..526fcc76c3 100755 --- a/aiida/backends/djsite/manage.py +++ b/aiida/backends/djsite/manage.py @@ -24,7 +24,7 @@ def main(profile, command): # pylint: disable=unused-argument from aiida.manage.manager import get_manager manager = get_manager() - manager._load_backend(schema_check=False) # pylint: disable=protected-access + manager._load_backend(validate_db=False) # pylint: disable=protected-access # The `execute_from_command` expects a list of command line arguments where the first is the program name that one # would normally call directly. Since this is now replaced by our `click` command we just spoof a random name. diff --git a/aiida/backends/djsite/manager.py b/aiida/backends/djsite/manager.py index 2cfd76efed..e71c34f917 100644 --- a/aiida/backends/djsite/manager.py +++ b/aiida/backends/djsite/manager.py @@ -61,12 +61,8 @@ def get_schema_generation_database(self): """ from django.db.utils import ProgrammingError - from aiida.manage.manager import get_manager - - backend = get_manager()._load_backend(schema_check=False, repository_check=False) # pylint: disable=protected-access - try: - result = backend.execute_raw(r"""SELECT val FROM db_dbsetting WHERE key = 'schema_generation';""") + result = self._backend.execute_raw(r"""SELECT val FROM db_dbsetting WHERE key = 'schema_generation';""") except ProgrammingError: # If this value does not exist, the schema has to correspond to the first generation which didn't actually # record its value explicitly in the database until ``aiida-core>=1.0.0``. @@ -84,14 +80,10 @@ def get_schema_version_database(self): """ from django.db.utils import ProgrammingError - from aiida.manage.manager import get_manager - - backend = get_manager()._load_backend(schema_check=False, repository_check=False) # pylint: disable=protected-access - try: - result = backend.execute_raw(r"""SELECT val FROM db_dbsetting WHERE key = 'db|schemaversion';""") + result = self._backend.execute_raw(r"""SELECT val FROM db_dbsetting WHERE key = 'db|schemaversion';""") except ProgrammingError: - result = backend.execute_raw(r"""SELECT tval FROM db_dbsetting WHERE key = 'db|schemaversion';""") + result = self._backend.execute_raw(r"""SELECT tval FROM db_dbsetting WHERE key = 'db|schemaversion';""") return result[0][0] def set_schema_version_database(self, version): @@ -107,13 +99,9 @@ def _migrate_database_generation(self): For Django we also have to clear the `django_migrations` table that contains a history of all applied migrations. After clearing it, we reinsert the name of the new initial schema . """ - # pylint: disable=cyclic-import - from aiida.manage.manager import get_manager super()._migrate_database_generation() - - backend = get_manager()._load_backend(schema_check=False, repository_check=False) # pylint: disable=protected-access - backend.execute_raw(r"""DELETE FROM django_migrations WHERE app = 'db';""") - backend.execute_raw( + self._backend.execute_raw(r"""DELETE FROM django_migrations WHERE app = 'db';""") + self._backend.execute_raw( r"""INSERT INTO django_migrations (app, name, applied) VALUES ('db', '0001_initial', NOW());""" ) diff --git a/aiida/backends/manager.py b/aiida/backends/manager.py index 082b1d8f0d..c64ad9d6b5 100644 --- a/aiida/backends/manager.py +++ b/aiida/backends/manager.py @@ -100,6 +100,10 @@ class BackendManager: _settings_manager = None + def __init__(self, backend) -> None: + from aiida.backends.sqlalchemy.manager import SqlaBackendManager + self._backend: SqlaBackendManager = backend + @abc.abstractmethod def get_settings_manager(self): """Return an instance of the `SettingsManager`. diff --git a/aiida/backends/sqlalchemy/__init__.py b/aiida/backends/sqlalchemy/__init__.py index 45d35d9170..d27c5f9dde 100644 --- a/aiida/backends/sqlalchemy/__init__.py +++ b/aiida/backends/sqlalchemy/__init__.py @@ -9,53 +9,3 @@ ########################################################################### # pylint: disable=global-statement """Module with implementation of the database backend using SqlAlchemy.""" -from aiida.backends.utils import create_scoped_session_factory, create_sqlalchemy_engine - -ENGINE = None -SESSION_FACTORY = None - - -def reset_session(): - """Reset the session which means setting the global engine and session factory instances to `None`.""" - global ENGINE - global SESSION_FACTORY - - if ENGINE is not None: - ENGINE.dispose() - - if SESSION_FACTORY is not None: - SESSION_FACTORY.expunge_all() # pylint: disable=no-member - SESSION_FACTORY.close() # pylint: disable=no-member - - ENGINE = None - SESSION_FACTORY = None - - -def get_scoped_session(profile=None, **kwargs): - """Return a scoped session - - According to SQLAlchemy docs, this returns always the same object within a thread, and a different object in a - different thread. Moreover, since we update the session class upon forking, different session objects will be used. - - :param kwargs: keyword argument that will be passed on to :py:func:`aiida.backends.utils.create_sqlalchemy_engine`, - opening the possibility to change QueuePool time outs and more. - See https://docs.sqlalchemy.org/en/13/core/engines.html?highlight=create_engine#sqlalchemy.create_engine for - more info. - """ - from aiida.manage.configuration import get_profile - if profile is None: - profile = get_profile() - - global ENGINE - global SESSION_FACTORY - - if SESSION_FACTORY is not None: - session = SESSION_FACTORY() - return session - - if ENGINE is None: - ENGINE = create_sqlalchemy_engine(profile, **kwargs) - - SESSION_FACTORY = create_scoped_session_factory(ENGINE, expire_on_commit=True) - - return SESSION_FACTORY() diff --git a/aiida/backends/sqlalchemy/manage.py b/aiida/backends/sqlalchemy/manage.py index 1538a1b9e1..74c30c3456 100755 --- a/aiida/backends/sqlalchemy/manage.py +++ b/aiida/backends/sqlalchemy/manage.py @@ -22,11 +22,12 @@ def execute_alembic_command(command_name, **kwargs): :param command_name: the sub command name :param kwargs: parameters to pass to the command """ - from aiida.backends.sqlalchemy.manager import SqlaBackendManager + from aiida.manage.configuration import get_profile + from aiida.orm.implementation.sqlalchemy.backend import SqlaBackend - manager = SqlaBackendManager() - - with manager.alembic_config() as config: + # create a new backend which does not validate the schema version + backend = SqlaBackend(get_profile(), validate_db=False) + with backend._backend_manager.alembic_config() as config: # pylint: disable=protected-access command = getattr(alembic.command, command_name) command(config, **kwargs) @@ -40,7 +41,7 @@ def alembic_cli(profile): load_profile(profile=profile.name) manager = get_manager() - manager._load_backend(schema_check=False) # pylint: disable=protected-access + manager._load_backend(validate_db=False) # pylint: disable=protected-access @alembic_cli.command('revision') diff --git a/aiida/backends/sqlalchemy/manager.py b/aiida/backends/sqlalchemy/manager.py index 5299858b0b..972fa4ec6d 100644 --- a/aiida/backends/sqlalchemy/manager.py +++ b/aiida/backends/sqlalchemy/manager.py @@ -14,7 +14,6 @@ import sqlalchemy from sqlalchemy.orm.exc import NoResultFound -from aiida.backends.sqlalchemy import get_scoped_session from aiida.common import NotExistent from ..manager import BackendManager, Setting, SettingsManager @@ -28,9 +27,8 @@ class SqlaBackendManager(BackendManager): """Class to manage the database schema.""" - @staticmethod @contextlib.contextmanager - def alembic_config(): + def alembic_config(self): """Context manager to return an instance of an Alembic configuration. The current database connection is added in the `attributes` property, through which it can then also be @@ -38,9 +36,7 @@ def alembic_config(): """ from alembic.config import Config - from . import ENGINE - - with ENGINE.begin() as connection: + with self._backend.get_session().bind.begin() as connection: dir_path = os.path.dirname(os.path.realpath(__file__)) config = Config() config.set_main_option('script_location', os.path.join(dir_path, ALEMBIC_REL_PATH)) @@ -77,7 +73,7 @@ def get_settings_manager(self): :return: `SettingsManager` """ if self._settings_manager is None: - self._settings_manager = SqlaSettingsManager() + self._settings_manager = SqlaSettingsManager(self._backend) return self._settings_manager @@ -134,12 +130,15 @@ class SqlaSettingsManager(SettingsManager): table_name = 'db_dbsetting' + def __init__(self, backend) -> None: + self._backend = backend + def validate_table_existence(self): """Verify that the `DbSetting` table actually exists. :raises: `~aiida.common.exceptions.NotExistent` if the settings table does not exist """ - inspector = sqlalchemy.inspect(get_scoped_session().bind) + inspector = sqlalchemy.inspect(self._backend.get_session().bind) if self.table_name not in inspector.get_table_names(): raise NotExistent('the settings table does not exist') @@ -154,7 +153,7 @@ def get(self, key): self.validate_table_existence() try: - setting = get_scoped_session().query(DbSetting).filter_by(key=key).one() + setting = self._backend.get_session().query(DbSetting).filter_by(key=key).one() except NoResultFound: raise NotExistent(f'setting `{key}` does not exist') from NoResultFound @@ -177,7 +176,7 @@ def set(self, key, value, description=None): if description is not None: other_attribs['description'] = description - DbSetting.set_value(key, value, other_attribs=other_attribs) + DbSetting.set_value(self._backend.get_session(), key, value, other_attribs=other_attribs) def delete(self, key): """Delete the setting with the given key. @@ -189,7 +188,7 @@ def delete(self, key): self.validate_table_existence() try: - setting = get_scoped_session().query(DbSetting).filter_by(key=key).one() + setting = self._backend.get_session().query(DbSetting).filter_by(key=key).one() setting.delete() except NoResultFound: raise NotExistent(f'setting `{key}` does not exist') from NoResultFound diff --git a/aiida/backends/sqlalchemy/models/base.py b/aiida/backends/sqlalchemy/models/base.py index dd7f6ab9ad..ebb599df60 100644 --- a/aiida/backends/sqlalchemy/models/base.py +++ b/aiida/backends/sqlalchemy/models/base.py @@ -9,86 +9,6 @@ ########################################################################### # pylint: disable=import-error,no-name-in-module """Base SQLAlchemy models.""" - -from sqlalchemy import orm from sqlalchemy.orm import declarative_base -from sqlalchemy.orm.exc import UnmappedClassError - -import aiida.backends.sqlalchemy -from aiida.backends.sqlalchemy import get_scoped_session -from aiida.common.exceptions import InvalidOperation - -# Taken from -# https://github.com/mitsuhiko/flask-sqlalchemy/blob/master/flask_sqlalchemy/__init__.py#L491 - - -class _QueryProperty: - """Query property.""" - - def __init__(self, query_class=orm.Query): - self.query_class = query_class - - def __get__(self, obj, _type): - """Get property of a query.""" - try: - mapper = orm.class_mapper(_type) - if mapper: - return self.query_class(mapper, session=aiida.backends.sqlalchemy.get_scoped_session()) - return None - except UnmappedClassError: - return None - - -class _SessionProperty: - """Session Property""" - - def __get__(self, obj, _type): - if not aiida.backends.sqlalchemy.get_scoped_session(): - raise InvalidOperation('You need to call load_dbenv before accessing the session of SQLALchemy.') - return aiida.backends.sqlalchemy.get_scoped_session() - - -class _AiidaQuery(orm.Query): - """AiiDA query.""" - - def __iter__(self): - """Iterator.""" - from aiida.orm.implementation.sqlalchemy import convert # pylint: disable=cyclic-import - - iterator = super().__iter__() - for result in iterator: - # Allow the use of with_entities - if issubclass(type(result), Model): - yield convert.get_backend_entity(result, None) - else: - yield result - - -class Model: - """Query model.""" - query = _QueryProperty() - - session = _SessionProperty() - - def save(self, commit=True): - """Emulate the behavior of Django's save() method - - :param commit: whether to do a commit or just add to the session - :return: the SQLAlchemy instance""" - sess = get_scoped_session() - sess.add(self) - if commit: - sess.commit() - return self - - def delete(self, commit=True): - """Emulate the behavior of Django's delete() method - - :param commit: whether to do a commit or just remover from the session""" - sess = get_scoped_session() - sess.delete(self) - if commit: - sess.commit() - -Base = declarative_base(cls=Model, name='Model') # pylint: disable=invalid-name +Base = declarative_base(name='Model') # pylint: disable=invalid-name diff --git a/aiida/backends/sqlalchemy/models/settings.py b/aiida/backends/sqlalchemy/models/settings.py index 71b3cf03a4..16460c0a70 100644 --- a/aiida/backends/sqlalchemy/models/settings.py +++ b/aiida/backends/sqlalchemy/models/settings.py @@ -16,7 +16,6 @@ from sqlalchemy.schema import UniqueConstraint from sqlalchemy.types import DateTime, Integer, String -from aiida.backends import sqlalchemy as sa from aiida.backends.sqlalchemy.models.base import Base from aiida.common import timezone @@ -38,10 +37,10 @@ def __str__(self): return f"'{self.key}'={self.getvalue()}" @classmethod - def set_value(cls, key, value, other_attribs=None, stop_if_existing=False): + def set_value(cls, session, key, value, other_attribs=None, stop_if_existing=False): """Set a setting value.""" other_attribs = other_attribs if other_attribs is not None else {} - setting = sa.get_scoped_session().query(DbSetting).filter_by(key=key).first() + setting = session.query(DbSetting).filter_by(key=key).first() if setting is not None: if stop_if_existing: return @@ -54,7 +53,12 @@ def set_value(cls, key, value, other_attribs=None, stop_if_existing=False): setting.time = timezone.datetime.now(tz=UTC) if 'description' in other_attribs.keys(): setting.description = other_attribs['description'] - setting.save() + session.add(setting) + try: + session.commit() + except Exception: # pylint: disable=broad-except + session.rollback() + raise def getvalue(self): """This can be called on a given row and will get the corresponding value.""" @@ -63,12 +67,3 @@ def getvalue(self): def get_description(self): """This can be called on a given row and will get the corresponding description.""" return self.description - - @classmethod - def del_value(cls, key): - """Delete a setting value.""" - setting = sa.get_scoped_session().query(DbSetting).filter(key=key) - setting.val = None - setting.time = timezone.datetime.utcnow() - flag_modified(setting, 'val') - setting.save() diff --git a/aiida/backends/sqlalchemy/utils.py b/aiida/backends/sqlalchemy/utils.py index a8d76265ef..6beb576fb3 100644 --- a/aiida/backends/sqlalchemy/utils.py +++ b/aiida/backends/sqlalchemy/utils.py @@ -40,18 +40,18 @@ def delete_nodes_and_connections_sqla(pks_to_delete): # pylint: disable=invalid def flag_modified(instance, key): - """Wrapper around `sqlalchemy.orm.attributes.flag_modified` to correctly dereference utils.ModelWrapper + """Wrapper around `sqlalchemy.orm.attributes.flag_modified` to correctly dereference a StorableModel Since SqlAlchemy 1.2.12 (and maybe earlier but not in 1.0.19) the flag_modified function will check that the - key is actually present in the instance or it will except. If we pass a model instance, wrapped in the ModelWrapper + key is actually present in the instance or it will except. If we pass a model instance, wrapped in the StorableModel the call will raise an InvalidRequestError. In this function that wraps the flag_modified of SqlAlchemy, we - derefence the model instance if the passed instance is actually wrapped in the ModelWrapper. + dereference the model instance if the passed instance is actually wrapped in the StorableModel. """ from sqlalchemy.orm.attributes import flag_modified as flag_modified_sqla - from aiida.orm.implementation.sqlalchemy.utils import ModelWrapper + from aiida.orm.implementation.sqlalchemy.utils import StorableModel - if isinstance(instance, ModelWrapper): + if isinstance(instance, StorableModel): instance = instance._model # pylint: disable=protected-access flag_modified_sqla(instance, key) diff --git a/aiida/backends/testbase.py b/aiida/backends/testbase.py index f6307df575..efb7ba4c09 100644 --- a/aiida/backends/testbase.py +++ b/aiida/backends/testbase.py @@ -10,6 +10,7 @@ """Basic test classes.""" import os import traceback +from typing import TYPE_CHECKING import unittest from aiida import orm @@ -18,6 +19,9 @@ from aiida.manage import configuration from aiida.manage.manager import get_manager, reset_manager +if TYPE_CHECKING: + from aiida.orm.implementation import Backend + TEST_KEYWORD = 'test_' @@ -36,7 +40,6 @@ class AiidaTestCase(unittest.TestCase): _user = None # type: aiida.orm.User _class_was_setup = False __backend_instance = None - backend = None # type: aiida.orm.implementation.Backend @classmethod def get_backend_class(cls): @@ -65,6 +68,11 @@ def get_backend_class(cls): return cls.__impl_class + @classproperty + def backend(cls) -> 'Backend': # pylint: disable=no-self-argument,no-self-use + """Return the backend instance.""" + return get_manager().get_backend() + @classmethod def setUpClass(cls): """Set up test class.""" @@ -74,7 +82,7 @@ def setUpClass(cls): check_if_tests_can_run() # Force the loading of the backend which will load the required database environment - cls.backend = get_manager().get_backend() + cls.backend # pylint: disable=pointless-statement cls.__backend_instance = cls.get_backend_class()() cls._class_was_setup = True @@ -98,6 +106,10 @@ def tearDownClass(cls): def tearDown(self): reset_manager() + # the user and computer need to be reset, so that they can be set to the new backend + # pylint: disable=protected-access + self.__class__._computer = None + self.__class__._user = None ### Database/repository-related methods diff --git a/aiida/cmdline/commands/cmd_database.py b/aiida/cmdline/commands/cmd_database.py index d6955e298e..5d458112c9 100644 --- a/aiida/cmdline/commands/cmd_database.py +++ b/aiida/cmdline/commands/cmd_database.py @@ -32,7 +32,7 @@ def database_version(): from aiida.manage.manager import get_manager manager = get_manager() - manager._load_backend(schema_check=False) # pylint: disable=protected-access + manager._load_backend(validate_db=False) # pylint: disable=protected-access backend = manager.get_backend() echo.echo('Generation: ', bold=True, nl=False) @@ -54,7 +54,7 @@ def database_migrate(force): manager = get_manager() profile = manager.get_profile() - backend = manager._load_backend(schema_check=False) # pylint: disable=protected-access + backend = manager._load_backend(validate_db=False) # pylint: disable=protected-access if force: try: @@ -124,7 +124,7 @@ def detect_duplicate_uuid(table, apply_patch): from aiida.manage.manager import get_manager manager = get_manager() - manager._load_backend(schema_check=False) # pylint: disable=protected-access + manager._load_backend(validate_db=False) # pylint: disable=protected-access try: messages = deduplicate_uuids(table=table, dry_run=not apply_patch) diff --git a/aiida/cmdline/commands/cmd_setup.py b/aiida/cmdline/commands/cmd_setup.py index 8034084403..797ffb40e1 100644 --- a/aiida/cmdline/commands/cmd_setup.py +++ b/aiida/cmdline/commands/cmd_setup.py @@ -79,7 +79,7 @@ def setup( # Migrate the database echo.echo_report('migrating the database.') manager = get_manager() - backend = manager._load_backend(schema_check=False, repository_check=False) # pylint: disable=protected-access + backend = manager._load_backend(validate_db=False, repository_check=False) # pylint: disable=protected-access try: backend.migrate() diff --git a/aiida/common/lang.py b/aiida/common/lang.py index f2bb8906f6..37f977b8b9 100644 --- a/aiida/common/lang.py +++ b/aiida/common/lang.py @@ -11,6 +11,7 @@ import functools import inspect import keyword +from typing import Any, Callable, Generic, TypeVar def isidentifier(identifier): @@ -75,8 +76,10 @@ def wrapped_fn(self, *args, **kwargs): # pylint: disable=missing-docstring override = override_decorator(check=False) # pylint: disable=invalid-name +ReturnType = TypeVar('ReturnType') -class classproperty: # pylint: disable=invalid-name + +class classproperty(Generic[ReturnType]): # pylint: disable=invalid-name """ A class that, when used as a decorator, works as if the two decorators @property and @classmethod where applied together @@ -85,8 +88,8 @@ class classproperty: # pylint: disable=invalid-name instance as its first argument). """ - def __init__(self, getter): + def __init__(self, getter: Callable[[Any], ReturnType]) -> None: self.getter = getter - def __get__(self, instance, owner): + def __get__(self, instance: Any, owner: Any) -> ReturnType: return self.getter(owner) diff --git a/aiida/manage/configuration/__init__.py b/aiida/manage/configuration/__init__.py index e62d5f2e45..64e1602c86 100644 --- a/aiida/manage/configuration/__init__.py +++ b/aiida/manage/configuration/__init__.py @@ -304,4 +304,4 @@ def load_documentation_profile(): config = {'default_profile': profile_name, 'profiles': {profile_name: profile}} PROFILE = Profile(profile_name, profile, from_config=True) CONFIG = Config(handle.name, config) - get_manager()._load_backend(schema_check=False, repository_check=False) # pylint: disable=protected-access + get_manager()._load_backend(validate_db=False, repository_check=False) # pylint: disable=protected-access diff --git a/aiida/manage/manager.py b/aiida/manage/manager.py index e526457435..e8047b307f 100644 --- a/aiida/manage/manager.py +++ b/aiida/manage/manager.py @@ -59,6 +59,8 @@ def close(self) -> None: self._communicator.close() if self._runner is not None: self._runner.stop() + if self._backend is not None: + self._backend.close() self._backend = None self._config = None @@ -93,16 +95,16 @@ def get_profile() -> Optional['Profile']: def unload_backend(self) -> None: """Unload the current backend and its corresponding database environment.""" backend = self.get_backend() - backend.reset_environment() + backend.close() self._backend = None - def _load_backend(self, schema_check: bool = True, repository_check: bool = True) -> 'Backend': + def _load_backend(self, validate_db: bool = True, repository_check: bool = True) -> 'Backend': """Load the backend for the currently configured profile and return it. .. note:: this will reconstruct the `Backend` instance in `self._backend` so the preferred method to load the backend is to call `get_backend` which will create it only when not yet instantiated. - :param schema_check: force a database schema check if the database environment has not yet been loaded. + :param validate_db: force a database schema check if the database environment has not yet been loaded. :param repository_check: force a check that the database is associated with the repository that is configured for the current profile. :return: the database backend. @@ -122,29 +124,28 @@ def _load_backend(self, schema_check: bool = True, repository_check: bool = True if configuration.BACKEND_UUID is not None and configuration.BACKEND_UUID != profile.uuid: raise InvalidOperation('cannot load backend because backend of another profile is already loaded') - backend_type = profile.database_backend + # only reload the backend if not already loaded + if self._backend is None: + + backend_type = profile.database_backend - # For django only, the setup module must be loaded before the class can be instantiated, - if backend_type == BACKEND_DJANGO: - from aiida.orm.implementation.django.backend import DjangoBackend - backend_cls = DjangoBackend - elif backend_type == BACKEND_SQLA: - from aiida.orm.implementation.sqlalchemy.backend import SqlaBackend - backend_cls = SqlaBackend + if backend_type == BACKEND_DJANGO: + from aiida.orm.implementation.django.backend import DjangoBackend + backend_cls = DjangoBackend + elif backend_type == BACKEND_SQLA: + from aiida.orm.implementation.sqlalchemy.backend import SqlaBackend + backend_cls = SqlaBackend - # Do NOT reload the backend environment if already loaded, simply reload the backend instance after - if configuration.BACKEND_UUID is None: - backend_cls.load_environment(profile, validate_schema=schema_check) - configuration.BACKEND_UUID = profile.uuid + self._backend = backend_cls(profile, validate_db=validate_db) - backend = backend_cls() + configuration.BACKEND_UUID = profile.uuid # Perform the check on the repository compatibility. Since this is new functionality and the stability is not # yet known, we issue a warning in the case the repo and database are incompatible. In the future this might # then become an exception once we have verified that it is working reliably. if repository_check and not profile.is_test_profile: repository_uuid_config = profile.get_repository().uuid - repository_uuid_database = backend.get_repository_uuid() + repository_uuid_database = self._backend.get_repository_uuid() from aiida.cmdline.utils import echo if repository_uuid_config != repository_uuid_database: @@ -156,8 +157,6 @@ def _load_backend(self, schema_check: bool = True, repository_check: bool = True 'Please make sure that the configuration of your profile is correct.\n' ) - self._backend = backend - # Reconfigure the logging with `with_orm=True` to make sure that profile specific logging configuration options # are taken into account and the `DbLogHandler` is configured. configure_logging(with_orm=True) diff --git a/aiida/manage/tests/main.py b/aiida/manage/tests/main.py index 5557827038..ab88a2aa6c 100644 --- a/aiida/manage/tests/main.py +++ b/aiida/manage/tests/main.py @@ -145,7 +145,7 @@ def __init__(self, profile_name): try: self._profile = load_profile(profile_name) - manager.get_manager()._load_backend(schema_check=False) # pylint: disable=protected-access + manager.get_manager()._load_backend(validate_db=False) # pylint: disable=protected-access except Exception: raise TestManagerError('Unable to load test profile \'{}\'.'.format(profile_name)) check_if_tests_can_run() @@ -160,11 +160,8 @@ def _select_db_test_case(self, backend): from aiida.backends.djsite.db.testbase import DjangoTests self._test_case = DjangoTests() elif backend == BACKEND_SQLA: - from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.testbase import SqlAlchemyTests - self._test_case = SqlAlchemyTests() - self._test_case.test_session = get_scoped_session() def reset_db(self): self._test_case.clean_db() # will drop all users @@ -352,7 +349,7 @@ def create_profile(self): self._profile = profile load_profile(profile_name) - backend = manager.get_manager()._load_backend(schema_check=False) + backend = manager.get_manager()._load_backend(validate_db=False) backend.migrate() self._select_db_test_case(backend=self._profile.database_backend) diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index a0d43a7b43..37b5ce99b8 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from sqlalchemy.orm.session import Session + from aiida.manage.configuration import Profile from aiida.orm.implementation import ( BackendAuthInfoCollection, BackendCommentCollection, @@ -29,11 +30,39 @@ class Backend(abc.ABC): - """The public interface that defines a backend factory that creates backend specific concrete objects.""" + """Abstraction for a backend to read/write persistent data for a profile's provenance graph.""" + + def __init__(self, profile: 'Profile', validate_db: bool = True) -> None: # pylint: disable=unused-argument + """Instatiate the backend. + + :param profile: the profile provides the configuration details for connecting to the persistent storage + :param validate_db: if True, the backend will perform validation tests on the database consistency + """ + self._profile = profile + + @property + def profile(self) -> 'Profile': + """Return the profile used to initialize the backend.""" + return self._profile + + @abc.abstractmethod + def close(self) -> None: + """Close the backend. + + This method is called when the backend is no longer needed, + and should be used to close any open connections to the persistent storage + """ + + @abc.abstractmethod + def reset(self) -> None: + """Reset the backend. + + This method should reset any open connections to the persistent storage + """ @abc.abstractmethod def migrate(self): - """Migrate the database to the latest schema generation or version.""" + """Migrate the persistent storage to the latest schema generation or version.""" @property @abc.abstractmethod diff --git a/aiida/orm/implementation/django/backend.py b/aiida/orm/implementation/django/backend.py index 06c38c2867..71d17fec4f 100644 --- a/aiida/orm/implementation/django/backend.py +++ b/aiida/orm/implementation/django/backend.py @@ -10,46 +10,58 @@ """Django implementation of `aiida.orm.implementation.backends.Backend`.""" from contextlib import contextmanager import os +from typing import TYPE_CHECKING # pylint: disable=import-error,no-name-in-module import django from django.db import models, transaction -from aiida.backends.djsite import get_scoped_session, reset_session from aiida.backends.djsite.manager import DjangoBackendManager from ..sql.backends import SqlBackend +if TYPE_CHECKING: + from sqlalchemy.future.engine import Engine + from sqlalchemy.future.orm import Session + from sqlalchemy.orm.scoping import scoped_session + __all__ = ('DjangoBackend',) +# Django setup can only be called once, see: +# https://docs.djangoproject.com/en/2.2/topics/settings/#calling-django-setup-is-required-for-standalone-django-usage +DJANGO_SETUP_CALLED = False + class DjangoBackend(SqlBackend[models.Model]): """Django implementation of `aiida.orm.implementation.backends.Backend`.""" - def __init__(self): - """Construct the backend instance by initializing all the collections.""" + def __init__(self, profile, validate_db: bool = True): + super().__init__(profile, validate_db) + + # we have to setup Django before importing the Django models + os.environ['DJANGO_SETTINGS_MODULE'] = 'aiida.backends.djsite.settings' + global DJANGO_SETUP_CALLED # pylint: disable=global-statement + if not DJANGO_SETUP_CALLED: + django.setup() # pylint: disable=no-member + DJANGO_SETUP_CALLED = True + from . import authinfos, comments, computers, groups, logs, nodes, users - super().__init__() self._authinfos = authinfos.DjangoAuthInfoCollection(self) self._comments = comments.DjangoCommentCollection(self) self._computers = computers.DjangoComputerCollection(self) self._groups = groups.DjangoGroupCollection(self) self._logs = logs.DjangoLogCollection(self) self._nodes = nodes.DjangoNodeCollection(self) - self._backend_manager = DjangoBackendManager() + self._backend_manager = DjangoBackendManager(self) self._users = users.DjangoUserCollection(self) - @classmethod - def load_environment(cls, profile, validate_schema=True, **kwargs): - os.environ['DJANGO_SETTINGS_MODULE'] = 'aiida.backends.djsite.settings' - django.setup() # pylint: disable=no-member - # For QueryBuilder only - get_scoped_session(profile, **kwargs) - if validate_schema: - DjangoBackendManager().validate_schema(profile) + if validate_db: + self.get_session() # ensure that the database is accessible + self._backend_manager.validate_schema(profile) - def reset_environment(self): - reset_session() + @property + def backend_manager(self): + return self._backend_manager def migrate(self): self._backend_manager.migrate() @@ -91,17 +103,6 @@ def transaction(): """Open a transaction to be used as a context manager.""" return transaction.atomic() - @staticmethod - def get_session(): - """Return a database session that can be used by the `QueryBuilder` to perform its query. - - If there is an exception within the context then the changes will be rolled back and the state will - be as before entering. Transactions can be nested. - - :return: an instance of :class:`sqlalchemy.orm.session.Session` - """ - return get_scoped_session() - # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` def get_backend_entity(self, model): diff --git a/aiida/orm/implementation/django/comments.py b/aiida/orm/implementation/django/comments.py index be7fe71b9d..39ade31879 100644 --- a/aiida/orm/implementation/django/comments.py +++ b/aiida/orm/implementation/django/comments.py @@ -131,7 +131,7 @@ def delete(self, comment_id): try: models.DbComment.objects.get(id=comment_id).delete() except ObjectDoesNotExist: - raise exceptions.NotExistent(f"Comment with id '{comment_id}' not found") + raise exceptions.NotExistent(f'Comment<{comment_id}> does not exist') def delete_all(self): """ diff --git a/aiida/orm/implementation/django/computers.py b/aiida/orm/implementation/django/computers.py index f345a68e42..874b4c3b3c 100644 --- a/aiida/orm/implementation/django/computers.py +++ b/aiida/orm/implementation/django/computers.py @@ -63,7 +63,7 @@ def store(self): @property def is_stored(self): - return self._dbmodel.id is not None + return self._dbmodel.is_saved() @property def label(self): diff --git a/aiida/orm/implementation/django/entities.py b/aiida/orm/implementation/django/entities.py index 87d182f664..e7cd437411 100644 --- a/aiida/orm/implementation/django/entities.py +++ b/aiida/orm/implementation/django/entities.py @@ -84,7 +84,7 @@ def is_stored(self): :return: True if stored, False otherwise """ - return self._dbmodel.id is not None + return self._dbmodel.is_saved() def store(self): """ diff --git a/aiida/orm/implementation/django/groups.py b/aiida/orm/implementation/django/groups.py index 50082ea707..85338c093e 100644 --- a/aiida/orm/implementation/django/groups.py +++ b/aiida/orm/implementation/django/groups.py @@ -282,5 +282,5 @@ def query( return retlist - def delete(self, id): # pylint: disable=redefined-builtin - models.DbGroup.objects.filter(id=id).delete() + def delete(self, pk): + models.DbGroup.objects.filter(id=pk).delete() diff --git a/aiida/orm/implementation/django/logs.py b/aiida/orm/implementation/django/logs.py index 7b3b725c2c..03f4bdef29 100644 --- a/aiida/orm/implementation/django/logs.py +++ b/aiida/orm/implementation/django/logs.py @@ -107,7 +107,7 @@ def delete(self, log_id): try: models.DbLog.objects.get(id=log_id).delete() except ObjectDoesNotExist: - raise exceptions.NotExistent(f"Log with id '{log_id}' not found") + raise exceptions.NotExistent(f'Log<{log_id}> does not exist') def delete_all(self): """ diff --git a/aiida/orm/implementation/groups.py b/aiida/orm/implementation/groups.py index 724a7d33e7..42d8f10622 100644 --- a/aiida/orm/implementation/groups.py +++ b/aiida/orm/implementation/groups.py @@ -274,9 +274,8 @@ def get(self, **filters): return results[0] @abc.abstractmethod - def delete(self, id): # pylint: disable=redefined-builtin, invalid-name - """ - Delete a group with the given id + def delete(self, pk): + """Delete a group with the given pk - :param id: the id of the group to delete + :param pk: the primary key of the group to delete """ diff --git a/aiida/orm/implementation/sql/backends.py b/aiida/orm/implementation/sql/backends.py index 1ab0711aa7..618c3ad48e 100644 --- a/aiida/orm/implementation/sql/backends.py +++ b/aiida/orm/implementation/sql/backends.py @@ -9,17 +9,24 @@ ########################################################################### """Generic backend related objects""" import abc -import typing +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from .. import backends +if TYPE_CHECKING: + from sqlalchemy.future.engine import Engine + from sqlalchemy.future.orm import Session + from sqlalchemy.orm.scoping import scoped_session + + from aiida.backends.manager import BackendManager + __all__ = ('SqlBackend',) # The template type for the base ORM model type -ModelType = typing.TypeVar('ModelType') # pylint: disable=invalid-name +ModelType = TypeVar('ModelType') # pylint: disable=invalid-name -class SqlBackend(typing.Generic[ModelType], backends.Backend): +class SqlBackend(Generic[ModelType], backends.Backend): """ A class for SQL based backends. Assumptions are that: * there is an ORM @@ -29,10 +36,44 @@ class SqlBackend(typing.Generic[ModelType], backends.Backend): if any of these assumptions do not fit then just implement a backend from :class:`aiida.orm.implementation.Backend` """ - def __init__(self) -> None: - from aiida.backends.manager import BackendManager - super().__init__() - self._backend_manager = BackendManager() + def __init__(self, profile, validate_db: bool = True): + super().__init__(profile, validate_db) + # set variables for QueryBuilder + self._engine: Optional['Engine'] = None + self._session_factory: Optional['scoped_session'] = None + + def get_session(self, **kwargs: Any) -> 'Session': + """Return a database session that can be used by the `QueryBuilderBackend`. + + On first call (or after a reset) the session is initialised, then the same session is always returned. + + :param kwargs: keyword arguments to be passed to the engine + """ + from aiida.backends.utils import create_scoped_session_factory, create_sqlalchemy_engine + if self._session_factory is not None: + return self._session_factory() + if self._engine is None: + self._engine = create_sqlalchemy_engine(self._profile, **kwargs) + self._session_factory = create_scoped_session_factory(self._engine) + return self._session_factory() + + def close(self): + if self._session_factory is not None: + # these methods are proxied from the session + self._session_factory.expunge_all() # pylint: disable=no-member + self._session_factory.remove() # pylint: disable=no-member + self._session_factory = None + if self._engine is not None: + self._engine.dispose() + self._engine = None + + def reset(self): + self.close() # close the connection, so that it will be regenerated on get_session + + @property + @abc.abstractmethod + def backend_manager(self) -> 'BackendManager': + """Return the backend manager.""" @abc.abstractmethod def get_backend_entity(self, model): @@ -80,54 +121,40 @@ def execute_prepared_statement(self, sql, parameters): return results - @classmethod - @abc.abstractmethod - def load_environment(cls, profile, validate_schema=True, **kwargs): - """Load the backend environment. - - :param profile: the profile whose backend environment to load - :param validate_schema: boolean, if True, validate the schema after loading the environment. - :param kwargs: keyword arguments that will be passed on to the backend specific scoped session getter function. - """ - - @abc.abstractmethod - def reset_environment(self): - """Reset the backend environment.""" - def get_repository_uuid(self): """Return the UUID of the repository that is associated with this database. :return: the UUID of the repository associated with this database or None if it doesn't exist. """ - return self._backend_manager.get_repository_uuid() + return self.backend_manager.get_repository_uuid() def get_schema_generation_database(self): """Return the database schema version. :return: `distutils.version.LooseVersion` with schema version of the database """ - return self._backend_manager.get_schema_generation_database() + return self.backend_manager.get_schema_generation_database() def get_schema_version_database(self): """Return the database schema version. :return: `distutils.version.LooseVersion` with schema version of the database """ - return self._backend_manager.get_schema_version_database() + return self.backend_manager.get_schema_version_database() - def set_value(self, key: str, value: typing.Any, description: typing.Optional[str] = None) -> None: + def set_value(self, key: str, value: Any, description: Optional[str] = None) -> None: """Set a global key/value pair on the profile backend. :param key: the key identifying the setting :param value: the value for the setting :param description: optional setting description """ - return self._backend_manager.get_settings_manager().set(key, value, description) + return self.backend_manager.get_settings_manager().set(key, value, description) - def get_value(self, key: str) -> typing.Any: + def get_value(self, key: str) -> Any: """Get a global key/value pair on the profile backend. :param key: the key identifying the setting :raises: `~aiida.common.exceptions.NotExistent` if the settings does not exist """ - return self._backend_manager.get_settings_manager().get(key) + return self.backend_manager.get_settings_manager().get(key) diff --git a/aiida/orm/implementation/sqlalchemy/authinfos.py b/aiida/orm/implementation/sqlalchemy/authinfos.py index 5f8b266ca0..4305705d4d 100644 --- a/aiida/orm/implementation/sqlalchemy/authinfos.py +++ b/aiida/orm/implementation/sqlalchemy/authinfos.py @@ -8,8 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for the SqlAlchemy backend implementation of the `AuthInfo` ORM class.""" +from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound -from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models.authinfo import DbAuthInfo from aiida.common import exceptions from aiida.common.lang import type_check @@ -34,7 +34,9 @@ def __init__(self, backend, computer, user): super().__init__(backend) type_check(user, users.SqlaUser) type_check(computer, computers.SqlaComputer) - self._dbmodel = utils.ModelWrapper(DbAuthInfo(dbcomputer=computer.dbmodel, aiidauser=user.dbmodel)) + self._dbmodel = utils.StorableModel( + DbAuthInfo(dbcomputer=computer.dbmodel, aiidauser=user.dbmodel), self._backend + ) @property def id(self): # pylint: disable=invalid-name @@ -116,20 +118,12 @@ class SqlaAuthInfoCollection(BackendAuthInfoCollection): ENTITY_CLASS = SqlaAuthInfo def delete(self, pk): - """Delete an entry from the collection. - - :param pk: the pk of the entry to delete - """ - # pylint: disable=import-error,no-name-in-module - from sqlalchemy.orm.exc import NoResultFound - - session = get_scoped_session() - - try: - session.query(DbAuthInfo).filter_by(id=pk).one().delete() - session.commit() - except NoResultFound: + session = self.backend.get_session() + inst = session.get(DbAuthInfo, pk) + if inst is None: raise exceptions.NotExistent(f'AuthInfo<{pk}> does not exist') + session.delete(inst) + session.commit() def get(self, computer, user): """Return an entry from the collection that is configured for the given computer and user @@ -140,10 +134,7 @@ def get(self, computer, user): :raise aiida.common.exceptions.NotExistent: if no entry exists for the computer/user pair :raise aiida.common.exceptions.MultipleObjectsError: if multiple entries exist for the computer/user pair """ - # pylint: disable=import-error,no-name-in-module - from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound - - session = get_scoped_session() + session = self.backend.get_session() try: authinfo = session.query(DbAuthInfo).filter_by(dbcomputer_id=computer.id, aiidauser_id=user.id).one() diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index 516e4c2c34..09c19dcb95 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -10,7 +10,6 @@ """SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`.""" from contextlib import contextmanager -from aiida.backends.sqlalchemy import get_scoped_session, reset_session from aiida.backends.sqlalchemy.manager import SqlaBackendManager from aiida.backends.sqlalchemy.models import base @@ -23,26 +22,24 @@ class SqlaBackend(SqlBackend[base.Base]): """SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`.""" - def __init__(self): - """Construct the backend instance by initializing all the collections.""" - super().__init__() + def __init__(self, profile, validate_db: bool = True): # pylint: disable=missing-function-docstring + super().__init__(profile, validate_db) self._authinfos = authinfos.SqlaAuthInfoCollection(self) self._comments = comments.SqlaCommentCollection(self) self._computers = computers.SqlaComputerCollection(self) self._groups = groups.SqlaGroupCollection(self) self._logs = logs.SqlaLogCollection(self) self._nodes = nodes.SqlaNodeCollection(self) - self._backend_manager = SqlaBackendManager() + self._backend_manager = SqlaBackendManager(self) self._users = users.SqlaUserCollection(self) - @classmethod - def load_environment(cls, profile, validate_schema=True, **kwargs): - get_scoped_session(profile, **kwargs) - if validate_schema: - SqlaBackendManager().validate_schema(profile) + if validate_db: + self.get_session() # ensure that the database is accessible + self._backend_manager.validate_schema(profile) - def reset_environment(self): # pylint: disable=no-self-use - reset_session() + @property + def backend_manager(self): + return self._backend_manager def migrate(self): self._backend_manager.migrate() @@ -89,19 +86,12 @@ def transaction(self): if session.in_transaction(): with session.begin_nested(): yield session + session.commit() else: with session.begin(): with session.begin_nested(): yield session - @staticmethod - def get_session(): - """Return a database session that can be used by the `QueryBuilder` to perform its query. - - :return: an instance of :class:`sqlalchemy.orm.session.Session` - """ - return get_scoped_session() - # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` def get_backend_entity(self, model): @@ -115,9 +105,8 @@ def cursor(self): :return: a psycopg cursor :rtype: :class:`psycopg2.extensions.cursor` """ - from aiida.backends import sqlalchemy as sa try: - connection = sa.ENGINE.raw_connection() + connection = self.get_session().bind.raw_connection() yield connection.cursor() finally: self.get_connection().close() @@ -141,11 +130,9 @@ def execute_raw(self, query): return results - @staticmethod - def get_connection(): + def get_connection(self): """Get the SQLA database connection - :return: the SQLA database connection + :return: the raw SQLA database connection """ - from aiida.backends import sqlalchemy as sa - return sa.ENGINE.raw_connection() + return self.get_session().bind.raw_connection() diff --git a/aiida/orm/implementation/sqlalchemy/comments.py b/aiida/orm/implementation/sqlalchemy/comments.py index da100140dd..d5a99ecd5e 100644 --- a/aiida/orm/implementation/sqlalchemy/comments.py +++ b/aiida/orm/implementation/sqlalchemy/comments.py @@ -8,13 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """SQLA implementations for the Comment entity and collection.""" -# pylint: disable=import-error,no-name-in-module - from datetime import datetime -from sqlalchemy.orm.exc import NoResultFound - -from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models import comment as models from aiida.common import exceptions, lang @@ -56,7 +51,7 @@ def __init__(self, backend, node, user, content=None, ctime=None, mtime=None): lang.type_check(mtime, datetime, f'the given mtime is of type {type(mtime)}') arguments['mtime'] = mtime - self._dbmodel = utils.ModelWrapper(models.DbComment(**arguments)) + self._dbmodel = utils.StorableModel(models.DbComment(**arguments), self._backend) def store(self): """Can only store if both the node and user are stored as well.""" @@ -113,26 +108,15 @@ def create(self, node, user, content=None, **kwargs): return SqlaComment(self.backend, node, user, content, **kwargs) # pylint: disable=abstract-class-instantiated def delete(self, comment_id): - """ - Remove a Comment from the collection with the given id - - :param comment_id: the id of the comment to delete - :type comment_id: int - - :raises TypeError: if ``comment_id`` is not an `int` - :raises `~aiida.common.exceptions.NotExistent`: if Comment with ID ``comment_id`` is not found - """ if not isinstance(comment_id, int): raise TypeError('comment_id must be an int') - session = get_scoped_session() - - try: - session.query(models.DbComment).filter_by(id=comment_id).one().delete() - session.commit() - except NoResultFound: - session.rollback() - raise exceptions.NotExistent(f"Comment with id '{comment_id}' not found") + session = self.backend.get_session() + inst = session.get(models.DbComment, comment_id) + if inst is None: + raise exceptions.NotExistent(f'Comment<{comment_id}> does not exist') + session.delete(inst) + session.commit() def delete_all(self): """ @@ -140,7 +124,7 @@ def delete_all(self): :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted """ - session = get_scoped_session() + session = self.backend.get_session() try: session.query(models.DbComment).delete() diff --git a/aiida/orm/implementation/sqlalchemy/computers.py b/aiida/orm/implementation/sqlalchemy/computers.py index 525fae15b4..58804c2d6d 100644 --- a/aiida/orm/implementation/sqlalchemy/computers.py +++ b/aiida/orm/implementation/sqlalchemy/computers.py @@ -15,7 +15,6 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.session import make_transient -from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models.computer import DbComputer from aiida.common import exceptions from aiida.orm.implementation.computers import BackendComputer, BackendComputerCollection @@ -32,7 +31,7 @@ class SqlaComputer(entities.SqlaModelEntity[DbComputer], BackendComputer): def __init__(self, backend, **kwargs): super().__init__(backend) - self._dbmodel = utils.ModelWrapper(DbComputer(**kwargs)) + self._dbmodel = utils.StorableModel(DbComputer(**kwargs), self._backend) @property def uuid(self): @@ -48,11 +47,11 @@ def id(self): # pylint: disable=invalid-name @property def is_stored(self): - return self._dbmodel.id is not None + return self._dbmodel.is_saved() def copy(self): """Create an unstored clone of an already stored `Computer`.""" - session = get_scoped_session() + session = self.backend.get_session() if not self.is_stored: raise exceptions.InvalidOperation('You can copy a computer only after having stored it') @@ -128,15 +127,17 @@ class SqlaComputerCollection(BackendComputerCollection): ENTITY_CLASS = SqlaComputer - @staticmethod - def list_names(): - session = get_scoped_session() + def list_names(self): + session = self.backend.get_session() return session.query(DbComputer.label).all() def delete(self, pk): + session = self.backend.get_session() + inst = session.get(DbComputer, pk) + if inst is None: + raise exceptions.NotExistent(f'AuthInfo<{pk}> does not exist') try: - session = get_scoped_session() - session.get(DbComputer, pk).delete() + session.delete(inst) session.commit() except SQLAlchemyError as exc: raise exceptions.InvalidOperation( diff --git a/aiida/orm/implementation/sqlalchemy/entities.py b/aiida/orm/implementation/sqlalchemy/entities.py index 7f0ed60295..17a4840801 100644 --- a/aiida/orm/implementation/sqlalchemy/entities.py +++ b/aiida/orm/implementation/sqlalchemy/entities.py @@ -45,7 +45,7 @@ def from_dbmodel(cls, dbmodel, backend): type_check(backend, SqlaBackend) entity = cls.__new__(cls) super(SqlaModelEntity, entity).__init__(backend) - entity._dbmodel = utils.ModelWrapper(dbmodel) # pylint: disable=protected-access + entity._dbmodel = utils.StorableModel(dbmodel, backend) # pylint: disable=protected-access return entity @classmethod @@ -92,7 +92,7 @@ def is_stored(self): :return: True if stored, False otherwise """ - return self._dbmodel.id is not None + return self._dbmodel.is_saved() def store(self): """ diff --git a/aiida/orm/implementation/sqlalchemy/groups.py b/aiida/orm/implementation/sqlalchemy/groups.py index 8b8e991c3b..e36869a08b 100644 --- a/aiida/orm/implementation/sqlalchemy/groups.py +++ b/aiida/orm/implementation/sqlalchemy/groups.py @@ -8,14 +8,12 @@ # For further information please visit http://www.aiida.net # ########################################################################### """SQLA groups""" - from collections.abc import Iterable import logging -from aiida.backends import sqlalchemy as sa from aiida.backends.sqlalchemy.models.group import DbGroup, table_groups_nodes from aiida.backends.sqlalchemy.models.node import DbNode -from aiida.common.exceptions import UniquenessError +from aiida.common.exceptions import NotExistent, UniquenessError from aiida.common.lang import type_check from aiida.orm.implementation.groups import BackendGroup, BackendGroupCollection @@ -47,7 +45,7 @@ def __init__(self, backend, label, user, description='', type_string=''): super().__init__(backend) dbgroup = DbGroup(label=label, description=description, user=user.dbmodel, type_string=type_string) - self._dbmodel = utils.ModelWrapper(dbgroup) + self._dbmodel = utils.StorableModel(dbgroup, self._backend) @property def label(self): @@ -113,7 +111,7 @@ def __int__(self): @property def is_stored(self): - return self.pk is not None + return self._dbmodel.is_saved() def store(self): self._dbmodel.save() @@ -124,15 +122,15 @@ def count(self): :return: integer number of entities contained within the group """ - from aiida.backends.sqlalchemy import get_scoped_session - session = get_scoped_session() + + session = self.backend.get_session() return session.query(self.MODEL_CLASS).join(self.MODEL_CLASS.dbnodes).filter(DbGroup.id == self.pk).count() def clear(self): """Remove all the nodes from this group.""" - from aiida.backends.sqlalchemy import get_scoped_session - session = get_scoped_session() - # Note we have to call `dbmodel` and `_dbmodel` to circumvent the `ModelWrapper` + + session = self.backend.get_session() + # Note we have to call `dbmodel` and `_dbmodel` to circumvent the `StorableModel` self.dbmodel.dbnodes = [] session.commit() @@ -184,7 +182,6 @@ def add_nodes(self, nodes, **kwargs): from sqlalchemy.dialects.postgresql import insert # pylint: disable=import-error, no-name-in-module from sqlalchemy.exc import IntegrityError # pylint: disable=import-error, no-name-in-module - from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models.base import Base from aiida.orm.implementation.sqlalchemy.nodes import SqlaNode @@ -199,7 +196,7 @@ def check_node(given_node): if not given_node.is_stored: raise ValueError('At least one of the provided nodes is unstored, stopping...') - with utils.disable_expire_on_commit(get_scoped_session()) as session: + with utils.disable_expire_on_commit(self.backend.get_session()) as session: if not skip_orm: # Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database dbnodes = self._dbmodel.dbnodes @@ -241,7 +238,6 @@ def remove_nodes(self, nodes, **kwargs): """ from sqlalchemy import and_ - from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models.base import Base from aiida.orm.implementation.sqlalchemy.nodes import SqlaNode @@ -260,7 +256,7 @@ def check_node(node): list_nodes = [] - with utils.disable_expire_on_commit(get_scoped_session()) as session: + with utils.disable_expire_on_commit(self.backend.get_session()) as session: if not skip_orm: for node in nodes: check_node(node) @@ -303,7 +299,7 @@ def query( # pylint: disable=too-many-branches from aiida.orm.implementation.sqlalchemy.nodes import SqlaNode - session = sa.get_scoped_session() + session = self.backend.get_session() filters = [] @@ -365,8 +361,10 @@ def query( return [SqlaGroup.from_dbmodel(group, self._backend) for group in groups] # pylint: disable=no-member - def delete(self, id): # pylint: disable=redefined-builtin - session = sa.get_scoped_session() - - session.get(DbGroup, id).delete() + def delete(self, pk): + session = self.backend.get_session() + inst = session.get(DbGroup, pk) + if inst is None: + raise NotExistent(f'Group<{pk}> does not exist') + session.delete(inst) session.commit() diff --git a/aiida/orm/implementation/sqlalchemy/logs.py b/aiida/orm/implementation/sqlalchemy/logs.py index b4d75ad6ac..18acf29068 100644 --- a/aiida/orm/implementation/sqlalchemy/logs.py +++ b/aiida/orm/implementation/sqlalchemy/logs.py @@ -9,10 +9,6 @@ ########################################################################### """SQLA Log and LogCollection module""" # pylint: disable=import-error,no-name-in-module - -from sqlalchemy.orm.exc import NoResultFound - -from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models import log as models from aiida.common import exceptions @@ -28,7 +24,7 @@ class SqlaLog(entities.SqlaModelEntity[models.DbLog], BackendLog): def __init__(self, backend, time, loggername, levelname, dbnode_id, message='', metadata=None): # pylint: disable=too-many-arguments super().__init__(backend) - self._dbmodel = utils.ModelWrapper( + self._dbmodel = utils.StorableModel( models.DbLog( time=time, loggername=loggername, @@ -36,7 +32,7 @@ def __init__(self, backend, time, loggername, levelname, dbnode_id, message='', dbnode_id=dbnode_id, message=message, metadata=metadata - ) + ), self._backend ) @property @@ -95,26 +91,15 @@ class SqlaLogCollection(BackendLogCollection): ENTITY_CLASS = SqlaLog def delete(self, log_id): - """ - Remove a Log entry from the collection with the given id - - :param log_id: id of the Log to delete - :type log_id: int - - :raises TypeError: if ``log_id`` is not an `int` - :raises `~aiida.common.exceptions.NotExistent`: if Log with ID ``log_id`` is not found - """ if not isinstance(log_id, int): raise TypeError('log_id must be an int') - session = get_scoped_session() - - try: - session.query(models.DbLog).filter_by(id=log_id).one().delete() - session.commit() - except NoResultFound: - session.rollback() - raise exceptions.NotExistent(f"Log with id '{log_id}' not found") + session = self.backend.get_session() + inst = session.get(models.DbLog, log_id) + if inst is None: + raise exceptions.NotExistent(f'Log<{log_id}> does not exist') + session.delete(inst) + session.commit() def delete_all(self): """ @@ -122,7 +107,7 @@ def delete_all(self): :raises `~aiida.common.exceptions.IntegrityError`: if all Logs could not be deleted """ - session = get_scoped_session() + session = self.backend.get_session() try: session.query(models.DbLog).delete() diff --git a/aiida/orm/implementation/sqlalchemy/nodes.py b/aiida/orm/implementation/sqlalchemy/nodes.py index db324ba280..b5683edbcd 100644 --- a/aiida/orm/implementation/sqlalchemy/nodes.py +++ b/aiida/orm/implementation/sqlalchemy/nodes.py @@ -15,7 +15,6 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.exc import NoResultFound -from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models import node as models from aiida.common import exceptions from aiida.common.lang import type_check @@ -83,7 +82,7 @@ def __init__( type_check(mtime, datetime, f'the given mtime is of type {type(mtime)}') arguments['mtime'] = mtime - self._dbmodel = sqla_utils.ModelWrapper(models.DbNode(**arguments)) + self._dbmodel = sqla_utils.StorableModel(models.DbNode(**arguments), self._backend) def clone(self): """Return an unstored clone of ourselves. @@ -103,7 +102,7 @@ def clone(self): clone = self.__class__.__new__(self.__class__) # pylint: disable=no-value-for-parameter clone.__init__(self.backend, self.node_type, self.user) - clone._dbmodel = sqla_utils.ModelWrapper(models.DbNode(**arguments)) # pylint: disable=protected-access + clone._dbmodel = sqla_utils.StorableModel(models.DbNode(**arguments), self.backend) # pylint: disable=protected-access return clone @property @@ -158,7 +157,7 @@ def add_incoming(self, source, link_type, link_label): :return: True if the proposed link is allowed, False otherwise :raise aiida.common.ModificationNotAllowed: if either source or target node is not stored """ - session = get_scoped_session() + session = self.backend.get_session() type_check(source, SqlaNode) @@ -180,7 +179,7 @@ def _add_link(self, source, link_type, link_label): """ from aiida.backends.sqlalchemy.models.node import DbLink - session = get_scoped_session() + session = self.backend.get_session() try: with session.begin_nested(): @@ -200,7 +199,7 @@ def store(self, links=None, with_transaction=True, clean=True): # pylint: disab :param with_transaction: if False, do not use a transaction because the caller will already have opened one. :param clean: boolean, if True, will clean the attributes and extras before attempting to store """ - session = get_scoped_session() + session = self.backend.get_session() if clean: self.clean_values() @@ -231,7 +230,7 @@ def get(self, pk): :param pk: id of the node """ - session = get_scoped_session() + session = self.backend.get_session() try: return self.ENTITY_CLASS.from_dbmodel(session.query(models.DbNode).filter_by(id=pk).one(), self.backend) @@ -239,14 +238,9 @@ def get(self, pk): raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound def delete(self, pk): - """Remove a Node entry from the collection with the given id - - :param pk: id of the node to delete - """ - session = get_scoped_session() - - try: - session.query(models.DbNode).filter_by(id=pk).one().delete() - session.commit() - except NoResultFound: - raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound + session = self.backend.get_session() + inst = session.get(models.DbNode, pk) + if inst is None: + raise exceptions.NotExistent(f'Node<{pk}> does not exist') + session.delete(inst) + session.commit() diff --git a/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py b/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py index 08dd41ee58..3a81923ab6 100644 --- a/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py +++ b/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py @@ -19,7 +19,7 @@ from sqlalchemy.sql.schema import Table from sqlalchemy.types import Integer -from aiida.backends.sqlalchemy.models.base import Model +from aiida.backends.sqlalchemy.models.base import Base from aiida.common.links import LinkType try: @@ -34,31 +34,31 @@ class _EntityMapper(Protocol): # pylint: disable=invalid-name @property - def Node(self) -> Type[Model]: + def Node(self) -> Type[Base]: ... @property - def Group(self) -> Type[Model]: + def Group(self) -> Type[Base]: ... @property - def Link(self) -> Type[Model]: + def Link(self) -> Type[Base]: ... @property - def User(self) -> Type[Model]: + def User(self) -> Type[Base]: ... @property - def Computer(self) -> Type[Model]: + def Computer(self) -> Type[Base]: ... @property - def Comment(self) -> Type[Model]: + def Comment(self) -> Type[Base]: ... @property - def Log(self) -> Type[Model]: + def Log(self) -> Type[Base]: ... @property @@ -72,7 +72,7 @@ class JoinReturn(NamedTuple): FilterType = Dict[str, Any] # pylint: disable=invalid-name -JoinFuncType = Callable[[Query, Type[Model], Type[Model], bool, FilterType, bool], JoinReturn] # pylint: disable=invalid-name +JoinFuncType = Callable[[Query, Type[Base], Type[Base], bool, FilterType, bool], JoinReturn] # pylint: disable=invalid-name class SqlaJoiner: diff --git a/aiida/orm/implementation/sqlalchemy/users.py b/aiida/orm/implementation/sqlalchemy/users.py index f6ddbe2e76..83cb77c74f 100644 --- a/aiida/orm/implementation/sqlalchemy/users.py +++ b/aiida/orm/implementation/sqlalchemy/users.py @@ -24,8 +24,8 @@ class SqlaUser(entities.SqlaModelEntity[DbUser], BackendUser): def __init__(self, backend, email, first_name, last_name, institution): # pylint: disable=too-many-arguments super().__init__(backend) - self._dbmodel = utils.ModelWrapper( - DbUser(email=email, first_name=first_name, last_name=last_name, institution=institution) + self._dbmodel = utils.StorableModel( + DbUser(email=email, first_name=first_name, last_name=last_name, institution=institution), self._backend ) @property diff --git a/aiida/orm/implementation/sqlalchemy/utils.py b/aiida/orm/implementation/sqlalchemy/utils.py index 42607c31c4..52baa2affc 100644 --- a/aiida/orm/implementation/sqlalchemy/utils.py +++ b/aiida/orm/implementation/sqlalchemy/utils.py @@ -8,34 +8,43 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Utilities for the implementation of the SqlAlchemy backend.""" - import contextlib +from typing import TYPE_CHECKING, Tuple -# pylint: disable=import-error,no-name-in-module from sqlalchemy import inspect -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError, InvalidRequestError +from sqlalchemy.orm import Session from sqlalchemy.orm.attributes import flag_modified -from aiida.backends.sqlalchemy import get_scoped_session from aiida.common import exceptions +if TYPE_CHECKING: + from aiida.backends.sqlalchemy.models.base import Base + + from .backend import SqlaBackend + IMMUTABLE_MODEL_FIELDS = {'id', 'pk', 'uuid', 'node_type'} -class ModelWrapper: - """Wrap a database model instance to correctly update and flush the data model when getting or setting a field. +class StorableModel: + """Class to create a storable ORM Model. - If the model is not stored, the behavior of the get and set attributes is unaltered. However, if the model is - stored, which is to say, it has a primary key, the `getattr` and `setattr` are modified as follows: + This class takes as input an SQLAlchemy ORM model instance (defining a row of a DB table), + and an AiiDA backend (providing an interface to a single database via a session), + and then provides methods to store and update the model in the database. - * `getattr`: if the item corresponds to a mutable model field, the model instance is refreshed first - * `setattr`: if the item corresponds to a mutable model field, changes are flushed after performing the change + An instance is deemed as stored in the database if it has been added to the session of the backend. + Note, this does not mean it has actually been saved in the database, + for example in a transaction when it will only be committed at the end of the transaction. + + Once an instance is stored, AiiDA enforces that certain DB fields are immutable, + and so changes to these fields are not flushed/refreshed to/from the database. """ # pylint: disable=too-many-instance-attributes - def __init__(self, model, auto_flush=()): - """Construct the ModelWrapper. + def __init__(self, model: 'Base', backend: 'SqlaBackend', auto_flush: Tuple[str, ...] = ()): + """Initialize the storable model. :param model: the database model instance to wrap :param auto_flush: an optional tuple of database model fields that are always to be flushed, in addition to @@ -44,6 +53,7 @@ def __init__(self, model, auto_flush=()): super().__init__() # Have to do it this way because we overwrite __setattr__ object.__setattr__(self, '_model', model) + object.__setattr__(self, '_backend', backend) object.__setattr__(self, '_auto_flush', auto_flush) def __getattr__(self, item): @@ -55,10 +65,10 @@ def __getattr__(self, item): :param item: the name of the model field :return: the value of the model's attribute """ - # Python 3's implementation of copy.copy does not call __init__ on the new object - # but manually restores attributes instead. Make sure we never get into a recursive - # loop by protecting the only special variable here: _model - if item == '_model': + # Python 3's implementation of copy.copy does not call __init__ on the new object but + # manually restores attributes instead. + # Make sure we never get into a recursive loop by protecting special variables + if item in ('_model', '_backend', '_auto_flush'): raise AttributeError() if self.is_saved() and self._is_mutable_model_field(item) and not self._in_transaction(): @@ -79,29 +89,52 @@ def __setattr__(self, key, value): fields = set((key,) + self._auto_flush) self._flush(fields=fields) + @property + def _session(self) -> Session: + """Return the session of the backend. + + :return: the session of the backend + """ + return self._backend.get_session() + def is_saved(self): """Return whether the wrapped model instance is saved in the database. :return: boolean, True if the model is saved in the database, False otherwise """ - # we should not flush here since it may lead to IntegrityErrors - # which are handled later in the save method - with self._model.session.no_autoflush: - return self._model.id is not None + return self._model in self._session + + def _in_transaction(self): + """Return whether the backend session is within a SAVEPOINT database transaction. + + :return: boolean, True if currently in SAVEPOINT transaction, False otherwise. + """ + return self._session.in_nested_transaction() def save(self): """Store the model instance. - .. note:: If one is currently in a transaction, this method is a no-op. + .. note:: + + If one is currently in a transaction, this method will only add the model to the session but not commit it. :raises `aiida.common.IntegrityError`: if a database integrity error is raised during the save. """ + session = self._session try: - commit = not self._in_transaction() - self._model.save(commit=commit) - except IntegrityError as exception: - self._model.session.rollback() - raise exceptions.IntegrityError(str(exception)) + session.add(self._model) + except InvalidRequestError as exc: + msg = ( + f'The entity {type(self._model)} could not be stored, ' + f'because it, or a joined entity, is already associated with another backend: {exc}' + ) + raise exceptions.StoringNotAllowed(msg) from exc + if not self._in_transaction(): + try: + session.commit() + except IntegrityError as exception: + session.rollback() + raise exceptions.IntegrityError(str(exception)) def _is_mutable_model_field(self, field): """Return whether the field is a mutable field of the model. @@ -129,8 +162,8 @@ def _flush(self, fields=()): """ if self.is_saved(): for field in fields: - flag_modified(self._model, field) - + if field in self._model.__dict__: + flag_modified(self._model, field) self.save() def _ensure_model_uptodate(self, fields=None): @@ -138,24 +171,12 @@ def _ensure_model_uptodate(self, fields=None): :param fields: optionally refresh only these fields, if `None` all fields are refreshed. """ - self._model.session.expire(self._model, attribute_names=fields) - - @staticmethod - def _in_transaction(): - """Return whether the current scope is within an open database transaction. - - :return: boolean, True if currently in open transaction, False otherwise. - """ - return get_scoped_session().in_nested_transaction() + self._session.expire(self._model, attribute_names=fields) @contextlib.contextmanager -def disable_expire_on_commit(session): - """Context manager that disables expire_on_commit and restores the original value on exit - - :param session: The SQLA session - :type session: :class:`sqlalchemy.orm.session.Session` - """ +def disable_expire_on_commit(session: Session): + """Context manager that disables expire_on_commit and restores the original value on exit""" current_value = session.expire_on_commit session.expire_on_commit = False try: diff --git a/aiida/tools/importexport/dbimport/backends/sqla.py b/aiida/tools/importexport/dbimport/backends/sqla.py index 59614a30b4..62883f84cc 100644 --- a/aiida/tools/importexport/dbimport/backends/sqla.py +++ b/aiida/tools/importexport/dbimport/backends/sqla.py @@ -288,8 +288,8 @@ def import_data_sqla( @contextmanager def sql_transaction(): """A context manager for the transaction.""" - import aiida.backends.sqlalchemy - session = aiida.backends.sqlalchemy.get_scoped_session() + from aiida.manage.manager import get_manager + session = get_manager().get_backend().get_session() try: yield session IMPORT_LOGGER.debug('COMMITTING EVERYTHING...') diff --git a/tests/backends/aiida_django/migrations/test_migrations_common.py b/tests/backends/aiida_django/migrations/test_migrations_common.py index 0784750bd3..939b535e51 100644 --- a/tests/backends/aiida_django/migrations/test_migrations_common.py +++ b/tests/backends/aiida_django/migrations/test_migrations_common.py @@ -35,7 +35,7 @@ def app(self): def setUp(self): """Go to a specific schema version before running tests.""" - from aiida.backends.djsite import get_scoped_session + from aiida.manage.manager import get_manager from aiida.orm import autogroup self.current_autogroup = autogroup.CURRENT_AUTOGROUP @@ -51,7 +51,7 @@ def setUp(self): # Before running the migration, make sure we close the querybuilder session which may still contain references # to objects whose mapping may be invalidated after resetting the schema to an older version. This can block # the migrations so we first expunge those objects by closing the session. - get_scoped_session().close() + get_manager().get_backend().reset() # Reverse to the original migration with Capturing(): diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index 035311b0aa..c67739f5f1 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -17,12 +17,12 @@ from alembic.config import Config from sqlalchemy import column -from aiida.backends import sqlalchemy as sa from aiida.backends.general.migrations import utils from aiida.backends.sqlalchemy import manager from aiida.backends.sqlalchemy.models.base import Base from aiida.backends.sqlalchemy.utils import flag_modified from aiida.backends.testbase import AiidaTestCase +from aiida.manage.configuration import get_profile from .test_utils import new_database @@ -42,20 +42,13 @@ class TestMigrationsSQLA(AiidaTestCase): migrate_from = None migrate_to = None - @classmethod - def setUpClass(cls, *args, **kwargs): - """ - Prepare the test class with the alembivc configuration - """ - super().setUpClass(*args, **kwargs) - cls.manager = manager.SqlaBackendManager() - def setUp(self): """ Go to the migrate_from revision, apply setUpBeforeMigration, then run the migration. """ super().setUp() + self.engine = self.backend.get_session().bind # pylint: disable=no-member from aiida.orm import autogroup self.current_autogroup = autogroup.CURRENT_AUTOGROUP @@ -80,23 +73,32 @@ def _perform_actual_migration(self): """ self.migrate_db_up(self.migrate_to) - def migrate_db_up(self, destination): + @staticmethod + def migrate_db_up(destination): """ Perform a migration upwards (upgrade) with alembic :param destination: the name of the destination migration """ - # Undo all previous real migration of the database - with self.manager.alembic_config() as config: + from aiida.orm.implementation.sqlalchemy.backend import SqlaBackend + + # create a new backend which does not validate the schema version + backend = SqlaBackend(get_profile(), validate_db=False) + with backend._backend_manager.alembic_config() as config: command.upgrade(config, destination) - def migrate_db_down(self, destination): + @staticmethod + def migrate_db_down(destination): """ Perform a migration downwards (downgrade) with alembic :param destination: the name of the destination migration """ - with self.manager.alembic_config() as config: + from aiida.orm.implementation.sqlalchemy.backend import SqlaBackend + + # create a new backend which does not validate the schema version + backend = SqlaBackend(get_profile(), validate_db=False) + with backend._backend_manager.alembic_config() as config: command.downgrade(config, destination) def tearDown(self): @@ -131,13 +133,12 @@ def current_rev(self): Utility method to get the current revision string """ from alembic.migration import MigrationContext # pylint: disable=import-error - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: context = MigrationContext.configure(connection) current_rev = context.get_current_revision() return current_rev - @staticmethod - def get_auto_base(): + def get_auto_base(self): """ Return the automap_base class that automatically inspects the current database and return SQLAlchemy Models. @@ -148,7 +149,7 @@ def get_auto_base(): from alembic.migration import MigrationContext # pylint: disable=import-error from sqlalchemy.ext.automap import automap_base # pylint: disable=import-error,no-name-in-module - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: context = MigrationContext.configure(connection) bind = context.bind @@ -158,15 +159,14 @@ def get_auto_base(): return base - @staticmethod @contextmanager - def get_session(): + def get_session(self): """ Return a session that is properly closed after use. """ from sqlalchemy.orm import Session # pylint: disable=import-error,no-name-in-module - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: session = Session(connection.engine) yield session session.close() @@ -266,6 +266,13 @@ class TestMigrationSchemaVsModelsSchema(AiidaTestCase): db_url_left = None db_url_right = None + @property + def engine(self): + """ + Return the engine of the current backend + """ + return self.backend.get_session().bind # pylint: disable=no-member + def setUp(self): from sqlalchemydiff.util import get_temporary_uri @@ -289,7 +296,7 @@ def setUp(self): # The correction URL to the SQLA database of the current # AiiDA connection - curr_db_url = sa.ENGINE.url + curr_db_url = self.engine.url # Create new urls for the two new databases self.db_url_left = get_temporary_uri(str(curr_db_url)) @@ -346,7 +353,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbUser = self.get_auto_base().classes.db_dbuser # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -394,7 +401,7 @@ def test_verify_migration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -536,7 +543,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbUser = self.get_auto_base().classes.db_dbuser # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -581,7 +588,7 @@ def test_attribute_key_changes(self): not_found = tuple([0]) - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -635,7 +642,7 @@ def tearDown(self): DbWorkflow = self.get_auto_base().classes.db_dbworkflow # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) session.query(DbWorkflow).delete() @@ -662,7 +669,7 @@ def setUpBeforeMigration(self): DbWorkflow = self.get_auto_base().classes.db_dbworkflow # pylint: disable=invalid-name DbLog = self.get_auto_base().classes.db_dblog # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -830,7 +837,7 @@ def test_dblog_calculation_node(self): DbLog = self.get_auto_base().classes.db_dblog # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -874,7 +881,7 @@ def test_metadata_correctness(self): DbLog = self.get_auto_base().classes.db_dblog # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) metadata = list(session.query(DbLog).with_entities(getattr(DbLog, 'metadata')).all()) @@ -901,7 +908,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbLog = self.get_auto_base().classes.db_dblog # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -964,7 +971,7 @@ def test_objpk_objname(self): DbLog = self.get_auto_base().classes.db_dblog # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1010,7 +1017,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbLog = self.get_auto_base().classes.db_dblog # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1046,7 +1053,7 @@ def test_dblog_unique_uuids(self): DbLog = self.get_auto_base().classes.db_dblog # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) l_uuids = list(session.query(DbLog).with_entities(getattr(DbLog, 'uuid')).all()) @@ -1067,7 +1074,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbUser = self.get_auto_base().classes.db_dbuser # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1096,7 +1103,7 @@ def test_data_node_type_string(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1133,7 +1140,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbUser = self.get_auto_base().classes.db_dbuser # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1165,7 +1172,7 @@ def test_trajectory_symbols(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1193,7 +1200,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbUser = self.get_auto_base().classes.db_dbuser # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1222,7 +1229,7 @@ def test_data_node_type_string(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1248,7 +1255,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbUser = self.get_auto_base().classes.db_dbuser # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1274,7 +1281,7 @@ def test_type_string(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1301,7 +1308,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbUser = self.get_auto_base().classes.db_dbuser # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1331,7 +1338,7 @@ def test_data_migrated(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) for state, pk in self.nodes.items(): @@ -1367,7 +1374,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbUser = self.get_auto_base().classes.db_dbuser # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1396,7 +1403,7 @@ def test_data_migrated(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) node = session.query(DbNode).filter(DbNode.id == self.node_id).one() @@ -1419,7 +1426,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbUser = self.get_auto_base().classes.db_dbuser # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1484,7 +1491,7 @@ def test_data_migrated(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) deleted_keys = ['_sealed', '_finished', '_failed', '_aborted', '_do_abort'] @@ -1520,7 +1527,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbUser = self.get_auto_base().classes.db_dbuser # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1582,7 +1589,7 @@ def test_data_migrated(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1615,7 +1622,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbUser = self.get_auto_base().classes.db_dbuser # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1647,7 +1654,7 @@ def test_data_migrated(self): DbLink = self.get_auto_base().classes.db_dblink # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) link = session.query(DbLink).filter(DbLink.id == self.link_id).one() @@ -1945,7 +1952,7 @@ def setUpBeforeMigration(self): DbComputer = self.get_auto_base().classes.db_dbcomputer # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1967,7 +1974,7 @@ def test_migration(self): DbComputer = self.get_auto_base().classes.db_dbcomputer # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -1991,7 +1998,7 @@ def setUpBeforeMigration(self): DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name DbUser = self.get_auto_base().classes.db_dbuser # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) @@ -2031,7 +2038,7 @@ def test_migration(self): DbComputer = self.get_auto_base().classes.db_dbcomputer # pylint: disable=invalid-name DbNode = self.get_auto_base().classes.db_dbnode # pylint: disable=invalid-name - with sa.ENGINE.begin() as connection: + with self.engine.begin() as connection: try: session = Session(connection.engine) diff --git a/tests/backends/aiida_sqlalchemy/test_nodes.py b/tests/backends/aiida_sqlalchemy/test_nodes.py index 349a96c289..3f58ed6562 100644 --- a/tests/backends/aiida_sqlalchemy/test_nodes.py +++ b/tests/backends/aiida_sqlalchemy/test_nodes.py @@ -20,19 +20,19 @@ class TestNodeBasicSQLA(AiidaTestCase): def test_settings(self): """Test the settings table (similar to Attributes, but without the key.""" - from aiida.backends.sqlalchemy import get_scoped_session + from aiida.backends.sqlalchemy.models.settings import DbSetting - session = get_scoped_session() + session = self.backend.get_session() # pylint: disable=no-member from pytz import UTC from sqlalchemy.exc import IntegrityError from aiida.common import timezone - DbSetting.set_value(key='pippo', value=[1, 2, 3]) + DbSetting.set_value(session, key='pippo', value=[1, 2, 3]) # s_1 = DbSetting.objects.get(key='pippo') - s_1 = DbSetting.query.filter_by(key='pippo').first() # pylint: disable=no-member + s_1 = session.query(DbSetting).filter_by(key='pippo').first() # pylint: disable=no-member self.assertEqual(s_1.getvalue(), [1, 2, 3]) @@ -44,14 +44,14 @@ def test_settings(self): session.add(s_2) # Should replace pippo - DbSetting.set_value(key='pippo', value='a') - s_1 = DbSetting.query.filter_by(key='pippo').first() # pylint: disable=no-member + DbSetting.set_value(session, key='pippo', value='a') + s_1 = session.query(DbSetting).filter_by(key='pippo').first() # pylint: disable=no-member self.assertEqual(s_1.getvalue(), 'a') def test_load_nodes(self): """Test for load_node() function.""" - from aiida.backends.sqlalchemy import get_scoped_session + from aiida.orm import load_node a_obj = Data() @@ -62,7 +62,7 @@ def test_load_nodes(self): self.assertEqual(a_obj.pk, load_node(pk=a_obj.pk).pk) self.assertEqual(a_obj.pk, load_node(uuid=a_obj.uuid).pk) - session = get_scoped_session() + session = self.backend.get_session() # pylint: disable=no-member try: session.begin_nested() @@ -105,19 +105,16 @@ def test_multiple_node_creation(self): (and subsequently committed) when a user is in the session. It tests the fix for the issue #234 """ - import aiida.backends.sqlalchemy from aiida.backends.sqlalchemy.models.node import DbNode from aiida.common.utils import get_new_uuid - backend = self.backend - # Get the automatic user - dbuser = backend.users.create(f'{self.id()}@aiida.net').store().dbmodel + dbuser = self.backend.users.create(f'{self.id()}@aiida.net').store().dbmodel # pylint: disable=no-member # Create a new node but don't add it to the session node_uuid = get_new_uuid() DbNode(user=dbuser, uuid=node_uuid, node_type=None) - session = aiida.backends.sqlalchemy.get_scoped_session() + session = self.backend.get_session() # pylint: disable=no-member # Query the session before commit res = session.query(DbNode.uuid).filter(DbNode.uuid == node_uuid).all() diff --git a/tests/backends/aiida_sqlalchemy/test_schema.py b/tests/backends/aiida_sqlalchemy/test_schema.py index 92f35e8a95..e6271f78c8 100644 --- a/tests/backends/aiida_sqlalchemy/test_schema.py +++ b/tests/backends/aiida_sqlalchemy/test_schema.py @@ -13,7 +13,6 @@ from sqlalchemy import exc as sa_exc -import aiida from aiida.backends.sqlalchemy.models.node import DbNode from aiida.backends.sqlalchemy.models.user import DbUser from aiida.backends.testbase import AiidaTestCase @@ -31,6 +30,11 @@ class TestRelationshipsSQLA(AiidaTestCase): 2)tests on many-to-many relationships: test__ (none is capitalized).""" + @property + def session(self): + """Return the backend session""" + return self.backend.get_session() # pylint: disable=no-member + def test_outputs_children_relationship(self): """This test checks that the outputs_q, children_q relationship and the corresponding properties work as expected.""" @@ -98,7 +102,7 @@ def test_user_node_1(self): self.assertIsNone(dbu1.id) self.assertIsNone(dbn_1.id) - session = aiida.backends.sqlalchemy.get_scoped_session() + session = self.session # Add only the node and commit session.add(dbn_1) session.commit() @@ -124,7 +128,7 @@ def test_user_node_2(self): self.assertIsNone(dbu1.id) self.assertIsNone(dbn_1.id) - session = aiida.backends.sqlalchemy.get_scoped_session() + session = self.session # Catch all the SQLAlchemy warnings generated by the following code with warnings.catch_warnings(): # pylint: disable=no-member @@ -159,7 +163,7 @@ def test_user_node_3(self): self.assertIsNone(dbn_1.id) self.assertIsNone(dbn_2.id) - session = aiida.backends.sqlalchemy.get_scoped_session() + session = self.session # Add only first node and commit session.add(dbn_1) @@ -198,7 +202,7 @@ def test_user_node_4(self): self.assertIsNone(dbu1.id) self.assertIsNone(dbn_1.id) - session = aiida.backends.sqlalchemy.get_scoped_session() + session = self.session # Add only first node and commit session.add(dbn_1) diff --git a/tests/backends/aiida_sqlalchemy/test_session.py b/tests/backends/aiida_sqlalchemy/test_session.py index d6c3e48da2..f5cd2b020a 100644 --- a/tests/backends/aiida_sqlalchemy/test_session.py +++ b/tests/backends/aiida_sqlalchemy/test_session.py @@ -39,23 +39,18 @@ def set_connection(self, expire_on_commit=True): # Cleaning the database self.clean_db() - aiida.backends.sqlalchemy.get_scoped_session().expunge_all() + self.backend.get_session().expunge_all() - @staticmethod - def drop_connection(): + def drop_connection(self): """Drop connection to a database.""" - session = aiida.backends.sqlalchemy.get_scoped_session() - session.expunge_all() - session.close() - aiida.backends.sqlalchemy.sessionfactory = None + self.backend.reset() def test_session_update_and_expiration_1(self): """expire_on_commit=True & adding manually and committing computer and code objects.""" self.set_connection(expire_on_commit=True) - - session = aiida.backends.sqlalchemy.get_scoped_session() + session = self.backend.get_session() email = get_manager().get_profile().default_user user = self.backend.users.create(email=email) @@ -79,9 +74,8 @@ def test_session_update_and_expiration_2(self): """expire_on_commit=True & committing computer and code objects with their built-in store function.""" - session = aiida.backends.sqlalchemy.get_scoped_session() - self.set_connection(expire_on_commit=True) + session = self.backend.get_session() email = get_manager().get_profile().default_user user = self.backend.users.create(email=email) @@ -103,8 +97,7 @@ def test_session_update_and_expiration_3(self): computer and code objects. """ self.set_connection(expire_on_commit=False) - - session = aiida.backends.sqlalchemy.get_scoped_session() + session = self.backend.get_session() email = get_manager().get_profile().default_user user = self.backend.users.create(email=email) @@ -129,8 +122,7 @@ def test_session_update_and_expiration_4(self): their built-in store function.""" self.set_connection(expire_on_commit=False) - - session = aiida.backends.sqlalchemy.get_scoped_session() + session = self.backend.get_session() email = get_manager().get_profile().default_user user = self.backend.users.create(email=email) @@ -155,12 +147,12 @@ def test_node_access_with_sessions(self): import aiida.backends.sqlalchemy as sa from aiida.common import timezone - session = sessionmaker(bind=sa.ENGINE) + session = sessionmaker(bind=self.backend.get_session().bind) custom_session = session() user = self.backend.users.create(email='test@localhost').store() node = self.backend.nodes.create(node_type='', user=user).store() - master_session = node.dbmodel.session + master_session = self.backend.get_session() self.assertIsNot(master_session, custom_session) # Manually load the DbNode in a different session diff --git a/tests/orm/implementation/test_backend.py b/tests/orm/implementation/test_backend.py index b7b89e7fe5..c270fa538e 100644 --- a/tests/orm/implementation/test_backend.py +++ b/tests/orm/implementation/test_backend.py @@ -16,13 +16,17 @@ class TestBackend(AiidaTestCase): """Test backend.""" + @property + def transaction(self): + return self.backend.transaction() # pylint: disable=no-member + def test_transaction_nesting(self): """Test that transaction nesting works.""" user = orm.User('initial@email.com').store() - with self.backend.transaction(): + with self.transaction: user.email = 'pre-failure@email.com' try: - with self.backend.transaction(): + with self.transaction: user.email = 'failure@email.com' self.assertEqual(user.email, 'failure@email.com') raise RuntimeError @@ -37,7 +41,7 @@ def test_transaction(self): user2 = orm.User('user2@email.com').store() try: - with self.backend.transaction(): + with self.transaction: user1.email = 'broken1@email.com' user2.email = 'broken2@email.com' raise RuntimeError @@ -49,14 +53,21 @@ def test_transaction(self): def test_store_in_transaction(self): """Test that storing inside a transaction is correctly dealt with.""" user1 = orm.User('user_store@email.com') - with self.backend.transaction(): + assert user1.is_stored is False + + with self.transaction: + assert user1.is_stored is False user1.store() + assert user1.is_stored is True + + assert user1.is_stored is True + # the following shouldn't raise orm.User.objects.get(email='user_store@email.com') user2 = orm.User('user_store_fail@email.com') try: - with self.backend.transaction(): + with self.transaction: user2.store() raise RuntimeError except RuntimeError: diff --git a/tests/orm/implementation/test_comments.py b/tests/orm/implementation/test_comments.py index b19bd483db..cbab826718 100644 --- a/tests/orm/implementation/test_comments.py +++ b/tests/orm/implementation/test_comments.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-member """Unit tests for the BackendComment and BackendCommentCollection classes.""" from datetime import datetime @@ -20,14 +21,11 @@ class TestBackendComment(AiidaTestCase): """Test BackendComment.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.computer = cls.computer.backend_entity # Unwrap the `Computer` instance to `BackendComputer` - cls.user = cls.backend.users.create(email='tester@localhost').store() - def setUp(self): super().setUp() + # get backend instances of the entities + self.computer = self.computer.backend_entity + self.user = self.user.backend_entity self.node = self.backend.nodes.create( node_type='', user=self.user, computer=self.computer, label='label', description='description' ).store() @@ -334,7 +332,7 @@ def test_deleting_non_existent_entities(self): # NotExistent should be raised, since no entities are found with self.assertRaises(exceptions.NotExistent) as exc: self.backend.comments.delete(comment_id=id_) - self.assertIn(f"Comment with id '{id_}' not found", str(exc.exception)) + self.assertIn(f'Comment<{id_}> does not exist', str(exc.exception)) # Try to delete existing and non-existing Comment - using delete_many # delete_many should return a list that *only* includes the existing Comment diff --git a/tests/orm/implementation/test_logs.py b/tests/orm/implementation/test_logs.py index 0d9f928806..d25923c06f 100644 --- a/tests/orm/implementation/test_logs.py +++ b/tests/orm/implementation/test_logs.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-member """Unit tests for the BackendLog and BackendLogCollection classes.""" from datetime import datetime @@ -22,14 +23,11 @@ class TestBackendLog(AiidaTestCase): """Test BackendLog.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.computer = cls.computer.backend_entity # Unwrap the `Computer` instance to `BackendComputer` - cls.user = cls.backend.users.create(email='tester@localhost').store() - def setUp(self): super().setUp() + # get backend instances of the entities + self.computer = self.computer.backend_entity + self.user = self.user.backend_entity self.node = self.backend.nodes.create( node_type='', user=self.user, computer=self.computer, label='label', description='description' ).store() @@ -285,7 +283,7 @@ def test_deleting_non_existent_entities(self): # NotExistent should be raised, since no entities are found with self.assertRaises(exceptions.NotExistent) as exc: self.backend.logs.delete(log_id=id_) - self.assertIn(f"Log with id '{id_}' not found", str(exc.exception)) + self.assertIn(f'Log<{id_}> does not exist', str(exc.exception)) # Try to delete existing and non-existing Log - using delete_many # delete_many should return a list that *only* includes the existing Logs diff --git a/tests/orm/implementation/test_nodes.py b/tests/orm/implementation/test_nodes.py index e4d017a1f3..563ded560e 100644 --- a/tests/orm/implementation/test_nodes.py +++ b/tests/orm/implementation/test_nodes.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-public-methods +# pylint: disable=no-member,too-many-public-methods """Unit tests for the BackendNode and BackendNodeCollection classes.""" from collections import OrderedDict @@ -21,14 +21,11 @@ class TestBackendNode(AiidaTestCase): """Test BackendNode.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.computer = cls.computer.backend_entity # Unwrap the `Computer` instance to `BackendComputer` - cls.user = cls.backend.users.create(email='tester@localhost').store() - def setUp(self): super().setUp() + # unwrap the default user/computer instances to their backend entities + self.user = self.user.backend_entity + self.computer = self.computer.backend_entity self.node_type = '' self.node_label = 'label' self.node_description = 'description' diff --git a/tests/orm/test_authinfos.py b/tests/orm/test_authinfos.py index 07051662d4..c170e14d6f 100644 --- a/tests/orm/test_authinfos.py +++ b/tests/orm/test_authinfos.py @@ -21,7 +21,6 @@ def setUp(self): super().setUp() for auth_info in authinfos.AuthInfo.objects.all(): authinfos.AuthInfo.objects.delete(auth_info.pk) - self.auth_info = self.computer.configure() # pylint: disable=no-member def test_set_auth_params(self): diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index 3e52f4d0f7..942a84935a 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -566,7 +566,9 @@ def test_clear_extras(self): self.assertEqual(self.group.extras, {}) # Repeat for stored group + self.group.set_extra_many(extras) self.group.store() + self.assertEqual(self.group.extras, extras) self.group.clear_extras() self.assertEqual(orm.load_group(self.group.pk).extras, {}) diff --git a/tests/restapi/conftest.py b/tests/restapi/conftest.py index de991b2abe..3222d7e0ba 100644 --- a/tests/restapi/conftest.py +++ b/tests/restapi/conftest.py @@ -52,10 +52,9 @@ def restrict_sqlalchemy_queuepool(aiida_profile): # pylint: disable=unused-argu """Create special SQLAlchemy engine for use with QueryBuilder - backend-agnostic""" from aiida.manage.manager import get_manager - manager = get_manager() backend = get_manager().get_backend() - backend.reset_environment() - backend.load_environment(manager.get_profile(), pool_timeout=1, max_overflow=0) + backend.close() + backend.get_session(pool_timeout=1, max_overflow=0) @pytest.fixture From 96c6c905e4a79d500ea9d46145d3c43c4578c544 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 14 Oct 2021 09:00:53 +0200 Subject: [PATCH 05/16] revert django is_stored change --- aiida/orm/implementation/django/computers.py | 2 +- aiida/orm/implementation/django/entities.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aiida/orm/implementation/django/computers.py b/aiida/orm/implementation/django/computers.py index 874b4c3b3c..f345a68e42 100644 --- a/aiida/orm/implementation/django/computers.py +++ b/aiida/orm/implementation/django/computers.py @@ -63,7 +63,7 @@ def store(self): @property def is_stored(self): - return self._dbmodel.is_saved() + return self._dbmodel.id is not None @property def label(self): diff --git a/aiida/orm/implementation/django/entities.py b/aiida/orm/implementation/django/entities.py index e7cd437411..87d182f664 100644 --- a/aiida/orm/implementation/django/entities.py +++ b/aiida/orm/implementation/django/entities.py @@ -84,7 +84,7 @@ def is_stored(self): :return: True if stored, False otherwise """ - return self._dbmodel.is_saved() + return self._dbmodel.id is not None def store(self): """ From f95d118eda514df010c808df0aa4774ebc45ebe6 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 14 Oct 2021 16:57:55 +0200 Subject: [PATCH 06/16] Fix cmdline tests --- aiida/common/__init__.py | 1 + aiida/common/exceptions.py | 7 +- aiida/orm/implementation/sqlalchemy/utils.py | 12 +- .../aiida_sqlalchemy/test_migrations.py | 14 +- tests/cmdline/commands/test_calcjob.py | 180 +++++---- tests/cmdline/commands/test_code.py | 20 +- tests/cmdline/commands/test_computer.py | 36 +- tests/cmdline/commands/test_data.py | 11 - tests/cmdline/commands/test_node.py | 349 ++++++++---------- .../cmdline/params/types/test_calculation.py | 47 ++- tests/cmdline/params/types/test_data.py | 39 +- tests/cmdline/params/types/test_node.py | 40 +- tests/conftest.py | 6 +- 13 files changed, 363 insertions(+), 399 deletions(-) diff --git a/aiida/common/__init__.py b/aiida/common/__init__.py index 5a4963a697..946e197eca 100644 --- a/aiida/common/__init__.py +++ b/aiida/common/__init__.py @@ -30,6 +30,7 @@ 'AIIDA_LOGGER', 'AiidaException', 'AttributeDict', + 'BackendClosedError', 'CalcInfo', 'CalcJobState', 'CodeInfo', diff --git a/aiida/common/exceptions.py b/aiida/common/exceptions.py index fbae709d02..7bc0933209 100644 --- a/aiida/common/exceptions.py +++ b/aiida/common/exceptions.py @@ -17,7 +17,8 @@ 'PluginInternalError', 'ValidationError', 'ConfigurationError', 'ProfileConfigurationError', 'MissingConfigurationError', 'ConfigurationVersionError', 'IncompatibleDatabaseSchema', 'DbContentError', 'InputValidationError', 'FeatureNotAvailable', 'FeatureDisabled', 'LicensingException', 'TestsNotAllowedError', - 'UnsupportedSpeciesError', 'TransportTaskException', 'OutputParsingError', 'HashingError', 'DatabaseMigrationError' + 'UnsupportedSpeciesError', 'TransportTaskException', 'OutputParsingError', 'HashingError', 'DatabaseMigrationError', + 'BackendClosedError' ) @@ -260,3 +261,7 @@ class HashingError(AiidaException): """ Raised when an attempt to hash an object fails via a known failure mode """ + + +class BackendClosedError(AiidaException): + """Raised when trying to manipulate an entity who's backend is closed""" diff --git a/aiida/orm/implementation/sqlalchemy/utils.py b/aiida/orm/implementation/sqlalchemy/utils.py index 52baa2affc..106ca84231 100644 --- a/aiida/orm/implementation/sqlalchemy/utils.py +++ b/aiida/orm/implementation/sqlalchemy/utils.py @@ -15,6 +15,7 @@ from sqlalchemy.exc import IntegrityError, InvalidRequestError from sqlalchemy.orm import Session from sqlalchemy.orm.attributes import flag_modified +from sqlalchemy.orm.exc import DetachedInstanceError from aiida.common import exceptions @@ -74,7 +75,11 @@ def __getattr__(self, item): if self.is_saved() and self._is_mutable_model_field(item) and not self._in_transaction(): self._ensure_model_uptodate(fields=(item,)) - return getattr(self._model, item) + try: + _attr = getattr(self._model, item) + except DetachedInstanceError as exc: + raise exceptions.BackendClosedError(f'The backend for this instance has been closed: {exc}') from exc + return _attr def __setattr__(self, key, value): """Set the attribute on the model instance. @@ -84,7 +89,10 @@ def __setattr__(self, key, value): :param key: the name of the model field :param value: the value to set """ - setattr(self._model, key, value) + try: + setattr(self._model, key, value) + except DetachedInstanceError as exc: + raise exceptions.BackendClosedError(f'The backend for this instance has been closed: {exc}') from exc if self.is_saved() and self._is_mutable_model_field(key): fields = set((key,) + self._auto_flush) self._flush(fields=fields) diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index c67739f5f1..5a9d4b797f 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -84,8 +84,11 @@ def migrate_db_up(destination): # create a new backend which does not validate the schema version backend = SqlaBackend(get_profile(), validate_db=False) - with backend._backend_manager.alembic_config() as config: - command.upgrade(config, destination) + try: + with backend._backend_manager.alembic_config() as config: + command.upgrade(config, destination) + finally: + backend.close() @staticmethod def migrate_db_down(destination): @@ -98,8 +101,11 @@ def migrate_db_down(destination): # create a new backend which does not validate the schema version backend = SqlaBackend(get_profile(), validate_db=False) - with backend._backend_manager.alembic_config() as config: - command.downgrade(config, destination) + try: + with backend._backend_manager.alembic_config() as config: + command.downgrade(config, destination) + finally: + backend.close() def tearDown(self): """ diff --git a/tests/cmdline/commands/test_calcjob.py b/tests/cmdline/commands/test_calcjob.py index 5b3bead7af..aa02e2b9e3 100644 --- a/tests/cmdline/commands/test_calcjob.py +++ b/tests/cmdline/commands/test_calcjob.py @@ -12,9 +12,9 @@ import io from click.testing import CliRunner +import pytest from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.commands import cmd_calcjob as command from aiida.common.datastructures import CalcJobState from aiida.plugins import CalculationFactory @@ -26,16 +26,17 @@ def get_result_lines(result): return [e for e in result.output.split('\n') if e] -class TestVerdiCalculation(AiidaTestCase): +class TestVerdiCalculation: """Tests for `verdi calcjob`.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test): # pylint: disable=unused-argument + """Setup database before each test""" + # pylint: disable=attribute-defined-outside-init from aiida.common.links import LinkType from aiida.engine import ProcessState - cls.computer = orm.Computer( + self.computer = orm.Computer( label='comp', hostname='localhost', transport_type='core.local', @@ -43,13 +44,13 @@ def setUpClass(cls, *args, **kwargs): workdir='/tmp/aiida' ).store() - cls.code = orm.Code(remote_computer_exec=(cls.computer, '/bin/true')).store() - cls.group = orm.Group(label='test_group').store() - cls.node = orm.Data().store() - cls.calcs = [] + self.code = orm.Code(remote_computer_exec=(self.computer, '/bin/true')).store() + self.group = orm.Group(label='test_group').store() + self.node = orm.Data().store() + self.calcs = [] user = orm.User.objects.get_default() - authinfo = orm.AuthInfo(computer=cls.computer, user=user) + authinfo = orm.AuthInfo(computer=self.computer, user=user) authinfo.store() process_class = CalculationFactory('core.templatereplacer') @@ -58,40 +59,40 @@ def setUpClass(cls, *args, **kwargs): # Create 5 CalcJobNodes (one for each CalculationState) for calculation_state in CalcJobState: - calc = orm.CalcJobNode(computer=cls.computer, process_type=process_type) + calc = orm.CalcJobNode(computer=self.computer, process_type=process_type) calc.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) calc.store() calc.set_process_state(ProcessState.RUNNING) - cls.calcs.append(calc) + self.calcs.append(calc) if calculation_state == CalcJobState.PARSING: - cls.KEY_ONE = 'key_one' - cls.KEY_TWO = 'key_two' - cls.VAL_ONE = 'val_one' - cls.VAL_TWO = 'val_two' + self.KEY_ONE = 'key_one' + self.KEY_TWO = 'key_two' + self.VAL_ONE = 'val_one' + self.VAL_TWO = 'val_two' output_parameters = orm.Dict(dict={ - cls.KEY_ONE: cls.VAL_ONE, - cls.KEY_TWO: cls.VAL_TWO, + self.KEY_ONE: self.VAL_ONE, + self.KEY_TWO: self.VAL_TWO, }).store() output_parameters.add_incoming(calc, LinkType.CREATE, 'output_parameters') # Create shortcut for easy dereferencing - cls.result_job = calc + self.result_job = calc # Add a single calc to a group - cls.group.add_nodes([calc]) + self.group.add_nodes([calc]) # Create a single failed CalcJobNode - cls.EXIT_STATUS = 100 - calc = orm.CalcJobNode(computer=cls.computer) + self.EXIT_STATUS = 100 + calc = orm.CalcJobNode(computer=self.computer) calc.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) calc.store() - calc.set_exit_status(cls.EXIT_STATUS) + calc.set_exit_status(self.EXIT_STATUS) calc.set_process_state(ProcessState.FINISHED) - cls.calcs.append(calc) + self.calcs.append(calc) # Load the fixture containing a single ArithmeticAddCalculation node import_archive('calcjob/arithmetic.add.aiida') @@ -99,101 +100,98 @@ def setUpClass(cls, *args, **kwargs): # Get the imported ArithmeticAddCalculation node ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') calculations = orm.QueryBuilder().append(ArithmeticAddCalculation).all()[0] - cls.arithmetic_job = calculations[0] - print(cls.arithmetic_job.repository_metadata) + self.arithmetic_job = calculations[0] - def setUp(self): - super().setUp() self.cli_runner = CliRunner() def test_calcjob_res(self): """Test verdi calcjob res""" options = [str(self.result_job.uuid)] result = self.cli_runner.invoke(command.calcjob_res, options) - self.assertIsNone(result.exception, result.output) - self.assertIn(self.KEY_ONE, result.output) - self.assertIn(self.VAL_ONE, result.output) - self.assertIn(self.KEY_TWO, result.output) - self.assertIn(self.VAL_TWO, result.output) + assert result.exception is None, result.output + assert self.KEY_ONE in result.output + assert self.VAL_ONE in result.output + assert self.KEY_TWO in result.output + assert self.VAL_TWO in result.output for flag in ['-k', '--keys']: options = [flag, self.KEY_ONE, '--', str(self.result_job.uuid)] result = self.cli_runner.invoke(command.calcjob_res, options) - self.assertIsNone(result.exception, result.output) - self.assertIn(self.KEY_ONE, result.output) - self.assertIn(self.VAL_ONE, result.output) - self.assertNotIn(self.KEY_TWO, result.output) - self.assertNotIn(self.VAL_TWO, result.output) + assert result.exception is None, result.output + assert self.KEY_ONE in result.output + assert self.VAL_ONE in result.output + assert self.KEY_TWO not in result.output + assert self.VAL_TWO not in result.output def test_calcjob_inputls(self): """Test verdi calcjob inputls""" options = [] result = self.cli_runner.invoke(command.calcjob_inputls, options) - self.assertIsNotNone(result.exception) + assert result.exception is not None options = [self.arithmetic_job.uuid] result = self.cli_runner.invoke(command.calcjob_inputls, options) - self.assertIsNone(result.exception, result.output) + assert result.exception is None, result.output # There is also an additional fourth file added by hand to test retrieval of binary content # see comments in test_calcjob_inputcat - self.assertEqual(len(get_result_lines(result)), 4) - self.assertIn('.aiida', get_result_lines(result)) - self.assertIn('aiida.in', get_result_lines(result)) - self.assertIn('_aiidasubmit.sh', get_result_lines(result)) - self.assertIn('in_gzipped_data', get_result_lines(result)) + assert len(get_result_lines(result)) == 4 + assert '.aiida' in get_result_lines(result) + assert 'aiida.in' in get_result_lines(result) + assert '_aiidasubmit.sh' in get_result_lines(result) + assert 'in_gzipped_data' in get_result_lines(result) options = [self.arithmetic_job.uuid, '.aiida'] result = self.cli_runner.invoke(command.calcjob_inputls, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 2) - self.assertIn('calcinfo.json', get_result_lines(result)) - self.assertIn('job_tmpl.json', get_result_lines(result)) + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 2 + assert 'calcinfo.json' in get_result_lines(result) + assert 'job_tmpl.json' in get_result_lines(result) options = [self.arithmetic_job.uuid, 'non-existing-folder'] result = self.cli_runner.invoke(command.calcjob_inputls, options) - self.assertIsNotNone(result.exception) - self.assertIn('does not exist for the given node', result.output) + assert result.exception is not None + assert 'does not exist for the given node' in result.output def test_calcjob_outputls(self): """Test verdi calcjob outputls""" options = [] result = self.cli_runner.invoke(command.calcjob_outputls, options) - self.assertIsNotNone(result.exception, msg=result.output) + assert result.exception is not None, result.output options = [self.arithmetic_job.uuid] result = self.cli_runner.invoke(command.calcjob_outputls, options) - self.assertIsNone(result.exception, result.output) + assert result.exception is None, result.output # There is also an additional fourth file added by hand to test retrieval of binary content # see comments in test_calcjob_outputcat - self.assertEqual(len(get_result_lines(result)), 4) - self.assertIn('_scheduler-stderr.txt', get_result_lines(result)) - self.assertIn('_scheduler-stdout.txt', get_result_lines(result)) - self.assertIn('aiida.out', get_result_lines(result)) - self.assertIn('gzipped_data', get_result_lines(result)) + assert len(get_result_lines(result)) == 4 + assert '_scheduler-stderr.txt' in get_result_lines(result) + assert '_scheduler-stdout.txt' in get_result_lines(result) + assert 'aiida.out' in get_result_lines(result) + assert 'gzipped_data' in get_result_lines(result) options = [self.arithmetic_job.uuid, 'non-existing-folder'] result = self.cli_runner.invoke(command.calcjob_inputls, options) - self.assertIsNotNone(result.exception) - self.assertIn('does not exist for the given node', result.output) + assert result.exception is not None + assert 'does not exist for the given node' in result.output def test_calcjob_inputcat(self): """Test verdi calcjob inputcat""" options = [] result = self.cli_runner.invoke(command.calcjob_inputcat, options) - self.assertIsNotNone(result.exception, msg=result.output) + assert result.exception is not None, result.output options = [self.arithmetic_job.uuid] result = self.cli_runner.invoke(command.calcjob_inputcat, options) - self.assertIsNone(result.exception, msg=result.output) - self.assertEqual(len(get_result_lines(result)), 1) - self.assertEqual(get_result_lines(result)[0], '2 3') + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 1 + assert get_result_lines(result)[0] == '2 3' options = [self.arithmetic_job.uuid, 'aiida.in'] result = self.cli_runner.invoke(command.calcjob_inputcat, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1) - self.assertEqual(get_result_lines(result)[0], '2 3') + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 1 + assert get_result_lines(result)[0] == '2 3' # Test cat binary files self.arithmetic_job._repository.put_object_from_filelike(io.BytesIO(b'COMPRESS'), 'aiida.in') @@ -212,19 +210,19 @@ def test_calcjob_outputcat(self): options = [] result = self.cli_runner.invoke(command.calcjob_outputcat, options) - self.assertIsNotNone(result.exception) + assert result.exception is not None options = [self.arithmetic_job.uuid] result = self.cli_runner.invoke(command.calcjob_outputcat, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1) - self.assertEqual(get_result_lines(result)[0], '5') + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 1 + assert get_result_lines(result)[0] == '5' options = [self.arithmetic_job.uuid, 'aiida.out'] result = self.cli_runner.invoke(command.calcjob_outputcat, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1) - self.assertEqual(get_result_lines(result)[0], '5') + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 1 + assert get_result_lines(result)[0] == '5' # Test cat binary files retrieved = self.arithmetic_job.outputs.retrieved @@ -245,24 +243,24 @@ def test_calcjob_cleanworkdir(self): # Specifying no filtering options and no explicit calcjobs should exit with non-zero status options = [] result = self.cli_runner.invoke(command.calcjob_cleanworkdir, options) - self.assertIsNotNone(result.exception) + assert result.exception is not None # Cannot specify both -p and -o options for flag_p in ['-p', '--past-days']: for flag_o in ['-o', '--older-than']: options = [flag_p, '5', flag_o, '1'] result = self.cli_runner.invoke(command.calcjob_cleanworkdir, options) - self.assertIsNotNone(result.exception) + assert result.exception is not None # Without the force flag it should fail options = [str(self.result_job.uuid)] result = self.cli_runner.invoke(command.calcjob_cleanworkdir, options) - self.assertIsNotNone(result.exception) + assert result.exception is not None # With force flag we should find one calcjob options = ['-f', str(self.result_job.uuid)] result = self.cli_runner.invoke(command.calcjob_cleanworkdir, options) - self.assertIsNone(result.exception, result.output) + assert result.exception is None, result.output def test_calcjob_inoutputcat_old(self): """Test most recent process class / plug-in can be successfully used to find filenames""" @@ -271,31 +269,29 @@ def test_calcjob_inoutputcat_old(self): import_archive('calcjob/arithmetic.add_old.aiida') ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') calculations = orm.QueryBuilder().append(ArithmeticAddCalculation).all() + add_job = None for job in calculations: if job[0].uuid == self.arithmetic_job.uuid: continue add_job = job[0] - return + + assert add_job is not None # Make sure add_job does not specify options 'input_filename' and 'output_filename' - self.assertIsNone( - add_job.get_option('input_filename'), msg=f"'input_filename' should not be an option for {add_job}" - ) - self.assertIsNone( - add_job.get_option('output_filename'), msg=f"'output_filename' should not be an option for {add_job}" - ) + assert add_job.get_option('input_filename') is None, f"'input_filename' should not be an option for {add_job}" + assert add_job.get_option('output_filename') is None, f"'output_filename' should not be an option for {add_job}" # Run `verdi calcjob inputcat add_job` options = [add_job.uuid] result = self.cli_runner.invoke(command.calcjob_inputcat, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1) - self.assertEqual(get_result_lines(result)[0], '2 3') + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 1 + assert get_result_lines(result)[0] == '2 3' # Run `verdi calcjob outputcat add_job` options = [add_job.uuid] result = self.cli_runner.invoke(command.calcjob_outputcat, options) - self.assertIsNone(result.exception, result.output) - self.assertEqual(len(get_result_lines(result)), 1) - self.assertEqual(get_result_lines(result)[0], '5') + assert result.exception is None, result.output + assert len(get_result_lines(result)) == 1 + assert get_result_lines(result)[0] == '5' diff --git a/tests/cmdline/commands/test_code.py b/tests/cmdline/commands/test_code.py index 63e19e88f2..3c6ea0aa9e 100644 --- a/tests/cmdline/commands/test_code.py +++ b/tests/cmdline/commands/test_code.py @@ -26,18 +26,14 @@ class TestVerdiCodeSetup(AiidaTestCase): """Tests for the 'verdi code setup' command.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.computer = orm.Computer( + def setUp(self): + self.computer = orm.Computer.objects.get_or_create( label='comp', hostname='localhost', transport_type='core.local', scheduler_type='core.direct', workdir='/tmp/aiida' - ).store() - - def setUp(self): + ) self.cli_runner = CliRunner() self.this_folder = os.path.dirname(__file__) self.this_file = os.path.basename(__file__) @@ -118,18 +114,14 @@ class TestVerdiCodeCommands(AiidaTestCase): Testing everything besides `code setup`.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.computer = orm.Computer( + def setUp(self): + self.computer = orm.Computer.objects.get_or_create( label='comp', hostname='localhost', transport_type='core.local', scheduler_type='core.direct', workdir='/tmp/aiida' - ).store() - - def setUp(self): + ) try: code = orm.Code.get_from_string('code') except NotExistent: diff --git a/tests/cmdline/commands/test_computer.py b/tests/cmdline/commands/test_computer.py index e036360019..f08824b90f 100644 --- a/tests/cmdline/commands/test_computer.py +++ b/tests/cmdline/commands/test_computer.py @@ -529,33 +529,27 @@ class TestVerdiComputerCommands(AiidaTestCase): Testing everything besides `computer setup`. """ - @classmethod - def setUpClass(cls, *args, **kwargs): - """Create a new computer> I create a new one because I want to configure it and I don't want to - interfere with other tests""" - super().setUpClass(*args, **kwargs) - cls.computer_name = 'comp_cli_test_computer' - cls.comp = orm.Computer( - label=cls.computer_name, - hostname='localhost', - transport_type='core.local', - scheduler_type='core.direct', - workdir='/tmp/aiida' - ) - cls.comp.set_default_mpiprocs_per_machine(1) - cls.comp.set_prepend_text('text to prepend') - cls.comp.set_append_text('text to append') - cls.comp.store() - def setUp(self): """ Prepare the computer and user """ self.user = orm.User.objects.get_default() - # I need to configure the computer here; being 'core.local', - # there should not be any options asked here - self.comp.configure() + self.computer_name = 'comp_cli_test_computer' + + created, self.comp = orm.Computer.objects.get_or_create( + label=self.computer_name, + hostname='localhost', + transport_type='core.local', + scheduler_type='core.direct', + workdir='/tmp/aiida' + ) + if created: + self.comp.set_default_mpiprocs_per_machine(1) + self.comp.set_prepend_text('text to prepend') + self.comp.set_append_text('text to append') + self.comp.store() + self.comp.configure() assert self.comp.is_user_configured(self.user), 'There was a problem configuring the test computer' self.cli_runner = CliRunner() diff --git a/tests/cmdline/commands/test_data.py b/tests/cmdline/commands/test_data.py index f64edd5baf..94338c91e3 100644 --- a/tests/cmdline/commands/test_data.py +++ b/tests/cmdline/commands/test_data.py @@ -193,13 +193,6 @@ def data_listing_test(self, datatype, search_string, ids): class TestVerdiData(AiidaTestCase): """Testing reachability of the verdi data subcommands.""" - @classmethod - def setUpClass(cls): # pylint: disable=arguments-differ - super().setUpClass() - - def setUp(self): - pass - def test_reachable(self): """Testing reachability of the following commands: verdi data array @@ -222,10 +215,6 @@ def test_reachable(self): class TestVerdiDataArray(AiidaTestCase): """Testing verdi data array.""" - @classmethod - def setUpClass(cls): # pylint: disable=arguments-differ - super().setUpClass() - def setUp(self): self.arr = ArrayData() self.arr.set_array('test_array', np.array([0, 1, 3])) diff --git a/tests/cmdline/commands/test_node.py b/tests/cmdline/commands/test_node.py index 6b0d455c2a..911d3d35cc 100644 --- a/tests/cmdline/commands/test_node.py +++ b/tests/cmdline/commands/test_node.py @@ -7,6 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=attribute-defined-outside-init,invalid-name,unused-argument """Tests for verdi node""" import errno import gzip @@ -15,11 +16,9 @@ import pathlib import tempfile -from click.testing import CliRunner import pytest from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.commands import cmd_node @@ -27,35 +26,33 @@ def get_result_lines(result): return [e for e in result.output.split('\n') if e] -class TestVerdiNode(AiidaTestCase): +class TestVerdiNode: """Tests for `verdi node`.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test, run_cli_command): + """Setup the database before each test""" node = orm.Data() - cls.ATTR_KEY_ONE = 'a' - cls.ATTR_VAL_ONE = '1' - cls.ATTR_KEY_TWO = 'b' - cls.ATTR_VAL_TWO = 'test' + self.ATTR_KEY_ONE = 'a' + self.ATTR_VAL_ONE = '1' + self.ATTR_KEY_TWO = 'b' + self.ATTR_VAL_TWO = 'test' - node.set_attribute_many({cls.ATTR_KEY_ONE: cls.ATTR_VAL_ONE, cls.ATTR_KEY_TWO: cls.ATTR_VAL_TWO}) + node.set_attribute_many({self.ATTR_KEY_ONE: self.ATTR_VAL_ONE, self.ATTR_KEY_TWO: self.ATTR_VAL_TWO}) - cls.EXTRA_KEY_ONE = 'x' - cls.EXTRA_VAL_ONE = '2' - cls.EXTRA_KEY_TWO = 'y' - cls.EXTRA_VAL_TWO = 'other' + self.EXTRA_KEY_ONE = 'x' + self.EXTRA_VAL_ONE = '2' + self.EXTRA_KEY_TWO = 'y' + self.EXTRA_VAL_TWO = 'other' - node.set_extra_many({cls.EXTRA_KEY_ONE: cls.EXTRA_VAL_ONE, cls.EXTRA_KEY_TWO: cls.EXTRA_VAL_TWO}) + node.set_extra_many({self.EXTRA_KEY_ONE: self.EXTRA_VAL_ONE, self.EXTRA_KEY_TWO: self.EXTRA_VAL_TWO}) node.store() - cls.node = node + self.node = node - def setUp(self): - self.cli_runner = CliRunner() + self.cli_runner = run_cli_command @classmethod def get_unstored_folder_node(cls): @@ -77,17 +74,15 @@ def test_node_show(self): node = orm.Data().store() node.label = 'SOMELABEL' options = [str(node.pk)] - result = self.cli_runner.invoke(cmd_node.node_show, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_node.node_show, options) # Let's check some content in the output. At least the UUID and the label should be in there - self.assertIn(node.label, result.output) - self.assertIn(node.uuid, result.output) + assert node.label in result.output + assert node.uuid in result.output # Let's now test the '--print-groups' option options.append('--print-groups') - result = self.cli_runner.invoke(cmd_node.node_show, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_node.node_show, options) # I don't check the list of groups - it might be in an autogroup # Let's create a group and put the node in there @@ -95,96 +90,94 @@ def test_node_show(self): group = orm.Group(group_name).store() group.add_nodes(node) - result = self.cli_runner.invoke(cmd_node.node_show, options) - self.assertClickResultNoException(result) + result = self.cli_runner(cmd_node.node_show, options) # Now the group should be in there - self.assertIn(group_name, result.output) + assert group_name in result.output def test_node_attributes(self): """Test verdi node attributes""" options = [str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.attributes, options) - self.assertIsNone(result.exception, result.output) - self.assertIn(self.ATTR_KEY_ONE, result.output) - self.assertIn(self.ATTR_VAL_ONE, result.output) - self.assertIn(self.ATTR_KEY_TWO, result.output) - self.assertIn(self.ATTR_VAL_TWO, result.output) + result = self.cli_runner(cmd_node.attributes, options) + assert result.exception is None, result.output + assert self.ATTR_KEY_ONE in result.output + assert self.ATTR_VAL_ONE in result.output + assert self.ATTR_KEY_TWO in result.output + assert self.ATTR_VAL_TWO in result.output for flag in ['-k', '--keys']: options = [flag, self.ATTR_KEY_ONE, '--', str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.attributes, options) - self.assertIsNone(result.exception, result.output) - self.assertIn(self.ATTR_KEY_ONE, result.output) - self.assertIn(self.ATTR_VAL_ONE, result.output) - self.assertNotIn(self.ATTR_KEY_TWO, result.output) - self.assertNotIn(self.ATTR_VAL_TWO, result.output) + result = self.cli_runner(cmd_node.attributes, options) + assert result.exception is None, result.output + assert self.ATTR_KEY_ONE in result.output + assert self.ATTR_VAL_ONE in result.output + assert self.ATTR_KEY_TWO not in result.output + assert self.ATTR_VAL_TWO not in result.output for flag in ['-r', '--raw']: options = [flag, str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.attributes, options) - self.assertIsNone(result.exception, result.output) + result = self.cli_runner(cmd_node.attributes, options) + assert result.exception is None, result.output for flag in ['-f', '--format']: for fmt in ['json+date', 'yaml', 'yaml_expanded']: options = [flag, fmt, str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.attributes, options) - self.assertIsNone(result.exception, result.output) + result = self.cli_runner(cmd_node.attributes, options) + assert result.exception is None, result.output for flag in ['-i', '--identifier']: for fmt in ['pk', 'uuid']: options = [flag, fmt, str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.attributes, options) - self.assertIsNone(result.exception, result.output) + result = self.cli_runner(cmd_node.attributes, options) + assert result.exception is None, result.output def test_node_extras(self): """Test verdi node extras""" options = [str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.extras, options) - self.assertIsNone(result.exception, result.output) - self.assertIn(self.EXTRA_KEY_ONE, result.output) - self.assertIn(self.EXTRA_VAL_ONE, result.output) - self.assertIn(self.EXTRA_KEY_TWO, result.output) - self.assertIn(self.EXTRA_VAL_TWO, result.output) + result = self.cli_runner(cmd_node.extras, options) + assert result.exception is None, result.output + assert self.EXTRA_KEY_ONE in result.output + assert self.EXTRA_VAL_ONE in result.output + assert self.EXTRA_KEY_TWO in result.output + assert self.EXTRA_VAL_TWO in result.output for flag in ['-k', '--keys']: options = [flag, self.EXTRA_KEY_ONE, '--', str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.extras, options) - self.assertIsNone(result.exception, result.output) - self.assertIn(self.EXTRA_KEY_ONE, result.output) - self.assertIn(self.EXTRA_VAL_ONE, result.output) - self.assertNotIn(self.EXTRA_KEY_TWO, result.output) - self.assertNotIn(self.EXTRA_VAL_TWO, result.output) + result = self.cli_runner(cmd_node.extras, options) + assert result.exception is None, result.output + assert self.EXTRA_KEY_ONE in result.output + assert self.EXTRA_VAL_ONE in result.output + assert self.EXTRA_KEY_TWO not in result.output + assert self.EXTRA_VAL_TWO not in result.output for flag in ['-r', '--raw']: options = [flag, str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.extras, options) - self.assertIsNone(result.exception, result.output) + result = self.cli_runner(cmd_node.extras, options) + assert result.exception is None, result.output for flag in ['-f', '--format']: for fmt in ['json+date', 'yaml', 'yaml_expanded']: options = [flag, fmt, str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.extras, options) - self.assertIsNone(result.exception, result.output) + result = self.cli_runner(cmd_node.extras, options) + assert result.exception is None, result.output for flag in ['-i', '--identifier']: for fmt in ['pk', 'uuid']: options = [flag, fmt, str(self.node.uuid)] - result = self.cli_runner.invoke(cmd_node.extras, options) - self.assertIsNone(result.exception, result.output) + result = self.cli_runner(cmd_node.extras, options) + assert result.exception is None, result.output def test_node_repo_ls(self): """Test 'verdi node repo ls' command.""" folder_node = self.get_unstored_folder_node().store() options = [str(folder_node.pk), 'some/nested/folder'] - result = self.cli_runner.invoke(cmd_node.repo_ls, options, catch_exceptions=False) - self.assertClickResultNoException(result) - self.assertIn('filename.txt', result.output) + result = self.cli_runner(cmd_node.repo_ls, options, catch_exceptions=False) + assert 'filename.txt' in result.output options = [str(folder_node.pk), 'some/non-existing-folder'] - result = self.cli_runner.invoke(cmd_node.repo_ls, options, catch_exceptions=False) - self.assertIsNotNone(result.exception) - self.assertIn('does not exist for the given node', result.output) + result = self.cli_runner(cmd_node.repo_ls, options, catch_exceptions=False) + assert result.exception is not None + assert 'does not exist for the given node' in result.output def test_node_repo_cat(self): """Test 'verdi node repo cat' command.""" @@ -195,7 +188,7 @@ def test_node_repo_cat(self): folder_node.store() options = [str(folder_node.pk), 'filename.txt.gz'] - result = self.cli_runner.invoke(cmd_node.repo_cat, options) + result = self.cli_runner(cmd_node.repo_cat, options) assert gzip.decompress(result.stdout_bytes) == b'COMPRESS' def test_node_repo_dump(self): @@ -205,16 +198,16 @@ def test_node_repo_dump(self): with tempfile.TemporaryDirectory() as tmp_dir: out_path = pathlib.Path(tmp_dir) / 'out_dir' options = [str(folder_node.uuid), str(out_path)] - res = self.cli_runner.invoke(cmd_node.repo_dump, options, catch_exceptions=False) - self.assertFalse(res.stdout) + res = self.cli_runner(cmd_node.repo_dump, options, catch_exceptions=False) + assert not res.stdout for file_key, content in [(self.key_file1, self.content_file1), (self.key_file2, self.content_file2)]: curr_path = out_path for key_part in file_key.split('/'): curr_path /= key_part - self.assertTrue(curr_path.exists()) + assert curr_path.exists() with curr_path.open('r') as res_file: - self.assertEqual(res_file.read(), content) + assert res_file.read() == content def test_node_repo_dump_to_nested_folder(self): """Test 'verdi node repo dump' command, with an output folder whose parent does not exist.""" @@ -223,16 +216,16 @@ def test_node_repo_dump_to_nested_folder(self): with tempfile.TemporaryDirectory() as tmp_dir: out_path = pathlib.Path(tmp_dir) / 'out_dir' / 'nested' / 'path' options = [str(folder_node.uuid), str(out_path)] - res = self.cli_runner.invoke(cmd_node.repo_dump, options, catch_exceptions=False) - self.assertFalse(res.stdout) + res = self.cli_runner(cmd_node.repo_dump, options, catch_exceptions=False) + assert not res.stdout for file_key, content in [(self.key_file1, self.content_file1), (self.key_file2, self.content_file2)]: curr_path = out_path for key_part in file_key.split('/'): curr_path /= key_part - self.assertTrue(curr_path.exists()) + assert curr_path.exists() with curr_path.open('r') as res_file: - self.assertEqual(res_file.read(), content) + assert res_file.read() == content def test_node_repo_existing_out_dir(self): """Test 'verdi node repo dump' command, check that an existing output directory is not overwritten.""" @@ -247,9 +240,9 @@ def test_node_repo_existing_out_dir(self): with some_file.open('w') as file_handle: file_handle.write(some_file_content) options = [str(folder_node.uuid), str(out_path)] - res = self.cli_runner.invoke(cmd_node.repo_dump, options, catch_exceptions=False) - self.assertIn('exists', res.stdout) - self.assertIn('Critical:', res.stdout) + res = self.cli_runner(cmd_node.repo_dump, options, catch_exceptions=False) + assert 'exists' in res.stdout + assert 'Critical:' in res.stdout # Make sure the directory content is still there with some_file.open('r') as file_handle: @@ -272,29 +265,24 @@ def delete_temporary_file(filepath): pass -class TestVerdiGraph(AiidaTestCase): +class TestVerdiGraph: """Tests for the ``verdi node graph`` command.""" - @classmethod - def setUpClass(cls): - super().setUpClass() + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test, tmp_path, run_cli_command): + """Setup the database before each test""" from aiida.orm import Data - cls.node = Data().store() + self.node = Data().store() + self.cli_runner = run_cli_command # some of the export tests write in the current directory, # make sure it is writeable and we don't pollute the current one - cls.old_cwd = os.getcwd() - cls.cwd = tempfile.mkdtemp(__name__) - os.chdir(cls.cwd) - - @classmethod - def tearDownClass(cls): - os.chdir(cls.old_cwd) - os.rmdir(cls.cwd) - - def setUp(self): - self.cli_runner = CliRunner() + self.old_cwd = os.getcwd() + self.cwd = str(tmp_path) + os.chdir(self.cwd) + yield + os.chdir(self.old_cwd) def test_generate_graph(self): """ @@ -306,9 +294,9 @@ def test_generate_graph(self): filename = f'{root_node}.dot.pdf' options = [root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(os.path.isfile(filename)) + result = self.cli_runner(cmd_node.graph_generate, options) + assert result.exception is None, result.output + assert os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -325,9 +313,8 @@ def test_catch_bad_pk(self): options = [root_node] filename = f'{root_node}.dot.pdf' try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNotNone(result.exception) - self.assertFalse(os.path.isfile(filename)) + self.cli_runner(cmd_node.graph_generate, options, raises=True) + assert not os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -337,16 +324,15 @@ def test_catch_bad_pk(self): root_node = 123456789 try: node = load_node(pk=root_node) - self.assertIsNone(node) + assert node is None except NotExistent: pass # Make sure verdi graph rejects this non-existant pk try: filename = f'{str(root_node)}.dot.pdf' options = [str(root_node)] - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNotNone(result.exception) - self.assertFalse(os.path.isfile(filename)) + self.cli_runner(cmd_node.graph_generate, options, raises=True) + assert not os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -363,9 +349,9 @@ def test_check_recursion_flags(self): for opt in ['-a', '--ancestor-depth', '-d', '--descendant-depth']: options = [opt, None, root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(os.path.isfile(filename)) + result = self.cli_runner(cmd_node.graph_generate, options) + assert result.exception is None, result.output + assert os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -374,9 +360,9 @@ def test_check_recursion_flags(self): for value in ['0', '1']: options = [opt, value, root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(os.path.isfile(filename)) + result = self.cli_runner(cmd_node.graph_generate, options) + assert result.exception is None, result.output + assert os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -385,9 +371,8 @@ def test_check_recursion_flags(self): for badvalue in ['xyz', '3.14', '-5']: options = [flag, badvalue, root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNotNone(result.exception) - self.assertFalse(os.path.isfile(filename)) + result = self.cli_runner(cmd_node.graph_generate, options, raises=True) + assert not os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -401,9 +386,9 @@ def test_check_io_flags(self): for flag in ['-i', '--process-in', '-o', '--process-out']: options = [flag, root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(os.path.isfile(filename)) + result = self.cli_runner(cmd_node.graph_generate, options) + assert result.exception is None, result.output + assert os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -423,9 +408,9 @@ def test_output_format(self): filename = f'{root_node}.dot.{fileformat}' options = [option, fileformat, root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(os.path.isfile(filename)) + result = self.cli_runner(cmd_node.graph_generate, options) + assert result.exception is None, result.output + assert os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -439,9 +424,9 @@ def test_node_id_label_format(self): for id_label_type in ['uuid', 'pk', 'label']: options = ['--identifier', id_label_type, root_node] try: - result = self.cli_runner.invoke(cmd_node.graph_generate, options) - self.assertIsNone(result.exception, result.output) - self.assertTrue(os.path.isfile(filename)) + result = self.cli_runner(cmd_node.graph_generate, options) + assert result.exception is None, result.output + assert os.path.isfile(filename) finally: delete_temporary_file(filename) @@ -449,153 +434,143 @@ def test_node_id_label_format(self): COMMENT = 'Well I never...' -class TestVerdiUserCommand(AiidaTestCase): +class TestVerdiUserCommand: """Tests for the ``verdi node comment`` command.""" - def setUp(self): - self.cli_runner = CliRunner() + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test, run_cli_command): + """Setup the database before each test""" + self.cli_runner = run_cli_command self.node = orm.Data().store() def test_comment_show_simple(self): """Test simply calling the show command (without data to show).""" - result = self.cli_runner.invoke(cmd_node.comment_show, [], catch_exceptions=False) - self.assertEqual(result.output, '') - self.assertEqual(result.exit_code, 0) + result = self.cli_runner(cmd_node.comment_show, [], catch_exceptions=False) + assert result.output == '' + assert result.exit_code == 0 def test_comment_show(self): """Test showing an existing comment.""" self.node.add_comment(COMMENT) options = [str(self.node.pk)] - result = self.cli_runner.invoke(cmd_node.comment_show, options, catch_exceptions=False) - self.assertNotEqual(result.output.find(COMMENT), -1) - self.assertEqual(result.exit_code, 0) + result = self.cli_runner(cmd_node.comment_show, options, catch_exceptions=False) + assert result.output.find(COMMENT) != -1 + assert result.exit_code == 0 def test_comment_add(self): """Test adding a comment.""" options = ['-N', str(self.node.pk), '--', f'{COMMENT}'] - result = self.cli_runner.invoke(cmd_node.comment_add, options, catch_exceptions=False) - self.assertEqual(result.exit_code, 0) + result = self.cli_runner(cmd_node.comment_add, options, catch_exceptions=False) + assert result.exit_code == 0 comment = self.node.get_comments() - self.assertEqual(len(comment), 1) - self.assertEqual(comment[0].content, COMMENT) + assert len(comment) == 1 + assert comment[0].content == COMMENT def test_comment_remove(self): """Test removing a comment.""" comment = self.node.add_comment(COMMENT) - self.assertEqual(len(self.node.get_comments()), 1) + assert len(self.node.get_comments()) == 1 options = [str(comment.pk), '--force'] - result = self.cli_runner.invoke(cmd_node.comment_remove, options, catch_exceptions=False) - self.assertEqual(result.exit_code, 0, result.output) - self.assertEqual(len(self.node.get_comments()), 0) + result = self.cli_runner(cmd_node.comment_remove, options, catch_exceptions=False) + assert result.exit_code == 0, result.output + assert len(self.node.get_comments()) == 0 -class TestVerdiRehash(AiidaTestCase): +class TestVerdiRehash: """Tests for the ``verdi node rehash`` command.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test, run_cli_command): + """Setup the database before each test""" from aiida.orm import Bool, Data, Float, Int - cls.node_base = Data().store() - cls.node_bool_true = Bool(True).store() - cls.node_bool_false = Bool(False).store() - cls.node_float = Float(1.0).store() - cls.node_int = Int(1).store() + self.node_base = Data().store() + self.node_bool_true = Bool(True).store() + self.node_bool_false = Bool(False).store() + self.node_float = Float(1.0).store() + self.node_int = Int(1).store() - def setUp(self): - self.cli_runner = CliRunner() + self.cli_runner = run_cli_command def test_rehash_interactive_yes(self): """Passing no options and answering 'Y' to the command will rehash all 5 nodes.""" expected_node_count = 5 options = [] # no option, will ask in the prompt - result = self.cli_runner.invoke(cmd_node.rehash, options, input='y') - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options, input='y') + assert f'{expected_node_count} nodes' in result.output def test_rehash_interactive_no(self): """Passing no options and answering 'N' to the command will abort the command.""" options = [] # no option, will ask in the prompt - result = self.cli_runner.invoke(cmd_node.rehash, options, input='n') - self.assertIsInstance(result.exception, SystemExit) - self.assertIn('ExitCode.CRITICAL', str(result.exception)) + result = self.cli_runner(cmd_node.rehash, options, raises=True, input='n') + assert isinstance(result.exception, SystemExit) + assert 'ExitCode.CRITICAL' in str(result.exception) def test_rehash(self): """Passing no options to the command will rehash all 5 nodes.""" expected_node_count = 5 options = ['-f'] # force, so no questions are asked - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options) + assert f'{expected_node_count} nodes' in result.output def test_rehash_bool(self): """Limiting the queryset by defining an entry point, in this case bool, should limit nodes to 2.""" expected_node_count = 2 options = ['-f', '-e', 'aiida.data:core.bool'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options) + assert f'{expected_node_count} nodes' in result.output def test_rehash_float(self): """Limiting the queryset by defining an entry point, in this case float, should limit nodes to 1.""" expected_node_count = 1 options = ['-f', '-e', 'aiida.data:core.float'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options) + assert f'{expected_node_count} nodes' in result.output def test_rehash_int(self): """Limiting the queryset by defining an entry point, in this case int, should limit nodes to 1.""" expected_node_count = 1 options = ['-f', '-e', 'aiida.data:core.int'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options) + assert f'{expected_node_count} nodes' in result.output def test_rehash_explicit_pk(self): """Limiting the queryset by defining explicit identifiers, should limit nodes to 2 in this example.""" expected_node_count = 2 options = ['-f', str(self.node_bool_true.pk), str(self.node_float.uuid)] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options) + assert f'{expected_node_count} nodes' in result.output def test_rehash_explicit_pk_and_entry_point(self): """Limiting the queryset by defining explicit identifiers and entry point, should limit nodes to 1.""" expected_node_count = 1 options = ['-f', '-e', 'aiida.data:core.bool', str(self.node_bool_true.pk), str(self.node_float.uuid)] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertClickResultNoException(result) - self.assertTrue(f'{expected_node_count} nodes' in result.output) + result = self.cli_runner(cmd_node.rehash, options) + assert f'{expected_node_count} nodes' in result.output def test_rehash_entry_point_no_matches(self): """Limiting the queryset by defining explicit entry point, with no nodes should exit with non-zero status.""" options = ['-f', '-e', 'aiida.data:core.structure'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertIsNotNone(result.exception) + self.cli_runner(cmd_node.rehash, options, raises=True) def test_rehash_invalid_entry_point(self): """Passing an invalid entry point should exit with non-zero status.""" # Incorrect entry point group options = ['-f', '-e', 'data:core.structure'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertIsNotNone(result.exception) + self.cli_runner(cmd_node.rehash, options, raises=True) # Non-existent entry point name options = ['-f', '-e', 'aiida.data:inexistant'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertIsNotNone(result.exception) + self.cli_runner(cmd_node.rehash, options, raises=True) # Incorrect syntax, no colon to join entry point group and name options = ['-f', '-e', 'aiida.data.structure'] - result = self.cli_runner.invoke(cmd_node.rehash, options) - self.assertIsNotNone(result.exception) + self.cli_runner(cmd_node.rehash, options, raises=True) @pytest.mark.parametrize( diff --git a/tests/cmdline/params/types/test_calculation.py b/tests/cmdline/params/types/test_calculation.py index 893900193d..04dfaf5f7d 100644 --- a/tests/cmdline/params/types/test_calculation.py +++ b/tests/cmdline/params/types/test_calculation.py @@ -8,38 +8,37 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the `CalculationParamType`.""" +import pytest -from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.params.types import CalculationParamType from aiida.orm import CalcFunctionNode, CalcJobNode, CalculationNode, WorkChainNode, WorkFunctionNode from aiida.orm.utils.loaders import OrmEntityLoader -class TestCalculationParamType(AiidaTestCase): +class TestCalculationParamType: """Tests for the `CalculationParamType`.""" - @classmethod - def setUpClass(cls, *args, **kwargs): + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test): # pylint: disable=unused-argument """ Create some code to test the CalculationParamType parameter type for the command line infrastructure We create an initial code with a random name and then on purpose create two code with a name that matches exactly the ID and UUID, respectively, of the first one. This allows us to test the rules implemented to solve ambiguities that arise when determing the identifier type """ - super().setUpClass(*args, **kwargs) + # pylint: disable=attribute-defined-outside-init + self.param = CalculationParamType() + self.entity_01 = CalculationNode().store() + self.entity_02 = CalculationNode().store() + self.entity_03 = CalculationNode().store() + self.entity_04 = WorkFunctionNode() + self.entity_05 = CalcFunctionNode() + self.entity_06 = CalcJobNode() + self.entity_07 = WorkChainNode() - cls.param = CalculationParamType() - cls.entity_01 = CalculationNode().store() - cls.entity_02 = CalculationNode().store() - cls.entity_03 = CalculationNode().store() - cls.entity_04 = WorkFunctionNode() - cls.entity_05 = CalcFunctionNode() - cls.entity_06 = CalcJobNode() - cls.entity_07 = WorkChainNode() - - cls.entity_01.label = 'calculation_01' - cls.entity_02.label = str(cls.entity_01.pk) - cls.entity_03.label = str(cls.entity_01.uuid) + self.entity_01.label = 'calculation_01' + self.entity_02.label = str(self.entity_01.pk) + self.entity_03.label = str(self.entity_01.uuid) def test_get_by_id(self): """ @@ -47,7 +46,7 @@ def test_get_by_id(self): """ identifier = f'{self.entity_01.pk}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_get_by_uuid(self): """ @@ -55,7 +54,7 @@ def test_get_by_uuid(self): """ identifier = f'{self.entity_01.uuid}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_get_by_label(self): """ @@ -63,7 +62,7 @@ def test_get_by_label(self): """ identifier = f'{self.entity_01.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_ambiguous_label_pk(self): """ @@ -74,11 +73,11 @@ def test_ambiguous_label_pk(self): """ identifier = f'{self.entity_02.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid identifier = f'{self.entity_02.label}{OrmEntityLoader.label_ambiguity_breaker}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_02.uuid) + assert result.uuid == self.entity_02.uuid def test_ambiguous_label_uuid(self): """ @@ -89,8 +88,8 @@ def test_ambiguous_label_uuid(self): """ identifier = f'{self.entity_03.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid identifier = f'{self.entity_03.label}{OrmEntityLoader.label_ambiguity_breaker}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_03.uuid) + assert result.uuid == self.entity_03.uuid diff --git a/tests/cmdline/params/types/test_data.py b/tests/cmdline/params/types/test_data.py index e541574bbf..b477aa4345 100644 --- a/tests/cmdline/params/types/test_data.py +++ b/tests/cmdline/params/types/test_data.py @@ -8,34 +8,33 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the `DataParamType`.""" +import pytest -from aiida.backends.testbase import AiidaTestCase from aiida.cmdline.params.types import DataParamType from aiida.orm import Data from aiida.orm.utils.loaders import OrmEntityLoader -class TestDataParamType(AiidaTestCase): +class TestDataParamType: """Tests for the `DataParamType`.""" - @classmethod - def setUpClass(cls, *args, **kwargs): + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test): # pylint: disable=unused-argument """ Create some code to test the DataParamType parameter type for the command line infrastructure We create an initial code with a random name and then on purpose create two code with a name that matches exactly the ID and UUID, respectively, of the first one. This allows us to test the rules implemented to solve ambiguities that arise when determing the identifier type """ - super().setUpClass(*args, **kwargs) + # pylint: disable=attribute-defined-outside-init + self.param = DataParamType() + self.entity_01 = Data().store() + self.entity_02 = Data().store() + self.entity_03 = Data().store() - cls.param = DataParamType() - cls.entity_01 = Data().store() - cls.entity_02 = Data().store() - cls.entity_03 = Data().store() - - cls.entity_01.label = 'data_01' - cls.entity_02.label = str(cls.entity_01.pk) - cls.entity_03.label = str(cls.entity_01.uuid) + self.entity_01.label = 'data_01' + self.entity_02.label = str(self.entity_01.pk) + self.entity_03.label = str(self.entity_01.uuid) def test_get_by_id(self): """ @@ -43,7 +42,7 @@ def test_get_by_id(self): """ identifier = f'{self.entity_01.pk}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_get_by_uuid(self): """ @@ -51,7 +50,7 @@ def test_get_by_uuid(self): """ identifier = f'{self.entity_01.uuid}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_get_by_label(self): """ @@ -59,7 +58,7 @@ def test_get_by_label(self): """ identifier = f'{self.entity_01.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_ambiguous_label_pk(self): """ @@ -70,11 +69,11 @@ def test_ambiguous_label_pk(self): """ identifier = f'{self.entity_02.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid identifier = f'{self.entity_02.label}{OrmEntityLoader.label_ambiguity_breaker}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_02.uuid) + assert result.uuid == self.entity_02.uuid def test_ambiguous_label_uuid(self): """ @@ -85,8 +84,8 @@ def test_ambiguous_label_uuid(self): """ identifier = f'{self.entity_03.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid identifier = f'{self.entity_03.label}{OrmEntityLoader.label_ambiguity_breaker}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_03.uuid) + assert result.uuid == self.entity_03.uuid diff --git a/tests/cmdline/params/types/test_node.py b/tests/cmdline/params/types/test_node.py index ecd3b53d72..1cc3eba42a 100644 --- a/tests/cmdline/params/types/test_node.py +++ b/tests/cmdline/params/types/test_node.py @@ -8,33 +8,33 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for the `NodeParamType`.""" -from aiida.backends.testbase import AiidaTestCase +import pytest + from aiida.cmdline.params.types import NodeParamType from aiida.orm import Data from aiida.orm.utils.loaders import OrmEntityLoader -class TestNodeParamType(AiidaTestCase): +class TestNodeParamType: """Tests for the `NodeParamType`.""" - @classmethod - def setUpClass(cls, *args, **kwargs): + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test): # pylint: disable=unused-argument """ Create some code to test the NodeParamType parameter type for the command line infrastructure We create an initial code with a random name and then on purpose create two code with a name that matches exactly the ID and UUID, respectively, of the first one. This allows us to test the rules implemented to solve ambiguities that arise when determing the identifier type """ - super().setUpClass(*args, **kwargs) - - cls.param = NodeParamType() - cls.entity_01 = Data().store() - cls.entity_02 = Data().store() - cls.entity_03 = Data().store() + # pylint: disable=attribute-defined-outside-init + self.param = NodeParamType() + self.entity_01 = Data().store() + self.entity_02 = Data().store() + self.entity_03 = Data().store() - cls.entity_01.label = 'data_01' - cls.entity_02.label = str(cls.entity_01.pk) - cls.entity_03.label = str(cls.entity_01.uuid) + self.entity_01.label = 'data_01' + self.entity_02.label = str(self.entity_01.pk) + self.entity_03.label = str(self.entity_01.uuid) def test_get_by_id(self): """ @@ -42,7 +42,7 @@ def test_get_by_id(self): """ identifier = f'{self.entity_01.pk}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_get_by_uuid(self): """ @@ -50,7 +50,7 @@ def test_get_by_uuid(self): """ identifier = f'{self.entity_01.uuid}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_get_by_label(self): """ @@ -58,7 +58,7 @@ def test_get_by_label(self): """ identifier = f'{self.entity_01.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid def test_ambiguous_label_pk(self): """ @@ -69,11 +69,11 @@ def test_ambiguous_label_pk(self): """ identifier = f'{self.entity_02.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid identifier = f'{self.entity_02.label}{OrmEntityLoader.label_ambiguity_breaker}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_02.uuid) + assert result.uuid == self.entity_02.uuid def test_ambiguous_label_uuid(self): """ @@ -84,8 +84,8 @@ def test_ambiguous_label_uuid(self): """ identifier = f'{self.entity_03.label}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_01.uuid) + assert result.uuid == self.entity_01.uuid identifier = f'{self.entity_03.label}{OrmEntityLoader.label_ambiguity_breaker}' result = self.param.convert(identifier, None, None) - self.assertEqual(result.uuid, self.entity_03.uuid) + assert result.uuid == self.entity_03.uuid diff --git a/tests/conftest.py b/tests/conftest.py index 41167de2a2..a18b90028f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -428,7 +428,7 @@ def run_cli_command(reset_log_level): # pylint: disable=unused-argument """ from click.testing import Result - def _run_cli_command(command: click.Command, options: list = None, raises: bool = False) -> Result: + def _run_cli_command(command: click.Command, options: list = None, raises: bool = False, **kwargs) -> Result: """Run the command and check the result. .. note:: the `output_lines` attribute is added to return value containing list of stripped output lines. @@ -457,12 +457,12 @@ def _run_cli_command(command: click.Command, options: list = None, raises: bool command = VerdiCommandGroup.add_verbosity_option(command) runner = click.testing.CliRunner() - result = runner.invoke(command, options, obj=obj) + result = runner.invoke(command, options, obj=obj, **kwargs) if raises: assert result.exception is not None, result.output assert result.exit_code != 0 - else: + elif kwargs.get('catch_exceptions', True): assert result.exception is None, ''.join(traceback.format_exception(*result.exc_info)) assert result.exit_code == 0, result.output From f971083dd63cb7a4054929ba4d838ab4dfe74f60 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 14 Oct 2021 19:11:12 +0200 Subject: [PATCH 07/16] Update test_code.py --- tests/cmdline/commands/test_code.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/cmdline/commands/test_code.py b/tests/cmdline/commands/test_code.py index 3c6ea0aa9e..1883a31ffc 100644 --- a/tests/cmdline/commands/test_code.py +++ b/tests/cmdline/commands/test_code.py @@ -27,13 +27,15 @@ class TestVerdiCodeSetup(AiidaTestCase): """Tests for the 'verdi code setup' command.""" def setUp(self): - self.computer = orm.Computer.objects.get_or_create( + created, self.computer = orm.Computer.objects.get_or_create( label='comp', hostname='localhost', transport_type='core.local', scheduler_type='core.direct', workdir='/tmp/aiida' ) + if created: + self.computer.store() self.cli_runner = CliRunner() self.this_folder = os.path.dirname(__file__) self.this_file = os.path.basename(__file__) @@ -115,13 +117,15 @@ class TestVerdiCodeCommands(AiidaTestCase): Testing everything besides `code setup`.""" def setUp(self): - self.computer = orm.Computer.objects.get_or_create( + created, self.computer = orm.Computer.objects.get_or_create( label='comp', hostname='localhost', transport_type='core.local', scheduler_type='core.direct', workdir='/tmp/aiida' ) + if created: + self.computer.store() try: code = orm.Code.get_from_string('code') except NotExistent: From 9e17f32851088326eee8dd7c98b002073eed83f5 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 15 Oct 2021 01:31:16 +0200 Subject: [PATCH 08/16] Ensure all backends are closed on AiidaTestCase.tearDown --- aiida/backends/testbase.py | 2 ++ aiida/orm/implementation/backends.py | 17 +++++++++++++ aiida/orm/implementation/sql/backends.py | 31 +++++++++++------------- 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/aiida/backends/testbase.py b/aiida/backends/testbase.py index efb7ba4c09..bbcad0a59f 100644 --- a/aiida/backends/testbase.py +++ b/aiida/backends/testbase.py @@ -105,7 +105,9 @@ def tearDownClass(cls): cls.clean_repository() def tearDown(self): + from aiida.orm.implementation.backends import close_all_backends reset_manager() + close_all_backends() # the user and computer need to be reset, so that they can be set to the new backend # pylint: disable=protected-access self.__class__._computer = None diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index 37b5ce99b8..8907223de5 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -10,6 +10,7 @@ """Generic backend related objects""" import abc from typing import TYPE_CHECKING +import weakref if TYPE_CHECKING: from sqlalchemy.orm.session import Session @@ -28,6 +29,20 @@ __all__ = ('Backend',) +_backends = weakref.WeakValueDictionary() +"""Weak-referencing dictionary of loaded Backend. +""" + + +def close_all_backends(): + """Close all loaded backends. + + This function is not for general use but may be useful for test suites + within the teardown scheme. + """ + for backend in _backends.values(): + backend.close() + class Backend(abc.ABC): """Abstraction for a backend to read/write persistent data for a profile's provenance graph.""" @@ -39,6 +54,8 @@ def __init__(self, profile: 'Profile', validate_db: bool = True) -> None: # pyl :param validate_db: if True, the backend will perform validation tests on the database consistency """ self._profile = profile + self._hashkey = 1 if not _backends else max(_backends.keys()) + 1 + _backends[self._hashkey] = self @property def profile(self) -> 'Profile': diff --git a/aiida/orm/implementation/sql/backends.py b/aiida/orm/implementation/sql/backends.py index 618c3ad48e..d27a2783cb 100644 --- a/aiida/orm/implementation/sql/backends.py +++ b/aiida/orm/implementation/sql/backends.py @@ -11,13 +11,12 @@ import abc from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from sqlalchemy.future.engine import Engine +from sqlalchemy.orm import Session + from .. import backends if TYPE_CHECKING: - from sqlalchemy.future.engine import Engine - from sqlalchemy.future.orm import Session - from sqlalchemy.orm.scoping import scoped_session - from aiida.backends.manager import BackendManager __all__ = ('SqlBackend',) @@ -39,30 +38,28 @@ class SqlBackend(Generic[ModelType], backends.Backend): def __init__(self, profile, validate_db: bool = True): super().__init__(profile, validate_db) # set variables for QueryBuilder - self._engine: Optional['Engine'] = None - self._session_factory: Optional['scoped_session'] = None + self._engine: Optional[Engine] = None + self._session: Optional[Session] = None def get_session(self, **kwargs: Any) -> 'Session': - """Return a database session that can be used by the `QueryBuilderBackend`. + """Return an SQLAlchemy database session. On first call (or after a reset) the session is initialised, then the same session is always returned. :param kwargs: keyword arguments to be passed to the engine """ - from aiida.backends.utils import create_scoped_session_factory, create_sqlalchemy_engine - if self._session_factory is not None: - return self._session_factory() + from aiida.backends.utils import create_sqlalchemy_engine + if self._engine is None: self._engine = create_sqlalchemy_engine(self._profile, **kwargs) - self._session_factory = create_scoped_session_factory(self._engine) - return self._session_factory() + if self._session is None: + self._session = Session(bind=self._engine, future=True) + return self._session def close(self): - if self._session_factory is not None: - # these methods are proxied from the session - self._session_factory.expunge_all() # pylint: disable=no-member - self._session_factory.remove() # pylint: disable=no-member - self._session_factory = None + if self._session is not None: + self._session.close() + self._session = None if self._engine is not None: self._engine.dispose() self._engine = None From 3d816ec6e62a84829214878735719c667463b7ac Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 15 Oct 2021 02:26:57 +0200 Subject: [PATCH 09/16] fix tests --- aiida/manage/tests/main.py | 4 + aiida/orm/implementation/backends.py | 4 +- aiida/orm/implementation/sql/backends.py | 5 +- docs/source/nitpick-exceptions | 3 + .../processes/calcjobs/test_calc_job.py | 95 +++++++++---------- tests/engine/test_manager.py | 4 +- tests/restapi/conftest.py | 3 +- 7 files changed, 61 insertions(+), 57 deletions(-) diff --git a/aiida/manage/tests/main.py b/aiida/manage/tests/main.py index ab88a2aa6c..9f266059dc 100644 --- a/aiida/manage/tests/main.py +++ b/aiida/manage/tests/main.py @@ -164,8 +164,12 @@ def _select_db_test_case(self, backend): self._test_case = SqlAlchemyTests() def reset_db(self): + """Reset the test database state.""" + from aiida.orm.implementation.backends import close_all_backends + self._test_case.clean_db() # will drop all users manager.reset_manager() + close_all_backends() self.init_db() def init_db(self): diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index 8907223de5..da82771c9a 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -30,7 +30,7 @@ __all__ = ('Backend',) _backends = weakref.WeakValueDictionary() -"""Weak-referencing dictionary of loaded Backend. +"""Weak-referencing dictionary of loaded Backends. """ @@ -71,7 +71,7 @@ def close(self) -> None: """ @abc.abstractmethod - def reset(self) -> None: + def reset(self, **kwargs) -> None: """Reset the backend. This method should reset any open connections to the persistent storage diff --git a/aiida/orm/implementation/sql/backends.py b/aiida/orm/implementation/sql/backends.py index d27a2783cb..81ddb64a2e 100644 --- a/aiida/orm/implementation/sql/backends.py +++ b/aiida/orm/implementation/sql/backends.py @@ -64,8 +64,9 @@ def close(self): self._engine.dispose() self._engine = None - def reset(self): - self.close() # close the connection, so that it will be regenerated on get_session + def reset(self, **kwargs): + self.close() + self.get_session(**kwargs) @property @abc.abstractmethod diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 99af999ad5..39e6d8a8c3 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -46,6 +46,7 @@ py:class aiida.tools.groups.paths.WalkNodeResult py:class Backend py:class BackendEntity py:class BackendNode +py:class Base py:class AuthInfo py:class CalcJob py:class CalcJobImporter @@ -66,11 +67,13 @@ py:class Process py:class ProcessBuilder py:class ProcessNode py:class ProcessSpec +py:class Profile py:class Port py:class PortNamespace py:class Repository py:class Runner py:class Scheduler +py:class SqlaBackend py:class Transport py:class TransportQueue py:class WorkChain diff --git a/tests/engine/processes/calcjobs/test_calc_job.py b/tests/engine/processes/calcjobs/test_calc_job.py index 26fed5984f..ef1853d19f 100644 --- a/tests/engine/processes/calcjobs/test_calc_job.py +++ b/tests/engine/processes/calcjobs/test_calc_job.py @@ -20,7 +20,6 @@ import pytest from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.common import CalcJobState, LinkType, StashMode, exceptions from aiida.engine import CalcJob, CalcJobImporter, ExitCode, Process, launch from aiida.engine.processes.calcjobs.calcjob import validate_stash_options @@ -168,16 +167,17 @@ def test_multi_codes_run_withmpi(aiida_local_code_factory, file_regression, calc @pytest.mark.requires_rmq -class TestCalcJob(AiidaTestCase): +class TestCalcJob: """Test for the `CalcJob` process sub class.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.computer.configure() # pylint: disable=no-member - cls.remote_code = orm.Code(remote_computer_exec=(cls.computer, '/bin/bash')).store() - cls.local_code = orm.Code(local_executable='bash', files=['/bin/bash']).store() - cls.inputs = {'x': orm.Int(1), 'y': orm.Int(2), 'metadata': {'options': {}}} + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test, aiida_localhost): # pylint: disable=unused-argument + """Setup database before each test""" + # pylint: disable=attribute-defined-outside-init + self.computer = aiida_localhost + self.remote_code = orm.Code(remote_computer_exec=(self.computer, '/bin/bash')).store() + self.local_code = orm.Code(local_executable='bash', files=['/bin/bash']).store() + self.inputs = {'x': orm.Int(1), 'y': orm.Int(2), 'metadata': {'options': {}}} def instantiate_process(self, state=CalcJobState.PARSING): """Instantiate a process with default inputs and return the `Process` instance.""" @@ -196,29 +196,27 @@ def instantiate_process(self, state=CalcJobState.PARSING): return process - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) + def setup_method(self): + assert Process.current() is None - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + def teardown_method(self): + assert Process.current() is None def test_run_base_class(self): """Verify that it is impossible to run, submit or instantiate a base `CalcJob` class.""" - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): CalcJob() - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run(CalcJob) - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run_get_node(CalcJob) - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run_get_pk(CalcJob) - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.submit(CalcJob) def test_define_not_calling_super(self): @@ -234,13 +232,13 @@ def define(cls, spec): def prepare_for_submission(self, folder): pass - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): launch.run(IncompleteDefineCalcJob) def test_spec_options_property(self): """`CalcJob.spec_options` should return the options port namespace of its spec.""" - self.assertIsInstance(CalcJob.spec_options, PortNamespace) - self.assertEqual(CalcJob.spec_options, CalcJob.spec().inputs['metadata']['options']) + assert isinstance(CalcJob.spec_options, PortNamespace) + assert CalcJob.spec_options == CalcJob.spec().inputs['metadata']['options'] def test_invalid_options_type(self): """Verify that passing an invalid type to `metadata.options` raises a `TypeError`.""" @@ -256,7 +254,7 @@ def prepare_for_submission(self, folder): pass # The `metadata.options` input expects a plain dict and not a node `Dict` - with self.assertRaises(TypeError): + with pytest.raises(TypeError): launch.run(SimpleCalcJob, code=self.remote_code, metadata={'options': orm.Dict(dict={'a': 1})}) def test_remote_code_set_computer_implicit(self): @@ -269,8 +267,8 @@ def test_remote_code_set_computer_implicit(self): inputs['code'] = self.remote_code process = ArithmeticAddCalculation(inputs=inputs) - self.assertTrue(process.node.is_stored) - self.assertEqual(process.node.computer.uuid, self.remote_code.computer.uuid) + assert process.node.is_stored + assert process.node.computer.uuid == self.remote_code.computer.uuid def test_remote_code_unstored_computer(self): """Test launching a `CalcJob` with an unstored computer which should raise.""" @@ -278,7 +276,7 @@ def test_remote_code_unstored_computer(self): inputs['code'] = self.remote_code inputs['metadata']['computer'] = orm.Computer('different', 'localhost', 'desc', 'core.local', 'core.direct') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ArithmeticAddCalculation(inputs=inputs) def test_remote_code_set_computer_explicit(self): @@ -291,7 +289,7 @@ def test_remote_code_set_computer_explicit(self): inputs['code'] = self.remote_code # Setting explicitly a computer that is not the same as that of the `code` should raise - with self.assertRaises(ValueError): + with pytest.raises(ValueError): inputs['metadata']['computer'] = orm.Computer( 'different', 'localhost', 'desc', 'core.local', 'core.direct' ).store() @@ -300,8 +298,8 @@ def test_remote_code_set_computer_explicit(self): # Setting the same computer as that of the `code` effectively accomplishes nothing but should be fine inputs['metadata']['computer'] = self.remote_code.computer process = ArithmeticAddCalculation(inputs=inputs) - self.assertTrue(process.node.is_stored) - self.assertEqual(process.node.computer.uuid, self.remote_code.computer.uuid) + assert process.node.is_stored + assert process.node.computer.uuid == self.remote_code.computer.uuid def test_local_code_set_computer(self): """Test launching a `CalcJob` with a local code *with* explicitly defining a computer, which should work.""" @@ -310,15 +308,15 @@ def test_local_code_set_computer(self): inputs['metadata']['computer'] = self.computer process = ArithmeticAddCalculation(inputs=inputs) - self.assertTrue(process.node.is_stored) - self.assertEqual(process.node.computer.uuid, self.computer.uuid) # pylint: disable=no-member + assert process.node.is_stored + assert process.node.computer.uuid == self.computer.uuid # pylint: disable=no-member def test_local_code_no_computer(self): """Test launching a `CalcJob` with a local code *without* explicitly defining a computer, which should raise.""" inputs = deepcopy(self.inputs) inputs['code'] = self.local_code - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ArithmeticAddCalculation(inputs=inputs) def test_invalid_parser_name(self): @@ -327,7 +325,7 @@ def test_invalid_parser_name(self): inputs['code'] = self.remote_code inputs['metadata']['options']['parser_name'] = 'invalid_parser' - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ArithmeticAddCalculation(inputs=inputs) def test_invalid_resources(self): @@ -336,7 +334,7 @@ def test_invalid_resources(self): inputs['code'] = self.remote_code inputs['metadata']['options']['resources'] = {'num_machines': 'invalid_type'} - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ArithmeticAddCalculation(inputs=inputs) def test_par_env_resources_computer(self): @@ -369,11 +367,9 @@ def test_exception_presubmit(self): """ from aiida.engine.processes.calcjobs.tasks import PreSubmitException - with self.assertRaises(PreSubmitException) as context: + with pytest.raises(PreSubmitException, match='exception occurred in presubmit call'): launch.run(ArithmeticAddCalculation, code=self.remote_code, **self.inputs) - self.assertIn('exception occurred in presubmit call', str(context.exception)) - @pytest.mark.usefixtures('chdir_tmp_path') def test_run_local_code(self): """Run a dry-run with local code.""" @@ -388,7 +384,7 @@ def test_run_local_code(self): # Since the repository will only contain files on the top-level due to `Code.set_files` we only check those for filename in self.local_code.list_object_names(): - self.assertTrue(filename in uploaded_files) + assert filename in uploaded_files @pytest.mark.usefixtures('chdir_tmp_path') def test_rerunnable(self): @@ -446,13 +442,13 @@ def test_provenance_exclude_list(self): # written to the node's repository so we can check it contains the expected contents. _, node = launch.run_get_node(FileCalcJob, **inputs) - self.assertIn('folder', node.dry_run_info) + assert 'folder' in node.dry_run_info # Verify that the folder (representing the node's repository) indeed do not contain the input files. Note, # however, that the directory hierarchy should be there, albeit empty - self.assertIn('base', node.list_object_names()) - self.assertEqual(sorted(['b']), sorted(node.list_object_names(os.path.join('base')))) - self.assertEqual(['two'], node.list_object_names(os.path.join('base', 'b'))) + assert 'base' in node.list_object_names() + assert sorted(['b']) == sorted(node.list_object_names(os.path.join('base'))) + assert ['two'] == node.list_object_names(os.path.join('base', 'b')) def test_parse_no_retrieved_folder(self): """Test the `CalcJob.parse` method when there is no retrieved folder.""" @@ -784,14 +780,15 @@ def test_validate_stash_options(stash_options, expected): assert expected in validate_stash_options(stash_options, None) -class TestImport(AiidaTestCase): +class TestImport: """Test the functionality to import existing calculations completed outside of AiiDA.""" - @classmethod - def setUpClass(cls, *args, **kwargs): - super().setUpClass(*args, **kwargs) - cls.computer.configure() # pylint: disable=no-member - cls.inputs = { + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test, aiida_localhost): # pylint: disable=unused-argument + """Setup database before each test""" + # pylint: disable=attribute-defined-outside-init + self.computer = aiida_localhost + self.inputs = { 'x': orm.Int(1), 'y': orm.Int(2), 'metadata': { diff --git a/tests/engine/test_manager.py b/tests/engine/test_manager.py index 94e07305a4..d9de4c877c 100644 --- a/tests/engine/test_manager.py +++ b/tests/engine/test_manager.py @@ -30,8 +30,8 @@ def setUp(self): self.manager = JobManager(self.transport_queue) def tearDown(self): - super().tearDown() AuthInfo.objects.delete(self.auth_info.pk) + super().tearDown() def test_get_jobs_list(self): """Test the `JobManager.get_jobs_list` method.""" @@ -59,8 +59,8 @@ def setUp(self): self.jobs_list = JobsList(self.auth_info, self.transport_queue) def tearDown(self): - super().tearDown() AuthInfo.objects.delete(self.auth_info.pk) + super().tearDown() def test_get_minimum_update_interval(self): """Test the `JobsList.get_minimum_update_interval` method.""" diff --git a/tests/restapi/conftest.py b/tests/restapi/conftest.py index 3222d7e0ba..ed6264e6a1 100644 --- a/tests/restapi/conftest.py +++ b/tests/restapi/conftest.py @@ -53,8 +53,7 @@ def restrict_sqlalchemy_queuepool(aiida_profile): # pylint: disable=unused-argu from aiida.manage.manager import get_manager backend = get_manager().get_backend() - backend.close() - backend.get_session(pool_timeout=1, max_overflow=0) + backend.reset(pool_timeout=1, max_overflow=0) @pytest.fixture From 43072763846dae01beeff1da4d9ded688f11b094 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 15 Oct 2021 03:00:11 +0200 Subject: [PATCH 10/16] Move create_sqlalchemy_engine to SqlBackend --- aiida/backends/utils.py | 35 ------------------------ aiida/orm/implementation/sql/backends.py | 30 ++++++++++++++++++-- 2 files changed, 28 insertions(+), 37 deletions(-) diff --git a/aiida/backends/utils.py b/aiida/backends/utils.py index 234412e1f1..549b3e2878 100644 --- a/aiida/backends/utils.py +++ b/aiida/backends/utils.py @@ -14,41 +14,6 @@ AIIDA_ATTRIBUTE_SEP = '.' -def create_sqlalchemy_engine(profile, **kwargs): - """Create SQLAlchemy engine (to be used for QueryBuilder queries) - - :param kwargs: keyword arguments that will be passed on to `sqlalchemy.create_engine`. - See https://docs.sqlalchemy.org/en/13/core/engines.html?highlight=create_engine#sqlalchemy.create_engine for - more info. - """ - from sqlalchemy import create_engine - - from aiida.common import json - - # The hostname may be `None`, which is a valid value in the case of peer authentication for example. In this case - # it should be converted to an empty string, because otherwise the `None` will be converted to string literal "None" - hostname = profile.database_hostname or '' - separator = ':' if profile.database_port else '' - - engine_url = 'postgresql://{user}:{password}@{hostname}{separator}{port}/{name}'.format( - separator=separator, - user=profile.database_username, - password=profile.database_password, - hostname=hostname, - port=profile.database_port, - name=profile.database_name - ) - return create_engine( - engine_url, json_serializer=json.dumps, json_deserializer=json.loads, future=True, encoding='utf-8', **kwargs - ) - - -def create_scoped_session_factory(engine, **kwargs): - """Create scoped SQLAlchemy session factory""" - from sqlalchemy.orm import scoped_session, sessionmaker - return scoped_session(sessionmaker(bind=engine, future=True, **kwargs)) - - def delete_nodes_and_connections(pks): """Backend-agnostic function to delete Nodes and connections""" if configuration.PROFILE.database_backend == BACKEND_DJANGO: diff --git a/aiida/orm/implementation/sql/backends.py b/aiida/orm/implementation/sql/backends.py index 81ddb64a2e..e49e085cb6 100644 --- a/aiida/orm/implementation/sql/backends.py +++ b/aiida/orm/implementation/sql/backends.py @@ -11,9 +11,12 @@ import abc from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from sqlalchemy import create_engine from sqlalchemy.future.engine import Engine from sqlalchemy.orm import Session +from aiida.common import json + from .. import backends if TYPE_CHECKING: @@ -25,6 +28,31 @@ ModelType = TypeVar('ModelType') # pylint: disable=invalid-name +def create_sqlalchemy_engine(profile, **kwargs): + """Create SQLAlchemy engine. + + :param kwargs: keyword arguments that will be passed on to `sqlalchemy.create_engine`. + See https://docs.sqlalchemy.org/en/13/core/engines.html?highlight=create_engine#sqlalchemy.create_engine for + more info. + """ + # The hostname may be `None`, which is a valid value in the case of peer authentication for example. In this case + # it should be converted to an empty string, because otherwise the `None` will be converted to string literal "None" + hostname = profile.database_hostname or '' + separator = ':' if profile.database_port else '' + + engine_url = 'postgresql://{user}:{password}@{hostname}{separator}{port}/{name}'.format( + separator=separator, + user=profile.database_username, + password=profile.database_password, + hostname=hostname, + port=profile.database_port, + name=profile.database_name + ) + return create_engine( + engine_url, json_serializer=json.dumps, json_deserializer=json.loads, future=True, encoding='utf-8', **kwargs + ) + + class SqlBackend(Generic[ModelType], backends.Backend): """ A class for SQL based backends. Assumptions are that: @@ -48,8 +76,6 @@ def get_session(self, **kwargs: Any) -> 'Session': :param kwargs: keyword arguments to be passed to the engine """ - from aiida.backends.utils import create_sqlalchemy_engine - if self._engine is None: self._engine = create_sqlalchemy_engine(self._profile, **kwargs) if self._session is None: From de12c96c441c0753431cae62fe23557c2a18cb9a Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 15 Oct 2021 08:43:11 +0200 Subject: [PATCH 11/16] remove unused attributes --- aiida/manage/manager.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/aiida/manage/manager.py b/aiida/manage/manager.py index e8047b307f..32c269a5ab 100644 --- a/aiida/manage/manager.py +++ b/aiida/manage/manager.py @@ -45,9 +45,7 @@ class Manager: def __init__(self) -> None: self._backend: Optional['Backend'] = None - self._config: Optional['Config'] = None self._daemon_client: Optional['DaemonClient'] = None - self._profile: Optional['Profile'] = None self._communicator: Optional['RmqThreadCommunicator'] = None self._process_controller: Optional['RemoteProcessThreadController'] = None self._persister: Optional['AiiDAPersister'] = None @@ -63,8 +61,6 @@ def close(self) -> None: self._backend.close() self._backend = None - self._config = None - self._profile = None self._communicator = None self._daemon_client = None self._process_controller = None From f66a9504381c1d2ee727431909ef7d37af0926a0 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 19 Oct 2021 14:30:54 +0200 Subject: [PATCH 12/16] fix import regression --- aiida/orm/implementation/django/backend.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aiida/orm/implementation/django/backend.py b/aiida/orm/implementation/django/backend.py index c1175c284c..9ae1ef14c4 100644 --- a/aiida/orm/implementation/django/backend.py +++ b/aiida/orm/implementation/django/backend.py @@ -19,7 +19,6 @@ from django.db import models from django.db import transaction as django_transaction -from aiida.backends.djsite.db import models as dbm from aiida.backends.djsite.manager import DjangoBackendManager from aiida.common.exceptions import IntegrityError from aiida.orm.entities import EntityTypes @@ -117,6 +116,8 @@ def _get_model_from_entity(entity_type: EntityTypes, with_pk: bool): """ from sqlalchemy import inspect + from aiida.backends.djsite.db import models as dbm + model = { EntityTypes.AUTHINFO: dbm.DbAuthInfo, EntityTypes.COMMENT: dbm.DbComment, @@ -134,6 +135,7 @@ def _get_model_from_entity(entity_type: EntityTypes, with_pk: bool): return model, keys def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]: + from aiida.backends.djsite.db import models as dbm model, keys = self._get_model_from_entity(entity_type, False) if allow_defaults: for row in rows: @@ -178,6 +180,7 @@ def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: model.objects.bulk_update(objects, fields) def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: + from aiida.backends.djsite.db import models as dbm if not self.in_transaction: raise AssertionError('Cannot delete nodes and links outside a transaction') # Delete all links pointing to or from a given node From 7219b8d7f115bf8826fb82749a4aa22ed7ade988 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 19 Oct 2021 15:00:30 +0200 Subject: [PATCH 13/16] fix linting --- aiida/backends/testbase.py | 2 +- aiida/manage/manager.py | 1 + aiida/orm/implementation/backends.py | 4 +++- aiida/orm/implementation/sqlalchemy/backend.py | 11 ++--------- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/aiida/backends/testbase.py b/aiida/backends/testbase.py index bbcad0a59f..a65595964c 100644 --- a/aiida/backends/testbase.py +++ b/aiida/backends/testbase.py @@ -17,7 +17,7 @@ from aiida.common.exceptions import ConfigurationError, InternalError, TestsNotAllowedError from aiida.common.lang import classproperty from aiida.manage import configuration -from aiida.manage.manager import get_manager, reset_manager +from aiida.manage.manager import reset_manager if TYPE_CHECKING: from aiida.orm.implementation import Backend diff --git a/aiida/manage/manager.py b/aiida/manage/manager.py index 32c269a5ab..6dd8eb95f8 100644 --- a/aiida/manage/manager.py +++ b/aiida/manage/manager.py @@ -125,6 +125,7 @@ def _load_backend(self, validate_db: bool = True, repository_check: bool = True) backend_type = profile.database_backend + backend_cls: 'Backend' if backend_type == BACKEND_DJANGO: from aiida.orm.implementation.django.backend import DjangoBackend backend_cls = DjangoBackend diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index 200cc12c35..ab8b4bbf2e 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -28,9 +28,11 @@ BackendUserCollection, ) + BackendCacheType = weakref.WeakValueDictionary[int, 'Backend'] # pylint: disable=unsubscriptable-object + __all__ = ('Backend',) -_backends = weakref.WeakValueDictionary() +_backends: 'BackendCacheType' = weakref.WeakValueDictionary() """Weak-referencing dictionary of loaded Backends. """ diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index a732463b32..8189800c93 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -196,11 +196,11 @@ def get_backend_entity(self, model): @contextmanager def cursor(self): + connection = self.get_session().bind.raw_connection() # type: ignore[union-attr] try: - connection = self.get_session().bind.raw_connection() yield connection.cursor() finally: - self._get_connection().close() + connection.close() def execute_raw(self, query): from sqlalchemy import text @@ -215,10 +215,3 @@ def execute_raw(self, query): return None return results - - def _get_connection(self): - """Get the SQLA database connection - - :return: the raw SQLA database connection - """ - return self.get_session().bind.raw_connection() From 8781d0f46b93e7d4dc733b2209e105aa6225101e Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 19 Oct 2021 15:03:47 +0200 Subject: [PATCH 14/16] fix tests --- tests/orm/implementation/test_backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/orm/implementation/test_backend.py b/tests/orm/implementation/test_backend.py index 12a91ada6d..f90b85864b 100644 --- a/tests/orm/implementation/test_backend.py +++ b/tests/orm/implementation/test_backend.py @@ -16,12 +16,11 @@ from aiida.orm.entities import EntityTypes -@pytest.mark.usefixtures('clear_database_before_test') class TestBackend: """Test backend.""" @pytest.fixture(autouse=True) - def init_test(self, backend): + def init_test(self, clear_database_before_test, backend): # pylint: disable=unused-argument """Set up the backend.""" self.backend = backend # pylint: disable=attribute-defined-outside-init From 89d1a78a7f13118e7f7310b5d866ad2073cfee04 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Oct 2021 13:05:58 +0000 Subject: [PATCH 15/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/orm/implementation/test_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/orm/implementation/test_backend.py b/tests/orm/implementation/test_backend.py index f90b85864b..10d2b54070 100644 --- a/tests/orm/implementation/test_backend.py +++ b/tests/orm/implementation/test_backend.py @@ -20,7 +20,7 @@ class TestBackend: """Test backend.""" @pytest.fixture(autouse=True) - def init_test(self, clear_database_before_test, backend): # pylint: disable=unused-argument + def init_test(self, clear_database_before_test, backend): # pylint: disable=unused-argument """Set up the backend.""" self.backend = backend # pylint: disable=attribute-defined-outside-init From 89f83f53390ebe212b978a6dedba779c1bab935c Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 19 Oct 2021 15:14:37 +0200 Subject: [PATCH 16/16] Update testbase.py --- aiida/backends/testbase.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiida/backends/testbase.py b/aiida/backends/testbase.py index a65595964c..6a6a8f21b2 100644 --- a/aiida/backends/testbase.py +++ b/aiida/backends/testbase.py @@ -71,6 +71,7 @@ def get_backend_class(cls): @classproperty def backend(cls) -> 'Backend': # pylint: disable=no-self-argument,no-self-use """Return the backend instance.""" + from aiida.manage.manager import get_manager return get_manager().get_backend() @classmethod