diff --git a/firedrake/__init__.py b/firedrake/__init__.py index f0f2a0901f..f5d158e837 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -74,7 +74,6 @@ from firedrake.checkpointing import * from firedrake.constant import * from firedrake.exceptions import * -from firedrake.fml import * from firedrake.function import * from firedrake.functionspace import * from firedrake.interpolation import * @@ -105,6 +104,7 @@ from firedrake.ensemble import * from firedrake.randomfunctiongen import * from firedrake.progress_bar import ProgressBar # noqa: F401 +from firedrake.fml import * from firedrake.logging import * # Set default log level diff --git a/firedrake/fml/__init__.py b/firedrake/fml/__init__.py index 5a424f2312..d0f951305d 100644 --- a/firedrake/fml/__init__.py +++ b/firedrake/fml/__init__.py @@ -1,2 +1,2 @@ -from gusto.fml.form_manipulation_language import * # noqa -from gusto.fml.replacement import * # noqa +from firedrake.fml.form_manipulation_language import * # noqa +from firedrake.fml.replacement import * # noqa diff --git a/firedrake/fml/form_manipulation_language.py b/firedrake/fml/form_manipulation_language.py index 1efe8c8666..850d20916f 100644 --- a/firedrake/fml/form_manipulation_language.py +++ b/firedrake/fml/form_manipulation_language.py @@ -28,10 +28,13 @@ class Term(object): def __init__(self, form, label_dict=None): """ - Args: - form (:class:`ufl.Form`): the form for this terms. - label_dict (dict, optional): dictionary of key-value pairs - corresponding to current form labels. Defaults to None. + Parameters + ---------- + form : :class:`ufl.Form` + The form for this terms. + label_dict : dict, optional + Dictionary of key-value pairs corresponding to current form labels. + Defaults to None. """ self.form = form self.labels = label_dict or {} @@ -40,10 +43,14 @@ def get(self, label): """ Returns the value of a label. - Args: - label (:class:`Label`): the label to return the value of. + Parameters + ---------- + label : :class:`Label` + The label to return the value of. - Returns: + Returns + ------- + Any type The value of a label. """ return self.labels.get(label.label) @@ -52,17 +59,19 @@ def has_label(self, *labels, return_tuple=False): """ Whether the term has the specified labels attached to it. - Args: - *labels (:class:`Label`): a label or series of labels. A tuple is - automatically returned if multiple labels are provided as - arguments. - return_tuple (bool, optional): if True, forces a tuple to be - returned even if only one label is provided as an argument. - Defaults to False. + Parameters + ---------- + *labels : :class:`Label` + A label or series of labels. A tuple is automatically returned if + multiple labels are provided as arguments. + return_tuple : bool, optional + If True, forces a tuple to be returned even if only one label is + provided as an argument. Defaults to False. - Returns: - bool or tuple: Booleans corresponding to whether the term has the - specified labels. + Returns + ------- + bool or tuple + Booleans corresponding to whether the term has the specified labels. """ if len(labels) == 1 and not return_tuple: return labels[0].label in self.labels @@ -73,14 +82,19 @@ def __add__(self, other): """ Adds a term or labelled form to this term. - Args: - other (:class:`Term` or :class:`LabelledForm`): the term or labelled - form to add to this term. + Parameters + ---------- + other : :class:`Term` or :class:`LabelledForm` + The term or labelled form to add to this term. - Returns: - :class:`LabelledForm`: a labelled form containing the terms. + Returns + ------- + :class:`LabelledForm` + A labelled form containing the terms. """ - if other is None: + if self is NullTerm: + return other + if other is None or other is NullTerm: return self elif isinstance(other, Term): return LabelledForm(self, other) @@ -95,12 +109,15 @@ def __sub__(self, other): """ Subtracts a term or labelled form from this term. - Args: - other (:class:`Term` or :class:`LabelledForm`): the term or labelled - form to subtract from this term. + Parameters + ---------- + other : :class:`Term` or :class:`LabelledForm` + The term or labelled form to subtract from this term. - Returns: - :class:`LabelledForm`: a labelled form containing the terms. + Returns + ------- + :class:`LabelledForm` + A labelled form containing the terms. """ other = other * Constant(-1.0) return self + other @@ -109,19 +126,16 @@ def __mul__(self, other): """ Multiplies this term by another quantity. - Args: - other (float, :class:`Constant` or :class:`ufl.algebra.Product`): - the quantity to multiply this term by. If it is a float or int - then it is converted to a :class:`Constant` before the - multiplication. + Parameters + ---------- + other : float, :class:`Constant` or :class:`ufl.algebra.Product` + The quantity to multiply this term by. - Returns: - :class:`Term`: the product of the term with the quantity. + Returns + ------- + :class:`Term` + The product of the term with the quantity. """ - if type(other) in (float, int): - other = Constant(other) - elif type(other) not in [Constant, ufl.algebra.Product]: - return NotImplemented return Term(other*self.form, self.labels) __rmul__ = __mul__ @@ -130,20 +144,16 @@ def __truediv__(self, other): """ Divides this term by another quantity. - Args: - other (float, :class:`Constant` or :class:`ufl.algebra.Product`): - the quantity to divide this term by. If it is a float or int - then it is converted to a :class:`Constant` before the - division. + Parameters + ---------- + other : float, :class:`Constant` or :class:`ufl.algebra.Product` + The quantity to divide this term by. - Returns: - :class:`Term`: the quotient of the term divided by the quantity. + Returns + ------- + :class:`Term`: The quotient of the term divided by the quantity. """ - if type(other) in (float, int, Constant, ufl.algebra.Product): - other = Constant(1.0 / other) - return self * other - else: - return NotImplemented + return self * (Constant(1.0) / other) # This is necessary to be the initialiser for functools.reduce @@ -165,11 +175,14 @@ class LabelledForm(object): def __init__(self, *terms): """ - Args: - *terms (:class:`Term`): terms to combine to make the `LabelledForm`. + Parameters + ---------- + *terms :class:`Term` + Terms to combine to make the `LabelledForm`. - Raises: - TypeError: _description_ + Raises + ------ + TypeError: If any argument is not a term. """ if len(terms) == 1 and isinstance(terms[0], LabelledForm): self.terms = terms[0].terms @@ -182,12 +195,15 @@ def __add__(self, other): """ Adds a form, term or labelled form to this labelled form. - Args: - other (:class:`ufl.Form`, :class:`Term` or :class:`LabelledForm`): - the form, term or labelled form to add to this labelled form. + Parameters + ---------- + other : :class:`ufl.Form`, :class:`Term` or :class:`LabelledForm` + The form, term or labelled form to add to this labelled form. - Returns: - :class:`LabelledForm`: a labelled form containing the terms. + Returns + ------- + :class:`LabelledForm` + A labelled form containing the terms. """ if isinstance(other, ufl.Form): return LabelledForm(*self, Term(other)) @@ -206,65 +222,58 @@ def __sub__(self, other): """ Subtracts a form, term or labelled form from this labelled form. - Args: - other (:class:`ufl.Form`, :class:`Term` or :class:`LabelledForm`): - the form, term or labelled form to subtract from this labelled - form. + Parameters + ---------- + other : :class:`ufl.Form`, :class:`Term` or :class:`LabelledForm` + The form, term or labelled form to subtract from this labelled form. - Returns: - :class:`LabelledForm`: a labelled form containing the terms. + Returns + ------- + :class:`LabelledForm` + A labelled form containing the terms. """ if type(other) is Term: return LabelledForm(*self, Constant(-1.)*other) elif type(other) is LabelledForm: return LabelledForm(*self, *[Constant(-1.)*t for t in other]) - elif type(other) is ufl.algebra.Product: - return LabelledForm(*self, Term(Constant(-1.)*other)) elif other is None: return self else: - return NotImplemented + # Make new Term for other and subtract it + return LabelledForm(*self, Term(Constant(-1.)*other)) def __mul__(self, other): """ Multiplies this labelled form by another quantity. - Args: - other (float, :class:`Constant` or :class:`ufl.algebra.Product`): - the quantity to multiply this labelled form by. If it is a float - or int then it is converted to a :class:`Constant` before the - multiplication. All terms in the form are multiplied. - - Returns: - :class:`LabelledForm`: the product of all terms with the quantity. - """ - if type(other) in (float, int): - other = Constant(other) - # UFL can cancel constants to a Zero type which needs treating separately - elif type(other) is ufl.constantvalue.Zero: - other = Constant(0.0) - elif type(other) not in [Constant, ufl.algebra.Product]: - return NotImplemented + Parameters + ---------- + other : float, :class:`Constant` or :class:`ufl.algebra.Product` + The quantity to multiply this labelled form by. All terms in the + form are multiplied. + + Returns + ------- + :class:`LabelledForm`: The product of all terms with the quantity. + """ return self.label_map(all_terms, lambda t: Term(other*t.form, t.labels)) def __truediv__(self, other): """ Divides this labelled form by another quantity. - Args: - other (float, :class:`Constant` or :class:`ufl.algebra.Product`): - the quantity to divide this labelled form by. If it is a float - or int then it is converted to a :class:`Constant` before the - division. All terms in the form are divided. + Parameters + ---------- + other : float, :class:`Constant` or :class:`ufl.algebra.Product` + The quantity to divide this labelled form by. All terms in the form + are divided. - Returns: - :class:`LabelledForm`: the quotient of all terms with the quantity. + Returns + ------- + :class:`LabelledForm` + The quotient of all terms with the quantity. """ - if type(other) in (float, int, Constant, ufl.algebra.Product): - other = Constant(1.0 / other) - return self * other - else: - return NotImplemented + return self * (Constant(1.0) / other) __rmul__ = __mul__ @@ -281,15 +290,21 @@ def label_map(self, term_filter, map_if_true=identity, """ Maps selected terms in the labelled form, returning a new labelled form. - Args: - term_filter (func): a function to filter the labelled form's terms. - map_if_true (func, optional): how to map the terms for which the - term_filter returns True. Defaults to identity. - map_if_false (func, optional): how to map the terms for which the - term_filter returns False. Defaults to identity. + Parameters + ---------- + term_filter : func + A function to filter the labelled form's terms. + map_if_true : func, optional + How to map the terms for which the term_filter returns True. + Defaults to identity. + map_if_false : func, optional + How to map the terms for which the term_filter returns False. + Defaults to identity. - Returns: - :class:`LabelledForm`: a new labelled form with the terms mapped. + Returns + ------- + :class:`LabelledForm` + A new labelled form with the terms mapped. """ new_labelled_form = LabelledForm( @@ -297,10 +312,11 @@ def label_map(self, term_filter, map_if_true=identity, filter(lambda t: t is not None, (map_if_true(t) if term_filter(t) else map_if_false(t) for t in self.terms)), - # TODO: Not clear what the initialiser should be! - # No initialiser means label_map can't work if everything is false - # None is a problem as cannot be added to Term - # NullTerm works but will need dropping ... + # Need to set an initialiser, otherwise the label_map + # won't work if the term_filter is False for everything + # None does not work, as then we add Terms to None + # and the addition operation is defined from None + # rather than the Term. NullTerm solves this. NullTerm)) # Drop the NullTerm @@ -314,11 +330,15 @@ def form(self): """ Provides the whole form from the labelled form. - Raises: - TypeError: if the labelled form has no terms. + Raises + ------ + TypeError + If the labelled form has no terms. - Returns: - :class:`ufl.Form`: the whole form corresponding to all the terms. + Returns + ------- + :class:`ufl.Form` + The whole form corresponding to all the terms. """ # Throw an error if there is no form if len(self.terms) == 0: @@ -334,12 +354,16 @@ class Label(object): def __init__(self, label, *, value=True, validator=None): """ - Args: - label (str): the name of the label. - value (..., optional): the value for the label to take. Can be any - type (subject to the validator). Defaults to True. - validator (func, optional): function to check the validity of any - value later passed to the label. Defaults to None. + Parameters + ---------- + label : str + The name of the label. + value : Any, optional + The value for the label to take. Can be any type (subject to the + validator). Defaults to True. + validator : func, optional + Function to check the validity of any value later passed to the + label. Defaults to None. """ self.label = label self.default_value = value @@ -349,20 +373,24 @@ def __call__(self, target, value=None): """ Applies the label to a form or term. - Args: - target (:class:`ufl.Form`, :class:`Term` or :class:`LabelledForm`): - the form, term or labelled form to be labelled. - value (..., optional): the value to attach to this label. Defaults - to None. - - Raises: - ValueError: if the `target` is not a :class:`ufl.Form`, - :class:`Term` or :class:`LabelledForm`. - - Returns: - :class:`Term` or :class:`LabelledForm`: a :class:`Term` is returned - if the target is a :class:`Term`, otherwise a - :class:`LabelledForm` is returned. + Parameters + ---------- + target : :class:`ufl.Form`, :class:`Term` or :class:`LabelledForm` + The form, term or labelled form to be labelled. + value : Any, optional + The value to attach to this label. Defaults to None. + + Raises + ------ + ValueError + If the `target` is not a :class:`ufl.Form`, :class:`Term` or + :class:`LabelledForm`. + + Returns + ------- + :class:`Term` or :class:`LabelledForm` + A :class:`Term` is returned if the target is a :class:`Term`, + otherwise a :class:`LabelledForm` is returned. """ # if value is provided, check that we have a validator function # and validate the value, otherwise use default value @@ -390,13 +418,15 @@ def remove(self, target): This removes any :class:`Label` with this `label` from `target`. If called on an :class:`LabelledForm`, it acts termwise. - Args: - target (:class:`Term` or :class:`LabelledForm`): term or labelled - form to have this label removed from. + Parameters + ---------- + target : :class:`Term` or :class:`LabelledForm` + Term or labelled form to have this label removed from. - Raises: - ValueError: if the `target` is not a :class:`Term` or a - :class:`LabelledForm`. + Raises + ------ + ValueError + If the `target` is not a :class:`Term` or a :class:`LabelledForm`. """ if isinstance(target, LabelledForm): @@ -418,14 +448,18 @@ def update_value(self, target, new): This updates the value of any :class:`Label` with this `label` from `target`. If called on an :class:`LabelledForm`, it acts termwise. - Args: - target (:class:`Term` or :class:`LabelledForm`): term or labelled - form to have this label updated. - new (...): the new value for this label to take. - - Raises: - ValueError: if the `target` is not a :class:`Term` or a - :class:`LabelledForm`. + Parameters + ---------- + target : :class:`Term` or :class:`LabelledForm` + Term or labelled form to have this label updated. + new : Any + The new value for this label to take. The type is subject to the + label's validator (if it has one). + + Raises + ------ + ValueError + If the `target` is not a :class:`Term` or a :class:`LabelledForm`. """ if isinstance(target, LabelledForm): diff --git a/firedrake/fml/replacement.py b/firedrake/fml/replacement.py index 9916204c02..2db67473e9 100644 --- a/firedrake/fml/replacement.py +++ b/firedrake/fml/replacement.py @@ -6,8 +6,7 @@ from .form_manipulation_language import Term, subject from firedrake import split, MixedElement -__all__ = ["replace_test_function", "replace_trial_function", - "replace_subject"] +__all__ = ["replace_test_function", "replace_trial_function", "replace_subject"] # ---------------------------------------------------------------------------- # @@ -15,10 +14,42 @@ # ---------------------------------------------------------------------------- # def _replace_dict(old, new, old_idx, new_idx, replace_type): """ - Build a dictionary to pass to the ufl.replace routine - The dictionary matches variables in the old term with those in the new - - Does not check types unless indexing is required (leave type-checking to ufl.replace) + Build a dictionary to pass to the ufl.replace routine. The dictionary + matches variables in the old term with those in the new. + + Does not check types unless indexing is required (leave type-checking to + ufl.replace). + + Parameters + ---------- + old : :class:`Function` or :class:`TestFunction` or :class:`TrialFunction` + The old variable to be replaced. + new : :class:`Function` or :class:`TestFunction` or :class:`TrialFunction` + The new variable to be replace with. + old_idx : int + The index of the old variable to be replaced. If the old variable is not + indexable then this should be None. + new_idx : int + The index of the new variable to replace with. If the new variable is + not indexable then this should be None. + replace_type : str + A string to use in error messages, describing the type of replacement + that is happening. + + Returns + ------- + dict + A dictionary pairing the variables in the old term to be replaced with + the new variables to replace them. + + Raises + ------ + ValueError + If the old_idx argument is not provided when an indexable variable is to + be replaced by something not of the same shape. + ValueError + If the new_idx argument is not provided when an indexable variable is to + be replace something not of the same shape. """ mixed_old = type(old.ufl_element()) is MixedElement @@ -98,11 +129,15 @@ def replace_test_function(new_test, old_idx=None, new_idx=None): """ A routine to replace the test function in a term with a new test function. - Args: - new_test (:class:`TestFunction`): the new test function. + Parameters + ---------- + new_test : :class:`TestFunction` + The new test function. - Returns: - a function that takes in t, a :class:`Term`, and returns a new + Returns + ------- + func + A function that takes in t, a :class:`Term`, and returns a new :class:`Term` with form containing the new_test and labels=t.labels """ @@ -111,11 +146,15 @@ def repl(t): Replaces the test function in a term with a new expression. This is built around the ufl replace routine. - Args: - t (:class:`Term`): the original term. + Parameters + ---------- + t : :class:`Term` + The original term. - Returns: - :class:`Term`: the new term. + Returns + ------- + :class:`Term` + The new term. """ old_test = t.form.arguments()[0] replace_dict = _replace_dict(old_test, new_test, @@ -138,11 +177,15 @@ def replace_trial_function(new_trial, old_idx=None, new_idx=None): """ A routine to replace the trial function in a term with a new expression. - Args: - new (:class:`TrialFunction` or :class:`Function`): the new function. + Parameters + ---------- + new : :class:`TrialFunction` or :class:`Function` + The new function. - Returns: - a function that takes in t, a :class:`Term`, and returns a new + Returns + ------- + func + A function that takes in t, a :class:`Term`, and returns a new :class:`Term` with form containing the new_test and labels=t.labels """ @@ -151,14 +194,20 @@ def repl(t): Replaces the trial function in a term with a new expression. This is built around the ufl replace routine. - Args: - t (:class:`Term`): the original term. + Parameters + ---------- + t (:class:`Term`) + The original term. - Raises: - TypeError: if the form is linear. + Raises + ------ + TypeError + If the form is not linear. - Returns: - :class:`Term`: the new term. + Returns + ------- + :class:`Term` + The new term. """ if len(t.form.arguments()) != 2: raise TypeError('Trying to replace trial function of a form that is not linear') @@ -183,25 +232,34 @@ def replace_subject(new_subj, old_idx=None, new_idx=None): """ A routine to replace the subject in a term with a new variable. - Args: - new (:class:`ufl.Expr`): the new expression to replace the subject. - idx (int, optional): index of the subject in the equation's - :class:`MixedFunctionSpace`. Defaults to None. + Parameters + ---------- + new : :class:`ufl.Expr` + The new expression to replace the subject. + idx : int, optional + Index of the subject in the equation's :class:`MixedFunctionSpace`. + Defaults to None. """ def repl(t): """ Replaces the subject in a term with a new expression. This is built around the ufl replace routine. - Args: - t (:class:`Term`): the original term. - - Raises: - ValueError: when the new expression and subject are not of - compatible sizes (e.g. a mixed function vs a non-mixed function) - - Returns: - :class:`Term`: the new term. + Parameters + ---------- + t : :class:`Term` + The original term. + + Raises + ------ + ValueError + When the new expression and subject are not of compatible sizes + (e.g. a mixed function vs a non-mixed function) + + Returns + ------- + :class:`Term`: + The new term. """ old_subj = t.get(subject) diff --git a/firedrake/output.py b/firedrake/output.py index c21b9dd48e..f3bcac0d66 100644 --- a/firedrake/output.py +++ b/firedrake/output.py @@ -11,10 +11,14 @@ from firedrake.petsc import PETSc from firedrake.utils import IntType -from .paraview_reordering import vtk_lagrange_tet_reorder,\ - vtk_lagrange_hex_reorder, vtk_lagrange_interval_reorder,\ - vtk_lagrange_triangle_reorder, vtk_lagrange_quad_reorder,\ - vtk_lagrange_wedge_reorder +from .paraview_reordering import ( + vtk_lagrange_tet_reorder, + vtk_lagrange_hex_reorder, + vtk_lagrange_interval_reorder, + vtk_lagrange_triangle_reorder, + vtk_lagrange_quad_reorder, + vtk_lagrange_wedge_reorder, +) __all__ = ("File", ) diff --git a/tests/regression/test_fml.py b/tests/regression/test_fml.py new file mode 100644 index 0000000000..42a92c6d7d --- /dev/null +++ b/tests/regression/test_fml.py @@ -0,0 +1,121 @@ +""" +Tests a full workflow for the Form Manipulation Language (FML). + +This uses an IMEX discretisation of the linear shallow-water equations on a +mixed function space. +""" + +from firedrake import (PeriodicUnitSquareMesh, FunctionSpace, Constant, + MixedFunctionSpace, TestFunctions, Function, split, + inner, dx, SpatialCoordinate, as_vector, pi, sin, div, + NonlinearVariationalProblem, NonlinearVariationalSolver, + subject, replace_subject, keep, drop, Label) + + +def test_fml(): + + # Define labels for shallow-water + time_derivative = Label("time_derivative") + transport = Label("transport") + pressure_gradient = Label("pressure_gradient") + explicit = Label("explicit") + implicit = Label("implicit") + + # ------------------------------------------------------------------------ # + # Set up finite element objects + # ------------------------------------------------------------------------ # + + # Two shallow-water constants + H = Constant(10000.) + g = Constant(10.) + + # Set up mesh and function spaces + dt = Constant(0.01) + Nx = 5 + mesh = PeriodicUnitSquareMesh(Nx, Nx) + spaces = [FunctionSpace(mesh, "BDM", 1), FunctionSpace(mesh, "DG", 1)] + W = MixedFunctionSpace(spaces) + + # Set up fields on a mixed function space + w, phi = TestFunctions(W) + X = Function(W) + u0, h0 = split(X) + + # Set up time derivatives + mass_form = time_derivative(subject(inner(u0, w)*dx + subject(inner(h0, phi)*dx), X)) + + # Height field transport form + transport_form = transport(subject(H*phi*div(u0)*dx, X)) + + # Pressure gradient term -- integrate by parts once + pressure_gradient_form = pressure_gradient(subject(-g*div(w)*h0*dx, X)) + + # Define IMEX scheme. Transport term explicit and pressure gradient implict. + # This is not necessarily a sensible scheme -- it's just a simple demo for + # how FML can be used. + transport_form = explicit(transport_form) + pressure_gradient_form = implicit(pressure_gradient_form) + + # Add terms together to give whole residual + residual = mass_form + transport_form + pressure_gradient_form + + # ------------------------------------------------------------------------ # + # Initial condition + # ------------------------------------------------------------------------ # + + # Constant flow but sinusoidal height field + x, _ = SpatialCoordinate(mesh) + u0, h0 = X.subfunctions + u0.interpolate(as_vector([1.0, 0.0])) + h0.interpolate(H + 0.01*H*sin(2*pi*x)) + + # ------------------------------------------------------------------------ # + # Set up time discretisation + # ------------------------------------------------------------------------ # + + X_np1 = Function(W) + + # Here we would normally set up routines for the explicit and implicit parts + # but as this is just a test, we'll do just a single explicit/implicit step + + # Explicit: just forward euler + explicit_lhs = residual.label_map(lambda t: t.has_label(time_derivative), + map_if_true=replace_subject(X_np1), + map_if_false=drop) + + explicit_rhs = residual.label_map(lambda t: t.has_label(time_derivative) + or t.has_label(explicit), + map_if_true=keep, map_if_false=drop) + explicit_rhs = explicit_rhs.label_map(lambda t: t.has_label(time_derivative), + map_if_false=lambda t: -dt*t) + + # Implicit: just backward euler + implicit_lhs = residual.label_map(lambda t: t.has_label(time_derivative) + or t.has_label(implicit), + map_if_true=replace_subject(X_np1), + map_if_false=drop) + implicit_lhs = implicit_lhs.label_map(lambda t: t.has_label(time_derivative), + map_if_false=lambda t: dt*t) + + implicit_rhs = residual.label_map(lambda t: t.has_label(time_derivative), + map_if_false=drop) + + # ------------------------------------------------------------------------ # + # Set up and solve problems + # ------------------------------------------------------------------------ # + + explicit_residual = explicit_lhs - explicit_rhs + implicit_residual = implicit_lhs - implicit_rhs + + explicit_problem = NonlinearVariationalProblem(explicit_residual.form, X_np1) + explicit_solver = NonlinearVariationalSolver(explicit_problem) + + implicit_problem = NonlinearVariationalProblem(implicit_residual.form, X_np1) + implicit_solver = NonlinearVariationalSolver(implicit_problem) + + # Solve problems and update X_np1 + # In reality this would be within a time stepping loop! + explicit_solver.solve() + X.assign(X_np1) + implicit_solver.solve() + X.assign(X_np1) diff --git a/tests/unit/test_fml/test_label.py b/tests/unit/test_fml/test_label.py index 8b7fc7215c..bfae34c7bf 100644 --- a/tests/unit/test_fml/test_label.py +++ b/tests/unit/test_fml/test_label.py @@ -2,9 +2,8 @@ Tests FML's Label objects. """ -from firedrake import IntervalMesh, FunctionSpace, Function, TestFunction, dx -from gusto.configuration import TransportEquationType -from gusto.fml import Label, LabelledForm, Term +from firedrake import (IntervalMesh, FunctionSpace, Function, TestFunction, dx, + Label, LabelledForm, Term) from ufl import Form import pytest @@ -28,12 +27,6 @@ def label_and_values(label_type): bad_value = 10 new_value = 7 - elif label_type == "other": - # A label whose value is some other type - this_label = Label("foo", validator=lambda value: type(value) == TransportEquationType) - good_value = TransportEquationType.advective - new_value = TransportEquationType.conservative - elif label_type == "function": # A label whose value is an Function this_label = Label("foo", validator=lambda value: type(value) == Function) @@ -80,8 +73,7 @@ def object_to_label(object_type): raise ValueError(f'object_type {object_type} not implemented') -@pytest.mark.parametrize("label_type", ["boolean", "integer", - "other", "function"]) +@pytest.mark.parametrize("label_type", ["boolean", "integer", "function"]) @pytest.mark.parametrize("object_type", [LabelledForm, Term, Form, int]) def test_label(label_type, object_type, label_and_values, object_to_label): diff --git a/tests/unit/test_fml/test_label_map.py b/tests/unit/test_fml/test_label_map.py index 0fb31563f7..46aff8589c 100644 --- a/tests/unit/test_fml/test_label_map.py +++ b/tests/unit/test_fml/test_label_map.py @@ -2,8 +2,8 @@ Tests FML's LabelledForm label_map routine. """ -from firedrake import IntervalMesh, FunctionSpace, Function, TestFunction, dx -from gusto.fml import Label, Term, identity, drop, all_terms +from firedrake import (IntervalMesh, FunctionSpace, Function, TestFunction, dx, + Label, Term, identity, drop, all_terms) def test_label_map(): diff --git a/tests/unit/test_fml/test_labelled_form.py b/tests/unit/test_fml/test_labelled_form.py index a0176adb3b..16caa105b3 100644 --- a/tests/unit/test_fml/test_labelled_form.py +++ b/tests/unit/test_fml/test_labelled_form.py @@ -2,9 +2,8 @@ Tests FML's LabelledForm objects. """ -from firedrake import (IntervalMesh, FunctionSpace, Function, - TestFunction, dx, Constant) -from gusto.fml import Label, Term, LabelledForm +from firedrake import (IntervalMesh, FunctionSpace, Function, TestFunction, dx, + Constant, Label, Term, LabelledForm) from ufl import Form diff --git a/tests/unit/test_fml/test_replace_perp.py b/tests/unit/test_fml/test_replace_perp.py index 953f2f04ed..dcae613512 100644 --- a/tests/unit/test_fml/test_replace_perp.py +++ b/tests/unit/test_fml/test_replace_perp.py @@ -1,9 +1,9 @@ # The perp routine should come from UFL when it is fully implemented there -from gusto import perp -from gusto.fml import subject, replace_subject, all_terms +from ufl import perp from firedrake import (UnitSquareMesh, FunctionSpace, MixedFunctionSpace, TestFunctions, Function, split, inner, dx, errornorm, - SpatialCoordinate, as_vector, TrialFunctions, solve) + SpatialCoordinate, as_vector, TrialFunctions, solve, + subject, replace_subject, all_terms) def test_replace_perp(): @@ -28,7 +28,7 @@ def test_replace_perp(): # make a function to replace the subject with and give it some values U1 = Function(W) - u1, _ = U1.split() + u1, _ = U1.subfunctions x, y = SpatialCoordinate(mesh) u1.interpolate(as_vector([1, 2])) @@ -38,9 +38,9 @@ def test_replace_perp(): U2 = Function(W) solve(a == L.form, U2) - u2, _ = U2.split() + u2, _ = U2.subfunctions U3 = Function(W) - u3, _ = U3.split() + u3, _ = U3.subfunctions u3.interpolate(as_vector([-2, 1])) assert errornorm(u2, u3) < 1e-14 diff --git a/tests/unit/test_fml/test_replacement.py b/tests/unit/test_fml/test_replacement.py index e3a93cf7f1..314dea9057 100644 --- a/tests/unit/test_fml/test_replacement.py +++ b/tests/unit/test_fml/test_replacement.py @@ -1,14 +1,12 @@ """ -Tests the replace_subject routine from labels.py +Tests the different replacement routines from replacement.py """ from firedrake import (UnitSquareMesh, FunctionSpace, Function, TestFunction, TestFunctions, TrialFunction, TrialFunctions, - Argument, - VectorFunctionSpace, dx, inner, split, grad) -from gusto.fml import (Label, subject, replace_subject, - replace_test_function, replace_trial_function, - drop, all_terms) + Argument, VectorFunctionSpace, dx, inner, split, grad, + Label, subject, replace_subject, replace_test_function, + replace_trial_function, drop, all_terms) import pytest from collections import namedtuple @@ -127,7 +125,6 @@ def mixed_test_argsets(): ReplaceTestArgs(TestFunction(W), {'old_idx': 0, 'new_idx': 0}, None), ReplaceTestArgs(TestFunctions(W), {'old_idx': 0}, ValueError), ReplaceTestArgs(TestFunction(W), {'new_idx': 0}, ValueError), - # ReplaceTestArgs(TestFunctions(W), {'old_idx': 1, 'new_idx': 1}, None), ReplaceTestArgs(TestFunction(V0), {'old_idx': 0}, None), ReplaceTestArgs(TestFunctions(V0), {'new_idx': 1}, ValueError), ReplaceTestArgs(TestFunction(W), {'old_idx': 7, 'new_idx': 7}, IndexError) @@ -140,14 +137,12 @@ def mixed_trial_argsets(): ReplaceTrialArgs(TrialFunction(W), {}, None), ReplaceTrialArgs(TrialFunctions(W), {}, None), ReplaceTrialArgs(TrialFunction(W), {'old_idx': 0, 'new_idx': 0}, None), - # ReplaceTrialArgs(TrialFunctions(W), {'old_idx': 1, 'new_idx': 1}, None), ReplaceTrialArgs(TrialFunction(V0), {'old_idx': 0}, None), ReplaceTrialArgs(TrialFunctions(V0), {'new_idx': 1}, ValueError), ReplaceTrialArgs(TrialFunction(W), {'old_idx': 7, 'new_idx': 7}, IndexError), ReplaceTrialArgs(Function(W), {}, None), ReplaceTrialArgs(split(Function(W)), {}, None), ReplaceTrialArgs(Function(W), {'old_idx': 0, 'new_idx': 0}, None), - # ReplaceTrialArgs(Function(W), {'old_idx': 1, 'new_idx': 1}, None), ReplaceTrialArgs(Function(V0), {'old_idx': 0}, None), ReplaceTrialArgs(Function(V0), {'new_idx': 0}, ValueError), ReplaceTrialArgs(Function(W), {'old_idx': 7, 'new_idx': 7}, IndexError), @@ -189,8 +184,7 @@ def vector_test_argsets(): ReplaceTestArgs(TestFunction(Wv), {'new_idx': 0}, None), ReplaceTestArgs(TestFunction(Wv), {'new_idx': 1}, ValueError), ReplaceTestArgs(TestFunctions(Wv), {'new_idx': 0}, None), - # ReplaceTestArgs(TestFunction(W), {'new_idx': 0}, None), - # ReplaceTestArgs(TestFunction(W), {'new_idx': 7}, IndexError), + ReplaceTestArgs(TestFunction(W), {'new_idx': 7}, IndexError), ] return argsets diff --git a/tests/unit/test_fml/test_term.py b/tests/unit/test_fml/test_term.py index 403a7096a9..d889d663e9 100644 --- a/tests/unit/test_fml/test_term.py +++ b/tests/unit/test_fml/test_term.py @@ -2,9 +2,8 @@ Tests FML's Term objects. A term contains a form and labels. """ -from firedrake import (IntervalMesh, FunctionSpace, Function, - TestFunction, dx, Constant) -from gusto.fml import Label, Term, LabelledForm +from firedrake import (IntervalMesh, FunctionSpace, Function, TestFunction, dx, + Constant, Label, Term, LabelledForm) import pytest