Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions smdebug/mxnet/hook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Third Party
# Standard Library
import os

import mxnet as mx

# First Party
Expand Down Expand Up @@ -67,7 +70,8 @@ def _get_worker_name(self):
return f"worker_{hvd.rank()}"
except (ModuleNotFoundError, ValueError, ImportError):
pass
return DEFAULT_WORKER_NAME
worker_name = f"worker_" + os.getenv("SMDEBUG_WORKER_RANK", str(0))
return worker_name

def _get_num_workers(self):
try:
Expand All @@ -77,7 +81,7 @@ def _get_num_workers(self):
return hvd.size()
except (ModuleNotFoundError, ValueError, ImportError):
pass
return 1
return int(os.getenv("SMDEBUG_NUM_WORKERS", 1))

def _cleanup(self):
# Write the gradients of the past step if the writer is still available.
Expand Down
7 changes: 4 additions & 3 deletions smdebug/pytorch/hook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Standard Library
import os

# Third Party
import torch
Expand Down Expand Up @@ -70,7 +71,7 @@ def _get_num_workers(self):
except (ModuleNotFoundError, ValueError, ImportError):
pass
# Return default
return 1
return int(os.getenv("SMDEBUG_NUM_WORKERS", 1))

def _get_worker_name(self):
"""Check horovod and torch.distributed."""
Expand All @@ -87,8 +88,8 @@ def _get_worker_name(self):
return f"worker_{hvd.rank()}"
except (ModuleNotFoundError, ValueError, ImportError):
pass
# Return default
return DEFAULT_WORKER_NAME
worker_name = f"worker_" + os.getenv("SMDEBUG_WORKER_RANK", str(0))
return worker_name

def _log_params(self, module):
module_name = module._get_name()
Expand Down
5 changes: 3 additions & 2 deletions smdebug/tensorflow/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def _get_worker_name(self) -> str:
elif self.distribution_strategy == TFDistributionStrategy.PARAMETER_SERVER:
return get_worker_id_from_tf_config(self.tf_config_json)
elif self.distribution_strategy == TFDistributionStrategy.NONE:
return DEFAULT_WORKER_NAME
worker_name = f"worker_" + os.getenv("SMDEBUG_WORKER_RANK", str(0))
return worker_name
elif self.distribution_strategy == TFDistributionStrategy.UNSUPPORTED:
raise NotImplementedError

Expand Down Expand Up @@ -220,7 +221,7 @@ def _get_num_workers(self):
elif self.distribution_strategy == TFDistributionStrategy.PARAMETER_SERVER:
return get_num_workers_from_tf_config(self.tf_config_json)
elif self.distribution_strategy == TFDistributionStrategy.NONE:
return 1
return int(os.getenv("SMDEBUG_NUM_WORKERS", 1))
elif self.distribution_strategy == TFDistributionStrategy.UNSUPPORTED:
raise NotImplementedError

Expand Down
51 changes: 48 additions & 3 deletions tests/pytorch/test_distributed_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def train(model, device, optimizer, num_steps=10):
optimizer.step()


def run(rank, size, include_workers="one", num_epochs=10, batch_size=128, num_batches=10):
def run(monkeypatch, rank, size, include_workers="one", num_epochs=10, batch_size=128, num_batches=10):
"""Distributed function to be implemented later."""
monkeypatch.setenv("SMDEBUG_WORKER_RANK", str(rank))
torch.manual_seed(1234)
device = torch.device("cpu")
model = Net().to(device)
Expand All @@ -90,11 +91,13 @@ def run(rank, size, include_workers="one", num_epochs=10, batch_size=128, num_ba
loss = F.mse_loss(output, target)
epoch_loss += loss.item()
loss.backward()
average_gradients(model)
if hasattr(dist, "is_initialized") and dist.is_initialized():
average_gradients(model)
Comment on lines +94 to +95
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reason for these changes?

Copy link
Contributor Author

@vandanavk vandanavk Jan 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when using the multiprocessing approach, torch.distributed is not used (init_process_group is not called). so, any reference to dist.get_rank or dist.get_world_size() will error out.

Comment on lines +94 to +95
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reason?

optimizer.step()
# print(f"Rank {dist.get_rank()}, epoch {epoch}: {epoch_loss / num_batches}")

assert hook._get_worker_name() == f"worker_{dist.get_rank()}"
if hasattr(dist, "is_initialized") and dist.is_initialized():
assert hook._get_worker_name() == f"worker_{dist.get_rank()}"
Comment on lines +99 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same her?

# Race condition here where both workers attempt to move
# /tmp/{out_dir}/END_OF_JOB.ts to {out_dir}/END_OF_JOB.ts
try:
Expand Down Expand Up @@ -179,3 +182,45 @@ def test_run_net_distributed_save_one_worker():
trial = _run_net_distributed(include_workers="one")
assert len(trial.workers()) == 1, f"trial.workers() = {trial.workers()}"
assert len(trial.steps()) == 3, f"trial.steps() = {trial.steps()}"


@pytest.mark.slow
def test_run_net_distributed_multiproc_save_all_workers(monkeypatch):
size = 2
monkeypatch.setenv("SMDEBUG_NUM_WORKERS", str(size))
processes = []
for rank in range(size):
p = Process(target=run, args=(monkeypatch, rank, size, "all"))
p.start()
processes.append(p)

for p in processes:
p.join()

out_dir = "/tmp/run"
trial = create_trial(path=out_dir)
assert len(trial.workers()) == 2, f"trial.workers() = {trial.workers()}"
assert len(trial.steps()) == 3, f"trial.steps() = {trial.steps()}"

del os.environ["SMDEBUG_NUM_WORKERS"]


@pytest.mark.slow
def test_run_net_distributed_multiproc_save_one_worker(monkeypatch):
size = 2
monkeypatch.setenv("SMDEBUG_NUM_WORKERS", str(size))
processes = []
for rank in range(size):
p = Process(target=run, args=(monkeypatch, rank, size, "one"))
p.start()
processes.append(p)

for p in processes:
p.join()

out_dir = "/tmp/run"
trial = create_trial(path=out_dir)
assert len(trial.workers()) == 1, f"trial.workers() = {trial.workers()}"
assert len(trial.steps()) == 3, f"trial.steps() = {trial.steps()}"

del os.environ["SMDEBUG_NUM_WORKERS"]