Skip to content

Commit

Permalink
Correct number of candidates for prolong/restrict on extruded meshes (#…
Browse files Browse the repository at this point in the history
…3148)

* Correct number of candidates for prolong/restrict on extruded meshes
* Test high order prolongation/restriction
  • Loading branch information
pbrubeck committed Oct 25, 2023
1 parent 8fc1f21 commit 6323219
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 10 deletions.
15 changes: 5 additions & 10 deletions firedrake/mg/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,8 @@ def prolong_kernel(expression):
if meshc.cell_set._extruded:
idx = levelf * hierarchy.refinements_per_level
assert idx == int(idx)
level_ratio = (hierarchy._meshes[int(idx)].layers - 1) // (meshc.layers - 1)
else:
level_ratio = 1
key = (("prolong", level_ratio)
assert hierarchy._meshes[int(idx)].cell_set._extruded
key = (("prolong",)
+ expression.ufl_element().value_shape()
+ entity_dofs_key(expression.function_space().finat_element.entity_dofs())
+ entity_dofs_key(coordinates.function_space().finat_element.entity_dofs()))
Expand Down Expand Up @@ -280,7 +278,7 @@ def prolong_kernel(expression):
""" % {"to_reference": str(to_reference_kernel),
"evaluate": eval_code,
"spacedim": element.cell.get_spatial_dimension(),
"ncandidate": hierarchy.fine_to_coarse_cells[levelf].shape[1] * level_ratio,
"ncandidate": hierarchy.fine_to_coarse_cells[levelf].shape[1],
"Rdim": numpy.prod(element.value_shape),
"inside_cell": inside_check(element.cell, eps=1e-8, X="Xref"),
"celldist_l1_c_expr": celldist_l1_c_expr(element.cell, X="Xref"),
Expand All @@ -298,10 +296,7 @@ def restrict_kernel(Vf, Vc):
coordinates = Vc.ufl_domain().coordinates
if Vf.extruded:
assert Vc.extruded
level_ratio = (Vf.mesh().layers - 1) // (Vc.mesh().layers - 1)
else:
level_ratio = 1
key = (("restrict", level_ratio)
key = (("restrict",)
+ Vf.ufl_element().value_shape()
+ entity_dofs_key(Vf.finat_element.entity_dofs())
+ entity_dofs_key(Vc.finat_element.entity_dofs())
Expand Down Expand Up @@ -366,7 +361,7 @@ def restrict_kernel(Vf, Vc):
}
""" % {"to_reference": str(to_reference_kernel),
"evaluate": evaluate_code,
"ncandidate": hierarchy.fine_to_coarse_cells[levelf].shape[1]*level_ratio,
"ncandidate": hierarchy.fine_to_coarse_cells[levelf].shape[1],
"inside_cell": inside_check(element.cell, eps=1e-8, X="Xref"),
"celldist_l1_c_expr": celldist_l1_c_expr(element.cell, X="Xref"),
"Xc_cell_inc": coords_element.space_dimension(),
Expand Down
73 changes: 73 additions & 0 deletions tests/multigrid/test_grid_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,76 @@ def test_grid_transfer_parallel(hierarchy, transfer_type):
run_restriction(hierarchy, vector, space, degrees)
elif transfer_type == "prolongation":
run_prolongation(hierarchy, vector, space, degrees)


@pytest.fixture(params=["interval-interval",
"quadrilateral",
"quadrilateral-interval",
"hexahedron"], scope="module")
def deformed_cell(request):
return request.param


@pytest.fixture(scope="module")
def deformed_hierarchy(deformed_cell):
cells = deformed_cell.split("-")
extruded = len(cells) == 2
cube = cells[0] in ["quadrilateral", "hexahedron"]
if cells[0] == "interval":
base_dim = 1
elif cells[0] in ["triangle", "quadrilateral"]:
base_dim = 2
elif cells[0] == "hexahedron":
base_dim = 3

nx = 2
if base_dim == 1:
base = UnitIntervalMesh(nx)
elif base_dim == 2:
base = UnitSquareMesh(nx, nx, quadrilateral=cube)
elif base_dim == 3:
base = UnitCubeMesh(nx, nx, nx, hexahedral=cube)
refine = 1
hierarchy = MeshHierarchy(base, refine)
if extruded:
height = 1
hierarchy = ExtrudedMeshHierarchy(hierarchy, height, base_layer=nx)

# Deform into disk/cylinder sector
rmin = 1
rmax = 2
tmin = -pi/4
tmax = pi/4
for mesh in hierarchy:
x = mesh.coordinates.dat.data_ro
R = (rmax - rmin) * x[:, 0] + rmin
T = (tmax - tmin) * x[:, 1] + tmin
mesh.coordinates.dat.data_wo[:, 0] = R * numpy.cos(T)
mesh.coordinates.dat.data_wo[:, 1] = R * numpy.sin(T)
return hierarchy


@pytest.fixture(params=["injection", "restriction", "prolongation"])
def deformed_transfer_type(request, deformed_hierarchy):
if not deformed_hierarchy.nested and request.param == "injection":
return pytest.mark.xfail(reason="Supermesh projections not implemented yet")(request.param)
else:
return request.param


def test_grid_transfer_deformed(deformed_hierarchy, deformed_transfer_type):
space = "Lagrange"
degrees = (1, 2)
vector = False
if not deformed_hierarchy.nested and deformed_transfer_type == "injection":
pytest.skip("Not implemented")
if deformed_transfer_type == "injection":
if space in {"DG", "DQ"} and complex_mode:
with pytest.raises(NotImplementedError):
run_injection(deformed_hierarchy, vector, space, degrees[:1])
else:
run_injection(deformed_hierarchy, vector, space, degrees[:1])
elif deformed_transfer_type == "restriction":
run_restriction(deformed_hierarchy, vector, space, degrees)
elif deformed_transfer_type == "prolongation":
run_prolongation(deformed_hierarchy, vector, space, degrees)

0 comments on commit 6323219

Please sign in to comment.