@@ -475,6 +475,7 @@ def get_read_items(metadata_list, state_dict, process_group, use_dist):
475475 storage_state_dict_metadata [tensor_key ] += local_tensor_metadata
476476
477477 read_items = []
478+ global_shape = None
478479 logger .debug (f"storage_state_dict_metadata:{ storage_state_dict_metadata } " )
479480 for tensor_key , val in state_dict .items ():
480481 tensor_name = None
@@ -493,13 +494,15 @@ def get_read_items(metadata_list, state_dict, process_group, use_dist):
493494 if len (val .shape ) > 0
494495 else ((), ())
495496 )
497+ global_shape = val .shape
496498 if local_shape is None or global_offset is None :
497499 continue
498500 else :
499501 local_shape = tuple (val .shape )
500502 global_offset = (
501503 tuple ([0 ] * len (val .shape )) if len (val .shape ) > 0 else ()
502504 )
505+ global_shape = local_shape
503506 dtype = str (val .dtype ).split ("." )[1 ]
504507 tensor_name = tensor_key
505508 elif isinstance (val , ShardedWeight ):
@@ -512,13 +515,14 @@ def get_read_items(metadata_list, state_dict, process_group, use_dist):
512515 tensor_name = (
513516 tensor_key [0 ] if isinstance (tensor_key , tuple ) else tensor_key
514517 )
518+ global_shape = val .global_shape
515519 else :
516520 raise ValueError (
517521 f"Only support paddle.Tensor., val type:{ type (val )} "
518522 )
519523
520524 cur_chunk_metadata = LocalTensorMetadata (
521- global_offset , local_shape , dtype
525+ global_offset , local_shape , dtype , global_shape
522526 )
523527 assert tensor_name in storage_state_dict_metadata , (
524528 f"tensor_key:{ tensor_key } not found in storage_state_dict_metadata:{ storage_state_dict_metadata } ."
@@ -636,7 +640,6 @@ def _handle_aoa(
636640 assert len (metadata_files ) == 1 , "Only support one metadata file now."
637641 metadata = paddle .load (os .path .join (path , metadata_files [0 ]))
638642 state_dict_metadata = metadata .state_dict_metadata
639-
640643 source_state_shard_info = {
641644 param_name : [
642645 ShardedWeightDesc (
@@ -794,9 +797,9 @@ def load_state_dict(
794797 if not use_dist :
795798 load_dict = {}
796799 for key , val in state_dict .items ():
797- assert (
798- val . local_shape == val . global_shape
799- ), f" { key } is not replicated!"
800+ assert val . local_shape == val . global_shape , (
801+ f" { key } is not replicated!"
802+ )
800803 load_dict [key ] = val
801804 load_state_dict_impl (
802805 load_dict ,
@@ -850,16 +853,16 @@ def load_state_dict_impl(
850853 mw_name_compatibility : bool = True ,
851854) -> None :
852855 with paddle .base .dygraph .guard ():
853- assert isinstance (
854- state_dict , dict
855- ), "The state_dict should be a dictionary."
856+ assert isinstance (state_dict , dict ), (
857+ "The state_dict should be a dictionary."
858+ )
856859 first_key = next (iter (state_dict ), None )
857860 if isinstance (first_key , tuple ):
858861 flat_state_dict = state_dict
859862 mapping = {}
860863 else :
861864 flat_state_dict , mapping = flatten_state_dict (state_dict )
862-
865+
863866 if len (flat_state_dict ) > 0 :
864867 for val in flat_state_dict .values ():
865868 assert isinstance (val , (paddle .Tensor , ShardedWeight )), (
@@ -998,9 +1001,9 @@ def _load_state_dict(
9981001 idx = 0
9991002 assert not any (
10001003 isinstance (k , tuple ) for k in copied_target_state_dict
1001- ) or all (
1002- isinstance ( k , tuple ) for k in copied_target_state_dict
1003- ), "target_state_dict contains a mix of tuple and non-tuple keys. Please ensure key types are consistent."
1004+ ) or all (isinstance ( k , tuple ) for k in copied_target_state_dict ), (
1005+ "target_state_dict contains a mix of tuple and non-tuple keys. Please ensure key types are consistent."
1006+ )
10041007
10051008 for item in read_items :
10061009 if any (isinstance (k , tuple ) for k in copied_target_state_dict ):
@@ -1055,9 +1058,9 @@ def _load_state_dict(
10551058 storage_chunk_tensor = storage_local_tensor
10561059 # The read item rank need to be assigned
10571060 if item .rank == paddle .distributed .get_rank ():
1058- assert (
1059- key in copied_target_state_dict
1060- ), f"item: { item } , state_dict: { copied_target_state_dict } "
1061+ assert key in copied_target_state_dict , (
1062+ f"item: { item } , state_dict: { copied_target_state_dict } "
1063+ )
10611064
10621065 cur_local_tensor = (
10631066 copied_target_state_dict [key ]._local_value ()
0 commit comments