diff --git a/firedrake/mg/utils.py b/firedrake/mg/utils.py index c716a99676..353741e464 100644 --- a/firedrake/mg/utils.py +++ b/firedrake/mg/utils.py @@ -8,23 +8,28 @@ from firedrake.cython import mgimpl as impl -def get_or_set_mg_hierarchy_map_cache(cache_dict, entity_dofs_key, - create_map_on_cpu): +def get_or_set_mg_hierarchy_map_cache(cache_dict, entity_dofs_key, create_map_on_cpu): """ :arg cache_dict: An instance of :class:`dict` that maps from tuple ``(entity_dofs_key, compute_backend)`` to the corresponding map. - :arg create_host_map: A callable that takes no argument and returns the map + :arg create_map_on_cpu: A callable that takes no argument and returns the map on the CPU backend. :returns map: An instance of :class:`pyop2.base.Map`. """ try: return cache_dict[(entity_dofs_key, op2.compute_backend)] except KeyError: + from pyop2.sequential import cpu_backend - host_map = cache_dict.setdefault((entity_dofs_key, - cpu_backend), create_map_on_cpu()) - return cache_dict.setdefault((entity_dofs_key, op2.compute_backend), - op2.compute_backend.Map(host_map)) + if (entity_dofs_key, cpu_backend) not in cache_dict: + cache_dict[(entity_dofs_key, cpu_backend)] = create_map_on_cpu() + + map_on_cpu = cache_dict[(entity_dofs_key, cpu_backend)] + + if (entity_dofs_key, op2.compute_backend) not in cache_dict: + cache_dict[(entity_dofs_key, op2.compute_backend)] = op2.compute_backend.Map(map_on_cpu) + + return cache_dict[(entity_dofs_key, op2.compute_backend)] def fine_node_to_coarse_node_map(Vf, Vc):