Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[auto_parallel] Layered Implementation of load_state_dict #66925

Merged
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 76 additions & 45 deletions python/paddle/distributed/checkpoint/load_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,22 @@ def get_checkpoint_files(path, use_cache=True):
return (metadata_files, local_data_files)


def get_rank_to_files(path, state_dict, process_group, use_dist):
def get_rank_to_files(
metadata_list, local_data_files, state_dict, process_group, use_dist
):
"""
Get the mapping of rank to its accessible files.
"""
metadata_files, local_data_files = get_checkpoint_files(path)

# The necessary files to be read
tensor_key_list = []
necessary_files = []
for metadata_file in metadata_files:
metadata = paddle.load(os.path.join(path, metadata_file))

for metadata in metadata_list:
for local_tensor_index, file_name in metadata.storage_metadata.items():
assert (
local_tensor_index not in tensor_key_list
), f"Duplicate tensor_key:{local_tensor_index} found. Check whether the metadata_file:{metadata_file} contains the same tensor metadata."
), f"Duplicate tensor_key:{local_tensor_index} found. Check whether the metadata."
tensor_key_list.append(local_tensor_index.tensor_key)
if local_tensor_index.tensor_key in state_dict:
necessary_files.append(file_name)
Expand All @@ -96,7 +98,7 @@ def get_rank_to_files(path, state_dict, process_group, use_dist):
global_necessary_files_set = set(global_necessary_files)
if len(global_necessary_files_set) <= 0:
logger.warning(
f"No necessary data files found in the checkpoint directory:{path}. Please check the metadata_files:{metadata_files}"
"No necessary data files found in the checkpoint directory. Please check the metadata."
)
missing_keys = set(state_dict.keys())
return {}, missing_keys
Expand All @@ -120,7 +122,7 @@ def get_rank_to_files(path, state_dict, process_group, use_dist):
assert (
global_data_files_set & global_necessary_files_set
== global_necessary_files_set
), f"The checkpoint files are not complete. Please check the checkpoint directory:{path}.global_data_files_set:{global_data_files_set}, necessary_data_files_set:{global_necessary_files_set}"
), f"The checkpoint files are not complete. Please check the checkpoint directory. global_data_files_set:{global_data_files_set}, necessary_data_files_set:{global_necessary_files_set}"
missing_keys = set(state_dict.keys()) - set(tensor_key_list)
if len(missing_keys) > 0:
logger.warning(
Expand Down Expand Up @@ -192,6 +194,13 @@ def get_read_rank_file(rank_to_not_read_files, ranks):
for rank, files in rank_to_not_read_files.items()
if rank in ranks
]
# 'ranks' refer to the ranks that have read the fewest number of files so far. However, the files containing the weights required
#. by these ranks may have already been completely read. In this case, they will not read any more files.
if len(nums) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加一下 Note,说明一下原因

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加

nums = [
(rank, len(files))
for rank, files in rank_to_not_read_files.items()
]
nums = sorted(nums, key=lambda x: x[1])
rank = nums[0][0]
return (rank, rank_to_not_read_files[rank][0])
Expand Down Expand Up @@ -224,6 +233,7 @@ def update(rank_to_read_files, rank_to_not_read_files, rank_file):
logger.debug(
f"update rank_to_read_files:{rank_to_read_files}, rank_to_not_read_files:{rank_to_not_read_files}, ranks:{ranks}, rank_file:{rank_file}"
)

cur_rank = paddle.distributed.get_rank()
if cur_rank in rank_to_read_files:
return rank_to_read_files[cur_rank]
Expand All @@ -232,17 +242,16 @@ def update(rank_to_read_files, rank_to_not_read_files, rank_file):
return []


def get_load_infos(path, local_load_files, process_group, use_dist):
def get_load_infos(metadata_list, local_load_files, process_group, use_dist):
load_info = {}
metadata_files, _ = get_checkpoint_files(path)
for metadata_file in metadata_files:
metadata = paddle.load(os.path.join(path, metadata_file))
for metadata in metadata_list:
for local_tensor_index, file_name in metadata.storage_metadata.items():
if file_name in local_load_files:
load_info[local_tensor_index] = (
paddle.distributed.get_rank(),
file_name,
)

load_info_list = []
if use_dist:
paddle.distributed.all_gather_object(
Expand Down Expand Up @@ -308,18 +317,17 @@ def not_overlap(
return False


def get_read_items(path, state_dict, process_group, use_dist):
def get_read_items(metadata_list, state_dict, process_group, use_dist):
storage_state_dict_metadata = {}
metadata_files, _ = get_checkpoint_files(path)
for metadata_file in metadata_files:
metadata = paddle.load(os.path.join(path, metadata_file))
for metadata in metadata_list:
for (
tensor_key,
local_tensor_metadata,
) in metadata.state_dict_metadata.items():
if tensor_key not in storage_state_dict_metadata:
storage_state_dict_metadata[tensor_key] = []
storage_state_dict_metadata[tensor_key] += local_tensor_metadata

read_items = []
logger.debug(f"storage_state_dict_metadata:{storage_state_dict_metadata}")
for tensor_key, val in state_dict.items():
Expand Down Expand Up @@ -451,8 +459,18 @@ def load_state_dict(
# sync to avoid some ranks not write path yet
paddle.distributed.barrier(process_group)

metadata_files, local_data_files = get_checkpoint_files(path)

metadata_list = []
for file in metadata_files:
metadata_list.append(paddle.load(os.path.join(path, file)))

rank_to_files, missing_keys = get_rank_to_files(
path, flat_state_dict, process_group, use_dist
metadata_list,
local_data_files,
flat_state_dict,
process_group,
use_dist,
)

if len(missing_keys) > 0:
Expand All @@ -461,25 +479,49 @@ def load_state_dict(
)
if len(rank_to_files) <= 0:
return

local_load_files = get_local_load_files(rank_to_files)

source_state_dict = {}
for file in local_load_files:
source_state_dict[file] = paddle.load(os.path.join(path, file))

state_dict_in_cpu = []
for k, v in flat_state_dict.items():
if v.place.is_cpu_place():
state_dict_in_cpu.append(k)
flat_state_dict[k] = v.cuda()

_load_state_dict(flat_state_dict, source_state_dict, metadata_list)

for k, v in flat_state_dict.items():
if k in state_dict_in_cpu:
value = state_dict
for key in mapping[k]:
value = value[key]
paddle.assign(v.cpu(), value)


def _load_state_dict(
target_state_dict,
source_state_dict,
metadata_list,
process_group=None,
coordinator_rank=0,
) -> None:
with paddle.base.dygraph.guard():
use_dist = True if paddle.distributed.get_world_size() > 1 else False
local_load_files = list(source_state_dict.keys())
# load_infos: {LocalTensorIndex: (rank, file_name)}, which local tensor located in which file, and the file is load in which rank.
load_infos = get_load_infos(
path, local_load_files, process_group, use_dist
metadata_list, local_load_files, process_group, use_dist
)
# read_items: [ReadItem(local_tensor_index, rank, cur_offsets, storage_offsets, lengths)],
# slice the storage local tensor in (storage_offsets, lengths) to assign the current tensor in (cur_offsets, lengths) in rank.
read_items = get_read_items(
path, flat_state_dict, process_group, use_dist
)
storage_file_to_state_dict = {}
logger.debug(
f"before load, state_dict:{flat_state_dict},\n load_infos:{load_infos},\n read_items:{read_items}"
metadata_list, target_state_dict, process_group, use_dist
)
state_dict_in_cpu = []
for k, v in flat_state_dict.items():
if v.place.is_cpu_place():
state_dict_in_cpu.append(k)
flat_state_dict[k] = v.cuda()

for item in read_items:
assert (
item.local_tensor_index in load_infos
Expand All @@ -489,12 +531,8 @@ def load_state_dict(
cur_chunk_tensor = None
# The src rank need to load the state_dict.
if src_rank == paddle.distributed.get_rank():
if file_name not in storage_file_to_state_dict:
# The value in state_dict is not distributed tensor but a normal tensor.
storage_file_to_state_dict[file_name] = paddle.load(
os.path.join(path, file_name)
)
storage_state_dict = storage_file_to_state_dict[file_name]
assert file_name in source_state_dict
storage_state_dict = source_state_dict[file_name]
assert item.local_tensor_index.tensor_key in storage_state_dict
storage_local_tensor = storage_state_dict[
item.local_tensor_index.tensor_key
Expand All @@ -520,18 +558,18 @@ def load_state_dict(
# The read item rank need to be assigned
if item.rank == paddle.distributed.get_rank():
assert (
item.local_tensor_index.tensor_key in flat_state_dict
), f"item:{item}, state_dict:{flat_state_dict}"
item.local_tensor_index.tensor_key in target_state_dict
), f"item:{item}, state_dict:{target_state_dict}"

cur_local_tensor = (
flat_state_dict[
target_state_dict[
item.local_tensor_index.tensor_key
]._local_value()
if use_dist
and flat_state_dict[
and target_state_dict[
item.local_tensor_index.tensor_key
].is_dist()
else flat_state_dict[item.local_tensor_index.tensor_key]
else target_state_dict[item.local_tensor_index.tensor_key]
)

cur_offsets = item.cur_offset
Expand Down Expand Up @@ -576,10 +614,3 @@ def load_state_dict(
tmp_tensor, src=src_rank, group=process_group
)
paddle.assign(tmp_tensor, cur_chunk_tensor)

for k, v in flat_state_dict.items():
if k in state_dict_in_cpu:
value = state_dict
for key in mapping[k]:
value = value[key]
paddle.assign(v.cpu(), value)