Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Resiliency] Straggler detection #9473

Merged
merged 16 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions docs/source/core/exp_manager.rst
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,50 @@ file followed by a graceful exit from the run. The checkpoint saved upon preempt
This feature is useful to increase utilization on clusters.
The ``PreemptionCallback`` is enabled by default. To disable it simply add ``create_preemption_callback: False`` under exp_manager in the config YAML file.

Stragglers Detection
----------------------

.. _exp_manager_straggler_det_support-label:

.. note::
Stragglers Detection feature is included in the optional NeMo resiliency package.

Distributed training can be affected by stragglers, which are slow workers that slow down the overall training process.
NeMo provides a straggler detection feature that can identify slower GPUs.

This feature is implemented in the ``StragglerDetectionCallback``, which is disabled by default.

The callback computes normalized GPU performance scores, which are scalar values ranging from 0.0 (worst) to 1.0 (best).
A performance score can be interpreted as the ratio of current performance to reference performance.

There are two types of performance scores provided by the callback:
- Relative GPU performance score: The best-performing GPU in the workload is used as a reference.
- Individual GPU performance score: The best historical performance of the GPU is used as a reference.

Examples:
- If the relative performance score is 0.5, it means that a GPU is twice slower than the fastest GPU.
- If the individual performance score is 0.5, it means that a GPU is twice slower than its best observed performance.

If a GPU performance score drops below the specified threshold, it is identified as a straggler.

To enable straggler detection, add ``create_straggler_detection_callback: True`` under exp_manager in the config YAML file.
You might also want to adjust the callback parameters:

.. code-block:: yaml

exp_manager:
...
create_straggler_detection_callback: True
straggler_detection_callback_params:
report_time_interval: 300 # Interval [seconds] of the straggler check
calc_relative_gpu_perf: True # Calculate relative GPU performance
calc_individual_gpu_perf: True # Calculate individual GPU performance
num_gpu_perf_scores_to_log: 5 # Log 5 best and 5 worst GPU performance scores, even if no stragglers are detected
gpu_relative_perf_threshold: 0.7 # Threshold for relative GPU performance scores
gpu_individual_perf_threshold: 0.7 # Threshold for individual GPU performance scores
stop_if_detected: True # Terminate the workload if stragglers are detected

Straggler detection might involve inter-rank synchronization, and should be invoked with reasonable frequency (e.g. every few minutes).

.. _nemo_multirun-label:

Expand Down
34 changes: 34 additions & 0 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@
from nemo.utils.mcore_logger import add_handlers_to_mcore_logger
from nemo.utils.model_utils import uninject_model_parallel_rank

try:
# `ptl_resiliency` is included in `gwe_resiliency_pkg` package
from ptl_resiliency import StragglerDetectionCallback

HAVE_STRAGGLER_DET = True
except (ImportError, ModuleNotFoundError):
HAVE_STRAGGLER_DET = False


class NotFoundError(NeMoBaseException):
"""Raised when a file or folder is not found"""
Expand Down Expand Up @@ -127,6 +135,17 @@ class EMAParams:
every_n_steps: int = 1


@dataclass
class StragglerDetectionParams:
report_time_interval: float = 300
calc_relative_gpu_perf: bool = True
calc_individual_gpu_perf: bool = True
num_gpu_perf_scores_to_log: int = 5
gpu_relative_perf_threshold: float = 0.7
gpu_individual_perf_threshold: float = 0.7
stop_if_detected: bool = True


@dataclass
class ExpManagerConfig:
"""Experiment Manager config for validation of passed arguments."""
Expand Down Expand Up @@ -177,6 +196,9 @@ class ExpManagerConfig:
max_time_per_run: Optional[str] = None
# time to sleep non 0 ranks during initialization
seconds_to_sleep: float = 5
# Straggler detection
create_straggler_detection_callback: Optional[bool] = False
straggler_detection_params: Optional[StragglerDetectionParams] = field(default_factory=StragglerDetectionParams)


class TimingCallback(Callback):
Expand Down Expand Up @@ -307,6 +329,7 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
See EarlyStoppingParams dataclass above.
- create_preemption_callback (bool): Flag to decide whether to enable preemption callback to save checkpoints and exit training
immediately upon preemption. Default is True.
- create_straggler_detection_callback (bool): Use straggler detection callback. Default is False.
- files_to_copy (list): A list of files to copy to the experiment logging directory. Defaults to None which
copies no files.
- log_local_rank_0_only (bool): Whether to only create log files for local rank 0. Defaults to False.
Expand Down Expand Up @@ -500,6 +523,17 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
trainer.max_time = cfg.max_time_per_run
trainer.callbacks.append(StatelessTimer(cfg.max_time_per_run))

if cfg.create_straggler_detection_callback:
if HAVE_STRAGGLER_DET:
logging.info("Enabling straggler detection...")
straggler_det_args_dict = dict(cfg.straggler_detection_params)
straggler_det_callback = StragglerDetectionCallback(**straggler_det_args_dict, logger=logging)
trainer.callbacks.append(straggler_det_callback)
else:
raise ValueError(
"`create_straggler_detection_callback` is True, but there is no Straggler Det. package installed."
)

if is_global_rank_zero():
# Move files_to_copy to folder and add git information if present
if cfg.files_to_copy:
Expand Down
139 changes: 139 additions & 0 deletions tests/core/test_straggler_det.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
Dismissed Show dismissed Hide dismissed

import pytest
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf

from nemo.core.classes import ModelPT
from nemo.utils.exp_manager import exp_manager

try:
# `ptl_resiliency` is included in `gwe_resiliency_pkg` package
from ptl_resiliency import StragglerDetectionCallback
Dismissed Show dismissed Hide dismissed

HAVE_STRAGGLER_DET = True
except (ImportError, ModuleNotFoundError):
HAVE_STRAGGLER_DET = False


class OnesDataset(torch.utils.data.Dataset):
def __init__(self, dataset_len):
super().__init__()
self.__dataset_len = dataset_len

def __getitem__(self, *args):
return torch.ones(2)

def __len__(self):
return self.__dataset_len


class ExampleModel(ModelPT):
def __init__(self, log_dir, **kwargs):
cfg = OmegaConf.structured({})
super().__init__(cfg)
pl.seed_everything(1234)
self.l1 = torch.nn.modules.Linear(in_features=2, out_features=1)
self.log_dir = log_dir

def on_train_start(self):
super().on_train_start()
rank = torch.distributed.get_rank()
Dismissed Show dismissed Hide dismissed

def train_dataloader(self):
dataset = OnesDataset(128)
return torch.utils.data.DataLoader(dataset, batch_size=2, num_workers=8)

def val_dataloader(self):
dataset = OnesDataset(128)
return torch.utils.data.DataLoader(dataset, batch_size=2, num_workers=8)

def forward(self, batch):
output = self.l1(batch)
output = torch.nn.functional.l1_loss(output, torch.zeros(output.size()).to(output.device))
return output

def validation_step(self, batch, batch_idx):
self.loss = self(batch)
return self.loss

def training_step(self, batch, batch_idx):
return self(batch)

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.1)

def list_available_models(self, *args, **kwargs):
pass

def setup_training_data(self, *args, **kwargs):
pass

def setup_validation_data(self, *args, **kwargs):
pass

def on_validation_epoch_end(self):
self.log("val_loss", torch.stack([self.loss]).mean())


@pytest.mark.skipif(not HAVE_STRAGGLER_DET, reason="requires resiliency package to be installed.")
class TestStragglerDetection:

@pytest.mark.run_only_on('GPU')
def test_prints_perf_scores(self, tmp_path):
# Run dummy 1 rank DDP training
# Training time is limited to 3 seconds and straggler reporting is set to 1 second
# Check if there are straggler related logs in the captured log
max_steps = 1_000_000
tmp_path = tmp_path / "test_1"
print("TMP PATH", tmp_path)

trainer = pl.Trainer(
strategy='ddp',
devices=1,
accelerator='gpu',
enable_checkpointing=False,
logger=False,
max_steps=max_steps,
val_check_interval=0.33,
)
exp_manager(
trainer,
{
"max_time_per_run": "00:00:00:03",
"explicit_log_dir": str(tmp_path),
"create_checkpoint_callback": False,
"create_straggler_detection_callback": True,
"straggler_detection_params": {
"report_time_interval": 1.0,
"calc_relative_gpu_perf": True,
"calc_individual_gpu_perf": True,
"num_gpu_perf_scores_to_log": 1,
},
},
)
model = ExampleModel(log_dir=tmp_path)
trainer.fit(model)

# assume that NeMo logs are written into "nemo_log_globalrank-0_localrank-0.txt"
rank0_log_content = None
with open(tmp_path / "nemo_log_globalrank-0_localrank-0.txt") as f:
rank0_log_content = f.read()

assert "GPU relative performance" in rank0_log_content
assert "GPU individual performance" in rank0_log_content
Loading