From af3d7908509841df3d914690ff4dcd35772aca72 Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Wed, 10 Apr 2024 16:47:51 +0200 Subject: [PATCH] separate optimize from simplify This patch limits the deep replacement of Evaluable._optimized_for_numpy1 to the operations of _optimized_for_numpy, no longer involving _simplified. This creates a logical separation between simplify and optimized and removes overlap between the two passes. The latter (.optimized) still implies the former (.simplified), but executes it strictly in a pre-processing pass, with the subsequent optimization making only local adjustments to the simplified state. The _optimized_for_numpy methods are modified to make these adjustments in a nimble fashion rather than relying on the heavy machinery of _simplified. --- nutils/evaluable.py | 53 ++++++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/nutils/evaluable.py b/nutils/evaluable.py index b39ac300e..c37c40f8f 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -316,7 +316,7 @@ def optimized_for_numpy(self): @util.deep_replace_property def _optimized_for_numpy1(obj): - retval = obj._simplified() or obj._optimized_for_numpy() + retval = obj._optimized_for_numpy() if retval is None: return obj if isinstance(obj, Array): @@ -1553,8 +1553,8 @@ def _optimized_for_numpy(self): i, j = sorted([i, factors.index(Sign(fi))]) return multiply(*factors[:i], *factors[i+1:j], *factors[j+1:], Absolute(fi)) if self.ndim: - args, args_idx = zip(*map(unalign, factors)) - return Einsum(args, args_idx, tuple(range(self.ndim))) + r = tuple(range(self.ndim)) + return Einsum(tuple(self.funcs), (r, r), r) evalf = staticmethod(numpy.multiply) @@ -1832,23 +1832,15 @@ def evalf(self, *args): def _node_details(self): return self._einsumfmt - def _simplified(self): + def _optimized_for_numpy(self): for i, arg in enumerate(self.args): if isinstance(arg, Transpose): # absorb `Transpose` - idx = tuple(map(self.args_idx[i].__getitem__, numpy.argsort(arg.axes))) - return Einsum(self.args[:i]+(arg.func,)+self.args[i+1:], self.args_idx[:i]+(idx,)+self.args_idx[i+1:], self.out_idx) - - def _sum(self, axis): - if not (0 <= axis < self.ndim): - raise IndexError('Axis out of range.') - return Einsum(self.args, self.args_idx, self.out_idx[:axis] + self.out_idx[axis+1:]) - - def _takediag(self, axis1, axis2): - if not (0 <= axis1 < axis2 < self.ndim): - raise IndexError('Axis out of range.') - ikeep, irm = self.out_idx[axis1], self.out_idx[axis2] - args_idx = tuple(tuple(ikeep if i == irm else i for i in idx) for idx in self.args_idx) - return Einsum(self.args, args_idx, self.out_idx[:axis1] + self.out_idx[axis1+1:axis2] + self.out_idx[axis2+1:] + (ikeep,)) + idx = tuple(self.args_idx[i][j] for j in numpy.argsort(arg.axes)) + elif isinstance(arg, InsertAxis) and any(self.args_idx[i][-1] in arg_idx for arg_idx in self.args_idx[:i] + self.args_idx[i+1:]): + idx = self.args_idx[i][:-1] + else: + continue + return Einsum(self.args[:i]+(arg.func,)+self.args[i+1:], self.args_idx[:i]+(idx,)+self.args_idx[i+1:], self.out_idx) class Sum(Array): @@ -1864,6 +1856,17 @@ def _simplified(self): return Take(self.func, constant(0)) return self.func._sum(self.ndim) + def _optimized_for_numpy(self): + func = self.func + axes = list(range(func.ndim)) + while isinstance(func, Transpose): + axes = [func.axes[i] for i in axes] + func = func.func + if isinstance(func, Einsum): + rmaxis = axes[-1] + axes = [i-(i>rmaxis) for i in axes[:-1]] + return transpose(Einsum(func.args, func.args_idx, func.out_idx[:rmaxis] + func.out_idx[rmaxis+1:]), axes) + def _sum(self, axis): trysum = self.func._sum(axis) if trysum is not None: @@ -1918,6 +1921,18 @@ def _simplified(self): return Take(self.func, constant(0)) return self.func._takediag(self.ndim-1, self.ndim) + def _optimized_for_numpy(self): + func = self.func + axes = list(range(func.ndim)) + while isinstance(func, Transpose): + axes = [func.axes[i] for i in axes] + func = func.func + if isinstance(func, Einsum): + axis, rmaxis = axes[-2:] + args_idx = tuple(tuple(func.out_idx[axis] if i == func.out_idx[rmaxis] else i for i in idx) for idx in func.args_idx) + axes = [i-(i>rmaxis) for i in axes[:-1]] + return transpose(Einsum(func.args, args_idx, func.out_idx[:rmaxis] + func.out_idx[rmaxis+1:]), axes) + @staticmethod def evalf(arr): return numpy.einsum('...kk->...k', arr, optimize=False) @@ -2018,8 +2033,6 @@ def _optimized_for_numpy(self): return Reciprocal(self.func) elif p == -2: return Reciprocal(self.func * self.func) - else: - return self._simplified() evalf = staticmethod(numpy.power)