Skip to content

Commit

Permalink
ProcessGroupBabyNCCL: support multiple streams and use event on start
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jan 30, 2025
1 parent 68e1d28 commit 7d48c52
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 47 deletions.
193 changes: 149 additions & 44 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
import logging
import queue
import threading
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -586,29 +589,59 @@ def __init__(
self._timeout = timeout

def wait(self, timeout: Optional[timedelta] = None) -> bool:
self._pg._assert_alive()

self._tx.put(("wait", self._op_id), timeout=self._timeout)
assert _get(self._rx, self._timeout) == self._op_id
op_id, event = cast(
Tuple[int, Optional[torch.cuda.Event]],
_get(self._rx, timeout or self._timeout),
)
assert op_id == self._op_id
if event is not None:
event.wait()
return True

def synchronize(self) -> None:
# TODO: No one seems to use this and NCCL wait already only waits the
# stream and is non-blocking on the CPU side so no real need for a
# separate call.
raise NotImplementedError("not implemented")

def get_future(self) -> Future[object]:
return self._pg._get_future(self._op_id)

def __del__(self) -> None:
self._tx.put(("del", self._op_id), timeout=self._timeout)


class _BabyWorkNCCL(_BabyWork):
def wait(self, timeout: Optional[timedelta] = None) -> bool:
self._tx.put(("synchronize", self._op_id), timeout=self._timeout)
# pyre-fixme[23]: unable to unpack into 2 values
op_id, event = _get(self._rx, self._timeout)
assert op_id == self._op_id
assert isinstance(event, torch.cuda.Event)
def _is_any_cuda(obj: object) -> bool:
"""
Returns true if any of the tensors in the object are CUDA tensors.
# Wait on Event makes the stream wait but not the CPU thread.
event.wait()
Supports lists, tuples, dicts, and tensors.
"""
if isinstance(obj, torch.Tensor):
return obj.is_cuda
elif isinstance(obj, (list, tuple)):
return any(_is_any_cuda(o) for o in obj)
elif isinstance(obj, dict):
return any(_is_any_cuda(o) for o in obj.values())
else:
return False

return True

@dataclass
class _OpMetadata:
work: Work
stream: Optional[torch.cuda.Stream]

@contextmanager
def set_stream(self) -> Generator[None, None, None]:
if self.stream is not None:
with torch.cuda.stream(self.stream):
yield
else:
yield


class ProcessGroupBaby(ProcessGroup):
Expand All @@ -617,11 +650,8 @@ class ProcessGroupBaby(ProcessGroup):
subprocess. Since it's running in a subprocess all tensors need to be in
shared memory or will be moved to shared memory. CUDA tensors are implicitly
share able and don't need any changes.
"""

WORK_CLASS: Type[_BabyWork] = _BabyWork

def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
super().__init__(0, 1)

Expand Down Expand Up @@ -679,7 +709,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:

self._p = ctx.Process(
target=self._worker,
args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue),
args=(
store_addr,
rank,
world_size,
self._tx,
self._rx,
self._future_queue,
),
daemon=True,
)
self._p.start()
Expand Down Expand Up @@ -716,23 +753,76 @@ def _worker(
return
tx.put(None)

work = {}
streams: Dict[str, torch.cuda.Stream] = {}
work: Dict[int, _OpMetadata] = {}
next_op_id: int = 0

while True:
op = rx.get()
cmd = op[0]
if cmd == "func":
func_name, args, kwargs = op[1:]
args = _PickleSafeOptions.unsafe_args(args)
fn = getattr(pg, func_name)
work[next_op_id] = fn(*args, **kwargs)
func_name, args, kwargs, stream_device, stream_id, event = op[1:]

print(f"func {func_name=}")

# To avoid potential deadlocks we need to preserve the
# stream/synchronization behavior of the parent process.
# We allocate one Stream per stream_id to make sure that we
# don't accidentally introduce cross stream synchronization
# points.
if stream_id is not None:
stream_key = f"{stream_device}/{stream_id}"
if stream_key not in streams:
streams[stream_key] = torch.cuda.Stream(
device=stream_device
)
stream = streams[stream_key]
else:
stream = None

with (
torch.cuda.stream(stream)
if stream is not None
else nullcontext()
):
print("stream created")

# Make the stream wait on the cuda event to make sure we
# don't start the operation until the tensor is ready.
if event is not None:
event.wait()

print("waited")

args = _PickleSafeOptions.unsafe_args(args)
fn = getattr(pg, func_name)
work[next_op_id] = _OpMetadata(
work=fn(*args, **kwargs),
stream=stream,
)
tx.put(next_op_id)
next_op_id += 1
elif cmd == "wait":
op_id: int = op[1]
work[op_id].wait()
tx.put(op_id)

metadata = work[op_id]

with metadata.set_stream():
# With WorkNCCL this makes the stream wait not the CPU when
# no timeout is passed.
metadata.work.wait()

# Register event on the stream that we can pass to the main
# process.
event = (
torch.cuda.current_stream().record_event(
torch.cuda.Event(interprocess=True)
)
if metadata.stream is not None
else None
)

tx.put((op_id, event))
elif cmd == "del":
op_id: int = op[1]
del work[op_id]
Expand All @@ -746,23 +836,8 @@ def callback(fut: Future[object]) -> None:
except Exception as e:
future_queue.put((op_id, _FUTURE_EXCEPTION, e))

work[op_id].get_future().add_done_callback(callback)
work[op_id].work.get_future().add_done_callback(callback)
tx.put(op_id)
elif cmd == "synchronize":
# CUDA only, use events instead of waiting on CPU
op_id = op[1]

# With WorkNCCL this makes the stream wait not the CPU when
# no timeout is passed.
work[op_id].wait()

# Register event on the stream that we can pass to the main
# process.
event = torch.cuda.Event(interprocess=True)
event.record()

del work[op_id]
tx.put((op_id, event))
elif cmd == "num_active_work":
tx.put(len(work))
else:
Expand Down Expand Up @@ -792,6 +867,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
logger.exception(f"got unexpected error in future handler: {e}")

def _get_future(self, op_id: int) -> Future[object]:
self._assert_alive()

with self._futures_lock:
fut = Future() # pyre-fixme[29]: is not a function
self._futures[op_id] = fut
Expand All @@ -804,22 +881,52 @@ def _get_future(self, op_id: int) -> Future[object]:
return fut

def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
self._assert_alive()

rx = self._rx
tx = self._tx
assert rx is not None
assert tx is not None

is_cuda = _is_any_cuda(args)

stream_device = torch.cuda.current_stream().device if is_cuda else None
stream_id = torch.cuda.current_stream().stream_id if is_cuda else None
event = (
torch.cuda.current_stream().record_event(
torch.cuda.Event(interprocess=True)
)
if is_cuda
else None
)

tx.put(
("func", func, _PickleSafeOptions.safe_args(args), kwargs),
(
"func",
func,
_PickleSafeOptions.safe_args(args),
kwargs,
stream_device,
stream_id,
event,
),
timeout=self._timeout,
)

op_id = _get(rx, self._timeout)
assert isinstance(op_id, int), f"invalid return {op_id}"

return self.WORK_CLASS(
pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout
)
return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout)

def _assert_alive(self) -> None:
"""
Assert that the process group is alive. This is used to ensure that
operations are not performed on a dead process group and any errors are surfaced.
"""
p = self._p
assert p is not None
if not p.is_alive():
raise RuntimeError(f"child process {p.pid=} is dead {p.exitcode=}")

def allreduce(
self,
Expand Down Expand Up @@ -952,8 +1059,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
tensors may leak in the current PyTorch implementation. TODO fix
"""

WORK_CLASS = _BabyWorkNCCL

@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
# pyre-fixme[16]: no attribute ProcessGroupNCCL
Expand Down
34 changes: 31 additions & 3 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,31 @@ def test_baby_gloo_apis(self) -> None:

self.assertEqual(a.num_active_work(), 0)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
@skipUnless(torch.cuda.is_available(), "needs CUDA")
def test_baby_nccl_apis(self) -> None:
# set to 1 if more than >=2 gpus
device_id = 1 % torch.cuda.device_count()
torch.cuda.set_device(device_id)

store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)

store_addr = f"localhost:{store.port}/prefix"

a = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10))
a.configure(store_addr, 0, 1)

_test_pg(a, torch.randn((2, 3), device="cuda"))

torch.cuda.synchronize()

# force collection to ensure no BabyWork objects remain
gc.collect()

self.assertEqual(a.num_active_work(), 0)

def test_dummy(self) -> None:
pg = ProcessGroupDummy(0, 1)
m = nn.Linear(3, 4)
Expand All @@ -282,12 +307,15 @@ def test_baby_nccl_2gpu(self) -> None:
store_addr: str = f"localhost:{store.port}/prefix"

def run(rank: int) -> Tuple[torch.Tensor, Work]:
a = ProcessGroupBabyNCCL()
a = ProcessGroupBabyNCCL(
timeout=timedelta(seconds=10.0),
)
a.configure(store_addr, rank, 2)

self.assertEqual(a.size(), 2)

at = torch.tensor([rank + 1], device=f"cuda:{rank}")
# We test using set_device to ensure stream device is correct.
torch.cuda.set_device(rank)
at = torch.tensor([rank + 1], device="cuda")

a_work = a.allreduce([at], ReduceOp.SUM)
return at, a_work
Expand Down

0 comments on commit 7d48c52

Please sign in to comment.