From 1853bf516f28e5166a4f6e927754f76caab2cc9e Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Thu, 3 Oct 2024 11:01:50 +0100 Subject: [PATCH 1/3] dialects: (dmp) Emit multiple swaps if halo is larger than core size --- xdsl/dialects/experimental/dmp.py | 102 ++++++++++++++++++++++-------- 1 file changed, 77 insertions(+), 25 deletions(-) diff --git a/xdsl/dialects/experimental/dmp.py b/xdsl/dialects/experimental/dmp.py index f0767c4a1a..70cdd5546b 100644 --- a/xdsl/dialects/experimental/dmp.py +++ b/xdsl/dialects/experimental/dmp.py @@ -13,7 +13,7 @@ from abc import ABC from collections.abc import Iterable, Sequence from math import prod -from typing import Literal, cast +from typing import Literal, cast, TypeAlias from xdsl.dialects import builtin, stencil from xdsl.ir import Attribute, Dialect, Operation, ParametrizedAttribute, SSAValue @@ -545,7 +545,7 @@ def comm_layout(self) -> RankTopoAttr: def _flat_face_exchanges_for_dim( shape: ShapeAttr, axis: int -) -> tuple[ExchangeDeclarationAttr, ExchangeDeclarationAttr]: +) -> tuple[ExchangeDeclarationAttr, ...]: """ Generate the two exchange delcarations to exchange the faces on the axis "axis". @@ -553,36 +553,88 @@ def _flat_face_exchanges_for_dim( dimensions = shape.dims assert axis <= dimensions - def coords(where: Literal["start", "end"]): - for d in range(dimensions): - # for the dim we want to exchange, return either start or end halo region - if d == axis: - if where == "start": - # "start" halo goes from buffer start to core start - yield shape.buffer_start(d), shape.core_start(d) + def coords(where: Literal["start", "end"]) -> Iterable[tuple[tuple[int, int], ...]]: + """ + Generate a series of swaps that need to be performed to exchange along "axis". + + A swap is a set of (lb,ub) tuples, one per axis of shape. + + Takes either "start" or "end" to signify if the lower (buffer start to core start) or upper + (core end to buffer end) parts of the halo should be exchanged. + + We need to make sure that if core_size is smaller than halo size, we emit multiple exchanges. + + We need to make sure that we emit the exchanges in a way that the closest neighbor is emitted first. + """ + # we may need to issue multiple swaps per direction, if the core size is smaller than the + # exchanged size. This is tracked in the "slice" variable. + slice = 0 + + while True: + swap: list[tuple[int, int]] = [] + for d in range(dimensions): + # for the dim we want to exchange, return exchanges need to exchange either start or end + # halo regions + if d == axis: + core_size = shape.core_size(d) + if where == "start": + # where == "start" halo goes from buffer start to core start + # the window of data we want to send starts here + start = shape.buffer_start(d) + # calculate where the current slice starts (lowest index, no lower than start) + slice_start = max( + start, shape.core_start(d) - (core_size * (slice + 1)) + ) + # calculate where the current slice ends (highest index, no higher than core_start) + # because slice >= 0 + slice_end = max( + start, shape.core_start(d) - (core_size * slice) + ) + + # stop swapping if swap is empty + if slice_end == slice_start: + return + swap.append((slice_start, slice_end)) + else: + # where == "end" halo goes from core end to buffer end + + # the window of data we want to send ends here (highest index) + end = shape.buffer_end(d) + # calculate where the current slice starts (lowest index, no lower than start) + # because slice >= 0, and no higher than end + slice_start = min(end, shape.core_end(d) + (core_size * slice)) + # calculate where the current slice ends (highest index, no higher than core_start) + slice_end = min( + end, shape.core_end(d) + (core_size * (slice + 1)) + ) + + # stop swapping if swap is empty + if slice_end == slice_start: + return + swap.append((slice_start, slice_end)) + else: - # "end" halo goes from core end to buffer end - yield shape.core_end(d), shape.buffer_end(d) - else: - # for the sliced regions, "extrude" from core - # this way we don't exchange edges - yield shape.core_start(d), shape.core_end(d) + # for the sliced regions, "extrude" from core + # this way we don't exchange edges + swap.append((shape.core_start(d), shape.core_end(d))) - ex1_coords = tuple(coords("end")) - ex2_coords = tuple(coords("start")) + slice += 1 + yield tuple(swap) return ( # towards positive dim: - ExchangeDeclarationAttr.from_points( - ex1_coords, - axis, - dir_sign=1, + *( + ExchangeDeclarationAttr.from_points( + ex1_coords, axis, dir_sign=1, neighbor_offset=i + 1 + ) + for i, ex1_coords in enumerate(coords("end")) ), # towards negative dim: - ExchangeDeclarationAttr.from_points( - ex2_coords, - axis, - dir_sign=-1, + *( + ExchangeDeclarationAttr.from_points( + ex2_coords, axis, dir_sign=-1, neighbor_offset=i + 1 + ) + for i, ex2_coords in enumerate(coords("start")) ), ) From 9f79087f75ab4f34852f73389f27a169fac626ad Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Thu, 3 Oct 2024 11:09:46 +0100 Subject: [PATCH 2/3] misc: formatting --- xdsl/dialects/experimental/dmp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/dialects/experimental/dmp.py b/xdsl/dialects/experimental/dmp.py index 70cdd5546b..38fa1736ad 100644 --- a/xdsl/dialects/experimental/dmp.py +++ b/xdsl/dialects/experimental/dmp.py @@ -13,7 +13,7 @@ from abc import ABC from collections.abc import Iterable, Sequence from math import prod -from typing import Literal, cast, TypeAlias +from typing import Literal, cast from xdsl.dialects import builtin, stencil from xdsl.ir import Attribute, Dialect, Operation, ParametrizedAttribute, SSAValue From 868811152531fa9c5b28cef3496a38c2ed35d7d6 Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Thu, 3 Oct 2024 11:38:47 +0100 Subject: [PATCH 3/3] tests: Add tests, and fix minor bugs --- tests/dialects/test_dmp.py | 43 ++++++++++++++++++++++++++++++- xdsl/dialects/experimental/dmp.py | 13 +++++++--- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/tests/dialects/test_dmp.py b/tests/dialects/test_dmp.py index 0b95e5195e..24772ccf0b 100644 --- a/tests/dialects/test_dmp.py +++ b/tests/dialects/test_dmp.py @@ -4,7 +4,7 @@ def flat_face_exchanges( shape: ShapeAttr, dim: int -) -> tuple[ExchangeDeclarationAttr, ExchangeDeclarationAttr]: +) -> tuple[ExchangeDeclarationAttr, ...]: # we need access to the _flat_face_exchanges_for_dim method in order to test it # since this is a private function, and pyright will yell whenever it's accessed, # we have this wrapper function here that takes care of making the private publicly @@ -77,3 +77,44 @@ def test_decomp_flat_face_4d(): assert ex_neg_y.size == (10, 4, 10, 10) assert ex_neg_y.source_offset == (0, 4, 0, 0) assert ex_neg_y.neighbor == (0, -1, 0, 0) + + +def test_decomp_with_overflow(): + shape = ShapeAttr.from_index_attrs( + (0, 0, 0), # buff lb + (2, 2, 2), # core lb + (3, 3, 102), # core ub + (5, 5, 104), # buff ub + ) + + exchanges = tuple(flat_face_exchanges(shape, 0)) + assert len(exchanges) == 4 + + ex_px_1, ex_px_2, ex_nx_1, ex_nx_2 = exchanges + + # all exchanges are of size (1, 1, 100) + assert all(ex.size == (1, 1, 100) for ex in exchanges) + + # the first exchange is closer to the core region + assert ex_px_1.offset == (3, 2, 2) + # and has a source offset of (-1, 0, 0) + assert ex_px_1.source_offset == (-1, 0, 0) + # the second exchange is farther away + assert ex_px_2.offset == (4, 2, 2) + # and has a source offset of twice that + assert ex_px_2.source_offset == (-2, 0, 0) + + # same for negative x, first exchange is closer to the core + assert ex_nx_1.offset == (1, 2, 2) + # and has a source offset of (1, 0, 0) + assert ex_nx_1.source_offset == (1, 0, 0) + # second is farther away + assert ex_nx_2.offset == (0, 2, 2) + # and has a source offset of (2, 0, 0) + assert ex_nx_2.source_offset == (2, 0, 0) + + assert ex_px_1.neighbor == (1, 0, 0) + assert ex_px_2.neighbor == (2, 0, 0) + + assert ex_nx_1.neighbor == (-1, 0, 0) + assert ex_nx_2.neighbor == (-2, 0, 0) diff --git a/xdsl/dialects/experimental/dmp.py b/xdsl/dialects/experimental/dmp.py index 38fa1736ad..07ae2336ec 100644 --- a/xdsl/dialects/experimental/dmp.py +++ b/xdsl/dialects/experimental/dmp.py @@ -116,7 +116,8 @@ def from_points( sizes, # source_offset (opposite of exchange direction) tuple( - 0 if d != dim else -1 * dir_sign * sizes[dim] for d in range(len(sizes)) + 0 if d != dim else -1 * dir_sign * sizes[dim] * neighbor_offset + for d in range(len(sizes)) ), # direction tuple( @@ -625,14 +626,20 @@ def coords(where: Literal["start", "end"]) -> Iterable[tuple[tuple[int, int], .. # towards positive dim: *( ExchangeDeclarationAttr.from_points( - ex1_coords, axis, dir_sign=1, neighbor_offset=i + 1 + ex1_coords, + axis, + dir_sign=1, + neighbor_offset=i + 1, ) for i, ex1_coords in enumerate(coords("end")) ), # towards negative dim: *( ExchangeDeclarationAttr.from_points( - ex2_coords, axis, dir_sign=-1, neighbor_offset=i + 1 + ex2_coords, + axis, + dir_sign=-1, + neighbor_offset=i + 1, ) for i, ex2_coords in enumerate(coords("start")) ),