@@ -69,13 +69,14 @@ def unshard_seq_load_balance(tensor, seq_dim):
6969
7070class RingCommunicator :
7171 def __init__ (self , group , local_key , local_value ):
72- self ._k_buffer = [paddle .zeros_like (local_key ) for _ in range (2 )]
73- self ._v_buffer = [paddle .zeros_like (local_value ) for _ in range (2 )]
74-
75- local_key = local_key .contiguous ()
76- local_value = local_value .contiguous ()
77- self ._k_buffer [0 ] = local_key .clone ()
78- self ._v_buffer [0 ] = local_value .clone ()
72+ self ._k_buffer = [
73+ local_key .clone ().contiguous (),
74+ local_key .clone ().contiguous (),
75+ ]
76+ self ._v_buffer = [
77+ local_value .clone ().contiguous (),
78+ local_value .clone ().contiguous (),
79+ ]
7980
8081 self ._next_buffer_idx = 0
8182
@@ -97,8 +98,6 @@ def wait(self):
9798 paddle .device .synchronize ()
9899
99100 def add_to_buffers (self , key , value ):
100- key = key .contiguous ()
101- value = value .contiguous ()
102101 if key .shape != self ._k_buffer [self ._next_buffer_idx ].shape :
103102 self ._k_buffer [self ._next_buffer_idx ][:, : key .shape [1 ], :, :].add_ (
104103 key
@@ -112,8 +111,8 @@ def add_to_buffers(self, key, value):
112111
113112 def get_buffers (self ):
114113 return (
115- self ._k_buffer [self ._next_buffer_idx ]. contiguous () ,
116- self ._v_buffer [self ._next_buffer_idx ]. contiguous () ,
114+ self ._k_buffer [self ._next_buffer_idx ],
115+ self ._v_buffer [self ._next_buffer_idx ],
117116 )
118117
119118 def send_recv (self ):
@@ -131,13 +130,13 @@ def send_recv(self):
131130 )
132131 recv_k_op = dist .P2POp (
133132 dist .irecv ,
134- self ._k_buffer [(self ._next_buffer_idx + 1 ) % 2 ]. contiguous () ,
133+ self ._k_buffer [(self ._next_buffer_idx + 1 ) % 2 ],
135134 self .recv_rank ,
136135 self .group ,
137136 )
138137 recv_v_op = dist .P2POp (
139138 dist .irecv ,
140- self ._v_buffer [(self ._next_buffer_idx + 1 ) % 2 ]. contiguous () ,
139+ self ._v_buffer [(self ._next_buffer_idx + 1 ) % 2 ],
141140 self .recv_rank ,
142141 self .group ,
143142 )
@@ -186,9 +185,9 @@ def concat_masks(attn_masks_list, rank, cp_size):
186185
187186def ring_flash_attention_forward_func (
188187 group ,
189- query ,
190- key ,
191- value ,
188+ local_query ,
189+ local_key ,
190+ local_value ,
192191 attn_mask = None ,
193192 dropout = 0.0 ,
194193 is_causal = False ,
@@ -197,20 +196,6 @@ def ring_flash_attention_forward_func(
197196):
198197 cp_size = group .world_size
199198 group_rank = group .rank
200- mesh = dist .auto_parallel .get_mesh ()
201-
202- local_query = dist .auto_parallel .api .dtensor_to_local (
203- query , mesh , query .placements
204- )
205- local_key = dist .auto_parallel .api .dtensor_to_local (
206- key , mesh , key .placements
207- )
208- local_value = dist .auto_parallel .api .dtensor_to_local (
209- value , mesh , value .placements
210- )
211- local_query = local_query .contiguous ()
212- local_key = local_key .contiguous ()
213- local_value = local_value .contiguous ()
214199
215200 comm_buffer = RingCommunicator (group , local_key , local_value )
216201 local_q_seq_len = local_query .shape [1 ]
@@ -219,7 +204,9 @@ def ring_flash_attention_forward_func(
219204 attn_mask , num_or_sections = cp_size * 2 , axis = 3
220205 )
221206 if is_causal :
222- local_query_second_chunk = local_query [:, local_q_seq_len // 2 :, :, :]
207+ local_query_second_chunk = local_query [
208+ :, local_q_seq_len // 2 :, :, :
209+ ].contiguous ()
223210 for step in range (cp_size ):
224211 block_k , block_v = comm_buffer .get_buffers ()
225212 if step != cp_size - 1 :
@@ -315,10 +302,10 @@ def ring_flash_attention_forward_func(
315302
316303def ring_flash_attention_backward_func (
317304 group ,
318- out_grad ,
319- query ,
320- key ,
321- value ,
305+ local_out_grad ,
306+ local_query ,
307+ local_key ,
308+ local_value ,
322309 local_out ,
323310 lse ,
324311 attn_mask ,
@@ -328,25 +315,7 @@ def ring_flash_attention_backward_func(
328315):
329316 cp_size = group .world_size
330317 group_rank = group .rank
331- mesh = dist .auto_parallel .get_mesh ()
332318
333- local_query = dist .auto_parallel .api .dtensor_to_local (
334- query , mesh , query .placements
335- )
336- local_key = dist .auto_parallel .api .dtensor_to_local (
337- key , mesh , key .placements
338- )
339- local_value = dist .auto_parallel .api .dtensor_to_local (
340- value , mesh , value .placements
341- )
342- local_out_grad = dist .auto_parallel .api .dtensor_to_local (
343- out_grad , mesh , out_grad .placements
344- )
345- local_query = local_query .contiguous ()
346- local_key = local_key .contiguous ()
347- local_value = local_value .contiguous ()
348- local_out = local_out .contiguous ()
349- local_out_grad = local_out_grad .contiguous ()
350319 lse = lse .contiguous ()
351320
352321 local_q_seq_len = local_query .shape [1 ]
@@ -361,14 +330,13 @@ def ring_flash_attention_backward_func(
361330 if is_causal :
362331 local_query_second_chunk = local_query [:, local_q_seq_len // 2 :, :, :]
363332 local_out_second_chunk = local_out [:, local_q_seq_len // 2 :, :, :]
364- lse_second_chunk = lse [:, :, local_q_seq_len // 2 :]
333+ lse_second_chunk = lse [:, :, local_q_seq_len // 2 :]. contiguous ()
365334 out_grad_second_chunk = local_out_grad [:, local_q_seq_len // 2 :, :, :]
366335
367336 if attn_mask is not None :
368337 attn_masks_list = paddle .split (
369338 attn_mask , num_or_sections = cp_size * 2 , axis = 3
370339 )
371-
372340 for step in range (cp_size ):
373341 block_k , block_v = kv_comm_buffer .get_buffers ()
374342 if step != cp_size - 1 :
@@ -445,13 +413,14 @@ def ring_flash_attention_backward_func(
445413 )
446414 )
447415 query_grad_buffer .add_ (block_q_grad )
448-
449416 paddle .device .synchronize ()
450417
451- grad_comm_buffer .add_to_buffers (block_k_grad , block_v_grad )
418+ grad_comm_buffer .add_to_buffers (
419+ block_k_grad .contiguous (), block_v_grad .contiguous ()
420+ )
452421 grad_comm_buffer .send_recv ()
453422
454- grad_comm_buffer .wait ()
423+ grad_comm_buffer .wait ()
455424 key_grad_buffer , value_grad_buffer = grad_comm_buffer .get_buffers ()
456425
457426 return query_grad_buffer , key_grad_buffer , value_grad_buffer
@@ -481,14 +450,23 @@ def forward(
481450 dist .init_parallel_env ()
482451
483452 group = mesh ._get_group ("sep" )
484-
453+ local_query = dist .auto_parallel .api .dtensor_to_local (
454+ query , query .process_mesh , query .placements
455+ )
456+ local_key = dist .auto_parallel .api .dtensor_to_local (
457+ key , key .process_mesh , key .placements
458+ )
459+ local_value = dist .auto_parallel .api .dtensor_to_local (
460+ value , value .process_mesh , value .placements
461+ )
485462 if attn_mask is not None :
486463 is_causal = False
464+
487465 out , lse = ring_flash_attention_forward_func (
488466 group ,
489- query ,
490- key ,
491- value ,
467+ local_query ,
468+ local_key ,
469+ local_value ,
492470 attn_mask ,
493471 dropout ,
494472 is_causal ,
@@ -500,7 +478,7 @@ def forward(
500478 ctx .dropout = dropout
501479 ctx .is_causal = is_causal
502480 out_dtensor = dist .auto_parallel .api .dtensor_from_local (
503- out . contiguous () , query .process_mesh , query .placements
481+ out , query .process_mesh , query .placements
504482 )
505483 return out_dtensor .contiguous ()
506484
@@ -517,27 +495,39 @@ def backward(ctx, out_grad):
517495 fixed_seed_offset = paddle .to_tensor (
518496 [0 , 0 ], place = paddle .CPUPlace (), dtype = paddle .int64
519497 )
498+ local_query = dist .auto_parallel .api .dtensor_to_local (
499+ query , query .process_mesh , query .placements
500+ )
501+ local_key = dist .auto_parallel .api .dtensor_to_local (
502+ key , key .process_mesh , key .placements
503+ )
504+ local_value = dist .auto_parallel .api .dtensor_to_local (
505+ value , value .process_mesh , value .placements
506+ )
507+ local_out_grad = dist .auto_parallel .api .dtensor_to_local (
508+ out_grad , out_grad .process_mesh , out_grad .placements
509+ )
520510 query_grad , key_grad , value_grad = ring_flash_attention_backward_func (
521511 group ,
522- out_grad ,
523- query ,
524- key ,
525- value ,
512+ local_out_grad ,
513+ local_query ,
514+ local_key ,
515+ local_value ,
526516 out ,
527- lse . contiguous () ,
517+ lse ,
528518 attn_mask ,
529519 dropout ,
530520 is_causal ,
531521 fixed_seed_offset ,
532522 )
533523 query_grad_dtensor = dist .auto_parallel .api .dtensor_from_local (
534- query_grad . contiguous () , query .process_mesh , query .placements
524+ query_grad , query .process_mesh , query .placements
535525 )
536526 key_grad_dtensor = dist .auto_parallel .api .dtensor_from_local (
537- key_grad . contiguous () , key .process_mesh , key .placements
527+ key_grad , key .process_mesh , key .placements
538528 )
539529 value_grad_dtensor = dist .auto_parallel .api .dtensor_from_local (
540- value_grad . contiguous () , value .process_mesh , value .placements
530+ value_grad , value .process_mesh , value .placements
541531 )
542532
543533 if attn_mask is not None and not attn_mask .stop_gradient :
0 commit comments