Skip to content

Commit

Permalink
Revert "First initialize dist with gloo (mosaicml#1133)" (mosaicml#1139)
Browse files Browse the repository at this point in the history
This reverts commit 76f74b6.
  • Loading branch information
dakinggg authored Apr 25, 2024
1 parent 24f65fd commit 4aef5de
Showing 1 changed file with 18 additions and 46 deletions.
64 changes: 18 additions & 46 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed
from composer import Trainer
from composer.core.callback import Callback
from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler,
Expand Down Expand Up @@ -115,33 +114,6 @@ def validate_config(cfg: DictConfig):
)


def _initialize_gloo_and_nccl(dist_timeout: Union[int, float]):
"""Initialize GLOO process group (then destroyed) and device process group.
We have experienced an issue where the first barrier with NCCL does not timeout properly,
and can hang forever if something is wrong. To attempt to mitigate this, we will first
initialize with a gloo process group and test a barrier, then destroy the process group
Args:
dist_timeout (Union[int, float]): Timeout for initializing the process group
"""
# First, initialize with a gloo process group and test a barrier
log.debug('Initializing dist with cpu...')
dist.initialize_dist('cpu', timeout=dist_timeout)
log.debug('Testing barrier with cpu...')
dist.barrier()
log.debug('Barrier test passed with cpu. Destroying process group...')
torch.distributed.destroy_process_group()
log.debug('Process group destroyed.')

# Now, initialize with the correct device
log.debug('Initializing dist with device...')
dist.initialize_dist(get_device(None), timeout=dist_timeout)
log.debug('Testing barrier with device...')
dist.barrier()
log.debug('Barrier test passed with device.')


def main(cfg: DictConfig) -> Trainer:
# Run user provided code if specified
code_paths = pop_config(cfg,
Expand Down Expand Up @@ -198,24 +170,7 @@ def main(cfg: DictConfig) -> Trainer:
'dist_timeout',
must_exist=False,
default_value=600.0)
python_log_level: Optional[str] = pop_config(cfg,
'python_log_level',
must_exist=False,
default_value='debug')
# Set logging level
if python_log_level is not None:
logging.basicConfig(
# Example of format string
# 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here
format=
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s'
)
logging.getLogger('llmfoundry').setLevel(
python_log_level.upper()) # Foundry module
logging.getLogger(__name__).setLevel(
python_log_level.upper()) # Train script

_initialize_gloo_and_nccl(dist_timeout=dist_timeout)
dist.initialize_dist(get_device(None), timeout=dist_timeout)

# Get global and device batch size information from distributed/single node setting
cfg = update_batch_size_info(cfg)
Expand Down Expand Up @@ -343,6 +298,10 @@ def main(cfg: DictConfig) -> Trainer:
'log_to_console',
must_exist=False,
default_value=True)
python_log_level: Optional[str] = pop_config(cfg,
'python_log_level',
must_exist=False,
default_value='debug')
console_log_interval: Union[int, str] = pop_config(cfg,
'console_log_interval',
must_exist=False,
Expand Down Expand Up @@ -432,6 +391,19 @@ def main(cfg: DictConfig) -> Trainer:
'FSDP is not applicable for single-GPU training. Reverting to DDP.')
fsdp_config = None

# set logging level
if python_log_level is not None:
logging.basicConfig(
# Example of format string
# 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here
format=
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s'
)
logging.getLogger('llmfoundry').setLevel(
python_log_level.upper()) # Foundry module
logging.getLogger(__name__).setLevel(
python_log_level.upper()) # Train script

# Initialize context
init_context = process_init_device(model_config, fsdp_config)
logged_cfg.update({'fsdp_config': fsdp_config}, merge=True)
Expand Down

0 comments on commit 4aef5de

Please sign in to comment.