Skip to content

Commit

Permalink
🧪 TESTS: AiidaTestCase -> pytest
Browse files Browse the repository at this point in the history
Process:

- unittest2pytest -n -w path/to/file
- add `# pylint: disable=no-self-use`
- replace `AiidaTestCase` with `@pytest.mark.usefixtures('clear_database_before_test_class')`
- remove from `aiida.backends.testbase import AiidaTestCase`
- change @unittest.skip -> @pytest.mark.skip, @unittest.skipIf -> @pytest.mark.skipif

The were some issues where`assertAlmostEquals` had been used for strange cases

Annoyingly there are some limitations, one being that the class fixture is called AFTER a `setup_class` method (see
https://stackoverflow.com/questions/31484419/pytest-setup-class-after-fixture-initialization_)
and also class methods don't accept fixtures. This makes it difficult to do a 1-to-1 replacement of any `setUpClass` methods used, and also replace the use of `self.computer`
  • Loading branch information
chrisjsewell committed Feb 25, 2021
1 parent c07e3ef commit 3361712
Show file tree
Hide file tree
Showing 9 changed files with 1,878 additions and 1,886 deletions.
7 changes: 7 additions & 0 deletions aiida/manage/tests/pytest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def clear_database_before_test(aiida_profile):
yield


@pytest.fixture(scope='class')
def clear_database_before_test_class(aiida_profile):
"""Clear the database before a test class."""
aiida_profile.reset_db()
yield


@pytest.fixture(scope='function')
def temporary_event_loop():
"""Create a temporary loop for independent test case"""
Expand Down
3 changes: 3 additions & 0 deletions aiida/orm/nodes/data/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,6 @@ def __float__(self):

def __int__(self):
return int(self.value)

def __abs__(self):
return abs(self.value)
371 changes: 175 additions & 196 deletions tests/engine/test_work_chain.py

Large diffs are not rendered by default.

759 changes: 369 additions & 390 deletions tests/orm/test_querybuilder.py

Large diffs are not rendered by default.

158 changes: 82 additions & 76 deletions tests/test_base_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,33 @@
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=no-self-use
"""Tests for AiiDA base data classes."""
import operator

from aiida.backends.testbase import AiidaTestCase
import pytest

from aiida.common.exceptions import ModificationNotAllowed
from aiida.orm import load_node, List, Bool, Float, Int, Str, NumericType
from aiida.orm.nodes.data.bool import get_true_node, get_false_node


class TestList(AiidaTestCase):
@pytest.mark.usefixtures('clear_database_before_test_class')
class TestList:
"""Test AiiDA List class."""

def test_creation(self):
node = List()
self.assertEqual(len(node), 0)
with self.assertRaises(IndexError):
assert len(node) == 0
with pytest.raises(IndexError):
node[0] # pylint: disable=pointless-statement

def test_append(self):
"""Test append() member function."""

def do_checks(node):
self.assertEqual(len(node), 1)
self.assertEqual(node[0], 4)
assert len(node) == 1
assert node[0] == 4

node = List()
node.append(4)
Expand All @@ -47,22 +50,22 @@ def test_extend(self):
lst = [1, 2, 3]

def do_checks(node):
self.assertEqual(len(node), len(lst))
assert len(node) == len(lst)
# Do an element wise comparison
for lst_, node_ in zip(lst, node):
self.assertEqual(lst_, node_)
assert lst_ == node_

node = List()
node.extend(lst)
do_checks(node)
# Further extend
node.extend(lst)
self.assertEqual(len(node), len(lst) * 2)
assert len(node) == len(lst) * 2

# Do an element wise comparison
for i, _ in enumerate(lst):
self.assertEqual(lst[i], node[i])
self.assertEqual(lst[i], node[i % len(lst)])
assert lst[i] == node[i]
assert lst[i] == node[i % len(lst)]

# Now try after storing
node = List()
Expand All @@ -77,19 +80,19 @@ def test_mutability(self):
node.store()

# Test all mutable calls are now disallowed
with self.assertRaises(ModificationNotAllowed):
with pytest.raises(ModificationNotAllowed):
node.append(5)
with self.assertRaises(ModificationNotAllowed):
with pytest.raises(ModificationNotAllowed):
node.extend([5])
with self.assertRaises(ModificationNotAllowed):
with pytest.raises(ModificationNotAllowed):
node.insert(0, 2)
with self.assertRaises(ModificationNotAllowed):
with pytest.raises(ModificationNotAllowed):
node.remove(0)
with self.assertRaises(ModificationNotAllowed):
with pytest.raises(ModificationNotAllowed):
node.pop()
with self.assertRaises(ModificationNotAllowed):
with pytest.raises(ModificationNotAllowed):
node.sort()
with self.assertRaises(ModificationNotAllowed):
with pytest.raises(ModificationNotAllowed):
node.reverse()

@staticmethod
Expand All @@ -102,144 +105,146 @@ def test_store_load():
assert node.get_list() == node_loaded.get_list()


class TestFloat(AiidaTestCase):
@pytest.mark.usefixtures('clear_database_before_test_class')
class TestFloat:
"""Test Float class."""

def setUp(self):
super().setUp()
def setup_method(self):
# pylint: disable=attribute-defined-outside-init
self.value = Float()
self.all_types = [Int, Float, Bool, Str]

def test_create(self):
"""Creating basic data objects."""
term_a = Float()
# Check that initial value is zero
self.assertAlmostEqual(term_a.value, 0.0)
assert round(abs(term_a.value - 0.0), 7) == 0

float_ = Float(6.0)
self.assertAlmostEqual(float_.value, 6.)
self.assertAlmostEqual(float_, Float(6.0))
assert round(abs(float_.value - 6.), 7) == 0
assert round(abs(float_ - Float(6.0)), 7) == 0

int_ = Int()
self.assertAlmostEqual(int_.value, 0)
assert round(abs(int_.value - 0), 7) == 0
int_ = Int(6)
self.assertAlmostEqual(int_.value, 6)
self.assertAlmostEqual(float_, int_)
assert round(abs(int_.value - 6), 7) == 0
assert round(abs(float_ - int_), 7) == 0

bool_ = Bool()
self.assertAlmostEqual(bool_.value, False)
assert round(abs(bool_.value - False), 7) == 0
bool_ = Bool(False)
self.assertAlmostEqual(bool_.value, False)
self.assertAlmostEqual(bool_.value, get_false_node())
assert bool_.value is False
assert bool_.value == get_false_node()
bool_ = Bool(True)
self.assertAlmostEqual(bool_.value, True)
self.assertAlmostEqual(bool_.value, get_true_node())
assert bool_.value is True
assert bool_.value == get_true_node()

str_ = Str()
self.assertAlmostEqual(str_.value, '')
assert str_.value == ''
str_ = Str('Hello')
self.assertAlmostEqual(str_.value, 'Hello')
assert str_.value == 'Hello'

def test_load(self):
"""Test object loading."""
for typ in self.all_types:
node = typ()
node.store()
loaded = load_node(node.pk)
self.assertAlmostEqual(node, loaded)
assert node == loaded

def test_add(self):
"""Test addition."""
term_a = Float(4)
term_b = Float(5)
# Check adding two db Floats
res = term_a + term_b
self.assertIsInstance(res, NumericType)
self.assertAlmostEqual(res, 9.0)
assert isinstance(res, NumericType)
assert round(abs(res - 9.0), 7) == 0

# Check adding db Float and native (both ways)
res = term_a + 5.0
self.assertIsInstance(res, NumericType)
self.assertAlmostEqual(res, 9.0)
assert isinstance(res, NumericType)
assert round(abs(res - 9.0), 7) == 0

res = 5.0 + term_a
self.assertIsInstance(res, NumericType)
self.assertAlmostEqual(res, 9.0)
assert isinstance(res, NumericType)
assert round(abs(res - 9.0), 7) == 0

# Inplace
term_a = Float(4)
term_a += term_b
self.assertAlmostEqual(term_a, 9.0)
assert round(abs(term_a - 9.0), 7) == 0

term_a = Float(4)
term_a += 5
self.assertAlmostEqual(term_a, 9.0)
assert round(abs(term_a - 9.0), 7) == 0

def test_mul(self):
"""Test floats multiplication."""
term_a = Float(4)
term_b = Float(5)
# Check adding two db Floats
res = term_a * term_b
self.assertIsInstance(res, NumericType)
self.assertAlmostEqual(res, 20.0)
assert isinstance(res, NumericType)
assert round(abs(res - 20.0), 7) == 0

# Check adding db Float and native (both ways)
res = term_a * 5.0
self.assertIsInstance(res, NumericType)
self.assertAlmostEqual(res, 20)
assert isinstance(res, NumericType)
assert round(abs(res - 20), 7) == 0

res = 5.0 * term_a
self.assertIsInstance(res, NumericType)
self.assertAlmostEqual(res, 20.0)
assert isinstance(res, NumericType)
assert round(abs(res - 20.0), 7) == 0

# Inplace
term_a = Float(4)
term_a *= term_b
self.assertAlmostEqual(term_a, 20)
assert round(abs(term_a - 20), 7) == 0

term_a = Float(4)
term_a *= 5
self.assertAlmostEqual(term_a, 20)
assert round(abs(term_a - 20), 7) == 0

def test_power(self):
"""Test power operator."""
term_a = Float(4)
term_b = Float(2)

res = term_a**term_b
self.assertAlmostEqual(res.value, 16.)
assert round(abs(res.value - 16.), 7) == 0

def test_division(self):
"""Test the normal division operator."""
term_a = Float(3)
term_b = Float(2)

self.assertAlmostEqual(term_a / term_b, 1.5)
self.assertIsInstance(term_a / term_b, Float)
assert round(abs(term_a / term_b - 1.5), 7) == 0
assert isinstance(term_a / term_b, Float)

def test_division_integer(self):
"""Test the integer division operator."""
term_a = Float(3)
term_b = Float(2)

self.assertAlmostEqual(term_a // term_b, 1.0)
self.assertIsInstance(term_a // term_b, Float)
assert round(abs(term_a // term_b - 1.0), 7) == 0
assert isinstance(term_a // term_b, Float)

def test_modulus(self):
"""Test modulus operator."""
term_a = Float(12.0)
term_b = Float(10.0)

self.assertAlmostEqual(term_a % term_b, 2.0)
self.assertIsInstance(term_a % term_b, NumericType)
self.assertAlmostEqual(term_a % 10.0, 2.0)
self.assertIsInstance(term_a % 10.0, NumericType)
self.assertAlmostEqual(12.0 % term_b, 2.0)
self.assertIsInstance(12.0 % term_b, NumericType)
assert round(abs(term_a % term_b - 2.0), 7) == 0
assert isinstance(term_a % term_b, NumericType)
assert round(abs(term_a % 10.0 - 2.0), 7) == 0
assert isinstance(term_a % 10.0, NumericType)
assert round(abs(12.0 % term_b - 2.0), 7) == 0
assert isinstance(12.0 % term_b, NumericType)


class TestFloatIntMix(AiidaTestCase):
@pytest.mark.usefixtures('clear_database_before_test_class')
class TestFloatIntMix:
"""Test operations between Int and Float objects."""

def test_operator(self):
Expand All @@ -254,37 +259,38 @@ def test_operator(self):
for term_x, term_y in [(term_a, term_b), (term_b, term_a)]:
res = oper(term_x, term_y)
c_val = oper(term_x.value, term_y.value)
self.assertEqual(res._type, type(c_val)) # pylint: disable=protected-access
self.assertEqual(res, oper(term_x.value, term_y.value))
assert res._type == type(c_val) # pylint: disable=protected-access
assert res == oper(term_x.value, term_y.value)


class TestInt(AiidaTestCase):
@pytest.mark.usefixtures('clear_database_before_test_class')
class TestInt:
"""Test Int class."""

def test_division(self):
"""Test the normal division operator."""
term_a = Int(3)
term_b = Int(2)

self.assertAlmostEqual(term_a / term_b, 1.5)
self.assertIsInstance(term_a / term_b, Float)
assert round(abs(term_a / term_b - 1.5), 7) == 0
assert isinstance(term_a / term_b, Float)

def test_division_integer(self):
"""Test the integer division operator."""
term_a = Int(3)
term_b = Int(2)

self.assertAlmostEqual(term_a // term_b, 1)
self.assertIsInstance(term_a // term_b, Int)
assert round(abs(term_a // term_b - 1), 7) == 0
assert isinstance(term_a // term_b, Int)

def test_modulo(self):
"""Test modulus operation."""
term_a = Int(12)
term_b = Int(10)

self.assertEqual(term_a % term_b, 2)
self.assertIsInstance(term_a % term_b, NumericType)
self.assertEqual(term_a % 10, 2)
self.assertIsInstance(term_a % 10, NumericType)
self.assertEqual(12 % term_b, 2)
self.assertIsInstance(12 % term_b, NumericType)
assert term_a % term_b == 2
assert isinstance(term_a % term_b, NumericType)
assert term_a % 10 == 2
assert isinstance(term_a % 10, NumericType)
assert 12 % term_b == 2
assert isinstance(12 % term_b, NumericType)
Loading

0 comments on commit 3361712

Please sign in to comment.