diff --git a/test/python/test_func1.py b/test/python/test_func1.py index 88ad105ee5..4d44dd9e4e 100644 --- a/test/python/test_func1.py +++ b/test/python/test_func1.py @@ -2,6 +2,8 @@ import cantera as ct from . import utilities +import math +import pytest class TestFunc1(utilities.CanteraTest): def test_function(self): @@ -12,6 +14,7 @@ def test_function(self): def test_lambda(self): f = ct.Func1(lambda t: np.sin(t)*np.sqrt(t)) + assert f.type == "functor" for t in [0.1, 0.7, 4.5]: self.assertNear(f(t), np.sin(t)*np.sqrt(t)) @@ -24,6 +27,7 @@ def __call__(self, t): m = Multiplier(8.1) f = ct.Func1(m) + assert f.type == "functor" for t in [0.1, 0.7, 4.5]: self.assertNear(f(t), 8.1*t) @@ -31,6 +35,7 @@ def test_constant(self): f = ct.Func1(5) for t in [0.1, 0.7, 4.5]: self.assertNear(f(t), 5) + assert f.type == "constant" def test_sequence(self): f = ct.Func1([5]) @@ -42,7 +47,9 @@ def test_sequence(self): def test_numpy(self): f = ct.Func1(np.array(5)) + assert f.type == "constant" g = ct.Func1(np.array([[5]])) + assert g.type == "constant" for t in [0.1, 0.7, 4.5]: self.assertNear(f(t), 5) self.assertNear(g(t), 5) @@ -70,11 +77,58 @@ def test_uncopyable(self): with self.assertRaises(NotImplementedError): copy.copy(f) + def test_simple(self): + functors = { + 'sin': math.sin, + 'cos': math.cos, + 'exp': math.exp, + 'log': math.log, + } + for name, fcn in functors.items(): + coeff = 2.34 + func = ct.Func1.cxx_functor(name, coeff) + assert func.type == name + for val in [.1, 1., 10.]: + assert name in func.write() + assert func(val) == pytest.approx(fcn(coeff * val)) + + def test_compound(self): + functors = { + 'sum': lambda x, y: x + y, + 'diff': lambda x, y: x - y, + 'product': lambda x, y: x * y, + 'ratio': lambda x, y: x / y, + } + f1 = ct.Func1.cxx_functor('pow', 2) + f2 = ct.Func1.cxx_functor('sin') + for name, fcn in functors.items(): + func = ct.Func1.cxx_functor(name, f1, f2) + assert func.type == name + for val in [.1, 1., 10.]: + assert name not in func.write() + assert func(val) == pytest.approx(fcn(f1(val), f2(val))) + + def test_modified(self): + functors = { + 'plus-constant': lambda x, y: x + y, + 'times-constant': lambda x, y: x * y, + } + f1 = ct.Func1.cxx_functor('sin') + constant = 2.34 + for name, fcn in functors.items(): + func = ct.Func1.cxx_functor(name, f1, constant) + assert func.type == name + for val in [.1, 1., 10.]: + assert name not in func.write() + assert func(val) == pytest.approx(fcn(f1(val), constant)) + def test_tabulated1(self): + # this implicitly probes advanced functors arr = np.array([[0, 2], [1, 1], [2, 0]]) time = arr[:, 0] fval = arr[:, 1] fcn = ct.TabulatedFunction(time, fval) + assert fcn.type == "tabulated-linear" for t, f in zip(time, fval): self.assertNear(f, fcn(t)) @@ -82,6 +136,7 @@ def test_tabulated2(self): time = [0, 1, 2] fval = [2, 1, 0] fcn = ct.TabulatedFunction(time, fval) + assert fcn.type == "tabulated-linear" for t, f in zip(time, fval): self.assertNear(f, fcn(t)) @@ -105,15 +160,16 @@ def test_tabulated5(self): time = [0, 1, 2] fval = [2, 1, 0] fcn = ct.TabulatedFunction(time, fval, method='previous') + assert fcn.type == "tabulated-previous" val = np.array([fcn(v) for v in [-0.5, 0, 0.5, 1.5, 2, 2.5]]) self.assertArrayNear(val, np.array([2.0, 2.0, 2.0, 1.0, 0.0, 0.0])) def test_tabulated_failures(self): - with self.assertRaisesRegex(ValueError, 'do not match'): + with pytest.raises(ct.CanteraError, match="even number of entries"): ct.TabulatedFunction(range(2), range(3)) - with self.assertRaisesRegex(ValueError, 'must not be empty'): + with pytest.raises(ct.CanteraError, match="at least 4 entries"): ct.TabulatedFunction([], []) - with self.assertRaisesRegex(ct.CanteraError, 'monotonically'): + with pytest.raises(ct.CanteraError, match="monotonically"): ct.TabulatedFunction((0, 1, 0.5, 2), (2, 1, 1, 0)) - with self.assertRaisesRegex(NotImplementedError, 'not implemented'): - ct.TabulatedFunction((0, 1, 1, 2), (2, 1, 1, 0), method='not implemented') + with pytest.raises(ct.CanteraError, match="No such type"): + ct.TabulatedFunction((0, 1, 1, 2), (2, 1, 1, 0), method='spam')