Skip to content

Commit

Permalink
Fix FormSum weights
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Dec 23, 2024
1 parent eabdb00 commit 6d8989a
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 16 deletions.
21 changes: 15 additions & 6 deletions ufl/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,21 @@ def __new__(cls, *args, **kw):
if isinstance(right, (Coargument, Argument)):
return left

if isinstance(left, (FormSum, Sum)):
# Action distributes over sums on the LHS
return FormSum(*[(Action(component, right), 1) for component in left.ufl_operands])
if isinstance(right, (FormSum, Sum)):
# Action also distributes over sums on the RHS
return FormSum(*[(Action(left, component), 1) for component in right.ufl_operands])
# Action distributes over sums on the LHS
if isinstance(left, Sum):
return FormSum(*((Action(component, right), 1) for component in left.ufl_operands))
elif isinstance(left, FormSum):
return FormSum(
*((Action(c, right), w) for c, w in zip(left.components(), left.weights()))
)

# Action also distributes over sums on the RHS
if isinstance(right, Sum):
return FormSum(*((Action(left, component), 1) for component in right.ufl_operands))
elif isinstance(right, FormSum):
return FormSum(
*((Action(left, c), w) for c, w in zip(right.components(), right.weights()))
)

return super(Action, cls).__new__(cls)

Expand Down
2 changes: 1 addition & 1 deletion ufl/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __new__(cls, *args, **kw):
return form._form
elif isinstance(form, FormSum):
# Adjoint distributes over sums
return FormSum(*[(Adjoint(component), 1) for component in form.components()])
return FormSum(*((Adjoint(c), w) for c, w in zip(form.components(), form.weights())))
elif isinstance(form, Coargument):
# The adjoint of a coargument `c: V* -> V*` is the identity
# matrix mapping from V to V (i.e. V x V* -> R).
Expand Down
2 changes: 1 addition & 1 deletion ufl/algorithms/map_integrands.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def map_integrands(function, form, only_integral_type=None):
# Simplification of `BaseForm` objects may turn `FormSum` into a sum of `Expr` objects
# that are not `BaseForm`, i.e. into a `Sum` object.
# Example: `Action(Adjoint(c*), u)` with `c*` a `Coargument` and u a `Coefficient`.
return sum([component for component, _ in nonzero_components])
return sum(component * w for component, w in nonzero_components)
return FormSum(*nonzero_components)
elif isinstance(form, Adjoint):
# Zeros are caught inside `Adjoint.__new__`
Expand Down
16 changes: 10 additions & 6 deletions ufl/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def _analyze_domains(self):

# Collect unique domains
self._domains = sort_domains(
join_domains(chain.from_iterable(e.ufl_domains() for e in self.ufl_operands))
join_domains(chain.from_iterable(c.ufl_domains() for c in self.components()))
)

def ufl_domains(self):
Expand All @@ -799,7 +799,9 @@ def ufl_domains(self):
def __hash__(self):
"""Hash."""
if self._hash is None:
self._hash = hash(tuple(hash(component) for component in self.components()))
self._hash = hash(
tuple((hash(c), hash(w)) for c, w in zip(self.components(), self.weights()))
)
return self._hash

def equals(self, other):
Expand All @@ -808,8 +810,10 @@ def equals(self, other):
return False
if self is other:
return True
return len(self.components()) == len(other.components()) and all(
a == b for a, b in zip(self.components(), other.components())
return (
len(self.components()) == len(other.components())
and all(a == b for a, b in zip(self.components(), other.components()))
and all(a == b for a, b in zip(self.weights(), other.weights()))
)

def __str__(self):
Expand All @@ -818,7 +822,7 @@ def __str__(self):
# warning("Calling str on form is potentially expensive and
# should be avoided except during debugging.")
# Not caching this because it can be huge
s = "\n + ".join(str(component) for component in self.components())
s = "\n + ".join(f"{w}*{c}" for c, w in zip(self.components(), self.weights()))
return s or "<empty FormSum>"

def __repr__(self):
Expand All @@ -827,7 +831,7 @@ def __repr__(self):
# warning("Calling repr on form is potentially expensive and
# should be avoided except during debugging.")
# Not caching this because it can be huge
itgs = ", ".join(repr(component) for component in self.components())
itgs = ", ".join(f"{w!r}*{c!r}" for c, w in zip(self.components(), self.weights()))
r = "FormSum([" + itgs + "])"
return r

Expand Down
4 changes: 2 additions & 2 deletions ufl/formoperators.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ def derivative(form, coefficient, argument=None, coefficient_derivatives=None):
# Distribute derivative over FormSum components
return FormSum(
*[
(derivative(component, coefficient, argument, coefficient_derivatives), 1)
for component in form.components()
(derivative(component, coefficient, argument, coefficient_derivatives), w)
for component, w in zip(form.components(), form.weights())
]
)
elif isinstance(form, Adjoint):
Expand Down

0 comments on commit 6d8989a

Please sign in to comment.