diff --git a/aiida/manage/tests/pytest_fixtures.py b/aiida/manage/tests/pytest_fixtures.py index 08c203b358..586d0cdac2 100644 --- a/aiida/manage/tests/pytest_fixtures.py +++ b/aiida/manage/tests/pytest_fixtures.py @@ -70,6 +70,13 @@ def clear_database_before_test(aiida_profile): yield +@pytest.fixture(scope='class') +def clear_database_before_test_class(aiida_profile): + """Clear the database before a test class.""" + aiida_profile.reset_db() + yield + + @pytest.fixture(scope='function') def temporary_event_loop(): """Create a temporary loop for independent test case""" diff --git a/aiida/orm/nodes/data/numeric.py b/aiida/orm/nodes/data/numeric.py index 6e34f812d7..702801b7d2 100644 --- a/aiida/orm/nodes/data/numeric.py +++ b/aiida/orm/nodes/data/numeric.py @@ -124,3 +124,6 @@ def __float__(self): def __int__(self): return int(self.value) + + def __abs__(self): + return abs(self.value) diff --git a/tests/engine/test_work_chain.py b/tests/engine/test_work_chain.py index 49166c9872..0d550439b0 100644 --- a/tests/engine/test_work_chain.py +++ b/tests/engine/test_work_chain.py @@ -10,14 +10,12 @@ # pylint: disable=too-many-lines,missing-function-docstring,invalid-name,missing-class-docstring,no-self-use """Tests for the `WorkChain` class.""" import inspect -import unittest import asyncio import plumpy import pytest from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions from aiida.common.links import LinkType from aiida.common.utils import Capturing @@ -188,7 +186,8 @@ def success(self): @pytest.mark.requires_rmq -class TestExitStatus(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestExitStatus: """ This class should test the various ways that one can exit from the outline flow of a WorkChain, other than it running it all the way through. Currently this can be done directly in the outline by calling the `return_` @@ -197,53 +196,49 @@ class TestExitStatus(AiidaTestCase): def test_failing_workchain_through_integer(self): _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(False)) - self.assertEqual(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS) - self.assertEqual(node.exit_message, None) - self.assertEqual(node.is_finished, True) - self.assertEqual(node.is_finished_ok, False) - self.assertEqual(node.is_failed, True) - self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) + assert node.exit_status == PotentialFailureWorkChain.EXIT_STATUS + assert node.exit_message is None + assert node.is_finished is True + assert node.is_finished_ok is False + assert node.is_failed is True + assert PotentialFailureWorkChain.OUTPUT_LABEL not in node.get_outgoing().all_link_labels() def test_failing_workchain_through_exit_code(self): _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(False), through_exit_code=Bool(True)) - self.assertEqual(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS) - self.assertEqual(node.exit_message, PotentialFailureWorkChain.EXIT_MESSAGE) - self.assertEqual(node.is_finished, True) - self.assertEqual(node.is_finished_ok, False) - self.assertEqual(node.is_failed, True) - self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) + assert node.exit_status == PotentialFailureWorkChain.EXIT_STATUS + assert node.exit_message == PotentialFailureWorkChain.EXIT_MESSAGE + assert node.is_finished is True + assert node.is_finished_ok is False + assert node.is_failed is True + assert PotentialFailureWorkChain.OUTPUT_LABEL not in node.get_outgoing().all_link_labels() def test_successful_workchain_through_integer(self): _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(True)) - self.assertEqual(node.exit_status, 0) - self.assertEqual(node.is_finished, True) - self.assertEqual(node.is_finished_ok, True) - self.assertEqual(node.is_failed, False) - self.assertIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) - self.assertEqual( - node.get_outgoing().get_node_by_label(PotentialFailureWorkChain.OUTPUT_LABEL), + assert node.exit_status == 0 + assert node.is_finished is True + assert node.is_finished_ok is True + assert node.is_failed is False + assert PotentialFailureWorkChain.OUTPUT_LABEL in node.get_outgoing().all_link_labels() + assert node.get_outgoing().get_node_by_label(PotentialFailureWorkChain.OUTPUT_LABEL) == \ PotentialFailureWorkChain.OUTPUT_VALUE - ) def test_successful_workchain_through_exit_code(self): _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(True), through_exit_code=Bool(True)) - self.assertEqual(node.exit_status, 0) - self.assertEqual(node.is_finished, True) - self.assertEqual(node.is_finished_ok, True) - self.assertEqual(node.is_failed, False) - self.assertIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) - self.assertEqual( - node.get_outgoing().get_node_by_label(PotentialFailureWorkChain.OUTPUT_LABEL), + assert node.exit_status == 0 + assert node.is_finished is True + assert node.is_finished_ok is True + assert node.is_failed is False + assert PotentialFailureWorkChain.OUTPUT_LABEL in node.get_outgoing().all_link_labels() + assert node.get_outgoing().get_node_by_label(PotentialFailureWorkChain.OUTPUT_LABEL) == \ PotentialFailureWorkChain.OUTPUT_VALUE - ) def test_return_out_of_outline(self): _, node = launch.run.get_node(PotentialFailureWorkChain, success=Bool(True), through_return=Bool(True)) - self.assertEqual(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS) - self.assertEqual(node.is_finished, True) - self.assertEqual(node.is_finished_ok, False) - self.assertEqual(node.is_failed, True) - self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outgoing().all_link_labels()) + assert node.exit_status == PotentialFailureWorkChain.EXIT_STATUS + assert node.is_finished is True + assert node.is_finished_ok is False + assert node.is_failed is True + assert PotentialFailureWorkChain.OUTPUT_LABEL not in node.get_outgoing().all_link_labels() class IfTest(WorkChain): @@ -270,55 +265,54 @@ def step2(self): @pytest.mark.requires_rmq -class TestContext(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestContext: def test_attributes(self): wc = IfTest() wc.ctx.new_attr = 5 - self.assertEqual(wc.ctx.new_attr, 5) + assert wc.ctx.new_attr == 5 del wc.ctx.new_attr - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): wc.ctx.new_attr # pylint: disable=pointless-statement def test_dict(self): wc = IfTest() wc.ctx['new_attr'] = 5 - self.assertEqual(wc.ctx['new_attr'], 5) + assert wc.ctx['new_attr'] == 5 del wc.ctx['new_attr'] - with self.assertRaises(KeyError): + with pytest.raises(KeyError): wc.ctx['new_attr'] # pylint: disable=pointless-statement @pytest.mark.requires_rmq -class TestWorkchain(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestWorkchain: # pylint: disable=too-many-public-methods - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + def setup_method(self): + assert Process.current() is None + yield + assert Process.current() is None def test_run_base_class(self): """Verify that it is impossible to run, submit or instantiate a base `WorkChain` class.""" - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): WorkChain() - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run(WorkChain) - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run.get_node(WorkChain) - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run.get_pk(WorkChain) - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.submit(WorkChain) def test_run(self): @@ -332,21 +326,21 @@ def test_run(self): # Check the steps that should have been run for step, finished in Wf.finished_steps.items(): if step not in ['step3', 'step4', 'is_b']: - self.assertTrue(finished, f'Step {step} was not called by workflow') + assert finished, f'Step {step} was not called by workflow' # Try the elif(..) part finished_steps = launch.run(Wf, value=B, n=three) # Check the steps that should have been run for step, finished in finished_steps.items(): if step not in ['is_a', 'step2', 'step4']: - self.assertTrue(finished, f'Step {step} was not called by workflow') + assert finished, f'Step {step} was not called by workflow' # Try the else... part finished_steps = launch.run(Wf, value=C, n=three) # Check the steps that should have been run for step, finished in finished_steps.items(): if step not in ['is_a', 'step2', 'is_b', 'step3']: - self.assertTrue(finished, f'Step {step} was not called by workflow') + assert finished, f'Step {step} was not called by workflow' def test_incorrect_outline(self): @@ -358,7 +352,7 @@ def define(cls, spec): # Try defining an invalid outline spec.outline(5) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): IncorrectOutline.spec() def test_define_not_calling_super(self): @@ -370,7 +364,7 @@ class IncompleteDefineWorkChain(WorkChain): def define(cls, spec): pass - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): launch.run(IncompleteDefineWorkChain) def test_out_unstored(self): @@ -390,7 +384,7 @@ def define(cls, spec): def illegal(self): self.out('not_allowed', orm.Int(2)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): launch.run(IllegalWorkChain) def test_same_input_node(self): @@ -416,8 +410,6 @@ def test_context(self): A = Str('a').store() B = Str('b').store() - test_case = self - class ReturnA(WorkChain): @classmethod @@ -451,15 +443,15 @@ def s1(self): return ToContext(r1=self.submit(ReturnA), r2=self.submit(ReturnB)) def s2(self): - test_case.assertEqual(self.ctx.r1.outputs.res, A) - test_case.assertEqual(self.ctx.r2.outputs.res, B) + assert self.ctx.r1.outputs.res == A + assert self.ctx.r2.outputs.res == B # Try overwriting r1 return ToContext(r1=self.submit(ReturnB)) def s3(self): - test_case.assertEqual(self.ctx.r1.outputs.res, B) - test_case.assertEqual(self.ctx.r2.outputs.res, B) + assert self.ctx.r1.outputs.res == B + assert self.ctx.r2.outputs.res == B run_and_check_success(OverrideContextWorkChain) @@ -482,7 +474,7 @@ def read_context(self): run_and_check_success(TestWorkChain) def test_str(self): - self.assertIsInstance(str(Wf.spec()), str) + assert isinstance(str(Wf.spec()), str) def test_malformed_outline(self): """ @@ -492,11 +484,11 @@ def test_malformed_outline(self): spec = WorkChainSpec() - with self.assertRaises(TypeError): + with pytest.raises(TypeError): spec.outline(5) # Test a function with wrong number of args - with self.assertRaises(TypeError): + with pytest.raises(TypeError): spec.outline(lambda x, y: None) def test_checkpointing(self): @@ -510,21 +502,21 @@ def test_checkpointing(self): # Check the steps that should have been run for step, finished in finished_steps.items(): if step not in ['step3', 'step4', 'is_b']: - self.assertTrue(finished, f'Step {step} was not called by workflow') + assert finished, f'Step {step} was not called by workflow' # Try the elif(..) part finished_steps = self._run_with_checkpoints(Wf, inputs={'value': B, 'n': three}) # Check the steps that should have been run for step, finished in finished_steps.items(): if step not in ['is_a', 'step2', 'step4']: - self.assertTrue(finished, f'Step {step} was not called by workflow') + assert finished, f'Step {step} was not called by workflow' # Try the else... part finished_steps = self._run_with_checkpoints(Wf, inputs={'value': C, 'n': three}) # Check the steps that should have been run for step, finished in finished_steps.items(): if step not in ['is_a', 'step2', 'is_b', 'step3']: - self.assertTrue(finished, f'Step {step} was not called by workflow') + assert finished, f'Step {step} was not called by workflow' def test_return(self): @@ -580,11 +572,11 @@ class SubWorkChain(WorkChain): # Verify that the `CALL` link of the calculation function is there with the correct label link_triple = process.node.get_outgoing(link_type=LinkType.CALL_CALC, link_label_filter=label_calcfunction).one() - self.assertIsInstance(link_triple.node, orm.CalcFunctionNode) + assert isinstance(link_triple.node, orm.CalcFunctionNode) # Verify that the `CALL` link of the work chain is there with the correct label link_triple = process.node.get_outgoing(link_type=LinkType.CALL_WORK, link_label_filter=label_workchain).one() - self.assertIsInstance(link_triple.node, orm.WorkChainNode) + assert isinstance(link_triple.node, orm.WorkChainNode) def test_tocontext_submit_workchain_no_daemon(self): @@ -692,8 +684,8 @@ async def run_async(workchain): # run the original workchain until paused await run_until_paused(workchain) - self.assertTrue(workchain.ctx.s1) - self.assertFalse(workchain.ctx.s2) + assert workchain.ctx.s1 + assert not workchain.ctx.s2 # Now bundle the workchain bundle = plumpy.Bundle(workchain) @@ -702,19 +694,19 @@ async def run_async(workchain): # Load from saved state workchain2 = bundle.unbundle() - self.assertTrue(workchain2.ctx.s1) - self.assertFalse(workchain2.ctx.s2) + assert workchain2.ctx.s1 + assert not workchain2.ctx.s2 # check bundling again creates the same saved state bundle2 = plumpy.Bundle(workchain2) - self.assertDictEqual(bundle, bundle2) + assert bundle == bundle2 # run the loaded workchain to completion runner.schedule(workchain2) workchain2.play() await workchain2.future() - self.assertTrue(workchain2.ctx.s1) - self.assertTrue(workchain2.ctx.s2) + assert workchain2.ctx.s1 + assert workchain2.ctx.s2 # ensure the original paused workchain future is finalised # to avoid warnings @@ -751,8 +743,6 @@ def check(self): def test_to_context(self): val = Int(5).store() - test_case = self - class SimpleWc(WorkChain): @classmethod @@ -776,8 +766,8 @@ def begin(self): return ToContext(result_b=self.submit(SimpleWc)) def result(self): - test_case.assertEqual(self.ctx.result_a.outputs.result, val) - test_case.assertEqual(self.ctx.result_b.outputs.result, val) + assert self.ctx.result_a.outputs.result == val + assert self.ctx.result_b.outputs.result == val run_and_check_success(Workchain) @@ -839,21 +829,21 @@ def run(self): wc = ExitCodeWorkChain() # The exit code can be gotten by calling it with the status or label, as well as using attribute dereferencing - self.assertEqual(wc.exit_codes(status).status, status) - self.assertEqual(wc.exit_codes(label).status, status) - self.assertEqual(wc.exit_codes.SOME_EXIT_CODE.status, status) + assert wc.exit_codes(status).status == status + assert wc.exit_codes(label).status == status + assert wc.exit_codes.SOME_EXIT_CODE.status == status - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): wc.exit_codes.NON_EXISTENT_ERROR # pylint: disable=pointless-statement - self.assertEqual(ExitCodeWorkChain.exit_codes.SOME_EXIT_CODE.status, status) # pylint: disable=no-member - self.assertEqual(ExitCodeWorkChain.exit_codes.SOME_EXIT_CODE.message, message) # pylint: disable=no-member + assert ExitCodeWorkChain.exit_codes.SOME_EXIT_CODE.status == status # pylint: disable=no-member + assert ExitCodeWorkChain.exit_codes.SOME_EXIT_CODE.message == message # pylint: disable=no-member - self.assertEqual(ExitCodeWorkChain.exit_codes['SOME_EXIT_CODE'].status, status) # pylint: disable=unsubscriptable-object - self.assertEqual(ExitCodeWorkChain.exit_codes['SOME_EXIT_CODE'].message, message) # pylint: disable=unsubscriptable-object + assert ExitCodeWorkChain.exit_codes['SOME_EXIT_CODE'].status == status # pylint: disable=unsubscriptable-object + assert ExitCodeWorkChain.exit_codes['SOME_EXIT_CODE'].message == message # pylint: disable=unsubscriptable-object - self.assertEqual(ExitCodeWorkChain.exit_codes[label].status, status) # pylint: disable=unsubscriptable-object - self.assertEqual(ExitCodeWorkChain.exit_codes[label].message, message) # pylint: disable=unsubscriptable-object + assert ExitCodeWorkChain.exit_codes[label].status == status # pylint: disable=unsubscriptable-object + assert ExitCodeWorkChain.exit_codes[label].message == message # pylint: disable=unsubscriptable-object @staticmethod def _run_with_checkpoints(wf_class, inputs=None): @@ -864,18 +854,16 @@ def _run_with_checkpoints(wf_class, inputs=None): @pytest.mark.requires_rmq -class TestWorkChainAbort(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestWorkChainAbort: """ Test the functionality to abort a workchain """ - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + def setup_method(self): + assert Process.current() is None + yield + assert Process.current() is None class AbortableWorkChain(WorkChain): @@ -904,15 +892,15 @@ async def run_async(): process.play() with Capturing(): - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): await process.future() runner.schedule(process) runner.loop.run_until_complete(run_async()) - self.assertEqual(process.node.is_finished_ok, False) - self.assertEqual(process.node.is_excepted, True) - self.assertEqual(process.node.is_killed, False) + assert process.node.is_finished_ok is False + assert process.node.is_excepted is True + assert process.node.is_killed is False def test_simple_kill_through_process(self): """ @@ -926,34 +914,32 @@ def test_simple_kill_through_process(self): async def run_async(): await run_until_paused(process) - self.assertTrue(process.paused) + assert process.paused process.kill() - with self.assertRaises(plumpy.ClosedError): + with pytest.raises(plumpy.ClosedError): launch.run(process) runner.schedule(process) runner.loop.run_until_complete(run_async()) - self.assertEqual(process.node.is_finished_ok, False) - self.assertEqual(process.node.is_excepted, False) - self.assertEqual(process.node.is_killed, True) + assert process.node.is_finished_ok is False + assert process.node.is_excepted is False + assert process.node.is_killed is True @pytest.mark.requires_rmq -class TestWorkChainAbortChildren(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestWorkChainAbortChildren: """ Test the functionality to abort a workchain and verify that children are also aborted appropriately """ - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + def setup_method(self): + assert Process.current() is None + yield + assert Process.current() is None class SubWorkChain(WorkChain): @@ -995,12 +981,12 @@ def test_simple_run(self): process = TestWorkChainAbortChildren.MainWorkChain() with Capturing(): - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): launch.run(process) - self.assertEqual(process.node.is_finished_ok, False) - self.assertEqual(process.node.is_excepted, True) - self.assertEqual(process.node.is_killed, False) + assert process.node.is_finished_ok is False + assert process.node.is_excepted is True + assert process.node.is_killed is False def test_simple_kill_through_process(self): """ @@ -1017,41 +1003,38 @@ async def run_async(): if asyncio.isfuture(result): await result - with self.assertRaises(plumpy.KilledError): + with pytest.raises(plumpy.KilledError): await process.future() runner.schedule(process) runner.loop.run_until_complete(run_async()) child = process.node.get_outgoing(link_type=LinkType.CALL_WORK).first().node - self.assertEqual(child.is_finished_ok, False) - self.assertEqual(child.is_excepted, False) - self.assertEqual(child.is_killed, True) + assert child.is_finished_ok is False + assert child.is_excepted is False + assert child.is_killed is True - self.assertEqual(process.node.is_finished_ok, False) - self.assertEqual(process.node.is_excepted, False) - self.assertEqual(process.node.is_killed, True) + assert process.node.is_finished_ok is False + assert process.node.is_excepted is False + assert process.node.is_killed is True @pytest.mark.requires_rmq -class TestImmutableInputWorkchain(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestImmutableInputWorkchain: """ Test that inputs cannot be modified """ - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + def setup_method(self): + assert Process.current() is None + yield + assert Process.current() is None def test_immutable_input(self): """ Check that from within the WorkChain self.inputs returns an AttributesFrozendict which should be immutable """ - test_class = self class FrozenDictWorkChain(WorkChain): @@ -1067,19 +1050,19 @@ def define(cls, spec): def step_one(self): # Attempt to manipulate the inputs dictionary which since it is a AttributesFrozendict should raise - with test_class.assertRaises(TypeError): + with pytest.raises(TypeError): self.inputs['a'] = Int(3) - with test_class.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.inputs.pop('b') - with test_class.assertRaises(TypeError): + with pytest.raises(TypeError): self.inputs['c'] = Int(4) def step_two(self): # Verify that original inputs are still there with same value and no inputs were added - test_class.assertIn('a', self.inputs) - test_class.assertIn('b', self.inputs) - test_class.assertNotIn('c', self.inputs) - test_class.assertEqual(self.inputs['a'].value, 1) + assert 'a' in self.inputs + assert 'b' in self.inputs + assert 'c' not in self.inputs + assert self.inputs['a'].value == 1 run_and_check_success(FrozenDictWorkChain, a=Int(1), b=Int(2)) @@ -1087,7 +1070,6 @@ def test_immutable_input_groups(self): """ Check that namespaced inputs also return AttributeFrozendicts and are hence immutable """ - test_class = self class ImmutableGroups(WorkChain): @@ -1102,19 +1084,19 @@ def define(cls, spec): def step_one(self): # Attempt to manipulate the namespaced inputs dictionary which should raise - with test_class.assertRaises(TypeError): + with pytest.raises(TypeError): self.inputs.subspace['one'] = Int(3) - with test_class.assertRaises(AttributeError): + with pytest.raises(AttributeError): self.inputs.subspace.pop('two') - with test_class.assertRaises(TypeError): + with pytest.raises(TypeError): self.inputs.subspace['four'] = Int(4) def step_two(self): # Verify that original inputs are still there with same value and no inputs were added - test_class.assertIn('one', self.inputs.subspace) - test_class.assertIn('two', self.inputs.subspace) - test_class.assertNotIn('four', self.inputs.subspace) - test_class.assertEqual(self.inputs.subspace['one'].value, 1) + assert 'one' in self.inputs.subspace + assert 'two' in self.inputs.subspace + assert 'four' not in self.inputs.subspace + assert self.inputs.subspace['one'].value == 1 run_and_check_success(ImmutableGroups, subspace={'one': Int(1), 'two': Int(2)}) @@ -1140,18 +1122,16 @@ def do_test(self): @pytest.mark.requires_rmq -class TestSerializeWorkChain(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestSerializeWorkChain: """ Test workchains with serialized input / output. """ - def setUp(self): - super().setUp() - self.assertIsNone(Process.current()) - - def tearDown(self): - super().tearDown() - self.assertIsNone(Process.current()) + def setup_method(self): + assert Process.current() is None + yield + assert Process.current() is None @staticmethod def test_serialize(): @@ -1267,7 +1247,8 @@ def do_run(self): @pytest.mark.requires_rmq -class TestWorkChainExpose(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestWorkChainExpose: """ Test the expose inputs / outputs functionality """ @@ -1287,23 +1268,21 @@ def test_expose(self): } }, ) - self.assertEqual( - res, { - 'a': Float(2.2), - 'sub_1': { - 'b': Float(2.3), - 'c': Bool(True) - }, - 'sub_2': { - 'b': Float(1.2), - 'sub_3': { - 'c': Bool(False) - } + assert res == { + 'a': Float(2.2), + 'sub_1': { + 'b': Float(2.3), + 'c': Bool(True) + }, + 'sub_2': { + 'b': Float(1.2), + 'sub_3': { + 'c': Bool(False) } } - ) + } - @unittest.skip('Functionality of `Process.exposed_outputs` is broken for nested namespaces, see issue #3533.') + @pytest.mark.skip('Functionality of `Process.exposed_outputs` is broken for nested namespaces, see issue #3533.') def test_nested_expose(self): res = launch.run( GrandParentExposeWorkChain, @@ -1323,25 +1302,23 @@ def test_nested_expose(self): ) ) ) - self.assertEqual( - res, { + assert res == { + 'sub': { 'sub': { - 'sub': { - 'a': Float(2.2), - 'sub_1': { - 'b': Float(2.3), - 'c': Bool(True) - }, - 'sub_2': { - 'b': Float(1.2), - 'sub_3': { - 'c': Bool(False) - } + 'a': Float(2.2), + 'sub_1': { + 'b': Float(2.3), + 'c': Bool(True) + }, + 'sub_2': { + 'b': Float(1.2), + 'sub_3': { + 'c': Bool(False) } } } } - ) + } @pytest.mark.filterwarnings('ignore::UserWarning') def test_issue_1741_expose_inputs(self): @@ -1379,7 +1356,8 @@ def step1(self): @pytest.mark.requires_rmq -class TestWorkChainMisc(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestWorkChainMisc: class PointlessWorkChain(WorkChain): @@ -1411,12 +1389,13 @@ def test_run_pointless_workchain(): def test_global_submit_raises(self): """Using top-level submit should raise.""" - with self.assertRaises(exceptions.InvalidOperation): + with pytest.raises(exceptions.InvalidOperation): launch.run(TestWorkChainMisc.IllegalSubmitWorkChain) @pytest.mark.requires_rmq -class TestDefaultUniqueness(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestDefaultUniqueness: """Test that default inputs of exposed nodes will get unique UUIDS.""" class Parent(WorkChain): @@ -1464,4 +1443,4 @@ def test_unique_default_inputs(self): # Trying to load one of the inputs through the UUID should fail, # as both `child_one.a` and `child_two.a` should have the same UUID. node = load_node(uuid=node.get_incoming().get_node_by_label('child_one__a').uuid) - self.assertEqual(len(uuids), len(nodes), f'Only {len(uuids)} unique UUIDS for {len(nodes)} input nodes') + assert len(uuids) == len(nodes), f'Only {len(uuids)} unique UUIDS for {len(nodes)} input nodes' diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index 62bc8925c4..030ec164b9 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -8,21 +8,18 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name,missing-docstring,too-many-lines +# pylint: disable=no-self-use,unused-argument """Tests for the QueryBuilder.""" import warnings import pytest from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.common.links import LinkType from aiida.manage import configuration -class TestQueryBuilder(AiidaTestCase): - - def setUp(self): - super().setUp() - self.refurbish_db() +@pytest.mark.usefixtures('clear_database_before_test') +class TestQueryBuilder: def test_date_filters_support(self): """Verify that `datetime.date` is supported in filters.""" @@ -34,7 +31,7 @@ def test_date_filters_support(self): orm.Data(ctime=timezone.now() - timedelta(days=1)).store() builder = orm.QueryBuilder().append(orm.Node, filters={'ctime': {'>': date.today() - timedelta(days=1)}}) - self.assertEqual(builder.count(), 1) + assert builder.count() == 1 def test_ormclass_type_classification(self): """ @@ -46,11 +43,11 @@ def test_ormclass_type_classification(self): qb = orm.QueryBuilder() # Asserting that improper declarations of the class type raise an error - with self.assertRaises(DbContentError): + with pytest.raises(DbContentError): qb._get_ormclass(None, 'data') - with self.assertRaises(DbContentError): + with pytest.raises(DbContentError): qb._get_ormclass(None, 'data.Data') - with self.assertRaises(DbContentError): + with pytest.raises(DbContentError): qb._get_ormclass(None, '.') # Asserting that the query type string and plugin type string are returned: @@ -58,34 +55,34 @@ def test_ormclass_type_classification(self): qb._get_ormclass(orm.StructureData, None), qb._get_ormclass(None, 'data.structure.StructureData.'), ): - self.assertEqual(classifiers['ormclass_type_string'], orm.StructureData._plugin_type_string) # pylint: disable=no-member + assert classifiers['ormclass_type_string'] == orm.StructureData._plugin_type_string # pylint: disable=no-member for _cls, classifiers in ( qb._get_ormclass(orm.Group, None), qb._get_ormclass(None, 'group.core'), qb._get_ormclass(None, 'Group.core'), ): - self.assertTrue(classifiers['ormclass_type_string'].startswith('group')) + assert classifiers['ormclass_type_string'].startswith('group') for _cls, classifiers in ( qb._get_ormclass(orm.User, None), qb._get_ormclass(None, 'user'), qb._get_ormclass(None, 'User'), ): - self.assertEqual(classifiers['ormclass_type_string'], 'user') + assert classifiers['ormclass_type_string'] == 'user' for _cls, classifiers in ( qb._get_ormclass(orm.Computer, None), qb._get_ormclass(None, 'computer'), qb._get_ormclass(None, 'Computer'), ): - self.assertEqual(classifiers['ormclass_type_string'], 'computer') + assert classifiers['ormclass_type_string'] == 'computer' for _cls, classifiers in ( qb._get_ormclass(orm.Data, None), qb._get_ormclass(None, 'data.Data.'), ): - self.assertEqual(classifiers['ormclass_type_string'], orm.Data._plugin_type_string) # pylint: disable=no-member + assert classifiers['ormclass_type_string'] == orm.Data._plugin_type_string # pylint: disable=no-member def test_process_type_classification(self): """ @@ -103,34 +100,34 @@ def test_process_type_classification(self): # When passing a WorkChain class, it should return the type of the corresponding Node # including the appropriate filter on the process_type _cls, classifiers = qb._get_ormclass(WorkChain, None) - self.assertEqual(classifiers['ormclass_type_string'], 'process.workflow.workchain.WorkChainNode.') - self.assertEqual(classifiers['process_type_string'], 'aiida.engine.processes.workchains.workchain.WorkChain') + assert classifiers['ormclass_type_string'] == 'process.workflow.workchain.WorkChainNode.' + assert classifiers['process_type_string'] == 'aiida.engine.processes.workchains.workchain.WorkChain' # When passing a WorkChainNode, no process_type filter is applied _cls, classifiers = qb._get_ormclass(orm.WorkChainNode, None) - self.assertEqual(classifiers['ormclass_type_string'], 'process.workflow.workchain.WorkChainNode.') - self.assertEqual(classifiers['process_type_string'], None) + assert classifiers['ormclass_type_string'] == 'process.workflow.workchain.WorkChainNode.' + assert classifiers['process_type_string'] is None # Same tests for a calculation _cls, classifiers = qb._get_ormclass(ArithmeticAdd, None) - self.assertEqual(classifiers['ormclass_type_string'], 'process.calculation.calcjob.CalcJobNode.') - self.assertEqual(classifiers['process_type_string'], 'aiida.calculations:arithmetic.add') + assert classifiers['ormclass_type_string'] == 'process.calculation.calcjob.CalcJobNode.' + assert classifiers['process_type_string'] == 'aiida.calculations:arithmetic.add' def test_get_group_type_filter(self): """Test the `aiida.orm.querybuilder.get_group_type_filter` function.""" from aiida.orm.querybuilder import get_group_type_filter classifiers = {'ormclass_type_string': 'group.core'} - self.assertEqual(get_group_type_filter(classifiers, False), {'==': 'core'}) - self.assertEqual(get_group_type_filter(classifiers, True), {'like': '%'}) + assert get_group_type_filter(classifiers, False) == {'==': 'core'} + assert get_group_type_filter(classifiers, True) == {'like': '%'} classifiers = {'ormclass_type_string': 'group.core.auto'} - self.assertEqual(get_group_type_filter(classifiers, False), {'==': 'core.auto'}) - self.assertEqual(get_group_type_filter(classifiers, True), {'like': 'core.auto%'}) + assert get_group_type_filter(classifiers, False) == {'==': 'core.auto'} + assert get_group_type_filter(classifiers, True) == {'like': 'core.auto%'} classifiers = {'ormclass_type_string': 'group.pseudo.family'} - self.assertEqual(get_group_type_filter(classifiers, False), {'==': 'pseudo.family'}) - self.assertEqual(get_group_type_filter(classifiers, True), {'like': 'pseudo.family%'}) + assert get_group_type_filter(classifiers, False) == {'==': 'pseudo.family'} + assert get_group_type_filter(classifiers, True) == {'like': 'pseudo.family%'} # Tracked in issue #4281 @pytest.mark.flaky(reruns=2) @@ -200,8 +197,8 @@ class DummyWorkChain(WorkChain): assert issubclass(w[-1].category, AiidaEntryPointWarning) # There should be one result of type WorkChainNode - self.assertEqual(qb.count(), 1) - self.assertTrue(isinstance(qb.all()[0][0], orm.WorkChainNode)) + assert qb.count() == 1 + assert isinstance(qb.all()[0][0], orm.WorkChainNode) # Query for nodes of a different type of WorkChain qb = orm.QueryBuilder() @@ -217,14 +214,14 @@ class DummyWorkChain(WorkChain): assert issubclass(w[-1].category, AiidaEntryPointWarning) # There should be no result - self.assertEqual(qb.count(), 0) + assert qb.count() == 0 # Query for all WorkChain nodes qb = orm.QueryBuilder() qb.append(WorkChain) # There should be one result - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 def test_simple_query_1(self): """ @@ -266,35 +263,35 @@ def test_simple_query_1(self): qb1 = orm.QueryBuilder() qb1.append(orm.Node, filters={'attributes.foo': 1.000}) - self.assertEqual(len(qb1.all()), 2) + assert len(qb1.all()) == 2 qb2 = orm.QueryBuilder() qb2.append(orm.Data) - self.assertEqual(qb2.count(), 3) + assert qb2.count() == 3 qb2 = orm.QueryBuilder() qb2.append(entity_type='data.Data.') - self.assertEqual(qb2.count(), 3) + assert qb2.count() == 3 qb3 = orm.QueryBuilder() qb3.append(orm.Node, project='label', tag='node1') qb3.append(orm.Node, project='label', tag='node2') - self.assertEqual(qb3.count(), 4) + assert qb3.count() == 4 qb4 = orm.QueryBuilder() qb4.append(orm.CalculationNode, tag='node1') qb4.append(orm.Data, tag='node2') - self.assertEqual(qb4.count(), 2) + assert qb4.count() == 2 qb5 = orm.QueryBuilder() qb5.append(orm.Data, tag='node1') qb5.append(orm.CalculationNode, tag='node2') - self.assertEqual(qb5.count(), 2) + assert qb5.count() == 2 qb6 = orm.QueryBuilder() qb6.append(orm.Data, tag='node1') qb6.append(orm.Data, tag='node2') - self.assertEqual(qb6.count(), 0) + assert qb6.count() == 0 def test_simple_query_2(self): from datetime import datetime @@ -320,7 +317,7 @@ def test_simple_query_2(self): qb1 = orm.QueryBuilder() qb1.append(orm.Node, filters={'label': 'hello'}) - self.assertEqual(len(list(qb1.all())), 1) + assert len(list(qb1.all())) == 1 qh = { 'path': [{ @@ -352,11 +349,11 @@ def test_simple_query_2(self): qb2 = orm.QueryBuilder(**qh) resdict = qb2.dict() - self.assertEqual(len(resdict), 1) - self.assertTrue(isinstance(resdict[0]['n1']['ctime'], datetime)) + assert len(resdict) == 1 + assert isinstance(resdict[0]['n1']['ctime'], datetime) res_one = qb2.one() - self.assertTrue('bar' in res_one) + assert 'bar' in res_one qh = { 'path': [{ @@ -376,23 +373,23 @@ def test_simple_query_2(self): } } qb = orm.QueryBuilder(**qh) - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 # Test the hashing: query1 = qb.get_query() qb.add_filter('n2', {'label': 'nonexistentlabel'}) - self.assertEqual(qb.count(), 0) + assert qb.count() == 0 - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): qb.one() - with self.assertRaises(MultipleObjectsError): + with pytest.raises(MultipleObjectsError): orm.QueryBuilder().append(orm.Node).one() query2 = qb.get_query() query3 = qb.get_query() - self.assertTrue(id(query1) != id(query2)) - self.assertTrue(id(query2) == id(query3)) + assert id(query1) != id(query2) + assert id(query2) == id(query3) def test_dict_multiple_projections(self): """Test that the `.dict()` accumulator with multiple projections returns the correct types.""" @@ -400,14 +397,14 @@ def test_dict_multiple_projections(self): builder = orm.QueryBuilder().append(orm.Data, project=['*', 'id']) results = builder.dict() - self.assertIsInstance(results, list) - self.assertTrue(all(isinstance(value, dict) for value in results)) + assert isinstance(results, list) + assert all(isinstance(value, dict) for value in results) dictionary = list(results[0].values())[0] # `results` should have the form [{'Data_1': {'*': Node, 'id': 1}}] - self.assertIsInstance(dictionary['*'], orm.Data) - self.assertEqual(dictionary['*'].pk, node.pk) - self.assertEqual(dictionary['id'], node.pk) + assert isinstance(dictionary['*'], orm.Data) + assert dictionary['*'].pk == node.pk + assert dictionary['id'] == node.pk def test_operators_eq_lt_gt(self): nodes = [orm.Data() for _ in range(8)] @@ -424,12 +421,12 @@ def test_operators_eq_lt_gt(self): for n in nodes: n.store() - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<': 1}}).count(), 0) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'==': 1}}).count(), 2) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<': 1.02}}).count(), 3) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<=': 1.02}}).count(), 4) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'>': 1.02}}).count(), 4) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'>=': 1.02}}).count(), 5) + assert orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<': 1}}).count() == 0 + assert orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'==': 1}}).count() == 2 + assert orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<': 1.02}}).count() == 3 + assert orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<=': 1.02}}).count() == 4 + assert orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'>': 1.02}}).count() == 4 + assert orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'>=': 1.02}}).count() == 5 def test_subclassing(self): s = orm.StructureData() @@ -445,107 +442,107 @@ def test_subclassing(self): # Now when asking for a node with attr.cat==miau, I want 3 esults: qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.cat': 'miau'}) - self.assertEqual(qb.count(), 3) + assert qb.count() == 3 qb = orm.QueryBuilder().append(orm.Data, filters={'attributes.cat': 'miau'}) - self.assertEqual(qb.count(), 3) + assert qb.count() == 3 # If I'm asking for the specific lowest subclass, I want one result for cls in (orm.StructureData, orm.Dict): qb = orm.QueryBuilder().append(cls, filters={'attributes.cat': 'miau'}) - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 # Now I am not allow the subclassing, which should give 1 result for each for cls, count in ((orm.StructureData, 1), (orm.Dict, 1), (orm.Data, 1), (orm.Node, 0)): qb = orm.QueryBuilder().append(cls, filters={'attributes.cat': 'miau'}, subclassing=False) - self.assertEqual(qb.count(), count) + assert qb.count() == count # Now I am testing the subclassing with tuples: qb = orm.QueryBuilder().append(cls=(orm.StructureData, orm.Dict), filters={'attributes.cat': 'miau'}) - self.assertEqual(qb.count(), 2) + assert qb.count() == 2 qb = orm.QueryBuilder().append( entity_type=('data.structure.StructureData.', 'data.dict.Dict.'), filters={'attributes.cat': 'miau'} ) - self.assertEqual(qb.count(), 2) + assert qb.count() == 2 qb = orm.QueryBuilder().append( cls=(orm.StructureData, orm.Dict), filters={'attributes.cat': 'miau'}, subclassing=False ) - self.assertEqual(qb.count(), 2) + assert qb.count() == 2 qb = orm.QueryBuilder().append( cls=(orm.StructureData, orm.Data), filters={'attributes.cat': 'miau'}, ) - self.assertEqual(qb.count(), 3) + assert qb.count() == 3 qb = orm.QueryBuilder().append( entity_type=('data.structure.StructureData.', 'data.dict.Dict.'), filters={'attributes.cat': 'miau'}, subclassing=False ) - self.assertEqual(qb.count(), 2) + assert qb.count() == 2 qb = orm.QueryBuilder().append( entity_type=('data.structure.StructureData.', 'data.Data.'), filters={'attributes.cat': 'miau'}, subclassing=False ) - self.assertEqual(qb.count(), 2) + assert qb.count() == 2 def test_list_behavior(self): for _i in range(4): orm.Data().store() - self.assertEqual(len(orm.QueryBuilder().append(orm.Node).all()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project='*').all()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).all()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['id']).all()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node).dict()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project='*').dict()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).dict()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['id']).dict()), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node).iterall())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project='*').iterall())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).iterall())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['id']).iterall())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node).iterdict())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project='*').iterdict())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).iterdict())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['id']).iterdict())), 4) + assert len(orm.QueryBuilder().append(orm.Node).all()) == 4 + assert len(orm.QueryBuilder().append(orm.Node, project='*').all()) == 4 + assert len(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).all()) == 4 + assert len(orm.QueryBuilder().append(orm.Node, project=['id']).all()) == 4 + assert len(orm.QueryBuilder().append(orm.Node).dict()) == 4 + assert len(orm.QueryBuilder().append(orm.Node, project='*').dict()) == 4 + assert len(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).dict()) == 4 + assert len(orm.QueryBuilder().append(orm.Node, project=['id']).dict()) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node).iterall())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node, project='*').iterall())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).iterall())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node, project=['id']).iterall())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node).iterdict())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node, project='*').iterdict())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).iterdict())) == 4 + assert len(list(orm.QueryBuilder().append(orm.Node, project=['id']).iterdict())) == 4 def test_append_validation(self): from aiida.common.exceptions import InputValidationError # So here I am giving two times the same tag - with self.assertRaises(InputValidationError): + with pytest.raises(InputValidationError): orm.QueryBuilder().append(orm.StructureData, tag='n').append(orm.StructureData, tag='n') # here I am giving a wrong filter specifications - with self.assertRaises(InputValidationError): + with pytest.raises(InputValidationError): orm.QueryBuilder().append(orm.StructureData, filters=['jajjsd']) # here I am giving a nonsensical projection: - with self.assertRaises(InputValidationError): + with pytest.raises(InputValidationError): orm.QueryBuilder().append(orm.StructureData, project=True) # here I am giving a nonsensical projection for the edge: - with self.assertRaises(InputValidationError): + with pytest.raises(InputValidationError): orm.QueryBuilder().append(orm.ProcessNode).append(orm.StructureData, edge_tag='t').add_projection('t', True) # Giving a nonsensical limit - with self.assertRaises(InputValidationError): + with pytest.raises(InputValidationError): orm.QueryBuilder().append(orm.ProcessNode).limit(2.3) # Giving a nonsensical offset - with self.assertRaises(InputValidationError): + with pytest.raises(InputValidationError): orm.QueryBuilder(offset=2.3) # So, I mess up one append, I want the QueryBuilder to clean it! - with self.assertRaises(InputValidationError): + with pytest.raises(InputValidationError): qb = orm.QueryBuilder() # This also checks if we correctly raise for wrong keywords qb.append(orm.StructureData, tag='s', randomkeyword={}) # Now I'm checking whether this keyword appears anywhere in the internal dictionaries: # pylint: disable=protected-access - self.assertTrue('s' not in qb._projections) - self.assertTrue('s' not in qb._filters) - self.assertTrue('s' not in qb.tag_to_alias_map) - self.assertTrue(len(qb._path) == 0) - self.assertTrue(orm.StructureData not in qb._cls_to_tag_map) + assert 's' not in qb._projections + assert 's' not in qb._filters + assert 's' not in qb.tag_to_alias_map + assert len(qb._path) == 0 + assert orm.StructureData not in qb._cls_to_tag_map # So this should work now: qb.append(orm.StructureData, tag='s').limit(2).dict() @@ -555,43 +552,33 @@ def test_tags(self): qb.append(orm.Node, tag='n2', edge_tag='e1', with_incoming='n1') qb.append(orm.Node, tag='n3', edge_tag='e2', with_incoming='n2') qb.append(orm.Computer, with_node='n3', tag='c1', edge_tag='nonsense') - self.assertEqual(qb.get_used_tags(), ['n1', 'n2', 'e1', 'n3', 'e2', 'c1', 'nonsense']) + assert qb.get_used_tags() == ['n1', 'n2', 'e1', 'n3', 'e2', 'c1', 'nonsense'] # Now I am testing the default tags, qb = orm.QueryBuilder().append(orm.StructureData ).append(orm.ProcessNode ).append(orm.StructureData ).append(orm.Dict, with_outgoing=orm.ProcessNode) - self.assertEqual( - qb.get_used_tags(), [ - 'StructureData_1', 'ProcessNode_1', 'StructureData_1--ProcessNode_1', 'StructureData_2', - 'ProcessNode_1--StructureData_2', 'Dict_1', 'ProcessNode_1--Dict_1' - ] - ) - self.assertEqual( - qb.get_used_tags(edges=False), [ - 'StructureData_1', - 'ProcessNode_1', - 'StructureData_2', - 'Dict_1', - ] - ) - self.assertEqual( - qb.get_used_tags(vertices=False), + assert qb.get_used_tags() == [ + 'StructureData_1', 'ProcessNode_1', 'StructureData_1--ProcessNode_1', 'StructureData_2', + 'ProcessNode_1--StructureData_2', 'Dict_1', 'ProcessNode_1--Dict_1' + ] + assert qb.get_used_tags(edges=False) == [ + 'StructureData_1', + 'ProcessNode_1', + 'StructureData_2', + 'Dict_1', + ] + assert qb.get_used_tags(vertices=False) == \ ['StructureData_1--ProcessNode_1', 'ProcessNode_1--StructureData_2', 'ProcessNode_1--Dict_1'] - ) - self.assertEqual( - qb.get_used_tags(edges=False), [ - 'StructureData_1', - 'ProcessNode_1', - 'StructureData_2', - 'Dict_1', - ] - ) - self.assertEqual( - qb.get_used_tags(vertices=False), + assert qb.get_used_tags(edges=False) == [ + 'StructureData_1', + 'ProcessNode_1', + 'StructureData_2', + 'Dict_1', + ] + assert qb.get_used_tags(vertices=False) == \ ['StructureData_1--ProcessNode_1', 'ProcessNode_1--StructureData_2', 'ProcessNode_1--Dict_1'] - ) def test_direction_keyword(self): """ @@ -619,8 +606,8 @@ def test_direction_keyword(self): qb.append(orm.CalculationNode, with_incoming='data', project='id') res2 = {_ for _, in qb.all()} - self.assertEqual(res1, res2) - self.assertEqual(res1, {c1.id}) + assert res1 == res2 + assert res1 == {c1.id} # testing direction=-1, which should return the incoming qb = orm.QueryBuilder() @@ -632,8 +619,8 @@ def test_direction_keyword(self): qb.append(orm.Data, filters={'id': d2.id}, tag='data') qb.append(orm.CalculationNode, with_outgoing='data', project='id') res2 = {_ for _, in qb.all()} - self.assertEqual(res1, res2) - self.assertEqual(res1, {c1.id}) + assert res1 == res2 + assert res1 == {c1.id} # testing direction higher than 1 qb = orm.QueryBuilder() @@ -644,12 +631,12 @@ def test_direction_keyword(self): qh = qb.queryhelp # saving query for later qb.append(orm.Data, direction=-4, project='id') res1 = {item[1] for item in qb.all()} - self.assertEqual(res1, {d1.id}) + assert res1 == {d1.id} qb = orm.QueryBuilder(**qh) qb.append(orm.Data, direction=4, project='id') res2 = {item[1] for item in qb.all()} - self.assertEqual(res2, {d2.id, d4.id}) + assert res2 == {d2.id, d4.id} @staticmethod def test_flat(): @@ -678,7 +665,8 @@ def test_flat(): assert result == list(chain.from_iterable(zip(pks, uuids))) -class TestMultipleProjections(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestMultipleProjections: """Unit tests for the QueryBuilder ORM class.""" def test_first_multiple_projections(self): @@ -689,108 +677,105 @@ def test_first_multiple_projections(self): result = orm.QueryBuilder().append(orm.User, tag='user', project=['email']).append(orm.Data, with_user='user', project=['*']).first() - self.assertEqual(type(result), list) - self.assertEqual(len(result), 2) - self.assertIsInstance(result[0], str) - self.assertIsInstance(result[1], orm.Data) + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0], str) + assert isinstance(result[1], orm.Data) -class TestQueryHelp(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test') +def test_queryhelp(aiida_localhost): + """ + Here I test the queryhelp by seeing whether results are the same as using the append method. + I also check passing of tuples. + """ + g = orm.Group(label='helloworld').store() + for cls in (orm.StructureData, orm.Dict, orm.Data): + obj = cls() + obj.set_attribute('foo-qh2', 'bar') + obj.store() + g.add_nodes(obj) + + for cls, expected_count, subclassing in ( + (orm.StructureData, 1, True), + (orm.Dict, 1, True), + (orm.Data, 3, True), + (orm.Data, 1, False), + ((orm.Dict, orm.StructureData), 2, True), + ((orm.Dict, orm.StructureData), 2, False), + ((orm.Dict, orm.Data), 2, False), + ((orm.Dict, orm.Data), 3, True), + ((orm.Dict, orm.Data, orm.StructureData), 3, False), + ): + qb = orm.QueryBuilder() + qb.append(cls, filters={'attributes.foo-qh2': 'bar'}, subclassing=subclassing, project='uuid') + assert qb.count() == expected_count - def test_queryhelp(self): - """ - Here I test the queryhelp by seeing whether results are the same as using the append method. - I also check passing of tuples. - """ - g = orm.Group(label='helloworld').store() - for cls in (orm.StructureData, orm.Dict, orm.Data): - obj = cls() - obj.set_attribute('foo-qh2', 'bar') - obj.store() - g.add_nodes(obj) - - for cls, expected_count, subclassing in ( - (orm.StructureData, 1, True), - (orm.Dict, 1, True), - (orm.Data, 3, True), - (orm.Data, 1, False), - ((orm.Dict, orm.StructureData), 2, True), - ((orm.Dict, orm.StructureData), 2, False), - ((orm.Dict, orm.Data), 2, False), - ((orm.Dict, orm.Data), 3, True), - ((orm.Dict, orm.Data, orm.StructureData), 3, False), - ): - qb = orm.QueryBuilder() - qb.append(cls, filters={'attributes.foo-qh2': 'bar'}, subclassing=subclassing, project='uuid') - self.assertEqual(qb.count(), expected_count) + qh = qb.queryhelp + qb_new = orm.QueryBuilder(**qh) + assert qb_new.count() == expected_count + assert sorted([uuid for uuid, in qb.all()]) == sorted([uuid for uuid, in qb_new.all()]) - qh = qb.queryhelp - qb_new = orm.QueryBuilder(**qh) - self.assertEqual(qb_new.count(), expected_count) - self.assertEqual(sorted([uuid for uuid, in qb.all()]), sorted([uuid for uuid, in qb_new.all()])) + qb = orm.QueryBuilder().append(orm.Group, filters={'label': 'helloworld'}) + assert qb.count() == 1 - qb = orm.QueryBuilder().append(orm.Group, filters={'label': 'helloworld'}) - self.assertEqual(qb.count(), 1) + qb = orm.QueryBuilder().append((orm.Group,), filters={'label': 'helloworld'}) + assert qb.count() == 1 - qb = orm.QueryBuilder().append((orm.Group,), filters={'label': 'helloworld'}) - self.assertEqual(qb.count(), 1) + # populate computer + qb = orm.QueryBuilder().append(orm.Computer,) + assert qb.count() == 1 - # populate computer - self.computer # pylint:disable=pointless-statement - qb = orm.QueryBuilder().append(orm.Computer,) - self.assertEqual(qb.count(), 1) + qb = orm.QueryBuilder().append(cls=(orm.Computer,)) + assert qb.count() == 1 - qb = orm.QueryBuilder().append(cls=(orm.Computer,)) - self.assertEqual(qb.count(), 1) - def test_recreate_from_queryhelp(self): - """Test recreating a QueryBuilder from the Query Help +@pytest.mark.usefixtures('clear_database_before_test') +def test_recreate_from_queryhelp(): + """Test recreating a QueryBuilder from the Query Help - We test appending a Data node and a Process node for variety, as well - as a generic Node specifically because it translates to `entity_type` - as an empty string (which can potentially cause problems). - """ - import copy + We test appending a Data node and a Process node for variety, as well + as a generic Node specifically because it translates to `entity_type` + as an empty string (which can potentially cause problems). + """ + import copy - qb1 = orm.QueryBuilder() - qb1.append(orm.Node) - qb1.append(orm.Data) - qb1.append(orm.CalcJobNode) + qb1 = orm.QueryBuilder() + qb1.append(orm.Node) + qb1.append(orm.Data) + qb1.append(orm.CalcJobNode) - qb2 = orm.QueryBuilder(**qb1.queryhelp) - self.assertDictEqual(qb1.queryhelp, qb2.queryhelp) + qb2 = orm.QueryBuilder(**qb1.queryhelp) + assert qb1.queryhelp == qb2.queryhelp - qb3 = copy.deepcopy(qb1) - self.assertDictEqual(qb1.queryhelp, qb3.queryhelp) + qb3 = copy.deepcopy(qb1) + assert qb1.queryhelp == qb3.queryhelp -class TestQueryBuilderCornerCases(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test') +def test_corner_case_computer_json(aiida_localhost): """ - In this class corner cases of QueryBuilder are added. + In this test we check the correct behavior of QueryBuilder when + retrieving the _metadata with no content. + Note that they are in JSON format in both backends. Forcing the + decoding of a None value leads to an exception (this was the case + under Django). """ + n1 = orm.CalculationNode() + n1.label = 'node2' + n1.set_attribute('foo', 1) + n1.store() - def test_computer_json(self): # pylint: disable=no-self-use - """ - In this test we check the correct behavior of QueryBuilder when - retrieving the _metadata with no content. - Note that they are in JSON format in both backends. Forcing the - decoding of a None value leads to an exception (this was the case - under Django). - """ - n1 = orm.CalculationNode() - n1.label = 'node2' - n1.set_attribute('foo', 1) - n1.store() - - # Checking the correct retrieval of _metadata which is - # a JSON field (in both backends). - qb = orm.QueryBuilder() - qb.append(orm.CalculationNode, project=['id'], tag='calc') - qb.append(orm.Computer, project=['id', 'metadata'], outerjoin=True, with_node='calc') - qb.all() + # Checking the correct retrieval of _metadata which is + # a JSON field (in both backends). + qb = orm.QueryBuilder() + qb.append(orm.CalculationNode, project=['id'], tag='calc') + qb.append(orm.Computer, project=['id', 'metadata'], outerjoin=True, with_node='calc') + qb.all() -class TestAttributes(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestAttributes: def test_attribute_existence(self): # I'm storing a value under key whatever: @@ -820,7 +805,7 @@ def test_attribute_existence(self): project='uuid' ) res_query = {str(_[0]) for _ in qb.all()} - self.assertEqual(res_query, res_uuids) + assert res_query == res_uuids def test_attribute_type(self): key = 'value_test_attr_type' @@ -840,30 +825,30 @@ def test_attribute_type(self): for val in (1.0, 1): qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': val}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_float.uuid, n_int.uuid))) + assert set(res) == set((n_float.uuid, n_int.uuid)) qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'>': 0.5}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_float.uuid, n_int.uuid))) + assert set(res) == set((n_float.uuid, n_int.uuid)) qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'<': 1.5}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_float.uuid, n_int.uuid))) + assert set(res) == set((n_float.uuid, n_int.uuid)) # Now I am testing the boolean value: qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': True}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_bool.uuid,))) + assert set(res) == set((n_bool.uuid,)) qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'like': '%n%'}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_str2.uuid,))) + assert set(res) == set((n_str2.uuid,)) qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'ilike': 'On%'}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_str2.uuid,))) + assert set(res) == set((n_str2.uuid,)) qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'like': '1'}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_str.uuid,))) + assert set(res) == set((n_str.uuid,)) qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'==': '1'}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_str.uuid,))) + assert set(res) == set((n_str.uuid,)) if configuration.PROFILE.database_backend == 'sqlalchemy': # I can't query the length of an array with Django, # so I exclude. Not the nicest way, But I would like to keep this piece @@ -871,10 +856,11 @@ def test_attribute_type(self): # duplicated or wrapped otherwise. qb = orm.QueryBuilder().append(orm.Node, filters={f'attributes.{key}': {'of_length': 3}}, project='uuid') res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_arr.uuid,))) + assert set(res) == set((n_arr.uuid,)) -class QueryBuilderLimitOffsetsTest(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class QueryBuilderLimitOffsetsTest: def test_ordering_limits_offsets_of_results_general(self): # Creating 10 nodes with an attribute that can be ordered @@ -886,17 +872,17 @@ def test_ordering_limits_offsets_of_results_general(self): qb = orm.QueryBuilder().append(orm.Node, project='attributes.foo').order_by({orm.Node: 'ctime'}) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(10))) + assert res == tuple(range(10)) # Now applying an offset: qb.offset(5) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(5, 10))) + assert res == tuple(range(5, 10)) # Now also applying a limit: qb.limit(3) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(5, 8))) + assert res == tuple(range(5, 8)) # Specifying the order explicitly the order: qb = orm.QueryBuilder().append(orm.Node, @@ -907,17 +893,17 @@ def test_ordering_limits_offsets_of_results_general(self): }}) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(10))) + assert res == tuple(range(10)) # Now applying an offset: qb.offset(5) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(5, 10))) + assert res == tuple(range(5, 10)) # Now also applying a limit: qb.limit(3) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(5, 8))) + assert res == tuple(range(5, 8)) # Reversing the order: qb = orm.QueryBuilder().append(orm.Node, @@ -928,20 +914,21 @@ def test_ordering_limits_offsets_of_results_general(self): }}) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(9, -1, -1))) + assert res == tuple(range(9, -1, -1)) # Now applying an offset: qb.offset(5) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(4, -1, -1))) + assert res == tuple(range(4, -1, -1)) # Now also applying a limit: qb.limit(3) res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(4, 1, -1))) + assert res == tuple(range(4, 1, -1)) -class QueryBuilderJoinsTests(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class QueryBuilderJoinsTests: def test_joins1(self): # Creating n1, who will be a parent: @@ -970,12 +957,12 @@ def test_joins1(self): qb = orm.QueryBuilder() qb.append(orm.Node, tag='parent') qb.append(orm.Node, tag='children', project='label', filters={'attributes.is_good': True}) - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 qb = orm.QueryBuilder() qb.append(orm.Node, tag='parent') qb.append(orm.Node, tag='children', outerjoin=True, project='label', filters={'attributes.is_good': True}) - self.assertEqual(qb.count(), 1) + assert qb.count() == 1 def test_joins2(self): # Creating n1, who will be a parent: @@ -1004,26 +991,22 @@ def test_joins2(self): # let's add a differnt relationship than advisor: students[9].add_incoming(advisors[2], link_type=LinkType.CREATE, link_label='lover') - self.assertEqual( - orm.QueryBuilder().append( - orm.Node - ).append(orm.Node, edge_filters={ + assert orm.QueryBuilder().append( + orm.Node + ).append(orm.Node, edge_filters={ + 'label': { + 'like': 'is\\_advisor\\_%' + } + }, tag='student').count() == 7 + + for adv_id, number_students in zip(list(range(3)), (2, 2, 3)): + assert orm.QueryBuilder().append(orm.Node, filters={ + 'attributes.advisor_id': adv_id + }).append(orm.Node, edge_filters={ 'label': { 'like': 'is\\_advisor\\_%' } - }, tag='student').count(), 7 - ) - - for adv_id, number_students in zip(list(range(3)), (2, 2, 3)): - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'attributes.advisor_id': adv_id - }).append(orm.Node, edge_filters={ - 'label': { - 'like': 'is\\_advisor\\_%' - } - }, tag='student').count(), number_students - ) + }, tag='student').count() == number_students def test_joins3_user_group(self): # Create another user @@ -1039,14 +1022,14 @@ def test_joins3_user_group(self): qb = orm.QueryBuilder() qb.append(orm.User, tag='user', filters={'id': {'==': user.id}}) qb.append(orm.Group, with_user='user', filters={'id': {'==': group.id}}) - self.assertEqual(qb.count(), 1, 'The expected group that belongs to the selected user was not found.') + assert qb.count() == 1, 'The expected group that belongs to the selected user was not found.' # Search for the user that owns a group qb = orm.QueryBuilder() qb.append(orm.Group, tag='group', filters={'id': {'==': group.id}}) qb.append(orm.User, with_group='group', filters={'id': {'==': user.id}}) - self.assertEqual(qb.count(), 1, 'The expected user that owns the selected group was not found.') + assert qb.count() == 1, 'The expected user that owns the selected group was not found.' def test_joins_group_node(self): """ @@ -1090,18 +1073,22 @@ def test_joins_group_node(self): qb = orm.QueryBuilder() qb.append(orm.Node, tag='node', project=['id']) qb.append(orm.Group, with_node='node', filters={'id': {'==': group.id}}) - self.assertEqual(qb.count(), 4, 'There should be 4 nodes in the group') + assert qb.count() == 4, 'There should be 4 nodes in the group' id_res = [_ for [_] in qb.all()] for curr_id in [n1.id, n2.id, n3.id, n4.id]: - self.assertIn(curr_id, id_res) + assert curr_id in id_res -class QueryBuilderPath(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class QueryBuilderPath: def test_query_path(self): # pylint: disable=too-many-statements + from aiida.manage.manager import get_manager + + backend = get_manager().get_backend() - q = self.backend.query_manager + q = backend.query_manager n1 = orm.Data() n1.label = 'n1' n2 = orm.CalculationNode() @@ -1136,95 +1123,81 @@ def test_query_path(self): node.store() # There are no parents to n9, checking that - self.assertEqual(set([]), set(q.get_all_parents([n9.pk]))) + assert set([]) == set(q.get_all_parents([n9.pk])) # There is one parent to n6 - self.assertEqual({(_,) for _ in (n6.pk,)}, {tuple(_) for _ in q.get_all_parents([n7.pk])}) + assert {(_,) for _ in (n6.pk,)} == {tuple(_) for _ in q.get_all_parents([n7.pk])} # There are several parents to n4 - self.assertEqual({(_.pk,) for _ in (n1, n2)}, {tuple(_) for _ in q.get_all_parents([n4.pk])}) + assert {(_.pk,) for _ in (n1, n2)} == {tuple(_) for _ in q.get_all_parents([n4.pk])} # There are several parents to n5 - self.assertEqual({(_.pk,) for _ in (n1, n2, n3, n4)}, {tuple(_) for _ in q.get_all_parents([n5.pk])}) + assert {(_.pk,) for _ in (n1, n2, n3, n4)} == {tuple(_) for _ in q.get_all_parents([n5.pk])} # Yet, no links from 1 to 8 - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n1.pk - }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ - 'id': n8.pk - }).count(), 0 - ) + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n1.pk + }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ + 'id': n8.pk + }).count() == 0 - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append(orm.Node, with_descendants='desc', filters={ - 'id': n1.pk - }).count(), 0 - ) + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append(orm.Node, with_descendants='desc', filters={ + 'id': n1.pk + }).count() == 0 n6.add_incoming(n5, link_type=LinkType.CREATE, link_label='link1') # Yet, now 2 links from 1 to 8 - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n1.pk - }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ - 'id': n8.pk - }).count(), 2 - ) + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n1.pk + }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ + 'id': n8.pk + }).count() == 2 - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append(orm.Node, with_descendants='desc', filters={ - 'id': n1.pk - }).count(), 2 - ) + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append(orm.Node, with_descendants='desc', filters={ + 'id': n1.pk + }).count() == 2 - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append( - orm.Node, - with_descendants='desc', - filters={ - 'id': n1.pk - }, - edge_filters={ - 'depth': { - '<': 6 - } - }, - ).count(), 2 - ) - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append( - orm.Node, - with_descendants='desc', - filters={ - 'id': n1.pk - }, - edge_filters={ - 'depth': 5 - }, - ).count(), 2 - ) - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append( - orm.Node, - with_descendants='desc', - filters={ - 'id': n1.pk - }, - edge_filters={ - 'depth': { - '<': 5 - } - }, - ).count(), 0 - ) + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append( + orm.Node, + with_descendants='desc', + filters={ + 'id': n1.pk + }, + edge_filters={ + 'depth': { + '<': 6 + } + }, + ).count() == 2 + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append( + orm.Node, + with_descendants='desc', + filters={ + 'id': n1.pk + }, + edge_filters={ + 'depth': 5 + }, + ).count() == 2 + assert orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append( + orm.Node, + with_descendants='desc', + filters={ + 'id': n1.pk + }, + edge_filters={ + 'depth': { + '<': 5 + } + }, + ).count() == 0 # TODO write a query that can filter certain paths by traversed ID # pylint: disable=fixme qb = orm.QueryBuilder().append( @@ -1241,16 +1214,16 @@ def test_query_path(self): frozenset([n1.pk, n2.pk, n4.pk, n5.pk, n6.pk, n7.pk, n8.pk]) } - self.assertTrue(queried_path_set == paths_there_should_be) + assert queried_path_set == paths_there_should_be qb = orm.QueryBuilder().append(orm.Node, filters={ 'id': n1.pk }, tag='anc').append(orm.Node, with_ancestors='anc', filters={'id': n8.pk}, edge_project='path') - self.assertEqual({frozenset(p) for p, in qb.all()}, { + assert {frozenset(p) for p, in qb.all()} == { frozenset([n1.pk, n2.pk, n3.pk, n5.pk, n6.pk, n7.pk, n8.pk]), frozenset([n1.pk, n2.pk, n4.pk, n5.pk, n6.pk, n7.pk, n8.pk]) - }) + } # This part of the test is no longer possible as the nodes have already been stored and the previous parts of # the test rely on this, which means however, that here, no more links can be added as that will raise. @@ -1299,7 +1272,8 @@ def test_query_path(self): # self.assertTrue(set(next(zip(*qb.all()))), set([5])) -class TestConsistency(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestConsistency: def test_create_node_and_query(self): """ @@ -1315,8 +1289,8 @@ def test_create_node_and_query(self): if idx % 10 == 10: n = orm.Data() n.store() - self.assertEqual(idx, 99) # pylint: disable=undefined-loop-variable - self.assertTrue(len(orm.QueryBuilder().append(orm.Node, project=['id', 'label']).all(batch_size=10)) > 99) + assert idx == 99 # pylint: disable=undefined-loop-variable + assert len(orm.QueryBuilder().append(orm.Node, project=['id', 'label']).all(batch_size=10)) > 99 def test_len_results(self): """ @@ -1333,10 +1307,11 @@ def test_len_results(self): qb = orm.QueryBuilder() qb.append(orm.CalculationNode, filters={'id': parent.id}, tag='parent', project=projection) qb.append(orm.Data, with_incoming='parent') - self.assertEqual(len(qb.all()), qb.count()) + assert len(qb.all()) == qb.count() -class TestManager(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestManager: def test_statistics(self): """ @@ -1345,6 +1320,9 @@ def test_statistics(self): I try to implement it in a way that does not depend on the past state. """ from collections import defaultdict + from aiida.manage.manager import get_manager + + backend = get_manager().get_backend() # pylint: disable=protected-access @@ -1354,7 +1332,7 @@ def store_and_add(n, statistics): statistics['types'][n._plugin_type_string] += 1 # pylint: disable=no-member statistics['ctime_by_day'][n.ctime.strftime('%Y-%m-%d')] += 1 - qmanager = self.backend.query_manager + qmanager = backend.query_manager current_db_statistics = qmanager.get_creation_statistics() types = defaultdict(int) types.update(current_db_statistics['types']) @@ -1376,7 +1354,7 @@ def store_and_add(n, statistics): k: dict(v) if isinstance(v, defaultdict) else v for k, v in expected_db_statistics.items() } - self.assertEqual(new_db_statistics, expected_db_statistics) + assert new_db_statistics == expected_db_statistics def test_statistics_default_class(self): """ @@ -1385,6 +1363,9 @@ def test_statistics_default_class(self): I try to implement it in a way that does not depend on the past state. """ from collections import defaultdict + from aiida.manage.manager import get_manager + + backend = get_manager().get_backend() def store_and_add(n, statistics): n.store() @@ -1392,7 +1373,7 @@ def store_and_add(n, statistics): statistics['types'][n._plugin_type_string] += 1 # pylint: disable=no-member,protected-access statistics['ctime_by_day'][n.ctime.strftime('%Y-%m-%d')] += 1 - current_db_statistics = self.backend.query_manager.get_creation_statistics() + current_db_statistics = backend.query_manager.get_creation_statistics() types = defaultdict(int) types.update(current_db_statistics['types']) ctime_by_day = defaultdict(int) @@ -1405,7 +1386,7 @@ def store_and_add(n, statistics): store_and_add(orm.Dict(), expected_db_statistics) store_and_add(orm.CalculationNode(), expected_db_statistics) - new_db_statistics = self.backend.query_manager.get_creation_statistics() + new_db_statistics = backend.query_manager.get_creation_statistics() # I only check a few fields new_db_statistics = {k: v for k, v in new_db_statistics.items() if k in expected_db_statistics} @@ -1413,47 +1394,45 @@ def store_and_add(n, statistics): k: dict(v) if isinstance(v, defaultdict) else v for k, v in expected_db_statistics.items() } - self.assertEqual(new_db_statistics, expected_db_statistics) + assert new_db_statistics == expected_db_statistics -class TestDoubleStar(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test') +def test_statistics_default_class(aiida_localhost): """ - In this test class we check if QueryBuilder returns the correct results + In this test case we check if QueryBuilder returns the correct results when double star is provided as projection. """ - - def test_statistics_default_class(self): - - # The expected result - # pylint: disable=no-member - expected_dict = { - 'description': self.computer.description, - 'scheduler_type': self.computer.scheduler_type, - 'hostname': self.computer.hostname, - 'uuid': self.computer.uuid, - 'name': self.computer.label, - 'transport_type': self.computer.transport_type, - 'id': self.computer.id, - 'metadata': self.computer.metadata, - } - - qb = orm.QueryBuilder() - qb.append(orm.Computer, project=['**']) - # We expect one result - self.assertEqual(qb.count(), 1) - - # Get the one result record and check that the returned - # data are correct - res = list(qb.dict()[0].values())[0] - self.assertDictEqual(res, expected_dict) - - # Ask the same query as above using queryhelp - qh = {'project': {'computer': ['**']}, 'path': [{'tag': 'computer', 'cls': orm.Computer}]} - qb = orm.QueryBuilder(**qh) - # We expect one result - self.assertEqual(qb.count(), 1) - - # Get the one result record and check that the returned - # data are correct - res = list(qb.dict()[0].values())[0] - self.assertDictEqual(res, expected_dict) + # The expected result + # pylint: disable=no-member + expected_dict = { + 'description': aiida_localhost.description, + 'scheduler_type': aiida_localhost.scheduler_type, + 'hostname': aiida_localhost.hostname, + 'uuid': aiida_localhost.uuid, + 'name': aiida_localhost.label, + 'transport_type': aiida_localhost.transport_type, + 'id': aiida_localhost.id, + 'metadata': aiida_localhost.metadata, + } + + qb = orm.QueryBuilder() + qb.append(orm.Computer, project=['**']) + # We expect one result + assert qb.count() == 1 + + # Get the one result record and check that the returned + # data are correct + res = list(qb.dict()[0].values())[0] + assert res == expected_dict + + # Ask the same query as above using queryhelp + qh = {'project': {'computer': ['**']}, 'path': [{'tag': 'computer', 'cls': orm.Computer}]} + qb = orm.QueryBuilder(**qh) + # We expect one result + assert qb.count() == 1 + + # Get the one result record and check that the returned + # data are correct + res = list(qb.dict()[0].values())[0] + assert res == expected_dict diff --git a/tests/test_base_dataclasses.py b/tests/test_base_dataclasses.py index dca5850f36..2e4ac7a89d 100644 --- a/tests/test_base_dataclasses.py +++ b/tests/test_base_dataclasses.py @@ -7,30 +7,33 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for AiiDA base data classes.""" import operator -from aiida.backends.testbase import AiidaTestCase +import pytest + from aiida.common.exceptions import ModificationNotAllowed from aiida.orm import load_node, List, Bool, Float, Int, Str, NumericType from aiida.orm.nodes.data.bool import get_true_node, get_false_node -class TestList(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestList: """Test AiiDA List class.""" def test_creation(self): node = List() - self.assertEqual(len(node), 0) - with self.assertRaises(IndexError): + assert len(node) == 0 + with pytest.raises(IndexError): node[0] # pylint: disable=pointless-statement def test_append(self): """Test append() member function.""" def do_checks(node): - self.assertEqual(len(node), 1) - self.assertEqual(node[0], 4) + assert len(node) == 1 + assert node[0] == 4 node = List() node.append(4) @@ -47,22 +50,22 @@ def test_extend(self): lst = [1, 2, 3] def do_checks(node): - self.assertEqual(len(node), len(lst)) + assert len(node) == len(lst) # Do an element wise comparison for lst_, node_ in zip(lst, node): - self.assertEqual(lst_, node_) + assert lst_ == node_ node = List() node.extend(lst) do_checks(node) # Further extend node.extend(lst) - self.assertEqual(len(node), len(lst) * 2) + assert len(node) == len(lst) * 2 # Do an element wise comparison for i, _ in enumerate(lst): - self.assertEqual(lst[i], node[i]) - self.assertEqual(lst[i], node[i % len(lst)]) + assert lst[i] == node[i] + assert lst[i] == node[i % len(lst)] # Now try after storing node = List() @@ -77,19 +80,19 @@ def test_mutability(self): node.store() # Test all mutable calls are now disallowed - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): node.append(5) - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): node.extend([5]) - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): node.insert(0, 2) - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): node.remove(0) - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): node.pop() - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): node.sort() - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): node.reverse() @staticmethod @@ -102,11 +105,12 @@ def test_store_load(): assert node.get_list() == node_loaded.get_list() -class TestFloat(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestFloat: """Test Float class.""" - def setUp(self): - super().setUp() + def setup_method(self): + # pylint: disable=attribute-defined-outside-init self.value = Float() self.all_types = [Int, Float, Bool, Str] @@ -114,31 +118,31 @@ def test_create(self): """Creating basic data objects.""" term_a = Float() # Check that initial value is zero - self.assertAlmostEqual(term_a.value, 0.0) + assert round(abs(term_a.value - 0.0), 7) == 0 float_ = Float(6.0) - self.assertAlmostEqual(float_.value, 6.) - self.assertAlmostEqual(float_, Float(6.0)) + assert round(abs(float_.value - 6.), 7) == 0 + assert round(abs(float_ - Float(6.0)), 7) == 0 int_ = Int() - self.assertAlmostEqual(int_.value, 0) + assert round(abs(int_.value - 0), 7) == 0 int_ = Int(6) - self.assertAlmostEqual(int_.value, 6) - self.assertAlmostEqual(float_, int_) + assert round(abs(int_.value - 6), 7) == 0 + assert round(abs(float_ - int_), 7) == 0 bool_ = Bool() - self.assertAlmostEqual(bool_.value, False) + assert round(abs(bool_.value - False), 7) == 0 bool_ = Bool(False) - self.assertAlmostEqual(bool_.value, False) - self.assertAlmostEqual(bool_.value, get_false_node()) + assert bool_.value is False + assert bool_.value == get_false_node() bool_ = Bool(True) - self.assertAlmostEqual(bool_.value, True) - self.assertAlmostEqual(bool_.value, get_true_node()) + assert bool_.value is True + assert bool_.value == get_true_node() str_ = Str() - self.assertAlmostEqual(str_.value, '') + assert str_.value == '' str_ = Str('Hello') - self.assertAlmostEqual(str_.value, 'Hello') + assert str_.value == 'Hello' def test_load(self): """Test object loading.""" @@ -146,7 +150,7 @@ def test_load(self): node = typ() node.store() loaded = load_node(node.pk) - self.assertAlmostEqual(node, loaded) + assert node == loaded def test_add(self): """Test addition.""" @@ -154,26 +158,26 @@ def test_add(self): term_b = Float(5) # Check adding two db Floats res = term_a + term_b - self.assertIsInstance(res, NumericType) - self.assertAlmostEqual(res, 9.0) + assert isinstance(res, NumericType) + assert round(abs(res - 9.0), 7) == 0 # Check adding db Float and native (both ways) res = term_a + 5.0 - self.assertIsInstance(res, NumericType) - self.assertAlmostEqual(res, 9.0) + assert isinstance(res, NumericType) + assert round(abs(res - 9.0), 7) == 0 res = 5.0 + term_a - self.assertIsInstance(res, NumericType) - self.assertAlmostEqual(res, 9.0) + assert isinstance(res, NumericType) + assert round(abs(res - 9.0), 7) == 0 # Inplace term_a = Float(4) term_a += term_b - self.assertAlmostEqual(term_a, 9.0) + assert round(abs(term_a - 9.0), 7) == 0 term_a = Float(4) term_a += 5 - self.assertAlmostEqual(term_a, 9.0) + assert round(abs(term_a - 9.0), 7) == 0 def test_mul(self): """Test floats multiplication.""" @@ -181,26 +185,26 @@ def test_mul(self): term_b = Float(5) # Check adding two db Floats res = term_a * term_b - self.assertIsInstance(res, NumericType) - self.assertAlmostEqual(res, 20.0) + assert isinstance(res, NumericType) + assert round(abs(res - 20.0), 7) == 0 # Check adding db Float and native (both ways) res = term_a * 5.0 - self.assertIsInstance(res, NumericType) - self.assertAlmostEqual(res, 20) + assert isinstance(res, NumericType) + assert round(abs(res - 20), 7) == 0 res = 5.0 * term_a - self.assertIsInstance(res, NumericType) - self.assertAlmostEqual(res, 20.0) + assert isinstance(res, NumericType) + assert round(abs(res - 20.0), 7) == 0 # Inplace term_a = Float(4) term_a *= term_b - self.assertAlmostEqual(term_a, 20) + assert round(abs(term_a - 20), 7) == 0 term_a = Float(4) term_a *= 5 - self.assertAlmostEqual(term_a, 20) + assert round(abs(term_a - 20), 7) == 0 def test_power(self): """Test power operator.""" @@ -208,38 +212,39 @@ def test_power(self): term_b = Float(2) res = term_a**term_b - self.assertAlmostEqual(res.value, 16.) + assert round(abs(res.value - 16.), 7) == 0 def test_division(self): """Test the normal division operator.""" term_a = Float(3) term_b = Float(2) - self.assertAlmostEqual(term_a / term_b, 1.5) - self.assertIsInstance(term_a / term_b, Float) + assert round(abs(term_a / term_b - 1.5), 7) == 0 + assert isinstance(term_a / term_b, Float) def test_division_integer(self): """Test the integer division operator.""" term_a = Float(3) term_b = Float(2) - self.assertAlmostEqual(term_a // term_b, 1.0) - self.assertIsInstance(term_a // term_b, Float) + assert round(abs(term_a // term_b - 1.0), 7) == 0 + assert isinstance(term_a // term_b, Float) def test_modulus(self): """Test modulus operator.""" term_a = Float(12.0) term_b = Float(10.0) - self.assertAlmostEqual(term_a % term_b, 2.0) - self.assertIsInstance(term_a % term_b, NumericType) - self.assertAlmostEqual(term_a % 10.0, 2.0) - self.assertIsInstance(term_a % 10.0, NumericType) - self.assertAlmostEqual(12.0 % term_b, 2.0) - self.assertIsInstance(12.0 % term_b, NumericType) + assert round(abs(term_a % term_b - 2.0), 7) == 0 + assert isinstance(term_a % term_b, NumericType) + assert round(abs(term_a % 10.0 - 2.0), 7) == 0 + assert isinstance(term_a % 10.0, NumericType) + assert round(abs(12.0 % term_b - 2.0), 7) == 0 + assert isinstance(12.0 % term_b, NumericType) -class TestFloatIntMix(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestFloatIntMix: """Test operations between Int and Float objects.""" def test_operator(self): @@ -254,11 +259,12 @@ def test_operator(self): for term_x, term_y in [(term_a, term_b), (term_b, term_a)]: res = oper(term_x, term_y) c_val = oper(term_x.value, term_y.value) - self.assertEqual(res._type, type(c_val)) # pylint: disable=protected-access - self.assertEqual(res, oper(term_x.value, term_y.value)) + assert res._type == type(c_val) # pylint: disable=protected-access + assert res == oper(term_x.value, term_y.value) -class TestInt(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestInt: """Test Int class.""" def test_division(self): @@ -266,25 +272,25 @@ def test_division(self): term_a = Int(3) term_b = Int(2) - self.assertAlmostEqual(term_a / term_b, 1.5) - self.assertIsInstance(term_a / term_b, Float) + assert round(abs(term_a / term_b - 1.5), 7) == 0 + assert isinstance(term_a / term_b, Float) def test_division_integer(self): """Test the integer division operator.""" term_a = Int(3) term_b = Int(2) - self.assertAlmostEqual(term_a // term_b, 1) - self.assertIsInstance(term_a // term_b, Int) + assert round(abs(term_a // term_b - 1), 7) == 0 + assert isinstance(term_a // term_b, Int) def test_modulo(self): """Test modulus operation.""" term_a = Int(12) term_b = Int(10) - self.assertEqual(term_a % term_b, 2) - self.assertIsInstance(term_a % term_b, NumericType) - self.assertEqual(term_a % 10, 2) - self.assertIsInstance(term_a % 10, NumericType) - self.assertEqual(12 % term_b, 2) - self.assertIsInstance(12 % term_b, NumericType) + assert term_a % term_b == 2 + assert isinstance(term_a % term_b, NumericType) + assert term_a % 10 == 2 + assert isinstance(term_a % 10, NumericType) + assert 12 % term_b == 2 + assert isinstance(12 % term_b, NumericType) diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 7c4e939848..a6cc9eb5b9 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -8,14 +8,14 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=too-many-lines,invalid-name +# pylint: disable=no-self-use """Tests for specific subclasses of Data.""" import os import tempfile -import unittest +from numpy.testing import assert_allclose import pytest -from aiida.backends.testbase import AiidaTestCase from aiida.common.exceptions import ModificationNotAllowed from aiida.common.utils import Capturing from aiida.orm import load_node @@ -51,7 +51,8 @@ def simplify(string): return '\n'.join(s.strip() for s in string.split()) -class TestCifData(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestCifData: """Tests for CifData class.""" from distutils.version import StrictVersion from aiida.orm.nodes.data.cif import has_pycifrw @@ -96,7 +97,7 @@ class TestCifData(AiidaTestCase): O 0.5 0.5 0.5 . ''' - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_reload_cifdata(self): """Test `CifData` cycle.""" file_content = 'data_test _cell_length_a 10(1)' @@ -108,35 +109,35 @@ def test_reload_cifdata(self): a = CifData(file=filename, source={'version': '1234', 'db_name': 'COD', 'id': '0000001'}) # Key 'db_kind' is not allowed in source description: - with self.assertRaises(KeyError): + with pytest.raises(KeyError): a.source = {'db_kind': 'small molecule'} the_uuid = a.uuid - self.assertEqual(a.list_object_names(), [basename]) + assert a.list_object_names() == [basename] with a.open() as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content a.store() - self.assertEqual(a.source, { + assert a.source == { 'db_name': 'COD', 'id': '0000001', 'version': '1234', - }) + } with a.open() as fhandle: - self.assertEqual(fhandle.read(), file_content) - self.assertEqual(a.list_object_names(), [basename]) + assert fhandle.read() == file_content + assert a.list_object_names() == [basename] b = load_node(the_uuid) # I check the retrieved object - self.assertTrue(isinstance(b, CifData)) - self.assertEqual(b.list_object_names(), [basename]) + assert isinstance(b, CifData) + assert b.list_object_names() == [basename] with b.open() as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # Checking the get_or_create() method: with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -144,9 +145,9 @@ def test_reload_cifdata(self): tmpf.flush() c, created = CifData.get_or_create(tmpf.name, store_cif=False) - self.assertTrue(isinstance(c, CifData)) - self.assertTrue(not created) - self.assertEqual(c.get_content(), file_content) + assert isinstance(c, CifData) + assert not created + assert c.get_content() == file_content other_content = 'data_test _cell_length_b 10(1)' with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -154,11 +155,11 @@ def test_reload_cifdata(self): tmpf.flush() c, created = CifData.get_or_create(tmpf.name, store_cif=False) - self.assertTrue(isinstance(c, CifData)) - self.assertTrue(created) - self.assertEqual(c.get_content(), other_content) + assert isinstance(c, CifData) + assert created + assert c.get_content() == other_content - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_parse_cifdata(self): """Test parsing a CIF file.""" file_content = 'data_test _cell_length_a 10(1)' @@ -167,9 +168,9 @@ def test_parse_cifdata(self): tmpf.flush() a = CifData(file=tmpf.name) - self.assertEqual(list(a.values.keys()), ['test']) + assert list(a.values.keys()) == ['test'] - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_change_cifdata_file(self): """Test changing file for `CifData` before storing.""" file_content_1 = 'data_test _cell_length_a 10(1)' @@ -179,17 +180,17 @@ def test_change_cifdata_file(self): tmpf.flush() a = CifData(file=tmpf.name) - self.assertEqual(a.values['test']['_cell_length_a'], '10(1)') + assert a.values['test']['_cell_length_a'] == '10(1)' with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write(file_content_2) tmpf.flush() a.set_file(tmpf.name) - self.assertEqual(a.values['test']['_cell_length_a'], '11(1)') + assert a.values['test']['_cell_length_a'] == '11(1)' - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') @pytest.mark.requires_rmq def test_get_structure(self): """Test `CifData.get_structure`.""" @@ -219,15 +220,15 @@ def test_get_structure(self): tmpf.flush() a = CifData(file=tmpf.name) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.get_structure(converter='none') c = a.get_structure() - self.assertEqual(c.get_kind_names(), ['C', 'O']) + assert c.get_kind_names() == ['C', 'O'] - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') @pytest.mark.requires_rmq def test_ase_primitive_and_conventional_cells_ase(self): """Checking the number of atoms per primitive/conventional cell @@ -263,17 +264,17 @@ def test_ase_primitive_and_conventional_cells_ase(self): c = CifData(file=tmpf.name) ase = c.get_structure(converter='ase', primitive_cell=False).get_ase() - self.assertEqual(ase.get_global_number_of_atoms(), 15) + assert ase.get_global_number_of_atoms() == 15 ase = c.get_structure(converter='ase').get_ase() - self.assertEqual(ase.get_global_number_of_atoms(), 15) + assert ase.get_global_number_of_atoms() == 15 ase = c.get_structure(converter='ase', primitive_cell=True, subtrans_included=False).get_ase() - self.assertEqual(ase.get_global_number_of_atoms(), 5) + assert ase.get_global_number_of_atoms() == 5 - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') @pytest.mark.requires_rmq def test_ase_primitive_and_conventional_cells_pymatgen(self): """Checking the number of atoms per primitive/conventional cell @@ -319,15 +320,15 @@ def test_ase_primitive_and_conventional_cells_pymatgen(self): c = CifData(file=tmpf.name) ase = c.get_structure(converter='pymatgen', primitive_cell=False).get_ase() - self.assertEqual(ase.get_global_number_of_atoms(), 15) + assert ase.get_global_number_of_atoms() == 15 ase = c.get_structure(converter='pymatgen').get_ase() - self.assertEqual(ase.get_global_number_of_atoms(), 15) + assert ase.get_global_number_of_atoms() == 15 ase = c.get_structure(converter='pymatgen', primitive_cell=True).get_ase() - self.assertEqual(ase.get_global_number_of_atoms(), 5) + assert ase.get_global_number_of_atoms() == 5 - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_pycifrw_from_datablocks(self): """ Tests CifData.pycifrw_from_cif() @@ -346,8 +347,7 @@ def test_pycifrw_from_datablocks(self): for line in lines: if not re.search('^#', line): non_comments.append(line) - self.assertEqual( - simplify('\n'.join(non_comments)), + assert simplify('\n'.join(non_comments)) == \ simplify( ''' data_0 @@ -366,7 +366,6 @@ def test_pycifrw_from_datablocks(self): _publ_section_title 'Test CIF' ''' ) - ) loops = {'_atom_site': ['_atom_site_label', '_atom_site_occupancy']} with Capturing(): @@ -375,8 +374,7 @@ def test_pycifrw_from_datablocks(self): for line in lines: if not re.search('^#', line): non_comments.append(line) - self.assertEqual( - simplify('\n'.join(non_comments)), + assert simplify('\n'.join(non_comments)) == \ simplify( ''' data_0 @@ -390,9 +388,8 @@ def test_pycifrw_from_datablocks(self): _publ_section_title 'Test CIF' ''' ) - ) - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_pycifrw_syntax(self): """Tests CifData.pycifrw_from_cif() - check syntax pb in PyCifRW 3.6.""" from aiida.orm.nodes.data.cif import pycifrw_from_cif @@ -407,15 +404,13 @@ def test_pycifrw_syntax(self): for line in lines: if not re.search('^#', line): non_comments.append(line) - self.assertEqual( - simplify('\n'.join(non_comments)), + assert simplify('\n'.join(non_comments)) == \ simplify(''' data_0 _tag '[value]' ''') - ) - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') @staticmethod def test_cif_with_long_line(): """Tests CifData - check that long lines (longer than 2048 characters) are supported. @@ -428,8 +423,8 @@ def test_cif_with_long_line(): tmpf.flush() _ = CifData(file=tmpf.name) - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_cif_roundtrip(self): """Test the `CifData` roundtrip.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -458,26 +453,24 @@ def test_cif_roundtrip(self): b = CifData(values=a.values) c = CifData(values=b.values) - self.assertEqual(b._prepare_cif(), c._prepare_cif()) # pylint: disable=protected-access + assert b._prepare_cif() == c._prepare_cif() # pylint: disable=protected-access b = CifData(ase=a.ase) c = CifData(ase=b.ase) - self.assertEqual(b._prepare_cif(), c._prepare_cif()) # pylint: disable=protected-access + assert b._prepare_cif() == c._prepare_cif() # pylint: disable=protected-access def test_symop_string_from_symop_matrix_tr(self): """Test symmetry operations.""" from aiida.tools.data.cif import symop_string_from_symop_matrix_tr - self.assertEqual(symop_string_from_symop_matrix_tr([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), 'x,y,z') + assert symop_string_from_symop_matrix_tr([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) == 'x,y,z' - self.assertEqual(symop_string_from_symop_matrix_tr([[1, 0, 0], [0, -1, 0], [0, 1, 1]]), 'x,-y,y+z') + assert symop_string_from_symop_matrix_tr([[1, 0, 0], [0, -1, 0], [0, 1, 1]]) == 'x,-y,y+z' - self.assertEqual( - symop_string_from_symop_matrix_tr([[-1, 0, 0], [0, 1, 0], [0, 0, 1]], [1, -1, 0]), '-x+1,y-1,z' - ) + assert symop_string_from_symop_matrix_tr([[-1, 0, 0], [0, 1, 0], [0, 0, 1]], [1, -1, 0]) == '-x+1,y-1,z' - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_attached_hydrogens(self): """Test parsing of file with attached hydrogens.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -504,7 +497,7 @@ def test_attached_hydrogens(self): tmpf.flush() a = CifData(file=tmpf.name) - self.assertEqual(a.has_attached_hydrogens, False) + assert a.has_attached_hydrogens is False with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write( @@ -530,11 +523,11 @@ def test_attached_hydrogens(self): tmpf.flush() a = CifData(file=tmpf.name) - self.assertEqual(a.has_attached_hydrogens, True) + assert a.has_attached_hydrogens is True - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') - @unittest.skipIf(not has_spglib(), 'Unable to import spglib') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') + @pytest.mark.skipif(not has_spglib(), reason='Unable to import spglib') @pytest.mark.requires_rmq def test_refine(self): """ @@ -568,14 +561,12 @@ def test_refine(self): ret_dict = refine_inline(a) b = ret_dict['cif'] - self.assertEqual(list(b.values.keys()), ['test']) - self.assertEqual(b.values['test']['_chemical_formula_sum'], 'C O2') - self.assertEqual( - b.values['test']['_symmetry_equiv_pos_as_xyz'], [ - 'x,y,z', '-x,-y,-z', '-y,x,z', 'y,-x,-z', '-x,-y,z', 'x,y,-z', 'y,-x,z', '-y,x,-z', 'x,-y,-z', '-x,y,z', - '-y,-x,-z', 'y,x,z', '-x,y,-z', 'x,-y,z', 'y,x,-z', '-y,-x,z' - ] - ) + assert list(b.values.keys()) == ['test'] + assert b.values['test']['_chemical_formula_sum'] == 'C O2' + assert b.values['test']['_symmetry_equiv_pos_as_xyz'] == [ + 'x,y,z', '-x,-y,-z', '-y,x,z', 'y,-x,-z', '-x,-y,z', 'x,y,-z', 'y,-x,z', '-y,x,-z', 'x,-y,-z', '-x,y,z', + '-y,-x,-z', 'y,x,z', '-x,y,-z', 'x,-y,z', 'y,x,-z', '-y,-x,z' + ] with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write(''' @@ -585,10 +576,10 @@ def test_refine(self): tmpf.flush() c = CifData(file=tmpf.name) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ret_dict = refine_inline(c) - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_scan_type(self): """Check that different scan_types of PyCifRW produce the same result.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -597,12 +588,12 @@ def test_scan_type(self): default = CifData(file=tmpf.name) default2 = CifData(file=tmpf.name, scan_type='standard') - self.assertEqual(default._prepare_cif(), default2._prepare_cif()) # pylint: disable=protected-access + assert default._prepare_cif() == default2._prepare_cif() # pylint: disable=protected-access flex = CifData(file=tmpf.name, scan_type='flex') - self.assertEqual(default._prepare_cif(), flex._prepare_cif()) # pylint: disable=protected-access + assert default._prepare_cif() == flex._prepare_cif() # pylint: disable=protected-access - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_empty_cif(self): """Test empty CifData @@ -615,7 +606,7 @@ def test_empty_cif(self): a = CifData() # but it does not have a file - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): _ = a.filename #now it has @@ -624,7 +615,7 @@ def test_empty_cif(self): a.store() - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_parse_policy(self): """Test that loading of CIF file occurs as defined by parse_policy.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -633,20 +624,20 @@ def test_parse_policy(self): # this will parse the cif eager = CifData(file=tmpf.name, parse_policy='eager') - self.assertIsNot(eager._values, None) # pylint: disable=protected-access + assert eager._values is not None # pylint: disable=protected-access # this should not parse the cif lazy = CifData(file=tmpf.name, parse_policy='lazy') - self.assertIs(lazy._values, None) # pylint: disable=protected-access + assert lazy._values is None # pylint: disable=protected-access # also lazy-loaded nodes should be storable lazy.store() # this should parse the cif _ = lazy.values - self.assertIsNot(lazy._values, None) # pylint: disable=protected-access + assert lazy._values is not None # pylint: disable=protected-access - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_set_file(self): """Test that setting a new file clears formulae and spacegroups.""" with tempfile.NamedTemporaryFile(mode='w+') as tmpf: @@ -655,7 +646,7 @@ def test_set_file(self): a = CifData(file=tmpf.name) f1 = a.get_formulae() - self.assertIsNot(f1, None) + assert f1 is not None with tempfile.NamedTemporaryFile(mode='w+') as tmpf: tmpf.write(self.valid_sample_cif_str_2) @@ -663,27 +654,27 @@ def test_set_file(self): # this should reset formulae and spacegroup_numbers a.set_file(tmpf.name) - self.assertIs(a.get_attribute('formulae'), None) - self.assertIs(a.get_attribute('spacegroup_numbers'), None) + assert a.get_attribute('formulae') is None + assert a.get_attribute('spacegroup_numbers') is None # this should populate formulae a.parse() f2 = a.get_formulae() - self.assertIsNot(f2, None) + assert f2 is not None # empty cifdata should be possible a = CifData() # but it does not have a file - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): _ = a.filename # now it has a.set_file(tmpf.name) a.parse() _ = a.filename - self.assertNotEqual(f1, f2) + assert f1 != f2 - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_has_partial_occupancies(self): """Test structure with partial occupancies.""" tests = [ @@ -713,9 +704,9 @@ def test_has_partial_occupancies(self): ) handle.flush() cif = CifData(file=handle.name) - self.assertEqual(cif.has_partial_occupancies, result) + assert cif.has_partial_occupancies == result - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_has_unknown_species(self): """Test structure with unknown species.""" tests = [ @@ -731,9 +722,9 @@ def test_has_unknown_species(self): handle.write("""data_test\n{}\n""".format(formula_string)) handle.flush() cif = CifData(file=handle.name) - self.assertEqual(cif.has_unknown_species, result, formula_string) + assert cif.has_unknown_species == result, formula_string - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_has_undefined_atomic_sites(self): """Test structure with undefined atomic sites.""" tests = [ @@ -749,20 +740,21 @@ def test_has_undefined_atomic_sites(self): handle.write("""data_test\n{}\n""".format(atomic_site_string)) handle.flush() cif = CifData(file=handle.name) - self.assertEqual(cif.has_undefined_atomic_sites, result) + assert cif.has_undefined_atomic_sites == result -class TestKindValidSymbols(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestKindValidSymbols: """Tests the symbol validation of the aiida.orm.nodes.data.structure.Kind class.""" def test_bad_symbol(self): """Should not accept a non-existing symbol.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols='Hxx') def test_empty_list_symbols(self): """Should not accept an empty list.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols=[]) @staticmethod @@ -776,35 +768,36 @@ def test_unknown_symbol(): Kind(symbols=['X']) -class TestSiteValidWeights(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestSiteValidWeights: """Tests valid weight lists.""" def test_isnot_list(self): """Should not accept a non-list, non-number weight.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols='Ba', weights='aaa') def test_empty_list_weights(self): """Should not accept an empty list.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols='Ba', weights=[]) def test_symbol_weight_mismatch(self): """Should not accept a size mismatch of the symbols and weights list.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols=['Ba', 'C'], weights=[1.]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols=['Ba'], weights=[0.1, 0.2]) def test_negative_value(self): """Should not accept a negative weight.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols=['Ba', 'C'], weights=[-0.1, 0.3]) def test_sum_greater_one(self): """Should not accept a sum of weights larger than one.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(symbols=['Ba', 'C'], weights=[0.5, 0.6]) @staticmethod @@ -823,24 +816,25 @@ def test_none(): Kind(symbols='Ba', weights=None) -class TestKindTestGeneral(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestKindTestGeneral: """Tests the creation of Kind objects and their methods.""" def test_sum_one_general(self): """Should accept a sum equal to one.""" a = Kind(symbols=['Ba', 'C'], weights=[1. / 3., 2. / 3.]) - self.assertTrue(a.is_alloy) - self.assertFalse(a.has_vacancies) + assert a.is_alloy + assert not a.has_vacancies def test_sum_less_one_general(self): """Should accept a sum equal less than one.""" a = Kind(symbols=['Ba', 'C'], weights=[1. / 3., 1. / 3.]) - self.assertTrue(a.is_alloy) - self.assertTrue(a.has_vacancies) + assert a.is_alloy + assert a.has_vacancies def test_no_position(self): """Should not accept a 'positions' parameter.""" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Kind(position=[0., 0., 0.], symbols=['Ba'], weights=[1.]) def test_simple(self): @@ -848,45 +842,46 @@ def test_simple(self): Should recognize a simple element. """ a = Kind(symbols='Ba') - self.assertFalse(a.is_alloy) - self.assertFalse(a.has_vacancies) + assert not a.is_alloy + assert not a.has_vacancies b = Kind(symbols='Ba', weights=1.) - self.assertFalse(b.is_alloy) - self.assertFalse(b.has_vacancies) + assert not b.is_alloy + assert not b.has_vacancies c = Kind(symbols='Ba', weights=None) - self.assertFalse(c.is_alloy) - self.assertFalse(c.has_vacancies) + assert not c.is_alloy + assert not c.has_vacancies def test_automatic_name(self): """ Check the automatic name generator. """ a = Kind(symbols='Ba') - self.assertEqual(a.name, 'Ba') + assert a.name == 'Ba' a = Kind(symbols='X') - self.assertEqual(a.name, 'X') + assert a.name == 'X' a = Kind(symbols=('Si', 'Ge'), weights=(1. / 3., 2. / 3.)) - self.assertEqual(a.name, 'GeSi') + assert a.name == 'GeSi' a = Kind(symbols=('Si', 'X'), weights=(1. / 3., 2. / 3.)) - self.assertEqual(a.name, 'SiX') + assert a.name == 'SiX' a = Kind(symbols=('Si', 'Ge'), weights=(0.4, 0.5)) - self.assertEqual(a.name, 'GeSiX') + assert a.name == 'GeSiX' a = Kind(symbols=('Si', 'X'), weights=(0.4, 0.5)) - self.assertEqual(a.name, 'SiXX') + assert a.name == 'SiXX' # Manually setting the name of the species a.name = 'newstring' - self.assertEqual(a.name, 'newstring') + assert a.name == 'newstring' -class TestKindTestMasses(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestKindTestMasses: """ Tests the management of masses during the creation of Kind objects. """ @@ -898,7 +893,7 @@ def test_auto_mass_one(self): from aiida.orm.nodes.data.structure import _atomic_masses a = Kind(symbols=['Ba', 'C'], weights=[1. / 3., 2. / 3.]) - self.assertAlmostEqual(a.mass, (_atomic_masses['Ba'] + 2. * _atomic_masses['C']) / 3.) + assert round(abs(a.mass - (_atomic_masses['Ba'] + 2. * _atomic_masses['C']) / 3.), 7) == 0 def test_sum_less_one_masses(self): """ @@ -907,7 +902,7 @@ def test_sum_less_one_masses(self): from aiida.orm.nodes.data.structure import _atomic_masses a = Kind(symbols=['Ba', 'C'], weights=[1. / 3., 1. / 3.]) - self.assertAlmostEqual(a.mass, (_atomic_masses['Ba'] + _atomic_masses['C']) / 2.) + assert round(abs(a.mass - (_atomic_masses['Ba'] + _atomic_masses['C']) / 2.), 7) == 0 def test_sum_less_one_singleelem(self): """ @@ -916,17 +911,18 @@ def test_sum_less_one_singleelem(self): from aiida.orm.nodes.data.structure import _atomic_masses a = Kind(symbols=['Ba']) - self.assertAlmostEqual(a.mass, _atomic_masses['Ba']) + assert round(abs(a.mass - _atomic_masses['Ba']), 7) == 0 def test_manual_mass(self): """ mass set manually """ a = Kind(symbols=['Ba', 'C'], weights=[1. / 3., 1. / 3.], mass=1000.) - self.assertAlmostEqual(a.mass, 1000.) + assert round(abs(a.mass - 1000.), 7) == 0 -class TestStructureDataInit(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestStructureDataInit: """ Tests the creation of StructureData objects (cell and pbc). """ @@ -935,28 +931,28 @@ def test_cell_wrong_size_1(self): """ Wrong cell size (not 3x3) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): StructureData(cell=((1., 2., 3.),)) def test_cell_wrong_size_2(self): """ Wrong cell size (not 3x3) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): StructureData(cell=((1., 0., 0.), (0., 0., 3.), (0., 3.))) def test_cell_zero_vector(self): """ Wrong cell (one vector has zero length) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): StructureData(cell=((0., 0., 0.), (0., 1., 0.), (0., 0., 1.))) def test_cell_zero_volume(self): """ Wrong cell (volume is zero) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): StructureData(cell=((1., 0., 0.), (0., 1., 0.), (1., 1., 0.))) def test_cell_ok_init(self): @@ -969,20 +965,20 @@ def test_cell_ok_init(self): for i in range(3): for j in range(3): - self.assertAlmostEqual(cell[i][j], out_cell[i][j]) + assert round(abs(cell[i][j] - out_cell[i][j]), 7) == 0 def test_volume(self): """ Check the volume calculation """ a = StructureData(cell=((1., 0., 0.), (0., 2., 0.), (0., 0., 3.))) - self.assertAlmostEqual(a.get_cell_volume(), 6.) + assert round(abs(a.get_cell_volume() - 6.), 7) == 0 def test_wrong_pbc_1(self): """ Wrong pbc parameter (not bool or iterable) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) StructureData(cell=cell, pbc=1) @@ -990,7 +986,7 @@ def test_wrong_pbc_2(self): """ Wrong pbc parameter (iterable but with wrong len) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) StructureData(cell=cell, pbc=[True, True]) @@ -998,7 +994,7 @@ def test_wrong_pbc_3(self): """ Wrong pbc parameter (iterable but with wrong len) """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) StructureData(cell=cell, pbc=[]) @@ -1008,10 +1004,10 @@ def test_ok_pbc_1(self): """ cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) a = StructureData(cell=cell, pbc=True) - self.assertEqual(a.pbc, tuple([True, True, True])) + assert a.pbc == tuple([True, True, True]) a = StructureData(cell=cell, pbc=False) - self.assertEqual(a.pbc, tuple([False, False, False])) + assert a.pbc == tuple([False, False, False]) def test_ok_pbc_2(self): """ @@ -1019,10 +1015,10 @@ def test_ok_pbc_2(self): """ cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) a = StructureData(cell=cell, pbc=[True]) - self.assertEqual(a.pbc, tuple([True, True, True])) + assert a.pbc == tuple([True, True, True]) a = StructureData(cell=cell, pbc=[False]) - self.assertEqual(a.pbc, tuple([False, False, False])) + assert a.pbc == tuple([False, False, False]) def test_ok_pbc_3(self): """ @@ -1030,10 +1026,11 @@ def test_ok_pbc_3(self): """ cell = ((1., 0., 0.), (0., 2., 0.), (0., 0., 3.)) a = StructureData(cell=cell, pbc=[True, False, True]) - self.assertEqual(a.pbc, tuple([True, False, True])) + assert a.pbc == tuple([True, False, True]) -class TestStructureData(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestStructureData: """ Tests the creation of StructureData objects (cell and pbc). """ @@ -1049,29 +1046,29 @@ def test_cell_ok_and_atoms(self): a = StructureData(cell=cell) out_cell = a.cell - self.assertAlmostEqual(cell, out_cell) + assert_allclose(cell, out_cell) a.append_atom(position=(0., 0., 0.), symbols=['Ba']) a.append_atom(position=(1., 1., 1.), symbols=['Ti']) a.append_atom(position=(1.2, 1.4, 1.6), symbols=['Ti']) - self.assertFalse(a.is_alloy) - self.assertFalse(a.has_vacancies) + assert not a.is_alloy + assert not a.has_vacancies # There should be only two kinds! (two atoms of kind Ti should # belong to the same kind) - self.assertEqual(len(a.kinds), 2) + assert len(a.kinds) == 2 a.append_atom(position=(0.5, 1., 1.5), symbols=['O', 'C'], weights=[0.5, 0.5]) - self.assertTrue(a.is_alloy) - self.assertFalse(a.has_vacancies) + assert a.is_alloy + assert not a.has_vacancies a.append_atom(position=(0.5, 1., 1.5), symbols=['O'], weights=[0.5]) - self.assertTrue(a.is_alloy) - self.assertTrue(a.has_vacancies) + assert a.is_alloy + assert a.has_vacancies a.clear_kinds() a.append_atom(position=(0.5, 1., 1.5), symbols=['O'], weights=[0.5]) - self.assertFalse(a.is_alloy) - self.assertTrue(a.has_vacancies) + assert not a.is_alloy + assert a.has_vacancies def test_cell_ok_and_unknown_atoms(self): """ @@ -1082,29 +1079,29 @@ def test_cell_ok_and_unknown_atoms(self): a = StructureData(cell=cell) out_cell = a.cell - self.assertAlmostEqual(cell, out_cell) + assert_allclose(cell, out_cell) a.append_atom(position=(0., 0., 0.), symbols=['Ba']) a.append_atom(position=(1., 1., 1.), symbols=['X']) a.append_atom(position=(1.2, 1.4, 1.6), symbols=['X']) - self.assertFalse(a.is_alloy) - self.assertFalse(a.has_vacancies) + assert not a.is_alloy + assert not a.has_vacancies # There should be only two kinds! (two atoms of kind X should # belong to the same kind) - self.assertEqual(len(a.kinds), 2) + assert len(a.kinds) == 2 a.append_atom(position=(0.5, 1., 1.5), symbols=['O', 'C'], weights=[0.5, 0.5]) - self.assertTrue(a.is_alloy) - self.assertFalse(a.has_vacancies) + assert a.is_alloy + assert not a.has_vacancies a.append_atom(position=(0.5, 1., 1.5), symbols=['O'], weights=[0.5]) - self.assertTrue(a.is_alloy) - self.assertTrue(a.has_vacancies) + assert a.is_alloy + assert a.has_vacancies a.clear_kinds() a.append_atom(position=(0.5, 1., 1.5), symbols=['X'], weights=[0.5]) - self.assertFalse(a.is_alloy) - self.assertTrue(a.has_vacancies) + assert not a.is_alloy + assert a.has_vacancies def test_kind_1(self): """ @@ -1117,9 +1114,9 @@ def test_kind_1(self): a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Ba']) a.append_atom(position=(1., 1., 1.), symbols=['Ti']) - self.assertEqual(len(a.kinds), 2) # I should only have two types + assert len(a.kinds) == 2 # I should only have two types # I check for the default names of kinds - self.assertEqual(set(k.name for k in a.kinds), set(('Ba', 'Ti'))) + assert set(k.name for k in a.kinds) == set(('Ba', 'Ti')) def test_kind_1_unknown(self): """ @@ -1132,9 +1129,9 @@ def test_kind_1_unknown(self): a.append_atom(position=(0.5, 0.5, 0.5), symbols=['X']) a.append_atom(position=(1., 1., 1.), symbols=['Ti']) - self.assertEqual(len(a.kinds), 2) # I should only have two types + assert len(a.kinds) == 2 # I should only have two types # I check for the default names of kinds - self.assertEqual(set(k.name for k in a.kinds), set(('X', 'Ti'))) + assert set(k.name for k in a.kinds) == set(('X', 'Ti')) def test_kind_2(self): """ @@ -1147,8 +1144,8 @@ def test_kind_2(self): a.append_atom(position=(1., 1., 1.), symbols=['Ti']) kind_list = a.kinds - self.assertEqual(len(kind_list), 3) # I should have now three kinds - self.assertEqual(set(k.name for k in kind_list), set(('Ba1', 'Ba2', 'Ti'))) + assert len(kind_list) == 3 # I should have now three kinds + assert set(k.name for k in kind_list) == set(('Ba1', 'Ba2', 'Ti')) def test_kind_2_unknown(self): """ @@ -1162,8 +1159,8 @@ def test_kind_2_unknown(self): a.append_atom(position=(1., 1., 1.), symbols=['Ti']) kind_list = a.kinds - self.assertEqual(len(kind_list), 3) # I should have now three kinds - self.assertEqual(set(k.name for k in kind_list), set(('X1', 'X2', 'Ti'))) + assert len(kind_list) == 3 # I should have now three kinds + assert set(k.name for k in kind_list) == set(('X1', 'X2', 'Ti')) def test_kind_3(self): """ @@ -1172,7 +1169,7 @@ def test_kind_3(self): a = StructureData(cell=((2., 0., 0.), (0., 2., 0.), (0., 0., 2.))) a.append_atom(position=(0., 0., 0.), symbols=['Ba'], mass=100.) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, I am adding two sites with the same name 'Ba' a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Ba'], mass=101., name='Ba') @@ -1181,9 +1178,9 @@ def test_kind_3(self): a.append_atom(position=(1., 1., 1.), symbols=['Ti']) - self.assertEqual(len(a.kinds), 3) # I should have now three types - self.assertEqual(len(a.sites), 3) # and 3 sites - self.assertEqual(set(k.name for k in a.kinds), set(('Ba', 'Ba2', 'Ti'))) + assert len(a.kinds) == 3 # I should have now three types + assert len(a.sites) == 3 # and 3 sites + assert set(k.name for k in a.kinds) == set(('Ba', 'Ba2', 'Ti')) def test_kind_3_unknown(self): """ @@ -1193,7 +1190,7 @@ def test_kind_3_unknown(self): a = StructureData(cell=((2., 0., 0.), (0., 2., 0.), (0., 0., 2.))) a.append_atom(position=(0., 0., 0.), symbols=['X'], mass=100.) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, I am adding two sites with the same name 'Ba' a.append_atom(position=(0.5, 0.5, 0.5), symbols=['X'], mass=101., name='X') @@ -1202,9 +1199,9 @@ def test_kind_3_unknown(self): a.append_atom(position=(1., 1., 1.), symbols=['Ti']) - self.assertEqual(len(a.kinds), 3) # I should have now three types - self.assertEqual(len(a.sites), 3) # and 3 sites - self.assertEqual(set(k.name for k in a.kinds), set(('X', 'X2', 'Ti'))) + assert len(a.kinds) == 3 # I should have now three types + assert len(a.sites) == 3 # and 3 sites + assert set(k.name for k in a.kinds) == set(('X', 'X2', 'Ti')) def test_kind_4(self): """ @@ -1215,26 +1212,26 @@ def test_kind_4(self): a.append_atom(position=(0., 0., 0.), symbols=['Ba', 'Ti'], weights=(1., 0.), name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different weights a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Ba', 'Ti'], weights=(0.9, 0.1), name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different weights (with vacancy) a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Ba', 'Ti'], weights=(0.8, 0.1), name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different symbols list a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Ba'], name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different symbols list a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Si', 'Ti'], weights=(1., 0.), name='mytype') # should allow because every property is identical a.append_atom(position=(0., 0., 0.), symbols=['Ba', 'Ti'], weights=(1., 0.), name='mytype') - self.assertEqual(len(a.kinds), 1) + assert len(a.kinds) == 1 def test_kind_4_unknown(self): """ @@ -1245,26 +1242,26 @@ def test_kind_4_unknown(self): a.append_atom(position=(0., 0., 0.), symbols=['X', 'Ti'], weights=(1., 0.), name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different weights a.append_atom(position=(0.5, 0.5, 0.5), symbols=['X', 'Ti'], weights=(0.9, 0.1), name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different weights (with vacancy) a.append_atom(position=(0.5, 0.5, 0.5), symbols=['X', 'Ti'], weights=(0.8, 0.1), name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different symbols list a.append_atom(position=(0.5, 0.5, 0.5), symbols=['X'], name='mytype') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Shouldn't allow, different symbols list a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Si', 'Ti'], weights=(1., 0.), name='mytype') # should allow because every property is identical a.append_atom(position=(0., 0., 0.), symbols=['X', 'Ti'], weights=(1., 0.), name='mytype') - self.assertEqual(len(a.kinds), 1) + assert len(a.kinds) == 1 def test_kind_5(self): """ @@ -1280,15 +1277,15 @@ def test_kind_5(self): # The name already exists, but the properties are identical => OK a.append_atom(position=(1., 1., 1.), symbols='Ti', name='Ti2') # The name already exists, but the properties are different! - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.append_atom(position=(1., 1., 1.), symbols='Ti', mass=100., name='Ti2') # Should not complain, should create a new type a.append_atom(position=(0., 0., 0.), symbols='Ba', mass=150.) # There should be 4 kinds, the automatic name for the last one # should be Ba1 - self.assertEqual([k.name for k in a.kinds], ['Ba', 'Ti', 'Ti2', 'Ba1']) - self.assertEqual(len(a.sites), 5) + assert [k.name for k in a.kinds] == ['Ba', 'Ti', 'Ti2', 'Ba1'] + assert len(a.sites) == 5 def test_kind_5_unknown(self): """ @@ -1305,15 +1302,15 @@ def test_kind_5_unknown(self): # The name already exists, but the properties are identical => OK a.append_atom(position=(1., 1., 1.), symbols='Ti', name='Ti2') # The name already exists, but the properties are different! - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.append_atom(position=(1., 1., 1.), symbols='Ti', mass=100., name='Ti2') # Should not complain, should create a new type a.append_atom(position=(0., 0., 0.), symbols='X', mass=150.) # There should be 4 kinds, the automatic name for the last one # should be Ba1 - self.assertEqual([k.name for k in a.kinds], ['X', 'Ti', 'Ti2', 'X1']) - self.assertEqual(len(a.sites), 5) + assert [k.name for k in a.kinds] == ['X', 'Ti', 'Ti2', 'X1'] + assert len(a.sites) == 5 def test_kind_5_bis(self): """Test the management of kinds (automatic creation of new kind @@ -1332,10 +1329,10 @@ def test_kind_5_bis(self): # I expect only two species, the first one with name 'Fe', mass 12, # and referencing the first three atoms; the second with name # 'Fe1', mass = elements[26]['mass'], and referencing the last two atoms - self.assertEqual({(k.name, k.mass) for k in s.kinds}, {('Fe', 12.0), ('Fe1', elements[26]['mass'])}) + assert {(k.name, k.mass) for k in s.kinds} == {('Fe', 12.0), ('Fe1', elements[26]['mass'])} kind_of_each_site = [site.kind_name for site in s.sites] - self.assertEqual(kind_of_each_site, ['Fe', 'Fe', 'Fe', 'Fe1', 'Fe1']) + assert kind_of_each_site == ['Fe', 'Fe', 'Fe', 'Fe1', 'Fe1'] def test_kind_5_bis_unknown(self): """Test the management of kinds (automatic creation of new kind @@ -1355,12 +1352,12 @@ def test_kind_5_bis_unknown(self): # I expect only two species, the first one with name 'X', mass 12, # and referencing the first three atoms; the second with name # 'X', mass = elements[0]['mass'], and referencing the last two atoms - self.assertEqual({(k.name, k.mass) for k in s.kinds}, {('X', 12.0), ('X1', elements[0]['mass'])}) + assert {(k.name, k.mass) for k in s.kinds} == {('X', 12.0), ('X1', elements[0]['mass'])} kind_of_each_site = [site.kind_name for site in s.sites] - self.assertEqual(kind_of_each_site, ['X', 'X', 'X', 'X1', 'X1']) + assert kind_of_each_site == ['X', 'X', 'X', 'X1', 'X1'] - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') def test_kind_5_bis_ase(self): """ Same test as test_kind_5_bis, but using ase @@ -1385,12 +1382,12 @@ def test_kind_5_bis_ase(self): # I expect only two species, the first one with name 'Fe', mass 12, # and referencing the first three atoms; the second with name # 'Fe1', mass = elements[26]['mass'], and referencing the last two atoms - self.assertEqual({(k.name, k.mass) for k in s.kinds}, {('Fe', 12.0), ('Fe1', asecell[3].mass)}) + assert {(k.name, k.mass) for k in s.kinds} == {('Fe', 12.0), ('Fe1', asecell[3].mass)} kind_of_each_site = [site.kind_name for site in s.sites] - self.assertEqual(kind_of_each_site, ['Fe', 'Fe', 'Fe', 'Fe1', 'Fe1']) + assert kind_of_each_site == ['Fe', 'Fe', 'Fe', 'Fe1', 'Fe1'] - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') def test_kind_5_bis_ase_unknown(self): """ Same test as test_kind_5_bis_unknown, but using ase @@ -1415,10 +1412,10 @@ def test_kind_5_bis_ase_unknown(self): # I expect only two species, the first one with name 'X', mass 12, # and referencing the first three atoms; the second with name # 'X1', mass = elements[26]['mass'], and referencing the last two atoms - self.assertEqual({(k.name, k.mass) for k in s.kinds}, {('X', 12.0), ('X1', asecell[3].mass)}) + assert {(k.name, k.mass) for k in s.kinds} == {('X', 12.0), ('X1', asecell[3].mass)} kind_of_each_site = [site.kind_name for site in s.sites] - self.assertEqual(kind_of_each_site, ['X', 'X', 'X', 'X1', 'X1']) + assert kind_of_each_site == ['X', 'X', 'X', 'X1', 'X1'] def test_kind_6(self): """ @@ -1437,15 +1434,15 @@ def test_kind_6(self): a.append_atom(position=(0., 0., 0.), symbols='Ba', mass=150.) # There should be 4 kinds, the automatic name for the last one # should be Ba1 (same check of test_kind_5 - self.assertEqual([k.name for k in a.kinds], ['Ba', 'Ti', 'Ti2', 'Ba1']) + assert [k.name for k in a.kinds] == ['Ba', 'Ti', 'Ti2', 'Ba1'] ############################# # Here I start the real tests # No such kind - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.get_kind('Ti3') k = a.get_kind('Ba1') - self.assertEqual(k.symbols, ('Ba',)) - self.assertAlmostEqual(k.mass, 150.) + assert k.symbols == ('Ba',) + assert round(abs(k.mass - 150.), 7) == 0 def test_kind_6_unknown(self): """ @@ -1464,15 +1461,15 @@ def test_kind_6_unknown(self): a.append_atom(position=(0., 0., 0.), symbols='X', mass=150.) # There should be 4 kinds, the automatic name for the last one # should be Ba1 (same check of test_kind_5 - self.assertEqual([k.name for k in a.kinds], ['X', 'Ti', 'Ti2', 'X1']) + assert [k.name for k in a.kinds] == ['X', 'Ti', 'Ti2', 'X1'] ############################# # Here I start the real tests # No such kind - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.get_kind('Ti3') k = a.get_kind('X1') - self.assertEqual(k.symbols, ('X',)) - self.assertAlmostEqual(k.mass, 150.) + assert k.symbols == ('X',) + assert round(abs(k.mass - 150.), 7) == 0 def test_kind_7(self): """ @@ -1487,7 +1484,7 @@ def test_kind_7(self): # The name already exists, but the properties are identical => OK a.append_atom(position=(0., 0., 0.), symbols=['O', 'H'], weights=[0.9, 0.1], mass=15.) - self.assertEqual(a.get_symbols_set(), set(['Ba', 'Ti', 'O', 'H'])) + assert a.get_symbols_set() == set(['Ba', 'Ti', 'O', 'H']) def test_kind_7_unknown(self): """ @@ -1503,10 +1500,10 @@ def test_kind_7_unknown(self): # The name already exists, but the properties are identical => OK a.append_atom(position=(0., 0., 0.), symbols=['O', 'H'], weights=[0.9, 0.1], mass=15.) - self.assertEqual(a.get_symbols_set(), set(['Ba', 'X', 'O', 'H'])) + assert a.get_symbols_set() == set(['Ba', 'X', 'O', 'H']) - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_spglib(), 'Unable to import spglib') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') + @pytest.mark.skipif(not has_spglib(), reason='Unable to import spglib') def test_kind_8(self): """ Test the ase_refine_cell() function @@ -1523,9 +1520,9 @@ def test_kind_8(self): b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.get_chemical_symbols(), ['C']) - self.assertEqual(b.cell.tolist(), [[10, 0, 0], [0, 10, 0], [0, 0, 5]]) - self.assertEqual(sym, {'hall': '-P 4 2', 'hm': 'P4/mmm', 'tables': 123}) + assert b.get_chemical_symbols() == ['C'] + assert b.cell.tolist() == [[10, 0, 0], [0, 10, 0], [0, 0, 5]] + assert sym == {'hall': '-P 4 2', 'hm': 'P4/mmm', 'tables': 123} a = ase.Atoms(cell=[10, 2 * math.sqrt(75), 10]) a.append(ase.Atom('C', [0, 0, 0])) @@ -1533,9 +1530,9 @@ def test_kind_8(self): b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.get_chemical_symbols(), ['C']) - self.assertEqual(numpy.round(b.cell, 2).tolist(), [[10, 0, 0], [-5, 8.66, 0], [0, 0, 10]]) - self.assertEqual(sym, {'hall': '-P 6 2', 'hm': 'P6/mmm', 'tables': 191}) + assert b.get_chemical_symbols() == ['C'] + assert numpy.round(b.cell, 2).tolist() == [[10, 0, 0], [-5, 8.66, 0], [0, 0, 10]] + assert sym == {'hall': '-P 6 2', 'hm': 'P6/mmm', 'tables': 191} a = ase.Atoms(cell=[[10, 0, 0], [-10, 10, 0], [0, 0, 10]]) a.append(ase.Atom('C', [5, 5, 5])) @@ -1543,10 +1540,10 @@ def test_kind_8(self): b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.get_chemical_symbols(), ['C', 'F']) - self.assertEqual(b.cell.tolist(), [[10, 0, 0], [0, 10, 0], [0, 0, 10]]) - self.assertEqual(b.get_scaled_positions().tolist(), [[0.5, 0.5, 0.5], [0, 0, 0]]) - self.assertEqual(sym, {'hall': '-P 4 2 3', 'hm': 'Pm-3m', 'tables': 221}) + assert b.get_chemical_symbols() == ['C', 'F'] + assert b.cell.tolist() == [[10, 0, 0], [0, 10, 0], [0, 0, 10]] + assert b.get_scaled_positions().tolist() == [[0.5, 0.5, 0.5], [0, 0, 0]] + assert sym == {'hall': '-P 4 2 3', 'hm': 'Pm-3m', 'tables': 221} a = ase.Atoms(cell=[[10, 0, 0], [-10, 10, 0], [0, 0, 10]]) a.append(ase.Atom('C', [0, 0, 0])) @@ -1554,18 +1551,18 @@ def test_kind_8(self): b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.get_chemical_symbols(), ['C', 'F']) - self.assertEqual(b.cell.tolist(), [[10, 0, 0], [0, 10, 0], [0, 0, 10]]) - self.assertEqual(b.get_scaled_positions().tolist(), [[0, 0, 0], [0.5, 0.5, 0.5]]) - self.assertEqual(sym, {'hall': '-P 4 2 3', 'hm': 'Pm-3m', 'tables': 221}) + assert b.get_chemical_symbols() == ['C', 'F'] + assert b.cell.tolist() == [[10, 0, 0], [0, 10, 0], [0, 0, 10]] + assert b.get_scaled_positions().tolist() == [[0, 0, 0], [0.5, 0.5, 0.5]] + assert sym == {'hall': '-P 4 2 3', 'hm': 'Pm-3m', 'tables': 221} a = ase.Atoms(cell=[[12.132, 0, 0], [0, 6.0606, 0], [0, 0, 8.0956]]) a.append(ase.Atom('Ba', [1.5334848, 1.3999986, 2.00042276])) b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.cell.tolist(), [[6.0606, 0, 0], [0, 8.0956, 0], [0, 0, 12.132]]) - self.assertEqual(b.get_scaled_positions().tolist(), [[0, 0, 0]]) + assert b.cell.tolist() == [[6.0606, 0, 0], [0, 8.0956, 0], [0, 0, 12.132]] + assert b.get_scaled_positions().tolist() == [[0, 0, 0]] a = ase.Atoms(cell=[10, 10, 10]) a.append(ase.Atom('C', [5, 5, 5])) @@ -1574,8 +1571,8 @@ def test_kind_8(self): b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.get_chemical_symbols(), ['C', 'O']) - self.assertEqual(sym, {'hall': '-P 4 2', 'hm': 'P4/mmm', 'tables': 123}) + assert b.get_chemical_symbols() == ['C', 'O'] + assert sym == {'hall': '-P 4 2', 'hm': 'P4/mmm', 'tables': 123} # Generated from COD entry 1507756 # (http://www.crystallography.net/cod/1507756.cif@87343) @@ -1586,8 +1583,8 @@ def test_kind_8(self): b, sym = ase_refine_cell(a) sym.pop('rotations') sym.pop('translations') - self.assertEqual(b.get_chemical_symbols(), ['Ba', 'Ti', 'O', 'O']) - self.assertEqual(sym, {'hall': 'P 4 -2', 'hm': 'P4mm', 'tables': 99}) + assert b.get_chemical_symbols() == ['Ba', 'Ti', 'O', 'O'] + assert sym == {'hall': 'P 4 -2', 'hm': 'P4mm', 'tables': 99} def test_get_formula(self): """ @@ -1595,28 +1592,28 @@ def test_get_formula(self): """ from aiida.orm.nodes.data.structure import get_formula - self.assertEqual(get_formula(['Ba', 'Ti'] + ['O'] * 3), 'BaO3Ti') - self.assertEqual(get_formula(['Ba', 'Ti', 'C'] + ['O'] * 3, separator=' '), 'C Ba O3 Ti') - self.assertEqual(get_formula(['H'] * 6 + ['C'] * 6), 'C6H6') - self.assertEqual(get_formula(['H'] * 6 + ['C'] * 6, mode='hill_compact'), 'CH') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ + assert get_formula(['Ba', 'Ti'] + ['O'] * 3) == 'BaO3Ti' + assert get_formula(['Ba', 'Ti', 'C'] + ['O'] * 3, separator=' ') == 'C Ba O3 Ti' + assert get_formula(['H'] * 6 + ['C'] * 6) == 'C6H6' + assert get_formula(['H'] * 6 + ['C'] * 6, mode='hill_compact') == 'CH' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ ['Ba'] + ['Ti'] * 2 + ['O'] * 3, - mode='group'), - '(BaTiO3)2BaTi2O3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ + mode='group') == \ + '(BaTiO3)2BaTi2O3' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ ['Ba'] + ['Ti'] * 2 + ['O'] * 3, - mode='group', separator=' '), - '(Ba Ti O3)2 Ba Ti2 O3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ + mode='group', separator=' ') == \ + '(Ba Ti O3)2 Ba Ti2 O3' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ ['Ba'] + ['Ti'] * 2 + ['O'] * 3, - mode='reduce'), - 'BaTiO3BaTiO3BaTi2O3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ + mode='reduce') == \ + 'BaTiO3BaTiO3BaTi2O3' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2 + \ ['Ba'] + ['Ti'] * 2 + ['O'] * 3, - mode='reduce', separator=', '), - 'Ba, Ti, O3, Ba, Ti, O3, Ba, Ti2, O3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2, mode='count'), 'Ba2Ti2O6') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2, mode='count_compact'), 'BaTiO3') + mode='reduce', separator=', ') == \ + 'Ba, Ti, O3, Ba, Ti, O3, Ba, Ti2, O3' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2, mode='count') == 'Ba2Ti2O6' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2, mode='count_compact') == 'BaTiO3' def test_get_formula_unknown(self): """ @@ -1624,31 +1621,31 @@ def test_get_formula_unknown(self): """ from aiida.orm.nodes.data.structure import get_formula - self.assertEqual(get_formula(['Ba', 'Ti'] + ['X'] * 3), 'BaTiX3') - self.assertEqual(get_formula(['Ba', 'Ti', 'C'] + ['X'] * 3, separator=' '), 'C Ba Ti X3') - self.assertEqual(get_formula(['X'] * 6 + ['C'] * 6), 'C6X6') - self.assertEqual(get_formula(['X'] * 6 + ['C'] * 6, mode='hill_compact'), 'CX') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ + assert get_formula(['Ba', 'Ti'] + ['X'] * 3) == 'BaTiX3' + assert get_formula(['Ba', 'Ti', 'C'] + ['X'] * 3, separator=' ') == 'C Ba Ti X3' + assert get_formula(['X'] * 6 + ['C'] * 6) == 'C6X6' + assert get_formula(['X'] * 6 + ['C'] * 6, mode='hill_compact') == 'CX' + assert get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ ['Ba'] + ['X'] * 2 + ['O'] * 3, - mode='group'), - '(BaTiX3)2BaX2O3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ + mode='group') == \ + '(BaTiX3)2BaX2O3' + assert get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ ['Ba'] + ['X'] * 2 + ['O'] * 3, - mode='group', separator=' '), - '(Ba Ti X3)2 Ba X2 O3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ + mode='group', separator=' ') == \ + '(Ba Ti X3)2 Ba X2 O3' + assert get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ ['Ba'] + ['Ti'] * 2 + ['X'] * 3, - mode='reduce'), - 'BaTiX3BaTiX3BaTi2X3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ + mode='reduce') == \ + 'BaTiX3BaTiX3BaTi2X3' + assert get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2 + \ ['Ba'] + ['Ti'] * 2 + ['X'] * 3, - mode='reduce', separator=', '), - 'Ba, Ti, X3, Ba, Ti, X3, Ba, Ti2, X3') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2, mode='count'), 'Ba2Ti2O6') - self.assertEqual(get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2, mode='count_compact'), 'BaTiX3') + mode='reduce', separator=', ') == \ + 'Ba, Ti, X3, Ba, Ti, X3, Ba, Ti2, X3' + assert get_formula((['Ba', 'Ti'] + ['O'] * 3) * 2, mode='count') == 'Ba2Ti2O6' + assert get_formula((['Ba', 'Ti'] + ['X'] * 3) * 2, mode='count_compact') == 'BaTiX3' - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') @pytest.mark.requires_rmq def test_get_cif(self): """ @@ -1668,8 +1665,7 @@ def test_get_cif(self): for line in lines: if not re.search('^#', line): non_comments.append(line) - self.assertEqual( - simplify('\n'.join(non_comments)), + assert simplify('\n'.join(non_comments)) == \ simplify( """ data_0 @@ -1700,7 +1696,6 @@ def test_get_cif(self): _chemical_formula_sum 'Ba2 Ti' """ ) - ) def test_xyz_parser(self): """Test XYZ parser.""" @@ -1732,13 +1727,13 @@ def test_xyz_parser(self): s._parse_xyz(xyz_string) # pylint: disable=protected-access # Making sure that the periodic boundary condition are not True # because I cannot parse a cell! - self.assertTrue(not any(s.pbc)) + assert not any(s.pbc) # Making sure that the structure has sites, kinds and a cell - self.assertTrue(s.sites) - self.assertTrue(s.kinds) - self.assertTrue(s.cell) + assert s.sites + assert s.kinds + assert s.cell # The default cell is given in these cases: - self.assertEqual(s.cell, np.diag([1, 1, 1]).tolist()) + assert s.cell == np.diag([1, 1, 1]).tolist() # Testing a case where 1 xyz_string4 = """ @@ -1764,11 +1759,12 @@ def test_xyz_parser(self): # The above cases have to fail because the number of atoms is wrong for xyz_string in (xyz_string4, xyz_string5, xyz_string6): - with self.assertRaises(TypeError): + with pytest.raises(TypeError): StructureData()._parse_xyz(xyz_string) # pylint: disable=protected-access -class TestStructureDataLock(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestStructureDataLock: """Tests that the structure is locked after storage.""" def test_lock(self): @@ -1789,17 +1785,17 @@ def test_lock(self): k2 = Kind(symbols='Ba', name='Ba') # Nothing should be changed after store() - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): a.append_kind(k2) - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): a.append_site(s) - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): a.clear_sites() - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): a.clear_kinds() - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): a.cell = cell - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): a.pbc = [True, True, True] _ = a.get_cell_volume() @@ -1811,12 +1807,13 @@ def test_lock(self): b.append_site(s) b.clear_sites() # I check that the original did not change - self.assertNotEqual(len(a.sites), 0) + assert len(a.sites) != 0 b.cell = cell b.pbc = [True, True, True] -class TestStructureDataReload(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestStructureDataReload: """ Tests the creation of StructureData, converting it to a raw format and converting it back. @@ -1840,32 +1837,32 @@ def test_reload(self): for i in range(3): for j in range(3): - self.assertAlmostEqual(cell[i][j], b.cell[i][j]) + assert round(abs(cell[i][j] - b.cell[i][j]), 7) == 0 - self.assertEqual(b.pbc, (False, True, True)) - self.assertEqual(len(b.sites), 2) - self.assertEqual(b.kinds[0].symbols[0], 'Ba') - self.assertEqual(b.kinds[1].symbols[0], 'Ti') + assert b.pbc == (False, True, True) + assert len(b.sites) == 2 + assert b.kinds[0].symbols[0] == 'Ba' + assert b.kinds[1].symbols[0] == 'Ti' for i in range(3): - self.assertAlmostEqual(b.sites[0].position[i], 0.) + assert round(abs(b.sites[0].position[i] - 0.), 7) == 0 for i in range(3): - self.assertAlmostEqual(b.sites[1].position[i], 1.) + assert round(abs(b.sites[1].position[i] - 1.), 7) == 0 # Fully reload from UUID b = load_node(a.uuid, sub_classes=(StructureData,)) for i in range(3): for j in range(3): - self.assertAlmostEqual(cell[i][j], b.cell[i][j]) + assert round(abs(cell[i][j] - b.cell[i][j]), 7) == 0 - self.assertEqual(b.pbc, (False, True, True)) - self.assertEqual(len(b.sites), 2) - self.assertEqual(b.kinds[0].symbols[0], 'Ba') - self.assertEqual(b.kinds[1].symbols[0], 'Ti') + assert b.pbc == (False, True, True) + assert len(b.sites) == 2 + assert b.kinds[0].symbols[0] == 'Ba' + assert b.kinds[1].symbols[0] == 'Ti' for i in range(3): - self.assertAlmostEqual(b.sites[0].position[i], 0.) + assert round(abs(b.sites[0].position[i] - 0.), 7) == 0 for i in range(3): - self.assertAlmostEqual(b.sites[1].position[i], 1.) + assert round(abs(b.sites[1].position[i] - 1.), 7) == 0 def test_clone(self): """ @@ -1883,17 +1880,17 @@ def test_clone(self): for i in range(3): for j in range(3): - self.assertAlmostEqual(cell[i][j], b.cell[i][j]) + assert round(abs(cell[i][j] - b.cell[i][j]), 7) == 0 - self.assertEqual(b.pbc, (False, True, True)) - self.assertEqual(len(b.kinds), 2) - self.assertEqual(len(b.sites), 2) - self.assertEqual(b.kinds[0].symbols[0], 'Ba') - self.assertEqual(b.kinds[1].symbols[0], 'Ti') + assert b.pbc == (False, True, True) + assert len(b.kinds) == 2 + assert len(b.sites) == 2 + assert b.kinds[0].symbols[0] == 'Ba' + assert b.kinds[1].symbols[0] == 'Ti' for i in range(3): - self.assertAlmostEqual(b.sites[0].position[i], 0.) + assert round(abs(b.sites[0].position[i] - 0.), 7) == 0 for i in range(3): - self.assertAlmostEqual(b.sites[1].position[i], 1.) + assert round(abs(b.sites[1].position[i] - 1.), 7) == 0 a.store() @@ -1901,24 +1898,25 @@ def test_clone(self): c = a.clone() for i in range(3): for j in range(3): - self.assertAlmostEqual(cell[i][j], c.cell[i][j]) + assert round(abs(cell[i][j] - c.cell[i][j]), 7) == 0 - self.assertEqual(c.pbc, (False, True, True)) - self.assertEqual(len(c.kinds), 2) - self.assertEqual(len(c.sites), 2) - self.assertEqual(c.kinds[0].symbols[0], 'Ba') - self.assertEqual(c.kinds[1].symbols[0], 'Ti') + assert c.pbc == (False, True, True) + assert len(c.kinds) == 2 + assert len(c.sites) == 2 + assert c.kinds[0].symbols[0] == 'Ba' + assert c.kinds[1].symbols[0] == 'Ti' for i in range(3): - self.assertAlmostEqual(c.sites[0].position[i], 0.) + assert round(abs(c.sites[0].position[i] - 0.), 7) == 0 for i in range(3): - self.assertAlmostEqual(c.sites[1].position[i], 1.) + assert round(abs(c.sites[1].position[i] - 1.), 7) == 0 -class TestStructureDataFromAse(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestStructureDataFromAse: """Tests the creation of Sites from/to a ASE object.""" from aiida.orm.nodes.data.structure import has_ase - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') def test_ase(self): """Tests roundtrip ASE -> StructureData -> ASE.""" import ase @@ -1933,17 +1931,17 @@ def test_ase(self): b = StructureData(ase=a) c = b.get_ase() - self.assertEqual(a[0].symbol, c[0].symbol) - self.assertEqual(a[1].symbol, c[1].symbol) + assert a[0].symbol == c[0].symbol + assert a[1].symbol == c[1].symbol for i in range(3): - self.assertAlmostEqual(a[0].position[i], c[0].position[i]) + assert round(abs(a[0].position[i] - c[0].position[i]), 7) == 0 for i in range(3): for j in range(3): - self.assertAlmostEqual(a.cell[i][j], c.cell[i][j]) + assert round(abs(a.cell[i][j] - c.cell[i][j]), 7) == 0 - self.assertAlmostEqual(c[1].mass, 110.2) + assert round(abs(c[1].mass - 110.2), 7) == 0 - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') def test_conversion_of_types_1(self): """ Tests roundtrip ASE -> StructureData -> ASE, with tags @@ -1965,14 +1963,14 @@ def test_conversion_of_types_1(self): a.set_tags((0, 1, 2, 3, 4, 5, 6, 7)) b = StructureData(ase=a) - self.assertEqual([k.name for k in b.kinds], ['Si', 'Si1', 'Si2', 'Si3', 'Ge4', 'Ge5', 'Ge6', 'Ge7']) + assert [k.name for k in b.kinds] == ['Si', 'Si1', 'Si2', 'Si3', 'Ge4', 'Ge5', 'Ge6', 'Ge7'] c = b.get_ase() a_tags = list(a.get_tags()) c_tags = list(c.get_tags()) - self.assertEqual(a_tags, c_tags) + assert a_tags == c_tags - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') def test_conversion_of_types_2(self): """ Tests roundtrip ASE -> StructureData -> ASE, with tags, and @@ -1996,7 +1994,7 @@ def test_conversion_of_types_2(self): # This will give funny names to the kinds, because I am using # both tags and different properties (mass). I just check to have # 4 kinds - self.assertEqual(len(b.kinds), 4) + assert len(b.kinds) == 4 # Do I get the same tags after one full iteration back and forth? c = b.get_ase() @@ -2004,9 +2002,9 @@ def test_conversion_of_types_2(self): e = d.get_ase() c_tags = list(c.get_tags()) e_tags = list(e.get_tags()) - self.assertEqual(c_tags, e_tags) + assert c_tags == e_tags - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') def test_conversion_of_types_3(self): """ Tests StructureData -> ASE, with all sorts of kind names @@ -2028,13 +2026,13 @@ def test_conversion_of_types_3(self): # Just to be sure that the species were saved with the correct name # in the first place - self.assertEqual([k.name for k in a.kinds], ['Ba', 'Ba1', 'Cu', 'Cu2', 'Cu_my', 'a_name', 'Fe', 'cu1']) + assert [k.name for k in a.kinds] == ['Ba', 'Ba1', 'Cu', 'Cu2', 'Cu_my', 'a_name', 'Fe', 'cu1'] b = a.get_ase() - self.assertEqual(b.get_chemical_symbols(), ['Ba', 'Ba', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu']) - self.assertEqual(list(b.get_tags()), [0, 1, 0, 2, 3, 4, 5, 6]) + assert b.get_chemical_symbols() == ['Ba', 'Ba', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu'] + assert list(b.get_tags()) == [0, 1, 0, 2, 3, 4, 5, 6] - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') def test_conversion_of_types_4(self): """ Tests ASE -> StructureData -> ASE, in particular conversion tags / kind names @@ -2048,14 +2046,14 @@ def test_conversion_of_types_4(self): atoms.set_cell([1, 1, 1]) s = StructureData(ase=atoms) kindnames = {k.name for k in s.kinds} - self.assertEqual(kindnames, set(['Fe', 'Fe1', 'Fe4'])) + assert kindnames == set(['Fe', 'Fe1', 'Fe4']) # check roundtrip ASE -> StructureData -> ASE atoms2 = s.get_ase() - self.assertEqual(list(atoms2.get_tags()), list(atoms.get_tags())) - self.assertEqual(list(atoms2.get_chemical_symbols()), list(atoms.get_chemical_symbols())) - self.assertEqual(atoms2.get_chemical_formula(), 'Fe5') + assert list(atoms2.get_tags()) == list(atoms.get_tags()) + assert list(atoms2.get_chemical_symbols()) == list(atoms.get_chemical_symbols()) + assert atoms2.get_chemical_formula() == 'Fe5' - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') def test_conversion_of_types_5(self): """ Tests ASE -> StructureData -> ASE, in particular conversion tags / kind names @@ -2070,14 +2068,14 @@ def test_conversion_of_types_5(self): atoms.set_cell([1, 1, 1]) s = StructureData(ase=atoms) kindnames = {k.name for k in s.kinds} - self.assertEqual(kindnames, set(['Fe', 'Fe1', 'Fe4'])) + assert kindnames == set(['Fe', 'Fe1', 'Fe4']) # check roundtrip ASE -> StructureData -> ASE atoms2 = s.get_ase() - self.assertEqual(list(atoms2.get_tags()), list(atoms.get_tags())) - self.assertEqual(list(atoms2.get_chemical_symbols()), list(atoms.get_chemical_symbols())) - self.assertEqual(atoms2.get_chemical_formula(), 'Fe5') + assert list(atoms2.get_tags()) == list(atoms.get_tags()) + assert list(atoms2.get_chemical_symbols()) == list(atoms.get_chemical_symbols()) + assert atoms2.get_chemical_formula() == 'Fe5' - @unittest.skipIf(not has_ase(), 'Unable to import ase') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') def test_conversion_of_types_6(self): """ Tests roundtrip StructureData -> ASE -> StructureData, with tags/kind names @@ -2089,23 +2087,24 @@ def test_conversion_of_types_6(self): a.append_atom(position=(1, 3, 1), symbols='Cl', name='Cl') b = a.get_ase() - self.assertEqual(b.get_chemical_symbols(), ['Ni', 'Ni', 'Cl', 'Cl']) - self.assertEqual(list(b.get_tags()), [1, 2, 0, 0]) + assert b.get_chemical_symbols() == ['Ni', 'Ni', 'Cl', 'Cl'] + assert list(b.get_tags()) == [1, 2, 0, 0] c = StructureData(ase=b) - self.assertEqual(c.get_site_kindnames(), ['Ni1', 'Ni2', 'Cl', 'Cl']) - self.assertEqual([k.symbol for k in c.kinds], ['Ni', 'Ni', 'Cl']) - self.assertEqual([s.position for s in c.sites], [(0., 0., 0.), (2., 2., 2.), (1., 0., 1.), (1., 3., 1.)]) + assert c.get_site_kindnames() == ['Ni1', 'Ni2', 'Cl', 'Cl'] + assert [k.symbol for k in c.kinds] == ['Ni', 'Ni', 'Cl'] + assert [s.position for s in c.sites] == [(0., 0., 0.), (2., 2., 2.), (1., 0., 1.), (1., 3., 1.)] -class TestStructureDataFromPymatgen(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestStructureDataFromPymatgen: """ Tests the creation of StructureData from a pymatgen Structure and Molecule objects. """ from aiida.orm.nodes.data.structure import has_pymatgen, get_pymatgen_version - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') def test_1(self): """ Tests roundtrip pymatgen -> StructureData -> pymatgen @@ -2149,17 +2148,17 @@ def test_1(self): structs_to_test = [StructureData(pymatgen=pymatgen_struct), StructureData(pymatgen_structure=pymatgen_struct)] for struct in structs_to_test: - self.assertEqual(struct.get_site_kindnames(), ['Bi', 'Bi', 'SeTe', 'SeTe', 'SeTe']) + assert struct.get_site_kindnames() == ['Bi', 'Bi', 'SeTe', 'SeTe', 'SeTe'] # Pymatgen's Composition does not guarantee any particular ordering of the kinds, # see the definition of its internal datatype at # pymatgen/core/composition.py#L135 (d4fe64c18a52949a4e22bfcf7b45de5b87242c51) - self.assertEqual([sorted(x.symbols) for x in struct.kinds], [[ + assert [sorted(x.symbols) for x in struct.kinds] == [[ 'Bi', - ], ['Se', 'Te']]) - self.assertEqual([sorted(x.weights) for x in struct.kinds], [[ + ], ['Se', 'Te']] + assert [sorted(x.weights) for x in struct.kinds] == [[ 1.0, - ], [0.33333, 0.66667]]) + ], [0.33333, 0.66667]] struct = StructureData(pymatgen_structure=pymatgen_struct) @@ -2190,7 +2189,7 @@ def recursively_compare_values(left, right): recursively_compare_values(dict1, dict2) - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') def test_2(self): """ Tests xyz -> pymatgen -> StructureData @@ -2213,15 +2212,15 @@ def test_2(self): pymatgen_mol = pymatgen_xyz.molecule for struct in [StructureData(pymatgen=pymatgen_mol), StructureData(pymatgen_molecule=pymatgen_mol)]: - self.assertEqual(struct.get_site_kindnames(), ['H', 'H', 'H', 'H', 'C']) - self.assertEqual(struct.pbc, (False, False, False)) - self.assertEqual([round(x, 2) for x in list(struct.sites[0].position)], [5.77, 5.89, 6.81]) - self.assertEqual([round(x, 2) for x in list(struct.sites[1].position)], [6.8, 5.89, 5.36]) - self.assertEqual([round(x, 2) for x in list(struct.sites[2].position)], [5.26, 5.0, 5.36]) - self.assertEqual([round(x, 2) for x in list(struct.sites[3].position)], [5.26, 6.78, 5.36]) - self.assertEqual([round(x, 2) for x in list(struct.sites[4].position)], [5.77, 5.89, 5.73]) - - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + assert struct.get_site_kindnames() == ['H', 'H', 'H', 'H', 'C'] + assert struct.pbc == (False, False, False) + assert [round(x, 2) for x in list(struct.sites[0].position)] == [5.77, 5.89, 6.81] + assert [round(x, 2) for x in list(struct.sites[1].position)] == [6.8, 5.89, 5.36] + assert [round(x, 2) for x in list(struct.sites[2].position)] == [5.26, 5.0, 5.36] + assert [round(x, 2) for x in list(struct.sites[3].position)] == [5.26, 6.78, 5.36] + assert [round(x, 2) for x in list(struct.sites[4].position)] == [5.77, 5.89, 5.73] + + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') def test_partial_occ_and_spin(self): """ Tests pymatgen -> StructureData, with partial occupancies and spins. @@ -2239,7 +2238,7 @@ def test_partial_occ_and_spin(self): lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], species=[FeMn1, FeMn2], coords=[[0, 0, 0], [0.5, 0.5, 0.5]] ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): StructureData(pymatgen=a) # same, with vacancies @@ -2249,10 +2248,10 @@ def test_partial_occ_and_spin(self): lattice=[[4, 0, 0], [0, 4, 0], [0, 0, 4]], species=[Fe1, Fe2], coords=[[0, 0, 0], [0.5, 0.5, 0.5]] ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): StructureData(pymatgen=a) - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') @staticmethod def test_multiple_kinds_partial_occupancies(): """Tests that a structure with multiple sites with the same element but different @@ -2268,7 +2267,7 @@ def test_multiple_kinds_partial_occupancies(): StructureData(pymatgen=a) - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') @staticmethod def test_multiple_kinds_alloy(): """ @@ -2289,12 +2288,13 @@ def test_multiple_kinds_alloy(): StructureData(pymatgen=a) -class TestPymatgenFromStructureData(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestPymatgenFromStructureData: """Tests the creation of pymatgen Structure and Molecule objects from StructureData.""" from aiida.orm.nodes.data.structure import has_ase, has_pymatgen, get_pymatgen_version - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') def test_1(self): """Tests the check of periodic boundary conditions.""" struct = StructureData() @@ -2303,11 +2303,11 @@ def test_1(self): struct.get_pymatgen_structure() struct.pbc = [True, True, False] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): struct.get_pymatgen_structure() - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') def test_2(self): """Tests ASE -> StructureData -> pymatgen.""" import ase @@ -2328,10 +2328,10 @@ def test_2(self): for i, _ in enumerate(coord_array): coord_array[i] = [round(x, 2) for x in coord_array[i]] - self.assertEqual(coord_array, [[0.0, 0.0, 0.0], [0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]]) + assert coord_array == [[0.0, 0.0, 0.0], [0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]] - @unittest.skipIf(not has_ase(), 'Unable to import ase') - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.skipif(not has_ase(), reason='Unable to import ase') + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') def test_3(self): """ Tests the conversion of StructureData to pymatgen's Molecule @@ -2351,10 +2351,10 @@ def test_3(self): p_mol = a_struct.get_pymatgen_molecule() p_mol_dict = p_mol.as_dict() - self.assertEqual([x['xyz'] for x in p_mol_dict['sites']], - [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) + assert [x['xyz'] for x in p_mol_dict['sites']] == \ + [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]] - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') def test_roundtrip(self): """ Tests roundtrip StructureData -> pymatgen -> StructureData @@ -2372,12 +2372,12 @@ def test_roundtrip(self): b = a.get_pymatgen() c = StructureData(pymatgen=b) - self.assertEqual(c.get_site_kindnames(), ['Cl', 'Cl', 'Cl', 'Cl', 'Na', 'Na', 'Na', 'Na']) - self.assertEqual([k.symbol for k in c.kinds], ['Cl', 'Na']) - self.assertEqual([s.position for s in c.sites], [(0., 0., 0.), (2.8, 0, 2.8), (0, 2.8, 2.8), (2.8, 2.8, 0), - (2.8, 2.8, 2.8), (2.8, 0, 0), (0, 2.8, 0), (0, 0, 2.8)]) + assert c.get_site_kindnames() == ['Cl', 'Cl', 'Cl', 'Cl', 'Na', 'Na', 'Na', 'Na'] + assert [k.symbol for k in c.kinds] == ['Cl', 'Na'] + assert [s.position for s in c.sites] == [(0., 0., 0.), (2.8, 0, 2.8), (0, 2.8, 2.8), (2.8, 2.8, 0), + (2.8, 2.8, 2.8), (2.8, 0, 0), (0, 2.8, 0), (0, 0, 2.8)] - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') def test_roundtrip_kindnames(self): """ Tests roundtrip StructureData -> pymatgen -> StructureData @@ -2394,16 +2394,16 @@ def test_roundtrip_kindnames(self): a.append_atom(position=(0, 0, 2.8), symbols='Na', name='Na4') b = a.get_pymatgen() - self.assertEqual([site.properties['kind_name'] for site in b.sites], - ['Cl', 'Cl10', 'Cla', 'cl_x', 'Na1', 'Na2', 'Na_Na', 'Na4']) + assert [site.properties['kind_name'] for site in b.sites] == \ + ['Cl', 'Cl10', 'Cla', 'cl_x', 'Na1', 'Na2', 'Na_Na', 'Na4'] c = StructureData(pymatgen=b) - self.assertEqual(c.get_site_kindnames(), ['Cl', 'Cl10', 'Cla', 'cl_x', 'Na1', 'Na2', 'Na_Na', 'Na4']) - self.assertEqual(c.get_symbols_set(), set(['Cl', 'Na'])) - self.assertEqual([s.position for s in c.sites], [(0., 0., 0.), (2.8, 0, 2.8), (0, 2.8, 2.8), (2.8, 2.8, 0), - (2.8, 2.8, 2.8), (2.8, 0, 0), (0, 2.8, 0), (0, 0, 2.8)]) + assert c.get_site_kindnames() == ['Cl', 'Cl10', 'Cla', 'cl_x', 'Na1', 'Na2', 'Na_Na', 'Na4'] + assert c.get_symbols_set() == set(['Cl', 'Na']) + assert [s.position for s in c.sites] == [(0., 0., 0.), (2.8, 0, 2.8), (0, 2.8, 2.8), (2.8, 2.8, 0), + (2.8, 2.8, 2.8), (2.8, 0, 0), (0, 2.8, 0), (0, 0, 2.8)] - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') def test_roundtrip_spins(self): """ Tests roundtrip StructureData -> pymatgen -> StructureData @@ -2421,15 +2421,15 @@ def test_roundtrip_spins(self): b = a.get_pymatgen(add_spin=True) # check the spins - self.assertEqual([s.as_dict()['properties']['spin'] for s in b.species], [-1, -1, -1, -1, 1, 1, 1, 1]) + assert [s.as_dict()['properties']['spin'] for s in b.species] == [-1, -1, -1, -1, 1, 1, 1, 1] # back to StructureData c = StructureData(pymatgen=b) - self.assertEqual(c.get_site_kindnames(), ['Mn1', 'Mn1', 'Mn1', 'Mn1', 'Mn2', 'Mn2', 'Mn2', 'Mn2']) - self.assertEqual([k.symbol for k in c.kinds], ['Mn', 'Mn']) - self.assertEqual([s.position for s in c.sites], [(0., 0., 0.), (2.8, 0, 2.8), (0, 2.8, 2.8), (2.8, 2.8, 0), - (2.8, 2.8, 2.8), (2.8, 0, 0), (0, 2.8, 0), (0, 0, 2.8)]) + assert c.get_site_kindnames() == ['Mn1', 'Mn1', 'Mn1', 'Mn1', 'Mn2', 'Mn2', 'Mn2', 'Mn2'] + assert [k.symbol for k in c.kinds] == ['Mn', 'Mn'] + assert [s.position for s in c.sites] == [(0., 0., 0.), (2.8, 0, 2.8), (0, 2.8, 2.8), (2.8, 2.8, 0), + (2.8, 2.8, 2.8), (2.8, 0, 0), (0, 2.8, 0), (0, 0, 2.8)] - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') def test_roundtrip_partial_occ(self): """ Tests roundtrip StructureData -> pymatgen -> StructureData @@ -2452,13 +2452,13 @@ def test_roundtrip_partial_occ(self): a.append_atom(position=(2., 1., 9.5), symbols='N') # a few checks on the structure kinds and symbols - self.assertEqual(a.get_symbols_set(), set(['Mn', 'Si', 'N'])) - self.assertEqual(a.get_site_kindnames(), ['Mn', 'Mn', 'Mn', 'Mn', 'MnX', 'MnX', 'Si', 'Si', 'N', 'N', 'N', 'N']) - self.assertEqual(a.get_formula(), 'Mn4N4Si2{Mn0.80X0.20}2') + assert a.get_symbols_set() == set(['Mn', 'Si', 'N']) + assert a.get_site_kindnames() == ['Mn', 'Mn', 'Mn', 'Mn', 'MnX', 'MnX', 'Si', 'Si', 'N', 'N', 'N', 'N'] + assert a.get_formula() == 'Mn4N4Si2{Mn0.80X0.20}2' b = a.get_pymatgen() # check the partial occupancies - self.assertEqual([s.as_dict() for s in b.species_and_occu], [{ + assert [s.as_dict() for s in b.species_and_occu] == [{ 'Mn': 1.0 }, { 'Mn': 1.0 @@ -2482,20 +2482,20 @@ def test_roundtrip_partial_occ(self): 'N': 1.0 }, { 'N': 1.0 - }]) + }] # back to StructureData c = StructureData(pymatgen=b) - self.assertEqual(c.cell, [[4., 0.0, 0.0], [-2., 3.5, 0.0], [0.0, 0.0, 16.]]) - self.assertEqual(c.get_symbols_set(), set(['Mn', 'Si', 'N'])) - self.assertEqual(c.get_site_kindnames(), ['Mn', 'Mn', 'Mn', 'Mn', 'MnX', 'MnX', 'Si', 'Si', 'N', 'N', 'N', 'N']) - self.assertEqual(c.get_formula(), 'Mn4N4Si2{Mn0.80X0.20}2') + assert c.cell == [[4., 0.0, 0.0], [-2., 3.5, 0.0], [0.0, 0.0, 16.]] + assert c.get_symbols_set() == set(['Mn', 'Si', 'N']) + assert c.get_site_kindnames() == ['Mn', 'Mn', 'Mn', 'Mn', 'MnX', 'MnX', 'Si', 'Si', 'N', 'N', 'N', 'N'] + assert c.get_formula() == 'Mn4N4Si2{Mn0.80X0.20}2' testing.assert_allclose([s.position for s in c.sites], [(0.0, 0.0, 13.5), (0.0, 0.0, 2.6), (0.0, 0.0, 5.5), (0.0, 0.0, 11.), (2., 1., 12.), (0.0, 2.2, 4.), (0.0, 2.2, 12.), (2., 1., 4.), (2., 1., 15.), (0.0, 2.2, 1.5), (0.0, 2.2, 7.), (2., 1., 9.5)]) - @unittest.skipIf(not has_pymatgen(), 'Unable to import pymatgen') + @pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') def test_partial_occ_and_spin(self): """Tests StructureData -> pymatgen, with partial occupancies and spins. This should raise a ValueError.""" @@ -2504,11 +2504,11 @@ def test_partial_occ_and_spin(self): a.append_atom(position=(2, 2, 2), symbols=('Fe', 'Al'), weights=(0.8, 0.2), name='FeAl2') # a few checks on the structure kinds and symbols - self.assertEqual(a.get_symbols_set(), set(['Fe', 'Al'])) - self.assertEqual(a.get_site_kindnames(), ['FeAl1', 'FeAl2']) - self.assertEqual(a.get_formula(), '{Al0.20Fe0.80}2') + assert a.get_symbols_set() == set(['Fe', 'Al']) + assert a.get_site_kindnames() == ['FeAl1', 'FeAl2'] + assert a.get_formula() == '{Al0.20Fe0.80}2' - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.get_pymatgen(add_spin=True) # same, with vacancies @@ -2517,15 +2517,16 @@ def test_partial_occ_and_spin(self): a.append_atom(position=(2, 2, 2), symbols='Fe', weights=0.8, name='FeX2') # a few checks on the structure kinds and symbols - self.assertEqual(a.get_symbols_set(), set(['Fe'])) - self.assertEqual(a.get_site_kindnames(), ['FeX1', 'FeX2']) - self.assertEqual(a.get_formula(), '{Fe0.80X0.20}2') + assert a.get_symbols_set() == set(['Fe']) + assert a.get_site_kindnames() == ['FeX1', 'FeX2'] + assert a.get_formula() == '{Fe0.80X0.20}2' - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.get_pymatgen(add_spin=True) -class TestArrayData(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestArrayData: """Tests the ArrayData objects.""" def test_creation(self): @@ -2548,20 +2549,20 @@ def test_creation(self): n.set_array('third', third) # Check if the arrays are there - self.assertEqual(set(['first', 'second', 'third']), set(n.get_arraynames())) - self.assertAlmostEqual(abs(first - n.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n.get_array('second')).max(), 0.) - self.assertAlmostEqual(abs(third - n.get_array('third')).max(), 0.) - self.assertEqual(first.shape, n.get_shape('first')) - self.assertEqual(second.shape, n.get_shape('second')) - self.assertEqual(third.shape, n.get_shape('third')) - - with self.assertRaises(KeyError): + assert set(['first', 'second', 'third']) == set(n.get_arraynames()) + assert round(abs(abs(first - n.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n.get_array('second')).max() - 0.), 7) == 0 + assert round(abs(abs(third - n.get_array('third')).max() - 0.), 7) == 0 + assert first.shape == n.get_shape('first') + assert second.shape == n.get_shape('second') + assert third.shape == n.get_shape('third') + + with pytest.raises(KeyError): n.get_array('nonexistent_array') # Delete an array, and try to delete a non-existing one n.delete_array('third') - with self.assertRaises(KeyError): + with pytest.raises(KeyError): n.delete_array('nonexistent_array') # Overwrite an array @@ -2569,57 +2570,57 @@ def test_creation(self): n.set_array('first', first) # Check if the arrays are there, and if I am getting the new one - self.assertEqual(set(['first', 'second']), set(n.get_arraynames())) - self.assertAlmostEqual(abs(first - n.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n.get_array('second')).max(), 0.) - self.assertEqual(first.shape, n.get_shape('first')) - self.assertEqual(second.shape, n.get_shape('second')) + assert set(['first', 'second']) == set(n.get_arraynames()) + assert round(abs(abs(first - n.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n.get_array('second')).max() - 0.), 7) == 0 + assert first.shape == n.get_shape('first') + assert second.shape == n.get_shape('second') n.store() # Same checks, after storing - self.assertEqual(set(['first', 'second']), set(n.get_arraynames())) - self.assertAlmostEqual(abs(first - n.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n.get_array('second')).max(), 0.) - self.assertEqual(first.shape, n.get_shape('first')) - self.assertEqual(second.shape, n.get_shape('second')) + assert set(['first', 'second']) == set(n.get_arraynames()) + assert round(abs(abs(first - n.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n.get_array('second')).max() - 0.), 7) == 0 + assert first.shape == n.get_shape('first') + assert second.shape == n.get_shape('second') # Same checks, again (this is checking the caching features) - self.assertEqual(set(['first', 'second']), set(n.get_arraynames())) - self.assertAlmostEqual(abs(first - n.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n.get_array('second')).max(), 0.) - self.assertEqual(first.shape, n.get_shape('first')) - self.assertEqual(second.shape, n.get_shape('second')) + assert set(['first', 'second']) == set(n.get_arraynames()) + assert round(abs(abs(first - n.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n.get_array('second')).max() - 0.), 7) == 0 + assert first.shape == n.get_shape('first') + assert second.shape == n.get_shape('second') # Same checks, after reloading n2 = load_node(uuid=n.uuid) - self.assertEqual(set(['first', 'second']), set(n2.get_arraynames())) - self.assertAlmostEqual(abs(first - n2.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n2.get_array('second')).max(), 0.) - self.assertEqual(first.shape, n2.get_shape('first')) - self.assertEqual(second.shape, n2.get_shape('second')) + assert set(['first', 'second']) == set(n2.get_arraynames()) + assert round(abs(abs(first - n2.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n2.get_array('second')).max() - 0.), 7) == 0 + assert first.shape == n2.get_shape('first') + assert second.shape == n2.get_shape('second') # Same checks, after reloading with UUID n2 = load_node(n.uuid, sub_classes=(ArrayData,)) - self.assertEqual(set(['first', 'second']), set(n2.get_arraynames())) - self.assertAlmostEqual(abs(first - n2.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n2.get_array('second')).max(), 0.) - self.assertEqual(first.shape, n2.get_shape('first')) - self.assertEqual(second.shape, n2.get_shape('second')) + assert set(['first', 'second']) == set(n2.get_arraynames()) + assert round(abs(abs(first - n2.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n2.get_array('second')).max() - 0.), 7) == 0 + assert first.shape == n2.get_shape('first') + assert second.shape == n2.get_shape('second') # Check that I cannot modify the node after storing - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): n.delete_array('first') - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): n.set_array('second', first) # Again same checks, to verify that the attempts to delete/overwrite # arrays did not damage the node content - self.assertEqual(set(['first', 'second']), set(n.get_arraynames())) - self.assertAlmostEqual(abs(first - n.get_array('first')).max(), 0.) - self.assertAlmostEqual(abs(second - n.get_array('second')).max(), 0.) - self.assertEqual(first.shape, n.get_shape('first')) - self.assertEqual(second.shape, n.get_shape('second')) + assert set(['first', 'second']) == set(n.get_arraynames()) + assert round(abs(abs(first - n.get_array('first')).max() - 0.), 7) == 0 + assert round(abs(abs(second - n.get_array('second')).max() - 0.), 7) == 0 + assert first.shape == n.get_shape('first') + assert second.shape == n.get_shape('second') def test_iteration(self): """ @@ -2640,14 +2641,15 @@ def test_iteration(self): for name, array in n.get_iterarrays(): if name == 'first': - self.assertAlmostEqual(abs(first - array).max(), 0.) + assert round(abs(abs(first - array).max() - 0.), 7) == 0 if name == 'second': - self.assertAlmostEqual(abs(second - array).max(), 0.) + assert round(abs(abs(second - array).max() - 0.), 7) == 0 if name == 'third': - self.assertAlmostEqual(abs(third - array).max(), 0.) + assert round(abs(abs(third - array).max() - 0.), 7) == 0 -class TestTrajectoryData(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestTrajectoryData: """Tests the TrajectoryData objects.""" def test_creation(self): @@ -2698,27 +2700,27 @@ def test_creation(self): ) # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertAlmostEqual(abs(times - n.get_times()).sum(), 0.) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertAlmostEqual(abs(velocities - n.get_velocities()).sum(), 0.) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert round(abs(abs(times - n.get_times()).sum() - 0.), 7) == 0 + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert round(abs(abs(velocities - n.get_velocities()).sum() - 0.), 7) == 0 # get_step_data function check data = n.get_step_data(1) - self.assertEqual(data[0], stepids[1]) - self.assertAlmostEqual(data[1], times[1]) - self.assertAlmostEqual(abs(cells[1] - data[2]).sum(), 0.) - self.assertEqual(symbols, data[3]) - self.assertAlmostEqual(abs(data[4] - positions[1]).sum(), 0.) - self.assertAlmostEqual(abs(data[5] - velocities[1]).sum(), 0.) + assert data[0] == stepids[1] + assert round(abs(data[1] - times[1]), 7) == 0 + assert round(abs(abs(cells[1] - data[2]).sum() - 0.), 7) == 0 + assert symbols == data[3] + assert round(abs(abs(data[4] - positions[1]).sum() - 0.), 7) == 0 + assert round(abs(abs(data[5] - velocities[1]).sum() - 0.), 7) == 0 # Step 70 has index 1 - self.assertEqual(1, n.get_index_from_stepid(70)) - with self.assertRaises(ValueError): + assert n.get_index_from_stepid(70) == 1 + with pytest.raises(ValueError): # Step 66 does not exist n.get_index_from_stepid(66) @@ -2726,79 +2728,79 @@ def test_creation(self): # I set the node, this time without times or velocities (the same node) n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions) # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertIsNone(n.get_times()) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertIsNone(n.get_velocities()) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert n.get_times() is None + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert n.get_velocities() is None # Same thing, but for a new node n = TrajectoryData() n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions) # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertIsNone(n.get_times()) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertIsNone(n.get_velocities()) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert n.get_times() is None + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert n.get_velocities() is None ######################################################## # I set the node, this time without velocities (the same node) n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times) # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertAlmostEqual(abs(times - n.get_times()).sum(), 0.) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertIsNone(n.get_velocities()) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert round(abs(abs(times - n.get_times()).sum() - 0.), 7) == 0 + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert n.get_velocities() is None # Same thing, but for a new node n = TrajectoryData() n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times) # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertAlmostEqual(abs(times - n.get_times()).sum(), 0.) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertIsNone(n.get_velocities()) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert round(abs(abs(times - n.get_times()).sum() - 0.), 7) == 0 + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert n.get_velocities() is None n.store() # Again same checks, but after storing # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertAlmostEqual(abs(times - n.get_times()).sum(), 0.) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertIsNone(n.get_velocities()) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert round(abs(abs(times - n.get_times()).sum() - 0.), 7) == 0 + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert n.get_velocities() is None # get_step_data function check data = n.get_step_data(1) - self.assertEqual(data[0], stepids[1]) - self.assertAlmostEqual(data[1], times[1]) - self.assertAlmostEqual(abs(cells[1] - data[2]).sum(), 0.) - self.assertEqual(symbols, data[3]) - self.assertAlmostEqual(abs(data[4] - positions[1]).sum(), 0.) - self.assertIsNone(data[5]) + assert data[0] == stepids[1] + assert round(abs(data[1] - times[1]), 7) == 0 + assert round(abs(abs(cells[1] - data[2]).sum() - 0.), 7) == 0 + assert symbols == data[3] + assert round(abs(abs(data[4] - positions[1]).sum() - 0.), 7) == 0 + assert data[5] is None # Step 70 has index 1 - self.assertEqual(1, n.get_index_from_stepid(70)) - with self.assertRaises(ValueError): + assert n.get_index_from_stepid(70) == 1 + with pytest.raises(ValueError): # Step 66 does not exist n.get_index_from_stepid(66) @@ -2806,27 +2808,27 @@ def test_creation(self): # Again, but after reloading from uuid n = load_node(n.uuid, sub_classes=(TrajectoryData,)) # Generic checks - self.assertEqual(n.numsites, 3) - self.assertEqual(n.numsteps, 2) - self.assertAlmostEqual(abs(stepids - n.get_stepids()).sum(), 0.) - self.assertAlmostEqual(abs(times - n.get_times()).sum(), 0.) - self.assertAlmostEqual(abs(cells - n.get_cells()).sum(), 0.) - self.assertEqual(symbols, n.symbols) - self.assertAlmostEqual(abs(positions - n.get_positions()).sum(), 0.) - self.assertIsNone(n.get_velocities()) + assert n.numsites == 3 + assert n.numsteps == 2 + assert round(abs(abs(stepids - n.get_stepids()).sum() - 0.), 7) == 0 + assert round(abs(abs(times - n.get_times()).sum() - 0.), 7) == 0 + assert round(abs(abs(cells - n.get_cells()).sum() - 0.), 7) == 0 + assert symbols == n.symbols + assert round(abs(abs(positions - n.get_positions()).sum() - 0.), 7) == 0 + assert n.get_velocities() is None # get_step_data function check data = n.get_step_data(1) - self.assertEqual(data[0], stepids[1]) - self.assertAlmostEqual(data[1], times[1]) - self.assertAlmostEqual(abs(cells[1] - data[2]).sum(), 0.) - self.assertEqual(symbols, data[3]) - self.assertAlmostEqual(abs(data[4] - positions[1]).sum(), 0.) - self.assertIsNone(data[5]) + assert data[0] == stepids[1] + assert round(abs(data[1] - times[1]), 7) == 0 + assert round(abs(abs(cells[1] - data[2]).sum() - 0.), 7) == 0 + assert symbols == data[3] + assert round(abs(abs(data[4] - positions[1]).sum() - 0.), 7) == 0 + assert data[5] is None # Step 70 has index 1 - self.assertEqual(1, n.get_index_from_stepid(70)) - with self.assertRaises(ValueError): + assert n.get_index_from_stepid(70) == 1 + with pytest.raises(ValueError): # Step 66 does not exist n.get_index_from_stepid(66) @@ -2883,12 +2885,12 @@ def test_conversion_to_structure(self): from_get_structure = n.get_structure(index=1) for struc in [from_step, from_get_structure]: - self.assertEqual(len(struc.sites), 3) # 3 sites - self.assertAlmostEqual(abs(numpy.array(struc.cell) - cells[1]).sum(), 0) + assert len(struc.sites) == 3 # 3 sites + assert round(abs(abs(numpy.array(struc.cell) - cells[1]).sum() - 0), 7) == 0 newpos = numpy.array([s.position for s in struc.sites]) - self.assertAlmostEqual(abs(newpos - positions[1]).sum(), 0) + assert round(abs(abs(newpos - positions[1]).sum() - 0), 7) == 0 newkinds = [s.kind_name for s in struc.sites] - self.assertEqual(newkinds, symbols) + assert newkinds == symbols # Weird assignments (nobody should ever do this, but it is possible in # principle and we want to check @@ -2897,19 +2899,19 @@ def test_conversion_to_structure(self): k3 = Kind(name='O', symbols='Os', mass=100.) k4 = Kind(name='Ge', symbols='Ge') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Not enough kinds struc = n.get_step_structure(1, custom_kinds=[k1, k2]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Too many kinds struc = n.get_step_structure(1, custom_kinds=[k1, k2, k3, k4]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Wrong kinds struc = n.get_step_structure(1, custom_kinds=[k1, k2, k4]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # Two kinds with the same name struc = n.get_step_structure(1, custom_kinds=[k1, k2, k3, k3]) @@ -2917,18 +2919,18 @@ def test_conversion_to_structure(self): struc = n.get_step_structure(1, custom_kinds=[k1, k2, k3]) # Checks - self.assertEqual(len(struc.sites), 3) # 3 sites - self.assertAlmostEqual(abs(numpy.array(struc.cell) - cells[1]).sum(), 0) + assert len(struc.sites) == 3 # 3 sites + assert round(abs(abs(numpy.array(struc.cell) - cells[1]).sum() - 0), 7) == 0 newpos = numpy.array([s.position for s in struc.sites]) - self.assertAlmostEqual(abs(newpos - positions[1]).sum(), 0) + assert round(abs(abs(newpos - positions[1]).sum() - 0), 7) == 0 newkinds = [s.kind_name for s in struc.sites] # Kinds are in the same order as given in the custm_kinds list - self.assertEqual(newkinds, symbols) + assert newkinds == symbols newatomtypes = [struc.get_kind(s.kind_name).symbols[0] for s in struc.sites] # Atoms remain in the same order as given in the positions list - self.assertEqual(newatomtypes, ['He', 'Os', 'Cu']) + assert newatomtypes == ['He', 'Os', 'Cu'] # Check the mass of the kind of the second atom ('O' _> symbol Os, mass 100) - self.assertAlmostEqual(struc.get_kind(struc.sites[1].kind_name).mass, 100.) + assert round(abs(struc.get_kind(struc.sites[1].kind_name).mass - 100.), 7) == 0 def test_conversion_from_structurelist(self): """ @@ -2971,9 +2973,9 @@ def test_conversion_from_structurelist(self): structurelist.append(struct) td = TrajectoryData(structurelist=structurelist) - self.assertEqual(td.get_cells().tolist(), cells) - self.assertEqual(td.symbols, symbols[0]) - self.assertEqual(td.get_positions().tolist(), positions) + assert td.get_cells().tolist() == cells + assert td.symbols == symbols[0] + assert td.get_positions().tolist() == positions symbols = [['H', 'O', 'C'], ['H', 'O', 'P']] structurelist = [] @@ -2983,7 +2985,7 @@ def test_conversion_from_structurelist(self): struct.append_atom(symbols=symbol, position=positions[i][j]) structurelist.append(struct) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): td = TrajectoryData(structurelist=structurelist) @staticmethod @@ -3058,7 +3060,8 @@ def test_export_to_file(): os.remove(file) -class TestKpointsData(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestKpointsData: """Tests the KpointsData objects.""" def test_mesh(self): @@ -3072,27 +3075,27 @@ def test_mesh(self): input_mesh = [4, 4, 4] k.set_kpoints_mesh(input_mesh) mesh, offset = k.get_kpoints_mesh() - self.assertEqual(mesh, input_mesh) - self.assertEqual(offset, [0., 0., 0.]) # must be a tuple of three 0 by default + assert mesh == input_mesh + assert offset == [0., 0., 0.] # must be a tuple of three 0 by default # a too long list should fail - with self.assertRaises(ValueError): + with pytest.raises(ValueError): k.set_kpoints_mesh([4, 4, 4, 4]) # now try to put explicitely an offset input_offset = [0.5, 0.5, 0.5] k.set_kpoints_mesh(input_mesh, input_offset) mesh, offset = k.get_kpoints_mesh() - self.assertEqual(mesh, input_mesh) - self.assertEqual(offset, input_offset) + assert mesh == input_mesh + assert offset == input_offset # verify the same but after storing k.store() - self.assertEqual(mesh, input_mesh) - self.assertEqual(offset, input_offset) + assert mesh == input_mesh + assert offset == input_offset # cannot modify it after storage - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): k.set_kpoints_mesh(input_mesh) def test_list(self): @@ -3115,33 +3118,33 @@ def test_list(self): klist = k.get_kpoints() # try to get the same - self.assertTrue(numpy.array_equal(input_klist, klist)) + assert numpy.array_equal(input_klist, klist) # if no cell is set, cannot convert into cartesian - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): _ = k.get_kpoints(cartesian=True) # try to set also weights # should fail if the weights length do not match kpoints input_weights = numpy.ones(6) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): k.set_kpoints(input_klist, weights=input_weights) # try a right one input_weights = numpy.ones(4) k.set_kpoints(input_klist, weights=input_weights) klist, weights = k.get_kpoints(also_weights=True) - self.assertTrue(numpy.array_equal(weights, input_weights)) - self.assertTrue(numpy.array_equal(klist, input_klist)) + assert numpy.array_equal(weights, input_weights) + assert numpy.array_equal(klist, input_klist) # verify the same, but after storing k.store() klist, weights = k.get_kpoints(also_weights=True) - self.assertTrue(numpy.array_equal(weights, input_weights)) - self.assertTrue(numpy.array_equal(klist, input_klist)) + assert numpy.array_equal(weights, input_weights) + assert numpy.array_equal(klist, input_klist) # cannot modify it after storage - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): k.set_kpoints(input_klist) def test_kpoints_to_cartesian(self): @@ -3175,13 +3178,13 @@ def test_kpoints_to_cartesian(self): # verify that it is not the same of the input # (at least I check that there something has been done) klist = k.get_kpoints(cartesian=True) - self.assertFalse(numpy.array_equal(klist, input_klist)) + assert not numpy.array_equal(klist, input_klist) # put the kpoints in cartesian and get them back, they should be equal # internally it is doing two matrix transforms k.set_kpoints(input_klist, cartesian=True) klist = k.get_kpoints(cartesian=True) - self.assertTrue(numpy.allclose(klist, input_klist, atol=1e-16)) + assert numpy.allclose(klist, input_klist, atol=1e-16) def test_path_wrapper_legacy(self): """ @@ -3193,7 +3196,7 @@ def test_path_wrapper_legacy(self): from aiida.tools.data.array.kpoints import get_explicit_kpoints_path # Shouldn't get anything without having set the cell - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): get_explicit_kpoints_path(None) # Define a cell @@ -3221,12 +3224,12 @@ def test_path_wrapper_legacy(self): ]) # at least 2 points per segment - with self.assertRaises(ValueError): + with pytest.raises(ValueError): get_explicit_kpoints_path(structure, method='legacy', value=[ ('G', 'M', 1), ]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): get_explicit_kpoints_path(structure, method='legacy', value=[ ('G', (0., 0., 0.), 'M', (1., 1., 1.), 1), ]) @@ -3248,15 +3251,16 @@ def test_tetra_z_wrapper_legacy(self): s = StructureData(cell=cell_x) result = get_kpoints_path(s, method='legacy', cartesian=True) - self.assertIsInstance(result['parameters'], Dict) + assert isinstance(result['parameters'], Dict) point_coords = result['parameters'].dict.point_coords - self.assertAlmostEqual(point_coords['Z'][2], numpy.pi / alat) - self.assertAlmostEqual(point_coords['Z'][0], 0.) + assert round(abs(point_coords['Z'][2] - numpy.pi / alat), 7) == 0 + assert round(abs(point_coords['Z'][0] - 0.), 7) == 0 -class TestSpglibTupleConversion(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestSpglibTupleConversion: """Tests for conversion of Spglib tuples.""" def test_simple_to_aiida(self): @@ -3277,11 +3281,11 @@ def test_simple_to_aiida(self): struc = spglib_tuple_to_structure((cell, relcoords, numbers)) - self.assertAlmostEqual(np.sum(np.abs(np.array(struc.cell) - np.array(cell))), 0.) - self.assertAlmostEqual( - np.sum(np.abs(np.array([site.position for site in struc.sites]) - np.array(abscoords))), 0. - ) - self.assertEqual([site.kind_name for site in struc.sites], ['Ba', 'Ti', 'O', 'O', 'O']) + assert round(abs(np.sum(np.abs(np.array(struc.cell) - np.array(cell))) - 0.), 7) == 0 + assert round( + abs(np.sum(np.abs(np.array([site.position for site in struc.sites]) - np.array(abscoords))) - 0.), 7 + ) == 0 + assert [site.kind_name for site in struc.sites] == ['Ba', 'Ti', 'O', 'O', 'O'] def test_complex1_to_aiida(self): """Test conversion of a tuple to an AiiDA structure when passing also information on the kinds.""" @@ -3323,26 +3327,26 @@ def test_complex1_to_aiida(self): ] # Must specify also kind_info and kinds - with self.assertRaises(ValueError): + with pytest.raises(ValueError): struc = spglib_tuple_to_structure((cell, relcoords, numbers),) # There is no kind_info for one of the numbers - with self.assertRaises(ValueError): + with pytest.raises(ValueError): struc = spglib_tuple_to_structure((cell, relcoords, numbers), kind_info=kind_info_wrong, kinds=kinds) # There is no kind in the kinds for one of the labels # specified in kind_info - with self.assertRaises(ValueError): + with pytest.raises(ValueError): struc = spglib_tuple_to_structure((cell, relcoords, numbers), kind_info=kind_info, kinds=kinds_wrong) struc = spglib_tuple_to_structure((cell, relcoords, numbers), kind_info=kind_info, kinds=kinds) - self.assertAlmostEqual(np.sum(np.abs(np.array(struc.cell) - np.array(cell))), 0.) - self.assertAlmostEqual( - np.sum(np.abs(np.array([site.position for site in struc.sites]) - np.array(abscoords))), 0. - ) - self.assertEqual([site.kind_name for site in struc.sites], - ['Ba', 'Ti', 'O', 'O', 'O', 'Ba2', 'BaTi', 'BaTi2', 'Ba3']) + assert round(abs(np.sum(np.abs(np.array(struc.cell) - np.array(cell))) - 0.), 7) == 0 + assert round( + abs(np.sum(np.abs(np.array([site.position for site in struc.sites]) - np.array(abscoords))) - 0.), 7 + ) == 0 + assert [site.kind_name for site in struc.sites] == \ + ['Ba', 'Ti', 'O', 'O', 'O', 'Ba2', 'BaTi', 'BaTi2', 'Ba3'] def test_from_aiida(self): """Test conversion of an AiiDA structure to a spglib tuple.""" @@ -3365,11 +3369,11 @@ def test_from_aiida(self): abscoords = np.array([_.position for _ in struc.sites]) struc_relpos = np.dot(np.linalg.inv(cell.T), abscoords.T).T - self.assertAlmostEqual(np.sum(np.abs(np.array(struc.cell) - np.array(struc_tuple[0]))), 0.) - self.assertAlmostEqual(np.sum(np.abs(np.array(struc_tuple[1]) - struc_relpos)), 0.) + assert round(abs(np.sum(np.abs(np.array(struc.cell) - np.array(struc_tuple[0]))) - 0.), 7) == 0 + assert round(abs(np.sum(np.abs(np.array(struc_tuple[1]) - struc_relpos)) - 0.), 7) == 0 expected_kind_info = [kind_info[site.kind_name] for site in struc.sites] - self.assertEqual(struc_tuple[2], expected_kind_info) + assert struc_tuple[2] == expected_kind_info def test_aiida_roundtrip(self): """ @@ -3392,22 +3396,19 @@ def test_aiida_roundtrip(self): struc_tuple, kind_info, kinds = structure_to_spglib_tuple(struc) roundtrip_struc = spglib_tuple_to_structure(struc_tuple, kind_info, kinds) - self.assertAlmostEqual(np.sum(np.abs(np.array(struc.cell) - np.array(roundtrip_struc.cell))), 0.) - self.assertEqual(struc.get_attribute('kinds'), roundtrip_struc.get_attribute('kinds')) - self.assertEqual([_.kind_name for _ in struc.sites], [_.kind_name for _ in roundtrip_struc.sites]) - self.assertEqual( - np.sum( - np.abs( - np.array([_.position for _ in struc.sites]) - np.array([_.position for _ in roundtrip_struc.sites]) - ) - ), 0. - ) + assert round(abs(np.sum(np.abs(np.array(struc.cell) - np.array(roundtrip_struc.cell))) - 0.), 7) == 0 + assert struc.get_attribute('kinds') == roundtrip_struc.get_attribute('kinds') + assert [_.kind_name for _ in struc.sites] == [_.kind_name for _ in roundtrip_struc.sites] + assert np.sum( + np.abs(np.array([_.position for _ in struc.sites]) - np.array([_.position for _ in roundtrip_struc.sites])) + ) == 0. -class TestSeekpathExplicitPath(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestSeekpathExplicitPath: """Tests for the `get_explicit_kpoints_path` from SeeK-path.""" - @unittest.skipIf(not has_seekpath(), 'No seekpath available') + @pytest.mark.skipif(not has_seekpath(), reason='No seekpath available') def test_simple(self): """Test a simple case.""" import numpy as np @@ -3427,70 +3428,64 @@ def test_simple(self): return_value = get_explicit_kpoints_path(structure, method='seekpath', **params) retdict = return_value['parameters'].get_dict() - self.assertTrue(retdict['has_inversion_symmetry']) - self.assertFalse(retdict['augmented_path']) - self.assertAlmostEqual(retdict['volume_original_wrt_prim'], 1.0) - self.assertEqual( - to_list_of_lists(retdict['explicit_segments']), + assert retdict['has_inversion_symmetry'] + assert not retdict['augmented_path'] + assert round(abs(retdict['volume_original_wrt_prim'] - 1.0), 7) == 0 + assert to_list_of_lists(retdict['explicit_segments']) == \ [[0, 31], [30, 61], [60, 104], [103, 123], [122, 153], [152, 183], [182, 226], [226, 246], [246, 266]] - ) ret_k = return_value['explicit_kpoints'] - self.assertEqual( - to_list_of_lists(ret_k.labels), [[0, 'GAMMA'], [30, 'X'], [60, 'M'], [103, 'GAMMA'], [122, 'Z'], [152, 'R'], - [182, 'A'], [225, 'Z'], [226, 'X'], [245, 'R'], [246, 'M'], [265, 'A']] - ) + assert to_list_of_lists(ret_k.labels) == [[0, 'GAMMA'], [30, 'X'], [60, 'M'], [103, 'GAMMA'], [122, 'Z'], + [152, 'R'], [182, 'A'], [225, 'Z'], [226, 'X'], [245, 'R'], + [246, 'M'], [265, 'A']] kpts = ret_k.get_kpoints(cartesian=False) highsympoints_relcoords = [kpts[idx] for idx, label in ret_k.labels] - self.assertAlmostEqual( - np.sum( - np.abs( - np.array([ - [0., 0., 0.], # Gamma - [0., 0.5, 0.], # X - [0.5, 0.5, 0.], # M - [0., 0., 0.], # Gamma - [0., 0., 0.5], # Z - [0., 0.5, 0.5], # R - [0.5, 0.5, 0.5], # A - [0., 0., 0.5], # Z - [0., 0.5, 0.], # X - [0., 0.5, 0.5], # R - [0.5, 0.5, 0.], # M - [0.5, 0.5, 0.5], # A - ]) - np.array(highsympoints_relcoords) - ) + assert round( + abs( + np.sum( + np.abs( + np.array([ + [0., 0., 0.], # Gamma + [0., 0.5, 0.], # X + [0.5, 0.5, 0.], # M + [0., 0., 0.], # Gamma + [0., 0., 0.5], # Z + [0., 0.5, 0.5], # R + [0.5, 0.5, 0.5], # A + [0., 0., 0.5], # Z + [0., 0.5, 0.], # X + [0., 0.5, 0.5], # R + [0.5, 0.5, 0.], # M + [0.5, 0.5, 0.5], # A + ]) - np.array(highsympoints_relcoords) + ) + ) - 0. ), - 0. - ) + 7 + ) == 0 ret_prims = return_value['primitive_structure'] ret_convs = return_value['conv_structure'] # The primitive structure should be the same as the one I input - self.assertAlmostEqual(np.sum(np.abs(np.array(structure.cell) - np.array(ret_prims.cell))), 0.) - self.assertEqual([_.kind_name for _ in structure.sites], [_.kind_name for _ in ret_prims.sites]) - self.assertEqual( - np.sum( - np. - abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_prims.sites])) - ), 0. - ) + assert round(abs(np.sum(np.abs(np.array(structure.cell) - np.array(ret_prims.cell))) - 0.), 7) == 0 + assert [_.kind_name for _ in structure.sites] == [_.kind_name for _ in ret_prims.sites] + assert np.sum( + np.abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_prims.sites])) + ) == 0. # Also the conventional structure should be the same as the one I input - self.assertAlmostEqual(np.sum(np.abs(np.array(structure.cell) - np.array(ret_convs.cell))), 0.) - self.assertEqual([_.kind_name for _ in structure.sites], [_.kind_name for _ in ret_convs.sites]) - self.assertEqual( - np.sum( - np. - abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_convs.sites])) - ), 0. - ) + assert round(abs(np.sum(np.abs(np.array(structure.cell) - np.array(ret_convs.cell))) - 0.), 7) == 0 + assert [_.kind_name for _ in structure.sites] == [_.kind_name for _ in ret_convs.sites] + assert np.sum( + np.abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_convs.sites])) + ) == 0. -class TestSeekpathPath(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestSeekpathPath: """Test Seekpath.""" - @unittest.skipIf(not has_seekpath(), 'No seekpath available') + @pytest.mark.skipif(not has_seekpath(), reason='No seekpath available') def test_simple(self): """Test SeekPath for BaTiO3 structure.""" import numpy as np @@ -3509,61 +3504,54 @@ def test_simple(self): return_value = get_kpoints_path(structure, method='seekpath', **params) retdict = return_value['parameters'].get_dict() - self.assertTrue(retdict['has_inversion_symmetry']) - self.assertFalse(retdict['augmented_path']) - self.assertAlmostEqual(retdict['volume_original_wrt_prim'], 1.0) - self.assertAlmostEqual(retdict['volume_original_wrt_conv'], 1.0) - self.assertEqual(retdict['bravais_lattice'], 'tP') - self.assertEqual(retdict['bravais_lattice_extended'], 'tP1') - self.assertEqual( - to_list_of_lists(retdict['path']), [['GAMMA', 'X'], ['X', 'M'], ['M', 'GAMMA'], ['GAMMA', 'Z'], ['Z', 'R'], - ['R', 'A'], ['A', 'Z'], ['X', 'R'], ['M', 'A']] - ) - - self.assertEqual( - retdict['point_coords'], { - 'A': [0.5, 0.5, 0.5], - 'M': [0.5, 0.5, 0.0], - 'R': [0.0, 0.5, 0.5], - 'X': [0.0, 0.5, 0.0], - 'Z': [0.0, 0.0, 0.5], - 'GAMMA': [0.0, 0.0, 0.0] - } - ) - - self.assertAlmostEqual( - np.sum( - np.abs( - np.array(retdict['inverse_primitive_transformation_matrix']) - - np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - ) - ), 0. - ) + assert retdict['has_inversion_symmetry'] + assert not retdict['augmented_path'] + assert round(abs(retdict['volume_original_wrt_prim'] - 1.0), 7) == 0 + assert round(abs(retdict['volume_original_wrt_conv'] - 1.0), 7) == 0 + assert retdict['bravais_lattice'] == 'tP' + assert retdict['bravais_lattice_extended'] == 'tP1' + assert to_list_of_lists(retdict['path']) == [['GAMMA', 'X'], ['X', 'M'], ['M', 'GAMMA'], ['GAMMA', 'Z'], + ['Z', 'R'], ['R', 'A'], ['A', 'Z'], ['X', 'R'], ['M', 'A']] + + assert retdict['point_coords'] == { + 'A': [0.5, 0.5, 0.5], + 'M': [0.5, 0.5, 0.0], + 'R': [0.0, 0.5, 0.5], + 'X': [0.0, 0.5, 0.0], + 'Z': [0.0, 0.0, 0.5], + 'GAMMA': [0.0, 0.0, 0.0] + } + + assert round( + abs( + np.sum( + np.abs( + np.array(retdict['inverse_primitive_transformation_matrix']) - + np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + ) + ) - 0. + ), 7 + ) == 0 ret_prims = return_value['primitive_structure'] ret_convs = return_value['conv_structure'] # The primitive structure should be the same as the one I input - self.assertAlmostEqual(np.sum(np.abs(np.array(structure.cell) - np.array(ret_prims.cell))), 0.) - self.assertEqual([_.kind_name for _ in structure.sites], [_.kind_name for _ in ret_prims.sites]) - self.assertEqual( - np.sum( - np. - abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_prims.sites])) - ), 0. - ) + assert round(abs(np.sum(np.abs(np.array(structure.cell) - np.array(ret_prims.cell))) - 0.), 7) == 0 + assert [_.kind_name for _ in structure.sites] == [_.kind_name for _ in ret_prims.sites] + assert np.sum( + np.abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_prims.sites])) + ) == 0. # Also the conventional structure should be the same as the one I input - self.assertAlmostEqual(np.sum(np.abs(np.array(structure.cell) - np.array(ret_convs.cell))), 0.) - self.assertEqual([_.kind_name for _ in structure.sites], [_.kind_name for _ in ret_convs.sites]) - self.assertEqual( - np.sum( - np. - abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_convs.sites])) - ), 0. - ) + assert round(abs(np.sum(np.abs(np.array(structure.cell) - np.array(ret_convs.cell))) - 0.), 7) == 0 + assert [_.kind_name for _ in structure.sites] == [_.kind_name for _ in ret_convs.sites] + assert np.sum( + np.abs(np.array([_.position for _ in structure.sites]) - np.array([_.position for _ in ret_convs.sites])) + ) == 0. -class TestBandsData(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestBandsData: """ Tests the BandsData objects. """ @@ -3588,7 +3576,7 @@ def test_band(self): b = BandsData() b.set_kpointsdata(k) - self.assertTrue(numpy.array_equal(b.cell, k.cell)) + assert numpy.array_equal(b.cell, k.cell) input_bands = numpy.array([numpy.ones(4) for i in range(k.get_kpoints().shape[0])]) input_occupations = input_bands @@ -3596,18 +3584,18 @@ def test_band(self): b.set_bands(input_bands, occupations=input_occupations, units='ev') b.set_bands(input_bands, units='ev') b.set_bands(input_bands, occupations=input_occupations) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): b.set_bands(occupations=input_occupations, units='ev') # pylint: disable=no-value-for-parameter b.set_bands(input_bands, occupations=input_occupations, units='ev') bands, occupations = b.get_bands(also_occupations=True) - self.assertTrue(numpy.array_equal(bands, input_bands)) - self.assertTrue(numpy.array_equal(occupations, input_occupations)) - self.assertTrue(b.units == 'ev') + assert numpy.array_equal(bands, input_bands) + assert numpy.array_equal(occupations, input_occupations) + assert b.units == 'ev' b.store() - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): b.set_bands(bands) @staticmethod diff --git a/tests/test_dbimporters.py b/tests/test_dbimporters.py index f004f72642..1487af60fa 100644 --- a/tests/test_dbimporters.py +++ b/tests/test_dbimporters.py @@ -7,14 +7,15 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for subclasses of DbImporter, DbSearchResults and DbEntry""" -import unittest +import pytest -from aiida.backends.testbase import AiidaTestCase from tests.static import STATIC_DIR -class TestCodDbImporter(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestCodDbImporter: """Test the CodDbImporter class.""" from aiida.orm.nodes.data.cif import has_pycifrw @@ -44,28 +45,28 @@ def test_query_construction_1(self): q_sql = re.sub(r'(\d\.\d{6})\d+', r'\1', q_sql) q_sql = re.sub(r'(120.00)39+', r'\g<1>4', q_sql) - self.assertEqual(q_sql, \ - 'SELECT file, svnrevision FROM data WHERE ' - "(status IS NULL OR status != 'retracted') AND " - '(a BETWEEN 3.332333 AND 3.334333 OR ' - 'a BETWEEN 0.999 AND 1.001) AND ' - '(alpha BETWEEN 1.665666 AND 1.667666 OR ' - 'alpha BETWEEN -0.001 AND 0.001) AND ' - "(chemname LIKE '%caffeine%' OR " - "chemname LIKE '%serotonine%') AND " - "(method IN ('single crystal') OR method IS NULL) AND " - "(formula REGEXP ' C[0-9 ]' AND " - "formula REGEXP ' H[0-9 ]' AND " - "formula REGEXP ' Cl[0-9 ]') AND " - "(formula IN ('- C6 H6 -')) AND " - '(file IN (1000000, 3000000)) AND ' - '(cellpressure BETWEEN 999 AND 1001 OR ' - 'cellpressure BETWEEN 1000 AND 1002) AND ' - '(celltemp BETWEEN -0.001 AND 0.001 OR ' - 'celltemp BETWEEN 10.499 AND 10.501) AND ' - "(nel IN (5)) AND (sg IN ('P -1')) AND " - '(vol BETWEEN 99.999 AND 100.001 OR ' - 'vol BETWEEN 120.004 AND 120.006)') + assert q_sql == \ + 'SELECT file, svnrevision FROM data WHERE ' \ + "(status IS NULL OR status != 'retracted') AND " \ + '(a BETWEEN 3.332333 AND 3.334333 OR ' \ + 'a BETWEEN 0.999 AND 1.001) AND ' \ + '(alpha BETWEEN 1.665666 AND 1.667666 OR ' \ + 'alpha BETWEEN -0.001 AND 0.001) AND ' \ + "(chemname LIKE '%caffeine%' OR " \ + "chemname LIKE '%serotonine%') AND " \ + "(method IN ('single crystal') OR method IS NULL) AND " \ + "(formula REGEXP ' C[0-9 ]' AND " \ + "formula REGEXP ' H[0-9 ]' AND " \ + "formula REGEXP ' Cl[0-9 ]') AND " \ + "(formula IN ('- C6 H6 -')) AND " \ + '(file IN (1000000, 3000000)) AND ' \ + '(cellpressure BETWEEN 999 AND 1001 OR ' \ + 'cellpressure BETWEEN 1000 AND 1002) AND ' \ + '(celltemp BETWEEN -0.001 AND 0.001 OR ' \ + 'celltemp BETWEEN 10.499 AND 10.501) AND ' \ + "(nel IN (5)) AND (sg IN ('P -1')) AND " \ + '(vol BETWEEN 99.999 AND 100.001 OR ' \ + 'vol BETWEEN 120.004 AND 120.006)' def test_datatype_checks(self): """Rather complicated, but wide-coverage test for data types, accepted @@ -99,7 +100,7 @@ def test_datatype_checks(self): methods[i]('test', 'test', [values[j]]) except ValueError as exc: message = str(exc) - self.assertEqual(message, messages[results[i][j]]) + assert message == messages[results[i][j]] def test_dbentry_creation(self): """Tests the creation of CodEntry from CodSearchResults.""" @@ -116,25 +117,23 @@ def test_dbentry_creation(self): 'id': '2000000', 'svnrevision': '1234' }]) - self.assertEqual(len(results), 3) - self.assertEqual( - results.at(1).source, { - 'db_name': 'Crystallography Open Database', - 'db_uri': 'http://www.crystallography.net/cod', - 'extras': {}, - 'id': '1000001', - 'license': 'CC0', - 'source_md5': None, - 'uri': 'http://www.crystallography.net/cod/1000001.cif@1234', - 'version': '1234', - } - ) - self.assertEqual([x.source['uri'] for x in results], [ + assert len(results) == 3 + assert results.at(1).source == { + 'db_name': 'Crystallography Open Database', + 'db_uri': 'http://www.crystallography.net/cod', + 'extras': {}, + 'id': '1000001', + 'license': 'CC0', + 'source_md5': None, + 'uri': 'http://www.crystallography.net/cod/1000001.cif@1234', + 'version': '1234', + } + assert [x.source['uri'] for x in results] == [ 'http://www.crystallography.net/cod/1000000.cif', 'http://www.crystallography.net/cod/1000001.cif@1234', 'http://www.crystallography.net/cod/2000000.cif@1234' - ]) + ] - @unittest.skipIf(not has_pycifrw(), 'Unable to import PyCifRW') + @pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') def test_dbentry_to_cif_node(self): """Tests the creation of CifData node from CodEntry.""" from aiida.orm import CifData @@ -144,23 +143,22 @@ def test_dbentry_to_cif_node(self): entry.cif = "data_test _publ_section_title 'Test structure'" cif = entry.get_cif_node() - self.assertEqual(isinstance(cif, CifData), True) - self.assertEqual(cif.get_attribute('md5'), '070711e8e99108aade31d20cd5c94c48') - self.assertEqual( - cif.source, { - 'db_name': 'Crystallography Open Database', - 'db_uri': 'http://www.crystallography.net/cod', - 'id': None, - 'version': None, - 'extras': {}, - 'source_md5': '070711e8e99108aade31d20cd5c94c48', - 'uri': 'http://www.crystallography.net/cod/1000000.cif', - 'license': 'CC0', - } - ) - - -class TestTcodDbImporter(AiidaTestCase): + assert isinstance(cif, CifData) is True + assert cif.get_attribute('md5') == '070711e8e99108aade31d20cd5c94c48' + assert cif.source == { + 'db_name': 'Crystallography Open Database', + 'db_uri': 'http://www.crystallography.net/cod', + 'id': None, + 'version': None, + 'extras': {}, + 'source_md5': '070711e8e99108aade31d20cd5c94c48', + 'uri': 'http://www.crystallography.net/cod/1000000.cif', + 'license': 'CC0', + } + + +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestTcodDbImporter: """Test the TcodDbImporter class.""" def test_dbentry_creation(self): @@ -177,26 +175,25 @@ def test_dbentry_creation(self): 'id': '20000000', 'svnrevision': '1234' }]) - self.assertEqual(len(results), 3) - self.assertEqual( - results.at(1).source, { - 'db_name': 'Theoretical Crystallography Open Database', - 'db_uri': 'http://www.crystallography.net/tcod', - 'extras': {}, - 'id': '10000001', - 'license': 'CC0', - 'source_md5': None, - 'uri': 'http://www.crystallography.net/tcod/10000001.cif@1234', - 'version': '1234', - } - ) - self.assertEqual([x.source['uri'] for x in results], [ + assert len(results) == 3 + assert results.at(1).source == { + 'db_name': 'Theoretical Crystallography Open Database', + 'db_uri': 'http://www.crystallography.net/tcod', + 'extras': {}, + 'id': '10000001', + 'license': 'CC0', + 'source_md5': None, + 'uri': 'http://www.crystallography.net/tcod/10000001.cif@1234', + 'version': '1234', + } + assert [x.source['uri'] for x in results] == [ 'http://www.crystallography.net/tcod/10000000.cif', 'http://www.crystallography.net/tcod/10000001.cif@1234', 'http://www.crystallography.net/tcod/20000000.cif@1234' - ]) + ] -class TestPcodDbImporter(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestPcodDbImporter: """Test the PcodDbImporter class.""" def test_dbentry_creation(self): @@ -204,22 +201,21 @@ def test_dbentry_creation(self): from aiida.tools.dbimporters.plugins.pcod import PcodSearchResults results = PcodSearchResults([{'id': '12345678'}]) - self.assertEqual(len(results), 1) - self.assertEqual( - results.at(0).source, { - 'db_name': 'Predicted Crystallography Open Database', - 'db_uri': 'http://www.crystallography.net/pcod', - 'extras': {}, - 'id': '12345678', - 'license': 'CC0', - 'source_md5': None, - 'uri': 'http://www.crystallography.net/pcod/cif/1/123/12345678.cif', - 'version': None, - } - ) - - -class TestMpodDbImporter(AiidaTestCase): + assert len(results) == 1 + assert results.at(0).source == { + 'db_name': 'Predicted Crystallography Open Database', + 'db_uri': 'http://www.crystallography.net/pcod', + 'extras': {}, + 'id': '12345678', + 'license': 'CC0', + 'source_md5': None, + 'uri': 'http://www.crystallography.net/pcod/cif/1/123/12345678.cif', + 'version': None, + } + + +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestMpodDbImporter: """Test the MpodDbImporter class.""" def test_dbentry_creation(self): @@ -227,22 +223,21 @@ def test_dbentry_creation(self): from aiida.tools.dbimporters.plugins.mpod import MpodSearchResults results = MpodSearchResults([{'id': '1234567'}]) - self.assertEqual(len(results), 1) - self.assertEqual( - results.at(0).source, { - 'db_name': 'Material Properties Open Database', - 'db_uri': 'http://mpod.cimav.edu.mx', - 'extras': {}, - 'id': '1234567', - 'license': None, - 'source_md5': None, - 'uri': 'http://mpod.cimav.edu.mx/datafiles/1234567.mpod', - 'version': None, - } - ) - - -class TestNnincDbImporter(AiidaTestCase): + assert len(results) == 1 + assert results.at(0).source == { + 'db_name': 'Material Properties Open Database', + 'db_uri': 'http://mpod.cimav.edu.mx', + 'extras': {}, + 'id': '1234567', + 'license': None, + 'source_md5': None, + 'uri': 'http://mpod.cimav.edu.mx/datafiles/1234567.mpod', + 'version': None, + } + + +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestNnincDbImporter: """Test the UpfEntry class.""" def test_upfentry_creation(self): @@ -262,7 +257,7 @@ def test_upfentry_creation(self): entry._contents = fpntr.read() # pylint: disable=protected-access upfnode = entry.get_upf_node() - self.assertEqual(upfnode.element, 'Ba') + assert upfnode.element == 'Ba' entry.source = {'id': 'O.pbesol-n-rrkjus_psl.0.1-tested-pslib030.UPF'} @@ -270,5 +265,5 @@ def test_upfentry_creation(self): # thus UpfData parser will complain about the mismatch of chemical # element, mentioned in file name, and the one described in the # pseudopotential file. - with self.assertRaises(ParsingError): + with pytest.raises(ParsingError): upfnode = entry.get_upf_node() diff --git a/tests/test_generic.py b/tests/test_generic.py index 0a05834d67..731dc0a074 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -7,96 +7,98 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Generic tests that need the use of the DB.""" +import pytest -from aiida.backends.testbase import AiidaTestCase from aiida import orm -class TestCode(AiidaTestCase): - """Test the Code class.""" +@pytest.mark.usefixtures('clear_database_before_test') +def test_code_local(aiida_localhost): + """Test local code.""" + import tempfile - def test_code_local(self): - """Test local code.""" - import tempfile + from aiida.orm import Code + from aiida.common.exceptions import ValidationError - from aiida.orm import Code - from aiida.common.exceptions import ValidationError - - code = Code(local_executable='test.sh') - with self.assertRaises(ValidationError): - # No file with name test.sh - code.store() + code = Code(local_executable='test.sh') + with pytest.raises(ValidationError): + # No file with name test.sh + code.store() - with tempfile.NamedTemporaryFile(mode='w+') as fhandle: - fhandle.write('#/bin/bash\n\necho test run\n') - fhandle.flush() - code.put_object_from_filelike(fhandle, 'test.sh') + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write('#/bin/bash\n\necho test run\n') + fhandle.flush() + code.put_object_from_filelike(fhandle, 'test.sh') - code.store() - self.assertTrue(code.can_run_on(self.computer)) - self.assertTrue(code.get_local_executable(), 'test.sh') - self.assertTrue(code.get_execname(), 'stest.sh') + code.store() + assert code.can_run_on(aiida_localhost) + assert code.get_local_executable(), 'test.sh' + assert code.get_execname(), 'stest.sh' - def test_remote(self): - """Test remote code.""" - import tempfile - from aiida.orm import Code - from aiida.common.exceptions import ValidationError +@pytest.mark.usefixtures('clear_database_before_test') +def test_code_remote(aiida_localhost): + """Test remote code.""" + import tempfile - with self.assertRaises(ValueError): - # remote_computer_exec has length 2 but is not a list or tuple - Code(remote_computer_exec='ab') + from aiida.orm import Code + from aiida.common.exceptions import ValidationError - # invalid code path - with self.assertRaises(ValueError): - Code(remote_computer_exec=(self.computer, '')) + with pytest.raises(ValueError): + # remote_computer_exec has length 2 but is not a list or tuple + Code(remote_computer_exec='ab') - # Relative path is invalid for remote code - with self.assertRaises(ValueError): - Code(remote_computer_exec=(self.computer, 'subdir/run.exe')) + # invalid code path + with pytest.raises(ValueError): + Code(remote_computer_exec=(aiida_localhost, '')) - # first argument should be a computer, not a string - with self.assertRaises(TypeError): - Code(remote_computer_exec=('localhost', '/bin/ls')) + # Relative path is invalid for remote code + with pytest.raises(ValueError): + Code(remote_computer_exec=(aiida_localhost, 'subdir/run.exe')) - code = Code(remote_computer_exec=(self.computer, '/bin/ls')) - with tempfile.NamedTemporaryFile(mode='w+') as fhandle: - fhandle.write('#/bin/bash\n\necho test run\n') - fhandle.flush() - code.put_object_from_filelike(fhandle, 'test.sh') + # first argument should be a computer, not a string + with pytest.raises(TypeError): + Code(remote_computer_exec=('localhost', '/bin/ls')) - with self.assertRaises(ValidationError): - # There are files inside - code.store() + code = Code(remote_computer_exec=(aiida_localhost, '/bin/ls')) + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write('#/bin/bash\n\necho test run\n') + fhandle.flush() + code.put_object_from_filelike(fhandle, 'test.sh') - # If there are no files, I can store - code.delete_object('test.sh') + with pytest.raises(ValidationError): + # There are files inside code.store() - self.assertEqual(code.get_remote_computer().pk, self.computer.pk) # pylint: disable=no-member - self.assertEqual(code.get_remote_exec_path(), '/bin/ls') - self.assertEqual(code.get_execname(), '/bin/ls') + # If there are no files, I can store + code.delete_object('test.sh') + code.store() + + assert code.get_remote_computer().pk == aiida_localhost.pk # pylint: disable=no-member + assert code.get_remote_exec_path() == '/bin/ls' + assert code.get_execname() == '/bin/ls' - self.assertTrue(code.can_run_on(self.computer)) - othercomputer = orm.Computer( - label='another_localhost', - hostname='localhost', - transport_type='local', - scheduler_type='pbspro', - workdir='/tmp/aiida' - ).store() - self.assertFalse(code.can_run_on(othercomputer)) + assert code.can_run_on(aiida_localhost) + othercomputer = orm.Computer( + label='another_localhost', + hostname='localhost', + transport_type='local', + scheduler_type='pbspro', + workdir='/tmp/aiida' + ).store() + assert not code.can_run_on(othercomputer) -class TestBool(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestBool: """Test AiiDA Bool class.""" def test_bool_conversion(self): for val in [True, False]: - self.assertEqual(val, bool(orm.Bool(val))) + assert val == bool(orm.Bool(val)) def test_int_conversion(self): for val in [True, False]: - self.assertEqual(int(val), int(orm.Bool(val))) + assert int(val) == int(orm.Bool(val)) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 530ecad83b..f7bb23c0a1 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -7,29 +7,31 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines,invalid-name,protected-access +# pylint: disable=no-self-use,too-many-lines,invalid-name,protected-access # pylint: disable=missing-docstring,too-many-locals,too-many-statements # pylint: disable=too-many-public-methods import copy import io import tempfile +import pytest + from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.common.exceptions import InvalidOperation, ModificationNotAllowed, StoringNotAllowed, ValidationError from aiida.common.links import LinkType from aiida.tools import delete_nodes, delete_group_nodes -class TestNodeIsStorable(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestNodeIsStorable: """Test that checks on storability of certain node sub classes work correctly.""" def test_base_classes(self): """Test storability of `Node` base sub classes.""" - with self.assertRaises(StoringNotAllowed): + with pytest.raises(StoringNotAllowed): orm.Node().store() - with self.assertRaises(StoringNotAllowed): + with pytest.raises(StoringNotAllowed): orm.ProcessNode().store() # The following base classes are storable @@ -43,27 +45,29 @@ def test_unregistered_sub_class(self): class SubData(orm.Data): pass - with self.assertRaises(StoringNotAllowed): + with pytest.raises(StoringNotAllowed): SubData().store() -class TestNodeCopyDeepcopy(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestNodeCopyDeepcopy: """Test that calling copy and deepcopy on a Node does the right thing.""" def test_copy_not_supported(self): """Copying a base Node instance is not supported.""" node = orm.Node() - with self.assertRaises(InvalidOperation): + with pytest.raises(InvalidOperation): copy.copy(node) def test_deepcopy_not_supported(self): """Deep copying a base Node instance is not supported.""" node = orm.Node() - with self.assertRaises(InvalidOperation): + with pytest.raises(InvalidOperation): copy.deepcopy(node) -class TestNodeHashing(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestNodeHashing: """ Tests the functionality of hashing a node """ @@ -95,8 +99,8 @@ def test_node_uuid_hashing_for_querybuidler(self): # Check that the query doesn't fail qb.all() # And that the results are correct - self.assertEqual(qb.count(), 1) - self.assertEqual(qb.first()[0], n.id) + assert qb.count() == 1 + assert qb.first()[0] == n.id @staticmethod def create_folderdata_with_empty_file(): @@ -127,11 +131,12 @@ def test_updatable_attributes(self): hash1 = node.get_hash() node.set_process_state('finished') hash2 = node.get_hash() - self.assertNotEqual(hash1, None) - self.assertEqual(hash1, hash2) + assert hash1 is not None + assert hash1 == hash2 -class TestTransitiveNoLoops(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestTransitiveNoLoops: """ Test the transitive closure functionality """ @@ -148,11 +153,12 @@ def test_loop_not_allowed(self): c2.add_incoming(d2, link_type=LinkType.INPUT_CALC, link_label='link') c2.store() - with self.assertRaises(ValueError): # This would generate a loop + with pytest.raises(ValueError): # This would generate a loop d1.add_incoming(c2, link_type=LinkType.CREATE, link_label='link') -class TestTypes(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestTypes: """ Generic test class to test types """ @@ -166,68 +172,68 @@ def test_uuid_type(self): results = orm.QueryBuilder().append(orm.Data, project=('uuid', '*')).all() for uuid, data in results: - self.assertTrue(isinstance(uuid, str)) - self.assertTrue(isinstance(data.uuid, str)) + assert isinstance(uuid, str) + assert isinstance(data.uuid, str) -class TestQueryWithAiidaObjects(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test') +def test_query_with_subclasses(aiida_localhost): """ Test if queries work properly also with aiida.orm.Node classes instead of aiida.backends.djsite.db.models.DbNode objects. """ - - def test_with_subclasses(self): - from aiida.plugins import DataFactory - - extra_name = f'{self.__class__.__name__}/test_with_subclasses' - - Dict = DataFactory('dict') - - a1 = orm.CalcJobNode(computer=self.computer) - a1.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) - a1.store() - # To query only these nodes later - a1.set_extra(extra_name, True) - a3 = orm.Data().store() - a3.set_extra(extra_name, True) - a4 = Dict(dict={'a': 'b'}).store() - a4.set_extra(extra_name, True) - # I don't set the extras, just to be sure that the filtering works - # The filtering is needed because other tests will put stuff int he DB - a6 = orm.CalcJobNode(computer=self.computer) - a6.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) - a1.store() - a7 = orm.Data() - a7.store() - - # Query by calculation - qb = orm.QueryBuilder() - qb.append(orm.CalcJobNode, filters={'extras': {'has_key': extra_name}}) - results = qb.all(flat=True) - # a3, a4 should not be found because they are not CalcJobNodes. - # a6, a7 should not be found because they have not the attribute set. - self.assertEqual({i.pk for i in results}, {a1.pk}) - - # Same query, but by the generic Node class - qb = orm.QueryBuilder() - qb.append(orm.Node, filters={'extras': {'has_key': extra_name}}) - results = qb.all(flat=True) - self.assertEqual({i.pk for i in results}, {a1.pk, a3.pk, a4.pk}) - - # Same query, but by the Data class - qb = orm.QueryBuilder() - qb.append(orm.Data, filters={'extras': {'has_key': extra_name}}) - results = qb.all(flat=True) - self.assertEqual({i.pk for i in results}, {a3.pk, a4.pk}) - - # Same query, but by the Dict subclass - qb = orm.QueryBuilder() - qb.append(orm.Dict, filters={'extras': {'has_key': extra_name}}) - results = qb.all(flat=True) - self.assertEqual({i.pk for i in results}, {a4.pk}) - - -class TestNodeBasic(AiidaTestCase): + from aiida.plugins import DataFactory + + extra_name = 'test_query_with_subclasses' + + Dict = DataFactory('dict') + + a1 = orm.CalcJobNode(computer=aiida_localhost) + a1.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) + a1.store() + # To query only these nodes later + a1.set_extra(extra_name, True) + a3 = orm.Data().store() + a3.set_extra(extra_name, True) + a4 = Dict(dict={'a': 'b'}).store() + a4.set_extra(extra_name, True) + # I don't set the extras, just to be sure that the filtering works + # The filtering is needed because other tests will put stuff int he DB + a6 = orm.CalcJobNode(computer=aiida_localhost) + a6.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) + a1.store() + a7 = orm.Data() + a7.store() + + # Query by calculation + qb = orm.QueryBuilder() + qb.append(orm.CalcJobNode, filters={'extras': {'has_key': extra_name}}) + results = qb.all(flat=True) + # a3, a4 should not be found because they are not CalcJobNodes. + # a6, a7 should not be found because they have not the attribute set. + assert {i.pk for i in results} == {a1.pk} + + # Same query, but by the generic Node class + qb = orm.QueryBuilder() + qb.append(orm.Node, filters={'extras': {'has_key': extra_name}}) + results = qb.all(flat=True) + assert {i.pk for i in results} == {a1.pk, a3.pk, a4.pk} + + # Same query, but by the Data class + qb = orm.QueryBuilder() + qb.append(orm.Data, filters={'extras': {'has_key': extra_name}}) + results = qb.all(flat=True) + assert {i.pk for i in results} == {a3.pk, a4.pk} + + # Same query, but by the Dict subclass + qb = orm.QueryBuilder() + qb.append(orm.Dict, filters={'extras': {'has_key': extra_name}}) + results = qb.all(flat=True) + assert {i.pk for i in results} == {a4.pk} + + +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestNodeBasic: """ These tests check the basic features of nodes (setting of attributes, copying of files, ...) @@ -258,6 +264,20 @@ class TestNodeBasic(AiidaTestCase): emptydict = {} emptylist = [] + def setup_method(self): + """Add a computer.""" + # pylint: disable=attribute-defined-outside-init + if not hasattr(self, 'computer'): + created, self.computer = orm.Computer.objects.get_or_create( + label='localhost', + hostname='localhost', + transport_type='local', + scheduler_type='direct', + workdir='/tmp/aiida', + ) + if created: + self.computer.store() + def test_uuid_uniquess(self): """ A uniqueness constraint on the UUID column of the Node model should prevent multiple nodes with identical UUID @@ -270,7 +290,7 @@ def test_uuid_uniquess(self): b.backend_entity.dbmodel.uuid = a.uuid a.store() - with self.assertRaises((DjIntegrityError, SqlaIntegrityError)): + with pytest.raises((DjIntegrityError, SqlaIntegrityError)): b.store() def test_attribute_mutability(self): @@ -284,10 +304,10 @@ def test_attribute_mutability(self): a.store() # After storing attributes should now be immutable - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): a.delete_attribute('bool') - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): a.set_attribute('integer', self.intval) def test_attr_before_storing(self): @@ -303,15 +323,15 @@ def test_attr_before_storing(self): a.set_attribute('k9', None) # Now I check if I can retrieve them, before the storage - self.assertEqual(self.boolval, a.get_attribute('k1')) - self.assertEqual(self.intval, a.get_attribute('k2')) - self.assertEqual(self.floatval, a.get_attribute('k3')) - self.assertEqual(self.stringval, a.get_attribute('k4')) - self.assertEqual(self.dictval, a.get_attribute('k5')) - self.assertEqual(self.listval, a.get_attribute('k6')) - self.assertEqual(self.emptydict, a.get_attribute('k7')) - self.assertEqual(self.emptylist, a.get_attribute('k8')) - self.assertIsNone(a.get_attribute('k9')) + assert self.boolval == a.get_attribute('k1') + assert self.intval == a.get_attribute('k2') + assert self.floatval == a.get_attribute('k3') + assert self.stringval == a.get_attribute('k4') + assert self.dictval == a.get_attribute('k5') + assert self.listval == a.get_attribute('k6') + assert self.emptydict == a.get_attribute('k7') + assert self.emptylist == a.get_attribute('k8') + assert a.get_attribute('k9') is None # And now I try to delete the keys a.delete_attribute('k1') @@ -324,19 +344,19 @@ def test_attr_before_storing(self): a.delete_attribute('k8') a.delete_attribute('k9') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): # I delete twice the same attribute a.delete_attribute('k1') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): # I delete a non-existing attribute a.delete_attribute('nonexisting') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): # I get a deleted attribute a.get_attribute('k1') - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): # I get a non-existing attribute a.get_attribute('nonexisting') @@ -365,7 +385,7 @@ def test_get_attrs_before_storing(self): } # Now I check if I can retrieve them, before the storage - self.assertEqual(a.attributes, target_attrs) + assert a.attributes == target_attrs # And now I try to delete the keys a.delete_attribute('k1') @@ -378,7 +398,7 @@ def test_get_attrs_before_storing(self): a.delete_attribute('k8') a.delete_attribute('k9') - self.assertEqual(a.attributes, {}) + assert a.attributes == {} def test_get_attrs_after_storing(self): a = orm.Data() @@ -407,19 +427,19 @@ def test_get_attrs_after_storing(self): } # Now I check if I can retrieve them, before the storage - self.assertEqual(a.attributes, target_attrs) + assert a.attributes == target_attrs def test_store_object(self): """Trying to set objects as attributes should fail, because they are not json-serializable.""" a = orm.Data() a.set_attribute('object', object()) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): a.store() b = orm.Data() b.set_attribute('object_list', [object(), object()]) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): b.store() def test_attributes_on_clone(self): @@ -450,9 +470,9 @@ def test_attributes_on_clone(self): b_expected_attributes['new'] = 'cvb' # I check before storing that the attributes are ok - self.assertEqual(b.attributes, b_expected_attributes) + assert b.attributes == b_expected_attributes # Note that during copy, I do not copy the extras! - self.assertEqual(b.extras, {}) + assert b.extras == {} # I store now b.store() @@ -461,11 +481,11 @@ def test_attributes_on_clone(self): b_expected_extras = {'meta': 'textofext', '_aiida_hash': AnyValue()} # Now I check that the attributes of the original node have not changed - self.assertEqual(a.attributes, attrs_to_set) + assert a.attributes == attrs_to_set # I check then on the 'b' copy - self.assertEqual(b.attributes, b_expected_attributes) - self.assertEqual(b.extras, b_expected_extras) + assert b.attributes == b_expected_attributes + assert b.extras == b_expected_extras def test_files(self): a = orm.Data() @@ -479,21 +499,21 @@ def test_files(self): a.put_object_from_file(handle.name, 'file1.txt') a.put_object_from_file(handle.name, 'file2.txt') - self.assertEqual(set(a.list_object_names()), set(['file1.txt', 'file2.txt'])) + assert set(a.list_object_names()) == set(['file1.txt', 'file2.txt']) with a.open('file1.txt') as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content with a.open('file2.txt') as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content b = a.clone() - self.assertNotEqual(a.uuid, b.uuid) + assert a.uuid != b.uuid # Check that the content is there - self.assertEqual(set(b.list_object_names()), set(['file1.txt', 'file2.txt'])) + assert set(b.list_object_names()) == set(['file1.txt', 'file2.txt']) with b.open('file1.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content with b.open('file2.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content # I overwrite a file and create a new one in the clone only with tempfile.NamedTemporaryFile(mode='w+') as handle: @@ -503,18 +523,18 @@ def test_files(self): b.put_object_from_file(handle.name, 'file3.txt') # I check the new content, and that the old one has not changed - self.assertEqual(set(a.list_object_names()), set(['file1.txt', 'file2.txt'])) + assert set(a.list_object_names()) == set(['file1.txt', 'file2.txt']) with a.open('file1.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content with a.open('file2.txt') as handle: - self.assertEqual(handle.read(), file_content) - self.assertEqual(set(b.list_object_names()), set(['file1.txt', 'file2.txt', 'file3.txt'])) + assert handle.read() == file_content + assert set(b.list_object_names()) == set(['file1.txt', 'file2.txt', 'file3.txt']) with b.open('file1.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content with b.open('file2.txt') as handle: - self.assertEqual(handle.read(), file_content_different) + assert handle.read() == file_content_different with b.open('file3.txt') as handle: - self.assertEqual(handle.read(), file_content_different) + assert handle.read() == file_content_different # This should in principle change the location of the files, # so I recheck @@ -529,19 +549,19 @@ def test_files(self): c.put_object_from_file(handle.name, 'file1.txt') c.put_object_from_file(handle.name, 'file4.txt') - self.assertEqual(set(a.list_object_names()), set(['file1.txt', 'file2.txt'])) + assert set(a.list_object_names()) == set(['file1.txt', 'file2.txt']) with a.open('file1.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content with a.open('file2.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content - self.assertEqual(set(c.list_object_names()), set(['file1.txt', 'file2.txt', 'file4.txt'])) + assert set(c.list_object_names()) == set(['file1.txt', 'file2.txt', 'file4.txt']) with c.open('file1.txt') as handle: - self.assertEqual(handle.read(), file_content_different) + assert handle.read() == file_content_different with c.open('file2.txt') as handle: - self.assertEqual(handle.read(), file_content) + assert handle.read() == file_content with c.open('file4.txt') as handle: - self.assertEqual(handle.read(), file_content_different) + assert handle.read() == file_content_different def test_folders(self): """ @@ -580,30 +600,30 @@ def test_folders(self): a.put_object_from_tree(tree_1, 'tree_1') # verify if the node has the structure I expect - self.assertEqual(set(a.list_object_names()), set(['tree_1'])) - self.assertEqual(set(a.list_object_names('tree_1')), set(['file1.txt', 'dir1'])) - self.assertEqual(set(a.list_object_names(os.path.join('tree_1', 'dir1'))), set(['dir2', 'file2.txt'])) + assert set(a.list_object_names()) == set(['tree_1']) + assert set(a.list_object_names('tree_1')) == set(['file1.txt', 'dir1']) + assert set(a.list_object_names(os.path.join('tree_1', 'dir1'))) == set(['dir2', 'file2.txt']) with a.open(os.path.join('tree_1', 'file1.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content with a.open(os.path.join('tree_1', 'dir1', 'file2.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # try to exit from the folder - with self.assertRaises(ValueError): + with pytest.raises(ValueError): a.list_object_names('..') # clone into a new node b = a.clone() - self.assertNotEqual(a.uuid, b.uuid) + assert a.uuid != b.uuid # Check that the content is there - self.assertEqual(set(b.list_object_names('.')), set(['tree_1'])) - self.assertEqual(set(b.list_object_names('tree_1')), set(['file1.txt', 'dir1'])) - self.assertEqual(set(b.list_object_names(os.path.join('tree_1', 'dir1'))), set(['dir2', 'file2.txt'])) + assert set(b.list_object_names('.')) == set(['tree_1']) + assert set(b.list_object_names('tree_1')) == set(['file1.txt', 'dir1']) + assert set(b.list_object_names(os.path.join('tree_1', 'dir1'))) == set(['dir2', 'file2.txt']) with b.open(os.path.join('tree_1', 'file1.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content with b.open(os.path.join('tree_1', 'dir1', 'file2.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # I overwrite a file and create a new one in the copy only dir3 = os.path.join(directory, 'dir3') @@ -611,28 +631,28 @@ def test_folders(self): b.put_object_from_tree(dir3, os.path.join('tree_1', 'dir3')) # no absolute path here - with self.assertRaises(ValueError): + with pytest.raises(ValueError): b.put_object_from_tree('dir3', os.path.join('tree_1', 'dir3')) stream = io.StringIO(file_content_different) b.put_object_from_filelike(stream, 'file3.txt') # I check the new content, and that the old one has not changed old - self.assertEqual(set(a.list_object_names('.')), set(['tree_1'])) - self.assertEqual(set(a.list_object_names('tree_1')), set(['file1.txt', 'dir1'])) - self.assertEqual(set(a.list_object_names(os.path.join('tree_1', 'dir1'))), set(['dir2', 'file2.txt'])) + assert set(a.list_object_names('.')) == set(['tree_1']) + assert set(a.list_object_names('tree_1')) == set(['file1.txt', 'dir1']) + assert set(a.list_object_names(os.path.join('tree_1', 'dir1'))) == set(['dir2', 'file2.txt']) with a.open(os.path.join('tree_1', 'file1.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content with a.open(os.path.join('tree_1', 'dir1', 'file2.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # new - self.assertEqual(set(b.list_object_names('.')), set(['tree_1', 'file3.txt'])) - self.assertEqual(set(b.list_object_names('tree_1')), set(['file1.txt', 'dir1', 'dir3'])) - self.assertEqual(set(b.list_object_names(os.path.join('tree_1', 'dir1'))), set(['dir2', 'file2.txt'])) + assert set(b.list_object_names('.')) == set(['tree_1', 'file3.txt']) + assert set(b.list_object_names('tree_1')) == set(['file1.txt', 'dir1', 'dir3']) + assert set(b.list_object_names(os.path.join('tree_1', 'dir1'))) == set(['dir2', 'file2.txt']) with b.open(os.path.join('tree_1', 'file1.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content with b.open(os.path.join('tree_1', 'dir1', 'file2.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # This should in principle change the location of the files, so I recheck a.store() @@ -647,22 +667,22 @@ def test_folders(self): c.delete_object(os.path.join('tree_1', 'dir1', 'dir2')) # check old - self.assertEqual(set(a.list_object_names('.')), set(['tree_1'])) - self.assertEqual(set(a.list_object_names('tree_1')), set(['file1.txt', 'dir1'])) - self.assertEqual(set(a.list_object_names(os.path.join('tree_1', 'dir1'))), set(['dir2', 'file2.txt'])) + assert set(a.list_object_names('.')) == set(['tree_1']) + assert set(a.list_object_names('tree_1')) == set(['file1.txt', 'dir1']) + assert set(a.list_object_names(os.path.join('tree_1', 'dir1'))) == set(['dir2', 'file2.txt']) with a.open(os.path.join('tree_1', 'file1.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content with a.open(os.path.join('tree_1', 'dir1', 'file2.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # check new - self.assertEqual(set(c.list_object_names('.')), set(['tree_1'])) - self.assertEqual(set(c.list_object_names('tree_1')), set(['file1.txt', 'dir1'])) - self.assertEqual(set(c.list_object_names(os.path.join('tree_1', 'dir1'))), set(['file2.txt', 'file4.txt'])) + assert set(c.list_object_names('.')) == set(['tree_1']) + assert set(c.list_object_names('tree_1')) == set(['file1.txt', 'dir1']) + assert set(c.list_object_names(os.path.join('tree_1', 'dir1'))) == set(['file2.txt', 'file4.txt']) with c.open(os.path.join('tree_1', 'file1.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content_different) + assert fhandle.read() == file_content_different with c.open(os.path.join('tree_1', 'dir1', 'file2.txt')) as fhandle: - self.assertEqual(fhandle.read(), file_content) + assert fhandle.read() == file_content # garbage cleaning shutil.rmtree(directory) @@ -680,13 +700,13 @@ def test_attr_after_storing(self): a.store() # Now I check if I can retrieve them, before the storage - self.assertIsNone(a.get_attribute('none')) - self.assertEqual(self.boolval, a.get_attribute('bool')) - self.assertEqual(self.intval, a.get_attribute('integer')) - self.assertEqual(self.floatval, a.get_attribute('float')) - self.assertEqual(self.stringval, a.get_attribute('string')) - self.assertEqual(self.dictval, a.get_attribute('dict')) - self.assertEqual(self.listval, a.get_attribute('list')) + assert a.get_attribute('none') is None + assert self.boolval == a.get_attribute('bool') + assert self.intval == a.get_attribute('integer') + assert self.floatval == a.get_attribute('float') + assert self.stringval == a.get_attribute('string') + assert self.dictval == a.get_attribute('dict') + assert self.listval == a.get_attribute('list') def test_attr_with_reload(self): a = orm.Data() @@ -701,13 +721,13 @@ def test_attr_with_reload(self): a.store() b = orm.load_node(uuid=a.uuid) - self.assertIsNone(a.get_attribute('none')) - self.assertEqual(self.boolval, b.get_attribute('bool')) - self.assertEqual(self.intval, b.get_attribute('integer')) - self.assertEqual(self.floatval, b.get_attribute('float')) - self.assertEqual(self.stringval, b.get_attribute('string')) - self.assertEqual(self.dictval, b.get_attribute('dict')) - self.assertEqual(self.listval, b.get_attribute('list')) + assert a.get_attribute('none') is None + assert self.boolval == b.get_attribute('bool') + assert self.intval == b.get_attribute('integer') + assert self.floatval == b.get_attribute('float') + assert self.stringval == b.get_attribute('string') + assert self.dictval == b.get_attribute('dict') + assert self.listval == b.get_attribute('list') def test_extra_with_reload(self): a = orm.Data() @@ -720,42 +740,42 @@ def test_extra_with_reload(self): a.set_extra('list', self.listval) # Check before storing - self.assertEqual(self.boolval, a.get_extra('bool')) - self.assertEqual(self.intval, a.get_extra('integer')) - self.assertEqual(self.floatval, a.get_extra('float')) - self.assertEqual(self.stringval, a.get_extra('string')) - self.assertEqual(self.dictval, a.get_extra('dict')) - self.assertEqual(self.listval, a.get_extra('list')) + assert self.boolval == a.get_extra('bool') + assert self.intval == a.get_extra('integer') + assert self.floatval == a.get_extra('float') + assert self.stringval == a.get_extra('string') + assert self.dictval == a.get_extra('dict') + assert self.listval == a.get_extra('list') a.store() # Check after storing - self.assertEqual(self.boolval, a.get_extra('bool')) - self.assertEqual(self.intval, a.get_extra('integer')) - self.assertEqual(self.floatval, a.get_extra('float')) - self.assertEqual(self.stringval, a.get_extra('string')) - self.assertEqual(self.dictval, a.get_extra('dict')) - self.assertEqual(self.listval, a.get_extra('list')) + assert self.boolval == a.get_extra('bool') + assert self.intval == a.get_extra('integer') + assert self.floatval == a.get_extra('float') + assert self.stringval == a.get_extra('string') + assert self.dictval == a.get_extra('dict') + assert self.listval == a.get_extra('list') b = orm.load_node(uuid=a.uuid) - self.assertIsNone(a.get_extra('none')) - self.assertEqual(self.boolval, b.get_extra('bool')) - self.assertEqual(self.intval, b.get_extra('integer')) - self.assertEqual(self.floatval, b.get_extra('float')) - self.assertEqual(self.stringval, b.get_extra('string')) - self.assertEqual(self.dictval, b.get_extra('dict')) - self.assertEqual(self.listval, b.get_extra('list')) + assert a.get_extra('none') is None + assert self.boolval == b.get_extra('bool') + assert self.intval == b.get_extra('integer') + assert self.floatval == b.get_extra('float') + assert self.stringval == b.get_extra('string') + assert self.dictval == b.get_extra('dict') + assert self.listval == b.get_extra('list') def test_get_extras_with_default(self): a = orm.Data() a.store() a.set_extra('a', 'b') - self.assertEqual(a.get_extra('a'), 'b') - with self.assertRaises(AttributeError): + assert a.get_extra('a') == 'b' + with pytest.raises(AttributeError): a.get_extra('c') - self.assertEqual(a.get_extra('c', 'def'), 'def') + assert a.get_extra('c', 'def') == 'def' @staticmethod def test_attr_and_extras_multikey(): @@ -799,12 +819,12 @@ def test_attr_listing(self): all_extras = dict(_aiida_hash=AnyValue(), **extras_to_set) - self.assertEqual(set(list(a.attributes.keys())), set(attrs_to_set.keys())) - self.assertEqual(set(list(a.extras.keys())), set(all_extras.keys())) + assert set(list(a.attributes.keys())) == set(attrs_to_set.keys()) + assert set(list(a.extras.keys())) == set(all_extras.keys()) - self.assertEqual(a.attributes, attrs_to_set) + assert a.attributes == attrs_to_set - self.assertEqual(a.extras, all_extras) + assert a.extras == all_extras def test_delete_extras(self): """ @@ -828,7 +848,7 @@ def test_delete_extras(self): for k, v in extras_to_set.items(): a.set_extra(k, v) - self.assertEqual(a.extras, all_extras) + assert a.extras == all_extras # I pregenerate it, it cannot change during iteration list_keys = list(extras_to_set.keys()) @@ -837,7 +857,7 @@ def test_delete_extras(self): # performed correctly a.delete_extra(k) del all_extras[k] - self.assertEqual(a.extras, all_extras) + assert a.extras == all_extras def test_replace_extras_1(self): """ @@ -880,7 +900,7 @@ def test_replace_extras_1(self): for k, v in extras_to_set.items(): a.set_extra(k, v) - self.assertEqual(a.extras, all_extras) + assert a.extras == all_extras for k, v in new_extras.items(): # I delete one by one the keys and check if the operation is @@ -890,7 +910,7 @@ def test_replace_extras_1(self): # I update extras_to_set with the new entries, and do the comparison # again all_extras.update(new_extras) - self.assertEqual(a.extras, all_extras) + assert a.extras == all_extras def test_basetype_as_attr(self): """ @@ -908,28 +928,28 @@ def test_basetype_as_attr(self): # Manages to store, and value is converted to its base type p = orm.Dict(dict={'b': orm.Str('sometext'), 'c': l1}) p.store() - self.assertEqual(p.get_attribute('b'), 'sometext') - self.assertIsInstance(p.get_attribute('b'), str) - self.assertEqual(p.get_attribute('c'), ['b', [1, 2]]) - self.assertIsInstance(p.get_attribute('c'), (list, tuple)) + assert p.get_attribute('b') == 'sometext' + assert isinstance(p.get_attribute('b'), str) + assert p.get_attribute('c') == ['b', [1, 2]] + assert isinstance(p.get_attribute('c'), (list, tuple)) # Check also before storing n = orm.Data() n.set_attribute('a', orm.Str('sometext2')) n.set_attribute('b', l2) - self.assertEqual(n.get_attribute('a').value, 'sometext2') - self.assertIsInstance(n.get_attribute('a'), orm.Str) - self.assertEqual(n.get_attribute('b').get_list(), ['f', True, {'gg': None}]) - self.assertIsInstance(n.get_attribute('b'), orm.List) + assert n.get_attribute('a').value == 'sometext2' + assert isinstance(n.get_attribute('a'), orm.Str) + assert n.get_attribute('b').get_list() == ['f', True, {'gg': None}] + assert isinstance(n.get_attribute('b'), orm.List) # Check also deep in a dictionary/list n = orm.Data() n.set_attribute('a', {'b': [orm.Str('sometext3')]}) - self.assertEqual(n.get_attribute('a')['b'][0].value, 'sometext3') - self.assertIsInstance(n.get_attribute('a')['b'][0], orm.Str) + assert n.get_attribute('a')['b'][0].value == 'sometext3' + assert isinstance(n.get_attribute('a')['b'][0], orm.Str) n.store() - self.assertEqual(n.get_attribute('a')['b'][0], 'sometext3') - self.assertIsInstance(n.get_attribute('a')['b'][0], str) + assert n.get_attribute('a')['b'][0] == 'sometext3' + assert isinstance(n.get_attribute('a')['b'][0], str) def test_basetype_as_extra(self): """ @@ -950,19 +970,19 @@ def test_basetype_as_extra(self): n.set_extra('a', orm.Str('sometext2')) n.set_extra('c', l1) n.set_extra('d', l2) - self.assertEqual(n.get_extra('a'), 'sometext2') - self.assertIsInstance(n.get_extra('a'), str) - self.assertEqual(n.get_extra('c'), ['b', [1, 2]]) - self.assertIsInstance(n.get_extra('c'), (list, tuple)) - self.assertEqual(n.get_extra('d'), ['f', True, {'gg': None}]) - self.assertIsInstance(n.get_extra('d'), (list, tuple)) + assert n.get_extra('a') == 'sometext2' + assert isinstance(n.get_extra('a'), str) + assert n.get_extra('c') == ['b', [1, 2]] + assert isinstance(n.get_extra('c'), (list, tuple)) + assert n.get_extra('d') == ['f', True, {'gg': None}] + assert isinstance(n.get_extra('d'), (list, tuple)) # Check also deep in a dictionary/list n = orm.Data() n.store() n.set_extra('a', {'b': [orm.Str('sometext3')]}) - self.assertEqual(n.get_extra('a')['b'][0], 'sometext3') - self.assertIsInstance(n.get_extra('a')['b'][0], str) + assert n.get_extra('a')['b'][0] == 'sometext3' + assert isinstance(n.get_extra('a')['b'][0], str) def test_comments(self): # This is the best way to compare dates with the stored ones, instead @@ -975,11 +995,11 @@ def test_comments(self): user = orm.User.objects.get_default() a = orm.Data() - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): a.add_comment('text', user=user) a.store() - self.assertEqual(a.get_comments(), []) + assert a.get_comments() == [] before = timezone.now() - timedelta(seconds=1) a.add_comment('text', user=user) @@ -994,13 +1014,13 @@ def test_comments(self): times = [i.ctime for i in comments] for time in times: - self.assertTrue(time > before) - self.assertTrue(time < after) + assert time > before + assert time < after - self.assertEqual([(i.user.email, i.content) for i in comments], [ - (self.user_email, 'text'), - (self.user_email, 'text2'), - ]) + assert [(i.user.email, i.content) for i in comments] == [ + (user.email, 'text'), + (user.email, 'text2'), + ] def test_code_loading_from_string(self): """ @@ -1021,22 +1041,22 @@ def test_code_loading_from_string(self): # Test that the code1 can be loaded correctly with its label q_code_1 = orm.Code.get_from_string(code1.label) - self.assertEqual(q_code_1.id, code1.id) - self.assertEqual(q_code_1.label, code1.label) - self.assertEqual(q_code_1.get_remote_exec_path(), code1.get_remote_exec_path()) + assert q_code_1.id == code1.id + assert q_code_1.label == code1.label + assert q_code_1.get_remote_exec_path() == code1.get_remote_exec_path() # Test that the code2 can be loaded correctly with its label q_code_2 = orm.Code.get_from_string(f'{code2.label}@{self.computer.label}') # pylint: disable=no-member - self.assertEqual(q_code_2.id, code2.id) - self.assertEqual(q_code_2.label, code2.label) - self.assertEqual(q_code_2.get_remote_exec_path(), code2.get_remote_exec_path()) + assert q_code_2.id == code2.id + assert q_code_2.label == code2.label + assert q_code_2.get_remote_exec_path() == code2.get_remote_exec_path() # Calling get_from_string for a non string type raises exception - with self.assertRaises(InputValidationError): + with pytest.raises(InputValidationError): orm.Code.get_from_string(code1.id) # Test that the lookup of a nonexistent code works as expected - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.Code.get_from_string('nonexistent_code') # Add another code with the label of code1 @@ -1046,7 +1066,7 @@ def test_code_loading_from_string(self): code3.store() # Query with the common label - with self.assertRaises(MultipleObjectsError): + with pytest.raises(MultipleObjectsError): orm.Code.get_from_string(code3.label) def test_code_loading_using_get(self): @@ -1068,30 +1088,30 @@ def test_code_loading_using_get(self): # Test that the code1 can be loaded correctly with its label only q_code_1 = orm.Code.get(label=code1.label) - self.assertEqual(q_code_1.id, code1.id) - self.assertEqual(q_code_1.label, code1.label) - self.assertEqual(q_code_1.get_remote_exec_path(), code1.get_remote_exec_path()) + assert q_code_1.id == code1.id + assert q_code_1.label == code1.label + assert q_code_1.get_remote_exec_path() == code1.get_remote_exec_path() # Test that the code1 can be loaded correctly with its id/pk q_code_1 = orm.Code.get(code1.id) - self.assertEqual(q_code_1.id, code1.id) - self.assertEqual(q_code_1.label, code1.label) - self.assertEqual(q_code_1.get_remote_exec_path(), code1.get_remote_exec_path()) + assert q_code_1.id == code1.id + assert q_code_1.label == code1.label + assert q_code_1.get_remote_exec_path() == code1.get_remote_exec_path() # Test that the code2 can be loaded correctly with its label and computername q_code_2 = orm.Code.get(label=code2.label, machinename=self.computer.label) # pylint: disable=no-member - self.assertEqual(q_code_2.id, code2.id) - self.assertEqual(q_code_2.label, code2.label) - self.assertEqual(q_code_2.get_remote_exec_path(), code2.get_remote_exec_path()) + assert q_code_2.id == code2.id + assert q_code_2.label == code2.label + assert q_code_2.get_remote_exec_path() == code2.get_remote_exec_path() # Test that the code2 can be loaded correctly with its id/pk q_code_2 = orm.Code.get(code2.id) - self.assertEqual(q_code_2.id, code2.id) - self.assertEqual(q_code_2.label, code2.label) - self.assertEqual(q_code_2.get_remote_exec_path(), code2.get_remote_exec_path()) + assert q_code_2.id == code2.id + assert q_code_2.label == code2.label + assert q_code_2.get_remote_exec_path() == code2.get_remote_exec_path() # Test that the lookup of a nonexistent code works as expected - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.Code.get(label='nonexistent_code') # Add another code with the label of code1 @@ -1101,7 +1121,7 @@ def test_code_loading_using_get(self): code3.store() # Query with the common label - with self.assertRaises(MultipleObjectsError): + with pytest.raises(MultipleObjectsError): orm.Code.get(label=code3.label) # Add another code whose label is equal to pk of another code @@ -1115,9 +1135,9 @@ def test_code_loading_using_get(self): # Code.get(pk_label_duplicate) should return code1, as the pk takes # precedence q_code_4 = orm.Code.get(code4.label) - self.assertEqual(q_code_4.id, code1.id) - self.assertEqual(q_code_4.label, code1.label) - self.assertEqual(q_code_4.get_remote_exec_path(), code1.get_remote_exec_path()) + assert q_code_4.id == code1.id + assert q_code_4.label == code1.label + assert q_code_4.get_remote_exec_path() == code1.get_remote_exec_path() def test_code_description(self): """ @@ -1132,10 +1152,10 @@ def test_code_description(self): code.store() q_code1 = orm.Code.get(label=code.label) - self.assertEqual(code.description, str(q_code1.description)) + assert code.description == str(q_code1.description) q_code2 = orm.Code.get(code.id) - self.assertEqual(code.description, str(q_code2.description)) + assert code.description == str(q_code2.description) def test_list_for_plugin(self): """ @@ -1154,10 +1174,10 @@ def test_list_for_plugin(self): code2.store() retrieved_pks = set(orm.Code.list_for_plugin('plugin_name', labels=False)) - self.assertEqual(retrieved_pks, set([code1.pk, code2.pk])) + assert retrieved_pks == set([code1.pk, code2.pk]) retrieved_labels = set(orm.Code.list_for_plugin('plugin_name', labels=True)) - self.assertEqual(retrieved_labels, set([code1.label, code2.label])) + assert retrieved_labels == set([code1.label, code2.label]) def test_load_node(self): """ @@ -1169,38 +1189,53 @@ def test_load_node(self): node = orm.Data().store() uuid_stored = node.uuid # convenience to store the uuid # Simple test to see whether I load correctly from the pk: - self.assertEqual(uuid_stored, orm.load_node(pk=node.pk).uuid) + assert uuid_stored == orm.load_node(pk=node.pk).uuid # Testing the loading with the uuid: - self.assertEqual(uuid_stored, orm.load_node(uuid=uuid_stored).uuid) + assert uuid_stored == orm.load_node(uuid=uuid_stored).uuid # Here I'm testing whether loading the node with the beginnings of a uuid works for i in range(10, len(uuid_stored), 2): start_uuid = uuid_stored[:i] - self.assertEqual(uuid_stored, orm.load_node(uuid=start_uuid).uuid) + assert uuid_stored == orm.load_node(uuid=start_uuid).uuid # Testing whether loading the node with part of UUID works, removing the dashes for i in range(10, len(uuid_stored), 2): start_uuid = uuid_stored[:i].replace('-', '') - self.assertEqual(uuid_stored, orm.load_node(uuid=start_uuid).uuid) + assert uuid_stored == orm.load_node(uuid=start_uuid).uuid # If I don't allow load_node to fix the dashes, this has to raise: - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.load_node(uuid=start_uuid, query_with_dashes=False) # Now I am reverting the order of the uuid, this will raise a NotExistent error: - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.load_node(uuid=uuid_stored[::-1]) # I am giving a non-sensical pk, this should also raise - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.load_node(-1) # Last check, when asking for specific subclass, this should raise: for spec in (node.pk, uuid_stored): - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.load_node(spec, sub_classes=(orm.ArrayData,)) -class TestSubNodesAndLinks(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestSubNodesAndLinks: + + def setup_method(self): + """Add a computer.""" + # pylint: disable=attribute-defined-outside-init + if not hasattr(self, 'computer'): + created, self.computer = orm.Computer.objects.get_or_create( + label='localhost', + hostname='localhost', + transport_type='local', + scheduler_type='direct', + workdir='/tmp/aiida', + ) + if created: + self.computer.store() def test_cachelink(self): """Test the proper functionality of the links cache, with different scenarios.""" @@ -1215,29 +1250,28 @@ def test_cachelink(self): # Try also reverse storage endcalc.add_incoming(n2, LinkType.INPUT_CALC, 'N2') - self.assertEqual({(i.link_label, i.node.uuid) for i in endcalc.get_incoming()}, {('N1', n1.uuid), - ('N2', n2.uuid)}) + assert {(i.link_label, i.node.uuid) for i in endcalc.get_incoming()} == {('N1', n1.uuid), ('N2', n2.uuid)} # Endnode not stored yet, n3 and n4 already stored endcalc.add_incoming(n3, LinkType.INPUT_CALC, 'N3') # Try also reverse storage endcalc.add_incoming(n4, LinkType.INPUT_CALC, 'N4') - self.assertEqual({(i.link_label, i.node.uuid) for i in endcalc.get_incoming()}, - {('N1', n1.uuid), ('N2', n2.uuid), ('N3', n3.uuid), ('N4', n4.uuid)}) + assert {(i.link_label, i.node.uuid) for i in endcalc.get_incoming()} == \ + {('N1', n1.uuid), ('N2', n2.uuid), ('N3', n3.uuid), ('N4', n4.uuid)} # Some parent nodes are not stored yet - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): endcalc.store() - self.assertEqual({(i.link_label, i.node.uuid) for i in endcalc.get_incoming()}, - {('N1', n1.uuid), ('N2', n2.uuid), ('N3', n3.uuid), ('N4', n4.uuid)}) + assert {(i.link_label, i.node.uuid) for i in endcalc.get_incoming()} == \ + {('N1', n1.uuid), ('N2', n2.uuid), ('N3', n3.uuid), ('N4', n4.uuid)} # This will also store n1 and n2! endcalc.store_all() - self.assertEqual({(i.link_label, i.node.uuid) for i in endcalc.get_incoming()}, - {('N1', n1.uuid), ('N2', n2.uuid), ('N3', n3.uuid), ('N4', n4.uuid)}) + assert {(i.link_label, i.node.uuid) for i in endcalc.get_incoming()} == \ + {('N1', n1.uuid), ('N2', n2.uuid), ('N3', n3.uuid), ('N4', n4.uuid)} def test_store_with_unstored_parents(self): """ @@ -1251,15 +1285,14 @@ def test_store_with_unstored_parents(self): endcalc.add_incoming(n2, LinkType.INPUT_CALC, 'N2') # Some parent nodes are not stored yet - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): endcalc.store() n1.store() # Now I can store endcalc.store() - self.assertEqual({(i.link_label, i.node.uuid) for i in endcalc.get_incoming()}, {('N1', n1.uuid), - ('N2', n2.uuid)}) + assert {(i.link_label, i.node.uuid) for i in endcalc.get_incoming()} == {('N1', n1.uuid), ('N2', n2.uuid)} def test_storeall_with_unstored_grandparents(self): """ @@ -1273,7 +1306,7 @@ def test_storeall_with_unstored_grandparents(self): endcalc.add_incoming(n2, LinkType.INPUT_CALC, 'N2') # Grandparents are unstored - with self.assertRaises(ModificationNotAllowed): + with pytest.raises(ModificationNotAllowed): endcalc.store_all() n1.store() @@ -1281,10 +1314,9 @@ def test_storeall_with_unstored_grandparents(self): endcalc.store_all() # Check the parents... - self.assertEqual({(i.link_label, i.node.uuid) for i in n2.get_incoming()}, {('N1', n1.uuid)}) - self.assertEqual({(i.link_label, i.node.uuid) for i in endcalc.get_incoming()}, {('N2', n2.uuid)}) + assert {(i.link_label, i.node.uuid) for i in n2.get_incoming()} == {('N1', n1.uuid)} + assert {(i.link_label, i.node.uuid) for i in endcalc.get_incoming()} == {('N2', n2.uuid)} - # pylint: disable=unused-variable,no-member,no-self-use def test_calculation_load(self): from aiida.orm import CalcJobNode @@ -1293,7 +1325,7 @@ def test_calculation_load(self): calc.set_option('resources', {'num_machines': 1, 'num_mpiprocs_per_machine': 1}) calc.store() - with self.assertRaises(Exception): + with pytest.raises(Exception): # I should get an error if I ask for a computer id/pk that doesn't exist CalcJobNode(computer=self.computer.id + 100000).store() @@ -1308,7 +1340,7 @@ def test_links_label_constraints(self): calc2b = orm.CalculationNode() calc.add_incoming(d1, LinkType.INPUT_CALC, link_label='label1') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): calc.add_incoming(d1bis, LinkType.INPUT_CALC, link_label='label1') calc.store() @@ -1320,7 +1352,7 @@ def test_links_label_constraints(self): # This shouldn't be allowed, it's an output CREATE link with # the same same of an existing output CREATE link - with self.assertRaises(ValueError): + with pytest.raises(ValueError): d4.add_incoming(calc, LinkType.CREATE, link_label='label2') # instead, for outputs, I can have multiple times the same label @@ -1344,29 +1376,29 @@ def test_link_with_unstored(self): n3.add_incoming(n2, link_type=LinkType.CALL_CALC, link_label='l2') # Twice the same link name - with self.assertRaises(ValueError): + with pytest.raises(ValueError): n3.add_incoming(n4, link_type=LinkType.INPUT_CALC, link_label='l3') n2.store_all() n3.store_all() n2_in_links = [(n.link_label, n.node.uuid) for n in n2.get_incoming()] - self.assertEqual(sorted(n2_in_links), sorted([ + assert sorted(n2_in_links) == sorted([ ('l1', n1.uuid), - ])) + ]) n3_in_links = [(n.link_label, n.node.uuid) for n in n3.get_incoming()] - self.assertEqual(sorted(n3_in_links), sorted([ + assert sorted(n3_in_links) == sorted([ ('l2', n2.uuid), ('l3', n1.uuid), - ])) + ]) n1_out_links = [(entry.link_label, entry.node.pk) for entry in n1.get_outgoing()] - self.assertEqual(sorted(n1_out_links), sorted([ + assert sorted(n1_out_links) == sorted([ ('l1', n2.pk), ('l3', n3.pk), - ])) + ]) n2_out_links = [(entry.link_label, entry.node.pk) for entry in n2.get_outgoing()] - self.assertEqual(sorted(n2_out_links), sorted([('l2', n3.pk)])) + assert sorted(n2_out_links) == sorted([('l2', n3.pk)]) def test_multiple_create_links(self): """ @@ -1378,7 +1410,7 @@ def test_multiple_create_links(self): # Caching the links n3.add_incoming(n1, link_type=LinkType.CREATE, link_label='CREATE') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): n3.add_incoming(n2, link_type=LinkType.CREATE, link_label='CREATE') def test_valid_links(self): @@ -1395,7 +1427,7 @@ def test_valid_links(self): label='localhost2', hostname='localhost', scheduler_type='direct', transport_type='local' ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # I need to save the localhost entry first orm.CalcJobNode(computer=unsavedcomputer).store() @@ -1411,17 +1443,17 @@ def test_valid_links(self): calc.add_incoming(d2, link_type=LinkType.INPUT_CALC, link_label='some_label') # Cannot link to itself - with self.assertRaises(ValueError): + with pytest.raises(ValueError): d1.add_incoming(d1, link_type=LinkType.INPUT_CALC, link_label='link') # I try to add wrong links (data to data, calc to calc, etc.) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): d2.add_incoming(d1, link_type=LinkType.INPUT_CALC, link_label='link') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): d1.add_incoming(d2, link_type=LinkType.INPUT_CALC, link_label='link') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): calc.add_incoming(calc2, link_type=LinkType.INPUT_CALC, link_label='link') calc.store() @@ -1436,13 +1468,13 @@ def test_valid_links(self): data_node = orm.Data().store() data_node.add_incoming(calc_a, link_type=LinkType.CREATE, link_label='link') # A data cannot have two input calculations - with self.assertRaises(ValueError): + with pytest.raises(ValueError): data_node.add_incoming(calc_b, link_type=LinkType.CREATE, link_label='link') calculation_inputs = calc.get_incoming().all() # This calculation has two data inputs - self.assertEqual(len(calculation_inputs), 2) + assert len(calculation_inputs) == 2 def test_check_single_calc_source(self): """ @@ -1460,7 +1492,7 @@ def test_check_single_calc_source(self): d1.add_incoming(calc, link_type=LinkType.CREATE, link_label='link') # more than one input to the same data object! - with self.assertRaises(ValueError): + with pytest.raises(ValueError): d1.add_incoming(calc2, link_type=LinkType.CREATE, link_label='link') def test_node_get_incoming_outgoing_links(self): @@ -1491,19 +1523,19 @@ def test_node_get_incoming_outgoing_links(self): node_return.add_incoming(node_origin, link_type=LinkType.RETURN, link_label='return2') # All incoming and outgoing - self.assertEqual(len(node_origin.get_incoming().all()), 2) - self.assertEqual(len(node_origin.get_outgoing().all()), 3) + assert len(node_origin.get_incoming().all()) == 2 + assert len(node_origin.get_outgoing().all()) == 3 # Link specific incoming - self.assertEqual(len(node_origin.get_incoming(link_type=LinkType.CALL_WORK).all()), 1) - self.assertEqual(len(node_origin2.get_incoming(link_type=LinkType.CALL_WORK).all()), 1) - self.assertEqual(len(node_origin.get_incoming(link_type=LinkType.INPUT_WORK).all()), 1) - self.assertEqual(len(node_origin.get_incoming(link_label_filter='in_ut%').all()), 1) - self.assertEqual(len(node_origin.get_incoming(node_class=orm.Node).all()), 2) + assert len(node_origin.get_incoming(link_type=LinkType.CALL_WORK).all()) == 1 + assert len(node_origin2.get_incoming(link_type=LinkType.CALL_WORK).all()) == 1 + assert len(node_origin.get_incoming(link_type=LinkType.INPUT_WORK).all()) == 1 + assert len(node_origin.get_incoming(link_label_filter='in_ut%').all()) == 1 + assert len(node_origin.get_incoming(node_class=orm.Node).all()) == 2 # Link specific outgoing - self.assertEqual(len(node_origin.get_outgoing(link_type=LinkType.CALL_WORK).all()), 1) - self.assertEqual(len(node_origin.get_outgoing(link_type=LinkType.RETURN).all()), 2) + assert len(node_origin.get_outgoing(link_type=LinkType.CALL_WORK).all()) == 1 + assert len(node_origin.get_outgoing(link_type=LinkType.RETURN).all()) == 2 class AnyValue: @@ -1515,7 +1547,8 @@ def __eq__(self, other): return True -class TestNodeDeletion(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test_class') +class TestNodeDeletion: def _check_existence(self, uuids_check_existence, uuids_check_deleted): """ @@ -1536,7 +1569,7 @@ def _check_existence(self, uuids_check_existence, uuids_check_deleted): orm.load_node(uuid) for uuid in uuids_check_deleted: # I check that it raises - with self.assertRaises(NotExistent): + with pytest.raises(NotExistent): orm.load_node(uuid) @staticmethod @@ -1550,8 +1583,8 @@ def test_deletion_dry_run_true(self): node = orm.Data().store() node_pk = node.pk deleted_pks, was_deleted = delete_nodes([node_pk], dry_run=True) - self.assertTrue(not was_deleted) - self.assertSetEqual(deleted_pks, {node_pk}) + assert not was_deleted + assert deleted_pks == {node_pk} orm.load_node(node_pk) def test_deletion_dry_run_callback(self): @@ -1566,11 +1599,11 @@ def _callback(pks): return False deleted_pks, was_deleted = delete_nodes([node_pk], dry_run=_callback) - self.assertTrue(was_deleted) - self.assertSetEqual(deleted_pks, {node_pk}) - with self.assertRaises(NotExistent): + assert was_deleted + assert deleted_pks == {node_pk} + with pytest.raises(NotExistent): orm.load_node(node_pk) - self.assertListEqual(callback_pks, [node_pk]) + assert callback_pks == [node_pk] # TEST BASIC CASES @@ -2054,8 +2087,8 @@ def test_delete_group_nodes(self): node_uuids = {node.uuid for node in nodes} group.add_nodes(nodes) deleted_pks, was_deleted = delete_group_nodes([group.pk], dry_run=False) - self.assertTrue(was_deleted) - self.assertSetEqual(deleted_pks, node_pks) + assert was_deleted + assert deleted_pks == node_pks self._check_existence([], node_uuids) def test_delete_group_nodes_dry_run_true(self): @@ -2066,6 +2099,6 @@ def test_delete_group_nodes_dry_run_true(self): node_uuids = {node.uuid for node in nodes} group.add_nodes(nodes) deleted_pks, was_deleted = delete_group_nodes([group.pk], dry_run=True) - self.assertTrue(not was_deleted) - self.assertSetEqual(deleted_pks, node_pks) + assert not was_deleted + assert deleted_pks == node_pks self._check_existence(node_uuids, [])