Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e7716f1

Browse files
authoredDec 9, 2023
[AutoParallel] rename fuse_rope to fused_rope and polish code of flash_attention.cc (#59832)
1 parent 434f970 commit e7716f1

File tree

10 files changed

+109
-109
lines changed

10 files changed

+109
-109
lines changed
 

‎paddle/phi/api/yaml/fused_backward.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
optional : sin, cos, position_ids, out_k_grad, out_v_grad, k_grad, v_grad
3535
infer_meta :
3636
func : FusedRopeGradInferMeta
37-
spmd_rule : FuseRopeGradInferSpmd
37+
spmd_rule : FusedRopeGradInferSpmd
3838
kernel :
3939
func : fused_rotary_position_embedding_grad
4040
data_type : out_q_grad

‎paddle/phi/api/yaml/fused_ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@
264264
output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
265265
infer_meta :
266266
func : FusedRopeInferMeta
267-
spmd_rule : FuseRopeInferSpmd
267+
spmd_rule : FusedRopeInferSpmd
268268
optional : k, v, sin, cos, position_ids, out_k, out_v
269269
kernel :
270270
func : fused_rotary_position_embedding

‎paddle/phi/infermeta/spmd_rules/flash_attention.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q,
9595
phi::errors::InvalidArgument("The Tensor k's shape must be [batch_size, "
9696
"seq_len_kv, num_heads, head_dim]"));
9797

98-
auto k_batch_size = q_shape[0];
99-
auto k_seq_len = q_shape[1];
100-
auto k_num_heads = q_shape[2];
101-
auto k_head_dim = q_shape[3];
98+
auto k_batch_size = k_shape[0];
99+
auto k_seq_len = k_shape[1];
100+
auto k_num_heads = k_shape[2];
101+
auto k_head_dim = k_shape[3];
102102

103103
PADDLE_ENFORCE_EQ(
104104
batch_size,
@@ -112,7 +112,7 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q,
112112
num_heads,
113113
k_num_heads,
114114
phi::errors::InvalidArgument(
115-
"The Tensor q and k's k_num_heads [%d] vs [%d] are not matched.",
115+
"The Tensor q and k's num_heads [%d] vs [%d] are not matched.",
116116
num_heads,
117117
k_num_heads));
118118

@@ -160,7 +160,7 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q,
160160
num_heads,
161161
v_num_heads,
162162
phi::errors::InvalidArgument(
163-
"The Tensor q and v's k_num_heads [%d] vs [%d] are not matched.",
163+
"The Tensor q and v's num_heads [%d] vs [%d] are not matched.",
164164
num_heads,
165165
v_num_heads));
166166

@@ -175,7 +175,7 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q,
175175
PADDLE_ENFORCE_EQ(
176176
v_ndim,
177177
v_dims_mapping_size,
178-
phi::errors::InvalidArgument("The Tensor q's rank [%d] and Its "
178+
phi::errors::InvalidArgument("The Tensor v's rank [%d] and Its "
179179
"dims_mapping size [%d] are not matched.",
180180
v_ndim,
181181
v_dims_mapping_size));
@@ -324,10 +324,10 @@ SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q,
324324
phi::errors::InvalidArgument("The Tensor k's shape must be [batch_size, "
325325
"seq_len_kv, num_heads, head_dim]"));
326326

327-
auto k_batch_size = q_shape[0];
328-
auto k_seq_len = q_shape[1];
329-
auto k_num_heads = q_shape[2];
330-
auto k_head_dim = q_shape[3];
327+
auto k_batch_size = k_shape[0];
328+
auto k_seq_len = k_shape[1];
329+
auto k_num_heads = k_shape[2];
330+
auto k_head_dim = k_shape[3];
331331

332332
PADDLE_ENFORCE_EQ(
333333
batch_size,
@@ -341,7 +341,7 @@ SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q,
341341
num_heads,
342342
k_num_heads,
343343
phi::errors::InvalidArgument(
344-
"The Tensor q and k's k_num_heads [%d] vs [%d] are not matched.",
344+
"The Tensor q and k's num_heads [%d] vs [%d] are not matched.",
345345
num_heads,
346346
k_num_heads));
347347

‎paddle/phi/infermeta/spmd_rules/fuse_rope.h

Lines changed: 0 additions & 54 deletions
This file was deleted.

‎paddle/phi/infermeta/spmd_rules/fuse_rope.cc renamed to ‎paddle/phi/infermeta/spmd_rules/fused_rope.cc

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/phi/infermeta/spmd_rules/fuse_rope.h"
15+
#include "paddle/phi/infermeta/spmd_rules/fused_rope.h"
1616

1717
#include "glog/logging.h"
1818

@@ -186,13 +186,13 @@ void infer_sin_cos(const DistMetaTensor& sin,
186186
}
187187
}
188188

189-
SpmdInfo FuseRopeInferSpmd(const DistMetaTensor& q,
190-
const DistMetaTensor& k,
191-
const DistMetaTensor& v,
192-
const DistMetaTensor& sin,
193-
const DistMetaTensor& cos,
194-
const DistMetaTensor& position_ids,
195-
bool use_neox_rotary_style) {
189+
SpmdInfo FusedRopeInferSpmd(const DistMetaTensor& q,
190+
const DistMetaTensor& k,
191+
const DistMetaTensor& v,
192+
const DistMetaTensor& sin,
193+
const DistMetaTensor& cos,
194+
const DistMetaTensor& position_ids,
195+
bool use_neox_rotary_style) {
196196
check_q(q);
197197

198198
std::vector<std::pair<std::string, std::vector<int64_t>>>
@@ -270,16 +270,16 @@ SpmdInfo FuseRopeInferSpmd(const DistMetaTensor& q,
270270
{q_dist_attr_dst, k_dist_attr_dst, v_dist_attr_dst}};
271271
}
272272

273-
SpmdInfo FuseRopeInferSpmdReverse(const DistMetaTensor& q,
274-
const DistMetaTensor& k,
275-
const DistMetaTensor& v,
276-
const DistMetaTensor& sin,
277-
const DistMetaTensor& cos,
278-
const DistMetaTensor& position_ids,
279-
const DistMetaTensor& out_q,
280-
const DistMetaTensor& out_k,
281-
const DistMetaTensor& out_v,
282-
bool use_neox_rotary_style) {
273+
SpmdInfo FusedRopeInferSpmdReverse(const DistMetaTensor& q,
274+
const DistMetaTensor& k,
275+
const DistMetaTensor& v,
276+
const DistMetaTensor& sin,
277+
const DistMetaTensor& cos,
278+
const DistMetaTensor& position_ids,
279+
const DistMetaTensor& out_q,
280+
const DistMetaTensor& out_k,
281+
const DistMetaTensor& out_v,
282+
bool use_neox_rotary_style) {
283283
check_q(out_q);
284284
std::vector<std::pair<std::string, std::vector<int64_t>>>
285285
outputs_sharding_info;
@@ -366,22 +366,22 @@ SpmdInfo FuseRopeInferSpmdReverse(const DistMetaTensor& q,
366366
{out_q_dist_attr_dst, out_k_dist_attr_dst, out_v_dist_attr_dst}};
367367
}
368368

369-
SpmdInfo FuseRopeGradInferSpmd(const DistMetaTensor& sin,
370-
const DistMetaTensor& cos,
371-
const DistMetaTensor& position_ids,
372-
const DistMetaTensor& out_q_grad,
373-
const DistMetaTensor& out_k_grad,
374-
const DistMetaTensor& out_v_grad,
375-
bool use_neox_rotary_style) {
369+
SpmdInfo FusedRopeGradInferSpmd(const DistMetaTensor& sin,
370+
const DistMetaTensor& cos,
371+
const DistMetaTensor& position_ids,
372+
const DistMetaTensor& out_q_grad,
373+
const DistMetaTensor& out_k_grad,
374+
const DistMetaTensor& out_v_grad,
375+
bool use_neox_rotary_style) {
376376
// NOTE(zhonghui): The forward and backward kernels of fuse rope are same, so
377377
// the spmd rules can be shared.
378-
SpmdInfo spmd_info = FuseRopeInferSpmd(out_q_grad,
379-
out_k_grad,
380-
out_v_grad,
381-
sin,
382-
cos,
383-
position_ids,
384-
use_neox_rotary_style);
378+
SpmdInfo spmd_info = FusedRopeInferSpmd(out_q_grad,
379+
out_k_grad,
380+
out_v_grad,
381+
sin,
382+
cos,
383+
position_ids,
384+
use_neox_rotary_style);
385385
std::vector<ArgDistAttr> dist_attrs;
386386
std::vector<int> order = {3, 4, 5, 0, 1, 2};
387387
for (int ind : order) {
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
19+
#include "paddle/phi/common/int_array.h"
20+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
21+
#include "paddle/phi/core/distributed/type_defs.h"
22+
23+
namespace phi {
24+
namespace distributed {
25+
26+
SpmdInfo FusedRopeInferSpmd(const DistMetaTensor& q,
27+
const DistMetaTensor& k,
28+
const DistMetaTensor& v,
29+
const DistMetaTensor& sin,
30+
const DistMetaTensor& cos,
31+
const DistMetaTensor& position_ids,
32+
bool use_neox_rotary_style);
33+
34+
SpmdInfo FusedRopeInferSpmdReverse(const DistMetaTensor& q,
35+
const DistMetaTensor& k,
36+
const DistMetaTensor& v,
37+
const DistMetaTensor& sin,
38+
const DistMetaTensor& cos,
39+
const DistMetaTensor& position_ids,
40+
const DistMetaTensor& out_q,
41+
const DistMetaTensor& out_k,
42+
const DistMetaTensor& out_v,
43+
bool use_neox_rotary_style);
44+
45+
SpmdInfo FusedRopeGradInferSpmd(const DistMetaTensor& sin,
46+
const DistMetaTensor& cos,
47+
const DistMetaTensor& position_ids,
48+
const DistMetaTensor& out_q_grad,
49+
const DistMetaTensor& out_k_grad,
50+
const DistMetaTensor& out_v_grad,
51+
bool use_neox_rotary_style);
52+
53+
} // namespace distributed
54+
} // namespace phi

‎paddle/phi/infermeta/spmd_rules/rules.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ limitations under the License. */
2323
#include "paddle/phi/infermeta/spmd_rules/flash_attention.h"
2424
#include "paddle/phi/infermeta/spmd_rules/flatten.h"
2525
#include "paddle/phi/infermeta/spmd_rules/full_like.h"
26-
#include "paddle/phi/infermeta/spmd_rules/fuse_rope.h"
26+
#include "paddle/phi/infermeta/spmd_rules/fused_rope.h"
2727
#include "paddle/phi/infermeta/spmd_rules/layer_norm.h"
2828
#include "paddle/phi/infermeta/spmd_rules/matmul.h"
2929
#include "paddle/phi/infermeta/spmd_rules/numel.h"

‎test/auto_parallel/semi_auto_parallel_for_fuse_rope.py renamed to ‎test/auto_parallel/semi_auto_parallel_for_fused_rope.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
fused_rotary_position_embedding = None
2727

2828

29-
class TestFuseRopeApiForSemiAutoParallel(SemiAutoParallelTestBase):
29+
class TestFusedRopeApiForSemiAutoParallel(SemiAutoParallelTestBase):
3030
def __init__(self):
3131
self._dtype = os.getenv("dtype")
3232
self._backend = os.getenv("backend")
@@ -146,4 +146,4 @@ def run_test_case(self):
146146

147147

148148
if __name__ == '__main__':
149-
TestFuseRopeApiForSemiAutoParallel().run_test_case()
149+
TestFusedRopeApiForSemiAutoParallel().run_test_case()

‎test/auto_parallel/test_semi_auto_parallel_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test_fuse_rope_api(self):
176176
)
177177
for envs in envs_list:
178178
self.run_test_case(
179-
"semi_auto_parallel_for_fuse_rope.py",
179+
"semi_auto_parallel_for_fused_rope.py",
180180
user_defined_envs=envs,
181181
)
182182

‎test/cpp/auto_parallel/spmd_rule_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,7 +1188,7 @@ TEST(Transpose, Ctor) {
11881188
check_partial_dims(backward_spmd_info.second[0], {});
11891189
}
11901190

1191-
TEST(FuseRope, Ctor) {
1191+
TEST(FusedRope, Ctor) {
11921192
std::vector<int64_t> mesh_shape = {2, 2};
11931193
std::vector<int64_t> process_ids = {0, 1, 2, 3};
11941194
std::vector<std::string> dim_names = {"x", "y"};
@@ -1212,7 +1212,7 @@ TEST(FuseRope, Ctor) {
12121212
// 1. test forward
12131213
// 1.1 only q input
12141214
phi::distributed::SpmdInfo forward_spmd_info =
1215-
phi::distributed::FuseRopeInferSpmd(
1215+
phi::distributed::FusedRopeInferSpmd(
12161216
q, none, none, none, none, none, false);
12171217
EXPECT_EQ(forward_spmd_info.first.size(), static_cast<size_t>(6));
12181218
EXPECT_EQ(forward_spmd_info.second.size(), static_cast<size_t>(3));
@@ -1236,7 +1236,7 @@ TEST(FuseRope, Ctor) {
12361236
build_input({1, 2048, 1, 128}, {-1, 1, -1, -1});
12371237
phi::distributed::DistMetaTensor position_ids =
12381238
build_input({16, 2048}, {0, 1});
1239-
forward_spmd_info = phi::distributed::FuseRopeInferSpmd(
1239+
forward_spmd_info = phi::distributed::FusedRopeInferSpmd(
12401240
q, k, none, sin, cos, position_ids, false);
12411241
EXPECT_EQ(forward_spmd_info.first.size(), static_cast<size_t>(6));
12421242
EXPECT_EQ(forward_spmd_info.second.size(), static_cast<size_t>(3));
@@ -1253,7 +1253,7 @@ TEST(FuseRope, Ctor) {
12531253
check_partial_dims(forward_spmd_info.second[1], {});
12541254
// 2. test backward
12551255
phi::distributed::SpmdInfo backward_spmd_info =
1256-
FuseRopeGradInferSpmd(sin, cos, position_ids, q, k, none, false);
1256+
FusedRopeGradInferSpmd(sin, cos, position_ids, q, k, none, false);
12571257
EXPECT_EQ(backward_spmd_info.first.size(), static_cast<size_t>(6));
12581258
EXPECT_EQ(backward_spmd_info.second.size(), static_cast<size_t>(3));
12591259
check_dim_mapping(backward_spmd_info.first[0], {-1, -1, -1, -1});
@@ -1273,7 +1273,7 @@ TEST(FuseRope, Ctor) {
12731273
build_input({16, 2048, 64, 128}, {0, 1, -1, -1});
12741274
phi::distributed::DistMetaTensor out_k =
12751275
build_input({16, 2048, 64, 128}, {-1, 1, -1, 0});
1276-
phi::distributed::SpmdInfo reverse_spmd_info = FuseRopeInferSpmdReverse(
1276+
phi::distributed::SpmdInfo reverse_spmd_info = FusedRopeInferSpmdReverse(
12771277
q, k, none, sin, cos, position_ids, out_q, out_k, none, false);
12781278
EXPECT_EQ(reverse_spmd_info.first.size(), static_cast<size_t>(6));
12791279
EXPECT_EQ(reverse_spmd_info.second.size(), static_cast<size_t>(3));

0 commit comments

Comments
 (0)