Skip to content

Commit

Permalink
Improve factor unit tests
Browse files Browse the repository at this point in the history
This patch restructures the test_evaluable.factor test suite, and adds
test_replace to assert that argument replacements result in a sound object with
correct derivatives, and test_derivative_sparsity to assert that factor
produces functions with an exposable sparsity structure. This is the case since
commit 9551525.
  • Loading branch information
gertjanvanzwieten committed Jan 22, 2025
1 parent 8f46cfb commit 515c32e
Showing 1 changed file with 68 additions and 33 deletions.
101 changes: 68 additions & 33 deletions tests/test_evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,63 +1352,98 @@ def test_grad_variable_ncoeffs(self):
class factor(TestCase):

def setUp(self):
index = evaluable.loop_index('i', 4)
values = evaluable.constant([1., 2., -2., -1.])
dofs = index + evaluable.Range(values.shape[0])
length = evaluable.constant(8)
basis = evaluable.loop_sum(evaluable._inflate(values, dofs, length, axis=0), index)
self.index = evaluable.loop_index('i', 4)
values = evaluable.constant([1., 2., -3.])
dofs = self.index + evaluable.Range(values.shape[0])
length = evaluable.constant(6)
basis = evaluable._inflate(values, dofs, length, axis=0)
self.varg = evaluable.Argument('v', basis.shape, float)
self.v = (basis * self.varg).sum(0)
self.targ = evaluable.Argument('t', (), float)
self.t = 10. * self.targ

def assertFactoredEqual(self, f, v=0, t=0):
self.assertEqual(f.argument_degree(self.varg), v)
self.assertEqual(f.argument_degree(self.targ), t)
g = evaluable.factor(f)
if not t:
tryargs = dict(v=numpy.zeros(8)), dict(v=numpy.ones(8)), dict(v=numpy.arange(8, dtype=float))
elif not v:
tryargs = dict(t=0.), dict(t=1.), dict(t=-5.)
else:
tryargs = dict(v=numpy.arange(8, dtype=float), t=0.), dict(v=numpy.zeros(8, dtype=float), t=5.), dict(v=numpy.arange(8, dtype=float), t=5.)
for vderiv in range(v):
if vderiv:
f = evaluable.derivative(f, self.varg)
g = evaluable.derivative(g, self.varg)
f_ = f
g_ = f
for tderiv in range(t):
if tderiv:
f_ = evaluable.derivative(f_, self.targ)
g_ = evaluable.derivative(g_, self.targ)
with self.subTest(f'derivative v{vderiv}, t{tderiv}'):
F = evaluable.compile(f_)
G = evaluable.compile(g_)
for args in tryargs:
self.assertAllAlmostEqual(F(**args), G(**args))
self.barg = evaluable.Argument('b', (*basis.shape, evaluable.constant(2)), float)
self.b = (evaluable.InsertAxis(basis, self.barg.shape[1]) * self.barg).sum(0)

def integral(self, f):
return evaluable.loop_sum(f, self.index)

def assertFactoredEqual(self, integrand, replacements=None, *, v=0, t=0, b=0):
orig = self.integral(integrand)
factored = evaluable.factor(orig)
if replacements:
orig = evaluable.replace_arguments(orig, replacements)
factored = evaluable.replace_arguments(factored, replacements)

for func in orig, factored:
self.assertEqual(func.argument_degree(self.varg), v)
self.assertEqual(func.argument_degree(self.targ), t)
self.assertEqual(func.argument_degree(self.barg), b)

testing_grid = [{}]
if t: testing_grid = [dict(d, t=t) for d in testing_grid for t in [0., 1., -5.]]
if v: testing_grid = [dict(d, v=v) for d in testing_grid for v in [numpy.zeros(6), numpy.ones(6), numpy.arange(6, dtype=float)]]
if b: testing_grid = [dict(d, b=b) for d in testing_grid for b in [numpy.zeros((6,2)), numpy.ones((6,2)), numpy.arange(12, dtype=float).reshape(6,2)]]

for deriv_args in [[self.targ] * i + [self.varg] * j + [self.barg] * k for i in range(t+1) for j in range(v+1) for k in range(b+1)]:
with self.subTest('f/' + ''.join(arg.name for arg in deriv_args)):
F = evaluable.compile(functools.reduce(evaluable.derivative, deriv_args, orig))
G = evaluable.compile(functools.reduce(evaluable.derivative, deriv_args, factored))
for eval_args in testing_grid:
self.assertAllAlmostEqual(F(**eval_args), G(**eval_args))

def test_linear(self):
self.assertFactoredEqual(1. + self.v, v=1)
self.assertFactoredEqual(2. * self.v - 5. * self.t, v=1, t=1)
self.assertFactoredEqual(2. * self.v - self.t, v=1, t=1)
self.assertFactoredEqual(2. * self.v * self.t, v=1, t=1)
self.assertFactoredEqual(3. * self.b * self.t, b=1, t=1)

def test_quadratic(self):
self.assertFactoredEqual(1. + self.v - self.v**2., v=2)
self.assertFactoredEqual(3. * self.v**2., v=2)
self.assertFactoredEqual(5. * self.t**2., t=2)
self.assertFactoredEqual(self.v * self.t**2. - 2. * self.v**2. * self.t, t=2, v=2)
self.assertFactoredEqual(self.v * evaluable.Sum(self.b * self.b) - 2. * self.v**2., v=2, b=2)

def test_cubic(self):
self.assertFactoredEqual(1. + self.v - self.v**2. + self.v**3., v=3)
self.assertFactoredEqual(-self.v**3., v=3)
self.assertFactoredEqual(3. * self.t**3., t=3)
self.assertFactoredEqual(self.t**2. * (self.v**2. - 2. * self.t), t=3, v=2)
self.assertFactoredEqual(self.t**2. * (self.b**2. - 2. * self.t * self.b), t=3, b=2)

def test_not_polynomial(self):
with self.assertRaisesRegex(evaluable.NotPolynomal, "nutils.evaluable.Sign<f:> is not polynomial in argument 'v'"):
evaluable.factor(evaluable.Sign(self.v))
evaluable.factor(self.integral(evaluable.Sign(self.v)))

def test_constant(self):
self.assertFactoredEqual(evaluable.constant(1.))

def test_replace(self):
self.assertFactoredEqual(self.v, dict(v=evaluable.constant(numpy.arange(.5, 6)) * self.t), t=1)
self.assertFactoredEqual(self.v * self.v, dict(v=evaluable.constant(numpy.arange(.5, 6)) * self.t), t=2)
self.assertFactoredEqual(self.b, dict(b=evaluable.constant(numpy.arange(.5, 12).reshape(6,2)) * self.t), t=1)
self.assertFactoredEqual(self.v * self.v * evaluable.Sum(self.b * self.b),
dict(b=evaluable.constant(numpy.arange(.5, 12).reshape(6,2)) * self.t, v=evaluable.constant(numpy.arange(.5, 6)) * self.t), t=4)

def test_derivative_sparsity(self):
# check that the function tree produced by factor reveals its sparsity after derivative.

v2 = evaluable.derivative(evaluable.derivative(evaluable.factor(self.integral(self.v * self.v)), self.varg), self.varg)
values, indices, shape = evaluable.eval_coo(v2)
self.assertAllEqual(values,
[2, 4, -6, 4, 10, -8, -6, -6, -8, 28, -8, -6, -6, -8, 28, -8, -6, -6, -8, 26, -12, -6, -12, 18])
self.assertAllEqual(indices,
[[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5],
[0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4, 1, 2, 3, 4, 5, 2, 3, 4, 5, 3, 4, 5]])
self.assertEqual(shape, (6, 6))

b2 = evaluable.derivative(evaluable.derivative(evaluable.factor(self.integral(evaluable.Sum(self.b * self.b))), self.barg), self.barg)
values, indices, shape = evaluable.eval_coo(b2)
self.assertAllEqual(values,
[2, 4, -6, 2, 4, -6, 4, 10, -8, -6, 4, 10, -8, -6, -6, -8, 28, -8, -6, -6, -8, 28, -8, -6, -6, -8, 28, -8, -6, -6, -8, 28, -8, -6, -6, -8, 26, -12, -6, -8, 26, -12, -6, -12, 18, -6, -12, 18])
self.assertAllEqual(indices,
[[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5],
[0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5],
[0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1]])
self.assertEqual(shape, (6, 2, 6, 2))

0 comments on commit 515c32e

Please sign in to comment.