From eefa3678b528639382718717519996ef7963df74 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Sun, 6 Mar 2022 19:50:52 +0100 Subject: [PATCH 1/3] `QueryBuilder`: add `flat` keyword to `first` method This keyword already exists for the `all` method and it will likewise be useful for `first` when only a single quantity is projected. In that case, often the caller doesn't want a list as a return value but simply the projected quantity. Allowing to get this directly from the method call as opposed to manually dereferencing the first item from the returned list often makes for cleaner code. --- aiida/orm/querybuilder.py | 18 +++++++++++++----- tests/orm/test_querybuilder.py | 17 +++++++++++++++-- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index 69582b618f..f1392b2406 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -989,12 +989,15 @@ def _get_aiida_entity_res(value) -> Any: except TypeError: return value - def first(self) -> Optional[List[Any]]: - """Executes the query, asking for the first row of results. + def first(self, flat: bool = False) -> Optional[Union[List[Any], Any]]: + """Return the first result of the query. - Note, this may change if several rows are valid for the query, - as persistent ordering is not guaranteed unless explicitly specified. + Calling ``first`` results in an execution of the underlying query. + Note, this may change if several rows are valid for the query, as persistent ordering is not guaranteed unless + explicitly specified. + + :param flat: if True, return just the projected quantity if there is just a single projection. :returns: One row of results as a list, or None if no result returned. """ result = self._impl.first(self.as_dict()) @@ -1002,7 +1005,12 @@ def first(self) -> Optional[List[Any]]: if result is None: return None - return [self._get_aiida_entity_res(rowitem) for rowitem in result] + result = [self._get_aiida_entity_res(rowitem) for rowitem in result] + + if flat and len(result) == 1: + return result[0] + + return result def count(self) -> int: """ diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index 9ad8cdc8b5..5b91e89f90 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -649,7 +649,7 @@ def test_direction_keyword(self): assert res2 == {d2.id, d4.id} @staticmethod - def test_flat(): + def test_all_flat(): """Test the `flat` keyword for the `QueryBuilder.all()` method.""" pks = [] uuids = [] @@ -665,13 +665,26 @@ def test_flat(): assert len(result) == 10 assert result == pks - # Mutltiple projections + # Multiple projections builder = orm.QueryBuilder().append(orm.Data, project=['id', 'uuid']).order_by({orm.Data: 'id'}) result = builder.all(flat=True) assert isinstance(result, list) assert len(result) == 20 assert result == list(chain.from_iterable(zip(pks, uuids))) + @staticmethod + def test_first_flat(): + """Test the `flat` keyword for the `QueryBuilder.first()` method.""" + node = orm.Data().store() + + # Single projected property + query = orm.QueryBuilder().append(orm.Data, project='id', filters={'id': node.pk}) + assert query.first(flat=True) == node.pk + + # Mutltiple projections + query = orm.QueryBuilder().append(orm.Data, project=['id', 'uuid'], filters={'id': node.pk}) + assert query.first(flat=True) == [node.pk, node.uuid] + def test_query_links(self): """Test querying for links""" d1, d2, d3, d4 = [orm.Data().store() for _ in range(4)] From fe0d3a93f8994218d491c859cb695935d6966338 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 9 Mar 2022 15:04:21 +0100 Subject: [PATCH 2/3] Define `overload`ed methods for typing purposes --- aiida/orm/querybuilder.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index f1392b2406..df62f60af8 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -19,6 +19,8 @@ An instance of one of the implementation classes becomes a member of the :func:`QueryBuilder` instance when instantiated by the user. """ +from __future__ import annotations + from copy import deepcopy from inspect import isclass as inspect_isclass from typing import ( @@ -27,6 +29,7 @@ Dict, Iterable, List, + Literal, NamedTuple, Optional, Sequence, @@ -35,6 +38,7 @@ Type, Union, cast, + overload, ) import warnings @@ -989,7 +993,15 @@ def _get_aiida_entity_res(value) -> Any: except TypeError: return value - def first(self, flat: bool = False) -> Optional[Union[List[Any], Any]]: + @overload + def first(self, flat: Literal[False]) -> Optional[list[Any]]: + ... + + @overload + def first(self, flat: Literal[True]) -> Optional[Any]: + ... + + def first(self, flat: bool = False) -> Optional[list[Any] | Any]: """Return the first result of the query. Calling ``first`` results in an execution of the underlying query. From 952ed5aa532d8fb04fc5286c3e5b2b8762e03c6e Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 9 Mar 2022 20:46:05 +0100 Subject: [PATCH 3/3] Use `first(flat=True)` where applicable and add pylint ignore where it is being a pita --- aiida/orm/nodes/data/upf.py | 3 +-- aiida/restapi/translator/nodes/node.py | 4 ++-- tests/orm/test_groups.py | 4 ++-- tests/orm/test_querybuilder.py | 11 +++++++---- tests/test_nodes.py | 4 ++-- tests/tools/archive/orm/test_computers.py | 20 ++++++++++---------- 6 files changed, 24 insertions(+), 22 deletions(-) diff --git a/aiida/orm/nodes/data/upf.py b/aiida/orm/nodes/data/upf.py index b212327ba2..6896ca5a19 100644 --- a/aiida/orm/nodes/data/upf.py +++ b/aiida/orm/nodes/data/upf.py @@ -122,7 +122,7 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T md5sum = md5_file(filename) builder = orm.QueryBuilder(backend=backend) builder.append(UpfData, filters={'attributes.md5': {'==': md5sum}}) - existing_upf = builder.first() + existing_upf = builder.first(flat=True) if existing_upf is None: # return the upfdata instances, not stored @@ -133,7 +133,6 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T else: if stop_if_existing: raise ValueError(f'A UPF with identical MD5 to {filename} cannot be added with stop_if_existing') - existing_upf = existing_upf[0] pseudo_and_created.append((existing_upf, False)) # check whether pseudo are unique per element diff --git a/aiida/restapi/translator/nodes/node.py b/aiida/restapi/translator/nodes/node.py index 2a38586afe..3f3c920ea8 100644 --- a/aiida/restapi/translator/nodes/node.py +++ b/aiida/restapi/translator/nodes/node.py @@ -254,7 +254,7 @@ def _get_content(self): return {} # otherwise ... - node = self.qbobj.first()[0] + node = self.qbobj.first()[0] # pylint: disable=unsubscriptable-object # content/attributes if self._content_type == 'attributes': @@ -643,7 +643,7 @@ def get_node_description(node): nodes = [] if qb_obj.count() > 0: - main_node = qb_obj.first()[0] + main_node = qb_obj.first(flat=True) pk = main_node.pk uuid = main_node.uuid nodetype = main_node.node_type diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index 08f5925b0e..fabcd29760 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -268,7 +268,7 @@ def test_group_uuid_hashing_for_querybuidler(self): # Search for the UUID of the stored group builder = orm.QueryBuilder() builder.append(orm.Group, project=['uuid'], filters={'label': {'==': 'test_group'}}) - [uuid] = builder.first() + uuid = builder.first(flat=True) # Look the node with the previously returned UUID builder = orm.QueryBuilder() @@ -279,7 +279,7 @@ def test_group_uuid_hashing_for_querybuidler(self): # And that the results are correct assert builder.count() == 1 - assert builder.first()[0] == group.id + assert builder.first(flat=True) == group.id @pytest.mark.usefixtures('aiida_profile_clean') diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index 5b91e89f90..2d950aacb6 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -716,13 +716,16 @@ def test_first_multiple_projections(self): orm.Data().store() orm.Data().store() - result = orm.QueryBuilder().append(orm.User, tag='user', - project=['email']).append(orm.Data, with_user='user', project=['*']).first() + query = orm.QueryBuilder() + query.append(orm.User, tag='user', project=['email']) + query.append(orm.Data, with_user='user', project=['*']) + + result = query.first() assert isinstance(result, list) assert len(result) == 2 - assert isinstance(result[0], str) - assert isinstance(result[1], orm.Data) + assert isinstance(result[0], str) # pylint: disable=unsubscriptable-object + assert isinstance(result[1], orm.Data) # pylint: disable=unsubscriptable-object class TestRepresentations: diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 5fdeae14da..a0110aa6ec 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -89,7 +89,7 @@ def test_node_uuid_hashing_for_querybuidler(self): # Search for the UUID of the stored node qb = orm.QueryBuilder() qb.append(orm.Data, project=['uuid'], filters={'id': {'==': n.id}}) - [uuid] = qb.first() + uuid = qb.first(flat=True) # Look the node with the previously returned UUID qb = orm.QueryBuilder() @@ -99,7 +99,7 @@ def test_node_uuid_hashing_for_querybuidler(self): qb.all() # And that the results are correct assert qb.count() == 1 - assert qb.first()[0] == n.id + assert qb.first(flat=True) == n.id @staticmethod def create_folderdata_with_empty_file(): diff --git a/tests/tools/archive/orm/test_computers.py b/tests/tools/archive/orm/test_computers.py index 3f1513e03c..fc0a652c7d 100644 --- a/tests/tools/archive/orm/test_computers.py +++ b/tests/tools/archive/orm/test_computers.py @@ -79,17 +79,17 @@ def test_same_computer_import(tmp_path, aiida_profile_clean, aiida_localhost): builder = orm.QueryBuilder() builder.append(orm.CalcJobNode, project=['label']) assert builder.count() == 1, 'Only one calculation should be found.' - assert str(builder.first()[0]) == calc1_label, 'The calculation label is not correct.' + assert str(builder.first(flat=True)) == calc1_label, 'The calculation label is not correct.' # Check that the referenced computer is imported correctly. builder = orm.QueryBuilder() builder.append(orm.Computer, project=['label', 'uuid', 'id']) assert builder.count() == 1, 'Only one computer should be found.' - assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.' - assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.' + assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.' # pylint: disable=unsubscriptable-object + assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.' # pylint: disable=unsubscriptable-object # Store the id of the computer - comp_id = builder.first()[2] + comp_id = builder.first()[2] # pylint: disable=unsubscriptable-object # Import the second calculation import_archive(filename2) @@ -99,9 +99,9 @@ def test_same_computer_import(tmp_path, aiida_profile_clean, aiida_localhost): builder = orm.QueryBuilder() builder.append(orm.Computer, project=['label', 'uuid', 'id']) assert builder.count() == 1, f'Found {builder.count()} computersbut only one computer should be found.' - assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.' - assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.' - assert builder.first()[2] == comp_id, 'The computer id is not correct.' + assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.' # pylint: disable=unsubscriptable-object + assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.' # pylint: disable=unsubscriptable-object + assert builder.first()[2] == comp_id, 'The computer id is not correct.' # pylint: disable=unsubscriptable-object # Check that now you have two calculations attached to the same # computer. @@ -175,13 +175,13 @@ def test_same_computer_different_name_import(tmp_path, aiida_profile_clean, aiid builder = orm.QueryBuilder() builder.append(orm.CalcJobNode, project=['label']) assert builder.count() == 1, 'Only one calculation should be found.' - assert str(builder.first()[0]) == calc1_label, 'The calculation label is not correct.' + assert str(builder.first(flat=True)) == calc1_label, 'The calculation label is not correct.' # Check that the referenced computer is imported correctly. builder = orm.QueryBuilder() builder.append(orm.Computer, project=['label', 'uuid', 'id']) assert builder.count() == 1, 'Only one computer should be found.' - assert str(builder.first()[0]) == comp1_name, 'The computer name is not correct.' + assert str(builder.first()[0]) == comp1_name, 'The computer name is not correct.' # pylint: disable=unsubscriptable-object # Import the second calculation import_archive(filename2) @@ -191,7 +191,7 @@ def test_same_computer_different_name_import(tmp_path, aiida_profile_clean, aiid builder = orm.QueryBuilder() builder.append(orm.Computer, project=['label']) assert builder.count() == 1, f'Found {builder.count()} computersbut only one computer should be found.' - assert str(builder.first()[0]) == comp1_name, 'The computer name is not correct.' + assert str(builder.first(flat=True)) == comp1_name, 'The computer name is not correct.' def test_different_computer_same_name_import(tmp_path, aiida_profile_clean, aiida_localhost_factory):