1515from omegaconf import DictConfig , ListConfig
1616
1717from torch import nn
18- from torch .distributed import (
19- destroy_process_group ,
20- init_device_mesh ,
21- init_process_group ,
22- )
18+ from torch .distributed import destroy_process_group , init_process_group
2319from torch .distributed ._tensor import DTensor
2420from torch .distributed .tensor .parallel import parallelize_module
2521from torch .optim import Optimizer
@@ -146,20 +142,31 @@ def __init__(self, cfg: DictConfig) -> None:
146142 # Initialize distributed variables
147143 self .world_size , self .rank = utils .get_world_size_and_rank ()
148144 self ._is_rank_zero = self .rank == 0
149- self .tensor_parallel_plan = config .instantiate (
150- cfg .get ("tensor_parallel_plan" , None )
151- )
152- self .tensor_parallel_dim = cfg .get ("tensor_parallel_dim" , 1 )
153- if self .tensor_parallel_dim > 1 and self .tensor_parallel_plan is None :
145+ self .tp_plan = config .instantiate (cfg .get ("tensor_parallel_plan" , None ))
146+ self .tp_degree = cfg .get ("tensor_parallel_dim" , 1 )
147+ if self .tp_degree > 1 and self .tp_plan is None :
154148 raise ValueError (
155149 "Tensor Parallel plan needs to be provided when tensor parallel is enabled."
156150 )
157- if self .world_size % self .tensor_parallel_dim != 0 :
158- raise ValueError (
159- f"world_size { self .world_size } must be divisible by tensor_parallel_dim { self .tensor_parallel_dim } "
151+ data_shard = cfg .get ("data_parallel_shard_dim" , - 1 ) # -1 means to infer
152+ data_replicate = cfg .get ("data_parallel_replicate_dim" , 1 )
153+
154+ # Set up n-d device mesh
155+ self .parallel_dims = training .ParallelDims (
156+ dp_replicate = data_replicate ,
157+ dp_shard = data_shard ,
158+ tp = self .tp_degree ,
159+ world_size = self .world_size ,
160+ )
161+ self .world_mesh = self .parallel_dims .build_mesh (device_type = device_type )
162+ if self .parallel_dims .dp_enabled :
163+ dp_mesh = self .world_mesh ["dp" ]
164+ self .dp_degree , self .dp_rank = (
165+ dp_mesh .size (),
166+ dp_mesh .get_local_rank (),
160167 )
161-
162- self . data_parallel_dim = self .world_size // self .tensor_parallel_dim
168+ else :
169+ self .dp_degree , self .dp_rank = 1 , 0
163170
164171 # Logging attributes
165172 self ._output_dir = cfg .output_dir
@@ -538,26 +545,18 @@ def _setup_model(
538545 if self ._compile :
539546 training .compile_model (model , verbose = self ._is_rank_zero )
540547
541- device_mesh = init_device_mesh (
542- self ._device .type ,
543- mesh_shape = (self .data_parallel_dim , self .tensor_parallel_dim ),
544- mesh_dim_names = ("dp" , "tp" ),
545- )
546- self .dp_size = device_mesh ["dp" ].size ()
547- self .dp_rank = device_mesh ["dp" ].get_local_rank ()
548-
549548 # Apply tensor parallelism to the model
550- if self .tensor_parallel_dim > 1 :
551- if self .data_parallel_dim == 1 and self .fsdp_cpu_offload :
549+ if self .parallel_dims . tp_enabled :
550+ if not self .parallel_dims . dp_enabled and self .fsdp_cpu_offload :
552551 raise ValueError (
553552 "Tensor parallelism is not supported with FSDP CPU offloading when data parallelism is disabled."
554553 )
555554 # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel
556- model = training .prepare_mha_for_tp (model , device_mesh ["tp" ])
555+ model = training .prepare_mha_for_tp (model , self . world_mesh ["tp" ])
557556 parallelize_module (
558557 model ,
559- device_mesh ["tp" ],
560- parallelize_plan = self .tensor_parallel_plan ,
558+ self . world_mesh ["tp" ],
559+ parallelize_plan = self .tp_plan ,
561560 )
562561
563562 # We currently have two versions of activation checkpointing in this recipe
@@ -580,19 +579,25 @@ def _setup_model(
580579 )
581580
582581 # Apply Fully Sharded Data Parallelism to the model
583- if self .data_parallel_dim > 1 :
582+ if self .parallel_dims . dp_shard_enabled :
584583 fsdp_shard_conditions = [
585584 partial (
586585 training .get_shard_conditions ,
587586 names_to_match = custom_sharded_layers ,
588587 )
589588 ]
589+
590+ if self .parallel_dims .dp_replicate_enabled :
591+ dp_mesh_dim_names = ("dp_replicate" , "dp_shard" )
592+ else :
593+ dp_mesh_dim_names = ("dp_shard" ,)
594+
590595 training .shard_model (
591596 model = model ,
592597 shard_conditions = fsdp_shard_conditions ,
593598 cpu_offload = fsdp_cpu_offload ,
594599 reshard_after_forward = reshard_after_forward ,
595- dp_mesh = device_mesh [ "dp" ],
600+ dp_mesh = self . world_mesh [ dp_mesh_dim_names ],
596601 )
597602
598603 with training .set_default_dtype (self ._dtype ), self ._device :
@@ -629,7 +634,7 @@ def _setup_model(
629634 training .log_memory_stats (memory_stats )
630635
631636 # synchronize before training begins
632- torch .distributed .barrier ()
637+ torch .distributed .barrier (device_ids = [ self . _device . index ] )
633638
634639 return model
635640
@@ -716,7 +721,7 @@ def _setup_data(
716721 collate_fn = _get_component_from_path (collate_fn )
717722
718723 sampler = StatefulDistributedSampler (
719- ds , num_replicas = self .dp_size , rank = self .dp_rank , shuffle = shuffle
724+ ds , num_replicas = self .dp_degree , rank = self .dp_rank , shuffle = shuffle , seed = 0
720725 )
721726 dataloader = StatefulDataLoader (
722727 dataset = ds ,
@@ -727,7 +732,7 @@ def _setup_data(
727732 collate_fn ,
728733 padding_idx = self ._tokenizer .pad_id ,
729734 ignore_idx = self ._loss_fn .ignore_index ,
730- pad_to_multiple_of = self .tensor_parallel_dim ,
735+ pad_to_multiple_of = self .tp_degree ,
731736 )
732737 if not packed
733738 else padded_collate_packed
@@ -811,22 +816,18 @@ def train(self) -> None:
811816 if self ._optimizer_in_bwd :
812817 torch .distributed .all_reduce (num_tokens )
813818 torch .distributed .all_reduce (running_loss )
814-
815- # We multiply by world_size to undo FSDP2 gradient normalization.
816- current_loss = current_loss * (self .dp_size / num_tokens )
819+ current_loss = current_loss * (self .dp_degree / num_tokens )
817820
818821 current_loss .backward ()
819-
820- # Step with optimizer
822+ # Optimizer step (if not fused in backward call)
821823 if (idx + 1 ) % self ._gradient_accumulation_steps == 0 :
822824 if not self ._optimizer_in_bwd :
823825 # Get total number of tokens across all ranks to normalize gradients
824826 torch .distributed .all_reduce (num_tokens )
825827 # This will ensure that the logged loss matches what we're optimizing
826828 torch .distributed .all_reduce (running_loss )
827829 # Manually scale the gradients from unnormalized loss by total # of tokens
828- # We multiply by world_size to undo FSDP2 gradient normalization.
829- training .scale_grads (self ._model , self .dp_size / num_tokens )
830+ training .scale_grads (self ._model , self .dp_degree / num_tokens )
830831 if self ._clip_grad_norm is not None :
831832 grad_norm = torch .nn .utils .clip_grad_norm_ (
832833 self ._model .parameters (),
0 commit comments