2121import paddle
2222from paddle import pir
2323from paddle .base .libpaddle import DataType
24- from paddle .distributed .flex_checkpoint .dcp .sharded_tensor import (
24+ from paddle .distributed .flex_checkpoint .dcp .sharded_weight import (
2525 ShardedStateDict ,
26- ShardedTensor ,
27- create_sharded_tensor_with_new_local ,
26+ ShardedWeight ,
27+ create_sharded_weight_with_new_local ,
2828)
2929from paddle .pir import Value
3030
@@ -749,7 +749,7 @@ def sharded_state_dict(
749749 model_sharded_state_dict (dict): Sharded state dict of the model, containing tensor metadata.
750750
751751 Returns:
752- dict: A new optimizer state dict where tensors are wrapped as ShardedTensor .
752+ dict: A new optimizer state dict where weights are wrapped as ShardedWeight .
753753 """
754754
755755 _FP32_MASTER = "fp32_master_0"
@@ -785,19 +785,19 @@ def _generate_base_static_name(vname):
785785 for key , tensor in optimizer_state_dict .items ():
786786 static_name , optim_state_type = _generate_base_static_name (key )
787787 struct_name = static_to_struct_mapping [static_name ]
788- sharded_tensor = model_sharded_state_dict [struct_name ]
788+ sharded_weight = model_sharded_state_dict [struct_name ]
789789
790790 unified_name = f"{ struct_name } .{ optim_state_type } "
791791
792792 # Determine tensor partitioning scheme
793793 if _MOMENT_NAME in optim_state_type :
794794 optimizer_sharded_state_dict [unified_name ] = (
795- create_sharded_tensor_with_new_local (
796- unified_name , tensor , sharded_tensor
795+ create_sharded_weight_with_new_local (
796+ unified_name , tensor , sharded_weight
797797 )
798798 )
799799 else : # Non-momentum parameters
800- optimizer_sharded_state_dict [unified_name ] = ShardedTensor (
800+ optimizer_sharded_state_dict [unified_name ] = ShardedWeight (
801801 key = unified_name ,
802802 local_tensor = tensor ,
803803 local_shape = (1 ,),
@@ -809,11 +809,11 @@ def _generate_base_static_name(vname):
809809 if master_weights is not None :
810810 for key , tensor in master_weights .items ():
811811 struct_name = static_to_struct_mapping [key ]
812- sharded_tensor = model_sharded_state_dict [struct_name ]
812+ sharded_weight = model_sharded_state_dict [struct_name ]
813813 unified_name = f"{ struct_name } .w_0"
814814 optimizer_sharded_state_dict [unified_name ] = (
815- create_sharded_tensor_with_new_local (
816- unified_name , tensor , sharded_tensor
815+ create_sharded_weight_with_new_local (
816+ unified_name , tensor , sharded_weight
817817 )
818818 )
819819
0 commit comments