-
Notifications
You must be signed in to change notification settings - Fork 83
[WIP] Support multiprocessing training #141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@pytest.mark.slow | ||
def test_run_net_distributed_multiproc_save_all_workers(): | ||
size = 2 | ||
os.environ["SMDEBUG_NUM_WORKERS"] = "2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
monkeypatch, see here -
monkeypatch.setenv("TF_CONFIG", json.dumps({})) |
if hasattr(dist, "is_initialized") and dist.is_initialized(): | ||
average_gradients(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reason for these changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when using the multiprocessing approach, torch.distributed is not used (init_process_group is not called). so, any reference to dist.get_rank or dist.get_world_size() will error out.
if hasattr(dist, "is_initialized") and dist.is_initialized(): | ||
assert hook._get_worker_name() == f"worker_{dist.get_rank()}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same her?
if hasattr(dist, "is_initialized") and dist.is_initialized(): | ||
average_gradients(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reason?
@vandanavk updates? |
In last sprint meeting, we decided to modify the example to use torch.distributed instead of multiprocessing directly. The solution in this PR doesn't fix all scenarios - example, with include_worker="one" |
* Bugfix: Invalid Worker (#139) * smdistributed.dataparallel environment check * addressed comments * Modified check_smdataparallel_env logic Co-authored-by: Nihal Harish <nihal42harish@gmail.com> Co-authored-by: Karan Jariwala <karankjariwala@gmail.com>
Description of changes:
Training a model by splitting the dataset across multiple processes on the same machine is considered distributed training. While using SM Debugger along with a distributed training script, the user can provide the option to save data from all workers or just 1 worker (include_workers in the hook).
Currently, this option is used in the following scenarios:
The example in https://github.com/awslabs/amazon-sagemaker-examples/tree/master/sagemaker-python-sdk/dgl_kge uses Python multiprocessing for MXNet distributed training and torch.multiprocessing for PyTorch distributed training. Performing distributed training using multiprocessing is not yet handled by smdebug.
To handle this scenario, introducing env variables SMDEBUG_NUM_WORKERS and SMDEBUG_WORKER_NAME. User must specify these if they are using multiprocessing library in the training script.
The other alternatives considered were:
sagemaker-debugger/tests/pytorch/test_distributed_training.py
Line 98 in 2433316
Style and formatting:
I have run
pre-commit install
to ensure that auto-formatting happens with every commit.Issue number, if available
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.