Skip to content

Commit

Permalink
sparsity: make allocation more flexible and save memory
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Dec 20, 2023
1 parent 41fc742 commit ee777cc
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 20 deletions.
61 changes: 47 additions & 14 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from collections import OrderedDict
from collections import OrderedDict, defaultdict
import functools
import itertools
from itertools import product
Expand Down Expand Up @@ -160,7 +160,8 @@ def assemble_base_form(expression, tensor=None, bcs=None,
This function assembles a :class:`~ufl.classes.BaseForm` object by traversing the corresponding DAG
in a post-order fashion and evaluating the nodes on the fly.
"""

if bcs is not None and any(not isinstance(bc, (DirichletBC, EquationBCSplit)) for bc in bcs):
raise TypeError("All bcs must be DirichletBC or EquationBCSplit for direct assembly.")
# Preprocess the DAG and restructure the DAG
if not is_base_form_preprocessed and not isinstance(expression, slate.TensorBase):
# Preprocessing the form makes a new object -> current form caching mechanism
Expand Down Expand Up @@ -425,10 +426,31 @@ def allocate_matrix(expr, bcs=None, *, mat_type=None, sub_mat_type=None,
for get_map, regions in domains.items()
if regions))
try:
# Let tensor=None and bcs=None for the purpose of allocation.
parloop_builders = TwoFormAssembler(expr, None, bcs=None, form_compiler_parameters=form_compiler_parameters).parloop_builders_for_sparsity_construction
maps_and_regions = defaultdict(lambda: defaultdict(set))
if any(parloop_builder._indices == (None, None) for parloop_builder in parloop_builders):
# Handle special cases: slate or split=False
assert all(parloop_builder._indices == (None, None) for parloop_builder in parloop_builders)
for parloop_builder in parloop_builders:
_integral_type = parloop_builder._integral_type
get_map, region = mapping[_integral_type]
for i, _rmap in enumerate(get_map(test)):
for j, _cmap in enumerate(get_map(trial)):
maps_and_regions[(i, j)][(_rmap, _cmap)].update((region, ))
else:
for parloop_builder in parloop_builders:
i, j = parloop_builder._indices
_integral_type = parloop_builder._integral_type
# Make Sparsity independent of _iterset for better reusability.
get_map, region = mapping[_integral_type]
_rmap = get_map(test).split[i] if get_map(test) is not None else None
_cmap = get_map(trial).split[j] if get_map(trial) is not None else None
maps_and_regions[(i, j)][(_rmap, _cmap)].update((region, ))
maps_and_regions = {key: [k + (tuple(v), ) for k, v in val.items()] for key, val in maps_and_regions.items()}
sparsity = op2.Sparsity((test.function_space().dof_dset,
trial.function_space().dof_dset),
tuple(map_pairs),
iteration_regions=tuple(iteration_regions),
maps_and_regions,
nest=nest,
block_sparse=baij)
except SparsityFormatError:
Expand Down Expand Up @@ -644,8 +666,6 @@ class FormAssembler(abc.ABC):
"""

def __init__(self, form, tensor, bcs=(), form_compiler_parameters=None, needs_zeroing=True, weight=1.0):
assert tensor is not None

bcs = solving._extract_bcs(bcs)

self._form = form
Expand Down Expand Up @@ -682,8 +702,6 @@ def assemble(self):
self.execute_parloops()

for bc in self._bcs:
if isinstance(bc, EquationBC): # can this be lifted?
bc = bc.extract_form("F")
self._apply_bc(bc)

return self.result
Expand Down Expand Up @@ -755,12 +773,12 @@ def global_kernels(self):
)

@cached_property
def parloops(self):
loops = []
def parloop_builders(self):
out = []
for (local_kernel, subdomain_id), global_kernel in zip(
self.local_kernels, self.global_kernels
):
loops.append(
out.append(
ParloopBuilder(
self._form,
local_kernel,
Expand All @@ -770,9 +788,13 @@ def parloops(self):
self.all_integer_subdomain_ids,
diagonal=self.diagonal,
lgmaps=self.collect_lgmaps(local_kernel, self._bcs)
).build()
)
)
return tuple(loops)
return tuple(out)

@cached_property
def parloops(self):
return tuple(parloop_builder.build() for parloop_builder in self.parloop_builders)

def needs_unrolling(self, local_knl, bcs):
"""Do we need to address matrix elements directly rather than in
Expand Down Expand Up @@ -903,7 +925,9 @@ def trial_function_space(self):

def get_indicess(self, knl):
if all(i is None for i in knl.indices):
return numpy.ndindex(self._tensor.block_shape)
test, trial = self._form.arguments()
return numpy.ndindex((len(test.function_space()),
len(trial.function_space())))
else:
assert all(i is not None for i in knl.indices)
return knl.indices,
Expand Down Expand Up @@ -995,6 +1019,15 @@ def _apply_bcs_mat_real_block(op2tensor, i, j, component, node_set):
dat = op2.DatView(dat, component)
dat.zero(subset=node_set)

@cached_property
def parloop_builders_for_sparsity_construction(self):
out = self.parloop_builders
for bc in self._bcs:
if isinstance(bc, EquationBCSplit):
_assembler = type(self)(bc.f, self._tensor, bc.bcs, self._form_compiler_params, needs_zeroing=False)
out += _assembler.parloop_builders_for_sparsity_construction
return out


class MatrixFreeAssembler:
"""Stub class wrapping matrix-free assembly."""
Expand Down
2 changes: 1 addition & 1 deletion firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def make_interpolator(expr, V, subset, access, bcs=None):
tensor = None
else:
sparsity = op2.Sparsity((V.dof_dset, argfs.dof_dset),
((V.cell_node_map(), argfs_map),),
[(V.cell_node_map(), argfs_map, None)], # non-mixed
name="%s_%s_sparsity" % (V.name, argfs.name),
nest=False,
block_sparse=True)
Expand Down
6 changes: 4 additions & 2 deletions firedrake/preconditioners/pmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,8 +1528,10 @@ def prolongation_matrix_aij(P1, Pk, P1_bcs=[], Pk_bcs=[]):
Pk = Pk.function_space()
sp = op2.Sparsity((Pk.dof_dset,
P1.dof_dset),
(Pk.cell_node_map(),
P1.cell_node_map()))
{(i, j): [(rmap, cmap, None)]
for i, rmap in enumerate(Pk.cell_node_map())
for j, cmap in enumerate(P1.cell_node_map())
if i == j})
mat = op2.Mat(sp, PETSc.ScalarType)
mesh = Pk.mesh()

Expand Down
2 changes: 0 additions & 2 deletions tests/equation_bcs/test_equation_bcs_assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ def test_equation_bcs_direct_assemble_one_form():

g = assemble(F, bcs=bc.extract_form('F'))
assert np.allclose(g.dat.data, [0.5, 0.5, 0, 0])
g = assemble(F, bcs=bc)
assert np.allclose(g.dat.data, [0.5, 0.5, 0, 0])


def test_equation_bcs_direct_assemble_two_form():
Expand Down
15 changes: 14 additions & 1 deletion tests/regression/test_assemble.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import numpy as np
from firedrake import *
from firedrake.utils import ScalarType
from firedrake.utils import ScalarType, IntType


@pytest.fixture(scope='module')
Expand Down Expand Up @@ -291,3 +291,16 @@ def test_assemble_vector_rspace_one_form(mesh):
U = inner(u, u)*dx
L = derivative(U, u)
assemble(L)


def test_assemble_sparsity():
mesh = UnitSquareMesh(2, 2, quadrilateral=True)
V = FunctionSpace(mesh, "CG", 1)
W = V * V * V
u = TrialFunction(W)
v = TestFunction(W)
A = assemble(inner(u, v) * dx, mat_type="nest")
for i in range(len(W)):
for j in range(len(W)):
if i != j:
assert np.all(A.M.sparsity[i][j].nnz == np.zeros(9, dtype=IntType))

0 comments on commit ee777cc

Please sign in to comment.