Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move grouping of integral from build_integral_data to group_form_integrals #305

Merged
34 changes: 33 additions & 1 deletion test/test_indices.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

import ufl.algorithms
import ufl.classes
from ufl import (
Argument,
Coefficient,
Expand All @@ -15,6 +17,7 @@
exp,
i,
indices,
interval,
j,
k,
l,
Expand Down Expand Up @@ -305,4 +308,33 @@ def test_spatial_derivative(self):


def test_renumbering(self):
pass
"""Test that kernels with common integral data, but different index numbering,
are correctly renumbered."""
cell = interval
mesh = Mesh(FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1))
V = FunctionSpace(mesh, FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1))
v = TestFunction(V)
u = TrialFunction(V)
i = indices(1)
a0 = u[i].dx(0) * v[i].dx(0) * ufl.dx((1))
a1 = (
u[i].dx(0)
* v[i].dx(0)
* ufl.dx(
(
2,
3,
)
)
)
form_data = ufl.algorithms.compute_form_data(
a0 + a1,
do_apply_function_pullbacks=True,
do_apply_integral_scaling=True,
do_apply_geometry_lowering=True,
preserve_geometry_types=(ufl.classes.Jacobian,),
do_apply_restrictions=True,
do_append_everywhere_integrals=False,
)

assert len(form_data.integral_data) == 1
60 changes: 36 additions & 24 deletions ufl/algorithms/domain_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
attach_coordinate_derivatives,
strip_coordinate_derivatives,
)
from ufl.algorithms.renumbering import renumber_indices
from ufl.form import Form
from ufl.integral import Integral
from ufl.protocols import id_or_none
Expand Down Expand Up @@ -262,37 +263,16 @@ def build_integral_data(integrals):
itgs = defaultdict(list)

# --- Merge integral data that has the same integrals,
unique_integrals = defaultdict(tuple)
metadata_table = defaultdict(dict)
for integral in integrals:
integrand = integral.integrand()
integral_type = integral.integral_type()
ufl_domain = integral.ufl_domain()
metadata = integral.metadata()
meta_hash = hash(canonicalize_metadata(metadata))
subdomain_id = integral.subdomain_id()
subdomain_data = id_or_none(integral.subdomain_data())
if subdomain_id == "everywhere":
subdomain_ids = integral.subdomain_id()
if "everywhere" in subdomain_ids:
raise ValueError(
"'everywhere' not a valid subdomain id. "
"Did you forget to call group_form_integrals?"
)
unique_integrals[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] += (
subdomain_id,
)
metadata_table[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] = metadata

for integral_data, subdomain_ids in unique_integrals.items():
(integral_type, ufl_domain, metadata, integrand, subdomain_data) = integral_data

integral = Integral(
integrand,
integral_type,
ufl_domain,
subdomain_ids,
metadata_table[integral_data],
subdomain_data,
)
# Group for integral data (One integral data object for all
# integrals with same domain, itype, (but possibly different metadata).
itgs[(ufl_domain, integral_type, subdomain_ids)].append(integral)
Expand Down Expand Up @@ -380,7 +360,39 @@ def calc_hash(cd):
)
integral = attach_coordinate_derivatives(integral, samecd_integrals[0])
integrals.append(integral)
return Form(integrals)

# Group integrals by common subdomain id
jorgensd marked this conversation as resolved.
Show resolved Hide resolved
# u.dx(0)*dx(1) + u.dx(0)*dx(2) -> u.dx(0)*dx((1,2))
# to avoid duplicate kernels generated after geometry lowering
unique_integrals = defaultdict(tuple)
metadata_table = defaultdict(dict)
for integral in integrals:
integral_type = integral.integral_type()
ufl_domain = integral.ufl_domain()
metadata = integral.metadata()
meta_hash = hash(canonicalize_metadata(metadata))
subdomain_id = integral.subdomain_id()
subdomain_data = id_or_none(integral.subdomain_data())
integrand = renumber_indices(integral.integrand())
jorgensd marked this conversation as resolved.
Show resolved Hide resolved
unique_integrals[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] += (
subdomain_id,
)
metadata_table[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] = metadata

grouped_integrals = []
for integral_data, subdomain_ids in unique_integrals.items():
(integral_type, ufl_domain, metadata, integrand, subdomain_data) = integral_data
integral = Integral(
integrand,
integral_type,
ufl_domain,
subdomain_ids,
metadata_table[integral_data],
subdomain_data,
)
grouped_integrals.append(integral)

return Form(grouped_integrals)


def reconstruct_form_from_integral_data(integral_data):
Expand Down
110 changes: 47 additions & 63 deletions ufl/algorithms/renumbering.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,71 @@
"""Algorithms for renumbering of counted objects, currently variables and indices."""
# Copyright (C) 2008-2016 Martin Sandve Alnæs and Anders Logg
# Copyright (C) 2008-2024 Martin Sandve Alnæs, Anders Logg, Jørgen S. Dokken and Lawerence Mitchell
jorgensd marked this conversation as resolved.
Show resolved Hide resolved
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from ufl.algorithms.transformer import ReuseTransformer, apply_transformer
from ufl.classes import Zero
from collections import defaultdict
from itertools import count as _count

from ufl.classes import Form, Integral
from ufl.core.expr import Expr
from ufl.core.multiindex import FixedIndex, Index, MultiIndex
from ufl.variable import Label, Variable
from ufl.core.multiindex import Index
from ufl.corealg.map_dag import map_expr_dag
from ufl.corealg.multifunction import MultiFunction


class VariableRenumberingTransformer(ReuseTransformer):
"""Variable renumbering transformer."""
class IndexRelabeller(MultiFunction):
"""Renumber indices to have a consistent index numbering starting from 0."""

def __init__(self):
"""Initialise."""
ReuseTransformer.__init__(self)
self.variable_map = {}

def variable(self, o):
"""Apply to variable."""
e, l = o.ufl_operands # noqa: E741
v = self.variable_map.get(l)
if v is None:
e = self.visit(e)
l2 = Label(len(self.variable_map))
v = Variable(e, l2)
self.variable_map[l] = v
return v

"""Initialize index relabeller with a zero count."""
super().__init__()
count = _count()
self.index_cache = defaultdict(lambda: Index(next(count)))

class IndexRenumberingTransformer(VariableRenumberingTransformer):
"""Index renumbering transformer.

This is a poorly designed algorithm. It is used in some tests,
please do not use for anything else.
"""
expr = MultiFunction.reuse_if_untouched

def __init__(self):
"""Initialise."""
VariableRenumberingTransformer.__init__(self)
self.index_map = {}
def multi_index(self, o):
"""Apply to multi-indices."""
return type(o)(
tuple(self.index_cache[i] if isinstance(i, Index) else i for i in o.indices())
)

def zero(self, o):
"""Apply to zero."""
fi = o.ufl_free_indices
fid = o.ufl_index_dimensions
mapped_fi = tuple(self.index(Index(count=i)) for i in fi)
paired_fid = [(mapped_fi[pos], fid[pos]) for pos, a in enumerate(fi)]
new_fi, new_fid = zip(*tuple(sorted(paired_fid)))
return Zero(o.ufl_shape, new_fi, new_fid)

def index(self, o):
"""Apply to index."""
if isinstance(o, FixedIndex):
new_indices = [self.index_cache[Index(i)].count() for i in fi]
if fi == () and fid == ():
return o
else:
c = o._count
i = self.index_map.get(c)
if i is None:
i = Index(count=len(self.index_map))
self.index_map[c] = i
return i
new_fi, new_fid = zip(*sorted(zip(new_indices, fid), key=lambda x: x[0]))
return type(o)(o.ufl_shape, tuple(new_fi), tuple(new_fid))

def multi_index(self, o):
"""Apply to multi_index."""
new_indices = tuple(self.index(i) for i in o.indices())
return MultiIndex(new_indices)

def renumber_indices(form):
"""Renumber indices to have a consistent index numbering starting from 0.

def renumber_indices(expr):
"""Renumber indices."""
if isinstance(expr, Expr):
num_free_indices = len(expr.ufl_free_indices)
This is useful to avoid multiple kernels for the same integrand,
but with different subdomain ids.

result = apply_transformer(expr, IndexRenumberingTransformer())
Args:
form: A UFL form, integral or expression.

if isinstance(expr, Expr):
if num_free_indices != len(result.ufl_free_indices):
raise ValueError(
"The number of free indices left in expression "
"should be invariant w.r.t. renumbering."
)
return result
Returns:
A new form, integral or expression with renumbered indices.
"""
if isinstance(form, Form):
new_integrals = [renumber_indices(itg) for itg in form.integrals()]
return Form(new_integrals)
elif isinstance(form, Integral):
integral = form
reindexer = IndexRelabeller()
new_integrand = map_expr_dag(reindexer, integral.integrand())
return integral.reconstruct(new_integrand)
elif isinstance(form, Expr):
expr = form
reindexer = IndexRelabeller()
return map_expr_dag(reindexer, expr)
else:
jorgensd marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Invalid form type {form.__class__name}")
Loading