From 52b337deae33085a2570fca45d2a1a9f4d79871a Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:54:02 +0800 Subject: [PATCH 1/7] Update load_state_dict.py --- .../distributed/checkpoint/load_state_dict.py | 149 +++++++++++------- 1 file changed, 89 insertions(+), 60 deletions(-) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index dba64277ce9695..431a4994462f2c 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -63,23 +63,23 @@ def get_checkpoint_files(path, use_cache=True): return (metadata_files, local_data_files) -def get_rank_to_files(path, state_dict, process_group, use_dist): +def get_rank_to_files( + metadata, local_data_files, state_dict, process_group, use_dist +): """ Get the mapping of rank to its accessible files. """ - metadata_files, local_data_files = get_checkpoint_files(path) + # The necessary files to be read tensor_key_list = [] necessary_files = [] - for metadata_file in metadata_files: - metadata = paddle.load(os.path.join(path, metadata_file)) - for local_tensor_index, file_name in metadata.storage_metadata.items(): - assert ( - local_tensor_index not in tensor_key_list - ), f"Duplicate tensor_key:{local_tensor_index} found. Check whether the metadata_file:{metadata_file} contains the same tensor metadata." - tensor_key_list.append(local_tensor_index.tensor_key) - if local_tensor_index.tensor_key in state_dict: - necessary_files.append(file_name) + for local_tensor_index, file_name in metadata.storage_metadata.items(): + assert ( + local_tensor_index not in tensor_key_list + ), f"Duplicate tensor_key:{local_tensor_index} found. Check whether the metadata." + tensor_key_list.append(local_tensor_index.tensor_key) + if local_tensor_index.tensor_key in state_dict: + necessary_files.append(file_name) all_necessary_files = [] if use_dist: @@ -96,7 +96,7 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): global_necessary_files_set = set(global_necessary_files) if len(global_necessary_files_set) <= 0: logger.warning( - f"No necessary data files found in the checkpoint directory:{path}. Please check the metadata_files:{metadata_files}" + "No necessary data files found in the checkpoint directory. Please check the metadata." ) missing_keys = set(state_dict.keys()) return {}, missing_keys @@ -120,7 +120,7 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): assert ( global_data_files_set & global_necessary_files_set == global_necessary_files_set - ), f"The checkpoint files are not complete. Please check the checkpoint directory:{path}.global_data_files_set:{global_data_files_set}, necessary_data_files_set:{global_necessary_files_set}" + ), f"The checkpoint files are not complete. Please check the checkpoint directory. global_data_files_set:{global_data_files_set}, necessary_data_files_set:{global_necessary_files_set}" missing_keys = set(state_dict.keys()) - set(tensor_key_list) if len(missing_keys) > 0: logger.warning( @@ -176,6 +176,7 @@ def get_local_load_files(rank_to_files): f"rank_to_read_files:{rank_to_read_files}, rank_to_not_read_files:{rank_to_not_read_files}" ) + # 得到目前读文件最少的rank def get_least_read_files_ranks(rank_to_read_files): nums = [ (rank, len(files)) for rank, files in rank_to_read_files.items() @@ -192,6 +193,11 @@ def get_read_rank_file(rank_to_not_read_files, ranks): for rank, files in rank_to_not_read_files.items() if rank in ranks ] + if len(nums) == 0: + nums = [ + (rank, len(files)) + for rank, files in rank_to_not_read_files.items() + ] nums = sorted(nums, key=lambda x: x[1]) rank = nums[0][0] return (rank, rank_to_not_read_files[rank][0]) @@ -224,6 +230,7 @@ def update(rank_to_read_files, rank_to_not_read_files, rank_file): logger.debug( f"update rank_to_read_files:{rank_to_read_files}, rank_to_not_read_files:{rank_to_not_read_files}, ranks:{ranks}, rank_file:{rank_file}" ) + cur_rank = paddle.distributed.get_rank() if cur_rank in rank_to_read_files: return rank_to_read_files[cur_rank] @@ -232,17 +239,15 @@ def update(rank_to_read_files, rank_to_not_read_files, rank_file): return [] -def get_load_infos(path, local_load_files, process_group, use_dist): +def get_load_infos(metadata, local_load_files, process_group, use_dist): load_info = {} - metadata_files, _ = get_checkpoint_files(path) - for metadata_file in metadata_files: - metadata = paddle.load(os.path.join(path, metadata_file)) - for local_tensor_index, file_name in metadata.storage_metadata.items(): - if file_name in local_load_files: - load_info[local_tensor_index] = ( - paddle.distributed.get_rank(), - file_name, - ) + for local_tensor_index, file_name in metadata.storage_metadata.items(): + if file_name in local_load_files: + load_info[local_tensor_index] = ( + paddle.distributed.get_rank(), + file_name, + ) + load_info_list = [] if use_dist: paddle.distributed.all_gather_object( @@ -308,18 +313,17 @@ def not_overlap( return False -def get_read_items(path, state_dict, process_group, use_dist): +def get_read_items(metadata, state_dict, process_group, use_dist): storage_state_dict_metadata = {} - metadata_files, _ = get_checkpoint_files(path) - for metadata_file in metadata_files: - metadata = paddle.load(os.path.join(path, metadata_file)) - for ( - tensor_key, - local_tensor_metadata, - ) in metadata.state_dict_metadata.items(): - if tensor_key not in storage_state_dict_metadata: - storage_state_dict_metadata[tensor_key] = [] - storage_state_dict_metadata[tensor_key] += local_tensor_metadata + + for ( + tensor_key, + local_tensor_metadata, + ) in metadata.state_dict_metadata.items(): + if tensor_key not in storage_state_dict_metadata: + storage_state_dict_metadata[tensor_key] = [] + storage_state_dict_metadata[tensor_key] += local_tensor_metadata + read_items = [] logger.debug(f"storage_state_dict_metadata:{storage_state_dict_metadata}") for tensor_key, val in state_dict.items(): @@ -451,8 +455,23 @@ def load_state_dict( # sync to avoid some ranks not write path yet paddle.distributed.barrier(process_group) + metadata_files, local_data_files = get_checkpoint_files(path) + + assert len(metadata_files) == 1 + + if coordinator_rank == paddle.distributed.get_rank(): + metadata = paddle.load(os.path.join(path, metadata_files[0])) + else: + metadata = None + + global_metadata = [] + + paddle.distributed.all_gather_object(global_metadata, metadata) + + metadata = global_metadata[coordinator_rank] + rank_to_files, missing_keys = get_rank_to_files( - path, flat_state_dict, process_group, use_dist + metadata, local_data_files, flat_state_dict, process_group, use_dist ) if len(missing_keys) > 0: @@ -461,25 +480,42 @@ def load_state_dict( ) if len(rank_to_files) <= 0: return + local_load_files = get_local_load_files(rank_to_files) + + source_state_dict = {} + for file in local_load_files: + source_state_dict[file] = paddle.load(os.path.join(path, file)) + + _load_state_dict(flat_state_dict, source_state_dict, metadata) + + +def _load_state_dict( + target_state_dict, + source_state_dict, + metadata, + process_group=None, + coordinator_rank=0, +) -> None: + with paddle.base.dygraph.guard(): + use_dist = True if paddle.distributed.get_world_size() > 1 else False + local_load_files = list(source_state_dict.keys()) # load_infos: {LocalTensorIndex: (rank, file_name)}, which local tensor located in which file, and the file is load in which rank. load_infos = get_load_infos( - path, local_load_files, process_group, use_dist + metadata, local_load_files, process_group, use_dist ) # read_items: [ReadItem(local_tensor_index, rank, cur_offsets, storage_offsets, lengths)], # slice the storage local tensor in (storage_offsets, lengths) to assign the current tensor in (cur_offsets, lengths) in rank. read_items = get_read_items( - path, flat_state_dict, process_group, use_dist - ) - storage_file_to_state_dict = {} - logger.debug( - f"before load, state_dict:{flat_state_dict},\n load_infos:{load_infos},\n read_items:{read_items}" + metadata, target_state_dict, process_group, use_dist ) + state_dict_in_cpu = [] - for k, v in flat_state_dict.items(): + for k, v in target_state_dict.items(): if v.place.is_cpu_place(): state_dict_in_cpu.append(k) - flat_state_dict[k] = v.cuda() + target_state_dict[k] = v.cuda() + for item in read_items: assert ( item.local_tensor_index in load_infos @@ -489,12 +525,8 @@ def load_state_dict( cur_chunk_tensor = None # The src rank need to load the state_dict. if src_rank == paddle.distributed.get_rank(): - if file_name not in storage_file_to_state_dict: - # The value in state_dict is not distributed tensor but a normal tensor. - storage_file_to_state_dict[file_name] = paddle.load( - os.path.join(path, file_name) - ) - storage_state_dict = storage_file_to_state_dict[file_name] + assert file_name in source_state_dict + storage_state_dict = source_state_dict[file_name] assert item.local_tensor_index.tensor_key in storage_state_dict storage_local_tensor = storage_state_dict[ item.local_tensor_index.tensor_key @@ -520,18 +552,18 @@ def load_state_dict( # The read item rank need to be assigned if item.rank == paddle.distributed.get_rank(): assert ( - item.local_tensor_index.tensor_key in flat_state_dict - ), f"item:{item}, state_dict:{flat_state_dict}" + item.local_tensor_index.tensor_key in target_state_dict + ), f"item:{item}, state_dict:{target_state_dict}" cur_local_tensor = ( - flat_state_dict[ + target_state_dict[ item.local_tensor_index.tensor_key ]._local_value() if use_dist - and flat_state_dict[ + and target_state_dict[ item.local_tensor_index.tensor_key ].is_dist() - else flat_state_dict[item.local_tensor_index.tensor_key] + else target_state_dict[item.local_tensor_index.tensor_key] ) cur_offsets = item.cur_offset @@ -577,9 +609,6 @@ def load_state_dict( ) paddle.assign(tmp_tensor, cur_chunk_tensor) - for k, v in flat_state_dict.items(): - if k in state_dict_in_cpu: - value = state_dict - for key in mapping[k]: - value = value[key] - paddle.assign(v.cpu(), value) + for k, v in target_state_dict.items(): + if k in state_dict_in_cpu: + target_state_dict[k] = v.cpu() From 7c0c732130c81fc0e054ea23603752f3680a64a1 Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Thu, 1 Aug 2024 17:13:45 +0800 Subject: [PATCH 2/7] Update load_state_dict.py --- .../distributed/checkpoint/load_state_dict.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 431a4994462f2c..56e90b2f8b4091 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -176,7 +176,6 @@ def get_local_load_files(rank_to_files): f"rank_to_read_files:{rank_to_read_files}, rank_to_not_read_files:{rank_to_not_read_files}" ) - # 得到目前读文件最少的rank def get_least_read_files_ranks(rank_to_read_files): nums = [ (rank, len(files)) for rank, files in rank_to_read_files.items() @@ -487,9 +486,23 @@ def load_state_dict( for file in local_load_files: source_state_dict[file] = paddle.load(os.path.join(path, file)) + state_dict_in_cpu = [] + for k, v in flat_state_dict.items(): + if v.place.is_cpu_place(): + state_dict_in_cpu.append(k) + flat_state_dict[k] = v.cuda() + _load_state_dict(flat_state_dict, source_state_dict, metadata) + for k, v in flat_state_dict.items(): + if k in state_dict_in_cpu: + value = state_dict + for key in mapping[k]: + value = value[key] + paddle.assign(v.cpu(), value) + + def _load_state_dict( target_state_dict, source_state_dict, @@ -510,12 +523,6 @@ def _load_state_dict( metadata, target_state_dict, process_group, use_dist ) - state_dict_in_cpu = [] - for k, v in target_state_dict.items(): - if v.place.is_cpu_place(): - state_dict_in_cpu.append(k) - target_state_dict[k] = v.cuda() - for item in read_items: assert ( item.local_tensor_index in load_infos @@ -609,6 +616,4 @@ def _load_state_dict( ) paddle.assign(tmp_tensor, cur_chunk_tensor) - for k, v in target_state_dict.items(): - if k in state_dict_in_cpu: - target_state_dict[k] = v.cpu() + From 3babebecbb96ea29a675bcee7245a515b919adb5 Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:34:11 +0800 Subject: [PATCH 3/7] Update load_state_dict.py --- .../distributed/checkpoint/load_state_dict.py | 85 +++++++++---------- 1 file changed, 41 insertions(+), 44 deletions(-) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 56e90b2f8b4091..34f5bab0caa3e7 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -64,7 +64,7 @@ def get_checkpoint_files(path, use_cache=True): def get_rank_to_files( - metadata, local_data_files, state_dict, process_group, use_dist + metadata_list, local_data_files, state_dict, process_group, use_dist ): """ Get the mapping of rank to its accessible files. @@ -73,13 +73,15 @@ def get_rank_to_files( # The necessary files to be read tensor_key_list = [] necessary_files = [] - for local_tensor_index, file_name in metadata.storage_metadata.items(): - assert ( - local_tensor_index not in tensor_key_list - ), f"Duplicate tensor_key:{local_tensor_index} found. Check whether the metadata." - tensor_key_list.append(local_tensor_index.tensor_key) - if local_tensor_index.tensor_key in state_dict: - necessary_files.append(file_name) + + for metadata in metadata_list: + for local_tensor_index, file_name in metadata.storage_metadata.items(): + assert ( + local_tensor_index not in tensor_key_list + ), f"Duplicate tensor_key:{local_tensor_index} found. Check whether the metadata." + tensor_key_list.append(local_tensor_index.tensor_key) + if local_tensor_index.tensor_key in state_dict: + necessary_files.append(file_name) all_necessary_files = [] if use_dist: @@ -192,6 +194,8 @@ def get_read_rank_file(rank_to_not_read_files, ranks): for rank, files in rank_to_not_read_files.items() if rank in ranks ] + # 'ranks' refer to the ranks that have read the fewest number of files so far. However, the files containing the weights required + #. by these ranks may have already been completely read. In this case, they will not read any more files. if len(nums) == 0: nums = [ (rank, len(files)) @@ -238,14 +242,15 @@ def update(rank_to_read_files, rank_to_not_read_files, rank_file): return [] -def get_load_infos(metadata, local_load_files, process_group, use_dist): +def get_load_infos(metadata_list, local_load_files, process_group, use_dist): load_info = {} - for local_tensor_index, file_name in metadata.storage_metadata.items(): - if file_name in local_load_files: - load_info[local_tensor_index] = ( - paddle.distributed.get_rank(), - file_name, - ) + for metadata in metadata_list: + for local_tensor_index, file_name in metadata.storage_metadata.items(): + if file_name in local_load_files: + load_info[local_tensor_index] = ( + paddle.distributed.get_rank(), + file_name, + ) load_info_list = [] if use_dist: @@ -312,16 +317,16 @@ def not_overlap( return False -def get_read_items(metadata, state_dict, process_group, use_dist): +def get_read_items(metadata_list, state_dict, process_group, use_dist): storage_state_dict_metadata = {} - - for ( - tensor_key, - local_tensor_metadata, - ) in metadata.state_dict_metadata.items(): - if tensor_key not in storage_state_dict_metadata: - storage_state_dict_metadata[tensor_key] = [] - storage_state_dict_metadata[tensor_key] += local_tensor_metadata + for metadata in metadata_list: + for ( + tensor_key, + local_tensor_metadata, + ) in metadata.state_dict_metadata.items(): + if tensor_key not in storage_state_dict_metadata: + storage_state_dict_metadata[tensor_key] = [] + storage_state_dict_metadata[tensor_key] += local_tensor_metadata read_items = [] logger.debug(f"storage_state_dict_metadata:{storage_state_dict_metadata}") @@ -456,21 +461,16 @@ def load_state_dict( metadata_files, local_data_files = get_checkpoint_files(path) - assert len(metadata_files) == 1 - - if coordinator_rank == paddle.distributed.get_rank(): - metadata = paddle.load(os.path.join(path, metadata_files[0])) - else: - metadata = None - - global_metadata = [] - - paddle.distributed.all_gather_object(global_metadata, metadata) - - metadata = global_metadata[coordinator_rank] + metadata_list = [] + for file in metadata_files: + metadata_list.append(paddle.load(os.path.join(path, file))) rank_to_files, missing_keys = get_rank_to_files( - metadata, local_data_files, flat_state_dict, process_group, use_dist + metadata_list, + local_data_files, + flat_state_dict, + process_group, + use_dist, ) if len(missing_keys) > 0: @@ -492,8 +492,7 @@ def load_state_dict( state_dict_in_cpu.append(k) flat_state_dict[k] = v.cuda() - _load_state_dict(flat_state_dict, source_state_dict, metadata) - + _load_state_dict(flat_state_dict, source_state_dict, metadata_list) for k, v in flat_state_dict.items(): if k in state_dict_in_cpu: @@ -506,7 +505,7 @@ def load_state_dict( def _load_state_dict( target_state_dict, source_state_dict, - metadata, + metadata_list, process_group=None, coordinator_rank=0, ) -> None: @@ -515,12 +514,12 @@ def _load_state_dict( local_load_files = list(source_state_dict.keys()) # load_infos: {LocalTensorIndex: (rank, file_name)}, which local tensor located in which file, and the file is load in which rank. load_infos = get_load_infos( - metadata, local_load_files, process_group, use_dist + metadata_list, local_load_files, process_group, use_dist ) # read_items: [ReadItem(local_tensor_index, rank, cur_offsets, storage_offsets, lengths)], # slice the storage local tensor in (storage_offsets, lengths) to assign the current tensor in (cur_offsets, lengths) in rank. read_items = get_read_items( - metadata, target_state_dict, process_group, use_dist + metadata_list, target_state_dict, process_group, use_dist ) for item in read_items: @@ -615,5 +614,3 @@ def _load_state_dict( tmp_tensor, src=src_rank, group=process_group ) paddle.assign(tmp_tensor, cur_chunk_tensor) - - From a1e5065425f0cb84acc0ce33f77215cf3171bbf3 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Fri, 2 Aug 2024 16:56:31 +0800 Subject: [PATCH 4/7] fix codestyle --- python/paddle/distributed/checkpoint/load_state_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 34f5bab0caa3e7..d0bfb5978d78ad 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -194,8 +194,8 @@ def get_read_rank_file(rank_to_not_read_files, ranks): for rank, files in rank_to_not_read_files.items() if rank in ranks ] - # 'ranks' refer to the ranks that have read the fewest number of files so far. However, the files containing the weights required - #. by these ranks may have already been completely read. In this case, they will not read any more files. + # 'ranks' refer to the ranks that have read the fewest number of files so far. However, the files containing the weights required + # . by these ranks may have already been completely read. In this case, they will not read any more files. if len(nums) == 0: nums = [ (rank, len(files)) From de56a964be01699fc7292e5d6d2ff2c6b5387dc2 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Sat, 3 Aug 2024 22:47:42 +0800 Subject: [PATCH 5/7] fix test --- test/auto_parallel/test_dist_checkpoint_utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/test/auto_parallel/test_dist_checkpoint_utils.py b/test/auto_parallel/test_dist_checkpoint_utils.py index e709178d1b53fe..bebe5802231c58 100644 --- a/test/auto_parallel/test_dist_checkpoint_utils.py +++ b/test/auto_parallel/test_dist_checkpoint_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import tempfile import unittest @@ -20,6 +21,7 @@ import paddle import paddle.distributed as dist +from paddle.distributed.checkpoint.load_state_dict import get_checkpoint_files from paddle.distributed.checkpoint.utils import ( flatten_state_dict, unflatten_state_dict, @@ -111,6 +113,13 @@ def test_get_rank_to_files(self): "w2": paddle.to_tensor([3, 4]), } dist.save_state_dict(state_dict, ckpt_dir) + + metadata_files,local_load_files= get_checkpoint_files(ckpt_dir) + metadata_list = [] + + for metadata_file in metadata_files: + metadata_list.append(paddle.load(os.path.join(ckpt_dir,metadata_file))) + new_state_dict = { "w1": paddle.to_tensor([1, 2]), "w2": paddle.to_tensor([3, 4]), @@ -119,7 +128,7 @@ def test_get_rank_to_files(self): rank_to_files, missing_keys, ) = dist.checkpoint.load_state_dict.get_rank_to_files( - ckpt_dir, new_state_dict, process_group, use_dist + metadata_list,local_load_files, new_state_dict, process_group, use_dist ) self.assertTrue(len(rank_to_files) == 1 and 0 in rank_to_files) self.assertTrue(rank_to_files[0] == ["0_0.distcp"]) @@ -133,7 +142,7 @@ def test_get_rank_to_files(self): rank_to_files, missing_keys, ) = dist.checkpoint.load_state_dict.get_rank_to_files( - ckpt_dir, new_state_dict, process_group, use_dist + metadata_list,local_load_files, new_state_dict, process_group, use_dist ) self.assertTrue(len(rank_to_files) == 1 and 0 in rank_to_files) self.assertTrue(rank_to_files[0] == ["0_0.distcp"]) @@ -148,7 +157,7 @@ def test_get_rank_to_files(self): rank_to_files, missing_keys, ) = dist.checkpoint.load_state_dict.get_rank_to_files( - ckpt_dir, new_state_dict, process_group, use_dist + metadata_list,local_load_files, new_state_dict, process_group, use_dist ) self.assertTrue(len(rank_to_files) == 0) self.assertTrue(len(missing_keys) == 2) From 1e4b443bfbebde3118510da2f61b09d08cfc3336 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Sat, 3 Aug 2024 22:48:26 +0800 Subject: [PATCH 6/7] fix test --- .../test_dist_checkpoint_utils.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/test/auto_parallel/test_dist_checkpoint_utils.py b/test/auto_parallel/test_dist_checkpoint_utils.py index bebe5802231c58..557b3141a1bdb8 100644 --- a/test/auto_parallel/test_dist_checkpoint_utils.py +++ b/test/auto_parallel/test_dist_checkpoint_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +import os import tempfile import unittest @@ -113,12 +113,14 @@ def test_get_rank_to_files(self): "w2": paddle.to_tensor([3, 4]), } dist.save_state_dict(state_dict, ckpt_dir) - - metadata_files,local_load_files= get_checkpoint_files(ckpt_dir) + + metadata_files, local_load_files = get_checkpoint_files(ckpt_dir) metadata_list = [] for metadata_file in metadata_files: - metadata_list.append(paddle.load(os.path.join(ckpt_dir,metadata_file))) + metadata_list.append( + paddle.load(os.path.join(ckpt_dir, metadata_file)) + ) new_state_dict = { "w1": paddle.to_tensor([1, 2]), @@ -128,7 +130,11 @@ def test_get_rank_to_files(self): rank_to_files, missing_keys, ) = dist.checkpoint.load_state_dict.get_rank_to_files( - metadata_list,local_load_files, new_state_dict, process_group, use_dist + metadata_list, + local_load_files, + new_state_dict, + process_group, + use_dist, ) self.assertTrue(len(rank_to_files) == 1 and 0 in rank_to_files) self.assertTrue(rank_to_files[0] == ["0_0.distcp"]) @@ -142,13 +148,18 @@ def test_get_rank_to_files(self): rank_to_files, missing_keys, ) = dist.checkpoint.load_state_dict.get_rank_to_files( - metadata_list,local_load_files, new_state_dict, process_group, use_dist + metadata_list, + local_load_files, + new_state_dict, + process_group, + use_dist, ) self.assertTrue(len(rank_to_files) == 1 and 0 in rank_to_files) self.assertTrue(rank_to_files[0] == ["0_0.distcp"]) self.assertTrue(len(missing_keys) == 1) self.assertTrue("w3" in missing_keys) + new_state_dict = { "w3": paddle.to_tensor([3, 4]), "w4": paddle.to_tensor([5, 6]), @@ -157,7 +168,11 @@ def test_get_rank_to_files(self): rank_to_files, missing_keys, ) = dist.checkpoint.load_state_dict.get_rank_to_files( - metadata_list,local_load_files, new_state_dict, process_group, use_dist + metadata_list, + local_load_files, + new_state_dict, + process_group, + use_dist, ) self.assertTrue(len(rank_to_files) == 0) self.assertTrue(len(missing_keys) == 2) From 2f7eb434c658cc3e43c941dd0f4aa42de88e992e Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Sat, 3 Aug 2024 22:49:03 +0800 Subject: [PATCH 7/7] fix test --- test/auto_parallel/test_dist_checkpoint_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/auto_parallel/test_dist_checkpoint_utils.py b/test/auto_parallel/test_dist_checkpoint_utils.py index 557b3141a1bdb8..4685451daf99eb 100644 --- a/test/auto_parallel/test_dist_checkpoint_utils.py +++ b/test/auto_parallel/test_dist_checkpoint_utils.py @@ -159,7 +159,6 @@ def test_get_rank_to_files(self): self.assertTrue(len(missing_keys) == 1) self.assertTrue("w3" in missing_keys) - new_state_dict = { "w3": paddle.to_tensor([3, 4]), "w4": paddle.to_tensor([5, 6]),