Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DO NOT SQUASH: Specify sparsity block-wise #3288

Merged
merged 17 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ The source is injected at the center of the unit square::
ricker = Constant(0.0)
ricker.assign(RickerWavelet(t, freq))

We also create a function `R` to save the assembled RHS vector::
We also create a cofunction `R` to save the assembled RHS vector::

R = Function(V)
R = Cofunction(V.dual())

Finally, we define the whole variational form :math:`F`, assemble it, and then create a cached PETSc `LinearSolver` object to efficiently timestep with::

Expand Down
4 changes: 2 additions & 2 deletions firedrake/adjoint_utils/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def wrapper(form, *args, **kwargs):
ad_block_tag = kwargs.pop("ad_block_tag", None)
annotate = annotate_tape(kwargs)
with stop_annotating():
from firedrake.assemble import preprocess_base_form
from firedrake.assemble import BaseFormAssembler
from firedrake.slate import slate
if not isinstance(form, slate.TensorBase):
# Preprocess the form at the annotation stage so that the `AssembleBlock`
Expand All @@ -25,7 +25,7 @@ def wrapper(form, *args, **kwargs):
# -> `interp = Action(Interpolate(v1, v0), f)` with `v1` and `v0` being respectively `Argument`
# and `Coargument`. Differentiating `interp` is not currently supported as the action's left slot
# is a 2-form. However, after preprocessing, we obtain `Interpolate(f, v0)`, which can be differentiated.
form = preprocess_base_form(form)
form = BaseFormAssembler.preprocess_base_form(form)
kwargs['is_base_form_preprocessed'] = True
output = assemble(form, *args, **kwargs)

Expand Down
2,409 changes: 1,301 additions & 1,108 deletions firedrake/assemble.py

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions firedrake/external_operators/abstract_external_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ufl.argument import BaseArgument

import firedrake.ufl_expr as ufl_expr
from firedrake.assemble import allocate_matrix
from firedrake.assemble import get_assembler
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from firedrake.matrix import MatrixBase
Expand Down Expand Up @@ -202,7 +202,6 @@ def _matrix_builder(self, bcs, opts, integral_types):

This helper function provides a way to allocate matrices that can then be populated
in the assembly method(s) of the external operator subclass.
This function relies on the :func:`firedrake.assemble.allocate_matrix` function.

Parameters
----------
Expand All @@ -222,7 +221,7 @@ def _matrix_builder(self, bcs, opts, integral_types):
# Remove `diagonal` keyword argument
opts.pop('diagonal', None)
# Allocate the matrix associated with `self`
return allocate_matrix(self, bcs=bcs, integral_types=integral_types, **opts)
return get_assembler(self, bcs=bcs, allocation_integral_types=integral_types, **opts).allocate()

def _ufl_expr_reconstruct_(self, *operands, function_space=None, derivatives=None,
argument_slots=None, operator_data=None, add_kwargs={}):
Expand Down
2 changes: 1 addition & 1 deletion firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def make_interpolator(expr, V, subset, access, bcs=None):
tensor = None
else:
sparsity = op2.Sparsity((V.dof_dset, argfs.dof_dset),
((V.cell_node_map(), argfs_map),),
[(V.cell_node_map(), argfs_map, None)], # non-mixed
name="%s_%s_sparsity" % (V.name, argfs.name),
nest=False,
block_sparse=True)
Expand Down
6 changes: 3 additions & 3 deletions firedrake/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,19 @@ def trial_space(self):

@cached_property
def _rhs(self):
from firedrake.assemble import OneFormAssembler
from firedrake.assemble import get_assembler

u = function.Function(self.trial_space)
b = cofunction.Cofunction(self.test_space.dual())
expr = -action(self.A.a, u)
return u, OneFormAssembler(expr, tensor=b).assemble, b
return u, get_assembler(expr).assemble, b

def _lifted(self, b):
u, update, blift = self._rhs
u.dat.zero()
for bc in self.A.bcs:
bc.apply(u)
update()
update(tensor=blift)
# blift contains -A u_bc
blift += b
if isinstance(blift, cofunction.Cofunction):
Expand Down
35 changes: 17 additions & 18 deletions firedrake/matrix_free/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class ImplicitMatrixContext(object):
@PETSc.Log.EventDecorator()
def __init__(self, a, row_bcs=[], col_bcs=[],
fc_params=None, appctx=None):
from firedrake.assemble import get_form_assembler
from firedrake.assemble import get_assembler

self.a = a
self.aT = adjoint(a)
Expand Down Expand Up @@ -144,10 +144,10 @@ def __init__(self, a, row_bcs=[], col_bcs=[],
elif isinstance(bc, EquationBCSplit):
self.bcs_action.append(bc.reconstruct(action_x=self._x))

self._assemble_action = get_form_assembler(self.action, tensor=self._ystar,
bcs=self.bcs_action,
form_compiler_parameters=self.fc_params,
zero_bc_nodes=True)
self._assemble_action = get_assembler(self.action,
bcs=self.bcs_action,
form_compiler_parameters=self.fc_params,
zero_bc_nodes=True).assemble

# For assembling action(adjoint(f), self._y)
# Sorted list of equation bcs
Expand All @@ -161,13 +161,12 @@ def __init__(self, a, row_bcs=[], col_bcs=[],
for bc in self.bcs:
for ebc in bc.sorted_equation_bcs():
self._assemble_actionT.append(
get_form_assembler(action(adjoint(ebc.f), self._y), tensor=self._xbc,
form_compiler_parameters=self.fc_params))
get_assembler(action(adjoint(ebc.f), self._y),
form_compiler_parameters=self.fc_params).assemble)
# Domain last
self._assemble_actionT.append(
get_form_assembler(self.actionT,
tensor=self._xstar if len(self.bcs) == 0 else self._xbc,
form_compiler_parameters=self.fc_params))
get_assembler(self.actionT,
form_compiler_parameters=self.fc_params).assemble)

@cached_property
def _diagonal(self):
Expand All @@ -177,13 +176,13 @@ def _diagonal(self):

@cached_property
def _assemble_diagonal(self):
from firedrake.assemble import get_form_assembler
return get_form_assembler(self.a, tensor=self._diagonal,
form_compiler_parameters=self.fc_params,
diagonal=True)
from firedrake.assemble import get_assembler
return get_assembler(self.a,
form_compiler_parameters=self.fc_params,
diagonal=True).assemble

def getDiagonal(self, mat, vec):
self._assemble_diagonal()
self._assemble_diagonal(tensor=self._diagonal)
diagonal_func = self._diagonal.riesz_representation(riesz_map="l2")
for bc in self.bcs:
# Operator is identity on boundary nodes
Expand Down Expand Up @@ -212,7 +211,7 @@ def mult(self, mat, X, Y):
# If we are not, then the matrix just has 0s in the rows and columns.
for bc in self.col_bcs:
bc.zero(self._x)
self._assemble_action()
self._assemble_action(tensor=self._ystar)
# This sets the essential boundary condition values on the
# result.
if self.on_diag:
Expand Down Expand Up @@ -307,14 +306,14 @@ def multTranspose(self, mat, Y, X):
# zero columns associated with DirichletBCs/EquationBCs
for obc in obj.bcs:
obc.zero(self._y)
aT()
aT(tensor=self._xbc)
self._xstar += self._xbc
else:
# No DirichletBC/EquationBC
# There is only a single element in the list (for the domain equation).
# Save to self._x directly
aT, = self._assemble_actionT
aT()
aT(tensor=self._xstar)

if self.on_diag:
if len(self.col_bcs) > 0:
Expand Down
15 changes: 6 additions & 9 deletions firedrake/preconditioners/assembled.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class AssembledPC(PCBase):
_prefix = "assembled_"

def initialize(self, pc):
from firedrake.assemble import allocate_matrix, TwoFormAssembler
from firedrake.assemble import get_assembler
A, P = pc.getOperators()

if pc.getType() != "python":
Expand Down Expand Up @@ -51,13 +51,10 @@ def initialize(self, pc):

(a, bcs) = self.form(pc, test, trial)

self.P = allocate_matrix(a, bcs=bcs,
form_compiler_parameters=fcp,
mat_type=mat_type,
options_prefix=options_prefix)
self._assemble_P = TwoFormAssembler(a, tensor=self.P, bcs=bcs,
form_compiler_parameters=fcp).assemble
self._assemble_P()
form_assembler = get_assembler(a, bcs=bcs, form_compiler_parameters=fcp, mat_type=mat_type, options_prefix=options_prefix)
self.P = form_assembler.allocate()
self._assemble_P = form_assembler.assemble
self._assemble_P(tensor=self.P)

# Transfer nullspace over
Pmat = self.P.petscmat
Expand Down Expand Up @@ -87,7 +84,7 @@ def initialize(self, pc):
pc.setFromOptions()

def update(self, pc):
self._assemble_P()
self._assemble_P(tensor=self.P)

def form(self, pc, test, trial):
_, P = pc.getOperators()
Expand Down
17 changes: 6 additions & 11 deletions firedrake/preconditioners/facet_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_permutation(self, V, W):
def initialize(self, pc):
from finat.ufl import RestrictedElement, MixedElement, TensorElement, VectorElement
from firedrake import FunctionSpace, TestFunctions, TrialFunctions
from firedrake.assemble import allocate_matrix, TwoFormAssembler
from firedrake.assemble import get_assembler

_, P = pc.getOperators()
appctx = self.get_appctx(pc)
Expand Down Expand Up @@ -88,15 +88,10 @@ def restrict(ele, restriction_domain):
self.iperm = self.perm.invertPermutation()

if mat_type != "submatrix":
self.mixed_op = allocate_matrix(mixed_operator,
bcs=mixed_bcs,
form_compiler_parameters=fcp,
mat_type=mat_type,
options_prefix=options_prefix)
self._assemble_mixed_op = TwoFormAssembler(mixed_operator, tensor=self.mixed_op,
form_compiler_parameters=fcp,
bcs=mixed_bcs).assemble
self._assemble_mixed_op()
form_assembler = get_assembler(mixed_operator, bcs=mixed_bcs, form_compiler_parameters=fcp, mat_type=mat_type, options_prefix=options_prefix)
self.mixed_op = form_assembler.allocate()
self._assemble_mixed_op = form_assembler.assemble
self._assemble_mixed_op(tensor=self.mixed_op)
mixed_opmat = self.mixed_op.petscmat

def _permute_nullspace(nsp):
Expand Down Expand Up @@ -147,7 +142,7 @@ def _permute_nullspace(nsp):

def update(self, pc):
if hasattr(self, "mixed_op"):
self._assemble_mixed_op()
self._assemble_mixed_op(tensor=self.mixed_op)
elif hasattr(self, "_permute_op"):
for mat in self.pc.getOperators():
mat.destroy()
Expand Down
38 changes: 15 additions & 23 deletions firedrake/preconditioners/fdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,11 @@ def initialize(self, pc):
self.work_vec_x = Amat.createVecLeft()
self.work_vec_y = Amat.createVecRight()
if use_amat:
from firedrake.assemble import allocate_matrix, TwoFormAssembler
self.A = allocate_matrix(J_fdm, bcs=bcs_fdm, form_compiler_parameters=fcp,
mat_type=mat_type, options_prefix=options_prefix)
self._assemble_A = TwoFormAssembler(J_fdm, tensor=self.A, bcs=bcs_fdm,
form_compiler_parameters=fcp).assemble
self._assemble_A()
from firedrake.assemble import get_assembler
form_assembler = get_assembler(J_fdm, bcs=bcs_fdm, form_compiler_parameters=fcp, mat_type=mat_type, options_prefix=options_prefix)
self.A = form_assembler.allocate()
self._assemble_A = form_assembler.assemble
self._assemble_A(tensor=self.A)
Amat = self.A.petscmat

if len(bcs) > 0:
Expand Down Expand Up @@ -333,7 +332,7 @@ def _assemble_P(self):
@PETSc.Log.EventDecorator("FDMUpdate")
def update(self, pc):
if hasattr(self, "A"):
self._assemble_A()
self._assemble_A(tensor=self.A)
self._assemble_P()

def apply(self, pc, x, y):
Expand Down Expand Up @@ -539,17 +538,15 @@ def assemble_coefficients(self, J, fcp, block_diagonal=False):
W = MixedFunctionSpace([c.function_space() for c in bdiags])
tensor = Function(W, val=op2.MixedDat([c.dat for c in bdiags]))
else:
from firedrake.assemble import OneFormAssembler
from firedrake.assemble import get_assembler
tensor = Function(Z.dual())
assembly_callables.append(OneFormAssembler(mixed_form, tensor=tensor, diagonal=True,
form_compiler_parameters=fcp).assemble)
assembly_callables.append(partial(get_assembler(mixed_form, form_compiler_parameters=fcp, diagonal=True).assemble, tensor=tensor))
coefficients = {"cell": tensor}
facet_integrals = [i for i in J.integrals() if "facet" in i.integral_type()]
J_facet = expand_indices(expand_derivatives(ufl.Form(facet_integrals)))
if len(J_facet.integrals()) > 0:
gamma = coefficients.setdefault("facet", Function(V.dual()))
assembly_callables.append(OneFormAssembler(J_facet, tensor=gamma, diagonal=True,
form_compiler_parameters=fcp).assemble)
assembly_callables.append(partial(get_assembler(J_facet, form_compiler_parameters=fcp, tensor=gamma, diagonal=True).assemble, tensor=gamma))
return coefficients, assembly_callables

@PETSc.Log.EventDecorator("FDMRefTensor")
Expand Down Expand Up @@ -2071,7 +2068,7 @@ def condense(self, A, J, bcs, fcp):

@PETSc.Log.EventDecorator("FDMCoefficients")
def assemble_coefficients(self, J, fcp):
from firedrake.assemble import OneFormAssembler
from firedrake.assemble import get_assembler
coefficients = {}
assembly_callables = []

Expand Down Expand Up @@ -2112,8 +2109,7 @@ def assemble_coefficients(self, J, fcp):
if not isinstance(alpha, ufl.constantvalue.Zero):
Q = FunctionSpace(mesh, finat.ufl.TensorElement(DG, shape=alpha.ufl_shape))
tensor = coefficients.setdefault("alpha", Function(Q.dual()))
assembly_callables.append(OneFormAssembler(ufl.inner(TestFunction(Q), alpha)*dx, tensor=tensor,
form_compiler_parameters=fcp).assemble)
assembly_callables.append(partial(get_assembler(ufl.inner(TestFunction(Q), alpha)*dx, form_compiler_parameters=fcp).assemble, tensor=tensor))

# get zero-th order coefficent
ref_val = [ufl.variable(t) for t in args_J]
Expand All @@ -2134,8 +2130,7 @@ def assemble_coefficients(self, J, fcp):
beta = ufl.diag_vector(beta)
Q = FunctionSpace(mesh, finat.ufl.TensorElement(DG, shape=beta.ufl_shape) if beta.ufl_shape else DG)
tensor = coefficients.setdefault("beta", Function(Q.dual()))
assembly_callables.append(OneFormAssembler(ufl.inner(TestFunction(Q), beta)*dx, tensor=tensor,
form_compiler_parameters=fcp).assemble)
assembly_callables.append(partial(get_assembler(ufl.inner(TestFunction(Q), beta)*dx, form_compiler_parameters=fcp).assemble, tensor=tensor))

family = "CG" if tdim == 1 else "DGT"
degree = 1 if tdim == 1 else 0
Expand All @@ -2157,13 +2152,11 @@ def assemble_coefficients(self, J, fcp):

Q = FunctionSpace(mesh, finat.ufl.TensorElement(DGT, shape=G.ufl_shape))
tensor = coefficients.setdefault("Gq_facet", Function(Q.dual()))
assembly_callables.append(OneFormAssembler(ifacet_inner(TestFunction(Q), G), tensor=tensor,
form_compiler_parameters=fcp).assemble)
assembly_callables.append(partial(get_assembler(ifacet_inner(TestFunction(Q), G), form_compiler_parameters=fcp).assemble, tensor=tensor))
PT = Piola.T
Q = FunctionSpace(mesh, finat.ufl.TensorElement(DGT, shape=PT.ufl_shape))
tensor = coefficients.setdefault("PT_facet", Function(Q.dual()))
assembly_callables.append(OneFormAssembler(ifacet_inner(TestFunction(Q), PT), tensor=tensor,
form_compiler_parameters=fcp).assemble)
assembly_callables.append(partial(get_assembler(ifacet_inner(TestFunction(Q), PT), form_compiler_parameters=fcp).assemble, tensor=tensor))

# make DGT functions with BC flags
shape = V.ufl_element().reference_value_shape
Expand All @@ -2189,8 +2182,7 @@ def assemble_coefficients(self, J, fcp):
if len(forms):
form = sum(forms)
if len(form.arguments()) == 1:
assembly_callables.append(OneFormAssembler(form, tensor=tensor,
form_compiler_parameters=fcp).assemble)
assembly_callables.append(partial(get_assembler(form, form_compiler_parameters=fcp).assemble, tensor=tensor))
# set arbitrary non-zero coefficients for preallocation
for coef in coefficients.values():
with coef.dat.vec as cvec:
Expand Down
Loading
Loading