diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 1ace4c5222..02870f1ea0 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -29,6 +29,8 @@ class Solver(Enum): """Enum for solver types.""" FORWARD = 0 ADJOINT = 1 + TLM = 2 + HESSIAN = 3 class GenericSolveBlock(Block): @@ -227,6 +229,9 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy): return adj_sol, adj_sol_bdy + def _hessian_solve(self, *args): + return self._assemble_and_solve_adj_eq(*args) + def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu): adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol)) return adj_sol_bdy.riesz_representation("l2") @@ -385,8 +390,7 @@ def _assemble_and_solve_soa_eq(self, dFdu_form, adj_sol, hessian_input, b = self._assemble_soa_eq_rhs(dFdu_form, adj_sol, hessian_input, d2Fdu2) dFdu_form = firedrake.adjoint(dFdu_form) - adj_sol2, adj_sol2_bdy = self._assemble_and_solve_adj_eq(dFdu_form, b, - compute_bdy) + adj_sol2, adj_sol2_bdy = self._hessian_solve(dFdu_form, b, compute_bdy) if self.adj2_cb is not None: self.adj2_cb(adj_sol2) if self.adj2_bdy_cb is not None and compute_bdy: @@ -685,11 +689,30 @@ def _adjoint_solve(self, dJdu, compute_bdy): u_sol, adj_sol_bdy, jac_adj, dJdu_copy) return u_sol, adj_sol_bdy + def _hessian_solve(self, adj_form, rhs, compute_bdy): + # self._ad_solver_replace_forms(Solver.HESSIAN) + # self._ad_solvers["hessian_lvs"].invalidate_jacobian() + self._ad_solvers["hessian_lvs"]._problem.F._components[1].assign(rhs) + self._ad_solvers["hessian_lvs"].solve() + u_sol = self._ad_solvers["hessian_lvs"]._problem.u + + adj_sol_bdy = None + if compute_bdy: + jac_adj = self._ad_solvers["hessian_lvs"]._problem.J + adj_sol_bdy = self._compute_adj_bdy( + u_sol, adj_sol_bdy, jac_adj, rhs.copy() + ) + + return u_sol, adj_sol_bdy + def _ad_assign_map(self, form, solver): if solver == Solver.FORWARD: count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map - else: + elif solver == Solver.ADJOINT: count_map = self._ad_solvers["adjoint_lvs"]._problem._ad_count_map + elif solver == Solver.TLM: + count_map = self._ad_solvers["tlm_lvs"]._problem._ad_count_map + assign_map = {} form_ad_count_map = dict((count_map[coeff], coeff) for coeff in form.coefficients()) @@ -700,8 +723,10 @@ def _ad_assign_map(self, form, solver): firedrake.Cofunction)): coeff_count = coeff.count() if coeff_count in form_ad_count_map: - assign_map[form_ad_count_map[coeff_count]] = \ - block_variable.saved_output + if solver == Solver.HESSIAN: + assign_map[form_ad_count_map[coeff_count]] = block_variable.tlm_value + else: + assign_map[form_ad_count_map[coeff_count]] = block_variable.saved_output if ( solver == Solver.ADJOINT @@ -712,6 +737,7 @@ def _ad_assign_map(self, form, solver): if coeff_count in form_ad_count_map: assign_map[form_ad_count_map[coeff_count]] = \ block_variable.saved_output + return assign_map def _ad_assign_coefficients(self, form, solver): @@ -724,9 +750,17 @@ def _ad_solver_replace_forms(self, solver=Solver.FORWARD): problem = self._ad_solvers["forward_nlvs"]._problem self._ad_assign_coefficients(problem.F, solver) self._ad_assign_coefficients(problem.J, solver) - else: + elif solver == Solver.ADJOINT: self._ad_assign_coefficients( self._ad_solvers["adjoint_lvs"]._problem.J, solver) + elif solver == Solver.TLM: + self._ad_assign_coefficients( + self._ad_solvers["tlm_lvs"]._problem.J, solver + ) + elif solver == Solver.HESSIAN: + self._ad_assign_coefficients( + self._ad_solvers["hessian_lvs"]._problem.J, solver + ) def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): compute_bdy = self._should_compute_boundary_adjoint( @@ -803,6 +837,54 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, return dFdm + def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, + prepared=None): + F_form = prepared["form"] + dFdu = prepared["dFdu"] + + bcs = [] + dFdm = 0. + for block_variable in self.get_dependencies(): + tlm_value = block_variable.tlm_value + c = block_variable.output + c_rep = block_variable.saved_output + + if isinstance(c, firedrake.DirichletBC): + if tlm_value is None: + bcs.append(c.reconstruct(g=0)) + else: + bcs.append(tlm_value) + continue + elif isinstance(c, firedrake.MeshGeometry): + X = firedrake.SpatialCoordinate(c) + c_rep = X + + if tlm_value is None: + continue + + if c == self.func and not self.linear: + continue + + dFdm += firedrake.derivative(-F_form, c_rep, tlm_value) + + if isinstance(dFdm, float): + v = dFdu.arguments()[0] + dFdm = firedrake.inner( + firedrake.Constant(numpy.zeros(v.ufl_shape)), v + ) * firedrake.dx + + dFdm = ufl.algorithms.expand_derivatives(dFdm) + dFdm = firedrake.assemble(dFdm) + + # XXX I dunno how this works + self._ad_solver_replace_forms(Solver.TLM) + self._ad_solvers["tlm_lvs"].invalidate_jacobian() + # update RHS + self._ad_solvers["tlm_lvs"]._problem.F._components[1].assign(dFdm) + + self._ad_solvers["tlm_lvs"].solve() + return self._ad_solvers["tlm_lvs"]._problem.u + class ProjectBlock(SolveVarFormBlock): def __init__(self, v, V, output, bcs=[], *args, **kwargs): diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index 21dded9839..4f08f238d3 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -17,6 +17,7 @@ def wrapper(self, *args, **kwargs): self._ad_u = self.u_restrict self._ad_bcs = self.bcs self._ad_J = self.J + try: # Some forms (e.g. SLATE tensors) are not currently # differentiable. @@ -29,6 +30,7 @@ def wrapper(self, *args, **kwargs): self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True) except (ValueError, TypeError, NotImplementedError): self._ad_adj_F = None + self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear} self._ad_count_map = {} return wrapper @@ -49,7 +51,8 @@ def wrapper(self, problem, *args, **kwargs): self._ad_args = args self._ad_kwargs = kwargs self._ad_solvers = {"forward_nlvs": None, "adjoint_lvs": None, - "recompute_count": 0} + "recompute_count": 0, "tlm_lvs": None, + "hessian_lvs": None} self._ad_adj_cache = {} return wrapper @@ -100,6 +103,20 @@ def wrapper(self, **kwargs): if self._ad_problem._constant_jacobian: self._ad_solvers["update_adjoint"] = False + if not self._ad_solvers["hessian_lvs"]: + with stop_annotating(): + self._ad_solvers["hessian_lvs"] = LinearVariationalSolver( + self._ad_hessian_lvs_problem(block, problem._ad_adj_F), + ) + + if not self._ad_solvers["tlm_lvs"]: + with stop_annotating(): + self._ad_solvers["tlm_lvs"] = LinearVariationalSolver( + self._ad_tlm_lvs_problem(block, problem.F, problem.u_restrict) + ) + if self._ad_problem._constant_jacobian: + self._ad_solvers["update_tlm"] = False + block._ad_solvers = self._ad_solvers tape.add_block(block) @@ -151,7 +168,8 @@ def _ad_adj_lvs_problem(self, block, adj_F): # linear variational problem is created with a deep copy of the # `block.adj_F` coefficients. _ad_count_map, J_replace_map, _ = self._build_count_map( - adj_F, block._dependencies) + adj_F, block._dependencies, + ) lvp = LinearVariationalProblem( replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol, bcs=tmp_problem.bcs, @@ -159,6 +177,45 @@ def _ad_adj_lvs_problem(self, block, adj_F): lvp._ad_count_map_update(_ad_count_map) return lvp + @no_annotations + def _ad_hessian_lvs_problem(self, block, adj_dFdu): + from firedrake import Function, Cofunction, LinearVariationalProblem + + bcs = block._homogenize_bcs() + adj_sol = Function(block.function_space) + right_hand_side = Cofunction(block.function_space.dual()) + tmp_problem = LinearVariationalProblem( + adj_dFdu, right_hand_side, adj_sol, bcs=bcs, + constant_jacobian=self._ad_problem._constant_jacobian) + + _ad_count_map, J_replace_map, _ = self._build_count_map( + adj_dFdu, block._dependencies, + ) + lvp = LinearVariationalProblem( + replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol, + bcs=tmp_problem.bcs, + constant_jacobian=self._ad_problem._constant_jacobian) + lvp._ad_count_map_update(_ad_count_map) + return lvp + + @no_annotations + def _ad_tlm_lvs_problem(self, block, F, u): + from firedrake import Function, Cofunction, LinearVariationalProblem + + lhs = derivative(F, u) + _ad_count_map, F_replace_map, _ = self._build_count_map(lhs, block._dependencies) + sol = Function(block.function_space) + rhs = Cofunction(block.function_space.dual()) + lvp = LinearVariationalProblem( + replace(lhs, F_replace_map), + rhs, + sol, + bcs=block._homogenize_bcs(), + constant_jacobian=self._ad_problem._constant_jacobian, + ) + lvp._ad_count_map_update(_ad_count_map) + return lvp + def _build_count_map(self, J, dependencies, F=None): from firedrake import Function