From d2d29adad17ae1fc48a294dfd5c4fa4d3e63e809 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Tue, 23 Jul 2024 04:00:18 -0700 Subject: [PATCH] Use utils to get shared fs safe signal file name (#1381) --- llmfoundry/data/finetuning/dataloader.py | 55 +++++++++--------------- llmfoundry/data/finetuning/tasks.py | 2 +- llmfoundry/models/hf/hf_causal_lm.py | 2 +- llmfoundry/utils/builders.py | 2 +- 4 files changed, 23 insertions(+), 38 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 11104ac706..60052acdc5 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -534,42 +534,27 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str: # Since we don't know exactly what the extension will be, since it is one of a list # use a signal file to wait for instead of the desired file - signal_file_path = os.path.join( - finetune_dir, - f'.node_{dist.get_node_rank()}_local_rank0_completed', - ) - if dist.get_local_rank() == 0: - try: - get_file(path=name, destination=destination, overwrite=True) - except FileNotFoundError as e: - if extension == SUPPORTED_EXTENSIONS[-1]: - files_searched = [ - f'{name}/{split}{ext}' for ext in SUPPORTED_EXTENSIONS - ] - raise FileNotFoundError( - f'Could not find a file with any of ' + \ - f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \ - f'at {files_searched}', - ) from e - else: - log.debug( - f'Could not find {name}, looking for another extension', - ) - continue - - os.makedirs(os.path.dirname(signal_file_path), exist_ok=True) - with open(signal_file_path, 'wb') as f: - f.write(b'local_rank0_completed_download') - - # Avoid the collective call until the local rank zero has finished trying to download the dataset - # so that we don't timeout for large downloads. This syncs all processes on the node - with dist.local_rank_zero_download_and_wait(signal_file_path): - # Then, wait to ensure every node has finished trying to download the dataset - dist.barrier() + with dist.busy_wait_for_local_rank_zero(finetune_dir): + if dist.get_local_rank() == 0: + try: + get_file(path=name, destination=destination, overwrite=True) + except FileNotFoundError as e: + if extension == SUPPORTED_EXTENSIONS[-1]: + files_searched = [ + f'{name}/{split}{ext}' + for ext in SUPPORTED_EXTENSIONS + ] + raise FileNotFoundError( + f'Could not find a file with any of ' + \ + f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \ + f'at {files_searched}', + ) from e + else: + log.debug( + f'Could not find {name}, looking for another extension', + ) + continue - # clean up signal file - if dist.get_local_rank() == 0: - os.remove(signal_file_path) dist.barrier() break return finetune_dir diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 78bfb9c74c..d5af632952 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -827,7 +827,7 @@ def build_from_hf( Returns: Dataset: The tokenized dataset. """ - signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed' + signal_file_path = dist.get_node_signal_file_name() # Non local rank 0 ranks will wait here for local rank 0 to finish the data processing. # Once local rank 0 is done, the datasets are all cached on disk, and all other ranks diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index e3fa8d03a3..a15429aa06 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -363,7 +363,7 @@ def _autoset_attn_implementation_monkeypatch( f'init_device="{init_device}" must be either "cpu" or "meta".', ) - signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed' + signal_file_path = dist.get_node_signal_file_name() if dist.get_local_rank() == 0: with open(signal_file_path, 'wb') as f: f.write(b'local_rank0_completed_download') diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 0437736f74..b889155be0 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -498,7 +498,7 @@ def build_tokenizer( os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' os.environ['TOKENIZERS_PARALLELISM'] = 'false' - signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup' + signal_file_path = dist.get_node_signal_file_name() if dist.is_available() and dist.is_initialized( ) and dist.get_world_size() > 1: