From 09cb6cb70df024143ac5a3720e4034b609536b02 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 19 Dec 2019 19:24:53 +0100 Subject: [PATCH] Implement `Backend.get_session` to retrieve scoped session The scoped session is an instance of SqlAlchemy's `Session` class that is used by the query builder to connect to the database, for both the SqlAlchemy database backend as well as for Django. Both database backends need to maintain their own scoped session factory which can be called to get a session instance. Certain applications need access to the session. For example, applications that run AiiDA in a threaded way, such as a REST API server need to manually close the session after the query has finished because this is not done automatically when the thread ends. The associated database connection remains open causing an eventual timeout when a new request comes in. The method `Backend.get_session` provides an official API to access the global scoped session instance which can then be closed. Additionally, a lot of code that was duplicated across the two implementations of the `QueryBuilder` for the two database backends has been moved to the abstract `BackendQueryBuilder`. Normally this code does indeed belong in the implementations but since the current implementation for both backends is based on SqlAlchemy they are both nearly identical. When in the future a new backend is implemented that does not use SqlAlchemy the current code can be factored out to a specific `SqlAlchemyQueryBuilder` that can be used for both database backends. --- aiida/backends/djsite/__init__.py | 54 +++++ aiida/backends/djsite/manager.py | 2 +- aiida/backends/djsite/queries.py | 13 +- aiida/backends/sqlalchemy/__init__.py | 83 ++++--- aiida/backends/sqlalchemy/manager.py | 9 +- aiida/backends/utils.py | 21 ++ aiida/orm/implementation/backends.py | 7 + aiida/orm/implementation/django/backend.py | 13 +- .../orm/implementation/django/querybuilder.py | 219 +----------------- aiida/orm/implementation/querybuilder.py | 154 ++++++++++-- .../orm/implementation/sqlalchemy/backend.py | 22 +- .../implementation/sqlalchemy/querybuilder.py | 177 +------------- aiida/orm/querybuilder.py | 2 +- .../migrations/test_migrations_common.py | 5 +- 14 files changed, 314 insertions(+), 467 deletions(-) diff --git a/aiida/backends/djsite/__init__.py b/aiida/backends/djsite/__init__.py index 2776a55f97..c4b438964f 100644 --- a/aiida/backends/djsite/__init__.py +++ b/aiida/backends/djsite/__init__.py @@ -7,3 +7,57 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=global-statement +"""Module with implementation of the database backend using Django.""" +from aiida.backends.utils import create_sqlalchemy_engine, create_scoped_session_factory + +ENGINE = None +SESSION_FACTORY = None + + +def reset_session(): + """Reset the session which means setting the global engine and session factory instances to `None`.""" + global ENGINE + global SESSION_FACTORY + + if ENGINE is not None: + ENGINE.dispose() + + if SESSION_FACTORY is not None: + SESSION_FACTORY.expunge_all() # pylint: disable=no-member + SESSION_FACTORY.close() # pylint: disable=no-member + + ENGINE = None + SESSION_FACTORY = None + + +def get_scoped_session(): + """Return a scoped session for the given profile that is exclusively to be used for the `QueryBuilder`. + + Since the `QueryBuilder` implementation uses SqlAlchemy to map the query onto the models in order to generate the + SQL to be sent to the database, it requires a session, which is an :class:`sqlalchemy.orm.session.Session` instance. + The only purpose is for SqlAlchemy to be able to connect to the database perform the query and retrieve the results. + Even the Django backend implementation will use SqlAlchemy for its `QueryBuilder` and so also needs an SqlA session. + It is important that we do not reuse the scoped session factory in the SqlAlchemy implementation, because that runs + the risk of cross-talk once profiles can be switched dynamically in a single python interpreter. Therefore the + Django implementation of the `QueryBuilder` should keep its own SqlAlchemy engine and scoped session factory + instances that are used to provide the query builder with a session. + + :param profile: :class:`aiida.manage.configuration.profile.Profile` for which to configure the engine. + :return: :class:`sqlalchemy.orm.session.Session` instance with engine configured for the given profile. + """ + from aiida.manage.configuration import get_profile + + global ENGINE + global SESSION_FACTORY + + if SESSION_FACTORY is not None: + session = SESSION_FACTORY() + return session + + if ENGINE is None: + ENGINE = create_sqlalchemy_engine(get_profile()) + + SESSION_FACTORY = create_scoped_session_factory(ENGINE) + + return SESSION_FACTORY() diff --git a/aiida/backends/djsite/manager.py b/aiida/backends/djsite/manager.py index c9b765fc99..f1003a4318 100644 --- a/aiida/backends/djsite/manager.py +++ b/aiida/backends/djsite/manager.py @@ -32,7 +32,7 @@ def _load_backend_environment(self): def reset_backend_environment(self): """Reset the backend environment.""" - from aiida.orm.implementation.django.querybuilder import reset_session + from . import reset_session reset_session() def is_database_schema_ahead(self): diff --git a/aiida/backends/djsite/queries.py b/aiida/backends/djsite/queries.py index 61c6622af7..209d646306 100644 --- a/aiida/backends/djsite/queries.py +++ b/aiida/backends/djsite/queries.py @@ -39,16 +39,19 @@ def get_creation_statistics(self, user_pk=None): # pylint: disable=no-member import sqlalchemy as sa import aiida.backends.djsite.db.models as djmodels - from aiida.orm.implementation.django.querybuilder import DjangoQueryBuilder + from aiida.manage.manager import get_manager + backend = get_manager().get_backend() # Get the session (uses internally aldjemy - so, sqlalchemy) also for the Djsite backend - sssn = DjangoQueryBuilder.get_session() + session = backend.get_session() retdict = {} - total_query = sssn.query(djmodels.DbNode.sa) - types_query = sssn.query(djmodels.DbNode.sa.node_type.label('typestring'), sa.func.count(djmodels.DbNode.sa.id)) - stat_query = sssn.query( + total_query = session.query(djmodels.DbNode.sa) + types_query = session.query( + djmodels.DbNode.sa.node_type.label('typestring'), sa.func.count(djmodels.DbNode.sa.id) + ) + stat_query = session.query( sa.func.date_trunc('day', djmodels.DbNode.sa.ctime).label('cday'), sa.func.count(djmodels.DbNode.sa.id) ) diff --git a/aiida/backends/sqlalchemy/__init__.py b/aiida/backends/sqlalchemy/__init__.py index 5d3e284179..0711f03118 100644 --- a/aiida/backends/sqlalchemy/__init__.py +++ b/aiida/backends/sqlalchemy/__init__.py @@ -9,13 +9,26 @@ ########################################################################### # pylint: disable=import-error,no-name-in-module,global-statement """Module with implementation of the database backend using SqlAlchemy.""" -from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session -from sqlalchemy.orm import sessionmaker +from aiida.backends.utils import create_sqlalchemy_engine, create_scoped_session_factory -# The next two serve as global variables, set in the `load_dbenv` call and should be properly reset upon forking. ENGINE = None -SCOPED_SESSION_CLASS = None +SESSION_FACTORY = None + + +def reset_session(): + """Reset the session which means setting the global engine and session factory instances to `None`.""" + global ENGINE + global SESSION_FACTORY + + if ENGINE is not None: + ENGINE.dispose() + + if SESSION_FACTORY is not None: + SESSION_FACTORY.expunge_all() # pylint: disable=no-member + SESSION_FACTORY.close() # pylint: disable=no-member + + ENGINE = None + SESSION_FACTORY = None def get_scoped_session(): @@ -24,55 +37,37 @@ def get_scoped_session(): According to SQLAlchemy docs, this returns always the same object within a thread, and a different object in a different thread. Moreover, since we update the session class upon forking, different session objects will be used. """ - global SCOPED_SESSION_CLASS - - if SCOPED_SESSION_CLASS is None: - reset_session() + from multiprocessing.util import register_after_fork + from aiida.manage.configuration import get_profile - return SCOPED_SESSION_CLASS() + global ENGINE + global SESSION_FACTORY + if SESSION_FACTORY is not None: + session = SESSION_FACTORY() + return session -def recreate_after_fork(engine): # pylint: disable=unused-argument - """Callback called after a fork. + if ENGINE is None: + ENGINE = create_sqlalchemy_engine(get_profile()) - Not only disposes the engine, but also recreates a new scoped session to use independent sessions in the fork. + SESSION_FACTORY = create_scoped_session_factory(ENGINE, expire_on_commit=True) + register_after_fork(ENGINE, recreate_after_fork) - :param engine: the engine that will be used by the sessionmaker - """ - global ENGINE - global SCOPED_SESSION_CLASS + return SESSION_FACTORY() - ENGINE.dispose() - SCOPED_SESSION_CLASS = scoped_session(sessionmaker(bind=ENGINE, expire_on_commit=True)) +def recreate_after_fork(engine): + """Callback called after a fork. -def reset_session(profile=None): - """ - Resets (global) engine and sessionmaker classes, to create a new one - (or creates a new one from scratch if not already available) + Not only disposes the engine, but also recreates a new scoped session to use independent sessions in the fork. - :param profile: the profile whose configuration to use to connect to the database + :param engine: the engine that will be used by the sessionmaker """ - from multiprocessing.util import register_after_fork - from aiida.common import json from aiida.manage.configuration import get_profile global ENGINE - global SCOPED_SESSION_CLASS - - if profile is None: - profile = get_profile() - - separator = ':' if profile.database_port else '' - engine_url = 'postgresql://{user}:{password}@{hostname}{separator}{port}/{name}'.format( - separator=separator, - user=profile.database_username, - password=profile.database_password, - hostname=profile.database_hostname, - port=profile.database_port, - name=profile.database_name - ) - - ENGINE = create_engine(engine_url, json_serializer=json.dumps, json_deserializer=json.loads, encoding='utf-8') - SCOPED_SESSION_CLASS = scoped_session(sessionmaker(bind=ENGINE, expire_on_commit=True)) - register_after_fork(ENGINE, recreate_after_fork) + global SESSION_FACTORY + + engine.dispose() + ENGINE = create_sqlalchemy_engine(get_profile()) + SESSION_FACTORY = create_scoped_session_factory(ENGINE, expire_on_commit=True) diff --git a/aiida/backends/sqlalchemy/manager.py b/aiida/backends/sqlalchemy/manager.py index be6d5bf1b3..83819858fa 100644 --- a/aiida/backends/sqlalchemy/manager.py +++ b/aiida/backends/sqlalchemy/manager.py @@ -54,15 +54,12 @@ def get_settings_manager(self): def _load_backend_environment(self): """Load the backend environment.""" - from . import reset_session - reset_session() + get_scoped_session() def reset_backend_environment(self): """Reset the backend environment.""" - from aiida.backends import sqlalchemy - if sqlalchemy.ENGINE is not None: - sqlalchemy.ENGINE.dispose() - sqlalchemy.SCOPED_SESSION_CLASS = None + from . import reset_session + reset_session() def is_database_schema_ahead(self): """Determine whether the database schema version is ahead of the code schema version. diff --git a/aiida/backends/utils.py b/aiida/backends/utils.py index 19d2d99da9..3fc9362f4e 100644 --- a/aiida/backends/utils.py +++ b/aiida/backends/utils.py @@ -13,6 +13,27 @@ AIIDA_ATTRIBUTE_SEP = '.' +def create_sqlalchemy_engine(profile): + from sqlalchemy import create_engine + from aiida.common import json + + separator = ':' if profile.database_port else '' + engine_url = 'postgresql://{user}:{password}@{hostname}{separator}{port}/{name}'.format( + separator=separator, + user=profile.database_username, + password=profile.database_password, + hostname=profile.database_hostname, + port=profile.database_port, + name=profile.database_name + ) + return create_engine(engine_url, json_serializer=json.dumps, json_deserializer=json.loads, encoding='utf-8') + + +def create_scoped_session_factory(engine, **kwargs): + from sqlalchemy.orm import scoped_session, sessionmaker + return scoped_session(sessionmaker(bind=engine, **kwargs)) + + def delete_nodes_and_connections(pks): if configuration.PROFILE.database_backend == BACKEND_DJANGO: from aiida.backends.djsite.utils import delete_nodes_and_connections_django as delete_nodes_backend diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index ca5423b17d..6b1c4025be 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -114,6 +114,13 @@ def transaction(self): :return: a context manager to group database operations """ + @abc.abstractmethod + def get_session(self): + """Return a database session that can be used by the `QueryBuilder` to perform its query. + + :return: an instance of :class:`sqlalchemy.orm.session.Session` + """ + class BackendEntity(abc.ABC): """An first-class entity in the backend""" diff --git a/aiida/orm/implementation/django/backend.py b/aiida/orm/implementation/django/backend.py index b4ca0fbacd..fbf9bb0ad5 100644 --- a/aiida/orm/implementation/django/backend.py +++ b/aiida/orm/implementation/django/backend.py @@ -8,7 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Django implementation of `aiida.orm.implementation.backends.Backend`.""" - from contextlib import contextmanager # pylint: disable=import-error,no-name-in-module @@ -89,6 +88,18 @@ def transaction(): """Open a transaction to be used as a context manager.""" return transaction.atomic() + @staticmethod + def get_session(): + """Return a database session that can be used by the `QueryBuilder` to perform its query. + + If there is an exception within the context then the changes will be rolled back and the state will + be as before entering. Transactions can be nested. + + :return: an instance of :class:`sqlalchemy.orm.session.Session` + """ + from aiida.backends.djsite import get_scoped_session + return get_scoped_session() + # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` def get_backend_entity(self, model): diff --git a/aiida/orm/implementation/django/querybuilder.py b/aiida/orm/implementation/django/querybuilder.py index 69ae115f4c..d6b2f1876e 100644 --- a/aiida/orm/implementation/django/querybuilder.py +++ b/aiida/orm/implementation/django/querybuilder.py @@ -8,77 +8,19 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Django query builder""" -import uuid - from aldjemy import core # Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed # pylint: disable=no-name-in-module, import-error -from sqlalchemy_utils.types.choice import Choice from sqlalchemy import and_, or_, not_, case from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.expression import FunctionElement -from sqlalchemy.types import Integer, Float, Boolean, DateTime +from sqlalchemy.types import Float, Boolean from aiida.backends.djsite.db import models from aiida.common.exceptions import InputValidationError from aiida.orm.implementation.querybuilder import BackendQueryBuilder -ENGINE = None -SESSION_FACTORY = None - - -def reset_session(): - """Reset the session which means setting the global engine and session factory instances to `None`.""" - global ENGINE # pylint: disable=global-statement - global SESSION_FACTORY # pylint: disable=global-statement - - ENGINE = None - SESSION_FACTORY = None - - -def get_scoped_session(profile): - """Return a scoped session for the given profile that is exclusively to be used for the `QueryBuilder`. - - Since the `QueryBuilder` implementation uses SqlAlchemy to map the query onto the models in order to generate the - SQL to be sent to the database, it requires a session, which is an :class:`sqlalchemy.orm.session.Session` instance. - The only purpose is for SqlAlchemy to be able to connect to the database perform the query and retrieve the results. - Even the Django backend implementation will use SqlAlchemy for its `QueryBuilder` and so also needs an SqlA session. - It is important that we do not reuse the scoped session factory in the SqlAlchemy implementation, because that runs - the risk of cross-talk once profiles can be switched dynamically in a single python interpreter. Therefore the - Django implementation of the `QueryBuilder` should keep its own SqlAlchemy engine and scoped session factory - instances that are used to provide the query builder with a session. - - :param profile: :class:`aiida.manage.configuration.profile.Profile` for which to configure the engine. - :return: :class:`sqlalchemy.orm.session.Session` instance with engine configured for the given profile. - """ - from sqlalchemy import create_engine - from sqlalchemy.orm import scoped_session, sessionmaker - - global ENGINE # pylint: disable=global-statement - global SESSION_FACTORY # pylint: disable=global-statement - - if SESSION_FACTORY is not None: - session = SESSION_FACTORY() - return session - - if ENGINE is None: - from aiida.common import json - separator = ':' if profile.database_port else '' - engine_url = 'postgresql://{user}:{password}@{hostname}{separator}{port}/{name}'.format( - separator=separator, - user=profile.database_username, - password=profile.database_password, - hostname=profile.database_hostname, - port=profile.database_port, - name=profile.database_name - ) - ENGINE = create_engine(engine_url, json_serializer=json.dumps, json_deserializer=json.loads, encoding='utf-8') - - SESSION_FACTORY = scoped_session(sessionmaker(bind=ENGINE)) - session = SESSION_FACTORY() - return session - class jsonb_array_length(FunctionElement): # pylint: disable=invalid-name # pylint: disable=too-few-public-methods @@ -303,13 +245,6 @@ def get_filter_expr(self, operator, value, attr_key, is_attribute, alias=None, c return not_(expr) return expr - @staticmethod - def get_session(): - from aiida.manage.configuration import get_config - config = get_config() - profile = config.current_profile - return get_scoped_session(profile) - def modify_expansions(self, alias, expansions): """ For django, there are no additional expansions for now, so @@ -407,154 +342,10 @@ def cast_according_to_type(path_in_json, value): raise InputValidationError('Unknown operator {} for filters in JSON field'.format(operator)) return expr - def get_projectable_attribute(self, alias, column_name, attrpath, cast=None, **kwargs): # pylint: disable=redefined-outer-name - """ - :returns: An attribute store in a JSON field of the give column - """ - entity = self.get_column(column_name, alias)[attrpath] - if cast is None: - pass - elif cast == 'f': - entity = entity.astext.cast(Float) - elif cast == 'i': - entity = entity.astext.cast(Integer) - elif cast == 'b': - entity = entity.astext.cast(Boolean) - elif cast == 't': - entity = entity.astext - elif cast == 'j': - entity = entity.astext.cast(JSONB) - elif cast == 'd': - entity = entity.astext.cast(DateTime) - else: - raise InputValidationError('Unkown casting key {}'.format(cast)) - return entity - - def get_aiida_res(self, key, res): - """ - Some instance returned by ORM (django or SA) need to be converted - to Aiida instances (eg nodes). Choice (sqlalchemy_utils) - will return their value - - :param key: The key - :param res: the result returned by the query - - :returns: an aiida-compatible instance - """ - if isinstance(res, Choice): - result = res.value - elif isinstance(res, uuid.UUID): - result = str(res) - else: - try: - result = self._backend.get_backend_entity(res) - except TypeError: - result = res - - return result - - def yield_per(self, query, batch_size): - """ - :param count: Number of rows to yield per step - - Yields *count* rows at a time - - :returns: a generator - """ - from django.db import transaction - with transaction.atomic(): - return query.yield_per(batch_size) - - def count(self, query): - - from django.db import transaction - with transaction.atomic(): - return query.count() - - def first(self, query): - """ - Executes query in the backend asking for one instance. - - :returns: One row of aiida results - """ - from django.db import transaction - with transaction.atomic(): - return query.first() - - def iterall(self, query, batch_size, tag_to_index_dict): - from django.db import transaction - if not tag_to_index_dict: - raise ValueError('Got an empty dictionary: {}'.format(tag_to_index_dict)) - - with transaction.atomic(): - results = query.yield_per(batch_size) - - if len(tag_to_index_dict) == 1: - # Sqlalchemy, for some strange reason, does not return a list of lists - # if you have provided an ormclass - - if list(tag_to_index_dict.values()) == ['*']: - for rowitem in results: - yield [self.get_aiida_res(tag_to_index_dict[0], rowitem)] - else: - for rowitem, in results: - yield [self.get_aiida_res(tag_to_index_dict[0], rowitem)] - elif len(tag_to_index_dict) > 1: - for resultrow in results: - yield [ - self.get_aiida_res(tag_to_index_dict[colindex], rowitem) - for colindex, rowitem in enumerate(resultrow) - ] - - def iterdict(self, query, batch_size, tag_to_projected_properties_dict, tag_to_alias_map): - from django.db import transaction - - def get_table_name(aliased_class): - """ Returns the table name given an Aliased class based on Aldjemy""" - return aliased_class._aliased_insp._target.table.name # pylint: disable=protected-access - - nr_items = sum(len(v) for v in tag_to_projected_properties_dict.values()) - - if not nr_items: - raise ValueError('Got an empty dictionary') - - # Wrapping everything in an atomic transaction: - with transaction.atomic(): - results = query.yield_per(batch_size) - # Two cases: If one column was asked, the database returns a matrix of rows * columns: - if nr_items > 1: - for this_result in results: - yield { - tag: { - self.get_corresponding_property( - get_table_name(tag_to_alias_map[tag]), attrkey, self.inner_to_outer_schema - ): self.get_aiida_res(attrkey, this_result[index_in_sql_result]) - for attrkey, index_in_sql_result in projected_entities_dict.items() - } for tag, projected_entities_dict in tag_to_projected_properties_dict.items() - } - elif nr_items == 1: - # I this case, sql returns a list, where each listitem is the result - # for one row. Here I am converting it to a list of lists (of length 1) - if [v for entityd in tag_to_projected_properties_dict.values() for v in entityd.keys()] == ['*']: - for this_result in results: - yield { - tag: { - self.get_corresponding_property( - get_table_name(tag_to_alias_map[tag]), attrkey, self.inner_to_outer_schema - ): self.get_aiida_res(attrkey, this_result) - for attrkey, position in projected_entities_dict.items() - } for tag, projected_entities_dict in tag_to_projected_properties_dict.items() - } - else: - for this_result, in results: - yield { - tag: { - self.get_corresponding_property( - get_table_name(tag_to_alias_map[tag]), attrkey, self.inner_to_outer_schema - ): self.get_aiida_res(attrkey, this_result) - for attrkey, position in projected_entities_dict.items() - } for tag, projected_entities_dict in tag_to_projected_properties_dict.items() - } + @staticmethod + def get_table_name(aliased_class): + """Returns the table name given an Aliased class based on Aldjemy""" + return aliased_class._aliased_insp._target.table.name # pylint: disable=protected-access def get_column_names(self, alias): """ diff --git a/aiida/orm/implementation/querybuilder.py b/aiida/orm/implementation/querybuilder.py index 9463ac4103..6ef8f33342 100644 --- a/aiida/orm/implementation/querybuilder.py +++ b/aiida/orm/implementation/querybuilder.py @@ -7,8 +7,21 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Backend query implementation classes""" +"""Abstract `QueryBuilder` definition. + +Note that this abstract class actually contains parts of the implementation, which are tightly coupled to SqlAlchemy. +This is done because currently, both database backend implementations, both Django and SqlAlchemy, directly use the +SqlAlchemy library to implement the query builder. If there ever is another database backend to be implemented that does +not go through SqlAlchemy, this class will have to be refactored. The SqlAlchemy specific implementations should most +likely be moved to a `SqlAlchemyBasedQueryBuilder` class and restore this abstract class to being a pure agnostic one. +""" import abc +import uuid + +# pylint: disable=no-name-in-module, import-error +from sqlalchemy_utils.types.choice import Choice +from sqlalchemy.types import Integer, Float, Boolean, DateTime +from sqlalchemy.dialects.postgresql import JSONB from aiida.common import exceptions from aiida.common.lang import abstractclassmethod, type_check @@ -100,11 +113,11 @@ def AiidaNode(self): from aiida.orm import Node return Node - @abc.abstractmethod def get_session(self): """ - :returns: a valid session, an instance of sqlalchemy.orm.session.Session + :returns: a valid session, an instance of :class:`sqlalchemy.orm.session.Session` """ + return self._backend.get_session() @abc.abstractmethod def modify_expansions(self, alias, expansions): @@ -210,23 +223,51 @@ def get_filter_expr_from_column(cls, operator, value, column): raise InputValidationError('Unknown operator {} for filters on columns'.format(operator)) return expr - @abc.abstractmethod def get_projectable_attribute(self, alias, column_name, attrpath, cast=None, **kwargs): - pass + """ + :returns: An attribute store in a JSON field of the give column + """ + # pylint: disable=unused-argument + entity = self.get_column(column_name, alias)[attrpath] + if cast is None: + pass + elif cast == 'f': + entity = entity.astext.cast(Float) + elif cast == 'i': + entity = entity.astext.cast(Integer) + elif cast == 'b': + entity = entity.astext.cast(Boolean) + elif cast == 't': + entity = entity.astext + elif cast == 'j': + entity = entity.astext.cast(JSONB) + elif cast == 'd': + entity = entity.astext.cast(DateTime) + else: + raise InputValidationError('Unkown casting key {}'.format(cast)) + return entity - @abc.abstractmethod - def get_aiida_res(self, key, res): + def get_aiida_res(self, res): """ Some instance returned by ORM (django or SA) need to be converted - to Aiida instances (eg nodes) + to AiiDA instances (eg nodes). Choice (sqlalchemy_utils) + will return their value - :param key: the key that this entry would be returned with :param res: the result returned by the query :returns: an aiida-compatible instance """ + if isinstance(res, Choice): + return res.value + + if isinstance(res, uuid.UUID): + return str(res) + + try: + return self._backend.get_backend_entity(res) + except TypeError: + return res - @abc.abstractmethod def yield_per(self, query, batch_size): """ :param int batch_size: Number of rows to yield per step @@ -235,32 +276,117 @@ def yield_per(self, query, batch_size): :returns: a generator """ + try: + return query.yield_per(batch_size) + except Exception: + self.get_session().close() + raise - @abc.abstractmethod def count(self, query): """ :returns: the number of results """ + try: + return query.count() + except Exception: + self.get_session().close() + raise - @abc.abstractmethod def first(self, query): """ Executes query in the backend asking for one instance. :returns: One row of aiida results """ + try: + return query.first() + except Exception: + self.get_session().close() + raise - @abc.abstractmethod def iterall(self, query, batch_size, tag_to_index_dict): """ :return: An iterator over all the results of a list of lists. """ + try: + if not tag_to_index_dict: + raise Exception('Got an empty dictionary: {}'.format(tag_to_index_dict)) + + results = query.yield_per(batch_size) + + if len(tag_to_index_dict) == 1: + # Sqlalchemy, for some strange reason, does not return a list of lsits + # if you have provided an ormclass + + if list(tag_to_index_dict.values()) == ['*']: + for rowitem in results: + yield [self.get_aiida_res(rowitem)] + else: + for rowitem, in results: + yield [self.get_aiida_res(rowitem)] + elif len(tag_to_index_dict) > 1: + for resultrow in results: + yield [self.get_aiida_res(rowitem) for colindex, rowitem in enumerate(resultrow)] + else: + raise ValueError('Got an empty dictionary') + except Exception: + self.get_session().close() + raise - @abc.abstractmethod def iterdict(self, query, batch_size, tag_to_projected_properties_dict, tag_to_alias_map): """ :returns: An iterator over all the results of a list of dictionaries. """ + try: + nr_items = sum(len(v) for v in tag_to_projected_properties_dict.values()) + + if not nr_items: + raise ValueError('Got an empty dictionary') + + results = query.yield_per(batch_size) + if nr_items > 1: + for this_result in results: + yield { + tag: { + self.get_corresponding_property( + self.get_table_name(tag_to_alias_map[tag]), attrkey, self.inner_to_outer_schema + ): self.get_aiida_res(this_result[index_in_sql_result]) + for attrkey, index_in_sql_result in projected_entities_dict.items() + } for tag, projected_entities_dict in tag_to_projected_properties_dict.items() + } + elif nr_items == 1: + # I this case, sql returns a list, where each listitem is the result + # for one row. Here I am converting it to a list of lists (of length 1) + if [v for entityd in tag_to_projected_properties_dict.values() for v in entityd.keys()] == ['*']: + for this_result in results: + yield { + tag: { + self.get_corresponding_property( + self.get_table_name(tag_to_alias_map[tag]), attrkey, self.inner_to_outer_schema + ): self.get_aiida_res(this_result) + for attrkey, position in projected_entities_dict.items() + } for tag, projected_entities_dict in tag_to_projected_properties_dict.items() + } + else: + for this_result, in results: + yield { + tag: { + self.get_corresponding_property( + self.get_table_name(tag_to_alias_map[tag]), attrkey, self.inner_to_outer_schema + ): self.get_aiida_res(this_result) + for attrkey, position in projected_entities_dict.items() + } for tag, projected_entities_dict in tag_to_projected_properties_dict.items() + } + else: + raise ValueError('Got an empty dictionary') + except Exception: + self.get_session().close() + raise + + @staticmethod + @abstractclassmethod + def get_table_name(aliased_class): + """Returns the table name given an Aliased class.""" @abc.abstractmethod def get_column_names(self, alias): diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index 41ccf1f5cf..224d4933df 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -8,10 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`.""" - from contextlib import contextmanager -from aiida.backends.sqlalchemy import get_scoped_session from aiida.backends.sqlalchemy.models import base from aiida.backends.sqlalchemy.queries import SqlaQueryManager from aiida.backends.sqlalchemy.manager import SqlaBackendManager @@ -83,11 +81,14 @@ def query(self): def users(self): return self._users - @staticmethod @contextmanager - def transaction(): - """Open a transaction to be used as a context manager.""" - session = get_scoped_session() + def transaction(self): + """Open a transaction to be used as a context manager. + + If there is an exception within the context then the changes will be rolled back and the state will be as before + entering. Transactions can be nested. + """ + session = self.get_session() nested = session.transaction.nested try: session.begin_nested() @@ -101,6 +102,15 @@ def transaction(): # Make sure to commit the outermost session session.commit() + @staticmethod + def get_session(): + """Return a database session that can be used by the `QueryBuilder` to perform its query. + + :return: an instance of :class:`sqlalchemy.orm.session.Session` + """ + from aiida.backends.sqlalchemy import get_scoped_session + return get_scoped_session() + # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` def get_backend_entity(self, model): diff --git a/aiida/orm/implementation/sqlalchemy/querybuilder.py b/aiida/orm/implementation/sqlalchemy/querybuilder.py index 4d221a0800..8c7d145f65 100644 --- a/aiida/orm/implementation/sqlalchemy/querybuilder.py +++ b/aiida/orm/implementation/sqlalchemy/querybuilder.py @@ -8,18 +8,13 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Sqla query builder implementation""" - -import uuid - # pylint: disable=no-name-in-module, import-error -from sqlalchemy_utils.types.choice import Choice from sqlalchemy import and_, or_, not_ -from sqlalchemy.types import Integer, Float, Boolean, DateTime +from sqlalchemy.types import Float, Boolean from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.sql.expression import case, FunctionElement from sqlalchemy.ext.compiler import compiles -import aiida.backends.sqlalchemy from aiida.common.exceptions import InputValidationError from aiida.common.exceptions import NotExistent from aiida.orm.implementation.querybuilder import BackendQueryBuilder @@ -126,9 +121,6 @@ def table_groups_nodes(self): import aiida.backends.sqlalchemy.models.group return aiida.backends.sqlalchemy.models.group.table_groups_nodes - def get_session(self): - return aiida.backends.sqlalchemy.get_scoped_session() - def modify_expansions(self, alias, expansions): """ In SQLA, the metadata should be changed to _metadata to be in-line with the database schema @@ -374,169 +366,10 @@ def cast_according_to_type(path_in_json, value): raise InputValidationError('Unknown operator {} for filters in JSON field'.format(operator)) return expr - def get_projectable_attribute(self, alias, column_name, attrpath, cast=None, **kwargs): - """ - :returns: An attribute store in a JSON field of the give column - """ - entity = self.get_column(column_name, alias)[attrpath] - if cast is None: - pass - elif cast == 'f': - entity = entity.astext.cast(Float) - elif cast == 'i': - entity = entity.astext.cast(Integer) - elif cast == 'b': - entity = entity.astext.cast(Boolean) - elif cast == 't': - entity = entity.astext - elif cast == 'j': - entity = entity.astext.cast(JSONB) - elif cast == 'd': - entity = entity.astext.cast(DateTime) - else: - raise InputValidationError('Unkown casting key {}'.format(cast)) - return entity - - def get_aiida_res(self, key, res): - """ - Some instance returned by ORM (django or SA) need to be converted - to AiiDA instances (eg nodes). Choice (sqlalchemy_utils) - will return their value - - :param key: The key - :param res: the result returned by the query - - :returns: an aiida-compatible instance - """ - if isinstance(res, Choice): - returnval = res.value - elif isinstance(res, uuid.UUID): - returnval = str(res) - else: - try: - returnval = self._backend.get_backend_entity(res) - except TypeError: - returnval = res - - return returnval - - def yield_per(self, query, batch_size): - """ - :param count: Number of rows to yield per step - - Yields *count* rows at a time - - :returns: a generator - """ - try: - return query.yield_per(batch_size) - except Exception: - # exception was raised. Rollback the session - self.get_session().rollback() - raise - - def count(self, query): - try: - return query.count() - except Exception: - # exception was raised. Rollback the session - self.get_session().rollback() - raise - - def first(self, query): - """ - Executes query in the backend asking for one instance. - - :returns: One row of aiida results - """ - try: - return query.first() - except Exception: - # exception was raised. Rollback the session - self.get_session().rollback() - raise - - def iterall(self, query, batch_size, tag_to_index_dict): - if not tag_to_index_dict: - raise Exception('Got an empty dictionary: {}'.format(tag_to_index_dict)) - - try: - results = query.yield_per(batch_size) - - if len(tag_to_index_dict) == 1: - # Sqlalchemy, for some strange reason, does not return a list of lsits - # if you have provided an ormclass - - if list(tag_to_index_dict.values()) == ['*']: - for rowitem in results: - yield [self.get_aiida_res(tag_to_index_dict[0], rowitem)] - else: - for rowitem, in results: - yield [self.get_aiida_res(tag_to_index_dict[0], rowitem)] - elif len(tag_to_index_dict) > 1: - for resultrow in results: - yield [ - self.get_aiida_res(tag_to_index_dict[colindex], rowitem) - for colindex, rowitem in enumerate(resultrow) - ] - else: - raise ValueError('Got an empty dictionary') - except Exception: - self.get_session().rollback() - raise - - def iterdict(self, query, batch_size, tag_to_projected_properties_dict, tag_to_alias_map): - - def get_table_name(aliased_class): - """ Returns the table name given an Aliased class""" - return aliased_class.__tablename__ - - nr_items = sum(len(v) for v in tag_to_projected_properties_dict.values()) - - if not nr_items: - raise ValueError('Got an empty dictionary') - - # Wrapping everything in an atomic transaction: - try: - results = query.yield_per(batch_size) - if nr_items > 1: - for this_result in results: - yield { - tag: { - self.get_corresponding_property( - get_table_name(tag_to_alias_map[tag]), attrkey, self.inner_to_outer_schema - ): self.get_aiida_res(attrkey, this_result[index_in_sql_result]) - for attrkey, index_in_sql_result in projected_entities_dict.items() - } for tag, projected_entities_dict in tag_to_projected_properties_dict.items() - } - elif nr_items == 1: - # I this case, sql returns a list, where each listitem is the result - # for one row. Here I am converting it to a list of lists (of length 1) - if [v for entityd in tag_to_projected_properties_dict.values() for v in entityd.keys()] == ['*']: - for this_result in results: - yield { - tag: { - self.get_corresponding_property( - get_table_name(tag_to_alias_map[tag]), attrkey, self.inner_to_outer_schema - ): self.get_aiida_res(attrkey, this_result) - for attrkey, position in projected_entities_dict.items() - } for tag, projected_entities_dict in tag_to_projected_properties_dict.items() - } - else: - for this_result, in results: - yield { - tag: { - self.get_corresponding_property( - get_table_name(tag_to_alias_map[tag]), attrkey, self.inner_to_outer_schema - ): self.get_aiida_res(attrkey, this_result) - for attrkey, position in projected_entities_dict.items() - } for tag, projected_entities_dict in tag_to_projected_properties_dict.items() - } - else: - raise ValueError('Got an empty dictionary') - except Exception: - self.get_session().rollback() - raise + @staticmethod + def get_table_name(aliased_class): + """ Returns the table name given an Aliased class""" + return aliased_class.__tablename__ def get_column_names(self, alias): """ diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index 9eb0face6a..7f622bed06 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -2067,7 +2067,7 @@ def first(self): raise Exception('length of query result does not match the number of specified projections') return [ - self.get_aiida_entity_res(self._impl.get_aiida_res(self._attrkeys_as_in_sql_result[colindex], rowitem)) + self.get_aiida_entity_res(self._impl.get_aiida_res(rowitem)) for colindex, rowitem in enumerate(result) ] diff --git a/tests/backends/aiida_django/migrations/test_migrations_common.py b/tests/backends/aiida_django/migrations/test_migrations_common.py index f536c2a378..f8de61f9a6 100644 --- a/tests/backends/aiida_django/migrations/test_migrations_common.py +++ b/tests/backends/aiida_django/migrations/test_migrations_common.py @@ -15,7 +15,6 @@ from aiida.backends.testbase import AiidaTestCase from aiida.common.utils import Capturing -from aiida.manage.configuration import get_profile class TestMigrations(AiidaTestCase): @@ -36,7 +35,7 @@ def app(self): def setUp(self): """Go to a specific schema version before running tests.""" - from aiida.orm.implementation.django.querybuilder import get_scoped_session + from aiida.backends.djsite import get_scoped_session from aiida.orm import autogroup self.current_autogroup = autogroup.current_autogroup @@ -52,7 +51,7 @@ def setUp(self): # Before running the migration, make sure we close the querybuilder session which may still contain references # to objects whose mapping may be invalidated after resetting the schema to an older version. This can block # the migrations so we first expunge those objects by closing the session. - get_scoped_session(get_profile()).close() + get_scoped_session().close() # Reverse to the original migration with Capturing():