From 907f241a2ceadda7cb2b8cf46f924a4573a1efff Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 13 Aug 2021 19:39:48 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=AA=20TESTS:=20convert=20tests/orm/tes?= =?UTF-8?q?t=5Fquerybuilder.py=20to=20pytest?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/orm/test_querybuilder.py | 576 ++++++++++++++++----------------- 1 file changed, 282 insertions(+), 294 deletions(-) diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index df7a16e39d..a9f743d3da 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -7,26 +7,25 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name,missing-docstring,too-many-lines +# pylint: disable=attribute-defined-outside-init,invalid-name,no-self-use,missing-docstring,too-many-lines,unused-argument """Tests for the QueryBuilder.""" +from collections import defaultdict +from datetime import date, datetime, timedelta +from itertools import chain import warnings + import pytest from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.common.links import LinkType from aiida.manage import configuration -class TestQueryBuilder(AiidaTestCase): - - def setUp(self): - super().setUp() - self.refurbish_db() +@pytest.mark.usefixtures('clear_database_before_test') +class TestQueryBuilder: def test_date_filters_support(self): """Verify that `datetime.date` is supported in filters.""" - from datetime import date, timedelta from aiida.common import timezone # Using timezone.now() rather than datetime.now() to get a timezone-aware object rather than a naive one @@ -34,7 +33,7 @@ def test_date_filters_support(self): orm.Data(ctime=timezone.now() - timedelta(days=1)).store() builder = orm.QueryBuilder().append(orm.Node, filters={'ctime': {'>': date.today() - timedelta(days=1)}}) - self.assertEqual(builder.count(), 1) + assert builder.count() == 1 def test_ormclass_type_classification(self): """ @@ -46,11 +45,11 @@ def test_ormclass_type_classification(self): qb = orm.QueryBuilder() # Asserting that improper declarations of the class type raise an error - with self.assertRaises(DbContentError): + with pytest.raises(DbContentError): qb._get_ormclass(None, 'data') - with self.assertRaises(DbContentError): + with pytest.raises(DbContentError): qb._get_ormclass(None, 'data.Data') - with self.assertRaises(DbContentError): + with pytest.raises(DbContentError): qb._get_ormclass(None, '.') # Asserting that the query type string and plugin type string are returned: @@ -58,34 +57,34 @@ def test_ormclass_type_classification(self): qb._get_ormclass(orm.StructureData, None), qb._get_ormclass(None, 'data.structure.StructureData.'), ): - self.assertEqual(classifiers['ormclass_type_string'], orm.StructureData._plugin_type_string) # pylint: disable=no-member + assert classifiers['ormclass_type_string'] == orm.StructureData._plugin_type_string # pylint: disable=no-member for _cls, classifiers in ( qb._get_ormclass(orm.Group, None), qb._get_ormclass(None, 'group.core'), qb._get_ormclass(None, 'Group.core'), ): - self.assertTrue(classifiers['ormclass_type_string'].startswith('group')) + assert classifiers['ormclass_type_string'].startswith('group') for _cls, classifiers in ( qb._get_ormclass(orm.User, None), qb._get_ormclass(None, 'user'), qb._get_ormclass(None, 'User'), ): - self.assertEqual(classifiers['ormclass_type_string'], 'user') + assert classifiers['ormclass_type_string'] == 'user' for _cls, classifiers in ( qb._get_ormclass(orm.Computer, None), qb._get_ormclass(None, 'computer'), qb._get_ormclass(None, 'Computer'), ): - self.assertEqual(classifiers['ormclass_type_string'], 'computer') + assert classifiers['ormclass_type_string'] == 'computer' for _cls, classifiers in ( qb._get_ormclass(orm.Data, None), qb._get_ormclass(None, 'data.Data.'), ): - self.assertEqual(classifiers['ormclass_type_string'], orm.Data._plugin_type_string) # pylint: disable=no-member + assert classifiers['ormclass_type_string'] == orm.Data._plugin_type_string # pylint: disable=no-member def test_process_type_classification(self): """ @@ -103,34 +102,34 @@ def test_process_type_classification(self): # When passing a WorkChain class, it should return the type of the corresponding Node # including the appropriate filter on the process_type _cls, classifiers = qb._get_ormclass(WorkChain, None) - self.assertEqual(classifiers['ormclass_type_string'], 'process.workflow.workchain.WorkChainNode.') - self.assertEqual(classifiers['process_type_string'], 'aiida.engine.processes.workchains.workchain.WorkChain') + assert classifiers['ormclass_type_string'] == 'process.workflow.workchain.WorkChainNode.' + assert classifiers['process_type_string'] == 'aiida.engine.processes.workchains.workchain.WorkChain' # When passing a WorkChainNode, no process_type filter is applied _cls, classifiers = qb._get_ormclass(orm.WorkChainNode, None) - self.assertEqual(classifiers['ormclass_type_string'], 'process.workflow.workchain.WorkChainNode.') - self.assertEqual(classifiers['process_type_string'], None) + assert classifiers['ormclass_type_string'] == 'process.workflow.workchain.WorkChainNode.' + assert classifiers['process_type_string'] is None # Same tests for a calculation _cls, classifiers = qb._get_ormclass(ArithmeticAdd, None) - self.assertEqual(classifiers['ormclass_type_string'], 'process.calculation.calcjob.CalcJobNode.') - self.assertEqual(classifiers['process_type_string'], 'aiida.calculations:arithmetic.add') + assert classifiers['ormclass_type_string'] == 'process.calculation.calcjob.CalcJobNode.' + assert classifiers['process_type_string'] == 'aiida.calculations:arithmetic.add' def test_get_group_type_filter(self): """Test the `aiida.orm.querybuilder.get_group_type_filter` function.""" from aiida.orm.querybuilder import get_group_type_filter classifiers = {'ormclass_type_string': 'group.core'} - self.assertEqual(get_group_type_filter(classifiers, False), {'==': 'core'}) - self.assertEqual(get_group_type_filter(classifiers, True), {'like': '%'}) + assert get_group_type_filter(classifiers, False) == {'==': 'core'} + assert get_group_type_filter(classifiers, True) == {'like': '%'} classifiers = {'ormclass_type_string': 'group.core.auto'} - self.assertEqual(get_group_type_filter(classifiers, False), {'==': 'core.auto'}) - self.assertEqual(get_group_type_filter(classifiers, True), {'like': 'core.auto%'}) + assert get_group_type_filter(classifiers, False) == {'==': 'core.auto'} + assert get_group_type_filter(classifiers, True) == {'like': 'core.auto%'} classifiers = {'ormclass_type_string': 'group.pseudo.family'} - self.assertEqual(get_group_type_filter(classifiers, False), {'==': 'pseudo.family'}) - self.assertEqual(get_group_type_filter(classifiers, True), {'like': 'pseudo.family%'}) + assert get_group_type_filter(classifiers, False) == {'==': 'pseudo.family'} + assert get_group_type_filter(classifiers, True) == {'like': 'pseudo.family%'} # Tracked in issue #4281 @pytest.mark.flaky(reruns=2) @@ -200,8 +199,8 @@ class DummyWorkChain(WorkChain): assert issubclass(w[-1].category, AiidaEntryPointWarning) # There should be one result of type WorkChainNode - self.assertEqual(qb.count(), 1) - self.assertTrue(isinstance(qb.all()[0][0], orm.WorkChainNode)) + assert qb.count() == 1 + assert isinstance(qb.all()[0][0], orm.WorkChainNode) # Query for nodes of a different type of WorkChain qb = orm.QueryBuilder() @@ -217,14 +216,14 @@ class DummyWorkChain(WorkChain): assert issubclass(w[-1].category, AiidaEntryPointWarning) # There should be no result - self.assertEqual(qb.count(), 0) + assert qb.count() == 0 # Query for all WorkChain nodes qb = orm.QueryBuilder() qb.append(WorkChain) # There should be one result - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 def test_simple_query_1(self): """ @@ -266,38 +265,37 @@ def test_simple_query_1(self): qb1 = orm.QueryBuilder() qb1.append(orm.Node, filters={'attributes.foo': 1.000}) - self.assertEqual(len(qb1.all()), 2) + assert len(qb1.all()) == 2 qb2 = orm.QueryBuilder() qb2.append(orm.Data) - self.assertEqual(qb2.count(), 3) + assert qb2.count() == 3 qb2 = orm.QueryBuilder() qb2.append(entity_type='data.Data.') - self.assertEqual(qb2.count(), 3) + assert qb2.count() == 3 qb3 = orm.QueryBuilder() qb3.append(orm.Node, project='label', tag='node1') qb3.append(orm.Node, project='label', tag='node2') - self.assertEqual(qb3.count(), 4) + assert qb3.count() == 4 qb4 = orm.QueryBuilder() qb4.append(orm.CalculationNode, tag='node1') qb4.append(orm.Data, tag='node2') - self.assertEqual(qb4.count(), 2) + assert qb4.count() == 2 qb5 = orm.QueryBuilder() qb5.append(orm.Data, tag='node1') qb5.append(orm.CalculationNode, tag='node2') - self.assertEqual(qb5.count(), 2) + assert qb5.count() == 2 qb6 = orm.QueryBuilder() qb6.append(orm.Data, tag='node1') qb6.append(orm.Data, tag='node2') - self.assertEqual(qb6.count(), 0) + assert qb6.count() == 0 def test_simple_query_2(self): - from datetime import datetime from aiida.common.exceptions import MultipleObjectsError, NotExistent n0 = orm.Data() n0.label = 'hello' @@ -320,7 +318,7 @@ def test_simple_query_2(self): qb1 = orm.QueryBuilder() qb1.append(orm.Node, filters={'label': 'hello'}) - self.assertEqual(len(list(qb1.all())), 1) + assert len(list(qb1.all())) == 1 qh = { 'path': [{ @@ -352,11 +350,11 @@ def test_simple_query_2(self): qb2 = orm.QueryBuilder(**qh) resdict = qb2.dict() - self.assertEqual(len(resdict), 1) - self.assertTrue(isinstance(resdict[0]['n1']['ctime'], datetime)) + assert len(resdict) == 1 + assert isinstance(resdict[0]['n1']['ctime'], datetime) res_one = qb2.one() - self.assertTrue('bar' in res_one) + assert 'bar' in res_one qh = { 'path': [{ @@ -376,23 +374,23 @@ def test_simple_query_2(self): } } qb = orm.QueryBuilder(**qh) - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 # Test the hashing: query1 = qb.get_query() qb.add_filter('n2', {'label': 'nonexistentlabel'}) - self.assertEqual(qb.count(), 0) + assert qb.count() == 0 - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): qb.one() - with self.assertRaises(MultipleObjectsError): + with pytest.raises(MultipleObjectsError): orm.QueryBuilder().append(orm.Node).one() query2 = qb.get_query() query3 = qb.get_query() - self.assertTrue(id(query1) != id(query2)) - self.assertTrue(id(query2) == id(query3)) + assert id(query1) != id(query2) + assert id(query2) == id(query3) def test_dict_multiple_projections(self): """Test that the `.dict()` accumulator with multiple projections returns the correct types.""" @@ -400,14 +398,14 @@ def test_dict_multiple_projections(self): builder = orm.QueryBuilder().append(orm.Data, project=['*', 'id']) results = builder.dict() - self.assertIsInstance(results, list) - self.assertTrue(all(isinstance(value, dict) for value in results)) + assert isinstance(results, list) + assert all(isinstance(value, dict) for value in results) dictionary = list(results[0].values())[0] # `results` should have the form [{'Data_1': {'*': Node, 'id': 1}}] - self.assertIsInstance(dictionary['*'], orm.Data) - self.assertEqual(dictionary['*'].pk, node.pk) - self.assertEqual(dictionary['id'], node.pk) + assert isinstance(dictionary['*'], orm.Data) + assert dictionary['*'].pk == node.pk + assert dictionary['id'] == node.pk def test_operators_eq_lt_gt(self): nodes = [orm.Data() for _ in range(8)] @@ -424,12 +422,12 @@ def test_operators_eq_lt_gt(self): for n in nodes: n.store() - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<': 1}}).count(), 0) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'==': 1}}).count(), 2) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<': 1.02}}).count(), 3) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<=': 1.02}}).count(), 4) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'>': 1.02}}).count(), 4) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'>=': 1.02}}).count(), 5) + assert orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<': 1}}).count() == 0 + assert orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'==': 1}}).count() == 2 + assert orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<': 1.02}}).count() == 3 + assert orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<=': 1.02}}).count() == 4 + assert orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'>': 1.02}}).count() == 4 + assert orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'>=': 1.02}}).count() == 5 def test_subclassing(self): s = orm.StructureData() @@ -445,106 +443,106 @@ def test_subclassing(self): # Now when asking for a node with attr.cat==miau, I want 3 esults: qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.cat': 'miau'}) - self.assertEqual(qb.count(), 3) + assert qb.count() == 3 qb = orm.QueryBuilder().append(orm.Data, filters={'attributes.cat': 'miau'}) - self.assertEqual(qb.count(), 3) + assert qb.count() == 3 # If I'm asking for the specific lowest subclass, I want one result for cls in (orm.StructureData, orm.Dict): qb = orm.QueryBuilder().append(cls, filters={'attributes.cat': 'miau'}) - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 # Now I am not allow the subclassing, which should give 1 result for each for cls, count in ((orm.StructureData, 1), (orm.Dict, 1), (orm.Data, 1), (orm.Node, 0)): qb = orm.QueryBuilder().append(cls, filters={'attributes.cat': 'miau'}, subclassing=False) - self.assertEqual(qb.count(), count) + assert qb.count() == count # Now I am testing the subclassing with tuples: qb = orm.QueryBuilder().append(cls=(orm.StructureData, orm.Dict), filters={'attributes.cat': 'miau'}) - self.assertEqual(qb.count(), 2) + assert qb.count() == 2 qb = orm.QueryBuilder().append( entity_type=('data.structure.StructureData.', 'data.dict.Dict.'), filters={'attributes.cat': 'miau'} ) - self.assertEqual(qb.count(), 2) + assert qb.count() == 2 qb = orm.QueryBuilder().append( cls=(orm.StructureData, orm.Dict), filters={'attributes.cat': 'miau'}, subclassing=False ) - self.assertEqual(qb.count(), 2) + assert qb.count() == 2 qb = orm.QueryBuilder().append( cls=(orm.StructureData, orm.Data), filters={'attributes.cat': 'miau'}, ) - self.assertEqual(qb.count(), 3) + assert qb.count() == 3 qb = orm.QueryBuilder().append( entity_type=('data.structure.StructureData.', 'data.dict.Dict.'), filters={'attributes.cat': 'miau'}, subclassing=False ) - self.assertEqual(qb.count(), 2) + assert qb.count() == 2 qb = orm.QueryBuilder().append( entity_type=('data.structure.StructureData.', 'data.Data.'), filters={'attributes.cat': 'miau'}, subclassing=False ) - self.assertEqual(qb.count(), 2) + assert qb.count() == 2 def test_list_behavior(self): for _i in range(4): orm.Data().store() - self.assertEqual(len(orm.QueryBuilder().append(orm.Node).all()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project='*').all()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).all()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['id']).all()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node).dict()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project='*').dict()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).dict()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['id']).dict()), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node).iterall())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project='*').iterall())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).iterall())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['id']).iterall())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node).iterdict())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project='*').iterdict())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).iterdict())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['id']).iterdict())), 4) + assert len(orm.QueryBuilder().append(orm.Node).all()) == 4 + assert len(orm.QueryBuilder().append(orm.Node, project='*').all()) == 4 + assert len(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).all()) == 4 + assert len(orm.QueryBuilder().append(orm.Node, project=['id']).all()) == 4 + assert len(orm.QueryBuilder().append(orm.Node).dict()) == 4 + assert len(orm.QueryBuilder().append(orm.Node, project='*').dict()) == 4 + assert len(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).dict()) == 4 + assert len(orm.QueryBuilder().append(orm.Node, project=['id']).dict()) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node).iterall())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node, project='*').iterall())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).iterall())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node, project=['id']).iterall())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node).iterdict())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node, project='*').iterdict())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).iterdict())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node, project=['id']).iterdict())) == 4 def test_append_validation(self): # So here I am giving two times the same tag - with self.assertRaises(ValueError): + with pytest.raises(ValueError): orm.QueryBuilder().append(orm.StructureData, tag='n').append(orm.StructureData, tag='n') # here I am giving a wrong filter specifications - with self.assertRaises(TypeError): + with pytest.raises(TypeError): orm.QueryBuilder().append(orm.StructureData, filters=['jajjsd']) # here I am giving a nonsensical projection: - with self.assertRaises(ValueError): + with pytest.raises(ValueError): orm.QueryBuilder().append(orm.StructureData, project=True) # here I am giving a nonsensical projection for the edge: - with self.assertRaises(ValueError): + with pytest.raises(ValueError): orm.QueryBuilder().append(orm.ProcessNode).append(orm.StructureData, edge_tag='t').add_projection('t', True) # Giving a nonsensical limit - with self.assertRaises(TypeError): + with pytest.raises(TypeError): orm.QueryBuilder().append(orm.ProcessNode).limit(2.3) # Giving a nonsensical offset - with self.assertRaises(TypeError): + with pytest.raises(TypeError): orm.QueryBuilder(offset=2.3) # So, I mess up one append, I want the QueryBuilder to clean it! - with self.assertRaises(ValueError): + with pytest.raises(ValueError): qb = orm.QueryBuilder() # This also checks if we correctly raise for wrong keywords qb.append(orm.StructureData, tag='s', randomkeyword={}) # Now I'm checking whether this keyword appears anywhere in the internal dictionaries: # pylint: disable=protected-access - self.assertTrue('s' not in qb._projections) - self.assertTrue('s' not in qb._filters) - self.assertTrue('s' not in qb.tag_to_alias_map) - self.assertTrue(len(qb._path) == 0) - self.assertTrue(orm.StructureData not in qb._cls_to_tag_map) + assert 's' not in qb._projections + assert 's' not in qb._filters + assert 's' not in qb.tag_to_alias_map + assert len(qb._path) == 0 + assert orm.StructureData not in qb._cls_to_tag_map # So this should work now: qb.append(orm.StructureData, tag='s').limit(2).dict() @@ -554,43 +552,33 @@ def test_tags(self): qb.append(orm.Node, tag='n2', edge_tag='e1', with_incoming='n1') qb.append(orm.Node, tag='n3', edge_tag='e2', with_incoming='n2') qb.append(orm.Computer, with_node='n3', tag='c1', edge_tag='nonsense') - self.assertEqual(qb.get_used_tags(), ['n1', 'n2', 'e1', 'n3', 'e2', 'c1', 'nonsense']) + assert qb.get_used_tags() == ['n1', 'n2', 'e1', 'n3', 'e2', 'c1', 'nonsense'] # Now I am testing the default tags, qb = orm.QueryBuilder().append(orm.StructureData ).append(orm.ProcessNode ).append(orm.StructureData ).append(orm.Dict, with_outgoing=orm.ProcessNode) - self.assertEqual( - qb.get_used_tags(), [ - 'StructureData_1', 'ProcessNode_1', 'StructureData_1--ProcessNode_1', 'StructureData_2', - 'ProcessNode_1--StructureData_2', 'Dict_1', 'ProcessNode_1--Dict_1' - ] - ) - self.assertEqual( - qb.get_used_tags(edges=False), [ - 'StructureData_1', - 'ProcessNode_1', - 'StructureData_2', - 'Dict_1', - ] - ) - self.assertEqual( - qb.get_used_tags(vertices=False), + assert qb.get_used_tags() == [ + 'StructureData_1', 'ProcessNode_1', 'StructureData_1--ProcessNode_1', 'StructureData_2', + 'ProcessNode_1--StructureData_2', 'Dict_1', 'ProcessNode_1--Dict_1' + ] + assert qb.get_used_tags(edges=False) == [ + 'StructureData_1', + 'ProcessNode_1', + 'StructureData_2', + 'Dict_1', + ] + assert qb.get_used_tags(vertices=False) == \ ['StructureData_1--ProcessNode_1', 'ProcessNode_1--StructureData_2', 'ProcessNode_1--Dict_1'] - ) - self.assertEqual( - qb.get_used_tags(edges=False), [ - 'StructureData_1', - 'ProcessNode_1', - 'StructureData_2', - 'Dict_1', - ] - ) - self.assertEqual( - qb.get_used_tags(vertices=False), + assert qb.get_used_tags(edges=False) == [ + 'StructureData_1', + 'ProcessNode_1', + 'StructureData_2', + 'Dict_1', + ] + assert qb.get_used_tags(vertices=False) == \ ['StructureData_1--ProcessNode_1', 'ProcessNode_1--StructureData_2', 'ProcessNode_1--Dict_1'] - ) def test_direction_keyword(self): """ @@ -618,8 +606,8 @@ def test_direction_keyword(self): qb.append(orm.CalculationNode, with_incoming='data', project='id') res2 = {_ for _, in qb.all()} - self.assertEqual(res1, res2) - self.assertEqual(res1, {c1.id}) + assert res1 == res2 + assert res1 == {c1.id} # testing direction=-1, which should return the incoming qb = orm.QueryBuilder() @@ -631,8 +619,8 @@ def test_direction_keyword(self): qb.append(orm.Data, filters={'id': d2.id}, tag='data') qb.append(orm.CalculationNode, with_outgoing='data', project='id') res2 = {_ for _, in qb.all()} - self.assertEqual(res1, res2) - self.assertEqual(res1, {c1.id}) + assert res1 == res2 + assert res1 == {c1.id} # testing direction higher than 1 qb = orm.QueryBuilder() @@ -643,18 +631,16 @@ def test_direction_keyword(self): qh = qb.queryhelp # saving query for later qb.append(orm.Data, direction=-4, project='id') res1 = {item[1] for item in qb.all()} - self.assertEqual(res1, {d1.id}) + assert res1 == {d1.id} qb = orm.QueryBuilder(**qh) qb.append(orm.Data, direction=4, project='id') res2 = {item[1] for item in qb.all()} - self.assertEqual(res2, {d2.id, d4.id}) + assert res2 == {d2.id, d4.id} @staticmethod def test_flat(): """Test the `flat` keyword for the `QueryBuilder.all()` method.""" - from itertools import chain - pks = [] uuids = [] for _ in range(10): @@ -677,7 +663,8 @@ def test_flat(): assert result == list(chain.from_iterable(zip(pks, uuids))) -class TestMultipleProjections(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test') +class TestMultipleProjections: """Unit tests for the QueryBuilder ORM class.""" def test_first_multiple_projections(self): @@ -688,13 +675,17 @@ def test_first_multiple_projections(self): result = orm.QueryBuilder().append(orm.User, tag='user', project=['email']).append(orm.Data, with_user='user', project=['*']).first() - self.assertEqual(type(result), list) - self.assertEqual(len(result), 2) - self.assertIsInstance(result[0], str) - self.assertIsInstance(result[1], orm.Data) + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0], str) + assert isinstance(result[1], orm.Data) + +class TestQueryHelp: -class TestQueryHelp(AiidaTestCase): + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test, aiida_localhost): + self.computer = aiida_localhost def test_queryhelp(self): """ @@ -721,26 +712,26 @@ def test_queryhelp(self): ): qb = orm.QueryBuilder() qb.append(cls, filters={'attributes.foo-qh2': 'bar'}, subclassing=subclassing, project='uuid') - self.assertEqual(qb.count(), expected_count) + assert qb.count() == expected_count qh = qb.queryhelp qb_new = orm.QueryBuilder(**qh) - self.assertEqual(qb_new.count(), expected_count) - self.assertEqual(sorted([uuid for uuid, in qb.all()]), sorted([uuid for uuid, in qb_new.all()])) + assert qb_new.count() == expected_count + assert sorted([uuid for uuid, in qb.all()]) == sorted([uuid for uuid, in qb_new.all()]) qb = orm.QueryBuilder().append(orm.Group, filters={'label': 'helloworld'}) - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 qb = orm.QueryBuilder().append((orm.Group,), filters={'label': 'helloworld'}) - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 # populate computer self.computer # pylint:disable=pointless-statement qb = orm.QueryBuilder().append(orm.Computer,) - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 qb = orm.QueryBuilder().append(cls=(orm.Computer,)) - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 def test_recreate_from_queryhelp(self): """Test recreating a QueryBuilder from the Query Help @@ -757,13 +748,14 @@ def test_recreate_from_queryhelp(self): qb1.append(orm.CalcJobNode) qb2 = orm.QueryBuilder(**qb1.queryhelp) - self.assertDictEqual(qb1.queryhelp, qb2.queryhelp) + assert qb1.queryhelp == qb2.queryhelp qb3 = copy.deepcopy(qb1) - self.assertDictEqual(qb1.queryhelp, qb3.queryhelp) + assert qb1.queryhelp == qb3.queryhelp -class TestQueryBuilderCornerCases(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test') +class TestQueryBuilderCornerCases: """ In this class corner cases of QueryBuilder are added. """ @@ -789,7 +781,8 @@ def test_computer_json(self): # pylint: disable=no-self-use qb.all() -class TestAttributes(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test') +class TestAttributes: def test_attribute_existence(self): # I'm storing a value under key whatever: @@ -819,7 +812,7 @@ def test_attribute_existence(self): project='uuid' ) res_query = {str(_[0]) for _ in qb.all()} - self.assertEqual(res_query, res_uuids) + assert res_query == res_uuids def test_attribute_type(self): key = 'value_test_attr_type' @@ -839,30 +832,30 @@ def test_attribute_type(self): for val in (1.0, 1): qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': val}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_float.uuid, n_int.uuid))) + assert set(res) == set((n_float.uuid, n_int.uuid)) qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'>': 0.5}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_float.uuid, n_int.uuid))) + assert set(res) == set((n_float.uuid, n_int.uuid)) qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'<': 1.5}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_float.uuid, n_int.uuid))) + assert set(res) == set((n_float.uuid, n_int.uuid)) # Now I am testing the boolean value: qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': True}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_bool.uuid,))) + assert set(res) == set((n_bool.uuid,)) qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'like': '%n%'}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_str2.uuid,))) + assert set(res) == set((n_str2.uuid,)) qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'ilike': 'On%'}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_str2.uuid,))) + assert set(res) == set((n_str2.uuid,)) qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'like': '1'}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_str.uuid,))) + assert set(res) == set((n_str.uuid,)) qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'==': '1'}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_str.uuid,))) + assert set(res) == set((n_str.uuid,)) if configuration.PROFILE.database_backend == 'sqlalchemy': # I can't query the length of an array with Django, # so I exclude. Not the nicest way, But I would like to keep this piece @@ -870,10 +863,11 @@ def test_attribute_type(self): # duplicated or wrapped otherwise. qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'of_length': 3}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_arr.uuid,))) + assert set(res) == set((n_arr.uuid,)) -class QueryBuilderLimitOffsetsTest(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test') +class QueryBuilderLimitOffsetsTest: def test_ordering_limits_offsets_of_results_general(self): # Creating 10 nodes with an attribute that can be ordered @@ -885,17 +879,17 @@ def test_ordering_limits_offsets_of_results_general(self): qb = orm.QueryBuilder().append(orm.Node, project='attributes.foo').order_by({orm.Node: 'ctime'}) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(10))) + assert res == tuple(range(10)) # Now applying an offset: qb.offset(5) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(5, 10))) + assert res == tuple(range(5, 10)) # Now also applying a limit: qb.limit(3) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(5, 8))) + assert res == tuple(range(5, 8)) # Specifying the order explicitly the order: qb = orm.QueryBuilder().append(orm.Node, @@ -906,17 +900,17 @@ def test_ordering_limits_offsets_of_results_general(self): }}) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(10))) + assert res == tuple(range(10)) # Now applying an offset: qb.offset(5) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(5, 10))) + assert res == tuple(range(5, 10)) # Now also applying a limit: qb.limit(3) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(5, 8))) + assert res == tuple(range(5, 8)) # Reversing the order: qb = orm.QueryBuilder().append(orm.Node, @@ -927,20 +921,21 @@ def test_ordering_limits_offsets_of_results_general(self): }}) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(9, -1, -1))) + assert res == tuple(range(9, -1, -1)) # Now applying an offset: qb.offset(5) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(4, -1, -1))) + assert res == tuple(range(4, -1, -1)) # Now also applying a limit: qb.limit(3) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(4, 1, -1))) + assert res == tuple(range(4, 1, -1)) -class QueryBuilderJoinsTests(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test') +class QueryBuilderJoinsTests: def test_joins1(self): # Creating n1, who will be a parent: @@ -969,12 +964,12 @@ def test_joins1(self): qb = orm.QueryBuilder() qb.append(orm.Node, tag='parent') qb.append(orm.Node, tag='children', project='label', filters={'attributes.is_good': True}) - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 qb = orm.QueryBuilder() qb.append(orm.Node, tag='parent') qb.append(orm.Node, tag='children', outerjoin=True, project='label', filters={'attributes.is_good': True}) - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 def test_joins2(self): # Creating n1, who will be a parent: @@ -1003,26 +998,22 @@ def test_joins2(self): # let's add a differnt relationship than advisor: students[9].add_incoming(advisors[2], link_type=LinkType.CREATE, link_label='lover') - self.assertEqual( - orm.QueryBuilder().append( - orm.Node - ).append(orm.Node, edge_filters={ + assert orm.QueryBuilder().append( + orm.Node + ).append(orm.Node, edge_filters={ + 'label': { + 'like': 'is\\_advisor\\_%' + } + }, tag='student').count() == 7 + + for adv_id, number_students in zip(list(range(3)), (2, 2, 3)): + assert orm.QueryBuilder().append(orm.Node, filters={ + 'attributes.advisor_id': adv_id + }).append(orm.Node, edge_filters={ 'label': { 'like': 'is\\_advisor\\_%' } - }, tag='student').count(), 7 - ) - - for adv_id, number_students in zip(list(range(3)), (2, 2, 3)): - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'attributes.advisor_id': adv_id - }).append(orm.Node, edge_filters={ - 'label': { - 'like': 'is\\_advisor\\_%' - } - }, tag='student').count(), number_students - ) + }, tag='student').count() == number_students def test_joins3_user_group(self): # Create another user @@ -1038,14 +1029,14 @@ def test_joins3_user_group(self): qb = orm.QueryBuilder() qb.append(orm.User, tag='user', filters={'id': {'==': user.id}}) qb.append(orm.Group, with_user='user', filters={'id': {'==': group.id}}) - self.assertEqual(qb.count(), 1, 'The expected group that belongs to the selected user was not found.') + assert qb.count() == 1, 'The expected group that belongs to the selected user was not found.' # Search for the user that owns a group qb = orm.QueryBuilder() qb.append(orm.Group, tag='group', filters={'id': {'==': group.id}}) qb.append(orm.User, with_group='group', filters={'id': {'==': user.id}}) - self.assertEqual(qb.count(), 1, 'The expected user that owns the selected group was not found.') + assert qb.count() == 1, 'The expected user that owns the selected group was not found.' def test_joins_group_node(self): """ @@ -1089,13 +1080,17 @@ def test_joins_group_node(self): qb = orm.QueryBuilder() qb.append(orm.Node, tag='node', project=['id']) qb.append(orm.Group, with_node='node', filters={'id': {'==': group.id}}) - self.assertEqual(qb.count(), 4, 'There should be 4 nodes in the group') + assert qb.count() == 4, 'There should be 4 nodes in the group' id_res = [_ for [_] in qb.all()] for curr_id in [n1.id, n2.id, n3.id, n4.id]: - self.assertIn(curr_id, id_res) + assert curr_id in id_res -class QueryBuilderPath(AiidaTestCase): +class QueryBuilderPath: + + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test, backend): + self.backend = backend def test_query_path(self): # pylint: disable=too-many-statements @@ -1135,95 +1130,81 @@ def test_query_path(self): node.store() # There are no parents to n9, checking that - self.assertEqual(set([]), set(q.get_all_parents([n9.pk]))) + assert set([]) == set(q.get_all_parents([n9.pk])) # There is one parent to n6 - self.assertEqual({(_,) for _ in (n6.pk,)}, {tuple(_) for _ in q.get_all_parents([n7.pk])}) + assert {(_,) for _ in (n6.pk,)} == {tuple(_) for _ in q.get_all_parents([n7.pk])} # There are several parents to n4 - self.assertEqual({(_.pk,) for _ in (n1, n2)}, {tuple(_) for _ in q.get_all_parents([n4.pk])}) + assert {(_.pk,) for _ in (n1, n2)} == {tuple(_) for _ in q.get_all_parents([n4.pk])} # There are several parents to n5 - self.assertEqual({(_.pk,) for _ in (n1, n2, n3, n4)}, {tuple(_) for _ in q.get_all_parents([n5.pk])}) + assert {(_.pk,) for _ in (n1, n2, n3, n4)} == {tuple(_) for _ in q.get_all_parents([n5.pk])} # Yet, no links from 1 to 8 - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n1.pk - }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ - 'id': n8.pk - }).count(), 0 - ) + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n1.pk + }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ + 'id': n8.pk + }).count() == 0 - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append(orm.Node, with_descendants='desc', filters={ - 'id': n1.pk - }).count(), 0 - ) + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append(orm.Node, with_descendants='desc', filters={ + 'id': n1.pk + }).count() == 0 n6.add_incoming(n5, link_type=LinkType.CREATE, link_label='link1') # Yet, now 2 links from 1 to 8 - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n1.pk - }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ - 'id': n8.pk - }).count(), 2 - ) + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n1.pk + }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ + 'id': n8.pk + }).count() == 2 - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append(orm.Node, with_descendants='desc', filters={ - 'id': n1.pk - }).count(), 2 - ) + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append(orm.Node, with_descendants='desc', filters={ + 'id': n1.pk + }).count() == 2 - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append( - orm.Node, - with_descendants='desc', - filters={ - 'id': n1.pk - }, - edge_filters={ - 'depth': { - '<': 6 - } - }, - ).count(), 2 - ) - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append( - orm.Node, - with_descendants='desc', - filters={ - 'id': n1.pk - }, - edge_filters={ - 'depth': 5 - }, - ).count(), 2 - ) - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append( - orm.Node, - with_descendants='desc', - filters={ - 'id': n1.pk - }, - edge_filters={ - 'depth': { - '<': 5 - } - }, - ).count(), 0 - ) + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append( + orm.Node, + with_descendants='desc', + filters={ + 'id': n1.pk + }, + edge_filters={ + 'depth': { + '<': 6 + } + }, + ).count() == 2 + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append( + orm.Node, + with_descendants='desc', + filters={ + 'id': n1.pk + }, + edge_filters={ + 'depth': 5 + }, + ).count() == 2 + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append( + orm.Node, + with_descendants='desc', + filters={ + 'id': n1.pk + }, + edge_filters={ + 'depth': { + '<': 5 + } + }, + ).count() == 0 # TODO write a query that can filter certain paths by traversed ID # pylint: disable=fixme qb = orm.QueryBuilder().append( @@ -1240,16 +1221,16 @@ def test_query_path(self): frozenset([n1.pk, n2.pk, n4.pk, n5.pk, n6.pk, n7.pk, n8.pk]) } - self.assertTrue(queried_path_set == paths_there_should_be) + assert queried_path_set == paths_there_should_be qb = orm.QueryBuilder().append(orm.Node, filters={ 'id': n1.pk }, tag='anc').append(orm.Node, with_ancestors='anc', filters={'id': n8.pk}, edge_project='path') - self.assertEqual({frozenset(p) for p, in qb.all()}, { + assert {frozenset(p) for p, in qb.all()} == { frozenset([n1.pk, n2.pk, n3.pk, n5.pk, n6.pk, n7.pk, n8.pk]), frozenset([n1.pk, n2.pk, n4.pk, n5.pk, n6.pk, n7.pk, n8.pk]) - }) + } # This part of the test is no longer possible as the nodes have already been stored and the previous parts of # the test rely on this, which means however, that here, no more links can be added as that will raise. @@ -1298,7 +1279,8 @@ def test_query_path(self): # self.assertTrue(set(next(zip(*qb.all()))), set([5])) -class TestConsistency(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test') +class TestConsistency: def test_create_node_and_query(self): """ @@ -1314,8 +1296,8 @@ def test_create_node_and_query(self): if idx % 10 == 10: n = orm.Data() n.store() - self.assertEqual(idx, 99) # pylint: disable=undefined-loop-variable - self.assertTrue(len(orm.QueryBuilder().append(orm.Node, project=['id', 'label']).all(batch_size=10)) > 99) + assert idx == 99 # pylint: disable=undefined-loop-variable + assert len(orm.QueryBuilder().append(orm.Node, project=['id', 'label']).all(batch_size=10)) > 99 def test_len_results(self): """ @@ -1332,10 +1314,14 @@ def test_len_results(self): qb = orm.QueryBuilder() qb.append(orm.CalculationNode, filters={'id': parent.id}, tag='parent', project=projection) qb.append(orm.Data, with_incoming='parent') - self.assertEqual(len(qb.all()), qb.count()) + assert len(qb.all()) == qb.count() -class TestManager(AiidaTestCase): +class TestManager: + + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test, backend): + self.backend = backend def test_statistics(self): """ @@ -1343,7 +1329,6 @@ def test_statistics(self): I try to implement it in a way that does not depend on the past state. """ - from collections import defaultdict # pylint: disable=protected-access @@ -1375,7 +1360,7 @@ def store_and_add(n, statistics): k: dict(v) if isinstance(v, defaultdict) else v for k, v in expected_db_statistics.items() } - self.assertEqual(new_db_statistics, expected_db_statistics) + assert new_db_statistics == expected_db_statistics def test_statistics_default_class(self): """ @@ -1383,7 +1368,6 @@ def test_statistics_default_class(self): I try to implement it in a way that does not depend on the past state. """ - from collections import defaultdict def store_and_add(n, statistics): n.store() @@ -1412,15 +1396,19 @@ def store_and_add(n, statistics): k: dict(v) if isinstance(v, defaultdict) else v for k, v in expected_db_statistics.items() } - self.assertEqual(new_db_statistics, expected_db_statistics) + assert new_db_statistics == expected_db_statistics -class TestDoubleStar(AiidaTestCase): +class TestDoubleStar: """ In this test class we check if QueryBuilder returns the correct results when double star is provided as projection. """ + @pytest.fixture(autouse=True) + def init_db(self, clear_database_before_test, aiida_localhost): + self.computer = aiida_localhost + def test_statistics_default_class(self): # The expected result @@ -1439,20 +1427,20 @@ def test_statistics_default_class(self): qb = orm.QueryBuilder() qb.append(orm.Computer, project=['**']) # We expect one result - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 # Get the one result record and check that the returned # data are correct res = list(qb.dict()[0].values())[0] - self.assertDictEqual(res, expected_dict) + assert res == expected_dict # Ask the same query as above using queryhelp qh = {'project': {'computer': ['**']}, 'path': [{'tag': 'computer', 'cls': orm.Computer}]} qb = orm.QueryBuilder(**qh) # We expect one result - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 # Get the one result record and check that the returned # data are correct res = list(qb.dict()[0].values())[0] - self.assertDictEqual(res, expected_dict) + assert res == expected_dict