Skip to content

Commit

Permalink
Introduce evaluable.Monomial
Browse files Browse the repository at this point in the history
This patch introduces evaluable.Monomial to represent the sparse tensor
contraction underlying evaluable.factor. This improves the structure of
factored functions that involve arguments of multiple dimensions, which in the
old form had their sparsity shielded after differentiation. As a side effect
the SymProd operation is removed as it is no longer needed.
  • Loading branch information
gertjanvanzwieten committed Jan 22, 2025
1 parent 20f98e8 commit 8f46cfb
Showing 1 changed file with 85 additions and 67 deletions.
152 changes: 85 additions & 67 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5393,6 +5393,87 @@ def zero_all_arguments(value):
return zeros_like(value)


class Monomial(Array):
'''Helper object for factor.
Performs a sparse tensor multiplication, without summation, and returns the
result as a dense vector; inflation and reshape is the responsibility of
factor. The factors of the multiplication are ``values`` and all the
``dependencies``, which are scattered into values via the ``indices``. The
``powers`` argument contains multiplicities, for the following reason.
With reference to the example of the factor doc string, derivative will
generate an evaluable of the form array'(arg) darg = darray_darg(0) darg +
.5 arg d2array_darg2 darg + .5 darg d2array_darg2 arg. By Schwarz's theorem
d2array_darg2 is symmetric, and the latter two terms are equal. However,
since the row and column indices of d2array_darg2 differ, we cannot detect
this equality but rather need to embed the information explicitly.
In this situation, the ``powers`` argument contains the value 2 to indicate
that its position is symmetric with the next, and the first integral can
therefore be doubled. With that, the derivative takes the desired form of
array'(arg) == darray_darg(0) + d2array_darg2 arg.'''

values: types.arraydata
dependencies: typing.Tuple[Array, ...]
indices: typing.Tuple[typing.Tuple[types.arraydata, ...], ...]
powers: typing.Tuple[int]

def __post_init__(self):
assert self.values.ndim == 1, self.values.shape
assert len(self.dependencies) == len(self.indices) == len(self.powers)
assert all(index.shape == self.values.shape for indices in self.indices for index in indices), (self.values.shape, self.indices)
assert all(len(indices) == arg.ndim for arg, indices in zip(self.dependencies, self.indices)), (len(self.indices), self.dependencies)
assert all(isinstance(power, int) and power > 0 for power in self.powers), self.powers
assert all(self.dependencies[j] == self.dependencies[i] for i, n in enumerate(self.powers) for j in range(i+1, i+n))

def _simplified(self):
if not self.dependencies:
return Constant(self.values)

@property
def shape(self):
return constant(self.values.shape[0]),

@property
def dtype(self):
return self.values.dtype

def evalf(self, *args):
v = numpy.array(self.values)
for arg, indices in zip(args, self.indices):
v *= arg[indices]
return v

def _derivative(self, var, seen):
deriv = Zeros(self.shape + var.shape, self.dtype)
iarg = 0
while iarg < len(self.dependencies):
dep = self.dependencies[iarg]
m = Monomial(self.values,
self.dependencies[:iarg] + self.dependencies[iarg+1:],
self.indices[:iarg] + self.indices[iarg+1:],
self.powers[:iarg] + self.powers[iarg+1:])
if dep.ndim:
*indices, ravel_index = map(Constant, self.indices[iarg])
*lengths, ravel_length = dep.shape
while indices:
ravel_index += indices.pop() * ravel_length
ravel_length *= lengths.pop()
m = unravel(Inflate(Diagonalize(m), ravel_index, ravel_length), -1, dep.shape)
m = einsum('aB,BC->aC', m, derivative(dep, var, seen))
power = self.powers[iarg]
if power > 1:
m *= m.dtype(power)
deriv += m
iarg += power
assert iarg == len(self.dependencies)
return deriv

def _argument_degree(self, argument):
return builtins.sum(dep.argument_degree(argument) for dep in self.dependencies)


@log.withcontext
def factor(array):
'''Convert array to a sparse polynomial.
Expand Down Expand Up @@ -5470,27 +5551,10 @@ def factor(array):
if not len(values):
continue

offsets = numpy.cumsum([array.ndim] + [arg.ndim for arg in args])
factors = [Take(_flat(arg), constant(numpy.ravel_multi_index(indices[s], shape[s]))) if arg.ndim else arg
for arg, s in zip(args, map(slice, offsets, offsets[1:]))]

# With reference to the example of the doc string, derivative will
# generate an evaluable of the form array'(arg) darg = darray_darg(0)
# darg + .5 arg d2array_darg2 darg + .5 darg d2array_darg2 arg. By
# Schwarz's theorem d2array_darg2 is symmetric, and the latter two
# terms are equal. However, since the row and column indices of
# d2array_darg2 differ, we cannot detect this equality but rather need
# to embed the information explicitly. To this end, factor produces an
# evaluable of the form array(arg) == array(0) + darray_darg(0) arg +
# .5 SymProd(arg, d2array_darg2 arg), which produces the desired
# derivative array'(arg) == darray_darg(0) + d2array_darg2 arg.
for shift, i in enumerate(i for i, (a, b) in enumerate(util.pairwise(args)) if a == b):
i -= shift
factors[i] = SymProd(factors[i], factors.pop(i+1))

monomial = constant(values)
for factor in factors:
monomial *= factor
indexmap = map(types.arraydata, indices[array.ndim:])
monomial = Monomial(types.arraydata(values), args,
indices=tuple(tuple(next(indexmap) for _ in range(arg.ndim)) for arg in args),
powers=tuple(args[i:].count(arg) for i, arg in enumerate(args)))

if array.ndim == 0:
term = Sum(monomial)
Expand All @@ -5506,52 +5570,6 @@ def factor(array):
return util.sum(polynomial) if polynomial else zeros_like(array)


class SymProd(Array):
'''Symmetric product.
The ``SymProd(a, b)`` operation behaves like ``Multiply(a, b)`` except that
its derivative is ``2 a db`` rather than ``a db + da b``, effectively
encoding the information that the two terms are equal. Care should be taken
that the ``SymProd`` operation is used only where this property applies.'''

a: Array
b: Array

def __post_init__(self):
assert not _any_certainly_different(self.a.shape, self.b.shape)
assert self.a.dtype == self.b.dtype
assert self.a.dtype != bool
assert not isinstance(self.b, SymProd)

@property
def shape(self):
return self.a.shape

@property
def dtype(self):
return self.a.dtype

@property
def dependencies(self):
return self.a, self.b

@property
def power(self):
return self.a.power + 1 if isinstance(self.a, SymProd) else 2

def evalf(self, a, b):
return a * b

def _derivative(self, var, seen):
return self.dtype(self.power) * appendaxes(self.a, var.shape) * self.b._derivative(var, seen)

def _optimized_for_numpy(self):
return self.a * self.b

def _argument_degree(self, argument):
return self.a.argument_degree(argument) + self.b.argument_degree(argument)


# AUXILIARY FUNCTIONS (FOR INTERNAL USE)


Expand Down

0 comments on commit 8f46cfb

Please sign in to comment.