diff --git a/docs/source/common/intro.rst b/docs/source/common/intro.rst index a89f1a480e5d..813783fc720b 100644 --- a/docs/source/common/intro.rst +++ b/docs/source/common/intro.rst @@ -11,3 +11,4 @@ The common collection contains things that could be used across all collections. metrics tokenizers data + s3_checkpointing diff --git a/docs/source/common/s3_checkpointing.rst b/docs/source/common/s3_checkpointing.rst new file mode 100644 index 000000000000..7a5c0bb09661 --- /dev/null +++ b/docs/source/common/s3_checkpointing.rst @@ -0,0 +1,96 @@ +**************** +S3 Checkpointing +**************** + +S3CheckpointIO +============== + +This checkpoint_io is used for saving and loading files to and from S3. +Initializing this checkpoint_io requires the dirpath be an S3 dirpath. + +**Example Usage:** + +.. code-block:: bash + + async_checkpointing = self.cfg.s3_checkpointing.get('enable_async_checkpointing', False) + chunk_size_MB = self.cfg.s3_checkpointing.get('chunk_size_MB') + max_read_concurrency = self.cfg.s3_checkpointing.get('max_read_concurrency') + max_write_concurrency = self.cfg.s3_checkpointing.get('max_write_concurrency') + dirpath = self.cfg.exp_manager.checkpoint_callback_params.get('dirpath') + + s3_checkpoint_io = S3CheckpointIO(dirpath=dirpath, chunk_size_MB=chunk_size_MB, max_read_concurrency=max_read_concurrency, max_write_concurrency=max_write_concurrency, async_checkpointing=async_checkpointing) + + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, + checkpoint_io=s3_checkpoint_io, + gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view, + find_unused_parameters=False, + nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None), + sharp=self.cfg.model.get('sharp', False), + ) + + +**Config changes:** + +.. code-block:: bash + + checkpoint_callback_params: + dirpath: s3://mstar-eks-dev-us-east-2/alxzhang/nemo123/1n/checkpoints + + ... + + s3_checkpointing: + # write_concurrency * tp * pp * 1.15 (buffer) should be within 3500 S3 TPS limit per partition + max_write_concurrency: 10 + # read_concurrency * tp * pp * 1.15 (buffer) should be within 5500 S3 TPS limit per partition + max_read_concurrency: 15 + chunk_size_MB: 64 + # enables asynchronous checkpoint writing to S3 + enable_async_checkpointing: False + +**Asynchronous** +By default, the S3CheckpointIO class acts synchronously. +The async feature currently does not check if the previous async save is completed, so it is possible +that an old checkpoint is removed even when the current save fails. +To prevent this, this feature is meant to be used in conjunction with saving top k checkpoints. + + +S3Utils and Dependencies +======================== + +This utility class is used by the S3CheckpoinIO and the exp_manager to do S3-related operations. +It has dependencies on + +1. boto3[crt] + +2. s3fs==0.4.2 + +3. tenacity + +If any of these are missing, this class can't be used. + + + +s3_dirpath_utils +================ + +Used to operate on strings by checking if they are S3 dirpaths, or convert a bucket and key into an s3 dirpath. +This has no reliance on the S3Utils utility class, and can be used without any new dependencies. + + +S3 Demands and ExpManager Details When Running at Scale +======================================================= + +Typically, in the ExpManager, every rank looks for the checkpoint file to load from. At large scale, there can be thousands of ranks querying S3 for dirpaths which can cause slowdown or throttling errors. + +To avoid overloading S3 when resuming from a checkpoint only rank 0 needs to identify the checkpoint path and find the correct resumption file. Rank 0 will broadcast the checkpoint path to the other ranks. + +.. code-block:: bash + + trainer._checkpoint_connector = NeMoCheckpointConnector(trainer) + +The NeMoModelCheckpoint setup() method will automatically broadcast the checkpoint path. + +The NeMoCheckpointConnector is defined in the exp_manager.py file, and uses the broadcasted checkpoint path founds by rank 0 on all ranks when resuming training from an existing checkpoint. + +The setting of the trainer._checkpoint_connector needs to happen before the ExpManager call as the ExpManager updates the trainer's checkpoint connector. diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 1f63f7742ea0..ccdddcbc2272 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -24,6 +24,16 @@ trainer: benchmark: False enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually +# Used for S3 Checkpointing +s3_checkpointing: + # write_concurrency * tp * pp * 1.15 (buffer) should be within 3500 S3 TPS limit per partition + max_write_concurrency: 10 + # read_concurrency * tp * pp * 1.15 (buffer) should be within 5500 S3 TPS limit per partition + max_read_concurrency: 15 + chunk_size_MB: 64 + # enables asynchronous checkpoint writing to S3 dirpath. the feature is experimental and currently does not check if the past save succeeded. Therefore, use in conjunction with save_top_k. + enable_async_checkpointing: False + exp_manager: explicit_log_dir: null exp_dir: null @@ -45,6 +55,7 @@ exp_manager: resume_from_checkpoint: ${model.resume_from_checkpoint} create_checkpoint_callback: True checkpoint_callback_params: + dirpath: null # to use S3 checkpointing, set the dirpath in format s3://bucket/key monitor: val_loss save_top_k: 10 mode: min diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 8ca010e59f70..6b356539aba9 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -195,7 +195,12 @@ def __init__( raise ImportError( "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) - super().__init__(parallel_devices, cluster_environment, checkpoint_io, **kwargs) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, + **kwargs, + ) self.no_ddp_communication_hook = no_ddp_communication_hook self.nccl_communicator_config_path = nccl_communicator_config_path diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index cd8e38af12f2..9dd36ba54dbe 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -16,12 +16,16 @@ class SharedStateDictProtocol(Protocol): - def sharded_state_dict(self, prefix=""): - ... + def sharded_state_dict(self, prefix=""): ... def init_parallel_ranks( - world_size: int, global_rank: int, local_rank: int, parallel_config: "ModelParallelConfig", seed=1234, fp8=False, + world_size: int, + global_rank: int, + local_rank: int, + parallel_config: "ModelParallelConfig", + seed=1234, + fp8=False, ) -> None: """ Initializes the parallel ranks for distributed training. @@ -161,7 +165,7 @@ class GradScaler(torch.cuda.amp.GradScaler): def __init__( self, - init_scale=2.0 ** 16, + init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, @@ -193,7 +197,9 @@ def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): # Update across all model parallel instances. torch.distributed.all_reduce( - found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group(), + found_inf, + op=torch.distributed.ReduceOp.MAX, + group=parallel_state.get_model_parallel_group(), ) if found_inf.item() == 0: @@ -244,7 +250,9 @@ def update(self, new_scale=None): # Update across all model parallel instances. torch.distributed.all_reduce( - found_inf_combined, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group(), + found_inf_combined, + op=torch.distributed.ReduceOp.MAX, + group=parallel_state.get_model_parallel_group(), ) if len(found_infs) > 1: @@ -252,7 +260,9 @@ def update(self, new_scale=None): found_inf = found_infs[i] # Update across all model parallel instances. torch.distributed.all_reduce( - found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group(), + found_inf, + op=torch.distributed.ReduceOp.MAX, + group=parallel_state.get_model_parallel_group(), ) found_inf_combined += found_inf @@ -428,7 +438,8 @@ def get_safe(param_id): for param_id, fp32_param in zip(state_group["params"], fp32_group) ] for fp32_group, state_group in zip( - optimizer_state_dict["fp32_from_fp16_params"], optimizer_state_dict["optimizer"]["param_groups"], + optimizer_state_dict["fp32_from_fp16_params"], + optimizer_state_dict["optimizer"]["param_groups"], ) ] diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py index 1fca29ce01d3..470b7f3984f2 100644 --- a/nemo/lightning/pytorch/plugins/data_sampler.py +++ b/nemo/lightning/pytorch/plugins/data_sampler.py @@ -101,11 +101,16 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul ) num_microbatch_calculator.update( - consumed_samples=consumed_samples, consistency_check=False, + consumed_samples=consumed_samples, + consistency_check=False, ) current_global_batch_size = num_microbatch_calculator.current_global_batch_size pl_module.log( - "global_batch_size", current_global_batch_size, prog_bar=True, rank_zero_only=True, batch_size=1, + "global_batch_size", + current_global_batch_size, + prog_bar=True, + rank_zero_only=True, + batch_size=1, ) self.if_first_step = 1 diff --git a/nemo/utils/callbacks/nemo_model_checkpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py index e1d1f2e94586..9893b0806ac2 100644 --- a/nemo/utils/callbacks/nemo_model_checkpoint.py +++ b/nemo/utils/callbacks/nemo_model_checkpoint.py @@ -182,14 +182,20 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: super().load_state_dict(state_dict) self._remove_invalid_entries_from_topk() - def setup(self, *args, **kwargs) -> None: + def setup(self, trainer, pl_module, stage: str) -> None: if is_global_rank_zero(): logging.debug("Removing unfinished checkpoints if any...") NeMoModelCheckpoint._remove_unfinished_checkpoints(self.dirpath) # Ensure that all ranks continue with unfinished checkpoints removed if torch.distributed.is_initialized(): torch.distributed.barrier() - super().setup(*args, **kwargs) + super().setup(trainer, pl_module, stage) + # When using S3 checkpointing, only Rank 0 has the checkpoint and model path set in exp_manager. + # Sync the values across all ranks to ensure consistency. + path = trainer.strategy.broadcast(trainer.ckpt_path) + trainer.ckpt_path = path + + self.last_model_path = trainer.strategy.broadcast(self.last_model_path) def on_save_checkpoint(self, trainer, pl_module, checkpoint): output = super().on_save_checkpoint(trainer, pl_module, checkpoint) diff --git a/nemo/utils/callbacks/s3_checkpoint_io.py b/nemo/utils/callbacks/s3_checkpoint_io.py new file mode 100644 index 000000000000..4ded98a1b610 --- /dev/null +++ b/nemo/utils/callbacks/s3_checkpoint_io.py @@ -0,0 +1,275 @@ +import os +import time +from concurrent.futures import ProcessPoolExecutor +from io import BytesIO +from multiprocessing import get_start_method +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import Any, Callable, Dict, Optional, Union + +import torch +from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO + +from nemo.utils import logging +from nemo.utils.s3_utils import ( + DEFAULT_CHUNK_SIZE_MB, + DEFAULT_MAX_READ_CONCURRENCY, + DEFAULT_MAX_WRITE_CONCURRENCY, + SHARED_MEM_DIR, + S3Utils, +) + + +class S3CheckpointIO(CheckpointIO): + """A custom S3CheckpointIO module that supports checkpoint reading/writing with s3 when filepath + is a s3 url. + """ + + def __init__( + self, + dirpath: str, + chunk_size_MB=DEFAULT_CHUNK_SIZE_MB, + max_read_concurrency=DEFAULT_MAX_READ_CONCURRENCY, + max_write_concurrency=DEFAULT_MAX_WRITE_CONCURRENCY, + async_checkpointing=False, + ): + """ + Initialize the transfer configuration with custom values. + + This method overrides the default TransferConfig values in boto3. + See https://boto3.amazonaws.com/v1/documentation/api/latest/_modules/boto3/s3/transfer.html#TransferConfig + + Args: + chunk_size_MB (int, optional): The size of chunks to use when transferring files. + Default is 64 (MB). + max_read_concurrency (int, optional): The maximum number of threads that will be making + requests to perform a download. Default is 15. + max_write_concurrency (int, optional): The maximum number of threads that will be making + requests to perform an upload. Default is 10. + async_checkpointing (bool, optional): Uses a ProcessPoolExecutor to do the main saving logic. + This feature should be used with save_top_k as it's possible a previous checkpoint is removed while + the current checkpoint write fails. + """ + if not S3Utils.is_s3_url(dirpath): + raise AssertionError( + f"Error attempting to initialize an S3CheckpointIO when {dirpath} is not an S3 url. Please use TorchCheckpointIO when using a non-S3 dirpath." + ) + + self.chunk_size_MB = chunk_size_MB + self.max_read_concurrency = max_read_concurrency + self.max_write_concurrency = max_write_concurrency + self._async_checkpointing = async_checkpointing + ''' + When using shared memory, we create a temporary file to hold the checkpoint before uploading to S3. + This list will track those temporary files, and clean up any leaked files that are still around during teardown. + ''' + self._temp_files = [] + + if self.async_checkpointing: + # create an executor that will asynchronously run functions + self._executor = ProcessPoolExecutor(max_workers=1) if self.async_checkpointing else None + + # Eager creating a subprocess now so that forked subprocess does not inherit cuda context from parent + if get_start_method() == 'fork' and torch.cuda.is_initialized() is True: + raise Exception( + f'torch.cuda should not be initialized when checkpointing subprocess is created by fork method' + ) + logging.info(f'Creating asynchronous checkpointing subprocess') + future = self._executor.submit(dummy_func) + try: + future.result() + logging.info(f'Asynchronous heckpointing subprocess created successfully') + except Exception as e: + logging.error(f'Failed to create asynchronous checkpointing subprocess, exception: {e}') + raise e + self._futures = [] + + super().__init__() + + @property + def async_checkpointing(self): + return self._async_checkpointing + + def _serialize_checkpoint_to_shm(self, checkpoint: Dict, path: str) -> str: + """ + Returns: + filename of the temporary file in shared memory. + """ + start_time = time.perf_counter() + tempfile = NamedTemporaryFile(dir=SHARED_MEM_DIR, delete=False) + torch.save(checkpoint, tempfile) + logging.info( + f'Time elapsed saving checkpoint dict to {tempfile.name} for {path}: {(time.perf_counter() - start_time):.2f} seconds, rank {torch.distributed.get_rank()}' + ) + del checkpoint + return tempfile.name + + def _serialize_checkpoint_to_bytes(self, checkpoint: Dict, path: str) -> BytesIO: + """ + Returns: + The bytestring of the checkpoint. + """ + ss = time.perf_counter() + bytes = BytesIO() + torch.save(checkpoint, bytes) + tt = time.perf_counter() - ss + logging.info( + f'Time elapsed saving checkpoint dict to bytes for {path}: {tt:.2f} seconds, rank {torch.distributed.get_rank()}' + ) + del checkpoint + return bytes + + def _check_uploading_results_so_far(self): + """ + self._future is a list of tuples of form (future, destination path, source path) + This function checks the result of all the futures, and updates the self._futures list appropriately. + It also updates the list of self._temp_files, which is used to clean up leaked temporary files in SHARED_MEM during teardown. + """ + if not self._futures: + return + start_time = time.perf_counter() + done_futures = [] + in_progress_futures = [] + for item in self._futures: + if item[0].done(): + done_futures.append(item) + else: + in_progress_futures.append(item) + + for item in done_futures: + try: + item[0].result() + except Exception as e: + logging.error(f'Failed to upload {item[2]} to {item[1]}, exception: {e}') + raise e + # If the future is complete, we can remove the temp file since we choose to clear the temp file when uploading. + try: + self._temp_files.remove(item[2]) + except: + pass # When not using shared memory, we do not append anything to the temp_files list, so remove will do nothing. + self._futures = in_progress_futures + logging.debug( + f'Time elapsed checking uploading future results: {(time.perf_counter() - start_time):.2f} seconds' + ) + + def save_checkpoint( + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None + ) -> None: + # if we have a shared memory directory, we can serialize as a file to shared memory instead of as bytes. + if os.path.exists(SHARED_MEM_DIR): + localfile = self._serialize_checkpoint_to_shm(checkpoint, path) + self._temp_files.append(localfile) + saved_as_file = True + else: + bytes = self._serialize_checkpoint_to_bytes(checkpoint, path) + saved_as_file = False + + if self.async_checkpointing: + self._check_uploading_results_so_far() + logging.info(f'Uploading checkpoint to {path} in asynchronous mode, rank {torch.distributed.get_rank()}') + if saved_as_file: + future = self._executor.submit( + _upload_file_to_s3, localfile, path, self.chunk_size_MB, self.max_write_concurrency, True + ) + self._futures.append((future, path, localfile)) + else: + future = self._executor.submit( + _upload_bytes_to_s3, bytes, path, self.chunk_size_MB, self.max_write_concurrency + ) + self._futures.append((future, path, 'bytes')) + else: + logging.info(f'Uploading checkpoint to {path} in synchronous mode, rank {torch.distributed.get_rank()}') + if saved_as_file: + _upload_file_to_s3(localfile, path, self.chunk_size_MB, self.max_write_concurrency, True) + self._temp_files.remove(localfile) + else: + _upload_bytes_to_s3(bytes, path, self.chunk_size_MB, self.max_write_concurrency) + + def load_checkpoint( + self, path: Union[str, Path], map_location: Optional[Callable] = lambda storage, loc: storage + ) -> Dict[str, Any]: + if os.path.exists(SHARED_MEM_DIR): + with NamedTemporaryFile(dir=SHARED_MEM_DIR, delete=True) as tempfile: + logging.info( + f'Loading checkpoint {path} into a temp file in shared memory {tempfile.name}, rank {torch.distributed.get_rank()}' + ) + S3Utils.download_s3_file_to_path( + s3_path=path, + file_path=tempfile.name, + chunk_size_MB=self.chunk_size_MB, + max_concurrency=self.max_read_concurrency, + ) + checkpoint = torch.load(tempfile.name) + else: + file_stream: BytesIO = S3Utils.download_s3_file_to_stream( + s3_path=path, chunk_size_MB=self.chunk_size_MB, max_concurrency=self.max_read_concurrency + ) + checkpoint = torch.load(file_stream) + return checkpoint + + def remove_checkpoint(self, path: Union[str, Path]) -> None: + if S3Utils.is_s3_url(path): + S3Utils.remove_object(path) + else: + super().remove_checkpoint(path) + + def teardown(self) -> None: + # this ensure we wait for final checkpoint to finish uploading at train end. + rank = torch.distributed.get_rank() + if self.async_checkpointing: + logging.info(f'Entering teardown, waiting for all jobs to finish, rank {rank}') + start_time = time.perf_counter() + self._executor.shutdown(wait=True) + logging.info(f'executor shut down after {(time.perf_counter() - start_time):.2f} seconds, rank {rank}') + + ''' + this will be non-empty at the end of training if using asynchronous uploading since the futures are not processed with _check_uploading_results_so_far. + therefore, we check that the path exists first before trying to delete. + ''' + if self._temp_files: + for tfile in self._temp_files: + if os.path.exists(tfile): + try: + os.remove(tfile) + except Exception as e: + logging.info(f"Error occurred while deleting file {tfile}: {e}") + + +def _clean_up_conflicting_checkpoint(filepath: str) -> None: + ''' + before saving to s3, clean up any existing object with the same prefix megatron_gpt+step_count + e.g. before we save "megatron_gpt--step=1400-validation_loss=6.32-consumed_samples=55920.0-last.ckpt" + we need to clean up "megatron_gpt--step=1400-validation_loss=xxx-consumed_samples=yyy-last.ckpt" + so that in case later we need to resume from step 1400, it has a single checkpoint file at step 1400 + ''' + + if S3Utils.is_s3_url(filepath): + prefix_with_step = S3Utils.parse_prefix_with_step(filepath) + logging.info(f'Looking for conflicting checkpoint under prefix {prefix_with_step}') + + conflict_last_ckpts = S3Utils.find_files_with_suffix( + base_path=prefix_with_step, suffix='last.ckpt', return_key_only=False + ) + for last_ckpt in conflict_last_ckpts: + logging.info(f'Cleaning up conflicting last ckpt {last_ckpt} before saving {filepath}') + S3Utils.remove_object(last_ckpt) + + +def _upload_file_to_s3(localfile, path, chunk_size_MB, max_write_concurrency, remove_file): + try: + _clean_up_conflicting_checkpoint(path) + S3Utils.upload_file(localfile, path, chunk_size_MB, max_write_concurrency, remove_file) + except Exception as e: + raise e + + +def _upload_bytes_to_s3(bytes, path, chunk_size_MB, max_write_concurrency): + try: + _clean_up_conflicting_checkpoint(path) + S3Utils.upload_file_stream_to_s3(bytes, path, chunk_size_MB, max_write_concurrency) + except Exception as e: + raise e + + +def dummy_func(): + time.sleep(0.01) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 9e8b55eade1f..44896fc51c89 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -35,6 +35,8 @@ from pytorch_lightning.loggers import MLFlowLogger, NeptuneLogger, TensorBoardLogger, WandbLogger from pytorch_lightning.loops import _TrainingEpochLoop from pytorch_lightning.strategies.ddp import DDPStrategy +from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector + from nemo.collections.common.callbacks import EMA from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION @@ -606,55 +608,93 @@ def check_resume( if not log_dir: raise ValueError(f"Resuming requires the log_dir {log_dir} to be passed to exp_manager") + # is_s3_url from here has no dependency requirements + from nemo.utils.s3_dirpath_utils import is_s3_url + + try: + # when using an s3 dirpath, we rely on optional dependencies in the S3Utils class. + if dirpath is not None and is_s3_url(dirpath): + from nemo.utils.s3_utils import S3Utils + except ImportError as err: + return False, "Detected S3 dirpath while missing required dependencies.\n{}\n".format( + err.output.decode("utf-8") + ) + checkpoint = None if resume_from_checkpoint: checkpoint = resume_from_checkpoint if resume_if_exists: - # Use /checkpoints/ unless `dirpath` is set - checkpoint_dir = Path(dirpath) if dirpath else Path(Path(log_dir) / "checkpoints") - - # when using distributed checkpointing, checkpoint_dir is a directory of directories - # we check for this here - dist_checkpoints = [d for d in list(checkpoint_dir.glob("*")) if d.is_dir()] - end_dist_checkpoints = [d for d in dist_checkpoints if d.match("*end")] - last_dist_checkpoints = [d for d in dist_checkpoints if d.match("*last")] - - end_checkpoints = end_dist_checkpoints if end_dist_checkpoints else list(checkpoint_dir.rglob("*end.ckpt")) - end_checkpoints = _filter_out_unfinished_checkpoints(end_checkpoints) - last_checkpoints = last_dist_checkpoints if last_dist_checkpoints else list(checkpoint_dir.rglob("*last.ckpt")) - last_checkpoints = _filter_out_unfinished_checkpoints(last_checkpoints) - - if not checkpoint_dir.exists() or (not len(end_checkpoints) > 0 and not len(last_checkpoints) > 0): - if resume_ignore_no_checkpoint: - warn = f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. " - if checkpoint is None: - warn += "Training from scratch." - elif checkpoint == resume_from_checkpoint: - warn += f"Training from {resume_from_checkpoint}." - logging.warning(warn) - else: - raise NotFoundError( - f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume." + ''' + attach valid checkpoint path to trainer if current rank is rank zero of any data parallel groups + this limit to only global rank 0 process calling s3, instead of all processes calling s3 + ''' + + # If we are using S3 checkpointing, we want check_resume to only execute on a single rank to avoid throttling S3. + if is_global_rank_zero() or not is_s3_url(dirpath): + checkpoint_dir_exists = False + if is_s3_url(dirpath): + checkpoint_dir = dirpath + checkpoint_dir_exists = S3Utils.s3_path_exists(checkpoint_dir, match_directory=True) + + if checkpoint_dir_exists: + # max number of last.ckpt files: save_last_k_checkpoints * tp * pp = 5*8*40. If optim states is saved distributedly, multiply by dp_size + all_keys = S3Utils.find_files_with_suffix(checkpoint_dir, suffix=None, return_key_only=False) + end_checkpoints = [k for k in all_keys if k.endswith('end.ckpt')] + last_checkpoints = [k for k in all_keys if k.endswith('last.ckpt')] + else: + end_checkpoints = [] + last_checkpoints = [] + else: # default non-s3 implementation + # Use /checkpoints/ unless `dirpath` is set + checkpoint_dir = Path(dirpath) if dirpath else Path(Path(log_dir) / "checkpoints") + checkpoint_dir_exists = checkpoint_dir.exists() + + # when using distributed checkpointing, checkpoint_dir is a directory of directories + # we check for this here + dist_checkpoints = [d for d in list(checkpoint_dir.glob("*")) if d.is_dir()] + end_dist_checkpoints = [d for d in dist_checkpoints if d.match("*end")] + last_dist_checkpoints = [d for d in dist_checkpoints if d.match("*last")] + + end_checkpoints = ( + end_dist_checkpoints if end_dist_checkpoints else list(checkpoint_dir.rglob("*end.ckpt")) ) - elif len(end_checkpoints) > 0: - if resume_past_end: - if len(end_checkpoints) > 1: - if 'mp_rank' in str(end_checkpoints[0]): - checkpoint = end_checkpoints[0] - else: - raise ValueError(f"Multiple checkpoints {end_checkpoints} that matches *end.ckpt.") - else: - raise ValueError( - f"Found {end_checkpoints[0]} indicating that the last training run has already completed." + end_checkpoints = _filter_out_unfinished_checkpoints(end_checkpoints) + last_checkpoints = ( + last_dist_checkpoints if last_dist_checkpoints else list(checkpoint_dir.rglob("*last.ckpt")) ) - elif len(last_checkpoints) > 1: - if any([s for s in ['mp_rank', 'tp_rank', 'fsdp_shard'] if s in str(last_checkpoints[0])]): - checkpoint = last_checkpoints[0] - checkpoint = uninject_model_parallel_rank(checkpoint) + last_checkpoints = _filter_out_unfinished_checkpoints(last_checkpoints) + + if not checkpoint_dir_exists or (not len(end_checkpoints) > 0 and not len(last_checkpoints) > 0): + if resume_ignore_no_checkpoint: + warn = f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. " + if checkpoint is None: + warn += "Training from scratch." + elif checkpoint == resume_from_checkpoint: + warn += f"Training from {resume_from_checkpoint}." + logging.warning(warn) + else: + raise NotFoundError( + f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume." + ) + elif len(end_checkpoints) > 0: + if resume_past_end: + if len(end_checkpoints) > 1: + if 'mp_rank' in str(end_checkpoints[0]): + checkpoint = end_checkpoints[0] + else: + raise ValueError(f"Multiple checkpoints {end_checkpoints} that matches *end.ckpt.") + else: + raise ValueError( + f"Found {end_checkpoints[0]} indicating that the last training run has already completed." + ) + elif len(last_checkpoints) > 1: + if any([s for s in ['mp_rank', 'tp_rank', 'fsdp_shard'] if s in str(last_checkpoints[0])]): + checkpoint = last_checkpoints[0] + checkpoint = uninject_model_parallel_rank(checkpoint) + else: + raise ValueError(f"Multiple checkpoints {last_checkpoints} that matches *last.ckpt.") else: - raise ValueError(f"Multiple checkpoints {last_checkpoints} that matches *last.ckpt.") - else: - checkpoint = last_checkpoints[0] + checkpoint = last_checkpoints[0] # PTL 2.0 supports ckpt_path instead of resume_from_checkpoint as the trainer flag if checkpoint is not None: @@ -914,6 +954,24 @@ def configure_loggers( trainer._logger_connector.configure_logger(logger_list) +class NeMoCheckpointConnector(_CheckpointConnector): + """ + Wrapper around Lightning's _CheckpointConnector to use broadcasted checkpoint path in + distributed training settings to pre-load checkpoint. + """ + + def resume_start(self, checkpoint_path=None) -> None: + checkpoint_path = self.trainer.ckpt_path + if checkpoint_path is not None: + logging.info(f'Resuming from checkpoint {checkpoint_path}, rank {torch.distributed.get_rank()}') + start_time = time.perf_counter() + super().resume_start(checkpoint_path) + if checkpoint_path is not None: + logging.info( + f'Time elapsed loading checkpoint/optimizer states: {(time.perf_counter() - start_time):.2f} seconds, rank {torch.distributed.get_rank()}' + ) + + def configure_checkpointing( trainer: 'pytorch_lightning.Trainer', log_dir: Path, diff --git a/nemo/utils/s3_dirpath_utils.py b/nemo/utils/s3_dirpath_utils.py new file mode 100644 index 000000000000..fd66115d4e5d --- /dev/null +++ b/nemo/utils/s3_dirpath_utils.py @@ -0,0 +1,22 @@ +from pathlib import Path +from typing import Optional + +S3_PATH_PREFIX = 's3://' + + +def build_s3_url(bucket, key) -> str: + """ + This function constructs an s3 address given a bucket and key. + It has no reliance on any S3-related dependencies as the file pre-defines the S3 path prefix. + """ + return f'{S3_PATH_PREFIX}{bucket}/{key}' + + +def is_s3_url(path: Optional[str]) -> bool: + """ + This function checks if a path is an S3 url. + It has no reliance on any S3-related dependencies as the file pre-defines the S3 path prefix. + """ + if isinstance(path, Path): + path = str(path) + return path is not None and path.strip().startswith(S3_PATH_PREFIX) diff --git a/nemo/utils/s3_utils.py b/nemo/utils/s3_utils.py new file mode 100644 index 000000000000..3435a603b05d --- /dev/null +++ b/nemo/utils/s3_utils.py @@ -0,0 +1,342 @@ +import os +import re +import time +from io import BytesIO +from pathlib import Path +from typing import List, Optional, Tuple + +import boto3 +import botocore +from boto3.s3.transfer import TransferConfig +from botocore.exceptions import ClientError +from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_delay, wait_exponential + +from nemo.utils import logging +from nemo.utils.s3_dirpath_utils import build_s3_url, is_s3_url + +try: + import awscrt + import s3transfer.crt + + crt_available = True +except ImportError as e: + crt_available = False + +MB = 1024**2 +GB = 1024**3 + +SHARED_MEM_DIR = '/dev/shm' +DEFAULT_CHUNK_SIZE_MB = 64 +DEFAULT_MAX_READ_CONCURRENCY = 15 +DEFAULT_MAX_WRITE_CONCURRENCY = 10 + + +class S3Utils: + """ + Utility class for interacting with S3. Handles downloading and uploading to S3, and parsing/formatting S3 urls. + """ + + ''' + Avoid caching boto3 client or resource as a class variable as it gets executed once during class construction. + When the security token expires, the client or resouece will be no longer valid. + Create a new resource as needed. To avoid multithreading errors, use different session for each thread. + ''' + + @staticmethod + def s3_path_exists(s3_path: str, match_directory: bool = False) -> bool: + """ + :s3_path: the path + :match_directory: if the content is known to be a directory then set it to `True`. Since s3 isn't a file system, paths are funky and the concept of folders doesn't really exist. + """ + bucket_name, prefix = S3Utils.parse_s3_url(s3_path) + if not prefix: + return False + + s3 = S3Utils._get_s3_resource() + # bucket = s3.Bucket(bucket_name) + s3_client = s3.meta.client + + try: + objs = s3_client.list_objects_v2(Bucket=bucket_name, MaxKeys=1, Prefix=prefix).get('Contents', []) + except s3_client.exceptions.NoSuchBucket: + return False + + if prefix == '': # bucket only + return True + + return len(objs) > 0 and (match_directory or objs[0]['Key'].startswith(prefix)) + + @staticmethod + def remove_object(s3_path: str) -> None: + s3_client = S3Utils._get_s3_resource(get_client=True) + bucket, key = S3Utils.parse_s3_url(s3_path) + s3_client.delete_object(Bucket=bucket, Key=key) + + @staticmethod + def download_s3_file_to_stream( + s3_path: str, chunk_size_MB: int = DEFAULT_CHUNK_SIZE_MB, max_concurrency: int = DEFAULT_MAX_READ_CONCURRENCY + ) -> BytesIO: + bytes_buffer = BytesIO() + + s3_client = S3Utils._get_s3_resource(get_client=True) + bucket, key = S3Utils.parse_s3_url(s3_path) + chunk_size = chunk_size_MB * MB + config = TransferConfig(multipart_chunksize=chunk_size, max_concurrency=max_concurrency) + + start_time = time.perf_counter() + _download_fileobj_with_retry(s3_client, bucket, key, bytes_buffer, config) + logging.info( + f'Time elapsed downloading {s3_path} to file stream with chunk_size={chunk_size_MB}MB ' + f'and max_concurrency={max_concurrency}: {(time.perf_counter() - start_time):.2f} seconds' + ) + + bytes_buffer.seek(0) + return bytes_buffer + + @staticmethod + def download_s3_file_to_path( + s3_path: str, + file_path: str, + chunk_size_MB: int = DEFAULT_CHUNK_SIZE_MB, + max_concurrency: int = DEFAULT_MAX_READ_CONCURRENCY, + ) -> None: + s3_client = S3Utils._get_s3_resource(get_client=True) + bucket, key = S3Utils.parse_s3_url(s3_path) + chunk_size = chunk_size_MB * MB + config = TransferConfig(multipart_chunksize=chunk_size, max_concurrency=max_concurrency) + + logging.info( + f'Downloading {s3_path} to {file_path} with chunk_size={chunk_size_MB}MB and max_threads={max_concurrency}' + ) + start_time = time.perf_counter() + _download_file_with_retry(s3_client, bucket, key, file_path, config) + logging.info( + f'Time elapsed downloading {s3_path} to {file_path} with chunk_size={chunk_size_MB}MB ' + f'and max_concurrency={max_concurrency}: {(time.perf_counter() - start_time):.2f} seconds' + ) + + @staticmethod + def upload_file_stream_to_s3( + bytes_buffer: BytesIO, + s3_path: str, + chunk_size_MB: int = DEFAULT_CHUNK_SIZE_MB, + max_concurrency: int = DEFAULT_MAX_WRITE_CONCURRENCY, + ) -> None: + s3_client = S3Utils._get_s3_resource(get_client=True) + bucket, key = S3Utils.parse_s3_url(s3_path) + chunk_size = chunk_size_MB * MB + config = TransferConfig(multipart_chunksize=chunk_size, max_concurrency=max_concurrency) + bytes_buffer.seek(0) + + start_time = time.perf_counter() + _upload_fileobj_with_retry(s3_client, bytes_buffer, bucket, key, config) + logging.info( + f'Time elapsed uploading bytes buffer to {s3_path} with chunk_size={chunk_size_MB}MB ' + f'and max_concurrency={max_concurrency}: {(time.perf_counter() - start_time):.2f} seconds' + ) + + @staticmethod + def upload_file( + file_path: str, + s3_path: str, + chunk_size_MB=DEFAULT_CHUNK_SIZE_MB, + max_concurrency=DEFAULT_MAX_WRITE_CONCURRENCY, + remove_file=False, + ): + total_size = os.path.getsize(file_path) + assert total_size > 0, f"file size is zero, {file_path}" + + s3_client = S3Utils._get_s3_resource(get_client=True) + bucket, key = S3Utils.parse_s3_url(s3_path) + + chunk_size = chunk_size_MB * MB + config = TransferConfig( + multipart_threshold=chunk_size, multipart_chunksize=chunk_size, max_concurrency=max_concurrency + ) + + start_time = time.perf_counter() + _upload_file_with_retry(s3_client, file_path, bucket, key, config) + if remove_file and os.path.exists(file_path): + os.remove(file_path) + logging.info( + f'Time elapsed uploading file {file_path} of size {(total_size/GB):.1f}GB to {s3_path} with chunk_size={chunk_size_MB}MB ' + f'and max_concurrency={max_concurrency}: {(time.perf_counter() - start_time):.2f} seconds' + ) + + @staticmethod + def find_files_with_suffix( + base_path: str, + suffix: str = None, + return_key_only: bool = True, + profile: Optional[str] = None, + creds: botocore.credentials.Credentials = None, + ) -> List[str]: + """ + Returns a list of keys that have the specified suffix + :param base_path: the root of search + :param suffix: the suffix to match, case sensitive + :return: list of keys matching the suffix, relative to the base_path + """ + s3 = S3Utils._get_s3_resource(profile, creds) + bucket_name, prefix = S3Utils.parse_s3_url(base_path) + + start_time = time.perf_counter() + bucket = s3.Bucket(bucket_name) + objects_list = _scan_objects_with_retry(s3_bucket=bucket, s3_prefix=prefix) + logging.info( + f'Time elapsed reading all objects under path {base_path}: {(time.perf_counter() - start_time):.2f} seconds' + ) + + if suffix: + objects_list = list(filter(lambda o: o.key.endswith(suffix), objects_list)) + + if return_key_only: + return [o.key for o in objects_list] + else: + return [S3Utils.build_s3_url(o.bucket_name, o.key) for o in objects_list] + + @staticmethod + def _get_s3_resource( + profile: str = None, + creds: botocore.credentials.Credentials = None, + get_client: bool = False, + session=None, + config={}, + ): + config = botocore.config.Config(max_pool_connections=30, **config) + + if profile is not None and creds is not None: + raise ValueError('Please provide profile or creds or neither, not both.') + + if profile is not None: + s3 = boto3.Session(profile_name=profile).resource('s3', config=config) + elif creds is not None: + s3 = boto3.Session().resource( + 's3', + aws_access_key_id=creds["AccessKeyId"], + aws_secret_access_key=creds["SecretAccessKey"], + aws_session_token=creds["SessionToken"], + config=config, + ) + else: + s3 = ( + boto3.Session().resource('s3', config=config) if not session else session.resource('s3', config=config) + ) + + if get_client: + return s3.meta.client + else: + return s3 + + @staticmethod + def parse_s3_url(s3_url: str) -> Optional[Tuple[str, str]]: + match = re.match(r"s3://([^/]+)/(.*)", s3_url, flags=re.UNICODE) + + if match is None: + return None, None + + return match.groups()[0], match.groups()[1] + + @staticmethod + def build_s3_url(bucket, key) -> str: + return build_s3_url(bucket, key) + + @staticmethod + def is_s3_url(path: Optional[str]) -> bool: + return is_s3_url(path) + + @staticmethod + def parse_prefix_with_step(path: str) -> str: + """ + Use regex to find the pattern up to "-step=900-" + s3://path/to/checkpoints/tp_rank_00_pp_rank_000/megatron_gpt--step=900-validation_loss=6.47-consumed_samples=35960.0-last.ckpt + should return s3://path/to/checkpoints/tp_rank_00_pp_rank_000/megatron_gpt--step=900- + """ + match = re.search(r'(.*step=\d+-)', path) + + if match: + return match.group(1) + + return path + + +def _scan_objects_with_retry(s3_bucket, s3_prefix): + # this returns a collection https://boto3.amazonaws.com/v1/documentation/api/latest/guide/collections.html + # This collection acts as an iterable that automatically makes additional requests to retrieve more objects from S3 as needed + objects = s3_bucket.objects.filter(Prefix=s3_prefix) + return list(objects) + + +def is_slow_down_error(exception): + """ + This function checks if the error is due to slowdown or is throttling related. + If so, returns true to allow tenacity to retry the upload/download to S3. + """ + class_name = exception.__class__.__name__ + module_name = exception.__class__.__module__ + full_class_name = f"{module_name}.{class_name}" + logging.error(f'Caught exception of type {full_class_name}: {exception}') + + # 2023-12-07T05:59:25.913721576Z stdout F 2023-12-07 05:59:25,913 [ERROR] - s3_utils.py:354 - Caught exception: + # AWS_ERROR_S3_INVALID_RESPONSE_STATUS: Invalid response status from request. Body from error request is: b'\nRequestTimeoutYour socket connection to the server was not read from or written to within the timeout period. Idle connections will be closed.XPHS9896G3RJE364ZAiF3HPpUD5IgSr/mfkP2QPs7ttuvY+uTRG9MET/jZZ45MJ6bVbnvSBQLggICvPCROPP/1k85p4=' + message = str(exception) + if ( + "SlowDown" in message + or "RequestTimeout" in message + or "InternalError" in message + ): + logging.info("Identified the Retriable Error retrying the job") + return True + + if crt_available and isinstance(exception, awscrt.exceptions.AwsCrtError): + logging.error(f'Caught awscrt.exceptions.AwsCrtError: {exception.__repr__()}') + return True + + if isinstance(exception, ClientError): + logging.error(f'Caught ClientError, response is: {exception.response}') + error_code = exception.response['Error']['Code'] if exception.response else None + return error_code in ['SlowDown', 'RequestTimeout', 'InternalError'] + logging.info("Non Retriable Error - Terminating the job") + return False + + +@retry( + wait=wait_exponential(multiplier=1, min=1, max=16), + stop=stop_after_delay(2 * 60), + retry=retry_if_exception(is_slow_down_error), + before_sleep=before_sleep_log(logging, logging.ERROR), +) +def _download_fileobj_with_retry( + s3_client, bucket: str, key: str, bytes_buffer: BytesIO, config: TransferConfig = None +): + s3_client.download_fileobj(bucket, key, bytes_buffer, Config=config) + + +@retry( + wait=wait_exponential(multiplier=1, min=1, max=16), + stop=stop_after_delay(2 * 60), + retry=retry_if_exception(is_slow_down_error), + before_sleep=before_sleep_log(logging, logging.ERROR), +) +def _download_file_with_retry(s3_client, bucket: str, key: str, file_path: str, config: TransferConfig = None): + s3_client.download_file(bucket, key, file_path, Config=config) + + +@retry( + wait=wait_exponential(multiplier=1, min=1, max=16), + stop=stop_after_delay(2 * 60), + retry=retry_if_exception(is_slow_down_error), + before_sleep=before_sleep_log(logging, logging.ERROR), +) +def _upload_fileobj_with_retry(s3_client, bytes_buffer: BytesIO, bucket: str, key: str, config: TransferConfig = None): + s3_client.upload_fileobj(bytes_buffer, bucket, key, Config=config) + + +@retry( + wait=wait_exponential(multiplier=1, min=1, max=16), + stop=stop_after_delay(2 * 60), + retry=retry_if_exception(is_slow_down_error), + before_sleep=before_sleep_log(logging, logging.ERROR), +) +def _upload_file_with_retry(s3_client, file_path: str, bucket: str, key: str, config: TransferConfig = None): + s3_client.upload_file(file_path, bucket, key, Config=config)