Skip to content

Commit

Permalink
[Disco][QoL] Implement broadcast/scatter methods for Session (#17035)
Browse files Browse the repository at this point in the history
* [Disco][QoL] Implement broadcast/scatter methods for Session

Prior to this commit, use of the `disco.Session` API to broadcast or
scatter an array required several steps from the caller.

1. Allocate memory on worker0
2. Transfer data from the controller to worker0
3. Allocate memory on each worker
4. Broadcast/scatter data from worker0 to all workers

While exposing these steps is necessary for performance, especially
when used repeatedly, it can be tedious/error-prone to use for
initialization that is only performed once.

This commit adds utility methods `Session.broadcast` and
`Session.scatter`, which are implemented in terms of the existing
lower-level methods `Session.broadcast_from_worker0` and
`Session.scatter_from_worker0`.  These methods perform the transfer
from the controller to worker0, and from worker0 to all other
workers.

* lint fix
  • Loading branch information
Lunderberg authored May 30, 2024
1 parent f6aab98 commit 7c2c0d9
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 14 deletions.
102 changes: 96 additions & 6 deletions python/tvm/runtime/disco/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,34 @@ def copy_from_worker_0(self, host_array: NDArray, remote_array: DRef) -> None:
"""
return _ffi_api.SessionCopyFromWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member

def copy_to_worker_0(self, host_array: NDArray, remote_array: DRef) -> None:
def copy_to_worker_0(self, host_array: NDArray, remote_array: Optional[DRef] = None) -> DRef:
"""Copy the controller-side NDArray to worker-0.
Parameters
----------
host_array : numpy.ndarray
The array to be copied from worker-0.
remote_array : NDArray
The NDArray on worker-0.
host_array : NDArray
The array to be copied to worker-0.
remote_array : Optiona[DRef]
The destination NDArray on worker-0.
Returns
-------
output_array: DRef
The DRef containing the copied data on worker0, and
NullOpt on all other workers. If `remote_array` was
provided, this return value is the same as `remote_array`.
Otherwise, it is the newly allocated space.
"""
return _ffi_api.SessionCopyToWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member
if remote_array is None:
remote_array = self.empty(host_array.shape, host_array.dtype, worker0_only=True)

_ffi_api.SessionCopyToWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member
return remote_array

def load_vm_module(
self,
Expand Down Expand Up @@ -302,6 +319,40 @@ def init_ccl(self, ccl: str, *device_ids):
_ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) # type: ignore # pylint: disable=no-member
self._clear_ipc_memory_pool()

def broadcast(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef:
"""Broadcast an array to all workers
Parameters
----------
src: Union[np.ndarray, NDArray]
The array to be broadcasted.
dst: Optional[DRef]
The output array. If None, an array matching the shape
and dtype of `src` will be allocated on each worker.
Returns
-------
output_array: DRef
The DRef containing the broadcasted data on all workers.
If `dst` was provided, this return value is the same as
`dst`. Otherwise, it is the newly allocated space.
"""
if not isinstance(src, NDArray):
src = _as_NDArray(src)

if dst is None:
dst = self.empty(src.shape, src.dtype)

src_dref = self.copy_to_worker_0(src)
self.broadcast_from_worker0(src_dref, dst)

return dst

def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
"""Broadcast an array from worker-0 to all other workers.
Expand All @@ -313,6 +364,45 @@ def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
func = self._get_cached_method("runtime.disco.broadcast_from_worker0")
func(src, dst)

def scatter(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef:
"""Scatter an array across all workers
Parameters
----------
src: Union[np.ndarray, NDArray]
The array to be scattered. The first dimension of this
array, `src.shape[0]`, must be equal to the number of
workers.
dst: Optional[DRef]
The output array. If None, an array with compatible shape
and the same dtype as `src` will be allocated on each
worker.
Returns
-------
output_array: DRef
The DRef containing the scattered data on all workers.
If `dst` was provided, this return value is the same as
`dst`. Otherwise, it is the newly allocated space.
"""
assert src.shape[0] == self.num_workers

if not isinstance(src, NDArray):
src = _as_NDArray(src)

if dst is None:
dst = self.empty(src.shape[1:], src.dtype)

src_dref = self.copy_to_worker_0(src)
self.scatter_from_worker0(src_dref, dst)

return dst

def scatter_from_worker0(self, from_array: DRef, to_array: DRef) -> None:
"""Scatter an array from worker-0 to all other workers.
Expand Down
70 changes: 62 additions & 8 deletions tests/python/disco/test_ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,33 +103,87 @@ def test_allgather(session_kind, ccl):

@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_broadcast_from_worker0(session_kind, ccl):
@pytest.mark.parametrize("use_explicit_output", [True, False])
def test_broadcast_from_worker0(session_kind, ccl, use_explicit_output):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)

array = np.arange(12, dtype="float32").reshape(3, 4)
d_array = sess.empty((3, 4), "float32", worker0_only=True)
d_array.debug_copy_from(0, array)
dst_array = sess.empty((3, 4), "float32")
sess.broadcast_from_worker0(d_array, dst_array)

if use_explicit_output:
src_array = sess.empty((3, 4), "float32", worker0_only=True)
src_array.debug_copy_from(0, array)
dst_array = sess.empty((3, 4), "float32")
sess.broadcast_from_worker0(src_array, dst_array)
else:
dst_array = sess.broadcast(array)

result = dst_array.debug_get_from_remote(1).numpy()
np.testing.assert_equal(result, array)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_scatter(session_kind, ccl, capfd):
@pytest.mark.parametrize("use_explicit_output", [True, False])
def test_scatter(session_kind, ccl, use_explicit_output, capfd):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)

array = np.arange(36, dtype="float32").reshape(2, 6, 3)

if use_explicit_output:
d_src = sess.empty((2, 6, 3), "float32", worker0_only=True)
d_dst = sess.empty((6, 3), "float32")
d_src.debug_copy_from(0, array)
sess.scatter_from_worker0(d_src, d_dst)
else:
d_dst = sess.scatter(array)

np.testing.assert_equal(
d_dst.debug_get_from_remote(0).numpy(),
array[0, :, :],
)
np.testing.assert_equal(
d_dst.debug_get_from_remote(1).numpy(),
array[1, :, :],
)

captured = capfd.readouterr()
assert (
not captured.err
), "No warning messages should be generated from disco.Session.scatter_from_worker0"


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_scatter_with_implicit_reshape(session_kind, ccl, capfd):
"""Scatter may perform an implicit reshape
Scattering elements to the workers requires the total number of
elements to be divisible by the number of workers. It does not
necessarily correspond to scattering across the outermost
dimension. Here, the number of workers (2) and the outermost
dimension (3) are not divisible, but the scatter may still be
performed.
This is only allowed when the caller explicitly uses the
`sess.scatter_from_worker0` method, and is not allowed in
`sess.scatter` method. Because the `sess.scatter` method may
perform an allocation on the disco workers, it requires that the
scatter occur across the outermost dimension.
"""
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)

array = np.arange(36, dtype="float32").reshape(3, 4, 3)

d_src = sess.empty((3, 4, 3), "float32", worker0_only=True)
d_dst = sess.empty((3, 3, 2), "float32")

d_src.debug_copy_from(0, array)

sess.scatter_from_worker0(d_src, d_dst)

np.testing.assert_equal(
Expand Down

0 comments on commit 7c2c0d9

Please sign in to comment.