From 1b886e2766852bc792fa0cb9719fa83c140758a5 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 9 Sep 2021 13:04:58 +0200 Subject: [PATCH 01/13] SQLA v2 API: Update declarative_base import --- aiida/backends/sqlalchemy/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiida/backends/sqlalchemy/models/base.py b/aiida/backends/sqlalchemy/models/base.py index 73a7cba6cf..dd7f6ab9ad 100644 --- a/aiida/backends/sqlalchemy/models/base.py +++ b/aiida/backends/sqlalchemy/models/base.py @@ -11,7 +11,7 @@ """Base SQLAlchemy models.""" from sqlalchemy import orm -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import declarative_base from sqlalchemy.orm.exc import UnmappedClassError import aiida.backends.sqlalchemy From 5001a2f7734008eee927663bf5c0bc585f2ea607 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 9 Sep 2021 13:24:20 +0200 Subject: [PATCH 02/13] SQLA v2 API: .transaction.nested -> .in_nested_transaction() --- aiida/orm/implementation/sqlalchemy/backend.py | 2 +- aiida/orm/implementation/sqlalchemy/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index 3661ee44a7..641c352dcd 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -83,7 +83,7 @@ def transaction(self): entering. Transactions can be nested. """ session = self.get_session() - nested = session.transaction.nested + nested = session.in_nested_transaction() try: session.begin_nested() yield session diff --git a/aiida/orm/implementation/sqlalchemy/utils.py b/aiida/orm/implementation/sqlalchemy/utils.py index 6a1f9f654d..42607c31c4 100644 --- a/aiida/orm/implementation/sqlalchemy/utils.py +++ b/aiida/orm/implementation/sqlalchemy/utils.py @@ -146,7 +146,7 @@ def _in_transaction(): :return: boolean, True if currently in open transaction, False otherwise. """ - return get_scoped_session().transaction.nested + return get_scoped_session().in_nested_transaction() @contextlib.contextmanager From a95f54c87d2b7ca0d7042aeb1d1243de0014875a Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 9 Sep 2021 14:04:12 +0200 Subject: [PATCH 03/13] SQLA v2 API: session.query(O).get(id) -> session.get(O, id) --- aiida/orm/implementation/sqlalchemy/computers.py | 2 +- aiida/orm/implementation/sqlalchemy/groups.py | 2 +- tests/backends/aiida_sqlalchemy/test_session.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aiida/orm/implementation/sqlalchemy/computers.py b/aiida/orm/implementation/sqlalchemy/computers.py index 30eb2339c0..14e500f05b 100644 --- a/aiida/orm/implementation/sqlalchemy/computers.py +++ b/aiida/orm/implementation/sqlalchemy/computers.py @@ -137,7 +137,7 @@ def list_names(): def delete(self, pk): try: session = get_scoped_session() - session.query(DbComputer).get(pk).delete() + session.get(DbComputer, pk).delete() session.commit() except SQLAlchemyError as exc: raise exceptions.InvalidOperation( diff --git a/aiida/orm/implementation/sqlalchemy/groups.py b/aiida/orm/implementation/sqlalchemy/groups.py index 482f264e95..d6c34e5a9f 100644 --- a/aiida/orm/implementation/sqlalchemy/groups.py +++ b/aiida/orm/implementation/sqlalchemy/groups.py @@ -367,5 +367,5 @@ def query( def delete(self, id): # pylint: disable=redefined-builtin session = sa.get_scoped_session() - session.query(DbGroup).get(id).delete() + session.get(DbGroup, id).delete() session.commit() diff --git a/tests/backends/aiida_sqlalchemy/test_session.py b/tests/backends/aiida_sqlalchemy/test_session.py index c868ea37ed..8707134cf6 100644 --- a/tests/backends/aiida_sqlalchemy/test_session.py +++ b/tests/backends/aiida_sqlalchemy/test_session.py @@ -164,7 +164,7 @@ def test_node_access_with_sessions(self): self.assertIsNot(master_session, custom_session) # Manually load the DbNode in a different session - dbnode_reloaded = custom_session.query(sa.models.node.DbNode).get(node.id) + dbnode_reloaded = custom_session.get(sa.models.node.DbNode, node.id) # Now, go through one by one changing the possible attributes (of the model) # and check that they're updated when the user reads them from the aiida node From 4c6d9a9f356336d1091d12006ff8f9df96abb09d Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 9 Sep 2021 14:22:22 +0200 Subject: [PATCH 04/13] SQLA v2 API: Replace .execute("...") with .execute(text("...")) --- .../migrations/versions/70c7d732f1b2_delete_dbpath.py | 4 ++-- .../versions/ea2f50e7f615_dblog_create_uuid_column.py | 4 ++-- aiida/backends/sqlalchemy/utils.py | 10 +++++++--- aiida/orm/implementation/sqlalchemy/backend.py | 3 ++- tests/backends/aiida_sqlalchemy/test_utils.py | 4 ++-- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py b/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py index bd0ad4409f..91ba715abd 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py +++ b/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py @@ -31,8 +31,8 @@ def upgrade(): """Migrations for the upgrade.""" op.drop_table('db_dbpath') conn = op.get_bind() - conn.execute('DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink') - conn.execute('DROP FUNCTION IF EXISTS update_tc()') + conn.execute(sa.text('DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink')) + conn.execute(sa.text('DROP FUNCTION IF EXISTS update_tc()')) def downgrade(): diff --git a/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py b/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py index 6060e03ef7..b7e4a80fa6 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py +++ b/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py @@ -33,7 +33,7 @@ def set_new_uuid(connection): from aiida.common.utils import get_new_uuid # Exit if there are no rows - e.g. initial setup - id_query = connection.execute('SELECT db_dblog.id FROM db_dblog') + id_query = connection.execute(sa.text('SELECT db_dblog.id FROM db_dblog')) if id_query.rowcount == 0: return @@ -52,7 +52,7 @@ def set_new_uuid(connection): UPDATE db_dblog as t SET uuid = uuid(c.uuid) from (values {key_values}) as c(id, uuid) where c.id = t.id""" - connection.execute(update_stm) + connection.execute(sa.text(update_stm)) def upgrade(): diff --git a/aiida/backends/sqlalchemy/utils.py b/aiida/backends/sqlalchemy/utils.py index edb7369ff3..e7f690aa17 100644 --- a/aiida/backends/sqlalchemy/utils.py +++ b/aiida/backends/sqlalchemy/utils.py @@ -60,6 +60,8 @@ def install_tc(session): """ Install the transitive closure table with SqlAlchemy. """ + from sqlalchemy import text + links_table_name = 'db_dblink' links_table_input_field = 'input_id' links_table_output_field = 'output_id' @@ -68,9 +70,11 @@ def install_tc(session): closure_table_child_field = 'child_id' session.execute( - get_pg_tc( - links_table_name, links_table_input_field, links_table_output_field, closure_table_name, - closure_table_parent_field, closure_table_child_field + text( + get_pg_tc( + links_table_name, links_table_input_field, links_table_output_field, closure_table_name, + closure_table_parent_field, closure_table_child_field + ) ) ) diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index 641c352dcd..0d37d574a1 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -131,10 +131,11 @@ def execute_raw(self, query): :param query: a string containing a raw SQL statement :return: the result of the query """ + from sqlalchemy import text from sqlalchemy.exc import ResourceClosedError # pylint: disable=import-error,no-name-in-module with self.transaction() as session: - queryset = session.execute(query) + queryset = session.execute(text(query)) try: results = queryset.fetchall() diff --git a/tests/backends/aiida_sqlalchemy/test_utils.py b/tests/backends/aiida_sqlalchemy/test_utils.py index 1235ae4e4b..16f829ecb3 100644 --- a/tests/backends/aiida_sqlalchemy/test_utils.py +++ b/tests/backends/aiida_sqlalchemy/test_utils.py @@ -58,7 +58,7 @@ def database_exists(url): try: if engine.dialect.name == 'postgresql': - text = f"SELECT 1 FROM pg_database WHERE datname='{database}'" + text = sa.text(f"SELECT 1 FROM pg_database WHERE datname='{database}'") return bool(engine.execute(text).scalar()) raise Exception('Only PostgreSQL is supported.') @@ -98,7 +98,7 @@ def create_database(url, encoding='utf8'): from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT engine.raw_connection().set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) - text = f"CREATE DATABASE {quote(engine, database)} ENCODING '{encoding}'" + text = sa.text(f"CREATE DATABASE {quote(engine, database)} ENCODING '{encoding}'") engine.execute(text) From 3966e7a10e65a0a267aec802056f75bc1ae5eb71 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 10 Sep 2021 05:57:03 +0200 Subject: [PATCH 05/13] SQLA v2 API: Remove branched connection context See: https://github.com/sqlalchemy/alembic/issues/908 --- aiida/backends/sqlalchemy/migrations/env.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/aiida/backends/sqlalchemy/migrations/env.py b/aiida/backends/sqlalchemy/migrations/env.py index d148bd54d2..9fa134f0f0 100644 --- a/aiida/backends/sqlalchemy/migrations/env.py +++ b/aiida/backends/sqlalchemy/migrations/env.py @@ -30,20 +30,19 @@ def run_migrations_online(): from aiida.backends.sqlalchemy.models.base import Base config = context.config # pylint: disable=no-member - connectable = config.attributes.get('connection', None) + connection = config.attributes.get('connection', None) - if connectable is None: + if connection is None: from aiida.common.exceptions import ConfigurationError raise ConfigurationError('An initialized connection is expected for the AiiDA online migrations.') - with connectable.connect() as connection: - context.configure( # pylint: disable=no-member - connection=connection, - target_metadata=Base.metadata, - transaction_per_migration=True, - ) + context.configure( # pylint: disable=no-member + connection=connection, + target_metadata=Base.metadata, + transaction_per_migration=True, + ) - context.run_migrations() # pylint: disable=no-member + context.run_migrations() # pylint: disable=no-member try: From 625335fa80d9e63f42b5af80e92a5eb7f0c49880 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 10 Sep 2021 06:29:57 +0200 Subject: [PATCH 06/13] SQLA v2 API: Use `Row._as_dict()` for dict access --- .../migrations/versions/041a79fc615f_dblog_cleaning.py | 8 ++++---- tests/backends/aiida_sqlalchemy/test_migrations.py | 10 ++++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/aiida/backends/sqlalchemy/migrations/versions/041a79fc615f_dblog_cleaning.py b/aiida/backends/sqlalchemy/migrations/versions/041a79fc615f_dblog_cleaning.py index 952bed3cac..a05796e0d5 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/041a79fc615f_dblog_cleaning.py +++ b/aiida/backends/sqlalchemy/migrations/versions/041a79fc615f_dblog_cleaning.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name,no-member,import-error,no-name-in-module +# pylint: disable=invalid-name,no-member,import-error,no-name-in-module,protected-access """This migration cleans the log records from non-Node entity records. It removes from the DbLog table the legacy workflow records and records that correspond to an unknown entity and places them to corresponding files. @@ -95,7 +95,7 @@ def get_serialized_legacy_workflow_logs(connection): ) res = list() for row in query: - res.append(dict(list(zip(row.keys(), row)))) + res.append(row._asdict()) return dumps_json(res) @@ -114,7 +114,7 @@ def get_serialized_unknown_entity_logs(connection): ) res = list() for row in query: - res.append(dict(list(zip(row.keys(), row)))) + res.append(row._asdict()) return dumps_json(res) @@ -133,7 +133,7 @@ def get_serialized_logs_with_no_nodes(connection): ) res = list() for row in query: - res.append(dict(list(zip(row.keys(), row)))) + res.append(row._asdict()) return dumps_json(res) diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index 81036ece0c..707cac77dc 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines +# pylint: disable=too-many-lines,protected-access """Tests for the migration engine (Alembic) as well as for the AiiDA migrations for SQLAlchemy.""" from contextlib import contextmanager @@ -779,7 +779,7 @@ def setUpBeforeMigration(self): param_data = session.query(DbLog).filter(DbLog.objpk == param.id ).filter(DbLog.objname == 'something.else.' ).with_entities(*cols_to_project).one() - serialized_param_data = dumps_json([(dict(list(zip(param_data.keys(), param_data))))]) + serialized_param_data = dumps_json([param_data._asdict()]) # Getting the serialized logs for the unknown entity logs (as the export migration fuction # provides them) - this should coincide to the above serialized_unknown_exp_logs = log_migration.get_serialized_unknown_entity_logs(connection) @@ -792,7 +792,7 @@ def setUpBeforeMigration(self): leg_wf = session.query(DbLog).filter(DbLog.objpk == leg_workf.id).filter( DbLog.objname == 'aiida.workflows.user.topologicalworkflows.topo.TopologicalWorkflow' ).with_entities(*cols_to_project).one() - serialized_leg_wf_logs = dumps_json([(dict(list(zip(leg_wf.keys(), leg_wf))))]) + serialized_leg_wf_logs = dumps_json([leg_wf._asdict()]) # Getting the serialized logs for the legacy workflow logs (as the export migration function # provides them) - this should coincide to the above serialized_leg_wf_exp_logs = log_migration.get_serialized_legacy_workflow_logs(connection) @@ -803,9 +803,7 @@ def setUpBeforeMigration(self): # Getting the serialized logs that don't correspond to a DbNode record logs_no_node = session.query(DbLog).filter( DbLog.id.in_([log_5.id, log_6.id])).with_entities(*cols_to_project) - logs_no_node_list = list() - for log_no_node in logs_no_node: - logs_no_node_list.append((dict(list(zip(log_no_node.keys(), log_no_node))))) + logs_no_node_list = [log_no_node._asdict() for log_no_node in logs_no_node] serialized_logs_no_node = dumps_json(logs_no_node_list) # Getting the serialized logs that don't correspond to a node (as the export migration function From c48d61752ab42cb4249a293a6aeaa937a0ea2cea Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 10 Sep 2021 06:49:58 +0200 Subject: [PATCH 07/13] SQLA v2 API: Replace Engine.execute() with Connection.execute() --- tests/backends/aiida_sqlalchemy/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/backends/aiida_sqlalchemy/test_utils.py b/tests/backends/aiida_sqlalchemy/test_utils.py index 16f829ecb3..398e1122b5 100644 --- a/tests/backends/aiida_sqlalchemy/test_utils.py +++ b/tests/backends/aiida_sqlalchemy/test_utils.py @@ -59,7 +59,7 @@ def database_exists(url): try: if engine.dialect.name == 'postgresql': text = sa.text(f"SELECT 1 FROM pg_database WHERE datname='{database}'") - return bool(engine.execute(text).scalar()) + return bool(engine.connect().execute(text).scalar()) raise Exception('Only PostgreSQL is supported.') finally: @@ -99,8 +99,8 @@ def create_database(url, encoding='utf8'): engine.raw_connection().set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) text = sa.text(f"CREATE DATABASE {quote(engine, database)} ENCODING '{encoding}'") - - engine.execute(text) + with engine.begin() as connection: + connection.execute(text) else: raise Exception('Only PostgreSQL with the psycopg2 driver is supported.') From 70a7018b2e9c5c56da277c1c45e915ac0c2bd7cc Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 10 Sep 2021 07:09:59 +0200 Subject: [PATCH 08/13] SQLA v2 API: Replace use of lists in `select()` --- ...798d4d3_trajectory_symbols_to_attribute.py | 4 +-- .../1b8ed3425af9_remove_legacy_workflows.py | 12 ++++---- .../1feaea71bd5a_migrate_repository.py | 2 +- .../239cea6d2452_provenance_redesign.py | 4 +-- ...84bcc35_delete_trajectory_symbols_array.py | 6 ++-- .../sqlalchemy/querybuilder/joiner.py | 30 +++++++++---------- 6 files changed, 28 insertions(+), 30 deletions(-) diff --git a/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py b/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py index d53ec44ce3..d42f7d0813 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py +++ b/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py @@ -46,7 +46,7 @@ def upgrade(): column('attributes', JSONB)) nodes = connection.execute( - select([DbNode.c.id, DbNode.c.uuid]).where( + select(DbNode.c.id, DbNode.c.uuid).where( DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall() for pk, uuid in nodes: @@ -64,7 +64,7 @@ def downgrade(): column('attributes', JSONB)) nodes = connection.execute( - select([DbNode.c.id, DbNode.c.uuid]).where( + select(DbNode.c.id, DbNode.c.uuid).where( DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall() for pk, _ in nodes: diff --git a/aiida/backends/sqlalchemy/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py b/aiida/backends/sqlalchemy/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py index 2b0eed82a1..cabb6b487a 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py +++ b/aiida/backends/sqlalchemy/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py @@ -58,9 +58,9 @@ def export_workflow_data(connection): DbWorkflowData = table('db_dbworkflowdata') DbWorkflowStep = table('db_dbworkflowstep') - count_workflow = connection.execute(select([func.count()]).select_from(DbWorkflow)).scalar() - count_workflow_data = connection.execute(select([func.count()]).select_from(DbWorkflowData)).scalar() - count_workflow_step = connection.execute(select([func.count()]).select_from(DbWorkflowStep)).scalar() + count_workflow = connection.execute(select(func.count()).select_from(DbWorkflow)).scalar() + count_workflow_data = connection.execute(select(func.count()).select_from(DbWorkflowData)).scalar() + count_workflow_step = connection.execute(select(func.count()).select_from(DbWorkflowStep)).scalar() # Nothing to do if all tables are empty if count_workflow == 0 and count_workflow_data == 0 and count_workflow_step == 0: @@ -78,9 +78,9 @@ def export_workflow_data(connection): delete_on_close = configuration.PROFILE.is_test_profile data = { - 'workflow': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflow))], - 'workflow_data': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflowData))], - 'workflow_step': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflowStep))], + 'workflow': [dict(row) for row in connection.execute(select('*').select_from(DbWorkflow))], + 'workflow_data': [dict(row) for row in connection.execute(select('*').select_from(DbWorkflowData))], + 'workflow_step': [dict(row) for row in connection.execute(select('*').select_from(DbWorkflowStep))], } with NamedTemporaryFile( diff --git a/aiida/backends/sqlalchemy/migrations/versions/1feaea71bd5a_migrate_repository.py b/aiida/backends/sqlalchemy/migrations/versions/1feaea71bd5a_migrate_repository.py index 24a74a6c19..304f7077de 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/1feaea71bd5a_migrate_repository.py +++ b/aiida/backends/sqlalchemy/migrations/versions/1feaea71bd5a_migrate_repository.py @@ -45,7 +45,7 @@ def upgrade(): ) profile = get_profile() - node_count = connection.execute(select([func.count()]).select_from(DbNode)).scalar() + node_count = connection.execute(select(func.count()).select_from(DbNode)).scalar() missing_repo_folder = [] shard_count = 256 diff --git a/aiida/backends/sqlalchemy/migrations/versions/239cea6d2452_provenance_redesign.py b/aiida/backends/sqlalchemy/migrations/versions/239cea6d2452_provenance_redesign.py index a0ff49e325..33f3edfaef 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/239cea6d2452_provenance_redesign.py +++ b/aiida/backends/sqlalchemy/migrations/versions/239cea6d2452_provenance_redesign.py @@ -41,7 +41,7 @@ def migrate_infer_calculation_entry_point(connection): column('process_type', String) ) - query_set = connection.execute(select([DbNode.c.type]).where(DbNode.c.type.like('calculation.%'))).fetchall() + query_set = connection.execute(select(DbNode.c.type).where(DbNode.c.type.like('calculation.%'))).fetchall() type_strings = set(entry[0] for entry in query_set) mapping_node_type_to_entry_point = infer_calculation_entry_point(type_strings=type_strings) @@ -54,7 +54,7 @@ def migrate_infer_calculation_entry_point(connection): # All affected entries should be logged to file that the user can consult. if ENTRY_POINT_STRING_SEPARATOR not in entry_point_string: query_set = connection.execute( - select([DbNode.c.uuid]).where(DbNode.c.type == op.inline_literal(type_string)) + select(DbNode.c.uuid).where(DbNode.c.type == op.inline_literal(type_string)) ).fetchall() uuids = [str(entry.uuid) for entry in query_set] diff --git a/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py b/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py index 765a4eaa6a..1c36359b36 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py +++ b/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py @@ -43,7 +43,7 @@ def upgrade(): column('attributes', JSONB)) nodes = connection.execute( - select([DbNode.c.id, DbNode.c.uuid]).where( + select(DbNode.c.id, DbNode.c.uuid).where( DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall() for pk, uuid in nodes: @@ -61,11 +61,11 @@ def downgrade(): column('attributes', JSONB)) nodes = connection.execute( - select([DbNode.c.id, DbNode.c.uuid]).where( + select(DbNode.c.id, DbNode.c.uuid).where( DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall() for pk, uuid in nodes: - attributes = connection.execute(select([DbNode.c.attributes]).where(DbNode.c.id == pk)).fetchone() + attributes = connection.execute(select(DbNode.c.attributes).where(DbNode.c.id == pk)).fetchone() symbols = numpy.array(attributes['symbols']) utils.store_numpy_array_in_repository(uuid, 'symbols', symbols) key = op.inline_literal('{"array|symbols"}') diff --git a/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py b/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py index e32e784901..e78e74380f 100644 --- a/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py +++ b/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py @@ -195,7 +195,7 @@ def _join_descendants_recursive( if expand_path: selection_walk_list.append(array((link1.input_id, link1.output_id)).label('path')) - walk = select(selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id)).where( + walk = select(*selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id)).where( and_( in_recursive_filters, # I apply filters for speed here link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # I follow input and create links @@ -214,13 +214,12 @@ def _join_descendants_recursive( descendants_recursive = aliased( aliased_walk.union_all( - select(selection_union_list).select_from( - join( - aliased_walk, - link2, - link2.input_id == aliased_walk.c.descendant_id, - ) - ).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) + select(*selection_union_list + ).select_from(join( + aliased_walk, + link2, + link2.input_id == aliased_walk.c.descendant_id, + )).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) ) ) # .alias() @@ -259,7 +258,7 @@ def _join_ancestors_recursive( if expand_path: selection_walk_list.append(array((link1.output_id, link1.input_id)).label('path')) - walk = select(selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id)).where( + walk = select(*selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id)).where( and_(in_recursive_filters, link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) ).cte(recursive=True) @@ -275,13 +274,12 @@ def _join_ancestors_recursive( ancestors_recursive = aliased( aliased_walk.union_all( - select(selection_union_list).select_from( - join( - aliased_walk, - link2, - link2.output_id == aliased_walk.c.ancestor_id, - ) - ).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) + select(*selection_union_list + ).select_from(join( + aliased_walk, + link2, + link2.output_id == aliased_walk.c.ancestor_id, + )).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) # I can't follow RETURN or CALL links ) ) From 87c96a09bab631a0a3d7f7c9b77f9c6bf96c1485 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 10 Sep 2021 15:46:27 +0200 Subject: [PATCH 09/13] SQLA v2 API: Do not build empty query filters Fixes #2475 --- .../sqlalchemy/querybuilder/joiner.py | 28 ++++++++++------ .../sqlalchemy/querybuilder/main.py | 33 +++++++++++-------- tests/orm/test_querybuilder.py | 10 +++++- 3 files changed, 46 insertions(+), 25 deletions(-) diff --git a/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py b/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py index e78e74380f..2a1f996f31 100644 --- a/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py +++ b/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py @@ -79,7 +79,8 @@ class SqlaJoiner: """A class containing the logic for SQLAlchemy entities joining entities.""" def __init__( - self, entity_mapper: _EntityMapper, filter_builder: Callable[[AliasedClass, FilterType], BooleanClauseList] + self, entity_mapper: _EntityMapper, filter_builder: Callable[[AliasedClass, FilterType], + Optional[BooleanClauseList]] ): """Initialise the class""" self._entities = entity_mapper @@ -185,7 +186,13 @@ def _join_descendants_recursive( link1 = aliased(self._entities.Link) link2 = aliased(self._entities.Link) node1 = aliased(self._entities.Node) + + link_filters = link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # follow input / create links in_recursive_filters = self._build_filters(node1, filter_dict) + if in_recursive_filters is None: + filters = link_filters + else: + filters = and_(in_recursive_filters, link_filters) selection_walk_list = [ link1.input_id.label('ancestor_id'), @@ -195,12 +202,8 @@ def _join_descendants_recursive( if expand_path: selection_walk_list.append(array((link1.input_id, link1.output_id)).label('path')) - walk = select(*selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id)).where( - and_( - in_recursive_filters, # I apply filters for speed here - link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # I follow input and create links - ) - ).cte(recursive=True) + walk = select(*selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id) + ).where(filters).cte(recursive=True) aliased_walk = aliased(walk) @@ -248,7 +251,13 @@ def _join_ancestors_recursive( link1 = aliased(self._entities.Link) link2 = aliased(self._entities.Link) node1 = aliased(self._entities.Node) + + link_filters = link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # follow input / create links in_recursive_filters = self._build_filters(node1, filter_dict) + if in_recursive_filters is None: + filters = link_filters + else: + filters = and_(in_recursive_filters, link_filters) selection_walk_list = [ link1.input_id.label('ancestor_id'), @@ -258,9 +267,8 @@ def _join_ancestors_recursive( if expand_path: selection_walk_list.append(array((link1.output_id, link1.input_id)).label('path')) - walk = select(*selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id)).where( - and_(in_recursive_filters, link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) - ).cte(recursive=True) + walk = select(*selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id) + ).where(filters).cte(recursive=True) aliased_walk = aliased(walk) diff --git a/aiida/orm/implementation/sqlalchemy/querybuilder/main.py b/aiida/orm/implementation/sqlalchemy/querybuilder/main.py index 02c43aa80c..285b171ef4 100644 --- a/aiida/orm/implementation/sqlalchemy/querybuilder/main.py +++ b/aiida/orm/implementation/sqlalchemy/querybuilder/main.py @@ -366,7 +366,9 @@ def _build(self) -> Query: alias = self._get_tag_alias(tag) except KeyError: raise ValueError(f'Unknown tag {tag!r} in filters, known: {list(self._tag_to_alias)}') - self._query = self._query.filter(self.build_filters(alias, filter_specs)) + filters = self.build_filters(alias, filter_specs) + if filters is not None: + self._query = self._query.filter(filters) # PROJECTIONS ########################## @@ -601,7 +603,7 @@ def get_column(colname: str, alias: AliasedClass) -> InstrumentedAttribute: '{}'.format(colname, alias, '\n'.join(alias._sa_class_manager.mapper.c.keys())) # pylint: disable=protected-access ) from exc - def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> BooleanClauseList: + def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> Optional[BooleanClauseList]: # pylint: disable=too-many-branches """Recurse through the filter specification and apply filter operations. :param alias: The alias of the ORM class the filter will be applied on @@ -612,17 +614,20 @@ def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> Boo expressions: List[Any] = [] for path_spec, filter_operation_dict in filter_spec.items(): if path_spec in ('and', 'or', '~or', '~and', '!and', '!or'): - subexpressions = [ - self.build_filters(alias, sub_filter_spec) for sub_filter_spec in filter_operation_dict - ] - if path_spec == 'and': - expressions.append(and_(*subexpressions)) - elif path_spec == 'or': - expressions.append(or_(*subexpressions)) - elif path_spec in ('~and', '!and'): - expressions.append(not_(and_(*subexpressions))) - elif path_spec in ('~or', '!or'): - expressions.append(not_(or_(*subexpressions))) + subexpressions = [] + for sub_filter_spec in filter_operation_dict: + filters = self.build_filters(alias, sub_filter_spec) + if filters is not None: + subexpressions.append(filters) + if subexpressions: + if path_spec == 'and': + expressions.append(and_(*subexpressions)) + elif path_spec == 'or': + expressions.append(or_(*subexpressions)) + elif path_spec in ('~and', '!and'): + expressions.append(not_(and_(*subexpressions))) + elif path_spec in ('~or', '!or'): + expressions.append(not_(or_(*subexpressions))) else: column_name = path_spec.split('.')[0] @@ -650,7 +655,7 @@ def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> Boo alias=alias ) ) - return and_(*expressions) + return and_(*expressions) if expressions else None def modify_expansions(self, alias: AliasedClass, expansions: List[str]) -> List[str]: """Modify names of projections if `**` was specified. diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index c596062aa1..43159b020e 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -798,7 +798,7 @@ class TestQueryBuilderCornerCases: In this class corner cases of QueryBuilder are added. """ - def test_computer_json(self): # pylint: disable=no-self-use + def test_computer_json(self): """ In this test we check the correct behavior of QueryBuilder when retrieving the _metadata with no content. @@ -818,6 +818,14 @@ def test_computer_json(self): # pylint: disable=no-self-use qb.append(orm.Computer, project=['id', 'metadata'], outerjoin=True, with_node='calc') qb.all() + def test_empty_filters(self): + """Test that an empty filter is correctly handled.""" + orm.Data().store() + qb = orm.QueryBuilder().append(orm.Data, filters={}) + assert qb.count() == 1 + qb = orm.QueryBuilder().append(orm.Data, filters={'or': [{}, {}]}) + assert qb.count() == 1 + @pytest.mark.usefixtures('clear_database_before_test') class TestAttributes: From db328dc81a461a5f969639f68d2e89db9c5460c1 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sat, 11 Sep 2021 10:36:00 +0200 Subject: [PATCH 10/13] SQLA v2 API: Use transaction context managers --- .../orm/implementation/sqlalchemy/backend.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index 0d37d574a1..89f9efffa4 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -83,18 +83,13 @@ def transaction(self): entering. Transactions can be nested. """ session = self.get_session() - nested = session.in_nested_transaction() - try: - session.begin_nested() - yield session - session.commit() - except Exception: - session.rollback() - raise - finally: - if not nested: - # Make sure to commit the outermost session - session.commit() + if session.in_transaction(): + with session.begin_nested(): + yield session + else: + with session.begin(): + with session.begin_nested(): + yield session @staticmethod def get_session(): From 6aa8281659dfad67efeb039725a83d012f12cd25 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sat, 11 Sep 2021 09:12:05 +0200 Subject: [PATCH 11/13] SQLA v2 API: add engine & session `future=True` flag --- aiida/backends/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiida/backends/utils.py b/aiida/backends/utils.py index 0b42aa378d..30ab18ae01 100644 --- a/aiida/backends/utils.py +++ b/aiida/backends/utils.py @@ -38,14 +38,14 @@ def create_sqlalchemy_engine(profile, **kwargs): name=profile.database_name ) return create_engine( - engine_url, json_serializer=json.dumps, json_deserializer=json.loads, future=False, encoding='utf-8', **kwargs + engine_url, json_serializer=json.dumps, json_deserializer=json.loads, future=True, encoding='utf-8', **kwargs ) def create_scoped_session_factory(engine, **kwargs): """Create scoped SQLAlchemy session factory""" from sqlalchemy.orm import scoped_session, sessionmaker - return scoped_session(sessionmaker(bind=engine, **kwargs)) + return scoped_session(sessionmaker(bind=engine, future=True, **kwargs)) def delete_nodes_and_connections(pks): From 582a69377fd4e1e1d9dbbfe77a48519f62758882 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 15 Sep 2021 19:17:40 +0200 Subject: [PATCH 12/13] SQLA v2 API: Suppress known `SAWarning`s in tests --- tests/backends/aiida_sqlalchemy/test_schema.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/backends/aiida_sqlalchemy/test_schema.py b/tests/backends/aiida_sqlalchemy/test_schema.py index 1aa5341efc..bffa84dc7e 100644 --- a/tests/backends/aiida_sqlalchemy/test_schema.py +++ b/tests/backends/aiida_sqlalchemy/test_schema.py @@ -9,6 +9,10 @@ ########################################################################### # pylint: disable=import-error,no-name-in-module """Test object relationships in the database.""" +import warnings + +from sqlalchemy import exc as sa_exc + from aiida.backends.testbase import AiidaTestCase from aiida.backends.sqlalchemy.models.user import DbUser from aiida.backends.sqlalchemy.models.node import DbNode @@ -111,9 +115,6 @@ def test_user_node_2(self): storing USER does NOT induce storage of the NODE Assert the correct storage of user and node.""" - import warnings - from sqlalchemy import exc as sa_exc - # Create user dbu1 = DbUser('tests2@schema', 'spam', 'eggs', 'monty') @@ -164,7 +165,10 @@ def test_user_node_3(self): # Add only first node and commit session.add(dbn_1) - session.commit() + with warnings.catch_warnings(): + # suppress known SAWarning that we have not added dbn_2 + warnings.simplefilter('ignore', category=sa_exc.SAWarning) + session.commit() # Check for which object a pk has been assigned, which means that # things have been at least flushed into the database @@ -200,7 +204,10 @@ def test_user_node_4(self): # Add only first node and commit session.add(dbn_1) - session.commit() + with warnings.catch_warnings(): + # suppress known SAWarning that we have not add the other nodes + warnings.simplefilter('ignore', category=sa_exc.SAWarning) + session.commit() # Check for which object a pk has been assigned, which means that # things have been at least flushed into the database From f926619e98eb74bad01b6af303d4a1267e4ebce9 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 15 Sep 2021 23:40:59 +0200 Subject: [PATCH 13/13] SQLA v2 API: Replace `Query.yield_per(x)` with `.execution_options(yield_per=x)` See: https://docs.sqlalchemy.org/en/14/orm/queryguide.html?highlight=yield%20per#yield-per Using statement execution also inhibits automatic uniquing of results, removing the previous need for patching of `ORMCompileState`. --- .../sqlalchemy/querybuilder/main.py | 48 +++---------------- 1 file changed, 7 insertions(+), 41 deletions(-) diff --git a/aiida/orm/implementation/sqlalchemy/querybuilder/main.py b/aiida/orm/implementation/sqlalchemy/querybuilder/main.py index 285b171ef4..1c6848156d 100644 --- a/aiida/orm/implementation/sqlalchemy/querybuilder/main.py +++ b/aiida/orm/implementation/sqlalchemy/querybuilder/main.py @@ -19,8 +19,7 @@ from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.exc import SAWarning from sqlalchemy.ext.compiler import compiles -from sqlalchemy.orm import aliased, loading -from sqlalchemy.orm.context import ORMCompileState, QueryContext +from sqlalchemy.orm import aliased from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.orm.util import AliasedClass @@ -72,35 +71,6 @@ def compile(element, compiler: TypeCompiler, **kwargs): # pylint: disable=funct return f'jsonb_typeof({compiler.process(element.clauses, **kwargs)})' -def _orm_setup_cursor_result( - session, - statement, - params, - execution_options, - bind_arguments, - result, -): - """Patched class method.""" - execution_context = result.context - compile_state = execution_context.compiled.compile_state - - # this is the patch required for turning off de-duplication of results - compile_state._has_mapper_entities = False # pylint: disable=protected-access - - load_options = execution_options.get('_sa_orm_load_options', QueryContext.default_load_options) - - querycontext = QueryContext( - compile_state, - statement, - params, - session, - load_options, - execution_options, - bind_arguments, - ) - return loading.instances(result, querycontext) - - class SqlaQueryBuilder(BackendQueryBuilder): """ QueryBuilder to use with SQLAlchemy-backend and @@ -229,7 +199,9 @@ def iterall(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[Li """Return an iterator over all the results of a list of lists.""" with self.use_query(data) as query: - for resultrow in query.yield_per(batch_size): # type: ignore[arg-type] # pylint: disable=not-an-iterable + stmt = query.statement.execution_options(yield_per=batch_size) + + for resultrow in self.get_session().execute(stmt): # we discard the first item of the result row, # which is what the query was initialised with # and not one of the requested projection (see self._build) @@ -240,7 +212,9 @@ def iterdict(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[D """Return an iterator over all the results of a list of dictionaries.""" with self.use_query(data) as query: - for row in query.yield_per(batch_size): # type: ignore[arg-type] # pylint: disable=not-an-iterable + stmt = query.statement.execution_options(yield_per=batch_size) + + for row in self.get_session().execute(stmt): # build the yield result yield_result: Dict[str, Dict[str, Any]] = {} for tag, projected_entities_dict in self._tag_to_projected_fields.items(): @@ -255,20 +229,12 @@ def iterdict(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[D @contextmanager def use_query(self, data: QueryDictType) -> Iterator[Query]: """Yield the built query.""" - # Currently, a monkey-patch is required to turn off de-duplication of results, - # carried out in the `use_query` method - # see: https://github.com/sqlalchemy/sqlalchemy/issues/4395#issuecomment-907293360 - # THIS CAN BE REMOVED WHEN MOVING TO THE VERSION 2 API - existing_func = ORMCompileState.orm_setup_cursor_result - ORMCompileState.orm_setup_cursor_result = _orm_setup_cursor_result # type: ignore[assignment] query = self._update_query(data) try: yield query except Exception: self.get_session().close() raise - finally: - ORMCompileState.orm_setup_cursor_result = existing_func # type: ignore[assignment] def _update_query(self, data: QueryDictType) -> Query: """Return the sqlalchemy.orm.Query instance for the current query specification.