Skip to content

Commit

Permalink
separate optimize from simplify (#870)
Browse files Browse the repository at this point in the history
This PR limits the deep replacement of Evaluable._optimized_for_numpy1 to the
operations of _optimized_for_numpy, no longer involving _simplified. This
creates a logical separation between simplify and optimized and removes overlap
between the two passes. The latter (.optimized) still implies the former
(.simplified), but executes it strictly in a pre-processing pass, with the
subsequent optimization making only local adjustments to the simplified state.
The _optimized_for_numpy methods are modified to make these adjustments in a
nimble fashion rather than relying on the heavy machinery of _simplified.
  • Loading branch information
gertjanvanzwieten committed May 1, 2024
2 parents 510a284 + af3d790 commit 5c24646
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 21 deletions.
2 changes: 1 addition & 1 deletion nutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
'Numerical Utilities for Finite Element Analysis'

__version__ = version = '9a26'
__version__ = version = '9a27'
version_name = 'jook-sing'
53 changes: 33 additions & 20 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def optimized_for_numpy(self):

@util.deep_replace_property
def _optimized_for_numpy1(obj):
retval = obj._simplified() or obj._optimized_for_numpy()
retval = obj._optimized_for_numpy()
if retval is None:
return obj
if isinstance(obj, Array):
Expand Down Expand Up @@ -1553,8 +1553,8 @@ def _optimized_for_numpy(self):
i, j = sorted([i, factors.index(Sign(fi))])
return multiply(*factors[:i], *factors[i+1:j], *factors[j+1:], Absolute(fi))
if self.ndim:
args, args_idx = zip(*map(unalign, factors))
return Einsum(args, args_idx, tuple(range(self.ndim)))
r = tuple(range(self.ndim))
return Einsum(tuple(self.funcs), (r, r), r)

evalf = staticmethod(numpy.multiply)

Expand Down Expand Up @@ -1832,23 +1832,15 @@ def evalf(self, *args):
def _node_details(self):
return self._einsumfmt

def _simplified(self):
def _optimized_for_numpy(self):
for i, arg in enumerate(self.args):
if isinstance(arg, Transpose): # absorb `Transpose`
idx = tuple(map(self.args_idx[i].__getitem__, numpy.argsort(arg.axes)))
return Einsum(self.args[:i]+(arg.func,)+self.args[i+1:], self.args_idx[:i]+(idx,)+self.args_idx[i+1:], self.out_idx)

def _sum(self, axis):
if not (0 <= axis < self.ndim):
raise IndexError('Axis out of range.')
return Einsum(self.args, self.args_idx, self.out_idx[:axis] + self.out_idx[axis+1:])

def _takediag(self, axis1, axis2):
if not (0 <= axis1 < axis2 < self.ndim):
raise IndexError('Axis out of range.')
ikeep, irm = self.out_idx[axis1], self.out_idx[axis2]
args_idx = tuple(tuple(ikeep if i == irm else i for i in idx) for idx in self.args_idx)
return Einsum(self.args, args_idx, self.out_idx[:axis1] + self.out_idx[axis1+1:axis2] + self.out_idx[axis2+1:] + (ikeep,))
idx = tuple(self.args_idx[i][j] for j in numpy.argsort(arg.axes))
elif isinstance(arg, InsertAxis) and any(self.args_idx[i][-1] in arg_idx for arg_idx in self.args_idx[:i] + self.args_idx[i+1:]):
idx = self.args_idx[i][:-1]
else:
continue
return Einsum(self.args[:i]+(arg.func,)+self.args[i+1:], self.args_idx[:i]+(idx,)+self.args_idx[i+1:], self.out_idx)


class Sum(Array):
Expand All @@ -1864,6 +1856,17 @@ def _simplified(self):
return Take(self.func, constant(0))
return self.func._sum(self.ndim)

def _optimized_for_numpy(self):
func = self.func
axes = list(range(func.ndim))
while isinstance(func, Transpose):
axes = [func.axes[i] for i in axes]
func = func.func
if isinstance(func, Einsum):
rmaxis = axes[-1]
axes = [i-(i>rmaxis) for i in axes[:-1]]
return transpose(Einsum(func.args, func.args_idx, func.out_idx[:rmaxis] + func.out_idx[rmaxis+1:]), axes)

def _sum(self, axis):
trysum = self.func._sum(axis)
if trysum is not None:
Expand Down Expand Up @@ -1918,6 +1921,18 @@ def _simplified(self):
return Take(self.func, constant(0))
return self.func._takediag(self.ndim-1, self.ndim)

def _optimized_for_numpy(self):
func = self.func
axes = list(range(func.ndim))
while isinstance(func, Transpose):
axes = [func.axes[i] for i in axes]
func = func.func
if isinstance(func, Einsum):
axis, rmaxis = axes[-2:]
args_idx = tuple(tuple(func.out_idx[axis] if i == func.out_idx[rmaxis] else i for i in idx) for idx in func.args_idx)
axes = [i-(i>rmaxis) for i in axes[:-1]]
return transpose(Einsum(func.args, args_idx, func.out_idx[:rmaxis] + func.out_idx[rmaxis+1:]), axes)

@staticmethod
def evalf(arr):
return numpy.einsum('...kk->...k', arr, optimize=False)
Expand Down Expand Up @@ -2018,8 +2033,6 @@ def _optimized_for_numpy(self):
return Reciprocal(self.func)
elif p == -2:
return Reciprocal(self.func * self.func)
else:
return self._simplified()

evalf = staticmethod(numpy.power)

Expand Down

0 comments on commit 5c24646

Please sign in to comment.