Skip to content

Commit

Permalink
Use spawn instead of forkserver for MTIA
Browse files Browse the repository at this point in the history
Summary:
MTIA runtime doesn't seem to work with forkserver; folly::Singleton cannot work with fork. Switch to `spawn` when running on MTIA.

Error paste: P1684403488

Differential Revision: D66351758
  • Loading branch information
yvonne-lab authored and facebook-github-bot committed Jan 19, 2025
1 parent 9dfdfb8 commit 731284a
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions torchrec/distributed/test_utils/multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#!/usr/bin/env python3

import logging
import multiprocessing
import os
import unittest
Expand All @@ -24,11 +25,6 @@
)


# AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail
# Therefore we use spawn for HIP runtime until AMD fixes the issue
_MP_INIT_MODE = "forkserver" if torch.version.hip is None else "spawn"


class MultiProcessContext:
def __init__(
self,
Expand Down Expand Up @@ -98,6 +94,15 @@ def __exit__(self, exc_type, exc_instance, traceback) -> None:


class MultiProcessTestBase(unittest.TestCase):
def __init__(
self, methodName: str = "runTest", mp_init_mode: str = "forkserver"
) -> None:
super().__init__(methodName)

# AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail
# Therefore we use spawn for HIP runtime until AMD fixes the issue
self._mp_init_mode: str = mp_init_mode if torch.version.hip is None else "spawn"
logging.info(f"Using {self._mp_init_mode} for multiprocessing")

@seed_and_log
def setUp(self) -> None:
Expand Down Expand Up @@ -131,7 +136,7 @@ def _run_multi_process_test(
# pyre-ignore
**kwargs,
) -> None:
ctx = multiprocessing.get_context(_MP_INIT_MODE)
ctx = multiprocessing.get_context(self._mp_init_mode)
processes = []
for rank in range(world_size):
kwargs["rank"] = rank
Expand All @@ -157,7 +162,7 @@ def _run_multi_process_test_per_rank(
world_size: int,
kwargs_per_rank: List[Dict[str, Any]],
) -> None:
ctx = multiprocessing.get_context(_MP_INIT_MODE)
ctx = multiprocessing.get_context(self._mp_init_mode)
processes = []
for rank in range(world_size):
kwargs = {}
Expand Down

0 comments on commit 731284a

Please sign in to comment.