Skip to content

Commit

Permalink
S3 Dirpath + Async Uploading Support for Default Checkpoints (NVIDIA#…
Browse files Browse the repository at this point in the history
…9045)

* Add S3 dirpath and asynchronous uploading support for basic checkpointing

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Update megtron_gpt_pretraining config to support S3 checkpointing

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Removed unused imports

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* move s3_checkpoint_io into callbacks. consolidate checkpoint_file_utils into s3_utils.py

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Update setup() in nemo_model_checkpoint to broadcast checkpoint path and work with upstreamed implementation of removing unfinished checkpoints

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Add boto3 dependency for testing

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Remove redundant setup() in nemo_model_checkpoint

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Remove comment line from import

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Removed explicit CRT calls since boto[crt] automatically uses CRT for file upload and download

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Style fix

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* remove un-used s3transfer import

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* add s3 prefix for s3-related checkpointing config

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* dummy sleep function lowered from 1 to 0.01 seconds

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Remove local_rank checking for rank, and use is_global_rank_zero.

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Style fix

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Apply isort and black reformatting

Signed-off-by: alxzhang-amazon <alxzhang-amazon@users.noreply.github.com>

* add tenacity dependency

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Apply isort and black reformatting

Signed-off-by: alxzhang-amazon <alxzhang-amazon@users.noreply.github.com>

* Add filtering of unfinished checkpoint to non-s3 checkpoint resuming

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* isort black reformatting

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Apply isort and black reformatting

Signed-off-by: alxzhang-amazon <alxzhang-amazon@users.noreply.github.com>

* Remove dependency requirement for checking if dirpath is an s3 path

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Make dependencies fully optional; allow exp_manager to optionally import S3Utils depending on whether dirpath is an S3 address or not

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Add rst doc for s3 checkpointing

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Remove unneeded assert

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Removed dependencies

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Apply isort and black reformatting

Signed-off-by: alxzhang-amazon <alxzhang-amazon@users.noreply.github.com>

* Updated documentation on async save to S3

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Apply isort and black reformatting

Signed-off-by: alxzhang-amazon <alxzhang-amazon@users.noreply.github.com>

* Update S3 checkpointing doc and fix visibility on website. Update the nlp_overrides DDP initializer to properly assign updated checkpoint io to base class.

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

* Apply isort and black reformatting

Signed-off-by: alxzhang-amazon <alxzhang-amazon@users.noreply.github.com>

* Slight fix in s3 checkpoint doc

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>

---------

Signed-off-by: Alexander Zhang <alxzhang@amazon.com>
Signed-off-by: alxzhang-amazon <166076199+alxzhang-amazon@users.noreply.github.com>
Signed-off-by: alxzhang-amazon <alxzhang-amazon@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: alxzhang-amazon <alxzhang-amazon@users.noreply.github.com>
  • Loading branch information
3 people authored and JesusPaz committed Jun 18, 2024
1 parent 86f5564 commit ff1e849
Show file tree
Hide file tree
Showing 11 changed files with 887 additions and 55 deletions.
1 change: 1 addition & 0 deletions docs/source/common/intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ The common collection contains things that could be used across all collections.
metrics
tokenizers
data
s3_checkpointing
96 changes: 96 additions & 0 deletions docs/source/common/s3_checkpointing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
****************
S3 Checkpointing
****************

S3CheckpointIO
==============

This checkpoint_io is used for saving and loading files to and from S3.
Initializing this checkpoint_io requires the dirpath be an S3 dirpath.

**Example Usage:**

.. code-block:: bash
async_checkpointing = self.cfg.s3_checkpointing.get('enable_async_checkpointing', False)
chunk_size_MB = self.cfg.s3_checkpointing.get('chunk_size_MB')
max_read_concurrency = self.cfg.s3_checkpointing.get('max_read_concurrency')
max_write_concurrency = self.cfg.s3_checkpointing.get('max_write_concurrency')
dirpath = self.cfg.exp_manager.checkpoint_callback_params.get('dirpath')
s3_checkpoint_io = S3CheckpointIO(dirpath=dirpath, chunk_size_MB=chunk_size_MB, max_read_concurrency=max_read_concurrency, max_write_concurrency=max_write_concurrency, async_checkpointing=async_checkpointing)
strategy = NLPDDPStrategy(
no_ddp_communication_hook=True,
checkpoint_io=s3_checkpoint_io,
gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None),
sharp=self.cfg.model.get('sharp', False),
)
**Config changes:**

.. code-block:: bash
checkpoint_callback_params:
dirpath: s3://mstar-eks-dev-us-east-2/alxzhang/nemo123/1n/checkpoints
...
s3_checkpointing:
# write_concurrency * tp * pp * 1.15 (buffer) should be within 3500 S3 TPS limit per partition
max_write_concurrency: 10
# read_concurrency * tp * pp * 1.15 (buffer) should be within 5500 S3 TPS limit per partition
max_read_concurrency: 15
chunk_size_MB: 64
# enables asynchronous checkpoint writing to S3
enable_async_checkpointing: False
**Asynchronous**
By default, the S3CheckpointIO class acts synchronously.
The async feature currently does not check if the previous async save is completed, so it is possible
that an old checkpoint is removed even when the current save fails.
To prevent this, this feature is meant to be used in conjunction with saving top k checkpoints.


S3Utils and Dependencies
========================

This utility class is used by the S3CheckpoinIO and the exp_manager to do S3-related operations.
It has dependencies on

1. boto3[crt]

2. s3fs==0.4.2

3. tenacity

If any of these are missing, this class can't be used.



s3_dirpath_utils
================

Used to operate on strings by checking if they are S3 dirpaths, or convert a bucket and key into an s3 dirpath.
This has no reliance on the S3Utils utility class, and can be used without any new dependencies.


S3 Demands and ExpManager Details When Running at Scale
=======================================================

Typically, in the ExpManager, every rank looks for the checkpoint file to load from. At large scale, there can be thousands of ranks querying S3 for dirpaths which can cause slowdown or throttling errors.

To avoid overloading S3 when resuming from a checkpoint only rank 0 needs to identify the checkpoint path and find the correct resumption file. Rank 0 will broadcast the checkpoint path to the other ranks.

.. code-block:: bash
trainer._checkpoint_connector = NeMoCheckpointConnector(trainer)
The NeMoModelCheckpoint setup() method will automatically broadcast the checkpoint path.

The NeMoCheckpointConnector is defined in the exp_manager.py file, and uses the broadcasted checkpoint path founds by rank 0 on all ranks when resuming training from an existing checkpoint.

The setting of the trainer._checkpoint_connector needs to happen before the ExpManager call as the ExpManager updates the trainer's checkpoint connector.
11 changes: 11 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ trainer:
benchmark: False
enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually

# Used for S3 Checkpointing
s3_checkpointing:
# write_concurrency * tp * pp * 1.15 (buffer) should be within 3500 S3 TPS limit per partition
max_write_concurrency: 10
# read_concurrency * tp * pp * 1.15 (buffer) should be within 5500 S3 TPS limit per partition
max_read_concurrency: 15
chunk_size_MB: 64
# enables asynchronous checkpoint writing to S3 dirpath. the feature is experimental and currently does not check if the past save succeeded. Therefore, use in conjunction with save_top_k.
enable_async_checkpointing: False

exp_manager:
explicit_log_dir: null
exp_dir: null
Expand All @@ -45,6 +55,7 @@ exp_manager:
resume_from_checkpoint: ${model.resume_from_checkpoint}
create_checkpoint_callback: True
checkpoint_callback_params:
dirpath: null # to use S3 checkpointing, set the dirpath in format s3://bucket/key
monitor: val_loss
save_top_k: 10
mode: min
Expand Down
7 changes: 6 additions & 1 deletion nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,12 @@ def __init__(
raise ImportError(
"megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
)
super().__init__(parallel_devices, cluster_environment, checkpoint_io, **kwargs)
super().__init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
**kwargs,
)

self.no_ddp_communication_hook = no_ddp_communication_hook
self.nccl_communicator_config_path = nccl_communicator_config_path
Expand Down
27 changes: 19 additions & 8 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@


class SharedStateDictProtocol(Protocol):
def sharded_state_dict(self, prefix=""):
...
def sharded_state_dict(self, prefix=""): ...


def init_parallel_ranks(
world_size: int, global_rank: int, local_rank: int, parallel_config: "ModelParallelConfig", seed=1234, fp8=False,
world_size: int,
global_rank: int,
local_rank: int,
parallel_config: "ModelParallelConfig",
seed=1234,
fp8=False,
) -> None:
"""
Initializes the parallel ranks for distributed training.
Expand Down Expand Up @@ -161,7 +165,7 @@ class GradScaler(torch.cuda.amp.GradScaler):

def __init__(
self,
init_scale=2.0 ** 16,
init_scale=2.0**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
Expand Down Expand Up @@ -193,7 +197,9 @@ def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):

# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group(),
found_inf,
op=torch.distributed.ReduceOp.MAX,
group=parallel_state.get_model_parallel_group(),
)

if found_inf.item() == 0:
Expand Down Expand Up @@ -244,15 +250,19 @@ def update(self, new_scale=None):

# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf_combined, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group(),
found_inf_combined,
op=torch.distributed.ReduceOp.MAX,
group=parallel_state.get_model_parallel_group(),
)

if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf = found_infs[i]
# Update across all model parallel instances.
torch.distributed.all_reduce(
found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group(),
found_inf,
op=torch.distributed.ReduceOp.MAX,
group=parallel_state.get_model_parallel_group(),
)
found_inf_combined += found_inf

Expand Down Expand Up @@ -428,7 +438,8 @@ def get_safe(param_id):
for param_id, fp32_param in zip(state_group["params"], fp32_group)
]
for fp32_group, state_group in zip(
optimizer_state_dict["fp32_from_fp16_params"], optimizer_state_dict["optimizer"]["param_groups"],
optimizer_state_dict["fp32_from_fp16_params"],
optimizer_state_dict["optimizer"]["param_groups"],
)
]

Expand Down
9 changes: 7 additions & 2 deletions nemo/lightning/pytorch/plugins/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,16 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul
)

num_microbatch_calculator.update(
consumed_samples=consumed_samples, consistency_check=False,
consumed_samples=consumed_samples,
consistency_check=False,
)
current_global_batch_size = num_microbatch_calculator.current_global_batch_size
pl_module.log(
"global_batch_size", current_global_batch_size, prog_bar=True, rank_zero_only=True, batch_size=1,
"global_batch_size",
current_global_batch_size,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)
self.if_first_step = 1

Expand Down
10 changes: 8 additions & 2 deletions nemo/utils/callbacks/nemo_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,20 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
super().load_state_dict(state_dict)
self._remove_invalid_entries_from_topk()

def setup(self, *args, **kwargs) -> None:
def setup(self, trainer, pl_module, stage: str) -> None:
if is_global_rank_zero():
logging.debug("Removing unfinished checkpoints if any...")
NeMoModelCheckpoint._remove_unfinished_checkpoints(self.dirpath)
# Ensure that all ranks continue with unfinished checkpoints removed
if torch.distributed.is_initialized():
torch.distributed.barrier()
super().setup(*args, **kwargs)
super().setup(trainer, pl_module, stage)
# When using S3 checkpointing, only Rank 0 has the checkpoint and model path set in exp_manager.
# Sync the values across all ranks to ensure consistency.
path = trainer.strategy.broadcast(trainer.ckpt_path)
trainer.ckpt_path = path

self.last_model_path = trainer.strategy.broadcast(self.last_model_path)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
Expand Down
Loading

0 comments on commit ff1e849

Please sign in to comment.