diff --git a/firedrake/__init__.py b/firedrake/__init__.py index 6ac952e0ad..b1a4d94e11 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -105,6 +105,7 @@ from firedrake.ensemble import * from firedrake.randomfunctiongen import * from firedrake.progress_bar import ProgressBar # noqa: F401 +from firedrake.fml import * from firedrake.logging import * # Set default log level diff --git a/firedrake/fml/__init__.py b/firedrake/fml/__init__.py new file mode 100644 index 0000000000..d0f951305d --- /dev/null +++ b/firedrake/fml/__init__.py @@ -0,0 +1,2 @@ +from firedrake.fml.form_manipulation_language import * # noqa +from firedrake.fml.replacement import * # noqa diff --git a/firedrake/fml/form_manipulation_language.py b/firedrake/fml/form_manipulation_language.py new file mode 100644 index 0000000000..8eebc94e49 --- /dev/null +++ b/firedrake/fml/form_manipulation_language.py @@ -0,0 +1,595 @@ +"""A language for manipulating forms using labels.""" + +import ufl +import functools +import operator +from firedrake import Constant, Function +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union + + +__all__ = ["Label", "Term", "LabelledForm", "identity", "drop", "all_terms", + "keep", "subject", "name_label"] + +# ---------------------------------------------------------------------------- # +# Core routines for filtering terms +# ---------------------------------------------------------------------------- # + + +def identity(t: "Term") -> "Term": + """ The identity map. + + Parameters + ---------- + t + A term. + + Returns + ------- + Term + The same term. + + """ + return t + + +def drop(t: "Term") -> None: + """Map all terms to ``None``. + + Parameters + ---------- + t + A term. + + Returns + ------- + None + None. + + """ + return None + + +def keep(t: "Term") -> "Term": + """Keep all terms. + + Functionally equivalent to identity. + + Parameters + ---------- + t + A term. + + Returns + ------- + Term + The same term. + + """ + return t + + +def all_terms(t: "Term") -> bool: + """Map all terms to ``True``. + + Parameters + ---------- + t + A term. + + Returns + ------- + bool + True. + + """ + return True + + +# ---------------------------------------------------------------------------- # +# Term class +# ---------------------------------------------------------------------------- # +class Term(object): + """A Term object contains a form and its labels.""" + + __slots__ = ["form", "labels"] + + def __init__(self, form: ufl.Form, label_dict: Mapping = None): + """ + + Parameters + ---------- + form + The form for this terms. + label_dict + Dictionary of key-value pairs corresponding to current form labels. + Defaults to None. + + """ + self.form = form + self.labels = label_dict or {} + + def get(self, label: "Label") -> Any: + """Return the value of a label. + + Parameters + ---------- + label + The label to return the value of. + + Returns + ------- + Any + The value of a label. + + """ + return self.labels.get(label.label) + + def has_label( + self, + *labels: "Label", + return_tuple: bool = False + ) -> Union[Tuple[bool], bool]: + """Return whether the specified labels are attached to this term. + + Parameters + ---------- + *labels + A label or series of labels. A tuple is automatically returned if + multiple labels are provided as arguments. + return_tuple + If True, forces a tuple to be returned even if only one label is + provided as an argument. Defaults to False. + + Returns + ------- + bool + Booleans corresponding to whether the term has the specified labels. + + """ + if len(labels) == 1 and not return_tuple: + return labels[0].label in self.labels + else: + return tuple(self.has_label(l) for l in labels) + + def __add__(self, other: Union["Term", "LabelledForm"]) -> "LabelledForm": + """Add a term or labelled form to this term. + + Parameters + ---------- + other + The term or labelled form to add to this term. + + Returns + ------- + LabelledForm + A labelled form containing the terms. + + """ + if self is NullTerm: + return other + if other is None or other is NullTerm: + return self + elif isinstance(other, Term): + return LabelledForm(self, other) + elif isinstance(other, LabelledForm): + return LabelledForm(self, *other.terms) + else: + return NotImplemented + + __radd__ = __add__ + + def __sub__(self, other: Union["Term", "LabelledForm"]) -> "LabelledForm": + """Subtract a term or labelled form from this term. + + Parameters + ---------- + other + The term or labelled form to subtract from this term. + + Returns + ------- + LabelledForm + A labelled form containing the terms. + + """ + other = other * Constant(-1.0) + return self + other + + def __mul__( + self, + other: Union[float, Constant, ufl.algebra.Product] + ) -> "Term": + """Multiply this term by another quantity. + + Parameters + ---------- + other + The quantity to multiply this term by. + + Returns + ------- + Term + The product of the term with the quantity. + + """ + return Term(other*self.form, self.labels) + + __rmul__ = __mul__ + + def __truediv__( + self, + other: Union[float, Constant, ufl.algebra.Product] + ) -> "Term": + """Divide this term by another quantity. + + Parameters + ---------- + other + The quantity to divide this term by. + + Returns + ------- + Term + The quotient of the term divided by the quantity. + + """ + return self * (Constant(1.0) / other) + + +# This is necessary to be the initialiser for functools.reduce +NullTerm = Term(None) + + +# ---------------------------------------------------------------------------- # +# Labelled form class +# ---------------------------------------------------------------------------- # +class LabelledForm(object): + """ + A form, broken down into terms that pair individual forms with labels. + + The LabelledForm object holds a list of terms, which pair + :class:`ufl.Form` objects with :class:`Label` s. The label_map + routine allows the terms to be manipulated or selected based on particular + filters. + """ + __slots__ = ["terms"] + + def __init__(self, *terms: Sequence[Term]): + """ + Parameters + ---------- + *terms : Term + Terms to combine to make the LabelledForm. + + Raises + ------ + TypeError: If any argument is not a term. + """ + if len(terms) == 1 and isinstance(terms[0], LabelledForm): + self.terms = terms[0].terms + else: + if any([type(term) is not Term for term in list(terms)]): + raise TypeError('Can only pass terms or a LabelledForm to LabelledForm') + self.terms = list(terms) + + def __add__( + self, + other: Union[ufl.Form, Term, "LabelledForm"] + ) -> "LabelledForm": + """Add a form, term or labelled form to this labelled form. + + Parameters + ---------- + other + The form, term or labelled form to add to this labelled form. + + Returns + ------- + LabelledForm + A labelled form containing the terms. + + """ + if isinstance(other, ufl.Form): + return LabelledForm(*self, Term(other)) + elif type(other) is Term: + return LabelledForm(*self, other) + elif type(other) is LabelledForm: + return LabelledForm(*self, *other) + elif other is None: + return self + else: + return NotImplemented + + __radd__ = __add__ + + def __sub__( + self, + other: Union[ufl.Form, Term, "LabelledForm"] + ) -> "LabelledForm": + """Subtract a form, term or labelled form from this labelled form. + + Parameters + ---------- + other + The form, term or labelled form to subtract from this labelled form. + + Returns + ------- + LabelledForm + A labelled form containing the terms. + + """ + if type(other) is Term: + return LabelledForm(*self, Constant(-1.)*other) + elif type(other) is LabelledForm: + return LabelledForm(*self, *[Constant(-1.)*t for t in other]) + elif other is None: + return self + else: + # Make new Term for other and subtract it + return LabelledForm(*self, Term(Constant(-1.)*other)) + + def __mul__( + self, + other: Union[float, Constant, ufl.algebra.Product] + ) -> "LabelledForm": + """Multiply this labelled form by another quantity. + + Parameters + ---------- + other + The quantity to multiply this labelled form by. All terms in the + form are multiplied. + + Returns + ------- + LabelledForm + The product of all terms with the quantity. + + """ + return self.label_map(all_terms, lambda t: Term(other*t.form, t.labels)) + + def __truediv__( + self, + other: Union[float, Constant, ufl.algebra.Product] + ) -> "LabelledForm": + """Divide this labelled form by another quantity. + + Parameters + ---------- + other + The quantity to divide this labelled form by. All terms in the form + are divided. + + Returns + ------- + LabelledForm + The quotient of all terms with the quantity. + + """ + return self * (Constant(1.0) / other) + + __rmul__ = __mul__ + + def __iter__(self) -> Sequence: + """Iterable of the terms in the labelled form.""" + return iter(self.terms) + + def __len__(self) -> int: + """Number of terms in the labelled form.""" + return len(self.terms) + + def label_map( + self, + term_filter: Callable[[Term], bool], + map_if_true: Callable[[Term], Optional[Term]] = identity, + map_if_false: Callable[[Term], Optional[Term]] = identity + ) -> "LabelledForm": + """Map selected terms in the labelled form, returning a new labelled form. + + Parameters + ---------- + term_filter + A function to filter the labelled form's terms. + map_if_true + How to map the terms for which the term_filter returns True. + Defaults to identity. + map_if_false + How to map the terms for which the term_filter returns False. + Defaults to identity. + + Returns + ------- + LabelledForm + A new labelled form with the terms mapped. + + """ + # FIXME: The rendered docstring for this method is a mess, the lambda + # hackery at the top goes some way to fix this, but this is probably a + # bug in napoleon. + + new_labelled_form = LabelledForm( + functools.reduce(operator.add, + filter(lambda t: t is not None, + (map_if_true(t) if term_filter(t) else + map_if_false(t) for t in self.terms)), + # Need to set an initialiser, otherwise the label_map + # won't work if the term_filter is False for everything + # None does not work, as then we add Terms to None + # and the addition operation is defined from None + # rather than the Term. NullTerm solves this. + NullTerm)) + + # Drop the NullTerm + new_labelled_form.terms = list(filter(lambda t: t is not NullTerm, + new_labelled_form.terms)) + + return new_labelled_form + + @property + def form(self) -> ufl.Form: + """Provide the whole form from the labelled form. + + Raises + ------ + TypeError + If the labelled form has no terms. + + Returns + ------- + ufl.Form + The whole form corresponding to all the terms. + + """ + # Throw an error if there is no form + if len(self.terms) == 0: + raise TypeError('The labelled form cannot return a form as it has no terms') + else: + return functools.reduce(operator.add, (t.form for t in self.terms)) + + +class Label(object): + """Object for tagging forms, allowing them to be manipulated.""" + + __slots__ = ["label", "default_value", "value", "validator"] + + def __init__( + self, + label, + *, + value: Any = True, + validator: Optional[Callable] = None + ): + """ + Parameters + ---------- + label + The name of the label. + value + The value for the label to take. Can be any type (subject to the + validator). Defaults to True. + validator + Function to check the validity of any value later passed to the + label. Defaults to None. + + """ + self.label = label + self.default_value = value + self.validator = validator + + def __call__( + self, + target: Union[ufl.Form, Term, LabelledForm], + value: Any = None + ) -> Union[Term, LabelledForm]: + """Apply the label to a form or term. + + Parameters + ---------- + target + The form, term or labelled form to be labelled. + value + The value to attach to this label. Defaults to None. + + Raises + ------ + ValueError + If the `target` is not a ufl.Form, Term or + LabelledForm. + + Returns + ------- + Union[Term, LabelledForm] + A Term is returned if the target is a Term, + otherwise a LabelledForm is returned. + + """ + # if value is provided, check that we have a validator function + # and validate the value, otherwise use default value + if value is not None: + assert self.validator, f'Label {self.label} requires a validator' + assert self.validator(value), f'Value {value} for label {self.label} does not satisfy validator' + self.value = value + else: + self.value = self.default_value + if isinstance(target, LabelledForm): + return LabelledForm(*(self(t, value) for t in target.terms)) + elif isinstance(target, ufl.Form): + return LabelledForm(Term(target, {self.label: self.value})) + elif isinstance(target, Term): + new_labels = target.labels.copy() + new_labels.update({self.label: self.value}) + return Term(target.form, new_labels) + else: + raise ValueError("Unable to label %s" % target) + + def remove(self, target: Union[Term, LabelledForm]): + """Remove a label from a term or labelled form. + + This removes any Label with this ``label`` from + ``target``. If called on an LabelledForm, it acts term-wise. + + Parameters + ---------- + target + Term or labelled form to have this label removed from. + + Raises + ------ + ValueError + If the `target` is not a Term or a LabelledForm. + + """ + + if isinstance(target, LabelledForm): + return LabelledForm(*(self.remove(t) for t in target.terms)) + elif isinstance(target, Term): + try: + d = target.labels.copy() + d.pop(self.label) + return Term(target.form, d) + except KeyError: + return target + else: + raise ValueError("Unable to unlabel %s" % target) + + def update_value(self, target: Union[Term, LabelledForm], new: Any): + """Update the label of a term or labelled form. + + This updates the value of any Label with this ``label`` from + ``target``. If called on an LabelledForm, it acts term-wise. + + Parameters + ---------- + target + Term or labelled form to have this label updated. + new + The new value for this label to take. The type is subject to the + label's validator (if it has one). + + Raises + ------ + ValueError + If the `target` is not a Term or a LabelledForm. + + """ + + if isinstance(target, LabelledForm): + return LabelledForm(*(self.update_value(t, new) for t in target.terms)) + elif isinstance(target, Term): + try: + d = target.labels.copy() + d[self.label] = new + return Term(target.form, d) + except KeyError: + return target + else: + raise ValueError("Unable to relabel %s" % target) + + +# ---------------------------------------------------------------------------- # +# Some common labels +# ---------------------------------------------------------------------------- # + +subject = Label("subject", validator=lambda value: type(value) == Function) +name_label = Label("name", validator=lambda value: type(value) == str) diff --git a/firedrake/fml/replacement.py b/firedrake/fml/replacement.py new file mode 100644 index 0000000000..08610fddd6 --- /dev/null +++ b/firedrake/fml/replacement.py @@ -0,0 +1,327 @@ +""" +Generic routines for replacing functions using FML. +""" + +import ufl +from .form_manipulation_language import Term, subject +from firedrake import split, MixedElement, Function, Argument +from typing import Callable, Optional, Union + +__all__ = ["replace_test_function", "replace_trial_function", "replace_subject"] + + +# ---------------------------------------------------------------------------- # +# A general routine for building the replacement dictionary +# ---------------------------------------------------------------------------- # +def _replace_dict( + old: Union[Function, Argument], + new: Union[Function, Argument], + old_idx: Optional[int], + new_idx: Optional[int], + replace_type: str +) -> dict: + """Build a dictionary to pass to the ufl.replace routine. + + The dictionary matches variables in the old term with those in the new. + + Does not check types unless indexing is required (leave type-checking to + ufl.replace). + + Parameters + ---------- + old + The old variable to be replaced. + (Function or TestFunction or TrialFunction) + new + The new variable to be replace with. + (Function or TestFunction or TrialFunction) + old_idx + The index of the old variable to be replaced. If the old variable is not + indexable then this should be None. + new_idx + The index of the new variable to replace with. If the new variable is + not indexable then this should be None. + replace_type + A string to use in error messages, describing the type of replacement + that is happening. + + Returns + ------- + dict + A dictionary pairing the variables in the old term to be replaced with + the new variables to replace them. + + Raises + ------ + ValueError + If the old_idx argument is not provided when an indexable variable is to + be replaced by something not of the same shape. + ValueError + If the new_idx argument is not provided when an indexable variable is to + be replace something not of the same shape. + + """ + + mixed_old = type(old.ufl_element()) is MixedElement + mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement + + indexable_old = mixed_old + indexable_new = mixed_new or type(new) is tuple + + if mixed_old: + split_old = split(old) + if indexable_new: + split_new = new if type(new) is tuple else split(new) + + # check indices arguments are valid + if not indexable_old and old_idx is not None: + raise ValueError(f"old_idx should not be specified to replace_{replace_type}" + + f" when replaced {replace_type} of type {old} is not mixed.") + + if not indexable_new and new_idx is not None: + raise ValueError(f"new_idx should not be specified to replace_{replace_type} when" + + f" new {replace_type} of type {new} is not mixed or indexable.") + + if indexable_old and not indexable_new: + if old_idx is None: + raise ValueError(f"old_idx must be specified to replace_{replace_type} when replaced" + + f" {replace_type} of type {old} is mixed and new {replace_type}" + + f" of type {new} is not mixed or indexable.") + + if indexable_new and not indexable_old: + if new_idx is None: + raise ValueError(f"new_idx must be specified to replace_{replace_type} when new" + + f" {replace_type} of type {new} is mixed or indexable and" + + f" old {replace_type} of type {old} is not mixed.") + + if indexable_old and indexable_new: + # must be both True or both False + if (old_idx is None) ^ (new_idx is None): + raise ValueError("both or neither old_idx and new_idx must be specified to" + + f" replace_{replace_type} when old {replace_type} of type" + + f" {old} is mixed and new {replace_type} of type {new} is" + + " mixed or indexable.") + if old_idx is None: # both indexes are none + if len(split_old) != len(split_new): + raise ValueError(f"if neither index is specified to replace_{replace_type}" + + f" and both old {replace_type} of type {old} and new" + + f" {replace_type} of type {new} are mixed or indexable" + + f" then old of length {len(split_old)} and new of length {len(split_new)}" + + " must be the same length.") + + # make the replace_dict + + replace_dict = {} + + if not indexable_old and not indexable_new: + replace_dict[old] = new + + elif not indexable_old and indexable_new: + replace_dict[old] = split_new[new_idx] + + elif indexable_old and not indexable_new: + replace_dict[split_old[old_idx]] = new + + elif indexable_old and indexable_new: + if old_idx is None: # replace everything + for k, v in zip(split_old, split_new): + replace_dict[k] = v + else: # idxs are given + replace_dict[split_old[old_idx]] = split_new[new_idx] + + return replace_dict + + +# ---------------------------------------------------------------------------- # +# Replacement routines +# ---------------------------------------------------------------------------- # +def replace_test_function( + new_test: Argument, + old_idx: Optional[int] = None, + new_idx: Optional[int] = None +) -> Callable[[Term], Term]: + """Replace the test function in a term with a new test function. + + Parameters + ---------- + new_test + The new test function. + old_idx + The index of the old TestFunction to be replaced. If the old + variable is not indexable then this should be None. + new_idx + The index of the new TestFunction to replace with. If the new + variable is not indexable then this should be None. + + Returns + ------- + Callable + A function that takes in t, a .Term, and returns a new + .Term with form containing the ``new_test`` and + ``labels=t.labels`` + + """ + + def repl(t: Term) -> Term: + """Replace the test function in a term with a new expression. + + This is built around the UFL replace routine. + + Parameters + ---------- + t + The original term. + + Returns + ------- + Term + The new term. + + """ + old_test = t.form.arguments()[0] + replace_dict = _replace_dict(old_test, new_test, + old_idx=old_idx, new_idx=new_idx, + replace_type='test') + + try: + new_form = ufl.replace(t.form, replace_dict) + except Exception as err: + error_message = f"{type(err)} raised by ufl.replace when trying to" \ + + f" replace_test_function with {new_test}" + raise type(err)(error_message) from err + + return Term(new_form, t.labels) + + return repl + + +def replace_trial_function( + new_trial: Union[Argument, Function], + old_idx: Optional[int] = None, + new_idx: Optional[int] = None +) -> Callable[[Term], Term]: + """Replace the trial function in a term with a new expression. + + Parameters + ---------- + new + The new function. + old_idx + The index of the old Function or TrialFunction to be replaced. + If the old variable is not indexable then this should be None. + new_idx + The index of the new Function or TrialFunction to replace with. + If the new variable is not indexable then this should be None. + + Returns + ------- + Callable + A function that takes in t, a Term, and returns a new + Term with form containing the ``new_test`` and + ``labels=t.labels`` + + """ + + def repl(t: Term) -> Term: + """Replace the trial function in a term with a new expression. + + This is built around the UFL replace routine. + + Parameters + ---------- + t + The original term. + + Raises + ------ + TypeError + If the form is not linear. + + Returns + ------- + Term + The new term. + + """ + if len(t.form.arguments()) != 2: + raise TypeError('Trying to replace trial function of a form that is not linear') + old_trial = t.form.arguments()[1] + replace_dict = _replace_dict(old_trial, new_trial, + old_idx=old_idx, new_idx=new_idx, + replace_type='trial') + + try: + new_form = ufl.replace(t.form, replace_dict) + except Exception as err: + error_message = f"{type(err)} raised by ufl.replace when trying to" \ + + f" replace_trial_function with {new_trial}" + raise type(err)(error_message) from err + + return Term(new_form, t.labels) + + return repl + + +def replace_subject( + new_subj: ufl.core.expr.Expr, + old_idx: Optional[int] = None, + new_idx: Optional[int] = None +) -> Callable[[Term], Term]: + """Replace the subject in a term with a new variable. + + Parameters + ---------- + new + The new expression to replace the subject. + old_idx + The index of the old subject to be replaced. If the old + variable is not indexable then this should be None. + new_idx + The index of the new subject to replace with. If the new + variable is not indexable then this should be None. + + Returns + ------- + Callable + A function that takes in t, a Term, and returns a new Term with + form containing the ``new_test`` and ``labels=t.labels`` + + """ + def repl(t: Term) -> Term: + """Replace the subject in a term with a new expression. + + This is built around the UFL replace routine. + + Parameters + ---------- + t + The original term. + + Raises + ------ + ValueError + When the new expression and subject are not of compatible sizes + (e.g. a mixed function vs a non-mixed function) + + Returns + ------- + Term + The new term. + + """ + + old_subj = t.get(subject) + replace_dict = _replace_dict(old_subj, new_subj, + old_idx=old_idx, new_idx=new_idx, + replace_type='subject') + + try: + new_form = ufl.replace(t.form, replace_dict) + except Exception as err: + error_message = f"{type(err)} raised by ufl.replace when trying to" \ + + f" replace_subject with {new_subj}" + raise type(err)(error_message) from err + + return Term(new_form, t.labels) + + return repl diff --git a/tests/regression/test_fml.py b/tests/regression/test_fml.py new file mode 100644 index 0000000000..f177e1da24 --- /dev/null +++ b/tests/regression/test_fml.py @@ -0,0 +1,121 @@ +""" +Tests a full workflow for the Form Manipulation Language (FML). + +This uses an IMEX discretisation of the linear shallow-water equations on a +mixed function space. +""" + +from firedrake import (PeriodicUnitSquareMesh, FunctionSpace, Constant, + MixedFunctionSpace, TestFunctions, Function, split, + inner, dx, SpatialCoordinate, as_vector, pi, sin, div, + NonlinearVariationalProblem, NonlinearVariationalSolver, + subject, replace_subject, keep, drop, Label) + + +def test_fml(): + + # Define labels for shallow-water + time_derivative = Label("time_derivative") + transport = Label("transport") + pressure_gradient = Label("pressure_gradient") + explicit = Label("explicit") + implicit = Label("implicit") + + # ------------------------------------------------------------------------ # + # Set up finite element objects + # ------------------------------------------------------------------------ # + + # Two shallow-water constants + H = Constant(10000.) + g = Constant(10.) + + # Set up mesh and function spaces + dt = Constant(0.01) + Nx = 5 + mesh = PeriodicUnitSquareMesh(Nx, Nx) + spaces = [FunctionSpace(mesh, "BDM", 1), FunctionSpace(mesh, "DG", 1)] + W = MixedFunctionSpace(spaces) + + # Set up fields on a mixed function space + w, phi = TestFunctions(W) + X = Function(W) + u0, h0 = split(X) + + # Set up time derivatives + mass_form = time_derivative(subject(inner(u0, w)*dx + subject(inner(h0, phi)*dx), X)) + + # Height field transport form + transport_form = transport(subject(H*inner(div(u0), phi)*dx, X)) + + # Pressure gradient term -- integrate by parts once + pressure_gradient_form = pressure_gradient(subject(-g*inner(h0, div(w))*dx, X)) + + # Define IMEX scheme. Transport term explicit and pressure gradient implict. + # This is not necessarily a sensible scheme -- it's just a simple demo for + # how FML can be used. + transport_form = explicit(transport_form) + pressure_gradient_form = implicit(pressure_gradient_form) + + # Add terms together to give whole residual + residual = mass_form + transport_form + pressure_gradient_form + + # ------------------------------------------------------------------------ # + # Initial condition + # ------------------------------------------------------------------------ # + + # Constant flow but sinusoidal height field + x, _ = SpatialCoordinate(mesh) + u0, h0 = X.subfunctions + u0.interpolate(as_vector([1.0, 0.0])) + h0.interpolate(H + 0.01*H*sin(2*pi*x)) + + # ------------------------------------------------------------------------ # + # Set up time discretisation + # ------------------------------------------------------------------------ # + + X_np1 = Function(W) + + # Here we would normally set up routines for the explicit and implicit parts + # but as this is just a test, we'll do just a single explicit/implicit step + + # Explicit: just forward euler + explicit_lhs = residual.label_map(lambda t: t.has_label(time_derivative), + map_if_true=replace_subject(X_np1), + map_if_false=drop) + + explicit_rhs = residual.label_map(lambda t: t.has_label(time_derivative) + or t.has_label(explicit), + map_if_true=keep, map_if_false=drop) + explicit_rhs = explicit_rhs.label_map(lambda t: t.has_label(time_derivative), + map_if_false=lambda t: -dt*t) + + # Implicit: just backward euler + implicit_lhs = residual.label_map(lambda t: t.has_label(time_derivative) + or t.has_label(implicit), + map_if_true=replace_subject(X_np1), + map_if_false=drop) + implicit_lhs = implicit_lhs.label_map(lambda t: t.has_label(time_derivative), + map_if_false=lambda t: dt*t) + + implicit_rhs = residual.label_map(lambda t: t.has_label(time_derivative), + map_if_false=drop) + + # ------------------------------------------------------------------------ # + # Set up and solve problems + # ------------------------------------------------------------------------ # + + explicit_residual = explicit_lhs - explicit_rhs + implicit_residual = implicit_lhs - implicit_rhs + + explicit_problem = NonlinearVariationalProblem(explicit_residual.form, X_np1) + explicit_solver = NonlinearVariationalSolver(explicit_problem) + + implicit_problem = NonlinearVariationalProblem(implicit_residual.form, X_np1) + implicit_solver = NonlinearVariationalSolver(implicit_problem) + + # Solve problems and update X_np1 + # In reality this would be within a time stepping loop! + explicit_solver.solve() + X.assign(X_np1) + implicit_solver.solve() + X.assign(X_np1) diff --git a/tests/unit/test_fml/test_label.py b/tests/unit/test_fml/test_label.py new file mode 100644 index 0000000000..bfae34c7bf --- /dev/null +++ b/tests/unit/test_fml/test_label.py @@ -0,0 +1,182 @@ +""" +Tests FML's Label objects. +""" + +from firedrake import (IntervalMesh, FunctionSpace, Function, TestFunction, dx, + Label, LabelledForm, Term) +from ufl import Form +import pytest + + +@pytest.fixture +def label_and_values(label_type): + # Returns labels with different value validation + + bad_value = "bar" + + if label_type == "boolean": + # A label that is simply a string, whose value is Boolean + this_label = Label("foo") + good_value = True + new_value = False + + elif label_type == "integer": + # A label whose value is an integer + this_label = Label("foo", validator=lambda value: (type(value) == int and value < 9)) + good_value = 5 + bad_value = 10 + new_value = 7 + + elif label_type == "function": + # A label whose value is an Function + this_label = Label("foo", validator=lambda value: type(value) == Function) + good_value, _ = setup_form() + new_value = Function(good_value.function_space()) + + return this_label, good_value, bad_value, new_value + + +def setup_form(): + # Create mesh and function space + L = 3.0 + n = 3 + mesh = IntervalMesh(n, L) + V = FunctionSpace(mesh, "DG", 0) + f = Function(V) + g = TestFunction(V) + form = f*g*dx + + return f, form + + +@pytest.fixture +def object_to_label(object_type): + # A series of different objects to be labelled + + if object_type == int: + return 10 + + else: + _, form = setup_form() + term = Term(form) + + if object_type == Form: + return form + + elif object_type == Term: + return term + + elif object_type == LabelledForm: + return LabelledForm(term) + + else: + raise ValueError(f'object_type {object_type} not implemented') + + +@pytest.mark.parametrize("label_type", ["boolean", "integer", "function"]) +@pytest.mark.parametrize("object_type", [LabelledForm, Term, Form, int]) +def test_label(label_type, object_type, label_and_values, object_to_label): + + label, good_value, bad_value, new_value = label_and_values + + # ------------------------------------------------------------------------ # + # Check label has correct name + # ------------------------------------------------------------------------ # + + assert label.label == "foo", "Label has incorrect name" + + # ------------------------------------------------------------------------ # + # Check we can't label unsupported objects + # ------------------------------------------------------------------------ # + + if object_type == int: + # Can't label integers, so check this fails and force end + try: + labelled_object = label(object_to_label) + except ValueError: + # Appropriate error has been returned so end the test + return + + # If we get here there has been an error + assert False, "Labelling an integer should throw an error" + + # ------------------------------------------------------------------------ # + # Test application of labels + # ------------------------------------------------------------------------ # + + if label_type == "boolean": + labelled_object = label(object_to_label) + + else: + # Check that passing an inappropriate label gives the correct error + try: + labelled_object = label(object_to_label, bad_value) + # If we get here the validator has not worked + assert False, 'The labelling validator has not worked for ' \ + + f'label_type {label_type} and object_type {object_type}' + + except AssertionError: + # Now label object properly + labelled_object = label(object_to_label, good_value) + + # ------------------------------------------------------------------------ # + # Check labelled form or term has been returned + # ------------------------------------------------------------------------ # + + if object_type == Term: + assert type(labelled_object) == Term, 'Labelled Term should be a ' \ + + f'be a Term and not type {type(labelled_object)}' + else: + assert type(labelled_object) == LabelledForm, 'Labelled Form should ' \ + + f'be a Labelled Form and not type {type(labelled_object)}' + + # ------------------------------------------------------------------------ # + # Test that the values are correct + # ------------------------------------------------------------------------ # + + if object_type == Term: + assert labelled_object.get(label) == good_value, 'Value of label ' \ + + f'should be {good_value} and not {labelled_object.get(label)}' + else: + assert labelled_object.terms[0].get(label) == good_value, 'Value of ' \ + + f'label should be {good_value} and not ' \ + + f'{labelled_object.terms[0].get(label)}' + + # ------------------------------------------------------------------------ # + # Test updating of values + # ------------------------------------------------------------------------ # + + # Check that passing an inappropriate label gives the correct error + try: + labelled_object = label.update_value(labelled_object, bad_value) + # If we get here the validator has not worked + assert False, 'The validator has not worked for updating label of ' \ + + f'label_type {label_type} and object_type {object_type}' + except AssertionError: + # Update new value + labelled_object = label.update_value(labelled_object, new_value) + + # Check that new value is correct + if object_type == Term: + assert labelled_object.get(label) == new_value, 'Updated value of ' \ + + f'label should be {new_value} and not {labelled_object.get(label)}' + else: + assert labelled_object.terms[0].get(label) == new_value, 'Updated ' \ + + f'value of label should be {new_value} and not ' \ + + f'{labelled_object.terms[0].get(label)}' + + # ------------------------------------------------------------------------ # + # Test removal of values + # ------------------------------------------------------------------------ # + + labelled_object = label.remove(labelled_object) + + # Try to see if object still has that label + if object_type == Term: + label_value = labelled_object.get(label) + else: + label_value = labelled_object.terms[0].get(label) + + # If we get here then the label has been extracted but it shouldn't have + assert label_value is None, f'The label {label_type} appears has not to ' \ + + f'have been removed for object_type {object_type}' diff --git a/tests/unit/test_fml/test_label_map.py b/tests/unit/test_fml/test_label_map.py new file mode 100644 index 0000000000..46aff8589c --- /dev/null +++ b/tests/unit/test_fml/test_label_map.py @@ -0,0 +1,74 @@ +""" +Tests FML's LabelledForm label_map routine. +""" + +from firedrake import (IntervalMesh, FunctionSpace, Function, TestFunction, dx, + Label, Term, identity, drop, all_terms) + + +def test_label_map(): + + # ------------------------------------------------------------------------ # + # Set up labelled forms + # ------------------------------------------------------------------------ # + + # Some basic labels + foo_label = Label("foo") + bar_label = Label("bar", validator=lambda value: type(value) == int) + + # Create mesh, function space and forms + L = 3.0 + n = 3 + mesh = IntervalMesh(n, L) + V = FunctionSpace(mesh, "DG", 0) + f = Function(V) + g = Function(V) + test = TestFunction(V) + form_1 = f*test*dx + form_2 = g*test*dx + term_1 = foo_label(Term(form_1)) + term_2 = bar_label(Term(form_2), 5) + + labelled_form = term_1 + term_2 + + # ------------------------------------------------------------------------ # + # Test all_terms + # ------------------------------------------------------------------------ # + + # Passing all_terms should return the same labelled form + new_labelled_form = labelled_form.label_map(all_terms) + assert len(new_labelled_form) == len(labelled_form), \ + 'new_labelled_form should be the same as labelled_form' + for new_term, term in zip(new_labelled_form.terms, labelled_form.terms): + assert new_term == term, 'terms in new_labelled_form should be the ' + \ + 'same as those in labelled_form' + + # ------------------------------------------------------------------------ # + # Test identity and drop + # ------------------------------------------------------------------------ # + + # Get just the first term, which has the foo label + new_labelled_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), map_if_true=identity, map_if_false=drop + ) + assert len(new_labelled_form) == 1, 'new_labelled_form should be length 1' + for new_term in new_labelled_form.terms: + assert new_term.has_label(foo_label), 'All terms in ' + \ + 'new_labelled_form should have foo_label' + + # Give term_1 the bar label + new_labelled_form = labelled_form.label_map( + lambda t: t.has_label(bar_label), map_if_true=identity, + map_if_false=lambda t: bar_label(t, 0) + ) + assert len(new_labelled_form) == 2, 'new_labelled_form should be length 2' + for new_term in new_labelled_form.terms: + assert new_term.has_label(bar_label), 'All terms in ' + \ + 'new_labelled_form should have bar_label' + + # Test with a more complex filter, which should give an empty labelled_form + new_labelled_form = labelled_form.label_map( + lambda t: (t.has_label(bar_label) and t.get(bar_label) > 10), + map_if_true=identity, map_if_false=drop + ) + assert len(new_labelled_form) == 0, 'new_labelled_form should be length 0' diff --git a/tests/unit/test_fml/test_labelled_form.py b/tests/unit/test_fml/test_labelled_form.py new file mode 100644 index 0000000000..16caa105b3 --- /dev/null +++ b/tests/unit/test_fml/test_labelled_form.py @@ -0,0 +1,133 @@ +""" +Tests FML's LabelledForm objects. +""" + +from firedrake import (IntervalMesh, FunctionSpace, Function, TestFunction, dx, + Constant, Label, Term, LabelledForm) +from ufl import Form + + +def test_labelled_form(): + + # ------------------------------------------------------------------------ # + # Set up labelled forms + # ------------------------------------------------------------------------ # + + # Some basic labels + lorem_label = Label("lorem", validator=lambda value: type(value) == str) + ipsum_label = Label("ipsum", validator=lambda value: type(value) == int) + + # Create mesh, function space and forms + L = 3.0 + n = 3 + mesh = IntervalMesh(n, L) + V = FunctionSpace(mesh, "DG", 0) + f = Function(V) + g = Function(V) + test = TestFunction(V) + form_1 = f*test*dx + form_2 = g*test*dx + term_1 = lorem_label(Term(form_1), 'i_have_lorem') + term_2 = ipsum_label(Term(form_2), 5) + + # ------------------------------------------------------------------------ # + # Test labelled forms have the correct number of terms + # ------------------------------------------------------------------------ # + + # Create from a single term + labelled_form_1 = LabelledForm(term_1) + assert len(labelled_form_1) == 1, 'LabelledForm should have 1 term' + + # Create from multiple terms + labelled_form_2 = LabelledForm(*[term_1, term_2]) + assert len(labelled_form_2) == 2, 'LabelledForm should have 2 terms' + + # Trying to create from two LabelledForms should give an error + try: + labelled_form_3 = LabelledForm(labelled_form_1, labelled_form_2) + # If we get here something has gone wrong + assert False, 'We should not be able to create LabelledForm ' + \ + 'from two LabelledForms' + except TypeError: + pass + + # Create from a single LabelledForm + labelled_form_3 = LabelledForm(labelled_form_1) + assert len(labelled_form_3) == 1, 'LabelledForm should have 1 term' + + # ------------------------------------------------------------------------ # + # Test getting form + # ------------------------------------------------------------------------ # + + assert type(labelled_form_1.form) is Form, 'The form belonging to the ' + \ + f'LabelledForm must be a Form, and not {type(labelled_form_1.form)}' + + assert type(labelled_form_2.form) is Form, 'The form belonging to the ' + \ + f'LabelledForm must be a Form, and not {type(labelled_form_2.form)}' + + assert type(labelled_form_3.form) is Form, 'The form belonging to the ' + \ + f'LabelledForm must be a Form, and not {type(labelled_form_3.form)}' + + # ------------------------------------------------------------------------ # + # Test addition and subtraction of labelled forms + # ------------------------------------------------------------------------ # + + # Add a Form to a LabelledForm + new_labelled_form = labelled_form_1 + form_2 + assert len(new_labelled_form) == 2, 'LabelledForm should have 2 terms' + + # Add a Term to a LabelledForm + new_labelled_form = labelled_form_1 + term_2 + assert len(new_labelled_form) == 2, 'LabelledForm should have 2 terms' + + # Add a LabelledForm to a LabelledForm + new_labelled_form = labelled_form_1 + labelled_form_2 + assert len(new_labelled_form) == 3, 'LabelledForm should have 3 terms' + + # Adding None to a LabelledForm should give the same LabelledForm + new_labelled_form = labelled_form_1 + None + assert new_labelled_form == labelled_form_1, 'Two LabelledForms should be equal' + + # Subtract a Form from a LabelledForm + new_labelled_form = labelled_form_1 - form_2 + assert len(new_labelled_form) == 2, 'LabelledForm should have 2 terms' + + # Subtract a Term from a LabelledForm + new_labelled_form = labelled_form_1 - term_2 + assert len(new_labelled_form) == 2, 'LabelledForm should have 2 terms' + + # Subtract a LabelledForm from a LabelledForm + new_labelled_form = labelled_form_1 - labelled_form_2 + assert len(new_labelled_form) == 3, 'LabelledForm should have 3 terms' + + # Subtracting None from a LabelledForm should give the same LabelledForm + new_labelled_form = labelled_form_1 - None + assert new_labelled_form == labelled_form_1, 'Two LabelledForms should be equal' + + # ------------------------------------------------------------------------ # + # Test multiplication and division of labelled forms + # ------------------------------------------------------------------------ # + + # Multiply by integer + new_labelled_form = labelled_form_1 * -4 + assert len(new_labelled_form) == 1, 'LabelledForm should have 1 term' + + # Multiply by float + new_labelled_form = labelled_form_1 * 12.4 + assert len(new_labelled_form) == 1, 'LabelledForm should have 1 term' + + # Multiply by Constant + new_labelled_form = labelled_form_1 * Constant(5.0) + assert len(new_labelled_form) == 1, 'LabelledForm should have 1 term' + + # Divide by integer + new_labelled_form = labelled_form_1 / (-8) + assert len(new_labelled_form) == 1, 'LabelledForm should have 1 term' + + # Divide by float + new_labelled_form = labelled_form_1 / (-6.2) + assert len(new_labelled_form) == 1, 'LabelledForm should have 1 term' + + # Divide by Constant + new_labelled_form = labelled_form_1 / Constant(0.01) + assert len(new_labelled_form) == 1, 'LabelledForm should have 1 term' diff --git a/tests/unit/test_fml/test_replace_perp.py b/tests/unit/test_fml/test_replace_perp.py new file mode 100644 index 0000000000..c94cc7c617 --- /dev/null +++ b/tests/unit/test_fml/test_replace_perp.py @@ -0,0 +1,46 @@ +# The perp routine should come from UFL when it is fully implemented there +from ufl import perp +from firedrake import (UnitSquareMesh, FunctionSpace, MixedFunctionSpace, + TestFunctions, Function, split, inner, dx, errornorm, + SpatialCoordinate, as_vector, TrialFunctions, solve, + subject, replace_subject, all_terms) + + +def test_replace_perp(): + + # The test checks that if the perp operator is applied to the + # subject of a labelled form, the perp of the subject is found and + # replaced by the replace_subject function. This gave particular problems + # before the perp operator was defined + + # set up mesh and function spaces - the subject is defined on a + # mixed function space because the problem didn't occur otherwise + Nx = 5 + mesh = UnitSquareMesh(Nx, Nx) + spaces = [FunctionSpace(mesh, "BDM", 1), FunctionSpace(mesh, "DG", 1)] + W = MixedFunctionSpace(spaces) + + # set up labelled form with subject u + w, p = TestFunctions(W) + U0 = Function(W) + u0, _ = split(U0) + form = subject(inner(perp(u0), w)*dx, U0) + + # make a function to replace the subject with and give it some values + U1 = Function(W) + u1, _ = U1.subfunctions + x, y = SpatialCoordinate(mesh) + u1.interpolate(as_vector([1, 2])) + + u, D = TrialFunctions(W) + a = inner(u, w)*dx + inner(D, p)*dx + L = form.label_map(all_terms, replace_subject(U1, old_idx=0, new_idx=0)) + U2 = Function(W) + solve(a == L.form, U2) + + u2, _ = U2.subfunctions + U3 = Function(W) + u3, _ = U3.subfunctions + u3.interpolate(as_vector([-2, 1])) + + assert errornorm(u2, u3) < 1e-14 diff --git a/tests/unit/test_fml/test_replacement.py b/tests/unit/test_fml/test_replacement.py new file mode 100644 index 0000000000..314dea9057 --- /dev/null +++ b/tests/unit/test_fml/test_replacement.py @@ -0,0 +1,372 @@ +""" +Tests the different replacement routines from replacement.py +""" + +from firedrake import (UnitSquareMesh, FunctionSpace, Function, TestFunction, + TestFunctions, TrialFunction, TrialFunctions, + Argument, VectorFunctionSpace, dx, inner, split, grad, + Label, subject, replace_subject, replace_test_function, + replace_trial_function, drop, all_terms) +import pytest + +from collections import namedtuple + +ReplaceSubjArgs = namedtuple("ReplaceSubjArgs", "new_subj idxs error") +ReplaceArgsArgs = namedtuple("ReplaceArgsArgs", "new_arg idxs error replace_function arg_idx") + + +def ReplaceTestArgs(*args): + return ReplaceArgsArgs(*args, replace_test_function, 0) + + +def ReplaceTrialArgs(*args): + return ReplaceArgsArgs(*args, replace_trial_function, 1) + + +# some dummy labels +foo_label = Label("foo") +bar_label = Label("bar") + +nx = 2 +mesh = UnitSquareMesh(nx, nx) +V0 = FunctionSpace(mesh, 'CG', 1) +V1 = FunctionSpace(mesh, 'DG', 1) +W = V0*V1 +Vv = VectorFunctionSpace(mesh, 'CG', 1) +Wv = Vv*V1 + + +@pytest.fixture() +def primal_form(): + primal_subj = Function(V0) + primal_test = TestFunction(V0) + + primal_term1 = foo_label(subject(primal_subj*primal_test*dx, primal_subj)) + primal_term2 = bar_label(inner(grad(primal_subj), grad(primal_test))*dx) + + return primal_term1 + primal_term2 + + +def primal_subj_argsets(): + argsets = [ + ReplaceSubjArgs(Function(V0), {}, None), + ReplaceSubjArgs(Function(V0), {'new_idx': 0}, ValueError), + ReplaceSubjArgs(Function(V0), {'old_idx': 0}, ValueError), + ReplaceSubjArgs(Function(W), {'new_idx': 0}, None), + ReplaceSubjArgs(Function(W), {'new_idx': 1}, None), + ReplaceSubjArgs(split(Function(W)), {'new_idx': 1}, None), + ReplaceSubjArgs(Function(W), {'old_idx': 0}, ValueError), + ReplaceSubjArgs(Function(W), {'new_idx': 7}, IndexError) + ] + return argsets + + +def primal_test_argsets(): + argsets = [ + ReplaceTestArgs(TestFunction(V0), {}, None), + ReplaceTestArgs(TestFunction(V0), {'new_idx': 0}, ValueError), + ReplaceTestArgs(TestFunction(W), {'new_idx': 0}, None), + ReplaceTestArgs(TestFunction(W), {'new_idx': 1}, None), + ReplaceTestArgs(TestFunctions(W), {'new_idx': 1}, None), + ReplaceTestArgs(TestFunction(W), {'new_idx': 7}, IndexError) + ] + return argsets + + +def primal_trial_argsets(): + argsets = [ + ReplaceTrialArgs(TrialFunction(V0), {}, None), + ReplaceTrialArgs(TrialFunction(V0), {'new_idx': 0}, ValueError), + ReplaceTrialArgs(TrialFunction(W), {'new_idx': 0}, None), + ReplaceTrialArgs(TrialFunction(W), {'new_idx': 1}, None), + ReplaceTrialArgs(TrialFunctions(W), {'new_idx': 1}, None), + ReplaceTrialArgs(TrialFunction(W), {'new_idx': 7}, IndexError), + ReplaceTrialArgs(Function(V0), {}, None), + ReplaceTrialArgs(Function(V0), {'new_idx': 0}, ValueError), + ReplaceTrialArgs(Function(W), {'new_idx': 0}, None), + ReplaceTrialArgs(Function(W), {'new_idx': 1}, None), + ReplaceTrialArgs(split(Function(W)), {'new_idx': 1}, None), + ReplaceTrialArgs(Function(W), {'new_idx': 7}, IndexError), + ] + return argsets + + +@pytest.fixture +def mixed_form(): + mixed_subj = Function(W) + mixed_test = TestFunction(W) + + mixed_subj0, mixed_subj1 = split(mixed_subj) + mixed_test0, mixed_test1 = split(mixed_test) + + mixed_term1 = foo_label(subject(mixed_subj0*mixed_test0*dx, mixed_subj)) + mixed_term2 = bar_label(inner(grad(mixed_subj1), grad(mixed_test1))*dx) + + return mixed_term1 + mixed_term2 + + +def mixed_subj_argsets(): + argsets = [ + ReplaceSubjArgs(Function(W), {}, None), + ReplaceSubjArgs(Function(W), {'new_idx': 0, 'old_idx': 0}, None), + ReplaceSubjArgs(Function(W), {'old_idx': 0}, ValueError), + ReplaceSubjArgs(Function(W), {'new_idx': 0}, ValueError), + ReplaceSubjArgs(Function(V0), {'old_idx': 0}, None), + ReplaceSubjArgs(Function(V0), {'new_idx': 0}, ValueError), + ReplaceSubjArgs(split(Function(W)), {'new_idx': 0, 'old_idx': 0}, None), + ] + return argsets + + +def mixed_test_argsets(): + argsets = [ + ReplaceTestArgs(TestFunction(W), {}, None), + ReplaceTestArgs(TestFunctions(W), {}, None), + ReplaceTestArgs(TestFunction(W), {'old_idx': 0, 'new_idx': 0}, None), + ReplaceTestArgs(TestFunctions(W), {'old_idx': 0}, ValueError), + ReplaceTestArgs(TestFunction(W), {'new_idx': 0}, ValueError), + ReplaceTestArgs(TestFunction(V0), {'old_idx': 0}, None), + ReplaceTestArgs(TestFunctions(V0), {'new_idx': 1}, ValueError), + ReplaceTestArgs(TestFunction(W), {'old_idx': 7, 'new_idx': 7}, IndexError) + ] + return argsets + + +def mixed_trial_argsets(): + argsets = [ + ReplaceTrialArgs(TrialFunction(W), {}, None), + ReplaceTrialArgs(TrialFunctions(W), {}, None), + ReplaceTrialArgs(TrialFunction(W), {'old_idx': 0, 'new_idx': 0}, None), + ReplaceTrialArgs(TrialFunction(V0), {'old_idx': 0}, None), + ReplaceTrialArgs(TrialFunctions(V0), {'new_idx': 1}, ValueError), + ReplaceTrialArgs(TrialFunction(W), {'old_idx': 7, 'new_idx': 7}, IndexError), + ReplaceTrialArgs(Function(W), {}, None), + ReplaceTrialArgs(split(Function(W)), {}, None), + ReplaceTrialArgs(Function(W), {'old_idx': 0, 'new_idx': 0}, None), + ReplaceTrialArgs(Function(V0), {'old_idx': 0}, None), + ReplaceTrialArgs(Function(V0), {'new_idx': 0}, ValueError), + ReplaceTrialArgs(Function(W), {'old_idx': 7, 'new_idx': 7}, IndexError), + ] + return argsets + + +@pytest.fixture +def vector_form(): + vector_subj = Function(Vv) + vector_test = TestFunction(Vv) + + vector_term1 = foo_label(subject(inner(vector_subj, vector_test)*dx, vector_subj)) + vector_term2 = bar_label(inner(grad(vector_subj), grad(vector_test))*dx) + + return vector_term1 + vector_term2 + + +def vector_subj_argsets(): + argsets = [ + ReplaceSubjArgs(Function(Vv), {}, None), + ReplaceSubjArgs(Function(V0), {}, ValueError), + ReplaceSubjArgs(Function(Vv), {'new_idx': 0}, ValueError), + ReplaceSubjArgs(Function(Vv), {'old_idx': 0}, ValueError), + ReplaceSubjArgs(Function(Wv), {'new_idx': 0}, None), + ReplaceSubjArgs(Function(Wv), {'new_idx': 1}, ValueError), + ReplaceSubjArgs(split(Function(Wv)), {'new_idx': 0}, None), + ReplaceSubjArgs(Function(W), {'old_idx': 0}, ValueError), + ReplaceSubjArgs(Function(W), {'new_idx': 7}, IndexError), + ] + return argsets + + +def vector_test_argsets(): + argsets = [ + ReplaceTestArgs(TestFunction(Vv), {}, None), + ReplaceTestArgs(TestFunction(V0), {}, ValueError), + ReplaceTestArgs(TestFunction(Vv), {'new_idx': 0}, ValueError), + ReplaceTestArgs(TestFunction(Wv), {'new_idx': 0}, None), + ReplaceTestArgs(TestFunction(Wv), {'new_idx': 1}, ValueError), + ReplaceTestArgs(TestFunctions(Wv), {'new_idx': 0}, None), + ReplaceTestArgs(TestFunction(W), {'new_idx': 7}, IndexError), + ] + return argsets + + +@pytest.mark.parametrize('argset', primal_subj_argsets()) +def test_replace_subject_primal(primal_form, argset): + new_subj = argset.new_subj + idxs = argset.idxs + error = argset.error + + if error is None: + old_subj = primal_form.form.coefficients()[0] + + new_form = primal_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(new_subj, **idxs), + map_if_false=drop) + + # what if we only replace part of the subject? + if 'new_idx' in idxs: + split_new = new_subj if type(new_subj) is tuple else split(new_subj) + new_subj = split_new[idxs['new_idx']].ufl_operands[0] + + assert new_subj in new_form.form.coefficients() + assert old_subj not in new_form.form.coefficients() + + else: + with pytest.raises(error): + new_form = primal_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(new_subj, **idxs)) + + +@pytest.mark.parametrize('argset', mixed_subj_argsets()) +def test_replace_subject_mixed(mixed_form, argset): + new_subj = argset.new_subj + idxs = argset.idxs + error = argset.error + + if error is None: + old_subj = mixed_form.form.coefficients()[0] + + new_form = mixed_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(new_subj, **idxs), + map_if_false=drop) + + # what if we only replace part of the subject? + if 'new_idx' in idxs: + split_new = new_subj if type(new_subj) is tuple else split(new_subj) + new_subj = split_new[idxs['new_idx']].ufl_operands[0] + + assert new_subj in new_form.form.coefficients() + assert old_subj not in new_form.form.coefficients() + + else: + with pytest.raises(error): + new_form = mixed_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(new_subj, **idxs)) + + +@pytest.mark.parametrize('argset', vector_subj_argsets()) +def test_replace_subject_vector(vector_form, argset): + new_subj = argset.new_subj + idxs = argset.idxs + error = argset.error + + if error is None: + old_subj = vector_form.form.coefficients()[0] + + new_form = vector_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(new_subj, **idxs), + map_if_false=drop) + + # what if we only replace part of the subject? + if 'new_idx' in idxs: + split_new = new_subj if type(new_subj) is tuple else split(new_subj) + new_subj = split_new[idxs['new_idx']].ufl_operands[0].ufl_operands[0] + + assert new_subj in new_form.form.coefficients() + assert old_subj not in new_form.form.coefficients() + + else: + with pytest.raises(error): + new_form = vector_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(new_subj, **idxs)) + + +@pytest.mark.parametrize('argset', primal_test_argsets() + primal_trial_argsets()) +def test_replace_arg_primal(primal_form, argset): + new_arg = argset.new_arg + idxs = argset.idxs + error = argset.error + replace_function = argset.replace_function + arg_idx = argset.arg_idx + primal_form = primal_form.label_map(lambda t: t.has_label(subject), + replace_subject(TrialFunction(V0)), + drop) + + if error is None: + new_form = primal_form.label_map( + all_terms, + map_if_true=replace_function(new_arg, **idxs)) + + if 'new_idx' in idxs: + split_arg = new_arg if type(new_arg) is tuple else split(new_arg) + new_arg = split_arg[idxs['new_idx']].ufl_operands[0] + + if isinstance(new_arg, Argument): + assert new_form.form.arguments()[arg_idx] is new_arg + elif type(new_arg) is Function: + assert new_form.form.coefficients()[0] is new_arg + + else: + with pytest.raises(error): + new_form = primal_form.label_map( + all_terms, + map_if_true=replace_function(new_arg, **idxs)) + + +@pytest.mark.parametrize('argset', mixed_test_argsets() + mixed_trial_argsets()) +def test_replace_arg_mixed(mixed_form, argset): + new_arg = argset.new_arg + idxs = argset.idxs + error = argset.error + replace_function = argset.replace_function + arg_idx = argset.arg_idx + mixed_form = mixed_form.label_map(lambda t: t.has_label(subject), + replace_subject(TrialFunction(W)), + drop) + + if error is None: + new_form = mixed_form.label_map( + all_terms, + map_if_true=replace_function(new_arg, **idxs)) + + if 'new_idx' in idxs: + split_arg = new_arg if type(new_arg) is tuple else split(new_arg) + new_arg = split_arg[idxs['new_idx']].ufl_operands[0] + + if isinstance(new_arg, Argument): + assert new_form.form.arguments()[arg_idx] is new_arg + elif type(new_arg) is Function: + assert new_form.form.coefficients()[0] is new_arg + + else: + with pytest.raises(error): + new_form = mixed_form.label_map( + all_terms, + map_if_true=replace_function(new_arg, **idxs)) + + +@pytest.mark.parametrize('argset', vector_test_argsets()) +def test_replace_arg_vector(vector_form, argset): + new_arg = argset.new_arg + idxs = argset.idxs + error = argset.error + replace_function = argset.replace_function + arg_idx = argset.arg_idx + vector_form = vector_form.label_map(lambda t: t.has_label(subject), + replace_subject(TrialFunction(Vv)), + drop) + + if error is None: + new_form = vector_form.label_map( + all_terms, + map_if_true=replace_function(new_arg, **idxs)) + + if 'new_idx' in idxs: + split_arg = new_arg if type(new_arg) is tuple else split(new_arg) + new_arg = split_arg[idxs['new_idx']].ufl_operands[0] + + if isinstance(new_arg, Argument): + assert new_form.form.arguments()[arg_idx] is new_arg + elif type(new_arg) is Function: + assert new_form.form.coefficients()[0] is new_arg + + else: + with pytest.raises(error): + new_form = vector_form.label_map( + all_terms, + map_if_true=replace_function(new_arg, **idxs)) diff --git a/tests/unit/test_fml/test_term.py b/tests/unit/test_fml/test_term.py new file mode 100644 index 0000000000..d889d663e9 --- /dev/null +++ b/tests/unit/test_fml/test_term.py @@ -0,0 +1,169 @@ +""" +Tests FML's Term objects. A term contains a form and labels. +""" + +from firedrake import (IntervalMesh, FunctionSpace, Function, TestFunction, dx, + Constant, Label, Term, LabelledForm) +import pytest + + +# Two methods of making a Term with Labels. Either pass them as a dict +# at the initialisation of the Term, or apply them afterwards +@pytest.mark.parametrize("initialise", ["from_dicts", "apply_labels"]) +def test_term(initialise): + + # ------------------------------------------------------------------------ # + # Set up terms + # ------------------------------------------------------------------------ # + + # Some basic labels + foo_label = Label("foo", validator=lambda value: type(value) == bool) + lorem_label = Label("lorem", validator=lambda value: type(value) == str) + ipsum_label = Label("ipsum", validator=lambda value: type(value) == int) + + # Dict for matching the label names to the label objects + all_labels = [foo_label, lorem_label, ipsum_label] + all_label_dict = {label.label: label for label in all_labels} + + # Create mesh, function space and forms + L = 3.0 + n = 3 + mesh = IntervalMesh(n, L) + V = FunctionSpace(mesh, "DG", 0) + f = Function(V) + g = Function(V) + h = Function(V) + test = TestFunction(V) + form = f*test*dx + + # Declare what the labels will be + label_dict = {'foo': True, 'lorem': 'etc', 'ipsum': 1} + + # Make terms + if initialise == "from_dicts": + term = Term(form, label_dict) + else: + term = Term(form) + + # Apply labels + for label_name, value in label_dict.items(): + term = all_label_dict[label_name](term, value) + + # ------------------------------------------------------------------------ # + # Test Term.get routine + # ------------------------------------------------------------------------ # + + for label in all_labels: + if label.label in label_dict.keys(): + # Check if label is attached to Term and it has correct value + assert term.get(label) == label_dict[label.label], \ + f'term should have label {label.label} with value equal ' + \ + f'to {label_dict[label.label]} and not {term.get(label)}' + else: + # Labelled shouldn't be attached to Term so this should return None + assert term.get(label) is None, 'term should not have ' + \ + f'label {label.label} but term.get(label) returns ' + \ + f'{term.get(label)}' + + # ------------------------------------------------------------------------ # + # Test Term.has_label routine + # ------------------------------------------------------------------------ # + + # Test has_label for each label one by one + for label in all_labels: + assert term.has_label(label) == (label.label in label_dict.keys()), \ + f'term.has_label giving incorrect value for {label.label}' + + # Test has_labels by passing all labels at once + has_labels = term.has_label(*all_labels, return_tuple=True) + for i, label in enumerate(all_labels): + assert has_labels[i] == (label.label in label_dict.keys()), \ + f'has_label for label {label.label} returning wrong value' + + # Check the return_tuple option is correct when only one label is passed + has_labels = term.has_label(*[foo_label], return_tuple=True) + assert len(has_labels) == 1, 'Length returned by has_label is ' + \ + f'incorrect, it is {len(has_labels)} but should be 1' + assert has_labels[0] == (label.label in label_dict.keys()), \ + f'has_label for label {label.label} returning wrong value' + + # ------------------------------------------------------------------------ # + # Test Term addition and subtraction + # ------------------------------------------------------------------------ # + + form_2 = g*test*dx + term_2 = ipsum_label(Term(form_2), 2) + + labelled_form_1 = term_2 + term + labelled_form_2 = term + term_2 + + # Adding two Terms should return a LabelledForm containing the Terms + assert type(labelled_form_1) is LabelledForm, 'The sum of two Terms ' + \ + f'should be a LabelledForm, not {type(labelled_form_1)}' + assert type(labelled_form_2) is LabelledForm, 'The sum of two Terms ' + \ + f'should be a LabelledForm, not {type(labelled_form_1)}' + + # Adding a LabelledForm to a Term should return a LabelledForm + labelled_form_3 = term + labelled_form_2 + assert type(labelled_form_3) is LabelledForm, 'The sum of a Term and ' + \ + f'Labelled Form should be a LabelledForm, not {type(labelled_form_3)}' + + labelled_form_1 = term_2 - term + labelled_form_2 = term - term_2 + + # Subtracting two Terms should return a LabelledForm containing the Terms + assert type(labelled_form_1) is LabelledForm, 'The difference of two ' + \ + f'Terms should be a LabelledForm, not {type(labelled_form_1)}' + assert type(labelled_form_2) is LabelledForm, 'The difference of two ' + \ + f'Terms should be a LabelledForm, not {type(labelled_form_1)}' + + # Subtracting a LabelledForm from a Term should return a LabelledForm + labelled_form_3 = term - labelled_form_2 + assert type(labelled_form_3) is LabelledForm, 'The differnce of a Term ' + \ + f'and a Labelled Form should be a LabelledForm, not {type(labelled_form_3)}' + + # Adding None to a Term should return the Term + new_term = term + None + assert term == new_term, 'Adding None to a Term should give the same Term' + + # ------------------------------------------------------------------------ # + # Test Term multiplication and division + # ------------------------------------------------------------------------ # + + # Multiplying a term by an integer should give a Term + new_term = term*3 + assert type(new_term) is Term, 'Multiplying a Term by an integer ' + \ + f'give a Term, not a {type(new_term)}' + + # Multiplying a term by a float should give a Term + new_term = term*19.0 + assert type(new_term) is Term, 'Multiplying a Term by a float ' + \ + f'give a Term, not a {type(new_term)}' + + # Multiplying a term by a Constant should give a Term + new_term = term*Constant(-4.0) + assert type(new_term) is Term, 'Multiplying a Term by a Constant ' + \ + f'give a Term, not a {type(new_term)}' + + # Dividing a term by an integer should give a Term + new_term = term/3 + assert type(new_term) is Term, 'Dividing a Term by an integer ' + \ + f'give a Term, not a {type(new_term)}' + + # Dividing a term by a float should give a Term + new_term = term/19.0 + assert type(new_term) is Term, 'Dividing a Term by a float ' + \ + f'give a Term, not a {type(new_term)}' + + # Dividing a term by a Constant should give a Term + new_term = term/Constant(-4.0) + assert type(new_term) is Term, 'Dividing a Term by a Constant ' + \ + f'give a Term, not a {type(new_term)}' + + # Multiplying a term by a Function should fail + try: + new_term = term*h + # If we get here we have failed + assert False, 'Multiplying a Term by a Function should fail' + except TypeError: + pass