Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Add the conversion between dcp and sfpt #272

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions flagscale/runner/runner_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import os
import re
import sys
import socket
import subprocess

Expand Down Expand Up @@ -83,19 +84,30 @@ def get_host_name_or_ip():
sock.close()
return IP


def run_local_command(cmd, dryrun=False, query=False):
logger.info(f"Run the local command: {cmd}")
if dryrun:
return
if query:
result = subprocess.run(
cmd, shell=True, check=True, capture_output=True, text=True
cmd,
shell=True,
check=True,
capture_output=True,
text=True,
encoding="utf-8",
errors="replace",
)
return result
else:
result = subprocess.run(
cmd, shell=True, check=True, capture_output=True, text=True
cmd,
shell=True,
check=True,
capture_output=True,
text=True,
encoding='utf-8',
errors='replace'
)
if result.returncode != 0:
print(f"Command {cmd} failed with return code {result.returncode}.")
Expand All @@ -113,13 +125,22 @@ def run_ssh_command(host, cmd, port=None, dryrun=False, query=False):
logger.info(f"Running the ssh command: {ssh_cmd}")
if dryrun:
return
result = subprocess.run(
ssh_cmd,
shell=True,
check=True,
capture_output=True,
text=True,
encoding='utf-8',
errors='replace'
)
if result.returncode != 0:
print(f"SSH command {ssh_cmd} failed with return code {result.returncode}.")
print(f"Output: {result.stdout}")
print(f"Error: {result.stderr}")
sys.exit(result.returncode)
if query:
result = subprocess.run(
ssh_cmd, shell=True, check=True, text=True, stdout=subprocess.PIPE
)
return result
else:
subprocess.run(ssh_cmd, shell=True, check=True)


def run_scp_command(host, src, dst, port=None, dryrun=False):
Expand All @@ -130,7 +151,20 @@ def run_scp_command(host, src, dst, port=None, dryrun=False):
logger.info(f"Run the scp command: {scp_cmd}")
if dryrun:
return
subprocess.run(scp_cmd, shell=True, check=True)
result = subprocess.run(
scp_cmd,
shell=True,
check=True,
capture_output=True,
text=True,
encoding='utf-8',
errors='replace'
)
if result.returncode != 0:
print(f"SCP command {scp_cmd} failed with return code {result.returncode}.")
print(f"Output: {result.stdout}")
print(f"Error: {result.stderr}")
sys.exit(result.returncode)


def flatten_dict_to_args(config_dict, ignore_keys=[]):
Expand Down
16 changes: 12 additions & 4 deletions megatron/megatron/core/dist_checkpointing/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
loading the sharded tensors.
"""

import os
import logging
from pathlib import Path
from typing import Dict, Optional, Set, Tuple, Union
Expand Down Expand Up @@ -339,10 +340,17 @@ def save(
f'Checkpoint destination directory does not exist: {checkpoint_dir}'
)

if next(checkpoint_dir.iterdir(), None) is not None:
raise CheckpointingException(
f'Checkpoint destination directory ({checkpoint_dir}) is not empty'
)
# Skip this if the env var exists, otherwise default to False
single_file_per_tensor_ckpt = os.getenv('FS_SFPT_CKPT_SAVE', 'False').lower() in (
'true',
'1',
't',
)
if not single_file_per_tensor_ckpt:
if next(checkpoint_dir.iterdir(), None) is not None:
raise CheckpointingException(
f'Checkpoint destination directory ({checkpoint_dir}) is not empty'
)

if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
import logging
import os
import queue
import pickle
from contextlib import contextmanager
from itertools import chain
from pathlib import Path
from time import time
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union, cast

import psutil
import torch
from torch import multiprocessing as mp
from torch.distributed.checkpoint import FileSystemWriter
from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item
from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item, _metadata_fn
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteItem, WriteItemType
from torch.distributed.checkpoint.storage import WriteResult
from torch.distributed.checkpoint.metadata import Metadata
from torch.futures import Future

logger = logging.getLogger(__name__)
Expand All @@ -26,6 +28,40 @@

_results_queue = None

_GLOBAL_PREVIOUS_METADATA = None

_GLOBAL_PREVIOUS_COUNT = 0


def get_previous_metadata():
"""
Get the metadata from the previous save.
"""
return _GLOBAL_PREVIOUS_METADATA


def set_previous_metadata(metadata):
"""
Set the metadata from the previous save.
"""
global _GLOBAL_PREVIOUS_METADATA
_GLOBAL_PREVIOUS_METADATA = metadata


def get_previous_count():
"""
Get the count from the previous save.
"""
return _GLOBAL_PREVIOUS_COUNT


def set_previous_count(count):
"""
Set the count from the previous save.
"""
global _GLOBAL_PREVIOUS_COUNT
_GLOBAL_PREVIOUS_COUNT = count


def _get_write_results_queue():
global _results_queue
Expand Down Expand Up @@ -80,6 +116,13 @@ def __init__(self, *args, **kwargs):
self.write_buckets: Optional[List[WriteBucket]] = None
self.results_queue: Optional[mp.Queue] = None

# Get the value from the environment variable if it exists, otherwise default to False
self.single_file_per_tensor_ckpt = os.getenv('FS_SFPT_CKPT_SAVE', 'False').lower() in (
'true',
'1',
't',
)

def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None:
"""
First stage of async saving. Copy data to CPU and plan the local saving.
Expand All @@ -99,12 +142,17 @@ def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None:
start = time()
# move tensors from GPU to CPU before starting async writing
# We do D2H synchronously for now
file_count = 0
if not self.single_file_per_tensor_ckpt:
file_count = 0
else:
file_count = get_previous_count()

def gen_file():
nonlocal file_count
file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
file_count += 1
if self.single_file_per_tensor_ckpt:
set_previous_count(file_count)
return file_name

# Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
Expand Down Expand Up @@ -314,6 +362,48 @@ def retrieve_write_results(self) -> List[WriteResult]:
)
return list(chain.from_iterable(write_results.values()))

def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
# Modify based on the original implementation from torch.distributed.checkpoint.filesystem.FileSystemWriter
# https://github.com/pytorch/pytorch/blob/625c24a7f98a645b6f8758a01d7163a842582ce0/torch/distributed/checkpoint/filesystem.py#L574

if not self.single_file_per_tensor_ckpt:
storage_md = {}
else:
if get_previous_count() == 1:
storage_md = {}
else:
# Get the metadata from the previous save
prev_metadata = get_previous_metadata()
prev_metadata.state_dict_metadata.update(metadata.state_dict_metadata)
metadata = prev_metadata
storage_md = metadata.storage_data

for wr_list in results:
storage_md.update({wr.index: wr.storage_data for wr in wr_list})
metadata.storage_data = storage_md

if not self.single_file_per_tensor_ckpt or get_previous_count() == 1:
metadata.storage_meta = self.storage_meta()

tmp_path = cast(Path, self.fs.concat_path(self.path, f"{_metadata_fn}.tmp"))
with self.fs.create_stream(tmp_path, "wb") as metadata_file:
pickle.dump(metadata, metadata_file)
if self.sync_files:
try:
os.fsync(metadata_file.fileno())
except AttributeError:
os.sync()

# delete in-case other checkpoints were present.
if self.fs.exists(self.metadata_path):
self.fs.rm_file(self.metadata_path)

self.fs.rename(tmp_path, self.metadata_path)

# Store the metadata for the next save
if self.single_file_per_tensor_ckpt:
set_previous_metadata(metadata)


def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
"""
Expand Down Expand Up @@ -349,7 +439,6 @@ def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[Writ
idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
buckets[idx].append(item)
bucket_sizes[idx] += _item_size(item)

return buckets


Expand Down
41 changes: 33 additions & 8 deletions megatron/megatron/core/dist_checkpointing/strategies/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

""" Strategies using PyTorch distributed.checkpoint as an underlying format. """
import io
import os
from collections import ChainMap, defaultdict
from dataclasses import dataclass
from itertools import product
Expand Down Expand Up @@ -404,7 +405,6 @@ def _replace_sharded_keys_with_state_dict_keys(
assert len(tensors) == len(rename_mapping[k])
for ten, recovered_k in zip(tensors, rename_mapping[k]):
recovered_sd[recovered_k] = ten

return unflatten_state_dict(recovered_sd, flat_mapping)


Expand Down Expand Up @@ -734,6 +734,13 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
Returns: loaded state dict
"""
# Get the value from the environment variable if it exists, otherwise default to True
single_file_per_tensor_ckpt = os.getenv('FS_SFPT_CKPT_LOAD', 'False').lower() in (
'true',
'1',
't',
)

# Apply N-D tensors resharding
sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation(
sharded_state_dict, get_reformulation_metadata(sharded_state_dict, checkpoint_dir)
Expand All @@ -752,13 +759,24 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True)
# Load PyT Distributed format
checkpoint.load_state_dict(
pyt_state_dict,
FileSystemReader(checkpoint_dir),
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
),
)
if not single_file_per_tensor_ckpt:
checkpoint.load_state_dict(
pyt_state_dict,
FileSystemReader(checkpoint_dir),
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
),
)
else:
checkpoint.load_state_dict(
pyt_state_dict,
FileSystemReader(checkpoint_dir),
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors,
allow_partial_load=True,
),
)

pyt_state_dict = cast(
Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict
)
Expand All @@ -767,6 +785,13 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
k: v if not isinstance(v, TorchShardedTensor) else _unwrap_pyt_sharded_tensor(v)
for k, v in pyt_state_dict.items()
}

if single_file_per_tensor_ckpt:
mcore_state_dict = {
k: [None] if (not isinstance(v, list) and "_extra_state" in k) else v
for k, v in mcore_state_dict.items()
}

mcore_state_dict = _replace_sharded_keys_with_state_dict_keys(
mcore_state_dict, flat_mapping, rename_mapping
)
Expand Down
12 changes: 11 additions & 1 deletion megatron/megatron/core/transformer/transformer_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import os
from contextlib import nullcontext
from dataclasses import dataclass
from typing import List, Optional, Union
Expand Down Expand Up @@ -630,6 +630,16 @@ def sharded_state_dict(
non_homogeneous_layers = metadata is not None and metadata.get(
'non_homogeneous_layers', False
)

# TODO: @aoyulong - This is a temporary solution to support single-file-per-tensor ckpt
non_homogeneous_layers_env = os.getenv('FS_NON_HOMOGENEOUS_LAYERS', 'False').lower() in (
'true',
'1',
't',
)
if non_homogeneous_layers_env:
non_homogeneous_layers = True

sharded_state_dict = {}

layer_prefix = f'{prefix}layers.'
Expand Down
4 changes: 3 additions & 1 deletion tests/scripts/format_tests/test_format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ flagscale/logger.py \
flagscale/patches_utils.py \
flagscale/datasets/sft_dataset.py \
flagscale/inference/inference_*.py \
flagscale/inference/arguments.py"
flagscale/inference/arguments.py \
tools/checkpoint/sfpt_ckpt/*.py \
"

# Function to run a command and continue even if it fails
run_command() {
Expand Down
Loading
Loading