Skip to content

Commit

Permalink
JDBetteridge/update caching (#3730)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: David A. Ham <david.ham@imperial.ac.uk>
  • Loading branch information
JDBetteridge and dham authored Oct 9, 2024
1 parent 14b9ddf commit 4e6cba3
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 218 deletions.
6 changes: 3 additions & 3 deletions firedrake/extrusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import finat
from pyop2 import op2
from pyop2.caching import cached
from pyop2.caching import serial_cache
from firedrake.petsc import PETSc
from firedrake.utils import IntType, RealType, ScalarType
from tsfc.finatinterface import create_element
Expand Down Expand Up @@ -338,7 +338,7 @@ def make_offset_key(finat_element):
return entity_dofs_key(finat_element.entity_dofs()), is_real_tensor_product_element(finat_element)


@cached({}, key=make_offset_key)
@serial_cache(hashkey=make_offset_key)
def calculate_dof_offset(finat_element):
"""Return the offset between the neighbouring cells of a
column for each DoF.
Expand Down Expand Up @@ -366,7 +366,7 @@ def calculate_dof_offset(finat_element):
return dof_offset


@cached({}, key=make_offset_key)
@serial_cache(hashkey=make_offset_key)
def calculate_dof_offset_quotient(finat_element):
"""Return the offset quotient for each DoF within the base cell.
Expand Down
13 changes: 7 additions & 6 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ufl.domain import as_domain, extract_unique_domain

from pyop2 import op2
from pyop2.caching import disk_cached
from pyop2.caching import memory_and_disk_cache

from tsfc.finatinterface import create_element, as_fiat_cell
from tsfc import compile_expression_dual_evaluation
Expand Down Expand Up @@ -1204,14 +1204,15 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):

def _compile_expression_key(comm, expr, to_element, ufl_element, domain, parameters, log):
"""Generate a cache key suitable for :func:`tsfc.compile_expression_dual_evaluation`."""
# Since the caching is collective, this function must return a 2-tuple of
# the form (comm, key) where comm is the communicator the cache is collective over.
# FIXME FInAT elements are not safely hashable so we ignore them here
key = hash_expr(expr), hash(ufl_element), utils.tuplify(parameters), log
return comm, key
return key


@disk_cached({}, _expr_cachedir, key=_compile_expression_key, collective=True)
@memory_and_disk_cache(
hashkey=_compile_expression_key,
cachedir=tsfc_interface._cachedir
)
@PETSc.Log.EventDecorator()
def compile_expression(comm, *args, **kwargs):
return compile_expression_dual_evaluation(*args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion firedrake/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tsfc.logging # noqa: F401
import pyop2.logger # noqa: F401

from pyop2.configuration import configuration
from pyop2.mpi import COMM_WORLD


Expand Down Expand Up @@ -79,7 +80,7 @@ def set_log_handlers(handlers=None, comm=COMM_WORLD):
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(fmt="%(name)s:%(levelname)s %(message)s"))

if comm is not None and comm.rank != 0:
if comm is not None and comm.rank != 0 and not configuration["spmd_strict"]:
handler = logging.NullHandler()

logger.addHandler(handler)
Expand Down
5 changes: 4 additions & 1 deletion firedrake/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def rename(self, name):
def __getstate__(self):
# Remove non-picklable update function slot
d = self.__dict__.copy()
del d["_update_function"]
try:
del d["_update_function"]
except KeyError:
pass
return d

def set_update_function(self, callable):
Expand Down
8 changes: 4 additions & 4 deletions firedrake/preconditioners/pmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tsfc.finatinterface import create_element
from tsfc import compile_expression_dual_evaluation
from pyop2 import op2
from pyop2.caching import cached
from pyop2.caching import serial_cache
from pyop2.utils import as_tuple

import firedrake
Expand Down Expand Up @@ -589,7 +589,7 @@ def get_readonly_view(arr):
return result


@cached({}, key=generate_key_evaluate_dual)
@serial_cache(hashkey=generate_key_evaluate_dual)
def evaluate_dual(source, target, derivative=None):
"""Evaluate the action of a set of dual functionals of the target element
on the (derivative of the) basis functions of the source element.
Expand Down Expand Up @@ -627,7 +627,7 @@ def evaluate_dual(source, target, derivative=None):
return get_readonly_view(numpy.dot(A, B))


@cached({}, key=generate_key_evaluate_dual)
@serial_cache(hashkey=generate_key_evaluate_dual)
def compare_element(e1, e2):
"""Numerically compare two :class:`FIAT.elements`.
Equality is satisfied if e2.dual_basis(e1.primal_basis) == identity."""
Expand All @@ -639,7 +639,7 @@ def compare_element(e1, e2):
return numpy.allclose(B, numpy.eye(B.shape[0]), rtol=1E-14, atol=1E-14)


@cached({}, key=lambda V: V.ufl_element())
@serial_cache(hashkey=lambda V: V.ufl_element())
@PETSc.Log.EventDecorator("GetLineElements")
def get_permutation_to_line_elements(V):
"""Find DOF permutation to factor out the EnrichedElement expansion
Expand Down
41 changes: 23 additions & 18 deletions firedrake/slate/slac/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
expressions (finite element variational forms written in UFL).
"""
import time
from hashlib import md5

from firedrake_citations import Citations
from firedrake.tsfc_interface import SplitKernel, KernelInfo, TSFCKernel
Expand All @@ -28,6 +27,7 @@
from pyop2.utils import get_petsc_dir
from pyop2.mpi import COMM_WORLD
from pyop2.codegen.rep2loopy import SolveCallable, INVCallable
from pyop2.caching import memory_and_disk_cache

import firedrake.slate.slate as slate
import numpy as np
Expand Down Expand Up @@ -67,18 +67,30 @@


class SlateKernel(TSFCKernel):
@classmethod
def _cache_key(cls, expr, compiler_parameters):
return md5(
(expr.expression_hash + str(sorted(compiler_parameters.items()))).encode()).hexdigest(), expr.ufl_domains()[0].comm

def __init__(self, expr, compiler_parameters):
if self._initialized:
return
self.split_kernel = generate_loopy_kernel(expr, compiler_parameters)
self._initialized = True


def _compile_expression_hashkey(slate_expr, compiler_parameters=None):
params = copy.deepcopy(parameters)
if compiler_parameters and "slate_compiler" in compiler_parameters.keys():
params["slate_compiler"].update(compiler_parameters.pop("slate_compiler"))
if compiler_parameters:
params["form_compiler"].update(compiler_parameters)
return getattr(slate_expr, "expression_hash", "ERROR") + str(sorted(params.items()))


def _compile_expression_comm(*args, **kwargs):
# args[0] is a slate_expr
return args[0].ufl_domains()[0].comm


@memory_and_disk_cache(
hashkey=_compile_expression_hashkey,
comm_fetcher=_compile_expression_comm,
cachedir=tsfc_interface._cachedir
)
@PETSc.Log.EventDecorator()
def compile_expression(slate_expr, compiler_parameters=None):
"""Takes a Slate expression `slate_expr` and returns the appropriate
``pyop2.op2.Kernel`` object representing the Slate expression.
Expand All @@ -102,15 +114,8 @@ def compile_expression(slate_expr, compiler_parameters=None):
if compiler_parameters:
params["form_compiler"].update(compiler_parameters)

# If the expression has already been symbolically compiled, then
# simply reuse the produced kernel.
cache = slate_expr._metakernel_cache
key = str(sorted(params.items()))
try:
return cache[key]
except KeyError:
kernel = SlateKernel(slate_expr, params).split_kernel
return cache.setdefault(key, kernel)
kernel = SlateKernel(slate_expr, params).split_kernel
return kernel


def get_temp_info(loopy_kernel):
Expand Down
Loading

0 comments on commit 4e6cba3

Please sign in to comment.