diff --git a/test/test_indices.py b/test/test_indices.py index ec85f7aa0..417648095 100755 --- a/test/test_indices.py +++ b/test/test_indices.py @@ -1,5 +1,7 @@ import pytest +import ufl.algorithms +import ufl.classes from ufl import ( Argument, Coefficient, @@ -15,6 +17,7 @@ exp, i, indices, + interval, j, k, l, @@ -305,4 +308,33 @@ def test_spatial_derivative(self): def test_renumbering(self): - pass + """Test that kernels with common integral data, but different index numbering, + are correctly renumbered.""" + cell = interval + mesh = Mesh(FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1)) + V = FunctionSpace(mesh, FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1)) + v = TestFunction(V) + u = TrialFunction(V) + i = indices(1) + a0 = u[i].dx(0) * v[i].dx(0) * ufl.dx((1)) + a1 = ( + u[i].dx(0) + * v[i].dx(0) + * ufl.dx( + ( + 2, + 3, + ) + ) + ) + form_data = ufl.algorithms.compute_form_data( + a0 + a1, + do_apply_function_pullbacks=True, + do_apply_integral_scaling=True, + do_apply_geometry_lowering=True, + preserve_geometry_types=(ufl.classes.Jacobian,), + do_apply_restrictions=True, + do_append_everywhere_integrals=False, + ) + + assert len(form_data.integral_data) == 1 diff --git a/ufl/algorithms/domain_analysis.py b/ufl/algorithms/domain_analysis.py index 7d7c34042..19e0f4e3c 100644 --- a/ufl/algorithms/domain_analysis.py +++ b/ufl/algorithms/domain_analysis.py @@ -15,6 +15,7 @@ attach_coordinate_derivatives, strip_coordinate_derivatives, ) +from ufl.algorithms.renumbering import renumber_indices from ufl.form import Form from ufl.integral import Integral from ufl.protocols import id_or_none @@ -262,37 +263,16 @@ def build_integral_data(integrals): itgs = defaultdict(list) # --- Merge integral data that has the same integrals, - unique_integrals = defaultdict(tuple) - metadata_table = defaultdict(dict) for integral in integrals: - integrand = integral.integrand() integral_type = integral.integral_type() ufl_domain = integral.ufl_domain() - metadata = integral.metadata() - meta_hash = hash(canonicalize_metadata(metadata)) - subdomain_id = integral.subdomain_id() - subdomain_data = id_or_none(integral.subdomain_data()) - if subdomain_id == "everywhere": + subdomain_ids = integral.subdomain_id() + if "everywhere" in subdomain_ids: raise ValueError( "'everywhere' not a valid subdomain id. " "Did you forget to call group_form_integrals?" ) - unique_integrals[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] += ( - subdomain_id, - ) - metadata_table[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] = metadata - - for integral_data, subdomain_ids in unique_integrals.items(): - (integral_type, ufl_domain, metadata, integrand, subdomain_data) = integral_data - integral = Integral( - integrand, - integral_type, - ufl_domain, - subdomain_ids, - metadata_table[integral_data], - subdomain_data, - ) # Group for integral data (One integral data object for all # integrals with same domain, itype, (but possibly different metadata). itgs[(ufl_domain, integral_type, subdomain_ids)].append(integral) @@ -380,7 +360,39 @@ def calc_hash(cd): ) integral = attach_coordinate_derivatives(integral, samecd_integrals[0]) integrals.append(integral) - return Form(integrals) + + # Group integrals by common integrand + # u.dx(0)*dx(1) + u.dx(0)*dx(2) -> u.dx(0)*dx((1,2)) + # to avoid duplicate kernels generated after geometry lowering + unique_integrals = defaultdict(tuple) + metadata_table = defaultdict(dict) + for integral in integrals: + integral_type = integral.integral_type() + ufl_domain = integral.ufl_domain() + metadata = integral.metadata() + meta_hash = hash(canonicalize_metadata(metadata)) + subdomain_id = integral.subdomain_id() + subdomain_data = id_or_none(integral.subdomain_data()) + integrand = renumber_indices(integral.integrand()) + unique_integrals[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] += ( + subdomain_id, + ) + metadata_table[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] = metadata + + grouped_integrals = [] + for integral_data, subdomain_ids in unique_integrals.items(): + (integral_type, ufl_domain, metadata, integrand, subdomain_data) = integral_data + integral = Integral( + integrand, + integral_type, + ufl_domain, + subdomain_ids, + metadata_table[integral_data], + subdomain_data, + ) + grouped_integrals.append(integral) + + return Form(grouped_integrals) def reconstruct_form_from_integral_data(integral_data): diff --git a/ufl/algorithms/renumbering.py b/ufl/algorithms/renumbering.py index 87e08203d..0e0408c8d 100644 --- a/ufl/algorithms/renumbering.py +++ b/ufl/algorithms/renumbering.py @@ -1,87 +1,57 @@ """Algorithms for renumbering of counted objects, currently variables and indices.""" -# Copyright (C) 2008-2016 Martin Sandve Alnæs and Anders Logg +# Copyright (C) 2008-2024 Martin Sandve Alnæs, Anders Logg, Jørgen S. Dokken and Lawrence Mitchell # # This file is part of UFL (https://www.fenicsproject.org) # # SPDX-License-Identifier: LGPL-3.0-or-later -from ufl.algorithms.transformer import ReuseTransformer, apply_transformer -from ufl.classes import Zero -from ufl.core.expr import Expr -from ufl.core.multiindex import FixedIndex, Index, MultiIndex -from ufl.variable import Label, Variable +from collections import defaultdict +from itertools import count as _count +from ufl.algorithms.map_integrands import map_integrand_dags +from ufl.core.multiindex import Index +from ufl.corealg.multifunction import MultiFunction -class VariableRenumberingTransformer(ReuseTransformer): - """Variable renumbering transformer.""" - - def __init__(self): - """Initialise.""" - ReuseTransformer.__init__(self) - self.variable_map = {} - - def variable(self, o): - """Apply to variable.""" - e, l = o.ufl_operands # noqa: E741 - v = self.variable_map.get(l) - if v is None: - e = self.visit(e) - l2 = Label(len(self.variable_map)) - v = Variable(e, l2) - self.variable_map[l] = v - return v +class IndexRelabeller(MultiFunction): + """Renumber indices to have a consistent index numbering starting from 0.""" -class IndexRenumberingTransformer(VariableRenumberingTransformer): - """Index renumbering transformer. + def __init__(self): + """Initialize index relabeller with a zero count.""" + super().__init__() + count = _count() + self.index_cache = defaultdict(lambda: Index(next(count))) - This is a poorly designed algorithm. It is used in some tests, - please do not use for anything else. - """ + expr = MultiFunction.reuse_if_untouched - def __init__(self): - """Initialise.""" - VariableRenumberingTransformer.__init__(self) - self.index_map = {} + def multi_index(self, o): + """Apply to multi-indices.""" + return type(o)( + tuple(self.index_cache[i] if isinstance(i, Index) else i for i in o.indices()) + ) def zero(self, o): """Apply to zero.""" fi = o.ufl_free_indices fid = o.ufl_index_dimensions - mapped_fi = tuple(self.index(Index(count=i)) for i in fi) - paired_fid = [(mapped_fi[pos], fid[pos]) for pos, a in enumerate(fi)] - new_fi, new_fid = zip(*tuple(sorted(paired_fid))) - return Zero(o.ufl_shape, new_fi, new_fid) - - def index(self, o): - """Apply to index.""" - if isinstance(o, FixedIndex): + new_indices = [self.index_cache[Index(i)].count() for i in fi] + if fi == () and fid == (): return o - else: - c = o._count - i = self.index_map.get(c) - if i is None: - i = Index(count=len(self.index_map)) - self.index_map[c] = i - return i + new_fi, new_fid = zip(*sorted(zip(new_indices, fid), key=lambda x: x[0])) + return type(o)(o.ufl_shape, tuple(new_fi), tuple(new_fid)) - def multi_index(self, o): - """Apply to multi_index.""" - new_indices = tuple(self.index(i) for i in o.indices()) - return MultiIndex(new_indices) +def renumber_indices(form): + """Renumber indices to have a consistent index numbering starting from 0. -def renumber_indices(expr): - """Renumber indices.""" - if isinstance(expr, Expr): - num_free_indices = len(expr.ufl_free_indices) + This is useful to avoid multiple kernels for the same integrand, + but with different subdomain ids. - result = apply_transformer(expr, IndexRenumberingTransformer()) + Args: + form: A UFL form, integral or expression. - if isinstance(expr, Expr): - if num_free_indices != len(result.ufl_free_indices): - raise ValueError( - "The number of free indices left in expression " - "should be invariant w.r.t. renumbering." - ) - return result + Returns: + A new form, integral or expression with renumbered indices. + """ + reindexer = IndexRelabeller() + return map_integrand_dags(reindexer, form)