1616from torch import distributed as dist
1717from torch .distributed .device_mesh import DeviceMesh
1818from torch .distributed .tensor import DTensor
19- from torch .nn .attention import SDPBackend
2019
2120from torchtitan .config import Comm as CommConfig , TORCH_DTYPE_MAP
2221from torchtitan .distributed .parallel_dims import ParallelDims
23- from torchtitan .models .attention import ScaledDotProductAttention
2422from torchtitan .tools .logging import logger
2523from torchtitan .tools .utils import device_module , device_type
2624
@@ -202,6 +200,10 @@ def context(cp_context: Generator[None, None, None] | None = None):
202200 )
203201
204202 if cp_context is not None :
203+ from torch .nn .attention import SDPBackend
204+
205+ from torchtitan .models .attention import ScaledDotProductAttention
206+
205207 if SDPBackend .MATH in ScaledDotProductAttention .backends :
206208 ScaledDotProductAttention .backends .remove (SDPBackend .MATH )
207209 assert (
@@ -319,7 +321,7 @@ def clip_grad_norm_(
319321 error_if_nonfinite : bool = False ,
320322 foreach : bool | None = None ,
321323 pp_mesh : DeviceMesh | None = None ,
322- ep_dense_params_mesh_ndim : int | None = None ,
324+ ep_enabled : bool = False ,
323325) -> torch .Tensor :
324326 """
325327 Clip the gradient norm of an iterable of parameters.
@@ -349,15 +351,14 @@ def clip_grad_norm_(
349351 Total norm of the parameter gradients (viewed as a single vector).
350352
351353 """
352- if ep_dense_params_mesh_ndim is not None :
354+ if ep_enabled :
353355 return _clip_grad_norm_with_ep (
354356 parameters ,
355357 max_norm ,
356358 norm_type ,
357359 error_if_nonfinite ,
358360 foreach ,
359361 pp_mesh ,
360- ep_dense_params_mesh_ndim ,
361362 )
362363
363364 if isinstance (parameters , torch .Tensor ):
@@ -401,7 +402,6 @@ def _clip_grad_norm_with_ep(
401402 error_if_nonfinite : bool ,
402403 foreach : bool | None ,
403404 pp_mesh : DeviceMesh | None ,
404- dense_params_mesh_ndim : int ,
405405) -> torch .Tensor :
406406 ep_params = []
407407 non_ep_params = []
@@ -412,12 +412,12 @@ def _clip_grad_norm_with_ep(
412412 if p .grad is None :
413413 continue
414414 assert isinstance (p , DTensor ) and isinstance (p .grad , DTensor )
415- if p .device_mesh .ndim == dense_params_mesh_ndim :
416- non_ep_params .append (p )
417- non_ep_grads .append (p .grad )
418- else :
415+ if "ep" in p .device_mesh .mesh_dim_names :
419416 ep_params .append (p )
420417 ep_grads .append (p .grad )
418+ else :
419+ non_ep_params .append (p )
420+ non_ep_grads .append (p .grad )
421421 ep_grads_total_norm = torch .nn .utils .get_total_norm (
422422 ep_grads , norm_type , error_if_nonfinite , foreach
423423 ).full_tensor ()
0 commit comments