Skip to content

Commit d073767

Browse files
committed
fix
1 parent 3dabad7 commit d073767

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

python/paddle/distributed/flex_checkpoint/dcp/sharded_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def make_tp_sharded_weight_for_checkpoint(
146146
Returns:
147147
A ShardedWeight configured for tensor parallel checkpointing.
148148
"""
149-
from ...fleet.fleet import get_hybrid_communicate_group
149+
from paddle.distributed.fleet import get_hybrid_communicate_group
150150

151151
hcg = get_hybrid_communicate_group()
152152
tensor_parallel_group = hcg.get_model_parallel_group()

python/paddle/optimizer/adamw.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
import paddle
2222
from paddle import pir
2323
from paddle.base.libpaddle import DataType
24-
from paddle.distributed.flex_checkpoint.dcp.sharded_tensor import (
24+
from paddle.distributed.flex_checkpoint.dcp.sharded_weight import (
2525
ShardedStateDict,
26-
ShardedTensor,
27-
create_sharded_tensor_with_new_local,
26+
ShardedWeight,
27+
create_sharded_weight_with_new_local,
2828
)
2929
from paddle.pir import Value
3030

@@ -749,7 +749,7 @@ def sharded_state_dict(
749749
model_sharded_state_dict (dict): Sharded state dict of the model, containing tensor metadata.
750750
751751
Returns:
752-
dict: A new optimizer state dict where tensors are wrapped as ShardedTensor.
752+
dict: A new optimizer state dict where weights are wrapped as ShardedWeight.
753753
"""
754754

755755
_FP32_MASTER = "fp32_master_0"
@@ -785,19 +785,19 @@ def _generate_base_static_name(vname):
785785
for key, tensor in optimizer_state_dict.items():
786786
static_name, optim_state_type = _generate_base_static_name(key)
787787
struct_name = static_to_struct_mapping[static_name]
788-
sharded_tensor = model_sharded_state_dict[struct_name]
788+
sharded_weight = model_sharded_state_dict[struct_name]
789789

790790
unified_name = f"{struct_name}.{optim_state_type}"
791791

792792
# Determine tensor partitioning scheme
793793
if _MOMENT_NAME in optim_state_type:
794794
optimizer_sharded_state_dict[unified_name] = (
795-
create_sharded_tensor_with_new_local(
796-
unified_name, tensor, sharded_tensor
795+
create_sharded_weight_with_new_local(
796+
unified_name, tensor, sharded_weight
797797
)
798798
)
799799
else: # Non-momentum parameters
800-
optimizer_sharded_state_dict[unified_name] = ShardedTensor(
800+
optimizer_sharded_state_dict[unified_name] = ShardedWeight(
801801
key=unified_name,
802802
local_tensor=tensor,
803803
local_shape=(1,),
@@ -809,11 +809,11 @@ def _generate_base_static_name(vname):
809809
if master_weights is not None:
810810
for key, tensor in master_weights.items():
811811
struct_name = static_to_struct_mapping[key]
812-
sharded_tensor = model_sharded_state_dict[struct_name]
812+
sharded_weight = model_sharded_state_dict[struct_name]
813813
unified_name = f"{struct_name}.w_0"
814814
optimizer_sharded_state_dict[unified_name] = (
815-
create_sharded_tensor_with_new_local(
816-
unified_name, tensor, sharded_tensor
815+
create_sharded_weight_with_new_local(
816+
unified_name, tensor, sharded_weight
817817
)
818818
)
819819

0 commit comments

Comments
 (0)