Skip to content

Commit

Permalink
Add documentation and motivation
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgensd committed Sep 27, 2024
1 parent 35ac5d2 commit 770e94e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 66 deletions.
24 changes: 23 additions & 1 deletion test/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
exp,
i,
indices,
interval,
j,
k,
l,
outer,
sin,
triangle,
)
import ufl.algorithms
import ufl.classes
from ufl import indices
from ufl.classes import IndexSum
from ufl.finiteelement import FiniteElement
from ufl.pullback import identity_pullback
Expand Down Expand Up @@ -305,4 +309,22 @@ def test_spatial_derivative(self):


def test_renumbering(self):
pass
"""Test that kernels with common integral data, but different index numbering, are correctly renumbered."""
cell = interval
mesh = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1))
V = FunctionSpace(mesh, FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1))
v = TestFunction(V)
u = TrialFunction(V)
i = indices(1)
a0 = u[i].dx(0)*v[i].dx(0) * ufl.dx((1))
a1 = u[i].dx(0)*v[i].dx(0) * ufl.dx((2,3,))
form_data = ufl.algorithms.compute_form_data(
a0+a1,
do_apply_function_pullbacks=True,
do_apply_integral_scaling=True,
do_apply_geometry_lowering=True,
preserve_geometry_types=(ufl.classes.Jacobian,),
do_apply_restrictions=True,
do_append_everywhere_integrals=False)

assert len(form_data.integral_data) == 1
107 changes: 42 additions & 65 deletions ufl/algorithms/renumbering.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,66 @@
"""Algorithms for renumbering of counted objects, currently variables and indices."""
# Copyright (C) 2008-2016 Martin Sandve Alnæs and Anders Logg
# Copyright (C) 2008-2024 Martin Sandve Alnæs, Anders Logg, Jø¶gen S. Dokken and Lawerence Mitchell
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from ufl.algorithms.transformer import ReuseTransformer, apply_transformer
from ufl.classes import Zero
from ufl.classes import Form, Integral
from ufl.core.expr import Expr
from ufl.core.multiindex import FixedIndex, Index, MultiIndex
from ufl.variable import Label, Variable
from ufl.corealg.map_dag import map_expr_dag

from ufl.core.multiindex import Index
from itertools import count as _count
from ufl.corealg.multifunction import MultiFunction
from collections import defaultdict

class VariableRenumberingTransformer(ReuseTransformer):
"""Variable renumbering transformer."""

class IndexRelabeller(MultiFunction):
"""Renumber indices to have a consistent index numbering starting from 0."""
def __init__(self):
"""Initialise."""
ReuseTransformer.__init__(self)
self.variable_map = {}

def variable(self, o):
"""Apply to variable."""
e, l = o.ufl_operands # noqa: E741
v = self.variable_map.get(l)
if v is None:
e = self.visit(e)
l2 = Label(len(self.variable_map))
v = Variable(e, l2)
self.variable_map[l] = v
return v


class IndexRenumberingTransformer(VariableRenumberingTransformer):
"""Index renumbering transformer.
This is a poorly designed algorithm. It is used in some tests,
please do not use for anything else.
"""
super().__init__()
count = _count()
self.index_cache = defaultdict(lambda: Index(next(count)))

def __init__(self):
"""Initialise."""
VariableRenumberingTransformer.__init__(self)
self.index_map = {}
expr = MultiFunction.reuse_if_untouched

def multi_index(self, o):
"""Apply to multi-indices."""
return type(o)(tuple(self.index_cache[i] if isinstance(i, Index) else i for i in o.indices()))

def zero(self, o):
"""Apply to zero."""
fi = o.ufl_free_indices
fid = o.ufl_index_dimensions
new_indices = [self.index_cache[Index(i)].count() for i in fi]
if fi == () and fid == ():
return o
mapped_fi = tuple(self.index(Index(count=i)) for i in fi)
paired_fid = [(mapped_fi[pos], fid[pos]) for pos, a in enumerate(fi)]
new_fi, new_fid = zip(*tuple(sorted(paired_fid)))
return Zero(o.ufl_shape, new_fi, new_fid)
new_fi, new_fid = zip(*sorted(zip(new_indices, fid), key=lambda x: x[0]))
return type(o)(o.ufl_shape, tuple(new_fi), tuple(new_fid))

def index(self, o):
"""Apply to index."""
if isinstance(o, FixedIndex):
return o
else:
c = o._count
i = self.index_map.get(c)
if i is None:
i = Index(count=len(self.index_map))
self.index_map[c] = i
return i

def multi_index(self, o):
"""Apply to multi_index."""
new_indices = tuple(self.index(i) for i in o.indices())
return MultiIndex(new_indices)


def renumber_indices(expr):
"""Renumber indices."""
if isinstance(expr, Expr):
num_free_indices = len(expr.ufl_free_indices)
def renumber_indices(form):
"""Renumber indices to have a consistent index numbering starting from 0.
result = apply_transformer(expr, IndexRenumberingTransformer())
This is useful to avoid multiple kernels for the same integrand, but with different subdomain ids.
if isinstance(expr, Expr):
if num_free_indices != len(result.ufl_free_indices):
raise ValueError(
"The number of free indices left in expression "
"should be invariant w.r.t. renumbering."
)
return result
Args:
form: A UFL form, integral or expression.
Returns:
A new form, integral or expression with renumbered indices.
"""
if isinstance(form, Form):
new_integrals = [renumber_indices(itg) for itg in form.integrals()]
return Form(new_integrals)
elif isinstance(form, Integral):
integral = form
reindexer = IndexRelabeller()
new_integrand = map_expr_dag(reindexer, integral.integrand())
return integral.reconstruct(new_integrand)
elif isinstance(form, Expr):
expr = form
reindexer = IndexRelabeller()
return map_expr_dag(reindexer, expr)
else:
raise ValueError(f"Invalid form type {form.__class__name}")

0 comments on commit 770e94e

Please sign in to comment.