diff --git a/smdebug/mxnet/hook.py b/smdebug/mxnet/hook.py index ff1678a35..cb4e150a0 100644 --- a/smdebug/mxnet/hook.py +++ b/smdebug/mxnet/hook.py @@ -1,4 +1,7 @@ # Third Party +# Standard Library +import os + import mxnet as mx # First Party @@ -67,7 +70,8 @@ def _get_worker_name(self): return f"worker_{hvd.rank()}" except (ModuleNotFoundError, ValueError, ImportError): pass - return DEFAULT_WORKER_NAME + worker_name = f"worker_" + os.getenv("SMDEBUG_WORKER_RANK", str(0)) + return worker_name def _get_num_workers(self): try: @@ -77,7 +81,7 @@ def _get_num_workers(self): return hvd.size() except (ModuleNotFoundError, ValueError, ImportError): pass - return 1 + return int(os.getenv("SMDEBUG_NUM_WORKERS", 1)) def _cleanup(self): # Write the gradients of the past step if the writer is still available. diff --git a/smdebug/pytorch/hook.py b/smdebug/pytorch/hook.py index c1393d0e9..acd1047ed 100644 --- a/smdebug/pytorch/hook.py +++ b/smdebug/pytorch/hook.py @@ -1,4 +1,5 @@ # Standard Library +import os # Third Party import torch @@ -70,7 +71,7 @@ def _get_num_workers(self): except (ModuleNotFoundError, ValueError, ImportError): pass # Return default - return 1 + return int(os.getenv("SMDEBUG_NUM_WORKERS", 1)) def _get_worker_name(self): """Check horovod and torch.distributed.""" @@ -87,8 +88,8 @@ def _get_worker_name(self): return f"worker_{hvd.rank()}" except (ModuleNotFoundError, ValueError, ImportError): pass - # Return default - return DEFAULT_WORKER_NAME + worker_name = f"worker_" + os.getenv("SMDEBUG_WORKER_RANK", str(0)) + return worker_name def _log_params(self, module): module_name = module._get_name() diff --git a/smdebug/tensorflow/base_hook.py b/smdebug/tensorflow/base_hook.py index bedfb4405..88062d68b 100644 --- a/smdebug/tensorflow/base_hook.py +++ b/smdebug/tensorflow/base_hook.py @@ -175,7 +175,8 @@ def _get_worker_name(self) -> str: elif self.distribution_strategy == TFDistributionStrategy.PARAMETER_SERVER: return get_worker_id_from_tf_config(self.tf_config_json) elif self.distribution_strategy == TFDistributionStrategy.NONE: - return DEFAULT_WORKER_NAME + worker_name = f"worker_" + os.getenv("SMDEBUG_WORKER_RANK", str(0)) + return worker_name elif self.distribution_strategy == TFDistributionStrategy.UNSUPPORTED: raise NotImplementedError @@ -220,7 +221,7 @@ def _get_num_workers(self): elif self.distribution_strategy == TFDistributionStrategy.PARAMETER_SERVER: return get_num_workers_from_tf_config(self.tf_config_json) elif self.distribution_strategy == TFDistributionStrategy.NONE: - return 1 + return int(os.getenv("SMDEBUG_NUM_WORKERS", 1)) elif self.distribution_strategy == TFDistributionStrategy.UNSUPPORTED: raise NotImplementedError diff --git a/tests/pytorch/test_distributed_training.py b/tests/pytorch/test_distributed_training.py index 031cda791..8c825e003 100644 --- a/tests/pytorch/test_distributed_training.py +++ b/tests/pytorch/test_distributed_training.py @@ -63,8 +63,9 @@ def train(model, device, optimizer, num_steps=10): optimizer.step() -def run(rank, size, include_workers="one", num_epochs=10, batch_size=128, num_batches=10): +def run(monkeypatch, rank, size, include_workers="one", num_epochs=10, batch_size=128, num_batches=10): """Distributed function to be implemented later.""" + monkeypatch.setenv("SMDEBUG_WORKER_RANK", str(rank)) torch.manual_seed(1234) device = torch.device("cpu") model = Net().to(device) @@ -90,11 +91,13 @@ def run(rank, size, include_workers="one", num_epochs=10, batch_size=128, num_ba loss = F.mse_loss(output, target) epoch_loss += loss.item() loss.backward() - average_gradients(model) + if hasattr(dist, "is_initialized") and dist.is_initialized(): + average_gradients(model) optimizer.step() # print(f"Rank {dist.get_rank()}, epoch {epoch}: {epoch_loss / num_batches}") - assert hook._get_worker_name() == f"worker_{dist.get_rank()}" + if hasattr(dist, "is_initialized") and dist.is_initialized(): + assert hook._get_worker_name() == f"worker_{dist.get_rank()}" # Race condition here where both workers attempt to move # /tmp/{out_dir}/END_OF_JOB.ts to {out_dir}/END_OF_JOB.ts try: @@ -179,3 +182,45 @@ def test_run_net_distributed_save_one_worker(): trial = _run_net_distributed(include_workers="one") assert len(trial.workers()) == 1, f"trial.workers() = {trial.workers()}" assert len(trial.steps()) == 3, f"trial.steps() = {trial.steps()}" + + +@pytest.mark.slow +def test_run_net_distributed_multiproc_save_all_workers(monkeypatch): + size = 2 + monkeypatch.setenv("SMDEBUG_NUM_WORKERS", str(size)) + processes = [] + for rank in range(size): + p = Process(target=run, args=(monkeypatch, rank, size, "all")) + p.start() + processes.append(p) + + for p in processes: + p.join() + + out_dir = "/tmp/run" + trial = create_trial(path=out_dir) + assert len(trial.workers()) == 2, f"trial.workers() = {trial.workers()}" + assert len(trial.steps()) == 3, f"trial.steps() = {trial.steps()}" + + del os.environ["SMDEBUG_NUM_WORKERS"] + + +@pytest.mark.slow +def test_run_net_distributed_multiproc_save_one_worker(monkeypatch): + size = 2 + monkeypatch.setenv("SMDEBUG_NUM_WORKERS", str(size)) + processes = [] + for rank in range(size): + p = Process(target=run, args=(monkeypatch, rank, size, "one")) + p.start() + processes.append(p) + + for p in processes: + p.join() + + out_dir = "/tmp/run" + trial = create_trial(path=out_dir) + assert len(trial.workers()) == 1, f"trial.workers() = {trial.workers()}" + assert len(trial.steps()) == 3, f"trial.steps() = {trial.steps()}" + + del os.environ["SMDEBUG_NUM_WORKERS"]