From ce0adef481f90dfe84103792f3baa779bf35b640 Mon Sep 17 00:00:00 2001 From: jorgensd Date: Thu, 19 Sep 2024 17:53:29 +0000 Subject: [PATCH 01/14] Move grouping of integral from build integral data to group form integrals --- ufl/algorithms/domain_analysis.py | 53 ++++++++++++++++--------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/ufl/algorithms/domain_analysis.py b/ufl/algorithms/domain_analysis.py index 7d7c34042..7c21ceab8 100644 --- a/ufl/algorithms/domain_analysis.py +++ b/ufl/algorithms/domain_analysis.py @@ -262,37 +262,17 @@ def build_integral_data(integrals): itgs = defaultdict(list) # --- Merge integral data that has the same integrals, - unique_integrals = defaultdict(tuple) - metadata_table = defaultdict(dict) for integral in integrals: - integrand = integral.integrand() integral_type = integral.integral_type() ufl_domain = integral.ufl_domain() - metadata = integral.metadata() - meta_hash = hash(canonicalize_metadata(metadata)) - subdomain_id = integral.subdomain_id() - subdomain_data = id_or_none(integral.subdomain_data()) - if subdomain_id == "everywhere": + subdomain_ids = integral.subdomain_id() + if "everywhere" in subdomain_ids: + raise ValueError( "'everywhere' not a valid subdomain id. " "Did you forget to call group_form_integrals?" ) - unique_integrals[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] += ( - subdomain_id, - ) - metadata_table[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] = metadata - - for integral_data, subdomain_ids in unique_integrals.items(): - (integral_type, ufl_domain, metadata, integrand, subdomain_data) = integral_data - - integral = Integral( - integrand, - integral_type, - ufl_domain, - subdomain_ids, - metadata_table[integral_data], - subdomain_data, - ) + # Group for integral data (One integral data object for all # integrals with same domain, itype, (but possibly different metadata). itgs[(ufl_domain, integral_type, subdomain_ids)].append(integral) @@ -380,7 +360,30 @@ def calc_hash(cd): ) integral = attach_coordinate_derivatives(integral, samecd_integrals[0]) integrals.append(integral) - return Form(integrals) + + # Group integrals by common subdomain id + # u.dx(0)*dx(1) + u.dx(0)*dx(2) -> u.dx(0)*dx((1,2)) + # to avoid duplicate kernels generated after geoemtry lowering + unique_integrals = defaultdict(tuple) + metadata_table = defaultdict(dict) + for integral in integrals: + integral_type = integral.integral_type() + ufl_domain = integral.ufl_domain() + metadata = integral.metadata() + meta_hash = hash(canonicalize_metadata(metadata)) + subdomain_id = integral.subdomain_id() + subdomain_data = id_or_none(integral.subdomain_data()) + unique_integrals[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] += (subdomain_id,) + metadata_table[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] = metadata + + grouped_integrals = [] + for integral_data, subdomain_ids in unique_integrals.items(): + (integral_type, ufl_domain, metadata, integrand, subdomain_data) = integral_data + integral = Integral(integrand, integral_type, ufl_domain, subdomain_ids, + metadata_table[integral_data], subdomain_data) + grouped_integrals.append(integral) + + return Form(grouped_integrals) def reconstruct_form_from_integral_data(integral_data): From e4ca24d78ee96e625bb8a14d542f548ce9339ba4 Mon Sep 17 00:00:00 2001 From: jorgensd Date: Thu, 19 Sep 2024 17:56:20 +0000 Subject: [PATCH 02/14] Ruff formatting --- ufl/algorithms/domain_analysis.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/ufl/algorithms/domain_analysis.py b/ufl/algorithms/domain_analysis.py index 7c21ceab8..02be68038 100644 --- a/ufl/algorithms/domain_analysis.py +++ b/ufl/algorithms/domain_analysis.py @@ -267,12 +267,11 @@ def build_integral_data(integrals): ufl_domain = integral.ufl_domain() subdomain_ids = integral.subdomain_id() if "everywhere" in subdomain_ids: - raise ValueError( "'everywhere' not a valid subdomain id. " "Did you forget to call group_form_integrals?" ) - + # Group for integral data (One integral data object for all # integrals with same domain, itype, (but possibly different metadata). itgs[(ufl_domain, integral_type, subdomain_ids)].append(integral) @@ -373,14 +372,22 @@ def calc_hash(cd): meta_hash = hash(canonicalize_metadata(metadata)) subdomain_id = integral.subdomain_id() subdomain_data = id_or_none(integral.subdomain_data()) - unique_integrals[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] += (subdomain_id,) + unique_integrals[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] += ( + subdomain_id, + ) metadata_table[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] = metadata - + grouped_integrals = [] for integral_data, subdomain_ids in unique_integrals.items(): (integral_type, ufl_domain, metadata, integrand, subdomain_data) = integral_data - integral = Integral(integrand, integral_type, ufl_domain, subdomain_ids, - metadata_table[integral_data], subdomain_data) + integral = Integral( + integrand, + integral_type, + ufl_domain, + subdomain_ids, + metadata_table[integral_data], + subdomain_data, + ) grouped_integrals.append(integral) return Form(grouped_integrals) From 136d27d0d05fac090723d44c5333e5f7c9781945 Mon Sep 17 00:00:00 2001 From: jorgensd Date: Thu, 19 Sep 2024 18:07:43 +0000 Subject: [PATCH 03/14] Add missing key --- ufl/algorithms/domain_analysis.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ufl/algorithms/domain_analysis.py b/ufl/algorithms/domain_analysis.py index 02be68038..64187e7ba 100644 --- a/ufl/algorithms/domain_analysis.py +++ b/ufl/algorithms/domain_analysis.py @@ -372,6 +372,7 @@ def calc_hash(cd): meta_hash = hash(canonicalize_metadata(metadata)) subdomain_id = integral.subdomain_id() subdomain_data = id_or_none(integral.subdomain_data()) + integrand = integral.integrand() unique_integrals[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] += ( subdomain_id, ) From 98959717a84a0f6ef5a7ae5633f091529c36a6a3 Mon Sep 17 00:00:00 2001 From: jorgensd Date: Thu, 19 Sep 2024 18:16:22 +0000 Subject: [PATCH 04/14] Add index renumbering --- ufl/algorithms/domain_analysis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ufl/algorithms/domain_analysis.py b/ufl/algorithms/domain_analysis.py index 64187e7ba..2acee2a93 100644 --- a/ufl/algorithms/domain_analysis.py +++ b/ufl/algorithms/domain_analysis.py @@ -15,6 +15,7 @@ attach_coordinate_derivatives, strip_coordinate_derivatives, ) +from ufl.algorithms.renumbering import renumber_indices from ufl.form import Form from ufl.integral import Integral from ufl.protocols import id_or_none @@ -372,7 +373,7 @@ def calc_hash(cd): meta_hash = hash(canonicalize_metadata(metadata)) subdomain_id = integral.subdomain_id() subdomain_data = id_or_none(integral.subdomain_data()) - integrand = integral.integrand() + integrand = renumber_indices(integral.integrand()) unique_integrals[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] += ( subdomain_id, ) From 270801edf791f47fe9a095c543761938bda10845 Mon Sep 17 00:00:00 2001 From: jorgensd Date: Thu, 19 Sep 2024 18:29:08 +0000 Subject: [PATCH 05/14] Fix zero index-renumbering for scalar zero --- ufl/algorithms/renumbering.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ufl/algorithms/renumbering.py b/ufl/algorithms/renumbering.py index 87e08203d..f4e9b59ba 100644 --- a/ufl/algorithms/renumbering.py +++ b/ufl/algorithms/renumbering.py @@ -48,6 +48,8 @@ def zero(self, o): """Apply to zero.""" fi = o.ufl_free_indices fid = o.ufl_index_dimensions + if fi == () and fid ==(): + return o mapped_fi = tuple(self.index(Index(count=i)) for i in fi) paired_fid = [(mapped_fi[pos], fid[pos]) for pos, a in enumerate(fi)] new_fi, new_fid = zip(*tuple(sorted(paired_fid))) From fd6f814e204d34a694cd193f85f06204e902fe04 Mon Sep 17 00:00:00 2001 From: jorgensd Date: Thu, 19 Sep 2024 18:30:01 +0000 Subject: [PATCH 06/14] Ruff format --- ufl/algorithms/renumbering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ufl/algorithms/renumbering.py b/ufl/algorithms/renumbering.py index f4e9b59ba..ecc408b73 100644 --- a/ufl/algorithms/renumbering.py +++ b/ufl/algorithms/renumbering.py @@ -48,7 +48,7 @@ def zero(self, o): """Apply to zero.""" fi = o.ufl_free_indices fid = o.ufl_index_dimensions - if fi == () and fid ==(): + if fi == () and fid == (): return o mapped_fi = tuple(self.index(Index(count=i)) for i in fi) paired_fid = [(mapped_fi[pos], fid[pos]) for pos, a in enumerate(fi)] From 55ccc7336376047f336f5af74887cb9d823a5fdf Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 24 Sep 2024 13:22:34 +0100 Subject: [PATCH 07/14] Update ufl/algorithms/domain_analysis.py --- ufl/algorithms/domain_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ufl/algorithms/domain_analysis.py b/ufl/algorithms/domain_analysis.py index 2acee2a93..b90de1b64 100644 --- a/ufl/algorithms/domain_analysis.py +++ b/ufl/algorithms/domain_analysis.py @@ -363,7 +363,7 @@ def calc_hash(cd): # Group integrals by common subdomain id # u.dx(0)*dx(1) + u.dx(0)*dx(2) -> u.dx(0)*dx((1,2)) - # to avoid duplicate kernels generated after geoemtry lowering + # to avoid duplicate kernels generated after geometry lowering unique_integrals = defaultdict(tuple) metadata_table = defaultdict(dict) for integral in integrals: From 770e94e40e73d3bbbe08da83f10427e39e9c8fbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20S=2E=20Dokken?= Date: Fri, 27 Sep 2024 08:09:51 +0000 Subject: [PATCH 08/14] Add documentation and motivation --- test/test_indices.py | 24 +++++++- ufl/algorithms/renumbering.py | 107 +++++++++++++--------------------- 2 files changed, 65 insertions(+), 66 deletions(-) diff --git a/test/test_indices.py b/test/test_indices.py index ec85f7aa0..d9adbd215 100755 --- a/test/test_indices.py +++ b/test/test_indices.py @@ -15,6 +15,7 @@ exp, i, indices, + interval, j, k, l, @@ -22,6 +23,9 @@ sin, triangle, ) +import ufl.algorithms +import ufl.classes +from ufl import indices from ufl.classes import IndexSum from ufl.finiteelement import FiniteElement from ufl.pullback import identity_pullback @@ -305,4 +309,22 @@ def test_spatial_derivative(self): def test_renumbering(self): - pass + """Test that kernels with common integral data, but different index numbering, are correctly renumbered.""" + cell = interval + mesh = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1)) + V = FunctionSpace(mesh, FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1)) + v = TestFunction(V) + u = TrialFunction(V) + i = indices(1) + a0 = u[i].dx(0)*v[i].dx(0) * ufl.dx((1)) + a1 = u[i].dx(0)*v[i].dx(0) * ufl.dx((2,3,)) + form_data = ufl.algorithms.compute_form_data( + a0+a1, + do_apply_function_pullbacks=True, + do_apply_integral_scaling=True, + do_apply_geometry_lowering=True, + preserve_geometry_types=(ufl.classes.Jacobian,), + do_apply_restrictions=True, + do_append_everywhere_integrals=False) + + assert len(form_data.integral_data) == 1 diff --git a/ufl/algorithms/renumbering.py b/ufl/algorithms/renumbering.py index ecc408b73..44b0a3b31 100644 --- a/ufl/algorithms/renumbering.py +++ b/ufl/algorithms/renumbering.py @@ -1,89 +1,66 @@ """Algorithms for renumbering of counted objects, currently variables and indices.""" -# Copyright (C) 2008-2016 Martin Sandve Alnæs and Anders Logg +# Copyright (C) 2008-2024 Martin Sandve Alnæs, Anders Logg, Jø¶gen S. Dokken and Lawerence Mitchell # # This file is part of UFL (https://www.fenicsproject.org) # # SPDX-License-Identifier: LGPL-3.0-or-later -from ufl.algorithms.transformer import ReuseTransformer, apply_transformer -from ufl.classes import Zero +from ufl.classes import Form, Integral from ufl.core.expr import Expr -from ufl.core.multiindex import FixedIndex, Index, MultiIndex -from ufl.variable import Label, Variable +from ufl.corealg.map_dag import map_expr_dag +from ufl.core.multiindex import Index +from itertools import count as _count +from ufl.corealg.multifunction import MultiFunction +from collections import defaultdict -class VariableRenumberingTransformer(ReuseTransformer): - """Variable renumbering transformer.""" - +class IndexRelabeller(MultiFunction): + """Renumber indices to have a consistent index numbering starting from 0.""" def __init__(self): - """Initialise.""" - ReuseTransformer.__init__(self) - self.variable_map = {} - - def variable(self, o): - """Apply to variable.""" - e, l = o.ufl_operands # noqa: E741 - v = self.variable_map.get(l) - if v is None: - e = self.visit(e) - l2 = Label(len(self.variable_map)) - v = Variable(e, l2) - self.variable_map[l] = v - return v - - -class IndexRenumberingTransformer(VariableRenumberingTransformer): - """Index renumbering transformer. - - This is a poorly designed algorithm. It is used in some tests, - please do not use for anything else. - """ + super().__init__() + count = _count() + self.index_cache = defaultdict(lambda: Index(next(count))) - def __init__(self): - """Initialise.""" - VariableRenumberingTransformer.__init__(self) - self.index_map = {} + expr = MultiFunction.reuse_if_untouched + + def multi_index(self, o): + """Apply to multi-indices.""" + return type(o)(tuple(self.index_cache[i] if isinstance(i, Index) else i for i in o.indices())) def zero(self, o): """Apply to zero.""" fi = o.ufl_free_indices fid = o.ufl_index_dimensions + new_indices = [self.index_cache[Index(i)].count() for i in fi] if fi == () and fid == (): return o - mapped_fi = tuple(self.index(Index(count=i)) for i in fi) - paired_fid = [(mapped_fi[pos], fid[pos]) for pos, a in enumerate(fi)] - new_fi, new_fid = zip(*tuple(sorted(paired_fid))) - return Zero(o.ufl_shape, new_fi, new_fid) + new_fi, new_fid = zip(*sorted(zip(new_indices, fid), key=lambda x: x[0])) + return type(o)(o.ufl_shape, tuple(new_fi), tuple(new_fid)) - def index(self, o): - """Apply to index.""" - if isinstance(o, FixedIndex): - return o - else: - c = o._count - i = self.index_map.get(c) - if i is None: - i = Index(count=len(self.index_map)) - self.index_map[c] = i - return i - def multi_index(self, o): - """Apply to multi_index.""" - new_indices = tuple(self.index(i) for i in o.indices()) - return MultiIndex(new_indices) -def renumber_indices(expr): - """Renumber indices.""" - if isinstance(expr, Expr): - num_free_indices = len(expr.ufl_free_indices) +def renumber_indices(form): + """Renumber indices to have a consistent index numbering starting from 0. - result = apply_transformer(expr, IndexRenumberingTransformer()) + This is useful to avoid multiple kernels for the same integrand, but with different subdomain ids. - if isinstance(expr, Expr): - if num_free_indices != len(result.ufl_free_indices): - raise ValueError( - "The number of free indices left in expression " - "should be invariant w.r.t. renumbering." - ) - return result + Args: + form: A UFL form, integral or expression. + Returns: + A new form, integral or expression with renumbered indices. + """ + if isinstance(form, Form): + new_integrals = [renumber_indices(itg) for itg in form.integrals()] + return Form(new_integrals) + elif isinstance(form, Integral): + integral = form + reindexer = IndexRelabeller() + new_integrand = map_expr_dag(reindexer, integral.integrand()) + return integral.reconstruct(new_integrand) + elif isinstance(form, Expr): + expr = form + reindexer = IndexRelabeller() + return map_expr_dag(reindexer, expr) + else: + raise ValueError(f"Invalid form type {form.__class__name}") \ No newline at end of file From 367bede82116b5e37c485eb6e2892ced3e2c220d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20S=2E=20Dokken?= Date: Fri, 27 Sep 2024 08:12:07 +0000 Subject: [PATCH 09/14] Doc fixes --- test/test_indices.py | 42 ++++++++++++++++++++++------------- ufl/algorithms/renumbering.py | 29 ++++++++++++++---------- 2 files changed, 43 insertions(+), 28 deletions(-) diff --git a/test/test_indices.py b/test/test_indices.py index d9adbd215..417648095 100755 --- a/test/test_indices.py +++ b/test/test_indices.py @@ -1,5 +1,7 @@ import pytest +import ufl.algorithms +import ufl.classes from ufl import ( Argument, Coefficient, @@ -23,9 +25,6 @@ sin, triangle, ) -import ufl.algorithms -import ufl.classes -from ufl import indices from ufl.classes import IndexSum from ufl.finiteelement import FiniteElement from ufl.pullback import identity_pullback @@ -309,22 +308,33 @@ def test_spatial_derivative(self): def test_renumbering(self): - """Test that kernels with common integral data, but different index numbering, are correctly renumbered.""" + """Test that kernels with common integral data, but different index numbering, + are correctly renumbered.""" cell = interval - mesh = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1)) - V = FunctionSpace(mesh, FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1)) + mesh = Mesh(FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1)) + V = FunctionSpace(mesh, FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1)) v = TestFunction(V) u = TrialFunction(V) i = indices(1) - a0 = u[i].dx(0)*v[i].dx(0) * ufl.dx((1)) - a1 = u[i].dx(0)*v[i].dx(0) * ufl.dx((2,3,)) + a0 = u[i].dx(0) * v[i].dx(0) * ufl.dx((1)) + a1 = ( + u[i].dx(0) + * v[i].dx(0) + * ufl.dx( + ( + 2, + 3, + ) + ) + ) form_data = ufl.algorithms.compute_form_data( - a0+a1, - do_apply_function_pullbacks=True, - do_apply_integral_scaling=True, - do_apply_geometry_lowering=True, - preserve_geometry_types=(ufl.classes.Jacobian,), - do_apply_restrictions=True, - do_append_everywhere_integrals=False) - + a0 + a1, + do_apply_function_pullbacks=True, + do_apply_integral_scaling=True, + do_apply_geometry_lowering=True, + preserve_geometry_types=(ufl.classes.Jacobian,), + do_apply_restrictions=True, + do_append_everywhere_integrals=False, + ) + assert len(form_data.integral_data) == 1 diff --git a/ufl/algorithms/renumbering.py b/ufl/algorithms/renumbering.py index 44b0a3b31..cc99016f3 100644 --- a/ufl/algorithms/renumbering.py +++ b/ufl/algorithms/renumbering.py @@ -5,27 +5,32 @@ # # SPDX-License-Identifier: LGPL-3.0-or-later +from collections import defaultdict +from itertools import count as _count + from ufl.classes import Form, Integral from ufl.core.expr import Expr -from ufl.corealg.map_dag import map_expr_dag - from ufl.core.multiindex import Index -from itertools import count as _count +from ufl.corealg.map_dag import map_expr_dag from ufl.corealg.multifunction import MultiFunction -from collections import defaultdict + class IndexRelabeller(MultiFunction): """Renumber indices to have a consistent index numbering starting from 0.""" + def __init__(self): + """Initialize index relabeller with a zero count.""" super().__init__() count = _count() self.index_cache = defaultdict(lambda: Index(next(count))) expr = MultiFunction.reuse_if_untouched - + def multi_index(self, o): """Apply to multi-indices.""" - return type(o)(tuple(self.index_cache[i] if isinstance(i, Index) else i for i in o.indices())) + return type(o)( + tuple(self.index_cache[i] if isinstance(i, Index) else i for i in o.indices()) + ) def zero(self, o): """Apply to zero.""" @@ -38,17 +43,17 @@ def zero(self, o): return type(o)(o.ufl_shape, tuple(new_fi), tuple(new_fid)) - - def renumber_indices(form): """Renumber indices to have a consistent index numbering starting from 0. - This is useful to avoid multiple kernels for the same integrand, but with different subdomain ids. + This is useful to avoid multiple kernels for the same integrand, + but with different subdomain ids. Args: form: A UFL form, integral or expression. + Returns: - A new form, integral or expression with renumbered indices. + A new form, integral or expression with renumbered indices. """ if isinstance(form, Form): new_integrals = [renumber_indices(itg) for itg in form.integrals()] @@ -57,10 +62,10 @@ def renumber_indices(form): integral = form reindexer = IndexRelabeller() new_integrand = map_expr_dag(reindexer, integral.integrand()) - return integral.reconstruct(new_integrand) + return integral.reconstruct(new_integrand) elif isinstance(form, Expr): expr = form reindexer = IndexRelabeller() return map_expr_dag(reindexer, expr) else: - raise ValueError(f"Invalid form type {form.__class__name}") \ No newline at end of file + raise ValueError(f"Invalid form type {form.__class__name}") From d405a4ada26fcd79f3f3a141f76b784787eb5dc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Schartum=20Dokken?= Date: Fri, 27 Sep 2024 10:17:57 +0200 Subject: [PATCH 10/14] Fix name --- ufl/algorithms/renumbering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ufl/algorithms/renumbering.py b/ufl/algorithms/renumbering.py index cc99016f3..ca1a543f2 100644 --- a/ufl/algorithms/renumbering.py +++ b/ufl/algorithms/renumbering.py @@ -1,5 +1,5 @@ """Algorithms for renumbering of counted objects, currently variables and indices.""" -# Copyright (C) 2008-2024 Martin Sandve Alnæs, Anders Logg, Jø¶gen S. Dokken and Lawerence Mitchell +# Copyright (C) 2008-2024 Martin Sandve Alnæs, Anders Logg, Jørgen S. Dokken and Lawerence Mitchell # # This file is part of UFL (https://www.fenicsproject.org) # From be8786d16a5927cd165e4d1e4fc87d7d6bdeebda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Schartum=20Dokken?= Date: Fri, 27 Sep 2024 13:15:48 +0200 Subject: [PATCH 11/14] Fix lawrence name --- ufl/algorithms/renumbering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ufl/algorithms/renumbering.py b/ufl/algorithms/renumbering.py index ca1a543f2..d44873928 100644 --- a/ufl/algorithms/renumbering.py +++ b/ufl/algorithms/renumbering.py @@ -1,5 +1,5 @@ """Algorithms for renumbering of counted objects, currently variables and indices.""" -# Copyright (C) 2008-2024 Martin Sandve Alnæs, Anders Logg, Jørgen S. Dokken and Lawerence Mitchell +# Copyright (C) 2008-2024 Martin Sandve Alnæs, Anders Logg, Jørgen S. Dokken and Lawrence Mitchell # # This file is part of UFL (https://www.fenicsproject.org) # From 84fb2c4b9a14ed78298cdc168a4a1c0016fb6596 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Schartum=20Dokken?= Date: Wed, 2 Oct 2024 17:14:13 +0200 Subject: [PATCH 12/14] Apply suggestions from code review Co-authored-by: Joe Dean --- ufl/algorithms/domain_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ufl/algorithms/domain_analysis.py b/ufl/algorithms/domain_analysis.py index b90de1b64..19e0f4e3c 100644 --- a/ufl/algorithms/domain_analysis.py +++ b/ufl/algorithms/domain_analysis.py @@ -361,7 +361,7 @@ def calc_hash(cd): integral = attach_coordinate_derivatives(integral, samecd_integrals[0]) integrals.append(integral) - # Group integrals by common subdomain id + # Group integrals by common integrand # u.dx(0)*dx(1) + u.dx(0)*dx(2) -> u.dx(0)*dx((1,2)) # to avoid duplicate kernels generated after geometry lowering unique_integrals = defaultdict(tuple) From ac30a3e15f4a5ec548f9fd8d70e17404f2320b40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20S=2E=20Dokken?= Date: Wed, 2 Oct 2024 15:23:29 +0000 Subject: [PATCH 13/14] Simplify using Lawrence instructions --- ufl/algorithms/renumbering.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/ufl/algorithms/renumbering.py b/ufl/algorithms/renumbering.py index d44873928..d205f5c41 100644 --- a/ufl/algorithms/renumbering.py +++ b/ufl/algorithms/renumbering.py @@ -8,12 +8,9 @@ from collections import defaultdict from itertools import count as _count -from ufl.classes import Form, Integral -from ufl.core.expr import Expr from ufl.core.multiindex import Index -from ufl.corealg.map_dag import map_expr_dag from ufl.corealg.multifunction import MultiFunction - +from ufl.algorithms.map_integrands import map_integrand_dags class IndexRelabeller(MultiFunction): """Renumber indices to have a consistent index numbering starting from 0.""" @@ -55,17 +52,6 @@ def renumber_indices(form): Returns: A new form, integral or expression with renumbered indices. """ - if isinstance(form, Form): - new_integrals = [renumber_indices(itg) for itg in form.integrals()] - return Form(new_integrals) - elif isinstance(form, Integral): - integral = form - reindexer = IndexRelabeller() - new_integrand = map_expr_dag(reindexer, integral.integrand()) - return integral.reconstruct(new_integrand) - elif isinstance(form, Expr): - expr = form - reindexer = IndexRelabeller() - return map_expr_dag(reindexer, expr) - else: - raise ValueError(f"Invalid form type {form.__class__name}") + + reindexer = IndexRelabeller() + return map_integrand_dags(reindexer, form) \ No newline at end of file From a86d573b5f2d9fb232b2b55f6d6d34d8f5a2212f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20S=2E=20Dokken?= Date: Wed, 2 Oct 2024 15:24:36 +0000 Subject: [PATCH 14/14] Ruff formatting --- ufl/algorithms/renumbering.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ufl/algorithms/renumbering.py b/ufl/algorithms/renumbering.py index d205f5c41..0e0408c8d 100644 --- a/ufl/algorithms/renumbering.py +++ b/ufl/algorithms/renumbering.py @@ -8,9 +8,10 @@ from collections import defaultdict from itertools import count as _count +from ufl.algorithms.map_integrands import map_integrand_dags from ufl.core.multiindex import Index from ufl.corealg.multifunction import MultiFunction -from ufl.algorithms.map_integrands import map_integrand_dags + class IndexRelabeller(MultiFunction): """Renumber indices to have a consistent index numbering starting from 0.""" @@ -52,6 +53,5 @@ def renumber_indices(form): Returns: A new form, integral or expression with renumbered indices. """ - reindexer = IndexRelabeller() - return map_integrand_dags(reindexer, form) \ No newline at end of file + return map_integrand_dags(reindexer, form)