diff --git a/torchrec/distributed/test_utils/multi_process.py b/torchrec/distributed/test_utils/multi_process.py index f3233e9b0..ac003d02b 100644 --- a/torchrec/distributed/test_utils/multi_process.py +++ b/torchrec/distributed/test_utils/multi_process.py @@ -9,6 +9,7 @@ #!/usr/bin/env python3 +import logging import multiprocessing import os import unittest @@ -24,11 +25,6 @@ ) -# AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail -# Therefore we use spawn for HIP runtime until AMD fixes the issue -_MP_INIT_MODE = "forkserver" if torch.version.hip is None else "spawn" - - class MultiProcessContext: def __init__( self, @@ -98,6 +94,15 @@ def __exit__(self, exc_type, exc_instance, traceback) -> None: class MultiProcessTestBase(unittest.TestCase): + def __init__( + self, methodName: str = "runTest", mp_init_mode: str = "forkserver" + ) -> None: + super().__init__(methodName) + + # AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail + # Therefore we use spawn for HIP runtime until AMD fixes the issue + self._mp_init_mode: str = mp_init_mode if torch.version.hip is None else "spawn" + logging.info(f"Using {self._mp_init_mode} for multiprocessing") @seed_and_log def setUp(self) -> None: @@ -131,7 +136,7 @@ def _run_multi_process_test( # pyre-ignore **kwargs, ) -> None: - ctx = multiprocessing.get_context(_MP_INIT_MODE) + ctx = multiprocessing.get_context(self._mp_init_mode) processes = [] for rank in range(world_size): kwargs["rank"] = rank @@ -157,7 +162,7 @@ def _run_multi_process_test_per_rank( world_size: int, kwargs_per_rank: List[Dict[str, Any]], ) -> None: - ctx = multiprocessing.get_context(_MP_INIT_MODE) + ctx = multiprocessing.get_context(self._mp_init_mode) processes = [] for rank in range(world_size): kwargs = {}