Skip to content

Commit 3cd2b19

Browse files
authored
【auto parallel】fix cp ut (#73735)
1 parent dc0a4b9 commit 3cd2b19

File tree

5 files changed

+90
-85
lines changed

5 files changed

+90
-85
lines changed

python/paddle/distributed/auto_parallel/intermediate/tensor_parallel.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from paddle.nn import Layer
3131

3232

33-
def c_split(x, process_mesh, need_transpose):
33+
def c_split(x, process_mesh, need_transpose, split_type="sp"):
3434
mp_index = process_mesh.dim_names.index('mp') # get the axis for the split
3535
dp_index = process_mesh.dim_names.index('dp')
3636
if isinstance(x, tuple):
@@ -44,17 +44,23 @@ def c_split(x, process_mesh, need_transpose):
4444
placements = target_x.placements
4545
if placements is None:
4646
placements = [dist.Replicate() for _ in range(len(process_mesh.shape))]
47-
if placements[dp_index] == dist.Shard(0):
48-
# NOTE(zhangwl):if shard(0) , input shape should be [b,s,h]
49-
split_dims = dist.Shard(1)
50-
elif placements[dp_index] == dist.Shard(1):
51-
# NOTE(zhangwl):if shard(1) , input shape should be [s,b,h]
52-
split_dims = dist.Shard(0)
47+
if split_type == "sp":
48+
if placements[dp_index] == dist.Shard(0):
49+
# NOTE(zhangwl):if shard(0) , input shape should be [b,s,h]
50+
split_dims = dist.Shard(1)
51+
elif placements[dp_index] == dist.Shard(1):
52+
# NOTE(zhangwl):if shard(1) , input shape should be [s,b,h]
53+
split_dims = dist.Shard(0)
54+
else:
55+
logging.warning(
56+
f"parallel api don't know {target_x.shape} which dimension is batch, default is to cut to the 0th dimension"
57+
)
58+
split_dims = dist.Shard(0)
59+
elif split_type == "mp":
60+
split_dims = dist.Shard(2) # split h [b,s,h]
5361
else:
54-
logging.warning(
55-
f"parallel api don't know {target_x.shape} which dimension is batch, default is to cut to the 0th dimension"
56-
)
57-
split_dims = dist.Shard(0)
62+
raise ValueError(f"Unsupported split type {split_type}")
63+
5864
placements[mp_index] = split_dims
5965
target_x = dist.reshard(target_x, process_mesh, placements)
6066
if isinstance(x, tuple):
@@ -251,7 +257,7 @@ def __init__(self, is_input_parallel: bool = True) -> None:
251257

252258
def split_input_hook(self, process_mesh):
253259
def split_hook(layer, input):
254-
return c_split(input, process_mesh, False)
260+
return c_split(input, process_mesh, False, split_type="mp")
255261

256262
return split_hook
257263

python/paddle/distributed/auto_parallel/ring_attention.py

Lines changed: 61 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,14 @@ def unshard_seq_load_balance(tensor, seq_dim):
6969

7070
class RingCommunicator:
7171
def __init__(self, group, local_key, local_value):
72-
self._k_buffer = [paddle.zeros_like(local_key) for _ in range(2)]
73-
self._v_buffer = [paddle.zeros_like(local_value) for _ in range(2)]
74-
75-
local_key = local_key.contiguous()
76-
local_value = local_value.contiguous()
77-
self._k_buffer[0] = local_key.clone()
78-
self._v_buffer[0] = local_value.clone()
72+
self._k_buffer = [
73+
local_key.clone().contiguous(),
74+
local_key.clone().contiguous(),
75+
]
76+
self._v_buffer = [
77+
local_value.clone().contiguous(),
78+
local_value.clone().contiguous(),
79+
]
7980

8081
self._next_buffer_idx = 0
8182

@@ -97,8 +98,6 @@ def wait(self):
9798
paddle.device.synchronize()
9899

99100
def add_to_buffers(self, key, value):
100-
key = key.contiguous()
101-
value = value.contiguous()
102101
if key.shape != self._k_buffer[self._next_buffer_idx].shape:
103102
self._k_buffer[self._next_buffer_idx][:, : key.shape[1], :, :].add_(
104103
key
@@ -112,8 +111,8 @@ def add_to_buffers(self, key, value):
112111

113112
def get_buffers(self):
114113
return (
115-
self._k_buffer[self._next_buffer_idx].contiguous(),
116-
self._v_buffer[self._next_buffer_idx].contiguous(),
114+
self._k_buffer[self._next_buffer_idx],
115+
self._v_buffer[self._next_buffer_idx],
117116
)
118117

119118
def send_recv(self):
@@ -131,13 +130,13 @@ def send_recv(self):
131130
)
132131
recv_k_op = dist.P2POp(
133132
dist.irecv,
134-
self._k_buffer[(self._next_buffer_idx + 1) % 2].contiguous(),
133+
self._k_buffer[(self._next_buffer_idx + 1) % 2],
135134
self.recv_rank,
136135
self.group,
137136
)
138137
recv_v_op = dist.P2POp(
139138
dist.irecv,
140-
self._v_buffer[(self._next_buffer_idx + 1) % 2].contiguous(),
139+
self._v_buffer[(self._next_buffer_idx + 1) % 2],
141140
self.recv_rank,
142141
self.group,
143142
)
@@ -186,9 +185,9 @@ def concat_masks(attn_masks_list, rank, cp_size):
186185

187186
def ring_flash_attention_forward_func(
188187
group,
189-
query,
190-
key,
191-
value,
188+
local_query,
189+
local_key,
190+
local_value,
192191
attn_mask=None,
193192
dropout=0.0,
194193
is_causal=False,
@@ -197,20 +196,6 @@ def ring_flash_attention_forward_func(
197196
):
198197
cp_size = group.world_size
199198
group_rank = group.rank
200-
mesh = dist.auto_parallel.get_mesh()
201-
202-
local_query = dist.auto_parallel.api.dtensor_to_local(
203-
query, mesh, query.placements
204-
)
205-
local_key = dist.auto_parallel.api.dtensor_to_local(
206-
key, mesh, key.placements
207-
)
208-
local_value = dist.auto_parallel.api.dtensor_to_local(
209-
value, mesh, value.placements
210-
)
211-
local_query = local_query.contiguous()
212-
local_key = local_key.contiguous()
213-
local_value = local_value.contiguous()
214199

215200
comm_buffer = RingCommunicator(group, local_key, local_value)
216201
local_q_seq_len = local_query.shape[1]
@@ -219,7 +204,9 @@ def ring_flash_attention_forward_func(
219204
attn_mask, num_or_sections=cp_size * 2, axis=3
220205
)
221206
if is_causal:
222-
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :]
207+
local_query_second_chunk = local_query[
208+
:, local_q_seq_len // 2 :, :, :
209+
].contiguous()
223210
for step in range(cp_size):
224211
block_k, block_v = comm_buffer.get_buffers()
225212
if step != cp_size - 1:
@@ -315,10 +302,10 @@ def ring_flash_attention_forward_func(
315302

316303
def ring_flash_attention_backward_func(
317304
group,
318-
out_grad,
319-
query,
320-
key,
321-
value,
305+
local_out_grad,
306+
local_query,
307+
local_key,
308+
local_value,
322309
local_out,
323310
lse,
324311
attn_mask,
@@ -328,25 +315,7 @@ def ring_flash_attention_backward_func(
328315
):
329316
cp_size = group.world_size
330317
group_rank = group.rank
331-
mesh = dist.auto_parallel.get_mesh()
332318

333-
local_query = dist.auto_parallel.api.dtensor_to_local(
334-
query, mesh, query.placements
335-
)
336-
local_key = dist.auto_parallel.api.dtensor_to_local(
337-
key, mesh, key.placements
338-
)
339-
local_value = dist.auto_parallel.api.dtensor_to_local(
340-
value, mesh, value.placements
341-
)
342-
local_out_grad = dist.auto_parallel.api.dtensor_to_local(
343-
out_grad, mesh, out_grad.placements
344-
)
345-
local_query = local_query.contiguous()
346-
local_key = local_key.contiguous()
347-
local_value = local_value.contiguous()
348-
local_out = local_out.contiguous()
349-
local_out_grad = local_out_grad.contiguous()
350319
lse = lse.contiguous()
351320

352321
local_q_seq_len = local_query.shape[1]
@@ -361,14 +330,13 @@ def ring_flash_attention_backward_func(
361330
if is_causal:
362331
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :]
363332
local_out_second_chunk = local_out[:, local_q_seq_len // 2 :, :, :]
364-
lse_second_chunk = lse[:, :, local_q_seq_len // 2 :]
333+
lse_second_chunk = lse[:, :, local_q_seq_len // 2 :].contiguous()
365334
out_grad_second_chunk = local_out_grad[:, local_q_seq_len // 2 :, :, :]
366335

367336
if attn_mask is not None:
368337
attn_masks_list = paddle.split(
369338
attn_mask, num_or_sections=cp_size * 2, axis=3
370339
)
371-
372340
for step in range(cp_size):
373341
block_k, block_v = kv_comm_buffer.get_buffers()
374342
if step != cp_size - 1:
@@ -445,13 +413,14 @@ def ring_flash_attention_backward_func(
445413
)
446414
)
447415
query_grad_buffer.add_(block_q_grad)
448-
449416
paddle.device.synchronize()
450417

451-
grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad)
418+
grad_comm_buffer.add_to_buffers(
419+
block_k_grad.contiguous(), block_v_grad.contiguous()
420+
)
452421
grad_comm_buffer.send_recv()
453422

454-
grad_comm_buffer.wait()
423+
grad_comm_buffer.wait()
455424
key_grad_buffer, value_grad_buffer = grad_comm_buffer.get_buffers()
456425

457426
return query_grad_buffer, key_grad_buffer, value_grad_buffer
@@ -481,14 +450,23 @@ def forward(
481450
dist.init_parallel_env()
482451

483452
group = mesh._get_group("sep")
484-
453+
local_query = dist.auto_parallel.api.dtensor_to_local(
454+
query, query.process_mesh, query.placements
455+
)
456+
local_key = dist.auto_parallel.api.dtensor_to_local(
457+
key, key.process_mesh, key.placements
458+
)
459+
local_value = dist.auto_parallel.api.dtensor_to_local(
460+
value, value.process_mesh, value.placements
461+
)
485462
if attn_mask is not None:
486463
is_causal = False
464+
487465
out, lse = ring_flash_attention_forward_func(
488466
group,
489-
query,
490-
key,
491-
value,
467+
local_query,
468+
local_key,
469+
local_value,
492470
attn_mask,
493471
dropout,
494472
is_causal,
@@ -500,7 +478,7 @@ def forward(
500478
ctx.dropout = dropout
501479
ctx.is_causal = is_causal
502480
out_dtensor = dist.auto_parallel.api.dtensor_from_local(
503-
out.contiguous(), query.process_mesh, query.placements
481+
out, query.process_mesh, query.placements
504482
)
505483
return out_dtensor.contiguous()
506484

@@ -517,27 +495,39 @@ def backward(ctx, out_grad):
517495
fixed_seed_offset = paddle.to_tensor(
518496
[0, 0], place=paddle.CPUPlace(), dtype=paddle.int64
519497
)
498+
local_query = dist.auto_parallel.api.dtensor_to_local(
499+
query, query.process_mesh, query.placements
500+
)
501+
local_key = dist.auto_parallel.api.dtensor_to_local(
502+
key, key.process_mesh, key.placements
503+
)
504+
local_value = dist.auto_parallel.api.dtensor_to_local(
505+
value, value.process_mesh, value.placements
506+
)
507+
local_out_grad = dist.auto_parallel.api.dtensor_to_local(
508+
out_grad, out_grad.process_mesh, out_grad.placements
509+
)
520510
query_grad, key_grad, value_grad = ring_flash_attention_backward_func(
521511
group,
522-
out_grad,
523-
query,
524-
key,
525-
value,
512+
local_out_grad,
513+
local_query,
514+
local_key,
515+
local_value,
526516
out,
527-
lse.contiguous(),
517+
lse,
528518
attn_mask,
529519
dropout,
530520
is_causal,
531521
fixed_seed_offset,
532522
)
533523
query_grad_dtensor = dist.auto_parallel.api.dtensor_from_local(
534-
query_grad.contiguous(), query.process_mesh, query.placements
524+
query_grad, query.process_mesh, query.placements
535525
)
536526
key_grad_dtensor = dist.auto_parallel.api.dtensor_from_local(
537-
key_grad.contiguous(), key.process_mesh, key.placements
527+
key_grad, key.process_mesh, key.placements
538528
)
539529
value_grad_dtensor = dist.auto_parallel.api.dtensor_from_local(
540-
value_grad.contiguous(), value.process_mesh, value.placements
530+
value_grad, value.process_mesh, value.placements
541531
)
542532

543533
if attn_mask is not None and not attn_mask.stop_gradient:

test/auto_parallel/hybrid_strategy/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ if((WITH_GPU) AND (LINUX))
137137
set_tests_properties(test_parallel_api_with_llama_2d
138138
PROPERTIES TIMEOUT "400" LABELS "RUN_TYPE=HYBRID")
139139
endif()
140+
if((WITH_GPU) AND (LINUX))
141+
py_test_modules(
142+
test_parallel_api_with_llama_2d_sep MODULES
143+
test_parallel_api_with_llama_2d_sep ENVS
144+
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
145+
set_tests_properties(test_parallel_api_with_llama_2d_sep
146+
PROPERTIES TIMEOUT "400" LABELS "RUN_TYPE=HYBRID")
147+
endif()
140148
if((WITH_GPU) AND (LINUX))
141149
py_test_modules(
142150
test_parallel_api_with_llama_3d MODULES test_parallel_api_with_llama_3d

test/auto_parallel/hybrid_strategy/test_parallel_api_with_llama_2d_sep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def setUp(self):
3737
"amp_dtype": ["bfloat16"],
3838
"amp_master_grad": ["true"],
3939
"use_lazy_init": ["true"],
40-
"sequence_parallel": ["true"],
40+
"sequence_parallel": ["false"],
4141
"context_parallel": ["true"],
4242
"prepare_input_output": ["false"],
4343
"sharding_stage": ["0"],
@@ -83,7 +83,7 @@ def setUp(self):
8383
"amp_dtype": ["bfloat16"],
8484
"amp_master_grad": ["true"],
8585
"use_lazy_init": ["true"],
86-
"sequence_parallel": ["true"],
86+
"sequence_parallel": ["false"],
8787
"sep_parallel": ["true"],
8888
"context_parallel": ["false"],
8989
"prepare_input_output": ["false"],

test/auto_parallel/hybrid_strategy/testslist.csv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ test_semi_auto_llama_acc_align,LINUX,GPU,300,HYBRID,test_runner.py,,,http_proxy=
1515
test_semi_auto_llama_save_load,LINUX,GPU,180,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..;FLAGS_enable_pir_api=1,
1616
test_parallel_api_with_llama_1d,LINUX,GPU,400,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
1717
test_parallel_api_with_llama_2d,LINUX,GPU,400,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
18+
test_parallel_api_with_llama_2d_sep,LINUX,GPU,400,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
1819
test_parallel_api_with_llama_3d,LINUX,GPU,400,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
1920
test_parallel_api_with_llama_lora,LINUX,GPU,360,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
2021
test_process_mesh,LINUX,GPU,60,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,

0 commit comments

Comments
 (0)