Skip to content

Commit c0986bf

Browse files
ksagiyampbrubeck
andauthored
fem: cache on instances (#4664)
Co-authored-by: Pablo Brubeck <brubeck@protonmail.com>
1 parent 6f1737c commit c0986bf

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

tsfc/fem.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -284,15 +284,6 @@ def get_quadrature_rule(fiat_cell, integration_dim, quadrature_degree, scheme):
284284
return make_quadrature(integration_cell, quadrature_degree, scheme=scheme)
285285

286286

287-
def make_basis_evaluation_key(ctx, finat_element, mt, entity_id):
288-
ufl_element = mt.terminal.ufl_element()
289-
domain = extract_unique_domain(mt.terminal)
290-
coordinate_element = domain.ufl_coordinate_element()
291-
# This way of caching is fragile.
292-
# Should Implement _hash_key_() for ModifiedTerminal and use the entire mt as key.
293-
return (ufl_element, mt.local_derivatives, ctx.point_set, ctx.integration_dim, entity_id, coordinate_element, mt.restriction, domain._ufl_hash_data_())
294-
295-
296287
class PointSetContext(ContextBase):
297288
"""Context for compile-time known evaluation points."""
298289

@@ -323,12 +314,32 @@ def point_expr(self):
323314
def weight_expr(self):
324315
return self.quadrature_rule.weight_expression
325316

326-
@serial_cache(hashkey=make_basis_evaluation_key)
317+
@staticmethod
318+
def _make_basis_evaluation_key(finat_element, mt, entity_id):
319+
ufl_element = mt.terminal.ufl_element()
320+
domain = extract_unique_domain(mt.terminal)
321+
coordinate_element = domain.ufl_coordinate_element()
322+
# This way of caching is fragile.
323+
# Should implement _hash_key_() in ModifiedTerminal and include the entire mt in the key,
324+
# or only pass necessary bits in mt to basis_evaluation.
325+
return (ufl_element, mt.local_derivatives, entity_id, coordinate_element, mt.restriction, domain._ufl_hash_data_())
326+
327+
@cached_property
328+
def _basis_evaluation_cache(self):
329+
return {}
330+
327331
def basis_evaluation(self, finat_element, mt, entity_id):
328-
return finat_element.basis_evaluation(mt.local_derivatives,
329-
self.point_set,
330-
(self.integration_dim, entity_id),
331-
coordinate_mapping=CoordinateMapping(mt, self))
332+
key = PointSetContext._make_basis_evaluation_key(finat_element, mt, entity_id)
333+
try:
334+
return self._basis_evaluation_cache[key]
335+
except KeyError:
336+
val = finat_element.basis_evaluation(
337+
mt.local_derivatives,
338+
self.point_set,
339+
(self.integration_dim, entity_id),
340+
coordinate_mapping=CoordinateMapping(mt, self),
341+
)
342+
return self._basis_evaluation_cache.setdefault(key, val)
332343

333344

334345
class GemPointContext(ContextBase):

0 commit comments

Comments
 (0)