@@ -1538,18 +1538,48 @@ def _native_attention(
15381538) -> torch .Tensor :
15391539 if return_lse :
15401540 raise ValueError ("Native attention backend does not support setting `return_lse=True`." )
1541- query , key , value = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
1542- out = torch .nn .functional .scaled_dot_product_attention (
1543- query = query ,
1544- key = key ,
1545- value = value ,
1546- attn_mask = attn_mask ,
1547- dropout_p = dropout_p ,
1548- is_causal = is_causal ,
1549- scale = scale ,
1550- enable_gqa = enable_gqa ,
1551- )
1552- out = out .permute (0 , 2 , 1 , 3 )
1541+ if _parallel_config is None :
1542+ query , key , value = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
1543+ out = torch .nn .functional .scaled_dot_product_attention (
1544+ query = query ,
1545+ key = key ,
1546+ value = value ,
1547+ attn_mask = attn_mask ,
1548+ dropout_p = dropout_p ,
1549+ is_causal = is_causal ,
1550+ scale = scale ,
1551+ enable_gqa = enable_gqa ,
1552+ )
1553+ out = out .permute (0 , 2 , 1 , 3 )
1554+ elif _parallel_config .context_parallel_config .ring_degree == 1 :
1555+ ulysses_mesh = _parallel_config .context_parallel_config ._ulysses_mesh
1556+ world_size = _parallel_config .context_parallel_config .ulysses_degree
1557+ group = ulysses_mesh .get_group ()
1558+
1559+ B , S_Q_LOCAL , H , D = query .shape
1560+ _ , S_KV_LOCAL , _ , _ = key .shape
1561+ H_LOCAL = H // world_size
1562+ query = query .reshape (B , S_Q_LOCAL , world_size , H_LOCAL , D ).permute (2 , 1 , 0 , 3 , 4 ).contiguous ()
1563+ key = key .reshape (B , S_KV_LOCAL , world_size , H_LOCAL , D ).permute (2 , 1 , 0 , 3 , 4 ).contiguous ()
1564+ value = value .reshape (B , S_KV_LOCAL , world_size , H_LOCAL , D ).permute (2 , 1 , 0 , 3 , 4 ).contiguous ()
1565+ query , key , value = (_all_to_all_single (x , group ) for x in (query , key , value ))
1566+ query , key , value = (x .flatten (0 , 1 ).permute (1 , 2 , 0 , 3 ).contiguous () for x in (query , key , value ))
1567+ out = torch .nn .functional .scaled_dot_product_attention (
1568+ query = query ,
1569+ key = key ,
1570+ value = value ,
1571+ attn_mask = attn_mask ,
1572+ dropout_p = dropout_p ,
1573+ is_causal = is_causal ,
1574+ scale = scale ,
1575+ enable_gqa = enable_gqa ,
1576+ )
1577+ out = out .reshape (B , H_LOCAL , world_size , S_Q_LOCAL , D ).permute (2 , 1 , 0 , 3 , 4 ).contiguous ()
1578+ out = _all_to_all_single (out , group )
1579+ out = out .flatten (0 , 1 ).permute (1 , 2 , 0 , 3 ).contiguous ()
1580+ return out
1581+ else :
1582+ raise ValueError ("Native attention backend does not support context parallelism with ring_degree > 1, you could try to use ulysses Attention instead" )
15531583 return out
15541584
15551585
0 commit comments