Skip to content

Commit 9b5b3ba

Browse files
committed
Start to move Hessian evaluation into NonlinearVariationalSolveBlock
1 parent 25bb74d commit 9b5b3ba

File tree

2 files changed

+61
-10
lines changed

2 files changed

+61
-10
lines changed

firedrake/adjoint_utils/blocks/solving.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Solver(Enum):
3030
FORWARD = 0
3131
ADJOINT = 1
3232
TLM = 2
33+
HESSIAN = 3
3334

3435

3536
class GenericSolveBlock(Block):
@@ -228,6 +229,9 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
228229

229230
return adj_sol, adj_sol_bdy
230231

232+
def _hessian_solve(self, *args):
233+
return self._assemble_and_solve_adj_eq(*args)
234+
231235
def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu):
232236
adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol))
233237
return adj_sol_bdy.riesz_representation("l2")
@@ -386,8 +390,7 @@ def _assemble_and_solve_soa_eq(self, dFdu_form, adj_sol, hessian_input,
386390
b = self._assemble_soa_eq_rhs(dFdu_form, adj_sol, hessian_input,
387391
d2Fdu2)
388392
dFdu_form = firedrake.adjoint(dFdu_form)
389-
adj_sol2, adj_sol2_bdy = self._assemble_and_solve_adj_eq(dFdu_form, b,
390-
compute_bdy)
393+
adj_sol2, adj_sol2_bdy = self._hessian_solve(dFdu_form, b, compute_bdy)
391394
if self.adj2_cb is not None:
392395
self.adj2_cb(adj_sol2)
393396
if self.adj2_bdy_cb is not None and compute_bdy:
@@ -686,6 +689,22 @@ def _adjoint_solve(self, dJdu, compute_bdy):
686689
u_sol, adj_sol_bdy, jac_adj, dJdu_copy)
687690
return u_sol, adj_sol_bdy
688691

692+
def _hessian_solve(self, adj_form, rhs, compute_bdy):
693+
# self._ad_solver_replace_forms(Solver.HESSIAN)
694+
# self._ad_solvers["hessian_lvs"].invalidate_jacobian()
695+
self._ad_solvers["hessian_lvs"]._problem.F._components[1].assign(rhs)
696+
self._ad_solvers["hessian_lvs"].solve()
697+
u_sol = self._ad_solvers["hessian_lvs"]._problem.u
698+
699+
adj_sol_bdy = None
700+
if compute_bdy:
701+
jac_adj = self._ad_solvers["hessian_lvs"]._problem.J
702+
adj_sol_bdy = self._compute_adj_bdy(
703+
u_sol, adj_sol_bdy, jac_adj, rhs.copy()
704+
)
705+
706+
return u_sol, adj_sol_bdy
707+
689708
def _ad_assign_map(self, form, solver):
690709
if solver == Solver.FORWARD:
691710
count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map
@@ -704,8 +723,10 @@ def _ad_assign_map(self, form, solver):
704723
firedrake.Cofunction)):
705724
coeff_count = coeff.count()
706725
if coeff_count in form_ad_count_map:
707-
assign_map[form_ad_count_map[coeff_count]] = \
708-
block_variable.saved_output
726+
if solver == Solver.HESSIAN:
727+
assign_map[form_ad_count_map[coeff_count]] = block_variable.tlm_value
728+
else:
729+
assign_map[form_ad_count_map[coeff_count]] = block_variable.saved_output
709730

710731
if (
711732
solver == Solver.ADJOINT
@@ -716,6 +737,7 @@ def _ad_assign_map(self, form, solver):
716737
if coeff_count in form_ad_count_map:
717738
assign_map[form_ad_count_map[coeff_count]] = \
718739
block_variable.saved_output
740+
719741
return assign_map
720742

721743
def _ad_assign_coefficients(self, form, solver):
@@ -735,6 +757,10 @@ def _ad_solver_replace_forms(self, solver=Solver.FORWARD):
735757
self._ad_assign_coefficients(
736758
self._ad_solvers["tlm_lvs"]._problem.J, solver
737759
)
760+
elif solver == Solver.HESSIAN:
761+
self._ad_assign_coefficients(
762+
self._ad_solvers["hessian_lvs"]._problem.J, solver
763+
)
738764

739765
def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
740766
compute_bdy = self._should_compute_boundary_adjoint(
@@ -858,11 +884,6 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
858884

859885
self._ad_solvers["tlm_lvs"].solve()
860886
return self._ad_solvers["tlm_lvs"]._problem.u
861-
# return self._assemble_and_solve_tlm_eq(
862-
# firedrake.assemble(dFdu, bcs=bcs, **self.assemble_kwargs),
863-
# dFdm, dudm, bcs
864-
# )
865-
866887

867888

868889
class ProjectBlock(SolveVarFormBlock):

firedrake/adjoint_utils/variational_solver.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def wrapper(self, *args, **kwargs):
1717
self._ad_u = self.u_restrict
1818
self._ad_bcs = self.bcs
1919
self._ad_J = self.J
20+
2021
try:
2122
# Some forms (e.g. SLATE tensors) are not currently
2223
# differentiable.
@@ -29,6 +30,7 @@ def wrapper(self, *args, **kwargs):
2930
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)
3031
except (ValueError, TypeError, NotImplementedError):
3132
self._ad_adj_F = None
33+
3234
self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}
3335
self._ad_count_map = {}
3436
return wrapper
@@ -49,7 +51,8 @@ def wrapper(self, problem, *args, **kwargs):
4951
self._ad_args = args
5052
self._ad_kwargs = kwargs
5153
self._ad_solvers = {"forward_nlvs": None, "adjoint_lvs": None,
52-
"recompute_count": 0, "tlm_lvs": None}
54+
"recompute_count": 0, "tlm_lvs": None,
55+
"hessian_lvs": None}
5356
self._ad_adj_cache = {}
5457

5558
return wrapper
@@ -100,6 +103,12 @@ def wrapper(self, **kwargs):
100103
if self._ad_problem._constant_jacobian:
101104
self._ad_solvers["update_adjoint"] = False
102105

106+
if not self._ad_solvers["hessian_lvs"]:
107+
with stop_annotating():
108+
self._ad_solvers["hessian_lvs"] = LinearVariationalSolver(
109+
self._ad_hessian_lvs_problem(block, problem._ad_adj_F),
110+
)
111+
103112
if not self._ad_solvers["tlm_lvs"]:
104113
with stop_annotating():
105114
self._ad_solvers["tlm_lvs"] = LinearVariationalSolver(
@@ -168,6 +177,27 @@ def _ad_adj_lvs_problem(self, block, adj_F):
168177
lvp._ad_count_map_update(_ad_count_map)
169178
return lvp
170179

180+
@no_annotations
181+
def _ad_hessian_lvs_problem(self, block, adj_dFdu):
182+
from firedrake import Function, Cofunction, LinearVariationalProblem
183+
184+
bcs = block._homogenize_bcs()
185+
adj_sol = Function(block.function_space)
186+
right_hand_side = Cofunction(block.function_space.dual())
187+
tmp_problem = LinearVariationalProblem(
188+
adj_dFdu, right_hand_side, adj_sol, bcs=bcs,
189+
constant_jacobian=self._ad_problem._constant_jacobian)
190+
191+
_ad_count_map, J_replace_map, _ = self._build_count_map(
192+
adj_dFdu, block._dependencies,
193+
)
194+
lvp = LinearVariationalProblem(
195+
replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol,
196+
bcs=tmp_problem.bcs,
197+
constant_jacobian=self._ad_problem._constant_jacobian)
198+
lvp._ad_count_map_update(_ad_count_map)
199+
return lvp
200+
171201
@no_annotations
172202
def _ad_tlm_lvs_problem(self, block, F, u):
173203
from firedrake import Function, Cofunction, LinearVariationalProblem

0 commit comments

Comments
 (0)