Skip to content

Commit 345f3d0

Browse files
fajin-corpankitm3k
authored andcommitted
[ARM CPU] Add rotary embedding fp16 kernel (microsoft#23013)
### Description Add fp16 kernel to rotary embedding to boost performance. ### Motivation and Context Part of performance optimization work for group query attention
1 parent ddb6e65 commit 345f3d0

File tree

13 files changed

+526
-69
lines changed

13 files changed

+526
-69
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
4141
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
4242
${MLAS_SRC_DIR}/flashattn.cpp
4343
${MLAS_SRC_DIR}/cast.cpp
44+
${MLAS_SRC_DIR}/rotary_embedding.h
45+
${MLAS_SRC_DIR}/rotary_embedding.cpp
4446
)
4547

4648
target_sources(onnxruntime_mlas PRIVATE
@@ -88,8 +90,11 @@ function(setup_mlas_source_for_windows)
8890
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
8991
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
9092
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
91-
${MLAS_SRC_DIR}/fp16_neon_common.cpp
93+
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
9294
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
95+
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
96+
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
97+
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
9398
)
9499

95100
set(mlas_platform_preprocess_srcs
@@ -367,6 +372,8 @@ else()
367372
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
368373
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
369374
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
375+
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
376+
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
370377
)
371378
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
372379
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
@@ -384,8 +391,9 @@ else()
384391
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
385392
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
386393
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
387-
${MLAS_SRC_DIR}/fp16_neon_common.cpp
394+
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
388395
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
396+
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
389397
)
390398
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
391399
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
@@ -395,8 +403,9 @@ else()
395403
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
396404
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
397405
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
398-
set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
406+
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
399407
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
408+
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
400409
endif()
401410

402411
if(ONNXRUNTIME_MLAS_MULTI_ARCH)

onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "contrib_ops/cpu/bert/rotary_embedding.h"
55
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
66

7+
#include "core/mlas/inc/mlas.h"
78
#include "core/platform/threadpool.h"
89

910
using onnxruntime::concurrency::ThreadPool;
@@ -78,31 +79,12 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete
7879
const T* cos_data = cos_cache + cache_offset;
7980
const T* sin_data = sin_cache + cache_offset;
8081

81-
int cache_idx = 0;
82-
bool sign = false;
83-
int j = 0;
84-
for (int i = 0; i < rotary_emb_dim; i++) {
85-
if (interleaved) {
86-
cache_idx = (i / 2) % half_rotary_emb_dim;
87-
sign = i & 1;
88-
j = sign ? i - 1 : i + 1; // i - sign
89-
} else {
90-
cache_idx = i % half_rotary_emb_dim;
91-
sign = (i >= half_rotary_emb_dim);
92-
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
93-
}
94-
float output_data_i = static_cast<float>(input_data[i]) * static_cast<float>(cos_data[cache_idx]);
95-
float input_data_j = static_cast<float>(input_data[j]);
96-
float sin_data_cache_idx = static_cast<float>(sin_data[cache_idx]);
97-
if (sign) {
98-
output_data_i += input_data_j * sin_data_cache_idx;
99-
} else {
100-
output_data_i -= input_data_j * sin_data_cache_idx;
101-
}
102-
output_data[i] = static_cast<T>(output_data_i);
103-
}
104-
for (int i = rotary_emb_dim; i < head_size; i++) {
105-
output_data[i] = input_data[i];
82+
MlasRotaryEmbedOneRow<T>(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data);
83+
84+
if (rotary_emb_dim < head_size) {
85+
std::memcpy(output_data + rotary_emb_dim,
86+
input_data + rotary_emb_dim,
87+
(head_size - rotary_emb_dim) * sizeof(T));
10688
}
10789
}
10890
});

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,29 @@ MLAS_FP16* Destination,
14351435
size_t Count
14361436
);
14371437

1438+
/**
1439+
* @brief rotary embedding for one hidden state vector
1440+
*
1441+
* @tparam T: data type of input, sin, cos and output. Currently only float32/16 are supported.
1442+
* @param input: input tensor, of shape [dim]
1443+
* @param sin: sin tensor, of shape [dim/2]
1444+
* @param cos: cos tensor, of shape [dim/2]
1445+
* @param dim: dimension of rotary embedding
1446+
* @param interleaved: whether the real part and imaginary parts are interleaved
1447+
* @param output: output tensor, of shape [dim]
1448+
*/
1449+
template <typename T>
1450+
void
1451+
MLASCALL
1452+
MlasRotaryEmbedOneRow(
1453+
const T* input,
1454+
const T* sin,
1455+
const T* cos,
1456+
size_t dim,
1457+
bool interleaved,
1458+
T* output
1459+
);
1460+
14381461
/**
14391462
* @brief Whether current CPU supports FP16 acceleration.
14401463
*/

onnxruntime/core/mlas/lib/fp16_neon_common.cpp renamed to onnxruntime/core/mlas/lib/cast_kernel_neon.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Licensed under the MIT License.
66
77
Module Name:
88
9-
fp16_neon_common.cpp
9+
cast_kernel_neon.cpp
1010
1111
Abstract:
1212

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,13 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;
10491049

10501050
extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;
10511051

1052+
//
1053+
// Rotary embedding dispatch structure.
1054+
//
1055+
struct MLAS_ROPE_DISPATCH;
1056+
extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon;
1057+
1058+
10521059
//
10531060
// Quantized depthwise convolution kernels.
10541061
//
@@ -1208,6 +1215,8 @@ struct MLAS_PLATFORM {
12081215

12091216
MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
12101217
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;
1218+
1219+
const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr};
12111220
};
12121221

12131222
inline

onnxruntime/core/mlas/lib/platform.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ Return Value:
543543
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
544544
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
545545
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
546+
this->RopeDispatch = &MlasRopeDispatchNeon;
546547

547548
//
548549
// Check if the processor supports ASIMD dot product instructions.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*++
2+
3+
Copyright (c) Intel Corporation. All rights reserved.
4+
5+
Licensed under the MIT License.
6+
7+
Module Name:
8+
9+
rotary_embedding.cpp
10+
11+
Abstract:
12+
13+
This module implements rotary embedding kernels for fp32/16.
14+
15+
--*/
16+
17+
#include "rotary_embedding.h"
18+
19+
namespace {
20+
21+
template <typename T>
22+
void
23+
MLASCALL
24+
MlasRotaryEmbedOneRow_FallBack(
25+
const T* input_data,
26+
const T* sin_data,
27+
const T* cos_data,
28+
size_t rotary_emb_dim,
29+
bool interleaved,
30+
T* output_data
31+
) {
32+
const size_t half_rotary_emb_dim = rotary_emb_dim / 2;
33+
size_t cache_idx = 0;
34+
bool sign = false;
35+
size_t j = 0;
36+
for (size_t i = 0; i < rotary_emb_dim; i++) {
37+
if (interleaved) {
38+
cache_idx = (i / 2) % half_rotary_emb_dim;
39+
sign = i & 1;
40+
j = sign ? i - 1 : i + 1; // i - sign
41+
} else {
42+
cache_idx = i % half_rotary_emb_dim;
43+
sign = (i >= half_rotary_emb_dim);
44+
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
45+
}
46+
float output_data_i = static_cast<float>(input_data[i]) * static_cast<float>(cos_data[cache_idx]);
47+
float input_data_j = static_cast<float>(input_data[j]);
48+
float sin_data_cache_idx = static_cast<float>(sin_data[cache_idx]);
49+
if (sign) {
50+
output_data_i += input_data_j * sin_data_cache_idx;
51+
} else {
52+
output_data_i -= input_data_j * sin_data_cache_idx;
53+
}
54+
output_data[i] = static_cast<T>(output_data_i);
55+
}
56+
}
57+
58+
} // namespace
59+
60+
61+
template <>
62+
void
63+
MLASCALL
64+
MlasRotaryEmbedOneRow<float>(
65+
const float* input,
66+
const float* sin,
67+
const float* cos,
68+
size_t dim,
69+
bool interleaved,
70+
float* output
71+
) {
72+
const auto* dispatch = GetMlasPlatform().RopeDispatch;
73+
74+
if (dispatch == nullptr || dispatch->SRope == nullptr) {
75+
MlasRotaryEmbedOneRow_FallBack<float>(input, sin, cos, dim, interleaved, output);
76+
return;
77+
}
78+
79+
dispatch->SRope(input, sin, cos, dim, interleaved, output);
80+
}
81+
82+
template <>
83+
void
84+
MLASCALL
85+
MlasRotaryEmbedOneRow<MLAS_FP16>(
86+
const MLAS_FP16* input,
87+
const MLAS_FP16* sin,
88+
const MLAS_FP16* cos,
89+
size_t dim,
90+
bool interleaved,
91+
MLAS_FP16* output
92+
) {
93+
const auto* dispatch = GetMlasPlatform().RopeDispatch;
94+
95+
if (dispatch == nullptr || dispatch->HRope == nullptr) {
96+
MlasRotaryEmbedOneRow_FallBack<MLAS_FP16>(input, sin, cos, dim, interleaved, output);
97+
return;
98+
}
99+
100+
dispatch->HRope(input, sin, cos, dim, interleaved, output);
101+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*++
2+
3+
Copyright (c) Microsoft Corporation. All rights reserved.
4+
5+
Licensed under the MIT License.
6+
7+
Module Name:
8+
9+
rotary_embedding.h
10+
11+
Abstract:
12+
13+
This module includes kernel function prototypes and helper functions for
14+
implementing rotary embedding.
15+
16+
--*/
17+
18+
#pragma once
19+
20+
#include "mlasi.h"
21+
22+
struct MLAS_ROPE_DISPATCH {
23+
// rotary embedding kernel for fp32
24+
typedef void(SRope_Fn)(
25+
const float* input,
26+
const float* sin,
27+
const float* cos,
28+
size_t dim,
29+
bool interleaved,
30+
float* output
31+
);
32+
33+
SRope_Fn* SRope = nullptr;
34+
35+
// rotary embedding kernel for fp16
36+
typedef void(HRope_Fn)(
37+
const MLAS_FP16* input,
38+
const MLAS_FP16* sin,
39+
const MLAS_FP16* cos,
40+
size_t dim,
41+
bool interleaved,
42+
MLAS_FP16* output
43+
);
44+
45+
HRope_Fn* HRope = nullptr;
46+
};
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*++
2+
3+
Copyright (c) Microsoft Corporation. All rights reserved.
4+
5+
Licensed under the MIT License.
6+
7+
Module Name:
8+
9+
rotary_embedding_kernel_neon.cpp
10+
11+
Abstract:
12+
13+
This module implements the rotary embedding kernels for ARM NEON.
14+
15+
--*/
16+
17+
#include "rotary_embedding.h"
18+
#include "rotary_embedding_kernel_neon.h"
19+
20+
//
21+
// Kernel dispatch structure definition.
22+
//
23+
const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon = []() {
24+
MLAS_ROPE_DISPATCH d;
25+
26+
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
27+
if (MlasFp16AccelerationSupported()) {
28+
d.HRope = rope_neon::RopeKernel_Fp16;
29+
}
30+
#endif
31+
return d;
32+
}();
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*++
2+
3+
Copyright (c) Microsoft Corporation. All rights reserved.
4+
5+
Licensed under the MIT License.
6+
7+
Module Name:
8+
9+
rotary_embedding_kernel_neon.h
10+
11+
Abstract:
12+
13+
This module includes function declarations and common helper functions for
14+
rotary embedding on ARM cpu.
15+
16+
--*/
17+
18+
#pragma once
19+
20+
#include <arm_neon.h>
21+
22+
#include "mlasi.h"
23+
24+
namespace rope_neon {
25+
26+
// Rotary embedding kernel for fp16. Embed one hidden state vector.
27+
void
28+
RopeKernel_Fp16(
29+
const MLAS_FP16* input,
30+
const MLAS_FP16* sin,
31+
const MLAS_FP16* cos,
32+
size_t dim,
33+
bool interleaved,
34+
MLAS_FP16* output
35+
);
36+
37+
} // namespace rope_neon

0 commit comments

Comments
 (0)