Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove switch/case from generated code #591

Merged
merged 19 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 142 additions & 162 deletions ffcx/codegeneration/C/cnodes.py

Large diffs are not rendered by default.

67 changes: 0 additions & 67 deletions ffcx/codegeneration/dofmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,54 +8,12 @@
# old implementation in FFC

import logging
import typing

import ffcx.codegeneration.dofmap_template as ufcx_dofmap

logger = logging.getLogger("ffcx")


def tabulate_entity_dofs(
L,
entity_dofs: typing.List[typing.List[typing.List[int]]],
num_dofs_per_entity: typing.List[int],
):
# Output argument array
dofs = L.Symbol("dofs")

# Input arguments
d = L.Symbol("d")
i = L.Symbol("i")

# TODO: Removed check for (d <= tdim + 1)
tdim = len(num_dofs_per_entity) - 1

# Generate cases for each dimension:
all_cases = []
for dim in range(tdim + 1):
# Ignore if no entities for this dimension
if num_dofs_per_entity[dim] == 0:
continue

# Generate cases for each mesh entity
cases = []
for entity in range(len(entity_dofs[dim])):
casebody = []
for j, dof in enumerate(entity_dofs[dim][entity]):
casebody += [L.Assign(dofs[j], dof)]
cases.append((entity, L.StatementList(casebody)))

# Generate inner switch
# TODO: Removed check for (i <= num_entities-1)
inner_switch = L.Switch(i, cases, autoscope=False)
all_cases.append((dim, inner_switch))

if all_cases:
return L.Switch(d, all_cases, autoscope=False)
else:
return L.NoOp()


def generator(ir, options):
"""Generate UFC code for a dofmap."""
logger.info("Generating code for dofmap:")
Expand All @@ -73,23 +31,6 @@ def generator(ir, options):

import ffcx.codegeneration.C.cnodes as L

num_entity_dofs = ir.num_entity_dofs + [0, 0, 0, 0]
num_entity_dofs = num_entity_dofs[:4]
d["num_entity_dofs"] = f"num_entity_dofs_{ir.name}"
d["num_entity_dofs_init"] = L.ArrayDecl(
"int", f"num_entity_dofs_{ir.name}", values=num_entity_dofs, sizes=4
)

num_entity_closure_dofs = ir.num_entity_closure_dofs + [0, 0, 0, 0]
num_entity_closure_dofs = num_entity_closure_dofs[:4]
d["num_entity_closure_dofs"] = f"num_entity_closure_dofs_{ir.name}"
d["num_entity_closure_dofs_init"] = L.ArrayDecl(
"int",
f"num_entity_closure_dofs_{ir.name}",
values=num_entity_closure_dofs,
sizes=4,
)

flattened_entity_dofs = []
entity_dof_offsets = [0]
for dim in ir.entity_dofs:
Expand Down Expand Up @@ -137,14 +78,6 @@ def generator(ir, options):

d["block_size"] = ir.block_size

# Functions
d["tabulate_entity_dofs"] = tabulate_entity_dofs(
L, ir.entity_dofs, ir.num_entity_dofs
)
d["tabulate_entity_closure_dofs"] = tabulate_entity_dofs(
L, ir.entity_closure_dofs, ir.num_entity_closure_dofs
)

if len(ir.sub_dofmaps) > 0:
d["sub_dofmaps_initialization"] = L.ArrayDecl(
"ufcx_dofmap*",
Expand Down
18 changes: 0 additions & 18 deletions ffcx/codegeneration/dofmap_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,6 @@

{sub_dofmaps_initialization}

void tabulate_entity_dofs_{factory_name}(int* restrict dofs, int d, int i)
{{
{tabulate_entity_dofs}
}}

void tabulate_entity_closure_dofs_{factory_name}(int* restrict dofs, int d, int i)
{{
{tabulate_entity_closure_dofs}
}}

{num_entity_dofs_init}

{num_entity_closure_dofs_init}

{entity_dofs_init}

{entity_dof_offsets_init}
Expand All @@ -44,10 +30,6 @@
.entity_dof_offsets = {entity_dof_offsets},
.entity_closure_dofs = {entity_closure_dofs},
.entity_closure_dof_offsets = {entity_closure_dof_offsets},
.num_entity_dofs = {num_entity_dofs},
.tabulate_entity_dofs = tabulate_entity_dofs_{factory_name},
.num_entity_closure_dofs = {num_entity_closure_dofs},
.tabulate_entity_closure_dofs = tabulate_entity_closure_dofs_{factory_name},
.num_sub_dofmaps = {num_sub_dofmaps},
.sub_dofmaps = {sub_dofmaps}
}};
Expand Down
53 changes: 1 addition & 52 deletions ffcx/codegeneration/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,6 @@ def generator(ir, options):
d["num_coefficients"] = ir.num_coefficients
d["num_constants"] = ir.num_constants

code = []
cases = []
for itg_type in ("cell", "interior_facet", "exterior_facet"):
cases += [(L.Symbol(itg_type), L.Return(len(ir.subdomain_ids[itg_type])))]
code += [L.Switch("integral_type", cases, default=L.Return(0))]
d["num_integrals"] = L.StatementList(code)

if len(ir.original_coefficient_position) > 0:
d["original_coefficient_position_init"] = L.ArrayDecl(
"int",
Expand Down Expand Up @@ -95,6 +88,7 @@ def generator(ir, options):
integrals = []
integral_ids = []
integral_offsets = [0]
# Note: the order of this list is defined by the enum ufcx_integral_type in ufcx.h
for itg_type in ("cell", "exterior_facet", "interior_facet"):
integrals += [L.AddressOf(L.Symbol(itg)) for itg in ir.integral_names[itg_type]]
integral_ids += ir.subdomain_ids[itg_type]
Expand Down Expand Up @@ -128,51 +122,6 @@ def generator(ir, options):
sizes=len(integral_offsets),
)

code = []
cases = []
code_ids = []
cases_ids = []
for itg_type in ("cell", "interior_facet", "exterior_facet"):
if len(ir.integral_names[itg_type]) > 0:
code += [
L.ArrayDecl(
"static ufcx_integral*",
f"integrals_{itg_type}_{ir.name}",
values=[
L.AddressOf(L.Symbol(itg))
for itg in ir.integral_names[itg_type]
],
sizes=len(ir.integral_names[itg_type]),
)
]
cases.append(
(
L.Symbol(itg_type),
L.Return(L.Symbol(f"integrals_{itg_type}_{ir.name}")),
)
)

code_ids += [
L.ArrayDecl(
"static int",
f"integral_ids_{itg_type}_{ir.name}",
values=ir.subdomain_ids[itg_type],
sizes=len(ir.subdomain_ids[itg_type]),
)
]
cases_ids.append(
(
L.Symbol(itg_type),
L.Return(L.Symbol(f"integral_ids_{itg_type}_{ir.name}")),
)
)

code += [L.Switch("integral_type", cases, default=L.Return(L.Null()))]
code_ids += [L.Switch("integral_type", cases_ids, default=L.Return(L.Null()))]
d["integrals"] = L.StatementList(code)

d["integral_ids"] = L.StatementList(code_ids)

code = []
function_name = L.Symbol("function_name")

Expand Down
20 changes: 0 additions & 20 deletions ffcx/codegeneration/form_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,6 @@
{constant_name_map}
}}

int* integral_ids_{factory_name}(ufcx_integral_type integral_type)
{{
{integral_ids}
}}

int num_integrals_{factory_name}(ufcx_integral_type integral_type)
{{
{num_integrals}
}}

ufcx_integral** integrals_{factory_name}(ufcx_integral_type integral_type)
{{
{integrals}
}}

ufcx_form {factory_name} =
{{

Expand All @@ -70,11 +55,6 @@
.finite_elements = {finite_elements},
.dofmaps = {dofmaps},

.integral_ids = integral_ids_{factory_name},
.num_integrals = num_integrals_{factory_name},

.integrals = integrals_{factory_name},

.form_integrals = {form_integrals},
.form_integral_ids = {form_integral_ids},
.form_integral_offsets = form_integral_offsets_{factory_name}
Expand Down
9 changes: 0 additions & 9 deletions ffcx/codegeneration/ufcx.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,15 +456,6 @@ extern "C"
/// Coefficient number j=i-r if r+j <= i < r+n
ufcx_dofmap** dofmaps;

/// All ids for integrals
int* (*integral_ids)(ufcx_integral_type);

/// Number of integrals
int (*num_integrals)(ufcx_integral_type);

/// Get an integral on sub domain subdomain_id
ufcx_integral** (*integrals)(ufcx_integral_type);

/// List of cell, interior facet and exterior facet integrals
ufcx_integral** form_integrals;

Expand Down
21 changes: 13 additions & 8 deletions test/test_add_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ def test_additive_facet_integral(mode, compile_args):
ffi = module.ffi
form0 = compiled_forms[0]

assert form0.num_integrals(module.lib.exterior_facet) == 1
ids = form0.integral_ids(module.lib.exterior_facet)
assert ids[0] == -1
integral_offsets = form0.form_integral_offsets
ex = module.lib.exterior_facet
assert integral_offsets[ex + 1] - integral_offsets[ex] == 1
integral_id = form0.form_integral_ids[integral_offsets[ex]]
assert integral_id == -1

default_integral = form0.integrals(module.lib.exterior_facet)[0]
default_integral = form0.form_integrals[integral_offsets[ex]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This here is implicitly tied to the ordering of the enum in DOLFINx, which I'm not sure is a good idea. See: #589 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order is explicitly defined in ufcx.h

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is still an issue for whenever we generate the offsets (as that is based on strings in a for-loop, not the enum).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have annotated the point where this happens in ffcx.


np_type = cdtype_to_numpy(mode)
A = np.zeros((3, 3), dtype=np_type)
Expand Down Expand Up @@ -83,11 +85,14 @@ def test_additive_cell_integral(mode, compile_args):
ffi = module.ffi
form0 = compiled_forms[0]

assert form0.num_integrals(module.lib.cell) == 1
ids = form0.integral_ids(module.lib.cell)
assert ids[0] == -1
cell = module.lib.cell
offsets = form0.form_integral_offsets
num_integrals = offsets[cell + 1] - offsets[cell]
assert num_integrals == 1
integral_id = form0.form_integral_ids[offsets[cell]]
assert integral_id == -1

default_integral = form0.integrals(0)[0]
default_integral = form0.form_integrals[offsets[cell]]

np_type = cdtype_to_numpy(mode)
A = np.zeros((3, 3), dtype=np_type)
Expand Down
50 changes: 17 additions & 33 deletions test/test_blocked_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,11 @@ def test_finite_element(compile_args):
assert ufcx_dofmap.num_global_support_dofs == 0
assert ufcx_dofmap.num_global_support_dofs == 0
assert ufcx_dofmap.num_element_support_dofs == 3
assert ufcx_dofmap.num_entity_dofs[0] == 1
assert ufcx_dofmap.num_entity_dofs[1] == 0
assert ufcx_dofmap.num_entity_dofs[2] == 0
assert ufcx_dofmap.num_entity_dofs[3] == 0
off = np.array([ufcx_dofmap.entity_dof_offsets[i] for i in range(8)])
assert np.all(np.diff(off) == [1, 1, 1, 0, 0, 0, 0])

for v in range(3):
vals = np.zeros(1, dtype=np.int32)
vals_ptr = module.ffi.cast("int *", module.ffi.from_buffer(vals))
ufcx_dofmap.tabulate_entity_dofs(vals_ptr, 0, v)
assert vals[0] == v
assert ufcx_dofmap.entity_dofs[v] == v
assert ufcx_dofmap.num_sub_dofmaps == 0


Expand All @@ -66,15 +62,11 @@ def test_vector_element(compile_args):
assert ufcx_dofmap.num_global_support_dofs == 0
assert ufcx_dofmap.num_global_support_dofs == 0
assert ufcx_dofmap.num_element_support_dofs == 3
assert ufcx_dofmap.num_entity_dofs[0] == 1
assert ufcx_dofmap.num_entity_dofs[1] == 0
assert ufcx_dofmap.num_entity_dofs[2] == 0
assert ufcx_dofmap.num_entity_dofs[3] == 0
off = np.array([ufcx_dofmap.entity_dof_offsets[i] for i in range(8)])
assert np.all(np.diff(off) == [1, 1, 1, 0, 0, 0, 0])

for v in range(3):
vals = np.zeros(1, dtype=np.int32)
vals_ptr = module.ffi.cast("int *", module.ffi.from_buffer(vals))
ufcx_dofmap.tabulate_entity_dofs(vals_ptr, 0, v)
assert vals[0] == v
assert ufcx_dofmap.entity_dofs[v] == v
assert ufcx_dofmap.num_sub_dofmaps == 2


Expand Down Expand Up @@ -102,15 +94,11 @@ def test_tensor_element(compile_args):
assert ufcx_dofmap.num_global_support_dofs == 0
assert ufcx_dofmap.num_global_support_dofs == 0
assert ufcx_dofmap.num_element_support_dofs == 3
assert ufcx_dofmap.num_entity_dofs[0] == 1
assert ufcx_dofmap.num_entity_dofs[1] == 0
assert ufcx_dofmap.num_entity_dofs[2] == 0
assert ufcx_dofmap.num_entity_dofs[3] == 0
off = np.array([ufcx_dofmap.entity_dof_offsets[i] for i in range(8)])
assert np.all(np.diff(off) == [1, 1, 1, 0, 0, 0, 0])

for v in range(3):
vals = np.zeros(1, dtype=np.int32)
vals_ptr = module.ffi.cast("int *", module.ffi.from_buffer(vals))
ufcx_dofmap.tabulate_entity_dofs(vals_ptr, 0, v)
assert vals[0] == v
assert ufcx_dofmap.entity_dofs[v] == v
assert ufcx_dofmap.num_sub_dofmaps == 4


Expand All @@ -136,14 +124,10 @@ def test_vector_quadrature_element(compile_args):
assert ufcx_dofmap.num_global_support_dofs == 0
assert ufcx_dofmap.num_global_support_dofs == 0
assert ufcx_dofmap.num_element_support_dofs == 4
assert ufcx_dofmap.num_entity_dofs[0] == 0
assert ufcx_dofmap.num_entity_dofs[1] == 0
assert ufcx_dofmap.num_entity_dofs[2] == 0
assert ufcx_dofmap.num_entity_dofs[3] == 4

vals = np.zeros(4, dtype=np.int32)
vals_ptr = module.ffi.cast("int *", module.ffi.from_buffer(vals))
ufcx_dofmap.tabulate_entity_dofs(vals_ptr, 3, 0)
assert (vals == [0, 1, 2, 3]).all()
off = np.array([ufcx_dofmap.entity_dof_offsets[i] for i in range(16)])
assert np.all(np.diff(off) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4])

for i in range(4):
assert ufcx_dofmap.entity_dofs[i] == i

assert ufcx_dofmap.num_sub_dofmaps == 3
Loading