Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JDBetteridge/update caching #3730

Merged
merged 21 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions firedrake/extrusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import finat
from pyop2 import op2
from pyop2.caching import cached
from pyop2.caching import serial_cache
from firedrake.petsc import PETSc
from firedrake.utils import IntType, RealType, ScalarType
from tsfc.finatinterface import create_element
Expand Down Expand Up @@ -338,7 +338,7 @@ def make_offset_key(finat_element):
return entity_dofs_key(finat_element.entity_dofs()), is_real_tensor_product_element(finat_element)


@cached({}, key=make_offset_key)
@serial_cache(hashkey=make_offset_key)
def calculate_dof_offset(finat_element):
"""Return the offset between the neighbouring cells of a
column for each DoF.
Expand Down Expand Up @@ -366,7 +366,7 @@ def calculate_dof_offset(finat_element):
return dof_offset


@cached({}, key=make_offset_key)
@serial_cache(hashkey=make_offset_key)
def calculate_dof_offset_quotient(finat_element):
"""Return the offset quotient for each DoF within the base cell.

Expand Down
13 changes: 7 additions & 6 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ufl.domain import as_domain, extract_unique_domain

from pyop2 import op2
from pyop2.caching import disk_cached
from pyop2.caching import memory_and_disk_cache

from tsfc.finatinterface import create_element, as_fiat_cell
from tsfc import compile_expression_dual_evaluation
Expand Down Expand Up @@ -1204,14 +1204,15 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):

def _compile_expression_key(comm, expr, to_element, ufl_element, domain, parameters, log):
"""Generate a cache key suitable for :func:`tsfc.compile_expression_dual_evaluation`."""
# Since the caching is collective, this function must return a 2-tuple of
# the form (comm, key) where comm is the communicator the cache is collective over.
# FIXME FInAT elements are not safely hashable so we ignore them here
key = hash_expr(expr), hash(ufl_element), utils.tuplify(parameters), log
return comm, key
return key


@disk_cached({}, _expr_cachedir, key=_compile_expression_key, collective=True)
@memory_and_disk_cache(
hashkey=_compile_expression_key,
cachedir=tsfc_interface._cachedir
)
@PETSc.Log.EventDecorator()
def compile_expression(comm, *args, **kwargs):
return compile_expression_dual_evaluation(*args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion firedrake/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tsfc.logging # noqa: F401
import pyop2.logger # noqa: F401

from pyop2.configuration import configuration
from pyop2.mpi import COMM_WORLD


Expand Down Expand Up @@ -79,7 +80,7 @@ def set_log_handlers(handlers=None, comm=COMM_WORLD):
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(fmt="%(name)s:%(levelname)s %(message)s"))

if comm is not None and comm.rank != 0:
if comm is not None and comm.rank != 0 and not configuration["spmd_strict"]:
handler = logging.NullHandler()

logger.addHandler(handler)
Expand Down
5 changes: 4 additions & 1 deletion firedrake/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def rename(self, name):
def __getstate__(self):
# Remove non-picklable update function slot
d = self.__dict__.copy()
del d["_update_function"]
try:
del d["_update_function"]
except KeyError:
pass
return d

def set_update_function(self, callable):
Expand Down
8 changes: 4 additions & 4 deletions firedrake/preconditioners/pmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tsfc.finatinterface import create_element
from tsfc import compile_expression_dual_evaluation
from pyop2 import op2
from pyop2.caching import cached
from pyop2.caching import serial_cache
from pyop2.utils import as_tuple

import firedrake
Expand Down Expand Up @@ -589,7 +589,7 @@ def get_readonly_view(arr):
return result


@cached({}, key=generate_key_evaluate_dual)
@serial_cache(hashkey=generate_key_evaluate_dual)
def evaluate_dual(source, target, derivative=None):
"""Evaluate the action of a set of dual functionals of the target element
on the (derivative of the) basis functions of the source element.
Expand Down Expand Up @@ -627,7 +627,7 @@ def evaluate_dual(source, target, derivative=None):
return get_readonly_view(numpy.dot(A, B))


@cached({}, key=generate_key_evaluate_dual)
@serial_cache(hashkey=generate_key_evaluate_dual)
def compare_element(e1, e2):
"""Numerically compare two :class:`FIAT.elements`.
Equality is satisfied if e2.dual_basis(e1.primal_basis) == identity."""
Expand All @@ -639,7 +639,7 @@ def compare_element(e1, e2):
return numpy.allclose(B, numpy.eye(B.shape[0]), rtol=1E-14, atol=1E-14)


@cached({}, key=lambda V: V.ufl_element())
@serial_cache(hashkey=lambda V: V.ufl_element())
@PETSc.Log.EventDecorator("GetLineElements")
def get_permutation_to_line_elements(V):
"""Find DOF permutation to factor out the EnrichedElement expansion
Expand Down
41 changes: 23 additions & 18 deletions firedrake/slate/slac/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
expressions (finite element variational forms written in UFL).
"""
import time
from hashlib import md5

from firedrake_citations import Citations
from firedrake.tsfc_interface import SplitKernel, KernelInfo, TSFCKernel
Expand All @@ -28,6 +27,7 @@
from pyop2.utils import get_petsc_dir
from pyop2.mpi import COMM_WORLD
from pyop2.codegen.rep2loopy import SolveCallable, INVCallable
from pyop2.caching import memory_and_disk_cache

import firedrake.slate.slate as slate
import numpy as np
Expand Down Expand Up @@ -67,18 +67,30 @@


class SlateKernel(TSFCKernel):
@classmethod
def _cache_key(cls, expr, compiler_parameters):
return md5(
(expr.expression_hash + str(sorted(compiler_parameters.items()))).encode()).hexdigest(), expr.ufl_domains()[0].comm

def __init__(self, expr, compiler_parameters):
if self._initialized:
return
self.split_kernel = generate_loopy_kernel(expr, compiler_parameters)
self._initialized = True


def _compile_expression_hashkey(slate_expr, compiler_parameters=None):
params = copy.deepcopy(parameters)
if compiler_parameters and "slate_compiler" in compiler_parameters.keys():
params["slate_compiler"].update(compiler_parameters.pop("slate_compiler"))
if compiler_parameters:
params["form_compiler"].update(compiler_parameters)
return getattr(slate_expr, "expression_hash", "ERROR") + str(sorted(params.items()))


def _compile_expression_comm(*args, **kwargs):
# args[0] is a slate_expr
return args[0].ufl_domains()[0].comm


@memory_and_disk_cache(
hashkey=_compile_expression_hashkey,
comm_fetcher=_compile_expression_comm,
cachedir=tsfc_interface._cachedir
)
@PETSc.Log.EventDecorator()
def compile_expression(slate_expr, compiler_parameters=None):
"""Takes a Slate expression `slate_expr` and returns the appropriate
``pyop2.op2.Kernel`` object representing the Slate expression.
Expand All @@ -102,15 +114,8 @@ def compile_expression(slate_expr, compiler_parameters=None):
if compiler_parameters:
params["form_compiler"].update(compiler_parameters)

# If the expression has already been symbolically compiled, then
# simply reuse the produced kernel.
cache = slate_expr._metakernel_cache
key = str(sorted(params.items()))
try:
return cache[key]
except KeyError:
kernel = SlateKernel(slate_expr, params).split_kernel
return cache.setdefault(key, kernel)
kernel = SlateKernel(slate_expr, params).split_kernel
return kernel


def get_temp_info(loopy_kernel):
Expand Down
Loading
Loading