Skip to content

Commit

Permalink
[Core] Add the conversion between dcp and sfpt
Browse files Browse the repository at this point in the history
  • Loading branch information
aoyulong committed Nov 20, 2024
1 parent f92de37 commit 916aa20
Show file tree
Hide file tree
Showing 8 changed files with 386 additions and 18 deletions.
16 changes: 12 additions & 4 deletions megatron/megatron/core/dist_checkpointing/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
loading the sharded tensors.
"""

import os
import logging
from pathlib import Path
from typing import Dict, Optional, Set, Tuple, Union
Expand Down Expand Up @@ -339,10 +340,17 @@ def save(
f'Checkpoint destination directory does not exist: {checkpoint_dir}'
)

if next(checkpoint_dir.iterdir(), None) is not None:
raise CheckpointingException(
f'Checkpoint destination directory ({checkpoint_dir}) is not empty'
)
# Skip this if the env var exists, otherwise default to False
single_file_per_tensor_ckpt = os.getenv('FS_SFPT_CKPT_SAVE', 'False').lower() in (
'true',
'1',
't',
)
if not single_file_per_tensor_ckpt:
if next(checkpoint_dir.iterdir(), None) is not None:
raise CheckpointingException(
f'Checkpoint destination directory ({checkpoint_dir}) is not empty'
)

if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
import logging
import os
import queue
import pickle
from contextlib import contextmanager
from itertools import chain
from pathlib import Path
from time import time
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union, cast

import psutil
import torch
from torch import multiprocessing as mp
from torch.distributed.checkpoint import FileSystemWriter
from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item
from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item, _metadata_fn
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteItem, WriteItemType
from torch.distributed.checkpoint.storage import WriteResult
from torch.distributed.checkpoint.metadata import Metadata
from torch.futures import Future

logger = logging.getLogger(__name__)
Expand All @@ -26,6 +28,40 @@

_results_queue = None

_GLOBAL_PREVIOUS_METADATA = None

_GLOBAL_PREVIOUS_COUNT = 0


def get_previous_metadata():
"""
Get the metadata from the previous save.
"""
return _GLOBAL_PREVIOUS_METADATA


def set_previous_metadata(metadata):
"""
Set the metadata from the previous save.
"""
global _GLOBAL_PREVIOUS_METADATA
_GLOBAL_PREVIOUS_METADATA = metadata


def get_previous_count():
"""
Get the count from the previous save.
"""
return _GLOBAL_PREVIOUS_COUNT


def set_previous_count(count):
"""
Set the count from the previous save.
"""
global _GLOBAL_PREVIOUS_COUNT
_GLOBAL_PREVIOUS_COUNT = count


def _get_write_results_queue():
global _results_queue
Expand Down Expand Up @@ -80,6 +116,13 @@ def __init__(self, *args, **kwargs):
self.write_buckets: Optional[List[WriteBucket]] = None
self.results_queue: Optional[mp.Queue] = None

# Get the value from the environment variable if it exists, otherwise default to False
self.single_file_per_tensor_ckpt = os.getenv('FS_SFPT_CKPT_SAVE', 'False').lower() in (
'true',
'1',
't',
)

def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None:
"""
First stage of async saving. Copy data to CPU and plan the local saving.
Expand All @@ -99,12 +142,17 @@ def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None:
start = time()
# move tensors from GPU to CPU before starting async writing
# We do D2H synchronously for now
file_count = 0
if not self.single_file_per_tensor_ckpt:
file_count = 0
else:
file_count = get_previous_count()

def gen_file():
nonlocal file_count
file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
file_count += 1
if self.single_file_per_tensor_ckpt:
set_previous_count(file_count)
return file_name

# Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
Expand Down Expand Up @@ -314,6 +362,48 @@ def retrieve_write_results(self) -> List[WriteResult]:
)
return list(chain.from_iterable(write_results.values()))

def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
# Modify based on the original implementation from torch.distributed.checkpoint.filesystem.FileSystemWriter
# https://github.com/pytorch/pytorch/blob/625c24a7f98a645b6f8758a01d7163a842582ce0/torch/distributed/checkpoint/filesystem.py#L574

if not self.single_file_per_tensor_ckpt:
storage_md = {}
else:
if get_previous_count() == 1:
storage_md = {}
else:
# Get the metadata from the previous save
prev_metadata = get_previous_metadata()
prev_metadata.state_dict_metadata.update(metadata.state_dict_metadata)
metadata = prev_metadata
storage_md = metadata.storage_data

for wr_list in results:
storage_md.update({wr.index: wr.storage_data for wr in wr_list})
metadata.storage_data = storage_md

if not self.single_file_per_tensor_ckpt or get_previous_count() == 1:
metadata.storage_meta = self.storage_meta()

tmp_path = cast(Path, self.fs.concat_path(self.path, f"{_metadata_fn}.tmp"))
with self.fs.create_stream(tmp_path, "wb") as metadata_file:
pickle.dump(metadata, metadata_file)
if self.sync_files:
try:
os.fsync(metadata_file.fileno())
except AttributeError:
os.sync()

# delete in-case other checkpoints were present.
if self.fs.exists(self.metadata_path):
self.fs.rm_file(self.metadata_path)

self.fs.rename(tmp_path, self.metadata_path)

# Store the metadata for the next save
if self.single_file_per_tensor_ckpt:
set_previous_metadata(metadata)


def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
"""
Expand Down Expand Up @@ -349,7 +439,6 @@ def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[Writ
idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
buckets[idx].append(item)
bucket_sizes[idx] += _item_size(item)

return buckets


Expand Down
41 changes: 33 additions & 8 deletions megatron/megatron/core/dist_checkpointing/strategies/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

""" Strategies using PyTorch distributed.checkpoint as an underlying format. """
import io
import os
from collections import ChainMap, defaultdict
from dataclasses import dataclass
from itertools import product
Expand Down Expand Up @@ -404,7 +405,6 @@ def _replace_sharded_keys_with_state_dict_keys(
assert len(tensors) == len(rename_mapping[k])
for ten, recovered_k in zip(tensors, rename_mapping[k]):
recovered_sd[recovered_k] = ten

return unflatten_state_dict(recovered_sd, flat_mapping)


Expand Down Expand Up @@ -734,6 +734,13 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
Returns: loaded state dict
"""
# Get the value from the environment variable if it exists, otherwise default to True
single_file_per_tensor_ckpt = os.getenv('FS_SFPT_CKPT_LOAD', 'False').lower() in (
'true',
'1',
't',
)

# Apply N-D tensors resharding
sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation(
sharded_state_dict, get_reformulation_metadata(sharded_state_dict, checkpoint_dir)
Expand All @@ -752,13 +759,24 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True)
# Load PyT Distributed format
checkpoint.load_state_dict(
pyt_state_dict,
FileSystemReader(checkpoint_dir),
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
),
)
if not single_file_per_tensor_ckpt:
checkpoint.load_state_dict(
pyt_state_dict,
FileSystemReader(checkpoint_dir),
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
),
)
else:
checkpoint.load_state_dict(
pyt_state_dict,
FileSystemReader(checkpoint_dir),
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors,
allow_partial_load=True,
),
)

pyt_state_dict = cast(
Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict
)
Expand All @@ -767,6 +785,13 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
k: v if not isinstance(v, TorchShardedTensor) else _unwrap_pyt_sharded_tensor(v)
for k, v in pyt_state_dict.items()
}

if single_file_per_tensor_ckpt:
mcore_state_dict = {
k: [None] if (not isinstance(v, list) and "_extra_state" in k) else v
for k, v in mcore_state_dict.items()
}

mcore_state_dict = _replace_sharded_keys_with_state_dict_keys(
mcore_state_dict, flat_mapping, rename_mapping
)
Expand Down
12 changes: 11 additions & 1 deletion megatron/megatron/core/transformer/transformer_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import os
from contextlib import nullcontext
from dataclasses import dataclass
from typing import List, Optional, Union
Expand Down Expand Up @@ -630,6 +630,16 @@ def sharded_state_dict(
non_homogeneous_layers = metadata is not None and metadata.get(
'non_homogeneous_layers', False
)

# TODO: @aoyulong - This is a temporary solution to support single-file-per-tensor ckpt
non_homogeneous_layers_env = os.getenv('FS_NON_HOMOGENEOUS_LAYERS', 'False').lower() in (
'true',
'1',
't',
)
if non_homogeneous_layers_env:
non_homogeneous_layers = True

sharded_state_dict = {}

layer_prefix = f'{prefix}layers.'
Expand Down
4 changes: 3 additions & 1 deletion tests/scripts/format_tests/test_format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ flagscale/logger.py \
flagscale/patches_utils.py \
flagscale/datasets/sft_dataset.py \
flagscale/inference/inference_*.py \
flagscale/inference/arguments.py"
flagscale/inference/arguments.py \
tools/checkpoint/sfpt_ckpt/*.py \
"

# Function to run a command and continue even if it fails
run_command() {
Expand Down
73 changes: 73 additions & 0 deletions tools/checkpoint/sfpt_ckpt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# README

This directory contains scripts for converting checkpoints between DCP (Distributed Checkpoint) and SFPT (Single File Per Tensor) formats.

## Scripts

- `dcp_to_sfpt.py` - Converts a DCP checkpoint to SFPT format.
- `sfpt_to_dcp.py` - Converts an SFPT checkpoint to DCP format.

## Usage

**Convert DCP to SFPT:**
1. Get the DCP checkpoint non-homogeneous layers from the training run.
* Add the environment variable to experiment-level configuration file:
```yaml
envs:
FS_NON_HOMOGENEOUS_LAYERS: True
```
* Add the following to the task-level configuration file:
```yaml
use_dist_ckpt: True
ckpt_format: torch_dist
ckpt_fully_parallel_save: True
ckpt_fully_parallel_load: True
```
2. Set the `PYTHONPATH` environment variable:

```bash
# FlagScale_ROOT is the root directory of the FlagScale repository
export PYTHONPATH=$FlagScale_ROOT/megatron:$FlagScale_ROOT
```

3. Run the conversion script:
```bash
torchrun --nnodes 1 --node_rank 0 --nproc_per_node 1 \
--master_addr localhost --master_port 1234 \
dcp_to_sfpt.py --input_dir /path/to/dcp_checkpoint --output_dir /path/to/output_sfpt_checkpoint
```

**Convert SFPT to DCP:**

1. Set the `PYTHONPATH` environment variable:
```bash
# FlagScale_ROOT is the root directory of the FlagScale repository
export PYTHONPATH=$FlagScale_ROOT/megatron:$FlagScale_ROOT
```

2. Run the conversion script:
```bash
FS_SFPT_CKPT_SAVE=1 torchrun --nnodes 1 --node_rank 0 --nproc_per_node 1 \
--master_addr localhost --master_port 1234 \
sfpt_to_dcp.py --input_dir /path/to/sfpt_checkpoint --output_dir /path/to/output_dcp_checkpoint
```

3. Use the DCP checkpoint for further fine-tuning.
* Add the environment variables to experiment-level configuration file:
```yaml
envs:
FS_NON_HOMOGENEOUS_LAYERS: True
FS_SFPT_CKPT_LOAD: True
```

* Add the following to the task-level configuration file:
```yaml
use_dist_ckpt: True
ckpt_format: torch_dist
ckpt_fully_parallel_save: True
ckpt_fully_parallel_load: True
finetune: True
load: /path/to/output_dcp_checkpoint
```
Loading

0 comments on commit 916aa20

Please sign in to comment.