Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Disco][QoL] Implement broadcast/scatter methods for Session #17035

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 96 additions & 6 deletions python/tvm/runtime/disco/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,34 @@ def copy_from_worker_0(self, host_array: NDArray, remote_array: DRef) -> None:
"""
return _ffi_api.SessionCopyFromWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member

def copy_to_worker_0(self, host_array: NDArray, remote_array: DRef) -> None:
def copy_to_worker_0(self, host_array: NDArray, remote_array: Optional[DRef] = None) -> DRef:
"""Copy the controller-side NDArray to worker-0.

Parameters
----------
host_array : numpy.ndarray
The array to be copied from worker-0.
remote_array : NDArray
The NDArray on worker-0.
host_array : NDArray

The array to be copied to worker-0.

remote_array : Optiona[DRef]

The destination NDArray on worker-0.

Returns
-------
output_array: DRef

The DRef containing the copied data on worker0, and
NullOpt on all other workers. If `remote_array` was
provided, this return value is the same as `remote_array`.
Otherwise, it is the newly allocated space.

"""
return _ffi_api.SessionCopyToWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member
if remote_array is None:
remote_array = self.empty(host_array.shape, host_array.dtype, worker0_only=True)

_ffi_api.SessionCopyToWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member
return remote_array

def load_vm_module(
self,
Expand Down Expand Up @@ -302,6 +319,40 @@ def init_ccl(self, ccl: str, *device_ids):
_ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) # type: ignore # pylint: disable=no-member
self._clear_ipc_memory_pool()

def broadcast(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef:
"""Broadcast an array to all workers

Parameters
----------
src: Union[np.ndarray, NDArray]

The array to be broadcasted.

dst: Optional[DRef]

The output array. If None, an array matching the shape
and dtype of `src` will be allocated on each worker.

Returns
-------
output_array: DRef

The DRef containing the broadcasted data on all workers.
If `dst` was provided, this return value is the same as
`dst`. Otherwise, it is the newly allocated space.

"""
if not isinstance(src, NDArray):
src = _as_NDArray(src)

if dst is None:
dst = self.empty(src.shape, src.dtype)

src_dref = self.copy_to_worker_0(src)
self.broadcast_from_worker0(src_dref, dst)

return dst

def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
"""Broadcast an array from worker-0 to all other workers.

Expand All @@ -313,6 +364,45 @@ def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef:
func = self._get_cached_method("runtime.disco.broadcast_from_worker0")
func(src, dst)

def scatter(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef:
"""Scatter an array across all workers

Parameters
----------
src: Union[np.ndarray, NDArray]

The array to be scattered. The first dimension of this
array, `src.shape[0]`, must be equal to the number of
workers.

dst: Optional[DRef]

The output array. If None, an array with compatible shape
and the same dtype as `src` will be allocated on each
worker.

Returns
-------
output_array: DRef

The DRef containing the scattered data on all workers.
If `dst` was provided, this return value is the same as
`dst`. Otherwise, it is the newly allocated space.

"""
assert src.shape[0] == self.num_workers

if not isinstance(src, NDArray):
src = _as_NDArray(src)

if dst is None:
dst = self.empty(src.shape[1:], src.dtype)

src_dref = self.copy_to_worker_0(src)
self.scatter_from_worker0(src_dref, dst)

return dst

def scatter_from_worker0(self, from_array: DRef, to_array: DRef) -> None:
"""Scatter an array from worker-0 to all other workers.

Expand Down
70 changes: 62 additions & 8 deletions tests/python/disco/test_ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,33 +103,87 @@ def test_allgather(session_kind, ccl):

@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_broadcast_from_worker0(session_kind, ccl):
@pytest.mark.parametrize("use_explicit_output", [True, False])
def test_broadcast_from_worker0(session_kind, ccl, use_explicit_output):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)

array = np.arange(12, dtype="float32").reshape(3, 4)
d_array = sess.empty((3, 4), "float32", worker0_only=True)
d_array.debug_copy_from(0, array)
dst_array = sess.empty((3, 4), "float32")
sess.broadcast_from_worker0(d_array, dst_array)

if use_explicit_output:
src_array = sess.empty((3, 4), "float32", worker0_only=True)
src_array.debug_copy_from(0, array)
dst_array = sess.empty((3, 4), "float32")
sess.broadcast_from_worker0(src_array, dst_array)
else:
dst_array = sess.broadcast(array)

result = dst_array.debug_get_from_remote(1).numpy()
np.testing.assert_equal(result, array)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_scatter(session_kind, ccl, capfd):
@pytest.mark.parametrize("use_explicit_output", [True, False])
def test_scatter(session_kind, ccl, use_explicit_output, capfd):
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)

array = np.arange(36, dtype="float32").reshape(2, 6, 3)

if use_explicit_output:
d_src = sess.empty((2, 6, 3), "float32", worker0_only=True)
d_dst = sess.empty((6, 3), "float32")
d_src.debug_copy_from(0, array)
sess.scatter_from_worker0(d_src, d_dst)
else:
d_dst = sess.scatter(array)

np.testing.assert_equal(
d_dst.debug_get_from_remote(0).numpy(),
array[0, :, :],
)
np.testing.assert_equal(
d_dst.debug_get_from_remote(1).numpy(),
array[1, :, :],
)

captured = capfd.readouterr()
assert (
not captured.err
), "No warning messages should be generated from disco.Session.scatter_from_worker0"


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_scatter_with_implicit_reshape(session_kind, ccl, capfd):
"""Scatter may perform an implicit reshape

Scattering elements to the workers requires the total number of
elements to be divisible by the number of workers. It does not
necessarily correspond to scattering across the outermost
dimension. Here, the number of workers (2) and the outermost
dimension (3) are not divisible, but the scatter may still be
performed.

This is only allowed when the caller explicitly uses the
`sess.scatter_from_worker0` method, and is not allowed in
`sess.scatter` method. Because the `sess.scatter` method may
perform an allocation on the disco workers, it requires that the
scatter occur across the outermost dimension.

"""
devices = [0, 1]
sess = session_kind(num_workers=len(devices))
sess.init_ccl(ccl, *devices)

array = np.arange(36, dtype="float32").reshape(3, 4, 3)

d_src = sess.empty((3, 4, 3), "float32", worker0_only=True)
d_dst = sess.empty((3, 3, 2), "float32")

d_src.debug_copy_from(0, array)

sess.scatter_from_worker0(d_src, d_dst)

np.testing.assert_equal(
Expand Down
Loading