Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when assembling ufl.derivative with mixed function space #1723

Closed
IvanYashchuk opened this issue Jun 5, 2020 · 6 comments
Closed

Error when assembling ufl.derivative with mixed function space #1723

IvanYashchuk opened this issue Jun 5, 2020 · 6 comments

Comments

@IvanYashchuk
Copy link
Contributor

Here's the code that fails at assembly

from firedrake import *
import ufl
mesh = UnitSquareMesh(10, 10)
V = FunctionSpace(mesh, "P", 1)
W = V * V
w = Function(W)
dw = TestFunction(W)
w0, _ = split(w)
dw0, _ = split(dw)
j = inner(grad(w0), grad(w0))*dx
J = ufl.derivative(j, w0, dw0)
# this now raises an error
assemble(J)

The following error is raised

# ~/devdir/firedrake/src/firedrake/firedrake/tsfc_interface.py in _read_from_disk(cls, key, comm)
#      80         if val is None:
# ---> 81             raise KeyError("Object with key %s not found" % key)
#      82         return cls._cache.setdefault(key, pickle.loads(val))
# KeyError: 'Object with key 0da115d0286568e9535ebd02399dac10 not found'
# During handling of the above exception, another exception occurred:
# ...
# UFLException: Expecting scalar coefficient in this branch.

The above code works in FEniCS.
Using J = inner(grad(w0), grad(dw0))*dx works as expected, so there is some bug specifically related to using ufl.derivative on splitted arguments.
Also firedrake.derivative is broken in this case

~/firedrake/src/firedrake/firedrake/ufl_expr.py in derivative(form, u, du, coefficient_derivatives)
    140     # TODO: What about Constant?
    141     u_is_x = isinstance(u, ufl.SpatialCoordinate)
--> 142     if not u_is_x and len(u.split()) > 1 and set(extract_coefficients(form)) & set(u.split()):
    143         raise ValueError("Taking derivative of form wrt u, but form contains coefficients from u.split()."
    144                          "\nYou probably meant to write split(u) when defining your form.")

AttributeError: 'Indexed' object has no attribute 'split'
@IvanYashchuk
Copy link
Contributor Author

Using J = ufl.algorithms.expand_derivatives(ufl.derivative(j, w0, dw0)) assembles without error.

@wence-
Copy link
Contributor

wence- commented Jun 5, 2020

Can you try this patch? The problem is some missing zero simplification in the form splitter leaving us with a nonsense J[1] form. By calling expand derivatives beforehand everything works its way out. Alternately, here I just handle that case.

diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py
index 472c4e0e..b56e13d5 100644
--- a/firedrake/formmanipulation.py
+++ b/firedrake/formmanipulation.py
@@ -69,6 +69,16 @@ class ExtractSubBlock(MultiFunction):
         # [v_0, v_2, v_3][1, 2]
         return self.expr(o, *map_expr_dags(self.index_inliner, operands))
 
+    def coefficient_derivative(self, o, expr, coefficients, arguments, cds):
+        # If we're only taking a derivative wrt part of an argument in
+        # a mixed space other bits might come back as zero. We want to
+        # propagate a zero in that case.
+        argument, = arguments
+        if all(isinstance(a, Zero) for a in argument.ufl_operands):
+            return Zero(o.ufl_shape, o.ufl_free_indices, o.ufl_index_dimensions)
+        else:
+            return self.reuse_if_untouched(o, expr, coefficients, arguments, cds)
+
     def argument(self, o):
         from ufl import split
         from firedrake import MixedFunctionSpace, FunctionSpace
diff --git a/firedrake/ufl_expr.py b/firedrake/ufl_expr.py
index a6ad68f4..690ddf96 100644
--- a/firedrake/ufl_expr.py
+++ b/firedrake/ufl_expr.py
@@ -139,7 +139,8 @@ def derivative(form, u, du=None, coefficient_derivatives=None):
     """
     # TODO: What about Constant?
     u_is_x = isinstance(u, ufl.SpatialCoordinate)
-    if not u_is_x and len(u.split()) > 1 and set(extract_coefficients(form)) & set(u.split()):
+    uc, = extract_coefficients(u)
+    if not u_is_x and len(uc.split()) > 1 and set(extract_coefficients(form)) & set(uc.split()):
         raise ValueError("Taking derivative of form wrt u, but form contains coefficients from u.split()."
                          "\nYou probably meant to write split(u) when defining your form.")
 
@@ -163,17 +164,21 @@ def derivative(form, u, du=None, coefficient_derivatives=None):
         if coefficient_derivatives is not None:
             cds.update(coefficient_derivatives)
         coefficient_derivatives = cds
-    elif isinstance(u, firedrake.Function):
-        V = u.function_space()
+    elif isinstance(uc, firedrake.Function):
+        V = uc.function_space()
         du = argument(V)
-    elif isinstance(u, firedrake.Constant):
-        if u.ufl_shape != ():
+    elif isinstance(uc, firedrake.Constant):
+        if uc.ufl_shape != ():
             raise ValueError("Real function space of vector elements not supported")
         V = firedrake.FunctionSpace(mesh, "Real", 0)
         du = argument(V)
     else:
         raise RuntimeError("Can't compute derivative for form")
 
+    if u.ufl_shape != du.ufl_shape:
+        raise ValueError("Shapes of u and du do not match.\n"
+                         "If you passed an indexed part of split(u) into "
+                         "derivative, you need to provide an appropriate du as well.")
     return ufl.derivative(form, u, du, coefficient_derivatives)

@IvanYashchuk
Copy link
Contributor Author

IvanYashchuk commented Jun 5, 2020

This patch works!

@wence-
Copy link
Contributor

wence- commented Jun 5, 2020

Do you mind preparing a test and PR?

@IvanYashchuk
Copy link
Contributor Author

Will do that

@wence-
Copy link
Contributor

wence- commented Jun 19, 2020

And fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants