Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

S3 Dirpath + Async Uploading Support for Default Checkpoints #9045

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f9c242c
Add S3 dirpath and asynchronous uploading support for basic checkpoin…
alxzhang-amazon Apr 26, 2024
e51f617
Update megtron_gpt_pretraining config to support S3 checkpointing
alxzhang-amazon Apr 26, 2024
4cd8da0
Removed unused imports
alxzhang-amazon Apr 26, 2024
00dac24
move s3_checkpoint_io into callbacks. consolidate checkpoint_file_uti…
alxzhang-amazon Apr 29, 2024
f4b4709
Update setup() in nemo_model_checkpoint to broadcast checkpoint path …
alxzhang-amazon Apr 29, 2024
c23437f
Add boto3 dependency for testing
alxzhang-amazon Apr 29, 2024
f060235
Remove redundant setup() in nemo_model_checkpoint
alxzhang-amazon Apr 29, 2024
050ccca
Remove comment line from import
alxzhang-amazon Apr 29, 2024
e28fe3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2024
6cadc08
Removed explicit CRT calls since boto[crt] automatically uses CRT for…
alxzhang-amazon Apr 30, 2024
4c934fd
Style fix
alxzhang-amazon Apr 30, 2024
a3791da
Merge branch 'main' into s3-checkpointing-support-upstream
alxzhang-amazon May 1, 2024
0667ab3
Merge branch 'main' into s3-checkpointing-support-upstream
alxzhang-amazon May 3, 2024
5a8bf4a
Merge branch 'main' into s3-checkpointing-support-upstream
alxzhang-amazon May 9, 2024
6597896
remove un-used s3transfer import
alxzhang-amazon May 9, 2024
8ddfe6d
add s3 prefix for s3-related checkpointing config
alxzhang-amazon May 13, 2024
c0f510d
dummy sleep function lowered from 1 to 0.01 seconds
alxzhang-amazon May 13, 2024
310a8a8
Remove local_rank checking for rank, and use is_global_rank_zero.
alxzhang-amazon May 13, 2024
56734c9
Style fix
alxzhang-amazon May 13, 2024
936079a
Merge branch 'main' into s3-checkpointing-support-upstream
alxzhang-amazon May 13, 2024
5a55e4e
Apply isort and black reformatting
alxzhang-amazon May 13, 2024
7a2a425
add tenacity dependency
alxzhang-amazon May 13, 2024
4045e64
Merge branch 'main' into s3-checkpointing-support-upstream
alxzhang-amazon Jun 4, 2024
0b9e9f7
Apply isort and black reformatting
alxzhang-amazon Jun 4, 2024
28c3691
Add filtering of unfinished checkpoint to non-s3 checkpoint resuming
alxzhang-amazon Jun 11, 2024
6b0d65e
isort black reformatting
alxzhang-amazon Jun 11, 2024
0ed7e5b
Merge branch 'main' into s3-checkpointing-support-upstream
alxzhang-amazon Jun 11, 2024
ef5e0fb
Apply isort and black reformatting
alxzhang-amazon Jun 11, 2024
b8066c4
Merge branch 'main' into s3-checkpointing-support-upstream
alxzhang-amazon Jun 11, 2024
d1b69cf
Remove dependency requirement for checking if dirpath is an s3 path
alxzhang-amazon Jun 12, 2024
62652db
Make dependencies fully optional; allow exp_manager to optionally imp…
alxzhang-amazon Jun 12, 2024
569e0f0
Add rst doc for s3 checkpointing
alxzhang-amazon Jun 12, 2024
1be7e40
Remove unneeded assert
alxzhang-amazon Jun 13, 2024
39761ec
Removed dependencies
alxzhang-amazon Jun 13, 2024
d2a5017
Apply isort and black reformatting
alxzhang-amazon Jun 13, 2024
4cfa8c5
Merge branch 'main' into s3-checkpointing-support-upstream
alxzhang-amazon Jun 13, 2024
7fd5d46
Updated documentation on async save to S3
alxzhang-amazon Jun 13, 2024
5228162
Apply isort and black reformatting
alxzhang-amazon Jun 13, 2024
0d8b86a
Merge branch 'main' into s3-checkpointing-support-upstream
alxzhang-amazon Jun 14, 2024
cc30150
Update S3 checkpointing doc and fix visibility on website. Update the…
alxzhang-amazon Jun 14, 2024
31aaa25
Apply isort and black reformatting
alxzhang-amazon Jun 14, 2024
776a9f5
Slight fix in s3 checkpoint doc
alxzhang-amazon Jun 14, 2024
54ffc8e
Merge branch 'main' into s3-checkpointing-support-upstream
alxzhang-amazon Jun 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ trainer:
benchmark: False
enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually

# Used for S3 Checkpointing
checkpointing:
mikolajblaz marked this conversation as resolved.
Show resolved Hide resolved
# write_concurrency * tp * pp * 1.15 (buffer) should be within 3500 S3 TPS limit per partition
max_write_concurrency: 10
# read_concurrency * tp * pp * 1.15 (buffer) should be within 5500 S3 TPS limit per partition
max_read_concurrency: 15
chunk_size_MB: 64
# enables asynchronous checkpoint writing to S3 dirpath
enable_async_checkpointing: False
mikolajblaz marked this conversation as resolved.
Show resolved Hide resolved

exp_manager:
explicit_log_dir: null
exp_dir: null
Expand All @@ -45,6 +55,7 @@ exp_manager:
resume_from_checkpoint: ${model.resume_from_checkpoint}
create_checkpoint_callback: True
checkpoint_callback_params:
dirpath: null # to use S3 checkpointing, set the dirpath in format s3://bucket/key
monitor: val_loss
save_top_k: 10
mode: min
Expand Down
13 changes: 11 additions & 2 deletions nemo/utils/callbacks/nemo_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,23 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
super().load_state_dict(state_dict)
self._remove_invalid_entries_from_topk()

def setup(self, *args, **kwargs) -> None:
def setup(self, trainer, pl_module, stage: str) -> None:
if is_global_rank_zero():
logging.debug("Removing unfinished checkpoints if any...")
NeMoModelCheckpoint._remove_unfinished_checkpoints(self.dirpath)
# Ensure that all ranks continue with unfinished checkpoints removed
if torch.distributed.is_initialized():
torch.distributed.barrier()
super().setup(*args, **kwargs)
super().setup(trainer, pl_module, stage)
# When using S3 checkpointing, only Rank 0 has the checkpoint and model path set in exp_manager.
# Sync the values across all ranks to ensure consistency.
path = trainer.strategy.broadcast(trainer.ckpt_path)
trainer.ckpt_path = path

assert (
trainer.checkpoint_callback == self
), f"This instance should be trainer.checkpoint_callback {trainer.checkpoint_callback} != {self}"
self.last_model_path = trainer.strategy.broadcast(self.last_model_path)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
Expand Down
272 changes: 272 additions & 0 deletions nemo/utils/callbacks/s3_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
import os
import time
from concurrent.futures import ProcessPoolExecutor
from io import BytesIO
from multiprocessing import get_start_method
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, Optional, Union

import torch
from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO

from nemo.utils import logging
from nemo.utils.s3_utils import (
DEFAULT_CHUNK_SIZE_MB,
DEFAULT_MAX_READ_CONCURRENCY,
DEFAULT_MAX_WRITE_CONCURRENCY,
SHARED_MEM_DIR,
S3Utils,
)


class S3CheckpointIO(CheckpointIO):
"""A custom S3CheckpointIO module that supports checkpoint reading/writing with s3 when filepath
mikolajblaz marked this conversation as resolved.
Show resolved Hide resolved
is a s3 url.
"""

def __init__(
self,
dirpath: str,
chunk_size_MB=DEFAULT_CHUNK_SIZE_MB,
max_read_concurrency=DEFAULT_MAX_READ_CONCURRENCY,
max_write_concurrency=DEFAULT_MAX_WRITE_CONCURRENCY,
async_checkpointing=False,
):
"""
Initialize the transfer configuration with custom values.

This method overrides the default TransferConfig values in boto3.
See https://boto3.amazonaws.com/v1/documentation/api/latest/_modules/boto3/s3/transfer.html#TransferConfig

Args:
chunk_size_MB (int, optional): The size of chunks to use when transferring files.
Default is 64 (MB).
max_read_concurrency (int, optional): The maximum number of threads that will be making
requests to perform a download. Default is 15.
max_write_concurrency (int, optional): The maximum number of threads that will be making
requests to perform an upload. Default is 10.
"""
if not S3Utils.is_s3_url(dirpath):
raise AssertionError(
f"Error attempting to initialize an S3CheckpointIO when {dirpath} is not an S3 url. Please use TorchCheckpointIO when using a non-S3 dirpath."
)

self.chunk_size_MB = chunk_size_MB
self.max_read_concurrency = max_read_concurrency
self.max_write_concurrency = max_write_concurrency
self._async_checkpointing = async_checkpointing
'''
When using shared memory, we create a temporary file to hold the checkpoint before uploading to S3.
This list will track those temporary files, and clean up any leaked files that are still around during teardown.
'''
self._temp_files = []

if self.async_checkpointing:
# create an executor that will asynchronously run functions
self._executor = ProcessPoolExecutor(max_workers=1) if self.async_checkpointing else None

# Eager creating a subprocess now so that forked subprocess does not inherit cuda context from parent
if get_start_method() == 'fork' and torch.cuda.is_initialized() is True:
raise Exception(
f'torch.cuda should not be initialized when checkpointing subprocess is created by fork method'
)
logging.info(f'Creating asynchronous checkpointing subprocess')
future = self._executor.submit(dummy_func)
try:
future.result()
logging.info(f'Asynchronous heckpointing subprocess created successfully')
except Exception as e:
logging.error(f'Failed to create asynchronous checkpointing subprocess, exception: {e}')
raise e
self._futures = []

super().__init__()

@property
def async_checkpointing(self):
return self._async_checkpointing

def _serialize_checkpoint_to_shm(self, checkpoint: Dict, path: str) -> str:
"""
Returns:
filename of the temporary file in shared memory.
"""
start_time = time.perf_counter()
tempfile = NamedTemporaryFile(dir=SHARED_MEM_DIR, delete=False)
torch.save(checkpoint, tempfile)
logging.info(
f'Time elapsed saving checkpoint dict to {tempfile.name} for {path}: {(time.perf_counter() - start_time):.2f} seconds, rank {torch.distributed.get_rank()}'
)
del checkpoint
return tempfile.name

def _serialize_checkpoint_to_bytes(self, checkpoint: Dict, path: str) -> BytesIO:
"""
Returns:
The bytestring of the checkpoint.
"""
ss = time.perf_counter()
bytes = BytesIO()
torch.save(checkpoint, bytes)
tt = time.perf_counter() - ss
logging.info(
f'Time elapsed saving checkpoint dict to bytes for {path}: {tt:.2f} seconds, rank {torch.distributed.get_rank()}'
)
del checkpoint
return bytes

def _check_uploading_results_so_far(self):
"""
self._future is a list of tuples of form (future, destination path, source path)
This function checks the result of all the futures, and updates the self._futures list appropriately.
It also updates the list of self._temp_files, which is used to clean up leaked temporary files in SHARED_MEM during teardown.
"""
if not self._futures:
return
start_time = time.perf_counter()
done_futures = []
in_progress_futures = []
for item in self._futures:
if item[0].done():
done_futures.append(item)
else:
in_progress_futures.append(item)

for item in done_futures:
try:
item[0].result()
except Exception as e:
logging.error(f'Failed to upload {item[2]} to {item[1]}, exception: {e}')
raise e
# If the future is complete, we can remove the temp file since we choose to clear the temp file when uploading.
try:
self._temp_files.remove(item[2])
except:

Check notice

Code scanning / CodeQL

Except block handles 'BaseException' Note

Except block directly handles BaseException.
pass # When not using shared memory, we do not append anything to the temp_files list, so remove will do nothing.
self._futures = in_progress_futures
logging.debug(
f'Time elapsed checking uploading future results: {(time.perf_counter() - start_time):.2f} seconds'
)

def save_checkpoint(
self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None
) -> None:
# if we have a shared memory directory, we can serialize as a file to shared memory instead of as bytes.
if os.path.exists(SHARED_MEM_DIR):
localfile = self._serialize_checkpoint_to_shm(checkpoint, path)
self._temp_files.append(localfile)
saved_as_file = True
else:
bytes = self._serialize_checkpoint_to_bytes(checkpoint, path)
saved_as_file = False

if self.async_checkpointing:
self._check_uploading_results_so_far()
logging.info(f'Uploading checkpoint to {path} in asynchronous mode, rank {torch.distributed.get_rank()}')
if saved_as_file:
future = self._executor.submit(
_upload_file_to_s3, localfile, path, self.chunk_size_MB, self.max_write_concurrency, True
)
self._futures.append((future, path, localfile))
else:
future = self._executor.submit(
_upload_bytes_to_s3, bytes, path, self.chunk_size_MB, self.max_write_concurrency
)
self._futures.append((future, path, 'bytes'))
else:
logging.info(f'Uploading checkpoint to {path} in synchronous mode, rank {torch.distributed.get_rank()}')
if saved_as_file:
_upload_file_to_s3(localfile, path, self.chunk_size_MB, self.max_write_concurrency, True)
self._temp_files.remove(localfile)
else:
_upload_bytes_to_s3(bytes, path, self.chunk_size_MB, self.max_write_concurrency)

def load_checkpoint(
self, path: Union[str, Path], map_location: Optional[Callable] = lambda storage, loc: storage
) -> Dict[str, Any]:
if os.path.exists(SHARED_MEM_DIR):
with NamedTemporaryFile(dir=SHARED_MEM_DIR, delete=True) as tempfile:
logging.info(
f'Loading checkpoint {path} into a temp file in shared memory {tempfile.name}, rank {torch.distributed.get_rank()}'
)
S3Utils.download_s3_file_to_path(
s3_path=path,
file_path=tempfile.name,
chunk_size_MB=self.chunk_size_MB,
max_concurrency=self.max_read_concurrency,
)
checkpoint = torch.load(tempfile.name)
else:
file_stream: BytesIO = S3Utils.download_s3_file_to_stream(
s3_path=path, chunk_size_MB=self.chunk_size_MB, max_concurrency=self.max_read_concurrency
)
checkpoint = torch.load(file_stream)
return checkpoint

def remove_checkpoint(self, path: Union[str, Path]) -> None:
if S3Utils.is_s3_url(path):
S3Utils.remove_object(path)
else:
super().remove_checkpoint(path)

def teardown(self) -> None:
# this ensure we wait for final checkpoint to finish uploading at train end.
rank = torch.distributed.get_rank()
if self.async_checkpointing:
logging.info(f'Entering teardown, waiting for all jobs to finish, rank {rank}')
start_time = time.perf_counter()
self._executor.shutdown(wait=True)
logging.info(f'executor shut down after {(time.perf_counter() - start_time):.2f} seconds, rank {rank}')

'''
this will be non-empty at the end of training if using asynchronous uploading since the futures are not processed with _check_uploading_results_so_far.
therefore, we check that the path exists first before trying to delete.
'''
if self._temp_files:
for tfile in self._temp_files:
if os.path.exists(tfile):
try:
os.remove(tfile)
except Exception as e:
logging.info(f"Error occurred while deleting file {tfile}: {e}")


def _clean_up_conflicting_checkpoint(filepath: str) -> None:
'''
before saving to s3, clean up any existing object with the same prefix megatron_gpt+step_count
e.g. before we save "megatron_gpt--step=1400-validation_loss=6.32-consumed_samples=55920.0-last.ckpt"
we need to clean up "megatron_gpt--step=1400-validation_loss=xxx-consumed_samples=yyy-last.ckpt"
so that in case later we need to resume from step 1400, it has a single checkpoint file at step 1400
'''

if S3Utils.is_s3_url(filepath):
prefix_with_step = S3Utils.parse_prefix_with_step(filepath)
logging.info(f'Looking for conflicting checkpoint under prefix {prefix_with_step}')

conflict_last_ckpts = S3Utils.find_files_with_suffix(
base_path=prefix_with_step, suffix='last.ckpt', return_key_only=False
)
for last_ckpt in conflict_last_ckpts:
logging.info(f'Cleaning up conflicting last ckpt {last_ckpt} before saving {filepath}')
S3Utils.remove_object(last_ckpt)


def _upload_file_to_s3(localfile, path, chunk_size_MB, max_write_concurrency, remove_file):
try:
_clean_up_conflicting_checkpoint(path)
S3Utils.upload_file(localfile, path, chunk_size_MB, max_write_concurrency, remove_file)
except Exception as e:
raise e


def _upload_bytes_to_s3(bytes, path, chunk_size_MB, max_write_concurrency):
try:
_clean_up_conflicting_checkpoint(path)
S3Utils.upload_file_stream_to_s3(bytes, path, chunk_size_MB, max_write_concurrency)
except Exception as e:
raise e


def dummy_func():
time.sleep(1)
mikolajblaz marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading