Skip to content

Commit 04eb395

Browse files
committed
fix
1 parent b9d117e commit 04eb395

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def get_read_items(metadata_list, state_dict, process_group, use_dist):
475475
storage_state_dict_metadata[tensor_key] += local_tensor_metadata
476476

477477
read_items = []
478+
global_shape = None
478479
logger.debug(f"storage_state_dict_metadata:{storage_state_dict_metadata}")
479480
for tensor_key, val in state_dict.items():
480481
tensor_name = None
@@ -493,13 +494,15 @@ def get_read_items(metadata_list, state_dict, process_group, use_dist):
493494
if len(val.shape) > 0
494495
else ((), ())
495496
)
497+
global_shape = val.shape
496498
if local_shape is None or global_offset is None:
497499
continue
498500
else:
499501
local_shape = tuple(val.shape)
500502
global_offset = (
501503
tuple([0] * len(val.shape)) if len(val.shape) > 0 else ()
502504
)
505+
global_shape = local_shape
503506
dtype = str(val.dtype).split(".")[1]
504507
tensor_name = tensor_key
505508
elif isinstance(val, ShardedWeight):
@@ -512,13 +515,14 @@ def get_read_items(metadata_list, state_dict, process_group, use_dist):
512515
tensor_name = (
513516
tensor_key[0] if isinstance(tensor_key, tuple) else tensor_key
514517
)
518+
global_shape = val.global_shape
515519
else:
516520
raise ValueError(
517521
f"Only support paddle.Tensor., val type:{type(val)}"
518522
)
519523

520524
cur_chunk_metadata = LocalTensorMetadata(
521-
global_offset, local_shape, dtype
525+
global_offset, local_shape, dtype, global_shape
522526
)
523527
assert tensor_name in storage_state_dict_metadata, (
524528
f"tensor_key:{tensor_key} not found in storage_state_dict_metadata:{storage_state_dict_metadata}."
@@ -636,7 +640,6 @@ def _handle_aoa(
636640
assert len(metadata_files) == 1, "Only support one metadata file now."
637641
metadata = paddle.load(os.path.join(path, metadata_files[0]))
638642
state_dict_metadata = metadata.state_dict_metadata
639-
640643
source_state_shard_info = {
641644
param_name: [
642645
ShardedWeightDesc(
@@ -794,9 +797,9 @@ def load_state_dict(
794797
if not use_dist:
795798
load_dict = {}
796799
for key, val in state_dict.items():
797-
assert (
798-
val.local_shape == val.global_shape
799-
), f"{key} is not replicated!"
800+
assert val.local_shape == val.global_shape, (
801+
f"{key} is not replicated!"
802+
)
800803
load_dict[key] = val
801804
load_state_dict_impl(
802805
load_dict,
@@ -850,16 +853,16 @@ def load_state_dict_impl(
850853
mw_name_compatibility: bool = True,
851854
) -> None:
852855
with paddle.base.dygraph.guard():
853-
assert isinstance(
854-
state_dict, dict
855-
), "The state_dict should be a dictionary."
856+
assert isinstance(state_dict, dict), (
857+
"The state_dict should be a dictionary."
858+
)
856859
first_key = next(iter(state_dict), None)
857860
if isinstance(first_key, tuple):
858861
flat_state_dict = state_dict
859862
mapping = {}
860863
else:
861864
flat_state_dict, mapping = flatten_state_dict(state_dict)
862-
865+
863866
if len(flat_state_dict) > 0:
864867
for val in flat_state_dict.values():
865868
assert isinstance(val, (paddle.Tensor, ShardedWeight)), (
@@ -998,9 +1001,9 @@ def _load_state_dict(
9981001
idx = 0
9991002
assert not any(
10001003
isinstance(k, tuple) for k in copied_target_state_dict
1001-
) or all(
1002-
isinstance(k, tuple) for k in copied_target_state_dict
1003-
), "target_state_dict contains a mix of tuple and non-tuple keys. Please ensure key types are consistent."
1004+
) or all(isinstance(k, tuple) for k in copied_target_state_dict), (
1005+
"target_state_dict contains a mix of tuple and non-tuple keys. Please ensure key types are consistent."
1006+
)
10041007

10051008
for item in read_items:
10061009
if any(isinstance(k, tuple) for k in copied_target_state_dict):
@@ -1055,9 +1058,9 @@ def _load_state_dict(
10551058
storage_chunk_tensor = storage_local_tensor
10561059
# The read item rank need to be assigned
10571060
if item.rank == paddle.distributed.get_rank():
1058-
assert (
1059-
key in copied_target_state_dict
1060-
), f"item:{item}, state_dict:{copied_target_state_dict}"
1061+
assert key in copied_target_state_dict, (
1062+
f"item:{item}, state_dict:{copied_target_state_dict}"
1063+
)
10611064

10621065
cur_local_tensor = (
10631066
copied_target_state_dict[key]._local_value()

python/paddle/distributed/flex_checkpoint/dcp/save_state_dict.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ def save_state_dict_impl(
366366
if len(val.shape) > 0
367367
else ((), ())
368368
)
369+
global_shape = val.shape
369370
if local_shape is None or global_offset is None:
370371
continue
371372
else:
@@ -376,6 +377,7 @@ def save_state_dict_impl(
376377
else ()
377378
)
378379
local_tensor = val
380+
global_shape = local_shape
379381
elif isinstance(val, ShardedWeight):
380382
local_tensor = val.local_tensor
381383
local_shape = val.local_shape

0 commit comments

Comments
 (0)