Skip to content

Commit c97f72c

Browse files
committed
Better arg for dp sharding and set defaults in ParallelDims for unsupported parallel schemes.
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
1 parent 5831c0b commit c97f72c

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

recipes/full_finetune_distributed.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,17 @@ def __init__(self, cfg: DictConfig) -> None:
152152
f"world_size {self.world_size} must be divisible by tensor_parallel_dim {self.tensor_parallel_dim}"
153153
)
154154

155-
data_shard = cfg.get("dp", self.world_size // self.tensor_parallel_dim)
155+
data_shard = cfg.get("dp_shard", self.world_size // self.tensor_parallel_dim)
156156
data_replicate = cfg.get("dp_replicate", 1)
157157

158158
self.parallel_dims = training.ParallelDims(
159159
dp_replicate=data_replicate,
160160
dp_shard=data_shard,
161+
cp=1,
161162
tp=self.tensor_parallel_dim,
163+
pp=1,
162164
world_size=self.world_size,
165+
enable_loss_parallel=False,
163166
)
164167
self.world_mesh = self.parallel_dims.build_mesh(device_type=device_type)
165168

0 commit comments

Comments
 (0)