From 770e94e40e73d3bbbe08da83f10427e39e9c8fbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20S=2E=20Dokken?= Date: Fri, 27 Sep 2024 08:09:51 +0000 Subject: [PATCH] Add documentation and motivation --- test/test_indices.py | 24 +++++++- ufl/algorithms/renumbering.py | 107 +++++++++++++--------------------- 2 files changed, 65 insertions(+), 66 deletions(-) diff --git a/test/test_indices.py b/test/test_indices.py index ec85f7aa0..d9adbd215 100755 --- a/test/test_indices.py +++ b/test/test_indices.py @@ -15,6 +15,7 @@ exp, i, indices, + interval, j, k, l, @@ -22,6 +23,9 @@ sin, triangle, ) +import ufl.algorithms +import ufl.classes +from ufl import indices from ufl.classes import IndexSum from ufl.finiteelement import FiniteElement from ufl.pullback import identity_pullback @@ -305,4 +309,22 @@ def test_spatial_derivative(self): def test_renumbering(self): - pass + """Test that kernels with common integral data, but different index numbering, are correctly renumbered.""" + cell = interval + mesh = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1)) + V = FunctionSpace(mesh, FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1)) + v = TestFunction(V) + u = TrialFunction(V) + i = indices(1) + a0 = u[i].dx(0)*v[i].dx(0) * ufl.dx((1)) + a1 = u[i].dx(0)*v[i].dx(0) * ufl.dx((2,3,)) + form_data = ufl.algorithms.compute_form_data( + a0+a1, + do_apply_function_pullbacks=True, + do_apply_integral_scaling=True, + do_apply_geometry_lowering=True, + preserve_geometry_types=(ufl.classes.Jacobian,), + do_apply_restrictions=True, + do_append_everywhere_integrals=False) + + assert len(form_data.integral_data) == 1 diff --git a/ufl/algorithms/renumbering.py b/ufl/algorithms/renumbering.py index ecc408b73..44b0a3b31 100644 --- a/ufl/algorithms/renumbering.py +++ b/ufl/algorithms/renumbering.py @@ -1,89 +1,66 @@ """Algorithms for renumbering of counted objects, currently variables and indices.""" -# Copyright (C) 2008-2016 Martin Sandve Alnæs and Anders Logg +# Copyright (C) 2008-2024 Martin Sandve Alnæs, Anders Logg, Jø¶gen S. Dokken and Lawerence Mitchell # # This file is part of UFL (https://www.fenicsproject.org) # # SPDX-License-Identifier: LGPL-3.0-or-later -from ufl.algorithms.transformer import ReuseTransformer, apply_transformer -from ufl.classes import Zero +from ufl.classes import Form, Integral from ufl.core.expr import Expr -from ufl.core.multiindex import FixedIndex, Index, MultiIndex -from ufl.variable import Label, Variable +from ufl.corealg.map_dag import map_expr_dag +from ufl.core.multiindex import Index +from itertools import count as _count +from ufl.corealg.multifunction import MultiFunction +from collections import defaultdict -class VariableRenumberingTransformer(ReuseTransformer): - """Variable renumbering transformer.""" - +class IndexRelabeller(MultiFunction): + """Renumber indices to have a consistent index numbering starting from 0.""" def __init__(self): - """Initialise.""" - ReuseTransformer.__init__(self) - self.variable_map = {} - - def variable(self, o): - """Apply to variable.""" - e, l = o.ufl_operands # noqa: E741 - v = self.variable_map.get(l) - if v is None: - e = self.visit(e) - l2 = Label(len(self.variable_map)) - v = Variable(e, l2) - self.variable_map[l] = v - return v - - -class IndexRenumberingTransformer(VariableRenumberingTransformer): - """Index renumbering transformer. - - This is a poorly designed algorithm. It is used in some tests, - please do not use for anything else. - """ + super().__init__() + count = _count() + self.index_cache = defaultdict(lambda: Index(next(count))) - def __init__(self): - """Initialise.""" - VariableRenumberingTransformer.__init__(self) - self.index_map = {} + expr = MultiFunction.reuse_if_untouched + + def multi_index(self, o): + """Apply to multi-indices.""" + return type(o)(tuple(self.index_cache[i] if isinstance(i, Index) else i for i in o.indices())) def zero(self, o): """Apply to zero.""" fi = o.ufl_free_indices fid = o.ufl_index_dimensions + new_indices = [self.index_cache[Index(i)].count() for i in fi] if fi == () and fid == (): return o - mapped_fi = tuple(self.index(Index(count=i)) for i in fi) - paired_fid = [(mapped_fi[pos], fid[pos]) for pos, a in enumerate(fi)] - new_fi, new_fid = zip(*tuple(sorted(paired_fid))) - return Zero(o.ufl_shape, new_fi, new_fid) + new_fi, new_fid = zip(*sorted(zip(new_indices, fid), key=lambda x: x[0])) + return type(o)(o.ufl_shape, tuple(new_fi), tuple(new_fid)) - def index(self, o): - """Apply to index.""" - if isinstance(o, FixedIndex): - return o - else: - c = o._count - i = self.index_map.get(c) - if i is None: - i = Index(count=len(self.index_map)) - self.index_map[c] = i - return i - def multi_index(self, o): - """Apply to multi_index.""" - new_indices = tuple(self.index(i) for i in o.indices()) - return MultiIndex(new_indices) -def renumber_indices(expr): - """Renumber indices.""" - if isinstance(expr, Expr): - num_free_indices = len(expr.ufl_free_indices) +def renumber_indices(form): + """Renumber indices to have a consistent index numbering starting from 0. - result = apply_transformer(expr, IndexRenumberingTransformer()) + This is useful to avoid multiple kernels for the same integrand, but with different subdomain ids. - if isinstance(expr, Expr): - if num_free_indices != len(result.ufl_free_indices): - raise ValueError( - "The number of free indices left in expression " - "should be invariant w.r.t. renumbering." - ) - return result + Args: + form: A UFL form, integral or expression. + Returns: + A new form, integral or expression with renumbered indices. + """ + if isinstance(form, Form): + new_integrals = [renumber_indices(itg) for itg in form.integrals()] + return Form(new_integrals) + elif isinstance(form, Integral): + integral = form + reindexer = IndexRelabeller() + new_integrand = map_expr_dag(reindexer, integral.integrand()) + return integral.reconstruct(new_integrand) + elif isinstance(form, Expr): + expr = form + reindexer = IndexRelabeller() + return map_expr_dag(reindexer, expr) + else: + raise ValueError(f"Invalid form type {form.__class__name}") \ No newline at end of file