From d476a81ba09308d192f0ec57c80a5d7202f20f48 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Mon, 12 Aug 2024 15:00:00 +0100 Subject: [PATCH 01/21] Draft changes to pyop2.caching --- firedrake/parameters.py | 5 +- firedrake/slate/slac/compiler.py | 22 +++-- firedrake/tsfc_interface.py | 159 ++++++++++++------------------- tests/test_tsfc_interface.py | 24 +++-- 4 files changed, 94 insertions(+), 116 deletions(-) diff --git a/firedrake/parameters.py b/firedrake/parameters.py index 9d379daab0..5863e76a77 100644 --- a/firedrake/parameters.py +++ b/firedrake/parameters.py @@ -37,7 +37,10 @@ def rename(self, name): def __getstate__(self): # Remove non-picklable update function slot d = self.__dict__.copy() - del d["_update_function"] + try: + del d["_update_function"] + except KeyError: + pass return d def set_update_function(self, callable): diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index 2fccb2cfca..b9e5bd005e 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -33,6 +33,7 @@ import numpy as np import loopy import gem +import os from gem import indices as make_indices from tsfc.kernel_args import OutputKernelArg, CoefficientKernelArg from tsfc.loopy import generate as generate_loopy @@ -66,17 +67,22 @@ cell_to_facets_dtype = np.dtype(np.int8) -class SlateKernel(TSFCKernel): - @classmethod - def _cache_key(cls, expr, compiler_parameters): - return md5( - (expr.expression_hash + str(sorted(compiler_parameters.items()))).encode()).hexdigest(), expr.ufl_domains()[0].comm +try: + _cachedir = os.environ["FIREDRAKE_TSFC_KERNEL_CACHE_DIR"] +except KeyError: + _cachedir = os.path.join(tempfile.gettempdir(), + f"firedrake-tsfc-expression-kernel-cache-uid{os.getuid()}") + + +def _cache_key(expr, compiler_parameters): + return expr.ufl_domains()[0].comm, md5( + (expr.expression_hash + str(sorted(compiler_parameters.items()))).encode() + ).hexdigest() + +class SlateKernel(TSFCKernel, cachedir=_cachedir, key=_cache_key): def __init__(self, expr, compiler_parameters): - if self._initialized: - return self.split_kernel = generate_loopy_kernel(expr, compiler_parameters) - self._initialized = True def compile_expression(slate_expr, compiler_parameters=None): diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index a9e84a8827..91c0fed777 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -4,15 +4,11 @@ passing to the backends. """ -import pickle - from hashlib import md5 from os import path, environ, getuid, makedirs -import gzip -import os -import zlib import tempfile import collections +import cachetools import ufl import finat.ufl @@ -24,8 +20,8 @@ from tsfc.ufl_utils import extract_firedrake_constants from pyop2 import op2 -from pyop2.caching import Cached -from pyop2.mpi import COMM_WORLD, MPI +from pyop2.caching import MemoryAndDiskCachedObject, parallel_memory_only_cache +from pyop2.mpi import COMM_WORLD from firedrake.formmanipulation import split_form from firedrake.parameters import parameters as default_parameters @@ -52,75 +48,40 @@ "events"]) -class TSFCKernel(Cached): - - _cache = {} - - _cachedir = environ.get('FIREDRAKE_TSFC_KERNEL_CACHE_DIR', - path.join(tempfile.gettempdir(), - 'firedrake-tsfc-kernel-cache-uid%d' % getuid())) - - @classmethod - def _cache_lookup(cls, key): - key, comm = key - # comm has to be part of the in memory key so that when - # compiling the same code on different subcommunicators we - # don't get deadlocks. But MPI_Comm objects are not hashable, - # so use comm.py2f() since this is an internal communicator and - # hence the C handle is stable. - commkey = comm.py2f() - assert commkey != MPI.COMM_NULL.py2f() - return cls._cache.get((key, commkey)) or cls._read_from_disk(key, comm) - - @classmethod - def _read_from_disk(cls, key, comm): - if comm.rank == 0: - cache = cls._cachedir - shard, disk_key = key[:2], key[2:] - filepath = os.path.join(cache, shard, disk_key) - val = None - if os.path.exists(filepath): - try: - with gzip.open(filepath, 'rb') as f: - val = f.read() - except zlib.error: - pass - - comm.bcast(val, root=0) - else: - val = comm.bcast(None, root=0) - - if val is None: - raise KeyError(f"Object with key {key} not found") - return cls._cache.setdefault((key, comm.py2f()), pickle.loads(val)) - - @classmethod - def _cache_store(cls, key, val): - key, comm = key - cls._cache[(key, comm.py2f())] = val - _ensure_cachedir(comm=comm) - if comm.rank == 0: - val._key = key - shard, disk_key = key[:2], key[2:] - filepath = os.path.join(cls._cachedir, shard, disk_key) - tempfile = os.path.join(cls._cachedir, shard, "%s_p%d.tmp" % (disk_key, os.getpid())) - # No need for a barrier after this, since non root - # processes will never race on this file. - os.makedirs(os.path.join(cls._cachedir, shard), exist_ok=True) - with gzip.open(tempfile, 'wb') as f: - pickle.dump(val, f, 0) - os.rename(tempfile, filepath) - comm.barrier() - - @classmethod - def _cache_key(cls, form, name, parameters, coefficient_numbers, constant_numbers, interface, diagonal=False): - return md5((form.signature() + name - + str(sorted(parameters.items())) - + str(coefficient_numbers) - + str(constant_numbers) - + str(type(interface)) - + str(diagonal)).encode()).hexdigest(), form.ufl_domains()[0].comm - +_cachedir = environ.get( + 'FIREDRAKE_TSFC_KERNEL_CACHE_DIR', + path.join(tempfile.gettempdir(), f'firedrake-tsfc-kernel-cache-uid{getuid()}') +) + + +def TSFCKernel_hashkey(*args, **kwargs): + arg_dict = dict(kwargs) + arg_names = [ + "form", "name", "parameters", "coefficient_numbers", + "constant_numbers", "interface", "diagonal" + ] + for k, v in zip(arg_names, args): + arg_dict[k] = v + arg_dict.setdefault("diagonal", False) + comm = arg_dict["form"].ufl_domains()[0].comm + if isinstance(arg_dict["form"], str): + signature = arg_dict["form"] + else: + signature = arg_dict["form"].signature() + parts = ( + signature, + arg_dict["name"], + sorted(arg_dict["parameters"].items()), + arg_dict["coefficient_numbers"], + arg_dict["constant_numbers"], + type(arg_dict["interface"]), + arg_dict["diagonal"] + ) + key = md5(" ".join(map(str, parts)).encode()).hexdigest() + return comm, key + + +class TSFCKernel(MemoryAndDiskCachedObject, cachedir=_cachedir, key=TSFCKernel_hashkey): def __init__( self, form, @@ -141,8 +102,6 @@ def __init__( :arg interface: the KernelBuilder interface for TSFC (may be None) :arg diagonal: If assembling a matrix is it diagonal? """ - if self._initialized: - return tree = tsfc_compile_form(form, prefix=name, parameters=parameters, interface=interface, diagonal=diagonal, log=PETSc.Log.isActive()) @@ -179,14 +138,22 @@ def __init__( arguments=kernel.arguments, events=events)) self.kernels = tuple(kernels) - self._initialized = True -SplitKernel = collections.namedtuple("SplitKernel", ["indices", - "kinfo"]) +SplitKernel = collections.namedtuple("SplitKernel", ["indices", "kinfo"]) + + +def compile_form_hashkey(*args, **kwargs): + # form, name, parameters, split, diagonal + comm = args[0].ufl_domains()[0].comm + parameters = kwargs.pop("parameters", None) + key = cachetools.keys.hashkey(*args, utils.tuplify(parameters), **kwargs) + kwargs.setdefault("parameters", parameters) + return comm, key @PETSc.Log.EventDecorator() +@parallel_memory_only_cache(key=compile_form_hashkey) def compile_form(form, name, parameters=None, split=True, interface=None, diagonal=False): """Compile a form using TSFC. @@ -222,16 +189,6 @@ def compile_form(form, name, parameters=None, split=True, interface=None, diagon parameters = default_parameters["form_compiler"].copy() parameters.update(_) - # We stash the compiled kernels on the form so we don't have to recompile - # if we assemble the same form again with the same optimisations - cache = form._cache.setdefault("firedrake_kernels", {}) - - key = (name, utils.tuplify(parameters), split, diagonal) - try: - return cache[key] - except KeyError: - pass - kernels = [] numbering = form.terminal_numbering() if split: @@ -258,15 +215,19 @@ def compile_form(form, name, parameters=None, split=True, interface=None, diagon numbering[c] for c in extract_firedrake_constants(f) ) prefix = name + "".join(map(str, (i for i in idx if i is not None))) - kinfos = TSFCKernel(f, prefix, parameters, - coefficient_numbers, - constant_numbers, - interface, diagonal).kernels - for kinfo in kinfos: + tsfc_kernel = TSFCKernel( + f, + prefix, + parameters, + coefficient_numbers, + constant_numbers, + interface, diagonal + ) + for kinfo in tsfc_kernel.kernels: kernels.append(SplitKernel(idx, kinfo)) kernels = tuple(kernels) - return cache.setdefault(key, kernels) + return kernels def _real_mangle(form): @@ -291,7 +252,7 @@ def clear_cache(comm=None): comm = comm or COMM_WORLD if comm.rank == 0: import shutil - shutil.rmtree(TSFCKernel._cachedir, ignore_errors=True) + shutil.rmtree(_cachedir, ignore_errors=True) _ensure_cachedir(comm=comm) @@ -299,7 +260,7 @@ def _ensure_cachedir(comm=None): """Ensure that the TSFC kernel cache directory exists.""" comm = comm or COMM_WORLD if comm.rank == 0: - makedirs(TSFCKernel._cachedir, exist_ok=True) + makedirs(_cachedir, exist_ok=True) def gather_integer_subdomain_ids(knls): diff --git a/tests/test_tsfc_interface.py b/tests/test_tsfc_interface.py index 39218e4e64..7c439a21d8 100644 --- a/tests/test_tsfc_interface.py +++ b/tests/test_tsfc_interface.py @@ -1,5 +1,6 @@ import pytest from firedrake import * +from pyop2.caching import _disk_cache_get, _as_hexdigest import os import subprocess import sys @@ -50,7 +51,11 @@ def rhs2(fs): @pytest.fixture def cache_key(mass): - return tsfc_interface.TSFCKernel(mass, 'mass', parameters["form_compiler"], (), (), None).cache_key + key = tsfc_interface.TSFCKernel_hashkey( + mass, 'mass', parameters["form_compiler"], (), (), None + )[1] + disk_key = _as_hexdigest((key, tsfc_interface.TSFCKernel.__qualname__)) + return disk_key class TestTSFCCache: @@ -60,13 +65,17 @@ class TestTSFCCache: def test_cache_key_persistent_across_invocations(self, tmpdir): code = """ from firedrake import * +from firedrake.tsfc_interface import TSFCKernel, TSFCKernel_hashkey +from pyop2.caching import _as_hexdigest mesh = UnitSquareMesh(1, 1) V = FunctionSpace(mesh, "CG", 1) u = TrialFunction(V) v = TestFunction(V) -key = tsfc_interface.TSFCKernel(inner(u,v)*dx, "mass", parameters["form_compiler"], (), (), None).cache_key +obj = tsfc_interface.TSFCKernel(inner(u,v)*dx, "mass", parameters["form_compiler"], (), (), None) +key = tsfc_interface.TSFCKernel_hashkey(inner(u,v)*dx, "mass", parameters["form_compiler"], (), (), None)[1] +disk_key = _as_hexdigest((key, obj.__class__.__qualname__)) with open("{file}", "w") as f: - f.write(key) + f.write(disk_key) """ filea = tmpdir.join("a") fileb = tmpdir.join("b") @@ -78,16 +87,15 @@ def test_cache_key_persistent_across_invocations(self, tmpdir): key2 = f.read() assert key1 == key2 - def test_tsfc_cache_persist_on_disk(self, cache_key): + def test_tsfc_cache_persist_on_disk(self, mass, cache_key): """TSFCKernel should be persisted on disk.""" + tsfc_interface.TSFCKernel(mass, "mass", parameters["form_compiler"], (), (), None) shard, key = cache_key[:2], cache_key[2:] - assert os.path.exists( - os.path.join(tsfc_interface.TSFCKernel._cachedir, shard, key)) + assert os.path.exists(os.path.join(tsfc_interface._cachedir, shard, key)) def test_tsfc_cache_read_from_disk(self, cache_key): """Loading an TSFCKernel from disk should yield the right object.""" - assert tsfc_interface.TSFCKernel._read_from_disk( - cache_key, COMM_WORLD).cache_key == cache_key + assert _disk_cache_get(tsfc_interface._cachedir, cache_key)._cache_key.value == cache_key def test_tsfc_same_form(self, mass): """Compiling the same form twice should load kernels from cache.""" From eafaa52cfaf7791b55f378b0f2bb6b370761730a Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Mon, 12 Aug 2024 22:07:27 +0100 Subject: [PATCH 02/21] WIP --- firedrake/slate/slac/compiler.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index b9e5bd005e..3949b9aced 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -8,6 +8,7 @@ expressions (finite element variational forms written in UFL). """ import time +import tempfile from hashlib import md5 from firedrake_citations import Citations @@ -28,6 +29,7 @@ from pyop2.utils import get_petsc_dir from pyop2.mpi import COMM_WORLD from pyop2.codegen.rep2loopy import SolveCallable, INVCallable +from pyop2.caching import parallel_memory_only_cache import firedrake.slate.slate as slate import numpy as np @@ -70,8 +72,10 @@ try: _cachedir = os.environ["FIREDRAKE_TSFC_KERNEL_CACHE_DIR"] except KeyError: - _cachedir = os.path.join(tempfile.gettempdir(), - f"firedrake-tsfc-expression-kernel-cache-uid{os.getuid()}") + _cachedir = os.path.join( + tempfile.gettempdir(), + f"firedrake-tsfc-expression-kernel-cache-uid{os.getuid()}" + ) def _cache_key(expr, compiler_parameters): @@ -85,6 +89,18 @@ def __init__(self, expr, compiler_parameters): self.split_kernel = generate_loopy_kernel(expr, compiler_parameters) +def compile_expression_hashkey(slate_expr, compiler_parameters=None): + comm = slate_expr.ufl_domains()[0].comm + params = copy.deepcopy(parameters) + if compiler_parameters and "slate_compiler" in compiler_parameters.keys(): + params["slate_compiler"].update(compiler_parameters.pop("slate_compiler")) + if compiler_parameters: + params["form_compiler"].update(compiler_parameters) + + return comm, getattr(slate_expr, "expression_hash", "ERROR") + str(sorted(params.items())) + + +@parallel_memory_only_cache(key=compile_expression_hashkey) def compile_expression(slate_expr, compiler_parameters=None): """Takes a Slate expression `slate_expr` and returns the appropriate ``pyop2.op2.Kernel`` object representing the Slate expression. @@ -108,15 +124,8 @@ def compile_expression(slate_expr, compiler_parameters=None): if compiler_parameters: params["form_compiler"].update(compiler_parameters) - # If the expression has already been symbolically compiled, then - # simply reuse the produced kernel. - cache = slate_expr._metakernel_cache - key = str(sorted(params.items())) - try: - return cache[key] - except KeyError: - kernel = SlateKernel(slate_expr, params).split_kernel - return cache.setdefault(key, kernel) + kernel = SlateKernel(slate_expr, params).split_kernel + return kernel def get_temp_info(loopy_kernel): From 80efdf4efaa0b908cfcb7369088a73b7f5ce5ec2 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Mon, 12 Aug 2024 22:35:43 +0100 Subject: [PATCH 03/21] Change package branch --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3c26f4d7a9..792d28cddf 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -83,6 +83,7 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ + --package-branch PyOP2 JDBetteridge/remove_comm_hash \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | From bd0360740dbf6ca4b0a6e373e9c09025687d52e1 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 13 Aug 2024 17:31:43 +0100 Subject: [PATCH 04/21] Just notes --- firedrake/interpolation.py | 2 +- firedrake/slate/slac/compiler.py | 1 + firedrake/tsfc_interface.py | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index a7442d7820..f2d4275386 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -1211,7 +1211,7 @@ def _compile_expression_key(comm, expr, to_element, ufl_element, domain, paramet return comm, key -@disk_cached({}, _expr_cachedir, key=_compile_expression_key, collective=True) +@disk_cached(None, _expr_cachedir, key=_compile_expression_key, collective=True) def compile_expression(comm, *args, **kwargs): return compile_expression_dual_evaluation(*args, **kwargs) diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index 3949b9aced..8da0c95636 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -100,6 +100,7 @@ def compile_expression_hashkey(slate_expr, compiler_parameters=None): return comm, getattr(slate_expr, "expression_hash", "ERROR") + str(sorted(params.items())) +# TODO: Decorate this with a disk/memory cache instead @parallel_memory_only_cache(key=compile_expression_hashkey) def compile_expression(slate_expr, compiler_parameters=None): """Takes a Slate expression `slate_expr` and returns the appropriate diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index 91c0fed777..e38f6de652 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -102,6 +102,7 @@ def __init__( :arg interface: the KernelBuilder interface for TSFC (may be None) :arg diagonal: If assembling a matrix is it diagonal? """ + # TODO: wrap tsfc_compile_form in a cache tree = tsfc_compile_form(form, prefix=name, parameters=parameters, interface=interface, diagonal=diagonal, log=PETSc.Log.isActive()) @@ -152,6 +153,7 @@ def compile_form_hashkey(*args, **kwargs): return comm, key +# TODO: This should be disk cached @PETSc.Log.EventDecorator() @parallel_memory_only_cache(key=compile_form_hashkey) def compile_form(form, name, parameters=None, split=True, interface=None, diagonal=False): From 8941366441d339bf83a867398d60fa8d6c947e6f Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 15 Aug 2024 20:36:08 +0100 Subject: [PATCH 05/21] WIP --- firedrake/extrusion_utils.py | 6 +++--- firedrake/interpolation.py | 7 +++++-- firedrake/preconditioners/pmg.py | 8 ++++---- firedrake/slate/slac/compiler.py | 18 ++++++------------ firedrake/tsfc_interface.py | 19 +++++++++++++------ 5 files changed, 31 insertions(+), 27 deletions(-) diff --git a/firedrake/extrusion_utils.py b/firedrake/extrusion_utils.py index 0a69eecb41..78b2db701f 100644 --- a/firedrake/extrusion_utils.py +++ b/firedrake/extrusion_utils.py @@ -5,7 +5,7 @@ import finat from pyop2 import op2 -from pyop2.caching import cached +from pyop2.caching import serial_cache, DEFAULT_CACHE from firedrake.petsc import PETSc from firedrake.utils import IntType, RealType, ScalarType from tsfc.finatinterface import create_element @@ -338,7 +338,7 @@ def make_offset_key(finat_element): return entity_dofs_key(finat_element.entity_dofs()), is_real_tensor_product_element(finat_element) -@cached({}, key=make_offset_key) +@serial_cache(hashkey=make_offset_key) def calculate_dof_offset(finat_element): """Return the offset between the neighbouring cells of a column for each DoF. @@ -366,7 +366,7 @@ def calculate_dof_offset(finat_element): return dof_offset -@cached({}, key=make_offset_key) +@serial_cache(hashkey=make_offset_key) def calculate_dof_offset_quotient(finat_element): """Return the offset quotient for each DoF within the base cell. diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index f2d4275386..a8b41ccdd8 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -12,7 +12,7 @@ from ufl.domain import as_domain, extract_unique_domain from pyop2 import op2 -from pyop2.caching import disk_cached +from pyop2.caching import memory_and_disk_cache from tsfc.finatinterface import create_element, as_fiat_cell from tsfc import compile_expression_dual_evaluation @@ -1211,7 +1211,10 @@ def _compile_expression_key(comm, expr, to_element, ufl_element, domain, paramet return comm, key -@disk_cached(None, _expr_cachedir, key=_compile_expression_key, collective=True) +@memory_and_disk_cache( + hashkey=_compile_expression_key, + cachedir=tsfc_interface._cachedir +) def compile_expression(comm, *args, **kwargs): return compile_expression_dual_evaluation(*args, **kwargs) diff --git a/firedrake/preconditioners/pmg.py b/firedrake/preconditioners/pmg.py index 7673403017..04bcba42b6 100644 --- a/firedrake/preconditioners/pmg.py +++ b/firedrake/preconditioners/pmg.py @@ -12,7 +12,7 @@ from tsfc.finatinterface import create_element from tsfc import compile_expression_dual_evaluation from pyop2 import op2 -from pyop2.caching import cached +from pyop2.caching import serial_cache, DEFAULT_CACHE from pyop2.utils import as_tuple import firedrake @@ -589,7 +589,7 @@ def get_readonly_view(arr): return result -@cached({}, key=generate_key_evaluate_dual) +@serial_cache(hashkey=generate_key_evaluate_dual) def evaluate_dual(source, target, derivative=None): """Evaluate the action of a set of dual functionals of the target element on the (derivative of the) basis functions of the source element. @@ -627,7 +627,7 @@ def evaluate_dual(source, target, derivative=None): return get_readonly_view(numpy.dot(A, B)) -@cached({}, key=generate_key_evaluate_dual) +@serial_cache(hashkey=generate_key_evaluate_dual) def compare_element(e1, e2): """Numerically compare two :class:`FIAT.elements`. Equality is satisfied if e2.dual_basis(e1.primal_basis) == identity.""" @@ -639,7 +639,7 @@ def compare_element(e1, e2): return numpy.allclose(B, numpy.eye(B.shape[0]), rtol=1E-14, atol=1E-14) -@cached({}, key=lambda V: V.ufl_element()) +@serial_cache(hashkey=lambda V: V.ufl_element()) @PETSc.Log.EventDecorator("GetLineElements") def get_permutation_to_line_elements(V): """Find DOF permutation to factor out the EnrichedElement expansion diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index 8da0c95636..1b82e79c9a 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -29,7 +29,7 @@ from pyop2.utils import get_petsc_dir from pyop2.mpi import COMM_WORLD from pyop2.codegen.rep2loopy import SolveCallable, INVCallable -from pyop2.caching import parallel_memory_only_cache +from pyop2.caching import memory_and_disk_cache import firedrake.slate.slate as slate import numpy as np @@ -69,22 +69,13 @@ cell_to_facets_dtype = np.dtype(np.int8) -try: - _cachedir = os.environ["FIREDRAKE_TSFC_KERNEL_CACHE_DIR"] -except KeyError: - _cachedir = os.path.join( - tempfile.gettempdir(), - f"firedrake-tsfc-expression-kernel-cache-uid{os.getuid()}" - ) - - def _cache_key(expr, compiler_parameters): return expr.ufl_domains()[0].comm, md5( (expr.expression_hash + str(sorted(compiler_parameters.items()))).encode() ).hexdigest() -class SlateKernel(TSFCKernel, cachedir=_cachedir, key=_cache_key): +class SlateKernel(TSFCKernel): def __init__(self, expr, compiler_parameters): self.split_kernel = generate_loopy_kernel(expr, compiler_parameters) @@ -101,7 +92,10 @@ def compile_expression_hashkey(slate_expr, compiler_parameters=None): # TODO: Decorate this with a disk/memory cache instead -@parallel_memory_only_cache(key=compile_expression_hashkey) +@memory_and_disk_cache( + hashkey=compile_expression_hashkey, + cachedir=tsfc_interface._cachedir +) def compile_expression(slate_expr, compiler_parameters=None): """Takes a Slate expression `slate_expr` and returns the appropriate ``pyop2.op2.Kernel`` object representing the Slate expression. diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index e38f6de652..15af0bf584 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -20,7 +20,7 @@ from tsfc.ufl_utils import extract_firedrake_constants from pyop2 import op2 -from pyop2.caching import MemoryAndDiskCachedObject, parallel_memory_only_cache +from pyop2.caching import memory_and_disk_cache from pyop2.mpi import COMM_WORLD from firedrake.formmanipulation import split_form @@ -81,7 +81,7 @@ def TSFCKernel_hashkey(*args, **kwargs): return comm, key -class TSFCKernel(MemoryAndDiskCachedObject, cachedir=_cachedir, key=TSFCKernel_hashkey): +class TSFCKernel: def __init__( self, form, @@ -144,18 +144,25 @@ def __init__( SplitKernel = collections.namedtuple("SplitKernel", ["indices", "kinfo"]) -def compile_form_hashkey(*args, **kwargs): +def _compile_form_hashkey(*args, **kwargs): # form, name, parameters, split, diagonal - comm = args[0].ufl_domains()[0].comm parameters = kwargs.pop("parameters", None) key = cachetools.keys.hashkey(*args, utils.tuplify(parameters), **kwargs) kwargs.setdefault("parameters", parameters) - return comm, key + return key + + +def _compile_form_comm(*args, **kwargs): + return args[0].ufl_domains()[0].comm # TODO: This should be disk cached @PETSc.Log.EventDecorator() -@parallel_memory_only_cache(key=compile_form_hashkey) +@memory_and_disk_cache( + hashkey=_compile_form_hashkey, + comm_fetcher=_compile_form_comm, + cachedir=_cachedir +) def compile_form(form, name, parameters=None, split=True, interface=None, diagonal=False): """Compile a form using TSFC. From c97e7264fb43e61a26295861d676e873258a1a7f Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Fri, 16 Aug 2024 00:13:58 +0100 Subject: [PATCH 06/21] WIP: fixes --- firedrake/slate/slac/compiler.py | 12 ++++++++---- tests/test_tsfc_interface.py | 10 +++++----- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index 1b82e79c9a..6c2066b435 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -80,20 +80,24 @@ def __init__(self, expr, compiler_parameters): self.split_kernel = generate_loopy_kernel(expr, compiler_parameters) -def compile_expression_hashkey(slate_expr, compiler_parameters=None): - comm = slate_expr.ufl_domains()[0].comm +def _compile_expression_hashkey(slate_expr, compiler_parameters=None): params = copy.deepcopy(parameters) if compiler_parameters and "slate_compiler" in compiler_parameters.keys(): params["slate_compiler"].update(compiler_parameters.pop("slate_compiler")) if compiler_parameters: params["form_compiler"].update(compiler_parameters) + return getattr(slate_expr, "expression_hash", "ERROR") + str(sorted(params.items())) - return comm, getattr(slate_expr, "expression_hash", "ERROR") + str(sorted(params.items())) + +def _compile_expression_comm(*args, **kwargs): + # args[0] is a slate_expr + return args[0].ufl_domains()[0].comm # TODO: Decorate this with a disk/memory cache instead @memory_and_disk_cache( - hashkey=compile_expression_hashkey, + hashkey=_compile_expression_hashkey, + comm_fetcher=_compile_expression_comm, cachedir=tsfc_interface._cachedir ) def compile_expression(slate_expr, compiler_parameters=None): diff --git a/tests/test_tsfc_interface.py b/tests/test_tsfc_interface.py index 7c439a21d8..836e98912f 100644 --- a/tests/test_tsfc_interface.py +++ b/tests/test_tsfc_interface.py @@ -1,6 +1,6 @@ import pytest from firedrake import * -from pyop2.caching import _disk_cache_get, _as_hexdigest +from pyop2.caching import DictLikeDiskAccess, _as_hexdigest import os import subprocess import sys @@ -54,14 +54,13 @@ def cache_key(mass): key = tsfc_interface.TSFCKernel_hashkey( mass, 'mass', parameters["form_compiler"], (), (), None )[1] - disk_key = _as_hexdigest((key, tsfc_interface.TSFCKernel.__qualname__)) + disk_key = _as_hexdigest(key, tsfc_interface.TSFCKernel.__qualname__) return disk_key class TestTSFCCache: - """TSFC code generation cache tests.""" - + # TODO: The first three tests no longer make sense, rewrite or delete def test_cache_key_persistent_across_invocations(self, tmpdir): code = """ from firedrake import * @@ -95,7 +94,8 @@ def test_tsfc_cache_persist_on_disk(self, mass, cache_key): def test_tsfc_cache_read_from_disk(self, cache_key): """Loading an TSFCKernel from disk should yield the right object.""" - assert _disk_cache_get(tsfc_interface._cachedir, cache_key)._cache_key.value == cache_key + disk = DictLikeDiskAccess(tsfc_interface._cachedir) + assert disk[cache_key]._cache_key.value == cache_key def test_tsfc_same_form(self, mass): """Compiling the same form twice should load kernels from cache.""" From 7bbbc5bae64db3a387395b3fed4b4036eb17c688 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Fri, 16 Aug 2024 00:25:15 +0100 Subject: [PATCH 07/21] pytest-mpi is now a dependency of PyOP2 and causing conflicts, becasue that's how dependencies are supposed to work... --- requirements-git.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements-git.txt b/requirements-git.txt index 8bf05aad21..b6e9e8e1dd 100644 --- a/requirements-git.txt +++ b/requirements-git.txt @@ -5,4 +5,3 @@ git+https://github.com/firedrakeproject/tsfc.git#egg=tsfc git+https://github.com/OP2/PyOP2.git#egg=pyop2 git+https://github.com/dolfin-adjoint/pyadjoint.git#egg=pyadjoint git+https://github.com/firedrakeproject/petsc.git@firedrake#egg=petsc -git+https://github.com/firedrakeproject/pytest-mpi.git@main#egg=pytest-mpi From 123e12911fc3c10684e66cc0a3e832aaea184c0e Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Fri, 16 Aug 2024 00:27:02 +0100 Subject: [PATCH 08/21] Linting --- firedrake/extrusion_utils.py | 2 +- firedrake/preconditioners/pmg.py | 2 +- firedrake/slate/slac/compiler.py | 2 -- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/firedrake/extrusion_utils.py b/firedrake/extrusion_utils.py index 78b2db701f..b038d904af 100644 --- a/firedrake/extrusion_utils.py +++ b/firedrake/extrusion_utils.py @@ -5,7 +5,7 @@ import finat from pyop2 import op2 -from pyop2.caching import serial_cache, DEFAULT_CACHE +from pyop2.caching import serial_cache from firedrake.petsc import PETSc from firedrake.utils import IntType, RealType, ScalarType from tsfc.finatinterface import create_element diff --git a/firedrake/preconditioners/pmg.py b/firedrake/preconditioners/pmg.py index 04bcba42b6..426e71c030 100644 --- a/firedrake/preconditioners/pmg.py +++ b/firedrake/preconditioners/pmg.py @@ -12,7 +12,7 @@ from tsfc.finatinterface import create_element from tsfc import compile_expression_dual_evaluation from pyop2 import op2 -from pyop2.caching import serial_cache, DEFAULT_CACHE +from pyop2.caching import serial_cache from pyop2.utils import as_tuple import firedrake diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index 6c2066b435..5b415b0c04 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -8,7 +8,6 @@ expressions (finite element variational forms written in UFL). """ import time -import tempfile from hashlib import md5 from firedrake_citations import Citations @@ -35,7 +34,6 @@ import numpy as np import loopy import gem -import os from gem import indices as make_indices from tsfc.kernel_args import OutputKernelArg, CoefficientKernelArg from tsfc.loopy import generate as generate_loopy From be11df1db0c00f3d112cdb4304f24d4f01948c2c Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Fri, 16 Aug 2024 21:34:50 +0100 Subject: [PATCH 09/21] WIP: More fixes --- firedrake/interpolation.py | 2 +- firedrake/tsfc_interface.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index a8b41ccdd8..42f1c9b73c 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -1208,7 +1208,7 @@ def _compile_expression_key(comm, expr, to_element, ufl_element, domain, paramet # the form (comm, key) where comm is the communicator the cache is collective over. # FIXME FInAT elements are not safely hashable so we ignore them here key = hash_expr(expr), hash(ufl_element), utils.tuplify(parameters), log - return comm, key + return key @memory_and_disk_cache( diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index 15af0bf584..2d7aa0e029 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -147,7 +147,12 @@ def __init__( def _compile_form_hashkey(*args, **kwargs): # form, name, parameters, split, diagonal parameters = kwargs.pop("parameters", None) - key = cachetools.keys.hashkey(*args, utils.tuplify(parameters), **kwargs) + key = cachetools.keys.hashkey( + args[0].signature(), + *args[1:], + utils.tuplify(parameters), + **kwargs + ) kwargs.setdefault("parameters", parameters) return key From c30fdddfc03ff3bf0db538fc78bbdd77181cb15d Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Sun, 18 Aug 2024 15:09:13 +0100 Subject: [PATCH 10/21] Fix TSFC interface tests --- firedrake/tsfc_interface.py | 1 + tests/test_tsfc_interface.py | 144 ++++++++++++----------------------- 2 files changed, 50 insertions(+), 95 deletions(-) diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index 2d7aa0e029..5fbbab3f72 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -54,6 +54,7 @@ ) +# TODO: Remove this as TSFCKernel is no longer cached to disk! def TSFCKernel_hashkey(*args, **kwargs): arg_dict = dict(kwargs) arg_names = [ diff --git a/tests/test_tsfc_interface.py b/tests/test_tsfc_interface.py index 836e98912f..97a2525e14 100644 --- a/tests/test_tsfc_interface.py +++ b/tests/test_tsfc_interface.py @@ -1,9 +1,5 @@ import pytest from firedrake import * -from pyop2.caching import DictLikeDiskAccess, _as_hexdigest -import os -import subprocess -import sys import loopy @@ -49,94 +45,52 @@ def rhs2(fs): return inner(f, v) * dx + inner(g, v) * ds -@pytest.fixture -def cache_key(mass): - key = tsfc_interface.TSFCKernel_hashkey( - mass, 'mass', parameters["form_compiler"], (), (), None - )[1] - disk_key = _as_hexdigest(key, tsfc_interface.TSFCKernel.__qualname__) - return disk_key - - -class TestTSFCCache: - """TSFC code generation cache tests.""" - # TODO: The first three tests no longer make sense, rewrite or delete - def test_cache_key_persistent_across_invocations(self, tmpdir): - code = """ -from firedrake import * -from firedrake.tsfc_interface import TSFCKernel, TSFCKernel_hashkey -from pyop2.caching import _as_hexdigest -mesh = UnitSquareMesh(1, 1) -V = FunctionSpace(mesh, "CG", 1) -u = TrialFunction(V) -v = TestFunction(V) -obj = tsfc_interface.TSFCKernel(inner(u,v)*dx, "mass", parameters["form_compiler"], (), (), None) -key = tsfc_interface.TSFCKernel_hashkey(inner(u,v)*dx, "mass", parameters["form_compiler"], (), (), None)[1] -disk_key = _as_hexdigest((key, obj.__class__.__qualname__)) -with open("{file}", "w") as f: - f.write(disk_key) - """ - filea = tmpdir.join("a") - fileb = tmpdir.join("b") - subprocess.check_call([sys.executable, "-c", code.format(file=filea)]) - subprocess.check_call([sys.executable, "-c", code.format(file=fileb)]) - with filea.open("r") as f: - key1 = f.read() - with fileb.open("r") as f: - key2 = f.read() - assert key1 == key2 - - def test_tsfc_cache_persist_on_disk(self, mass, cache_key): - """TSFCKernel should be persisted on disk.""" - tsfc_interface.TSFCKernel(mass, "mass", parameters["form_compiler"], (), (), None) - shard, key = cache_key[:2], cache_key[2:] - assert os.path.exists(os.path.join(tsfc_interface._cachedir, shard, key)) - - def test_tsfc_cache_read_from_disk(self, cache_key): - """Loading an TSFCKernel from disk should yield the right object.""" - disk = DictLikeDiskAccess(tsfc_interface._cachedir) - assert disk[cache_key]._cache_key.value == cache_key - - def test_tsfc_same_form(self, mass): - """Compiling the same form twice should load kernels from cache.""" - k1 = tsfc_interface.compile_form(mass, 'mass') - k2 = tsfc_interface.compile_form(mass, 'mass') - - assert k1 is k2 - assert all(k1_[-1] is k2_[-1] for k1_, k2_ in zip(k1, k2)) - - def test_tsfc_same_mixed_form(self, mixed_mass): - """Compiling a mixed form twice should load kernels from cache.""" - k1 = tsfc_interface.compile_form(mixed_mass, 'mixed_mass') - k2 = tsfc_interface.compile_form(mixed_mass, 'mixed_mass') - - assert k1 is k2 - assert all(k1_[-1] is k2_[-1] for k1_, k2_ in zip(k1, k2)) - - def test_tsfc_different_forms(self, mass, laplace): - """Compiling different forms should not load kernels from cache.""" - k1, = tsfc_interface.compile_form(mass, 'mass') - k2, = tsfc_interface.compile_form(laplace, 'mass') - - assert k1[-1] is not k2[-1] - - def test_tsfc_different_names(self, mass): - """Compiling different forms should not load kernels from cache.""" - k1, = tsfc_interface.compile_form(mass, 'mass') - k2, = tsfc_interface.compile_form(mass, 'laplace') - - assert k1[-1] is not k2[-1] - - def test_tsfc_cell_kernel(self, mass): - k = tsfc_interface.compile_form(mass, 'mass') - assert len(k) == 1 and 'cell_integral' in loopy.generate_code_v2(k[0][1][0].code).device_code() - - def test_tsfc_exterior_facet_kernel(self, rhs): - k = tsfc_interface.compile_form(rhs, 'rhs') - assert len(k) == 1 and 'exterior_facet_integral' in loopy.generate_code_v2(k[0][1][0].code).device_code() - - def test_tsfc_cell_exterior_facet_kernel(self, rhs2): - k = tsfc_interface.compile_form(rhs2, 'rhs2') - kernel_name = sorted(k_[1][0].name for k_ in k) - assert len(k) == 2 and 'cell_integral' in kernel_name[0] and \ - 'exterior_facet_integral' in kernel_name[1] +def test_tsfc_same_form(mass): + """Compiling the same form twice should load kernels from cache.""" + k1 = tsfc_interface.compile_form(mass, 'mass') + k2 = tsfc_interface.compile_form(mass, 'mass') + + assert k1 is k2 + assert all(k1_[-1] is k2_[-1] for k1_, k2_ in zip(k1, k2)) + + +def test_tsfc_same_mixed_form(mixed_mass): + """Compiling a mixed form twice should load kernels from cache.""" + k1 = tsfc_interface.compile_form(mixed_mass, 'mixed_mass') + k2 = tsfc_interface.compile_form(mixed_mass, 'mixed_mass') + + assert k1 is k2 + assert all(k1_[-1] is k2_[-1] for k1_, k2_ in zip(k1, k2)) + + +def test_tsfc_different_forms(mass, laplace): + """Compiling different forms should not load kernels from cache.""" + k1, = tsfc_interface.compile_form(mass, 'mass') + k2, = tsfc_interface.compile_form(laplace, 'mass') + + assert k1[-1] is not k2[-1] + + +def test_tsfc_different_names(mass): + """Compiling different forms should not load kernels from cache.""" + k1, = tsfc_interface.compile_form(mass, 'mass') + k2, = tsfc_interface.compile_form(mass, 'laplace') + + assert k1[-1] is not k2[-1] + + +def test_tsfc_cell_kernel(mass): + k = tsfc_interface.compile_form(mass, 'mass') + assert len(k) == 1 and 'cell_integral' in loopy.generate_code_v2(k[0][1][0].code).device_code() + + +def test_tsfc_exterior_facet_kernel(rhs): + k = tsfc_interface.compile_form(rhs, 'rhs') + assert len(k) == 1 and 'exterior_facet_integral' in loopy.generate_code_v2(k[0][1][0].code).device_code() + + +def test_tsfc_cell_exterior_facet_kernel(rhs2): + k = tsfc_interface.compile_form(rhs2, 'rhs2') + kernel_name = sorted(k_[1][0].name for k_ in k) + assert len(k) == 2 and 'cell_integral' in kernel_name[0] and \ + 'exterior_facet_integral' in kernel_name[1] From 409cd5958e14ab0d53c1b7a00aa55cf9816bf62f Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Sun, 18 Aug 2024 15:10:51 +0100 Subject: [PATCH 11/21] Remove TSFC kernel hash --- firedrake/tsfc_interface.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index 5fbbab3f72..483df0e494 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -4,7 +4,6 @@ passing to the backends. """ -from hashlib import md5 from os import path, environ, getuid, makedirs import tempfile import collections @@ -54,34 +53,6 @@ ) -# TODO: Remove this as TSFCKernel is no longer cached to disk! -def TSFCKernel_hashkey(*args, **kwargs): - arg_dict = dict(kwargs) - arg_names = [ - "form", "name", "parameters", "coefficient_numbers", - "constant_numbers", "interface", "diagonal" - ] - for k, v in zip(arg_names, args): - arg_dict[k] = v - arg_dict.setdefault("diagonal", False) - comm = arg_dict["form"].ufl_domains()[0].comm - if isinstance(arg_dict["form"], str): - signature = arg_dict["form"] - else: - signature = arg_dict["form"].signature() - parts = ( - signature, - arg_dict["name"], - sorted(arg_dict["parameters"].items()), - arg_dict["coefficient_numbers"], - arg_dict["constant_numbers"], - type(arg_dict["interface"]), - arg_dict["diagonal"] - ) - key = md5(" ".join(map(str, parts)).encode()).hexdigest() - return comm, key - - class TSFCKernel: def __init__( self, From ac169f7fd780cd2af461f408fd0e50c9ef88acd2 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Sun, 18 Aug 2024 15:27:42 +0100 Subject: [PATCH 12/21] Do TODOs (that I added...) --- firedrake/slate/slac/compiler.py | 1 - firedrake/tsfc_interface.py | 24 ++++++++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index 5b415b0c04..8e6486d7a3 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -92,7 +92,6 @@ def _compile_expression_comm(*args, **kwargs): return args[0].ufl_domains()[0].comm -# TODO: Decorate this with a disk/memory cache instead @memory_and_disk_cache( hashkey=_compile_expression_hashkey, comm_fetcher=_compile_expression_comm, diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index 483df0e494..c20d1d89cd 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -14,12 +14,12 @@ from ufl import Form, conj from .ufl_expr import TestFunction -from tsfc import compile_form as tsfc_compile_form +from tsfc import compile_form as original_tsfc_compile_form from tsfc.parameters import PARAMETERS as tsfc_default_parameters from tsfc.ufl_utils import extract_firedrake_constants from pyop2 import op2 -from pyop2.caching import memory_and_disk_cache +from pyop2.caching import memory_and_disk_cache, default_parallel_hashkey from pyop2.mpi import COMM_WORLD from firedrake.formmanipulation import split_form @@ -53,6 +53,24 @@ ) +def tsfc_compile_form_hashkey(form, prefix, parameters, interface, diagonal, log): + # Drop prefix as it's only used for naming and log + return default_parallel_hashkey(form.signature(), parameters, interface, diagonal) + + +def tsfc_compile_form_comm_fetcher(*args, **kwargs): + # args[0] is a form + return args[0].ufl_domains()[0].comm + + +# Decorate the original tsfc.compile_form with a cache +tsfc_compile_form = memory_and_disk_cache( + hashkey=tsfc_compile_form_hashkey, + comm_fetcher=tsfc_compile_form_comm_fetcher, + cachedir=_cachedir +)(original_tsfc_compile_form) + + class TSFCKernel: def __init__( self, @@ -74,7 +92,6 @@ def __init__( :arg interface: the KernelBuilder interface for TSFC (may be None) :arg diagonal: If assembling a matrix is it diagonal? """ - # TODO: wrap tsfc_compile_form in a cache tree = tsfc_compile_form(form, prefix=name, parameters=parameters, interface=interface, diagonal=diagonal, log=PETSc.Log.isActive()) @@ -133,7 +150,6 @@ def _compile_form_comm(*args, **kwargs): return args[0].ufl_domains()[0].comm -# TODO: This should be disk cached @PETSc.Log.EventDecorator() @memory_and_disk_cache( hashkey=_compile_form_hashkey, From 1edbed3a2f79005d57f642c3422221c3200d9197 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Sun, 18 Aug 2024 17:01:44 +0100 Subject: [PATCH 13/21] Remove incorrect comment --- firedrake/interpolation.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 42f1c9b73c..24bc7ab45d 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -1204,9 +1204,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): def _compile_expression_key(comm, expr, to_element, ufl_element, domain, parameters, log): """Generate a cache key suitable for :func:`tsfc.compile_expression_dual_evaluation`.""" - # Since the caching is collective, this function must return a 2-tuple of - # the form (comm, key) where comm is the communicator the cache is collective over. - # FIXME FInAT elements are not safely hashable so we ignore them here key = hash_expr(expr), hash(ufl_element), utils.tuplify(parameters), log return key From e2d2f6d9d0dc937538aa6a3a798c0417bf6438d6 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Sun, 18 Aug 2024 17:02:39 +0100 Subject: [PATCH 14/21] Add prefix back to the hash as it's being abused by SLATE --- firedrake/tsfc_interface.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index c20d1d89cd..4532604128 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -55,7 +55,9 @@ def tsfc_compile_form_hashkey(form, prefix, parameters, interface, diagonal, log): # Drop prefix as it's only used for naming and log - return default_parallel_hashkey(form.signature(), parameters, interface, diagonal) + # JBTODO: Can't drop prefix as tests/slate/test_optimise.py::test_partially_optimised fails, investigate + # it looks like the prefix is being used to create different subkernels, which conflicts with the docstring below + return default_parallel_hashkey(form.signature(), prefix, parameters, interface, diagonal) def tsfc_compile_form_comm_fetcher(*args, **kwargs): From 68bb658f28b91a5a8b7649574018c926f0a0278e Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Fri, 23 Aug 2024 15:59:53 +0100 Subject: [PATCH 15/21] Add event decorators for compilation functions --- firedrake/interpolation.py | 1 + firedrake/slate/slac/compiler.py | 1 + firedrake/tsfc_interface.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 24bc7ab45d..f1e9fa93d0 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -1212,6 +1212,7 @@ def _compile_expression_key(comm, expr, to_element, ufl_element, domain, paramet hashkey=_compile_expression_key, cachedir=tsfc_interface._cachedir ) +@PETSc.Log.EventDecorator() def compile_expression(comm, *args, **kwargs): return compile_expression_dual_evaluation(*args, **kwargs) diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index 8e6486d7a3..46611aaf01 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -97,6 +97,7 @@ def _compile_expression_comm(*args, **kwargs): comm_fetcher=_compile_expression_comm, cachedir=tsfc_interface._cachedir ) +@PETSc.Log.EventDecorator() def compile_expression(slate_expr, compiler_parameters=None): """Takes a Slate expression `slate_expr` and returns the appropriate ``pyop2.op2.Kernel`` object representing the Slate expression. diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index 4532604128..245663d8c1 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -152,12 +152,12 @@ def _compile_form_comm(*args, **kwargs): return args[0].ufl_domains()[0].comm -@PETSc.Log.EventDecorator() @memory_and_disk_cache( hashkey=_compile_form_hashkey, comm_fetcher=_compile_form_comm, cachedir=_cachedir ) +@PETSc.Log.EventDecorator() def compile_form(form, name, parameters=None, split=True, interface=None, diagonal=False): """Compile a form using TSFC. From 137e301eeb748a374c9aff8752c932f42c5540d6 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Sun, 18 Aug 2024 14:53:46 +0100 Subject: [PATCH 16/21] Use simple parttioner in ensemble tests --- tests/regression/test_ensembleparallelism.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/regression/test_ensembleparallelism.py b/tests/regression/test_ensembleparallelism.py index dd7a8a7157..faa3db99dc 100644 --- a/tests/regression/test_ensembleparallelism.py +++ b/tests/regression/test_ensembleparallelism.py @@ -67,7 +67,7 @@ def ensemble(): def mesh(ensemble): if COMM_WORLD.size == 1: return - return UnitSquareMesh(10, 10, comm=ensemble.comm) + return UnitSquareMesh(10, 10, comm=ensemble.comm, distribution_parameters={"partitioner_type": "simple"}) # mixed function space From 532c0f9ab6909703e8f5a26f6292d65921c85cf4 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 27 Aug 2024 18:43:46 +0100 Subject: [PATCH 17/21] Log on all ranks when PYOP2_SPMD_STRICT is enabled --- firedrake/logging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/firedrake/logging.py b/firedrake/logging.py index 7b524fbe9c..1841f7a4e4 100644 --- a/firedrake/logging.py +++ b/firedrake/logging.py @@ -5,6 +5,7 @@ import tsfc.logging # noqa: F401 import pyop2.logger # noqa: F401 +from pyop2.configuration import configuration from pyop2.mpi import COMM_WORLD @@ -79,7 +80,7 @@ def set_log_handlers(handlers=None, comm=COMM_WORLD): handler = logging.StreamHandler() handler.setFormatter(logging.Formatter(fmt="%(name)s:%(levelname)s %(message)s")) - if comm is not None and comm.rank != 0: + if comm is not None and comm.rank != 0 and not configuration["spmd_strict"]: handler = logging.NullHandler() logger.addHandler(handler) From 46f2ea199c4d72cd265352effc7233f51c7acaa7 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 27 Aug 2024 18:44:22 +0100 Subject: [PATCH 18/21] Fix some non-SPMD dat accesses in VOM tests --- tests/vertexonly/test_vertex_only_fs.py | 7 +++++ .../test_vertex_only_mesh_generation.py | 31 +++++++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/tests/vertexonly/test_vertex_only_fs.py b/tests/vertexonly/test_vertex_only_fs.py index 866a31b810..a44791d0ea 100644 --- a/tests/vertexonly/test_vertex_only_fs.py +++ b/tests/vertexonly/test_vertex_only_fs.py @@ -323,10 +323,14 @@ def test_input_ordering_missing_point(): # put data on the input ordering P0DG_input_ordering = FunctionSpace(vm.input_ordering, "DG", 0) data_input_ordering = Function(P0DG_input_ordering) + if vm.comm.rank == 0: data_input_ordering.dat.data_wo[:] = data + # Accessing data_ro [*here] is collective, hence this redundant call + _ = len(data_input_ordering.dat.data_ro) else: data_input_ordering.dat.data_wo[:] = [] + # [*here] assert not len(data_input_ordering.dat.data_ro) # shouldn't have any halos @@ -348,6 +352,9 @@ def test_input_ordering_missing_point(): data_input_ordering.interpolate(data_on_vm) if vm.comm.rank == 0: assert np.allclose(data_input_ordering.dat.data_ro[0:3], 2*data[0:3]) + # [*here] assert np.allclose(data_input_ordering.dat.data_ro[3], data[3]) else: assert not len(data_input_ordering.dat.data_ro) + # Accessing data_ro [*here] is collective, hence this redundant call + _ = len(data_input_ordering.dat.data_ro) diff --git a/tests/vertexonly/test_vertex_only_mesh_generation.py b/tests/vertexonly/test_vertex_only_mesh_generation.py index c3fe5dda12..3ac5bb43ba 100644 --- a/tests/vertexonly/test_vertex_only_mesh_generation.py +++ b/tests/vertexonly/test_vertex_only_mesh_generation.py @@ -144,24 +144,40 @@ def verify_vertexonly_mesh(m, vm, inputvertexcoords, name): total_cells = MPI.COMM_WORLD.allreduce(len(vm.coordinates.dat.data_ro), op=MPI.SUM) total_in_bounds = MPI.COMM_WORLD.allreduce(len(in_bounds), op=MPI.SUM) skip_in_bounds_checks = False + local_cells = len(vm.coordinates.dat.data_ro) if total_cells != total_in_bounds: assert MPI.COMM_WORLD.size > 1 # i.e. we're in parallel assert total_cells < total_in_bounds # i.e. some points are duplicated - local_cells = len(vm.coordinates.dat.data_ro) local_in_bounds = len(in_bounds) if not local_cells == local_in_bounds and local_in_bounds > 0: - assert max(ref_cell_dists_l1) > 0.5*m.tolerance + # This assertion needs to happen in parallel! + assertion = (max(ref_cell_dists_l1) > 0.5*m.tolerance) skip_in_bounds_checks = True + else: + assertion = True + else: + assertion = True + # FIXME: Replace with parallel assert when it's merged into pytest-mpi + assert min(MPI.COMM_WORLD.allgather([assertion])) + # Correct local coordinates (though not guaranteed to be in same order) if not skip_in_bounds_checks: # Correct local coordinates (though not guaranteed to be in same order) + # [*here] np.allclose(np.sort(vm.coordinates.dat.data_ro), np.sort(inputvertexcoords[in_bounds])) + else: + # Accessing data_ro [*here] is collective, hence this redundant call + _ = len(vm.coordinates.dat.data_ro) # Correct parent topology assert vm._parent_mesh is m assert vm.topology._parent_mesh is m.topology # Correct generic cell properties if not skip_in_bounds_checks: + # [*here] assert vm.cell_closure.shape == (len(vm.coordinates.dat.data_ro_with_halos), 1) + else: + # Accessing data_ro [*here] is collective, hence this redundant call + _ = len(vm.coordinates.dat.data_ro_with_halos) with pytest.raises(AttributeError): vm.exterior_facets() with pytest.raises(AttributeError): @@ -169,8 +185,13 @@ def verify_vertexonly_mesh(m, vm, inputvertexcoords, name): with pytest.raises(AttributeError): vm.cell_to_facets if not skip_in_bounds_checks: + # [*here] assert vm.num_cells() == vm.cell_closure.shape[0] == len(vm.coordinates.dat.data_ro_with_halos) == vm.cell_set.total_size assert vm.cell_set.size == len(inputvertexcoords[in_bounds]) == len(vm.coordinates.dat.data_ro) + else: + # Accessing data_ro and data_ro_with_halos [*here] is collective, hence this redundant call + _ = len(vm.coordinates.dat.data_ro_with_halos) + _ = len(vm.coordinates.dat.data_ro) assert vm.num_facets() == 0 assert vm.num_faces() == vm.num_entities(2) == 0 assert vm.num_edges() == vm.num_entities(1) == 0 @@ -257,11 +278,17 @@ def test_generate_cell_midpoints(parentmesh, redundant): out_of_mesh_point = np.full((1, parentmesh.geometric_dimension()), np.inf) for i in range(max_len): if i < len(vm.coordinates.dat.data_ro): + # [*here] cell_num = parentmesh.locate_cell(vm.coordinates.dat.data_ro[i]) else: cell_num = parentmesh.locate_cell(out_of_mesh_point) # should return None + # Accessing data_ro [*here] is collective, hence this redundant call + _ = len(vm.coordinates.dat.data_ro) if cell_num is not None: assert (f.dat.data_ro[cell_num] == vm.coordinates.dat.data_ro[i]).all() + else: + _ = len(f.dat.data_ro) + _ = len(vm.coordinates.dat.data_ro) # Have correct pyop2 labels as implied by cell set sizes if parentmesh.extruded: From 785b22c32254dcef5c0e07ef56649cbefc6a603c Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 8 Oct 2024 14:58:58 +0100 Subject: [PATCH 19/21] Apply suggestions from code review --- firedrake/slate/slac/compiler.py | 6 ------ firedrake/tsfc_interface.py | 2 -- 2 files changed, 8 deletions(-) diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index 46611aaf01..7dedd21e10 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -67,12 +67,6 @@ cell_to_facets_dtype = np.dtype(np.int8) -def _cache_key(expr, compiler_parameters): - return expr.ufl_domains()[0].comm, md5( - (expr.expression_hash + str(sorted(compiler_parameters.items()))).encode() - ).hexdigest() - - class SlateKernel(TSFCKernel): def __init__(self, expr, compiler_parameters): self.split_kernel = generate_loopy_kernel(expr, compiler_parameters) diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index 245663d8c1..a4a57ae0cb 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -55,8 +55,6 @@ def tsfc_compile_form_hashkey(form, prefix, parameters, interface, diagonal, log): # Drop prefix as it's only used for naming and log - # JBTODO: Can't drop prefix as tests/slate/test_optimise.py::test_partially_optimised fails, investigate - # it looks like the prefix is being used to create different subkernels, which conflicts with the docstring below return default_parallel_hashkey(form.signature(), prefix, parameters, interface, diagonal) From 67bbaa20b2ce854b8e295a32ee26ad867d19e6ab Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 8 Oct 2024 15:00:55 +0100 Subject: [PATCH 20/21] Don't need md5 --- firedrake/slate/slac/compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index 7dedd21e10..1333a04039 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -8,7 +8,6 @@ expressions (finite element variational forms written in UFL). """ import time -from hashlib import md5 from firedrake_citations import Citations from firedrake.tsfc_interface import SplitKernel, KernelInfo, TSFCKernel From d082a6d223d68ba411226b10b4e292b3190b6603 Mon Sep 17 00:00:00 2001 From: "David A. Ham" Date: Wed, 9 Oct 2024 16:45:44 +0100 Subject: [PATCH 21/21] Update .github/workflows/build.yml Co-authored-by: Jack Betteridge <43041811+JDBetteridge@users.noreply.github.com> --- .github/workflows/build.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 792d28cddf..3c26f4d7a9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -83,7 +83,6 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ - --package-branch PyOP2 JDBetteridge/remove_comm_hash \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: |