@@ -136,7 +136,7 @@ def parallelize_llama(
136136 reshard_after_forward_policy = job_config .parallelism .fsdp_reshard_after_forward ,
137137 dp_mod_ep_mesh = (
138138 world_mesh [tuple (dp_mod_ep_mesh_dim_names )]
139- if dp_mod_ep_mesh_dim_names
139+ if parallel_dims . ep_enabled
140140 else None
141141 ),
142142 gradient_divide_factor = parallel_dims .fsdp_gradient_divide_factor ,
@@ -295,34 +295,43 @@ def apply_fsdp(
295295 if cpu_offload :
296296 fsdp_config ["offload_policy" ] = CPUOffloadPolicy ()
297297
298- for layer_id , transformer_block in model . layers . items () :
299- if reshard_after_forward_policy == "always" :
298+ match reshard_after_forward_policy :
299+ case "always" :
300300 reshard_after_forward = True
301- elif reshard_after_forward_policy == "never" :
301+ case "never" :
302302 reshard_after_forward = False
303- elif reshard_after_forward_policy == "default" :
304- if pp_enabled :
305- # For PP, do not reshard after forward to avoid per-microbatch
306- # all-gathers, which can be expensive and non-overlapped
307- reshard_after_forward = False
308- else :
309- # As an optimization, do not reshard after forward for the last
310- # transformer block since FSDP would prefetch it immediately
311- reshard_after_forward = int (layer_id ) < len (model .layers ) - 1
312- else :
303+ case "default" :
304+ # For PP, by default do not reshard after forward to avoid per-microbatch
305+ # all-gathers, which can be expensive and non-overlapped
306+ reshard_after_forward = not pp_enabled
307+ case _:
313308 raise ValueError (
314309 f"Invalid reshard_after_forward_policy: { reshard_after_forward_policy } ."
315310 )
316311
317- # NOTE: in an MoE layer, the router and the shared experts
318- # are sharded together with the TransformerBlock
312+ if model .tok_embeddings is not None :
313+ fully_shard (
314+ model .tok_embeddings ,
315+ ** fsdp_config ,
316+ reshard_after_forward = reshard_after_forward ,
317+ )
318+
319+ for layer_id , transformer_block in model .layers .items ():
320+ # NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping
321+ # - the router and the shared experts are sharded together with the TransformerBlock
322+ # - the routed experts are sharded with the remaining dp_mod_ep_mesh
319323 if transformer_block .moe_enabled and dp_mod_ep_mesh :
320324 fsdp_mod_ep_config = fsdp_config .copy ()
321325 fsdp_mod_ep_config ["mesh" ] = dp_mod_ep_mesh
322326 fully_shard (
323327 transformer_block .moe .experts ,
324328 ** fsdp_mod_ep_config ,
325329 reshard_after_forward = reshard_after_forward ,
330+ # NOTE: When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding
331+ # causes inefficiency, so we choose to do FSDP sharding on dim-1.
332+ # TODO: Even when EP is not used, we may still want to
333+ # shard the experts on non-0 dim.
334+ shard_placement_fn = lambda param : Shard (1 ),
326335 )
327336 # NOTE: # Although the FSDP sharding of experts is done on a mesh of
328337 # a different size than other parameters, the gradient division
@@ -336,7 +345,17 @@ def apply_fsdp(
336345 ** fsdp_config ,
337346 reshard_after_forward = reshard_after_forward ,
338347 )
339- fully_shard (model , ** fsdp_config , reshard_after_forward = not pp_enabled )
348+
349+ # As an optimization, do not reshard_after_forward the last layers by default
350+ # since FSDP would prefetch them immediately after the forward pass
351+ if model .norm is not None and model .output is not None :
352+ fully_shard (
353+ [model .norm , model .output ],
354+ ** fsdp_config ,
355+ reshard_after_forward = reshard_after_forward_policy == "always" ,
356+ )
357+
358+ fully_shard (model , ** fsdp_config )
340359
341360
342361def apply_moe_ep_tp (
@@ -362,9 +381,18 @@ def apply_moe_ep_tp(
362381 ),
363382 # replicate computation for the router
364383 "moe.router.gate" : NoParallel (),
365- # input Replicate, output Partial
366- "moe.shared_expert" : TensorParallel (),
367384 }
385+ if transformer_block .moe .shared_experts is not None :
386+ # input Replicate, output Partial
387+ moe_layer_plan .update (
388+ {
389+ "moe.shared_experts.w1" : ColwiseParallel (),
390+ "moe.shared_experts.w2" : RowwiseParallel (
391+ output_layouts = Partial ()
392+ ),
393+ "moe.shared_experts.w3" : ColwiseParallel (),
394+ }
395+ )
368396 parallelize_module (
369397 module = transformer_block ,
370398 device_mesh = tp_mesh ,
0 commit comments