@@ -30,6 +30,7 @@ class Solver(Enum):
3030 FORWARD = 0
3131 ADJOINT = 1
3232 TLM = 2
33+ HESSIAN = 3
3334
3435
3536class GenericSolveBlock (Block ):
@@ -228,6 +229,9 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
228229
229230 return adj_sol , adj_sol_bdy
230231
232+ def _hessian_solve (self , * args ):
233+ return self ._assemble_and_solve_adj_eq (* args )
234+
231235 def _compute_adj_bdy (self , adj_sol , adj_sol_bdy , dFdu_adj_form , dJdu ):
232236 adj_sol_bdy = firedrake .assemble (dJdu - firedrake .action (dFdu_adj_form , adj_sol ))
233237 return adj_sol_bdy .riesz_representation ("l2" )
@@ -386,8 +390,7 @@ def _assemble_and_solve_soa_eq(self, dFdu_form, adj_sol, hessian_input,
386390 b = self ._assemble_soa_eq_rhs (dFdu_form , adj_sol , hessian_input ,
387391 d2Fdu2 )
388392 dFdu_form = firedrake .adjoint (dFdu_form )
389- adj_sol2 , adj_sol2_bdy = self ._assemble_and_solve_adj_eq (dFdu_form , b ,
390- compute_bdy )
393+ adj_sol2 , adj_sol2_bdy = self ._hessian_solve (dFdu_form , b , compute_bdy )
391394 if self .adj2_cb is not None :
392395 self .adj2_cb (adj_sol2 )
393396 if self .adj2_bdy_cb is not None and compute_bdy :
@@ -686,6 +689,22 @@ def _adjoint_solve(self, dJdu, compute_bdy):
686689 u_sol , adj_sol_bdy , jac_adj , dJdu_copy )
687690 return u_sol , adj_sol_bdy
688691
692+ def _hessian_solve (self , adj_form , rhs , compute_bdy ):
693+ # self._ad_solver_replace_forms(Solver.HESSIAN)
694+ # self._ad_solvers["hessian_lvs"].invalidate_jacobian()
695+ self ._ad_solvers ["hessian_lvs" ]._problem .F ._components [1 ].assign (rhs )
696+ self ._ad_solvers ["hessian_lvs" ].solve ()
697+ u_sol = self ._ad_solvers ["hessian_lvs" ]._problem .u
698+
699+ adj_sol_bdy = None
700+ if compute_bdy :
701+ jac_adj = self ._ad_solvers ["hessian_lvs" ]._problem .J
702+ adj_sol_bdy = self ._compute_adj_bdy (
703+ u_sol , adj_sol_bdy , jac_adj , rhs .copy ()
704+ )
705+
706+ return u_sol , adj_sol_bdy
707+
689708 def _ad_assign_map (self , form , solver ):
690709 if solver == Solver .FORWARD :
691710 count_map = self ._ad_solvers ["forward_nlvs" ]._problem ._ad_count_map
@@ -704,8 +723,10 @@ def _ad_assign_map(self, form, solver):
704723 firedrake .Cofunction )):
705724 coeff_count = coeff .count ()
706725 if coeff_count in form_ad_count_map :
707- assign_map [form_ad_count_map [coeff_count ]] = \
708- block_variable .saved_output
726+ if solver == Solver .HESSIAN :
727+ assign_map [form_ad_count_map [coeff_count ]] = block_variable .tlm_value
728+ else :
729+ assign_map [form_ad_count_map [coeff_count ]] = block_variable .saved_output
709730
710731 if (
711732 solver == Solver .ADJOINT
@@ -716,6 +737,7 @@ def _ad_assign_map(self, form, solver):
716737 if coeff_count in form_ad_count_map :
717738 assign_map [form_ad_count_map [coeff_count ]] = \
718739 block_variable .saved_output
740+
719741 return assign_map
720742
721743 def _ad_assign_coefficients (self , form , solver ):
@@ -735,6 +757,10 @@ def _ad_solver_replace_forms(self, solver=Solver.FORWARD):
735757 self ._ad_assign_coefficients (
736758 self ._ad_solvers ["tlm_lvs" ]._problem .J , solver
737759 )
760+ elif solver == Solver .HESSIAN :
761+ self ._ad_assign_coefficients (
762+ self ._ad_solvers ["hessian_lvs" ]._problem .J , solver
763+ )
738764
739765 def prepare_evaluate_adj (self , inputs , adj_inputs , relevant_dependencies ):
740766 compute_bdy = self ._should_compute_boundary_adjoint (
@@ -858,11 +884,6 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
858884
859885 self ._ad_solvers ["tlm_lvs" ].solve ()
860886 return self ._ad_solvers ["tlm_lvs" ]._problem .u
861- # return self._assemble_and_solve_tlm_eq(
862- # firedrake.assemble(dFdu, bcs=bcs, **self.assemble_kwargs),
863- # dFdm, dudm, bcs
864- # )
865-
866887
867888
868889class ProjectBlock (SolveVarFormBlock ):
0 commit comments