Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (dmp) Emit multiple swaps if halo is larger than core size #3238

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion tests/dialects/test_dmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

def flat_face_exchanges(
shape: ShapeAttr, dim: int
) -> tuple[ExchangeDeclarationAttr, ExchangeDeclarationAttr]:
) -> tuple[ExchangeDeclarationAttr, ...]:
# we need access to the _flat_face_exchanges_for_dim method in order to test it
# since this is a private function, and pyright will yell whenever it's accessed,
# we have this wrapper function here that takes care of making the private publicly
Expand Down Expand Up @@ -77,3 +77,44 @@ def test_decomp_flat_face_4d():
assert ex_neg_y.size == (10, 4, 10, 10)
assert ex_neg_y.source_offset == (0, 4, 0, 0)
assert ex_neg_y.neighbor == (0, -1, 0, 0)


def test_decomp_with_overflow():
shape = ShapeAttr.from_index_attrs(
(0, 0, 0), # buff lb
(2, 2, 2), # core lb
(3, 3, 102), # core ub
(5, 5, 104), # buff ub
)

exchanges = tuple(flat_face_exchanges(shape, 0))
assert len(exchanges) == 4

ex_px_1, ex_px_2, ex_nx_1, ex_nx_2 = exchanges

# all exchanges are of size (1, 1, 100)
assert all(ex.size == (1, 1, 100) for ex in exchanges)

# the first exchange is closer to the core region
assert ex_px_1.offset == (3, 2, 2)
# and has a source offset of (-1, 0, 0)
assert ex_px_1.source_offset == (-1, 0, 0)
# the second exchange is farther away
assert ex_px_2.offset == (4, 2, 2)
# and has a source offset of twice that
assert ex_px_2.source_offset == (-2, 0, 0)

# same for negative x, first exchange is closer to the core
assert ex_nx_1.offset == (1, 2, 2)
# and has a source offset of (1, 0, 0)
assert ex_nx_1.source_offset == (1, 0, 0)
# second is farther away
assert ex_nx_2.offset == (0, 2, 2)
# and has a source offset of (2, 0, 0)
assert ex_nx_2.source_offset == (2, 0, 0)

assert ex_px_1.neighbor == (1, 0, 0)
assert ex_px_2.neighbor == (2, 0, 0)

assert ex_nx_1.neighbor == (-1, 0, 0)
assert ex_nx_2.neighbor == (-2, 0, 0)
109 changes: 84 additions & 25 deletions xdsl/dialects/experimental/dmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def from_points(
sizes,
# source_offset (opposite of exchange direction)
tuple(
0 if d != dim else -1 * dir_sign * sizes[dim] for d in range(len(sizes))
0 if d != dim else -1 * dir_sign * sizes[dim] * neighbor_offset
for d in range(len(sizes))
),
# direction
tuple(
Expand Down Expand Up @@ -545,44 +546,102 @@ def comm_layout(self) -> RankTopoAttr:

def _flat_face_exchanges_for_dim(
shape: ShapeAttr, axis: int
) -> tuple[ExchangeDeclarationAttr, ExchangeDeclarationAttr]:
) -> tuple[ExchangeDeclarationAttr, ...]:
"""
Generate the two exchange delcarations to exchange the faces on the
axis "axis".
"""
dimensions = shape.dims
assert axis <= dimensions

def coords(where: Literal["start", "end"]):
for d in range(dimensions):
# for the dim we want to exchange, return either start or end halo region
if d == axis:
if where == "start":
# "start" halo goes from buffer start to core start
yield shape.buffer_start(d), shape.core_start(d)
def coords(where: Literal["start", "end"]) -> Iterable[tuple[tuple[int, int], ...]]:
"""
Generate a series of swaps that need to be performed to exchange along "axis".

A swap is a set of (lb,ub) tuples, one per axis of shape.

Takes either "start" or "end" to signify if the lower (buffer start to core start) or upper
(core end to buffer end) parts of the halo should be exchanged.

We need to make sure that if core_size is smaller than halo size, we emit multiple exchanges.

We need to make sure that we emit the exchanges in a way that the closest neighbor is emitted first.
"""
# we may need to issue multiple swaps per direction, if the core size is smaller than the
# exchanged size. This is tracked in the "slice" variable.
slice = 0

while True:
swap: list[tuple[int, int]] = []
for d in range(dimensions):
# for the dim we want to exchange, return exchanges need to exchange either start or end
# halo regions
if d == axis:
core_size = shape.core_size(d)
if where == "start":
# where == "start" halo goes from buffer start to core start
# the window of data we want to send starts here
start = shape.buffer_start(d)
# calculate where the current slice starts (lowest index, no lower than start)
slice_start = max(
start, shape.core_start(d) - (core_size * (slice + 1))
)
# calculate where the current slice ends (highest index, no higher than core_start)
# because slice >= 0
slice_end = max(
start, shape.core_start(d) - (core_size * slice)
)

# stop swapping if swap is empty
if slice_end == slice_start:
return
swap.append((slice_start, slice_end))
else:
# where == "end" halo goes from core end to buffer end

# the window of data we want to send ends here (highest index)
end = shape.buffer_end(d)
# calculate where the current slice starts (lowest index, no lower than start)
# because slice >= 0, and no higher than end
slice_start = min(end, shape.core_end(d) + (core_size * slice))
# calculate where the current slice ends (highest index, no higher than core_start)
slice_end = min(
end, shape.core_end(d) + (core_size * (slice + 1))
)

# stop swapping if swap is empty
if slice_end == slice_start:
return
swap.append((slice_start, slice_end))

else:
# "end" halo goes from core end to buffer end
yield shape.core_end(d), shape.buffer_end(d)
else:
# for the sliced regions, "extrude" from core
# this way we don't exchange edges
yield shape.core_start(d), shape.core_end(d)
# for the sliced regions, "extrude" from core
# this way we don't exchange edges
swap.append((shape.core_start(d), shape.core_end(d)))

ex1_coords = tuple(coords("end"))
ex2_coords = tuple(coords("start"))
slice += 1
yield tuple(swap)

return (
# towards positive dim:
ExchangeDeclarationAttr.from_points(
ex1_coords,
axis,
dir_sign=1,
*(
ExchangeDeclarationAttr.from_points(
ex1_coords,
axis,
dir_sign=1,
neighbor_offset=i + 1,
)
for i, ex1_coords in enumerate(coords("end"))
),
# towards negative dim:
ExchangeDeclarationAttr.from_points(
ex2_coords,
axis,
dir_sign=-1,
*(
ExchangeDeclarationAttr.from_points(
ex2_coords,
axis,
dir_sign=-1,
neighbor_offset=i + 1,
)
for i, ex2_coords in enumerate(coords("start"))
),
)

Expand Down
Loading