Skip to content

Commit 07a34bc

Browse files
Add FlashMask V2 (#74729)
* add flashmask v2 Co-authored-by: starcrown001 <148410714+starcrown001@users.noreply.github.com> * supprot seqlenq != seqlenk in flashmask Co-authored-by: starcrown001 <148410714+starcrown001@users.noreply.github.com> * refine * fix xpu * fix codestyle * update fa submodule * fix flashmaskv2 maxmin buffer padding * fix code style --------- Co-authored-by: starcrown001 <148410714+starcrown001@users.noreply.github.com>
1 parent 3cba4c1 commit 07a34bc

File tree

22 files changed

+2792
-20
lines changed

22 files changed

+2792
-20
lines changed

cmake/external/flashattn.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ else()
8989
set(FLASHATTN_V3_LIBRARIES
9090
"${FLASHATTN_INSTALL_DIR}/bin/libflashattnv3${CMAKE_SHARED_LIBRARY_SUFFIX}"
9191
CACHE FILEPATH "flash-attn Library" FORCE)
92+
set(FLASHMASK_V2_LIBRARIES
93+
"${FLASHATTN_INSTALL_DIR}/bin/libflashmaskv2${CMAKE_SHARED_LIBRARY_SUFFIX}"
94+
CACHE FILEPATH "flash-attn Library" FORCE)
9295
endif()
9396
else()
9497
set(FLASHATTN_LIBRARIES
@@ -98,13 +101,17 @@ else()
98101
set(FLASHATTN_V3_LIBRARIES
99102
"${FLASHATTN_INSTALL_DIR}/lib/libflashattnv3${CMAKE_SHARED_LIBRARY_SUFFIX}"
100103
CACHE FILEPATH "flash-attn Library" FORCE)
104+
set(FLASHMASK_V2_LIBRARIES
105+
"${FLASHATTN_INSTALL_DIR}/lib/libflashmaskv2${CMAKE_SHARED_LIBRARY_SUFFIX}"
106+
CACHE FILEPATH "flash-attn Library" FORCE)
101107
endif()
102108
endif()
103109

104110
set(BUILD_BYPRODUCTS_LIST ${FLASHATTN_LIBRARIES})
105111
if(WITH_FLASHATTN_V3)
106112
add_definitions(-DPADDLE_WITH_FLASHATTN_V3)
107113
list(APPEND BUILD_BYPRODUCTS_LIST ${FLASHATTN_V3_LIBRARIES})
114+
list(APPEND BUILD_BYPRODUCTS_LIST ${FLASHMASK_V2_LIBRARIES})
108115
endif()
109116

110117
if(NOT DEFINED FA_JOB_POOLS_COMPILE)
@@ -293,6 +300,7 @@ endif()
293300
message(STATUS "flash-attn library: ${FLASHATTN_LIBRARIES}")
294301
if(WITH_FLASHATTN_V3)
295302
message(STATUS "flash-attn-v3 library: ${FLASHATTN_V3_LIBRARIES}")
303+
message(STATUS "flash-mask-v2 library: ${FLASHMASK_V2_LIBRARIES}")
296304
endif()
297305
get_filename_component(FLASHATTN_LIBRARY_PATH ${FLASHATTN_LIBRARIES} DIRECTORY)
298306
include_directories(${FLASHATTN_INCLUDE_DIR})

cmake/inference_lib.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ function(copy_part_of_third_party TARGET DST)
216216
${TARGET}
217217
SRCS ${FLASHATTN_INCLUDE_DIR} ${FLASHATTN_V3_LIBRARIES}
218218
DSTS ${dst_dir} ${dst_dir}/lib)
219+
copy(
220+
${TARGET}
221+
SRCS ${FLASHATTN_INCLUDE_DIR} ${FLASHMASK_V2_LIBRARIES}
222+
DSTS ${dst_dir} ${dst_dir}/lib)
219223
endif()
220224

221225
if(NOT PROTOBUF_FOUND OR WIN32)

paddle/phi/backends/dynload/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ endif()
9494

9595
if(WITH_FLASHATTN_V3)
9696
list(APPEND DYNLOAD_COMMON_SRCS flashattnv3.cc)
97+
list(APPEND DYNLOAD_COMMON_SRCS flashmaskv2.cc)
9798
endif()
9899

99100
if(MKL_FOUND AND WITH_ONEMKL)

paddle/phi/backends/dynload/dynamic_loader.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,20 @@ void* GetFlashAttnV3DsoHandle() {
834834
#endif
835835
}
836836

837+
void* GetFlashMaskV2DsoHandle() {
838+
std::string flashattn_dir = "";
839+
if (!s_py_site_pkg_path.path.empty()) {
840+
flashattn_dir = s_py_site_pkg_path.path;
841+
}
842+
#if defined(__APPLE__) || defined(__OSX__)
843+
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashmaskv2.dylib");
844+
#elif defined(_WIN32)
845+
return GetDsoHandleFromSearchPath(flashattn_dir, "flashmaskv2.dll");
846+
#else
847+
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashmaskv2.so");
848+
#endif
849+
}
850+
837851
void* GetAfsApiDsoHandle() {
838852
std::string afsapi_dir = "";
839853
if (!s_py_site_pkg_path.path.empty()) {

paddle/phi/backends/dynload/dynamic_loader.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ void* GetWarpCTCDsoHandle();
3838
void* GetWarpRNNTDsoHandle();
3939
void* GetFlashAttnDsoHandle();
4040
void* GetFlashAttnV3DsoHandle();
41+
void* GetFlashMaskV2DsoHandle();
4142
void* GetNCCLDsoHandle();
4243
void* GetFLAGCXDsoHandle();
4344
void* GetTensorRtDsoHandle();
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/dynload/flashmaskv2.h"
16+
17+
namespace phi {
18+
namespace dynload {
19+
20+
std::once_flag flashmaskv2_dso_flag;
21+
void* flashmaskv2_dso_handle = nullptr;
22+
23+
#define DEFINE_WRAP(__name) DynLoad__##__name __name
24+
25+
FLASHMASK_V2_ROUTINE_EACH(DEFINE_WRAP);
26+
27+
} // namespace dynload
28+
} // namespace phi
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <mutex> // NOLINT
18+
19+
#include "flashattn/include/flashmaskv2_api.h"
20+
#include "paddle/phi/backends/dynload/dynamic_loader.h"
21+
#include "paddle/phi/common/port.h"
22+
23+
namespace phi {
24+
namespace dynload {
25+
26+
extern std::once_flag flashmaskv2_dso_flag;
27+
extern void *flashmaskv2_dso_handle;
28+
29+
#define DYNAMIC_LOAD_FLASHMASK_V2_WRAP(__name) \
30+
struct DynLoad__##__name { \
31+
template <typename... Args> \
32+
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
33+
using flashattnFunc = decltype(&::__name); \
34+
std::call_once(flashmaskv2_dso_flag, []() { \
35+
flashmaskv2_dso_handle = phi::dynload::GetFlashMaskV2DsoHandle(); \
36+
}); \
37+
static void *p_##__name = dlsym(flashmaskv2_dso_handle, #__name); \
38+
return reinterpret_cast<flashattnFunc>(p_##__name)(args...); \
39+
} \
40+
}; \
41+
extern DynLoad__##__name __name
42+
43+
#define DECLARE_DYNAMIC_LOAD_FLASHMASK_V2_WRAP(__name) \
44+
DYNAMIC_LOAD_FLASHMASK_V2_WRAP(__name)
45+
46+
#ifdef PADDLE_WITH_CUDA
47+
#define FLASHMASK_V2_ROUTINE_EACH(__macro) \
48+
__macro(flashmaskv2_create_fwd_params_handle); \
49+
__macro(flashmaskv2_clear_fwd_params_handle); \
50+
__macro(flashmaskv2_destroy_fwd_params_handle); \
51+
__macro(flashmaskv2_create_bwd_params_handle); \
52+
__macro(flashmaskv2_clear_bwd_params_handle); \
53+
__macro(flashmaskv2_destroy_bwd_params_handle); \
54+
__macro(flashmaskv2_cast_to_fwd_params_handle); \
55+
__macro(flashmaskv2_run_mha_fwd_combine); \
56+
__macro(flashmaskv2_run_mha_fwd); \
57+
__macro(flashmaskv2_run_mha_bwd); \
58+
__macro(flashmaskv2_get_pagedkv_tma); \
59+
__macro(flashmaskv2_get_pack_gqa); \
60+
__macro(flashmaskv2_get_num_splits);
61+
62+
FLASHMASK_V2_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHMASK_V2_WRAP)
63+
64+
#define FLASHMASK_V2_HANDLE_ROUTINE(member) \
65+
DECLARE_DYNAMIC_LOAD_FLASHMASK_V2_WRAP(flashmaskv2_fwd_params_get_##member); \
66+
DECLARE_DYNAMIC_LOAD_FLASHMASK_V2_WRAP(flashmaskv2_fwd_params_set_##member); \
67+
DECLARE_DYNAMIC_LOAD_FLASHMASK_V2_WRAP(flashmaskv2_bwd_params_get_##member); \
68+
DECLARE_DYNAMIC_LOAD_FLASHMASK_V2_WRAP(flashmaskv2_bwd_params_set_##member);
69+
70+
// The QKV matrices.
71+
FLASHMASK_V2_HANDLE_ROUTINE(q_ptr)
72+
FLASHMASK_V2_HANDLE_ROUTINE(k_ptr)
73+
FLASHMASK_V2_HANDLE_ROUTINE(v_ptr)
74+
75+
// The stride between rows of the Q, K and V matrices.
76+
FLASHMASK_V2_HANDLE_ROUTINE(q_batch_stride)
77+
FLASHMASK_V2_HANDLE_ROUTINE(k_batch_stride)
78+
FLASHMASK_V2_HANDLE_ROUTINE(v_batch_stride)
79+
FLASHMASK_V2_HANDLE_ROUTINE(q_row_stride)
80+
FLASHMASK_V2_HANDLE_ROUTINE(k_row_stride)
81+
FLASHMASK_V2_HANDLE_ROUTINE(v_row_stride)
82+
FLASHMASK_V2_HANDLE_ROUTINE(q_head_stride)
83+
FLASHMASK_V2_HANDLE_ROUTINE(k_head_stride)
84+
FLASHMASK_V2_HANDLE_ROUTINE(v_head_stride)
85+
FLASHMASK_V2_HANDLE_ROUTINE(v_dim_stride)
86+
87+
// The number of heads.
88+
FLASHMASK_V2_HANDLE_ROUTINE(h)
89+
FLASHMASK_V2_HANDLE_ROUTINE(h_k)
90+
91+
// The O matrix (output).
92+
FLASHMASK_V2_HANDLE_ROUTINE(o_ptr)
93+
FLASHMASK_V2_HANDLE_ROUTINE(oaccum_ptr)
94+
95+
// The stride between rows of O.
96+
FLASHMASK_V2_HANDLE_ROUTINE(o_batch_stride)
97+
FLASHMASK_V2_HANDLE_ROUTINE(o_row_stride)
98+
FLASHMASK_V2_HANDLE_ROUTINE(o_head_stride)
99+
100+
// The pointer to the softmax sum.
101+
FLASHMASK_V2_HANDLE_ROUTINE(softmax_lse_ptr)
102+
FLASHMASK_V2_HANDLE_ROUTINE(softmax_lseaccum_ptr)
103+
104+
// For FP8 scaling
105+
FLASHMASK_V2_HANDLE_ROUTINE(q_descale_ptr)
106+
FLASHMASK_V2_HANDLE_ROUTINE(k_descale_ptr)
107+
FLASHMASK_V2_HANDLE_ROUTINE(v_descale_ptr)
108+
FLASHMASK_V2_HANDLE_ROUTINE(q_descale_batch_stride)
109+
FLASHMASK_V2_HANDLE_ROUTINE(q_descale_head_stride)
110+
FLASHMASK_V2_HANDLE_ROUTINE(k_descale_batch_stride)
111+
FLASHMASK_V2_HANDLE_ROUTINE(k_descale_head_stride)
112+
FLASHMASK_V2_HANDLE_ROUTINE(v_descale_batch_stride)
113+
FLASHMASK_V2_HANDLE_ROUTINE(v_descale_head_stride)
114+
115+
// The dimensions.
116+
FLASHMASK_V2_HANDLE_ROUTINE(b)
117+
FLASHMASK_V2_HANDLE_ROUTINE(seqlen_q)
118+
FLASHMASK_V2_HANDLE_ROUTINE(seqlen_k)
119+
FLASHMASK_V2_HANDLE_ROUTINE(seqlen_knew)
120+
FLASHMASK_V2_HANDLE_ROUTINE(d)
121+
FLASHMASK_V2_HANDLE_ROUTINE(seqlen_q_rounded)
122+
FLASHMASK_V2_HANDLE_ROUTINE(seqlen_k_rounded)
123+
FLASHMASK_V2_HANDLE_ROUTINE(d_rounded)
124+
FLASHMASK_V2_HANDLE_ROUTINE(rotary_dim)
125+
FLASHMASK_V2_HANDLE_ROUTINE(total_q)
126+
FLASHMASK_V2_HANDLE_ROUTINE(total_k)
127+
FLASHMASK_V2_HANDLE_ROUTINE(total_knew)
128+
FLASHMASK_V2_HANDLE_ROUTINE(b_k)
129+
FLASHMASK_V2_HANDLE_ROUTINE(dv)
130+
FLASHMASK_V2_HANDLE_ROUTINE(dv_rounded)
131+
132+
// The scaling factors for the kernel.
133+
FLASHMASK_V2_HANDLE_ROUTINE(scale_softmax)
134+
FLASHMASK_V2_HANDLE_ROUTINE(softcap)
135+
136+
// array of length b+1 holding starting offset of each sequence.
137+
FLASHMASK_V2_HANDLE_ROUTINE(cu_seqlens_q)
138+
FLASHMASK_V2_HANDLE_ROUTINE(cu_seqlens_k)
139+
FLASHMASK_V2_HANDLE_ROUTINE(cu_seqlens_knew)
140+
FLASHMASK_V2_HANDLE_ROUTINE(leftpad_k)
141+
142+
// If provided, the actual length of each q/k sequence.
143+
FLASHMASK_V2_HANDLE_ROUTINE(seqused_q)
144+
FLASHMASK_V2_HANDLE_ROUTINE(seqused_k)
145+
146+
// The stride between rows of Oaccum.
147+
FLASHMASK_V2_HANDLE_ROUTINE(oaccum_split_stride)
148+
FLASHMASK_V2_HANDLE_ROUTINE(oaccum_batch_stride)
149+
FLASHMASK_V2_HANDLE_ROUTINE(oaccum_row_stride)
150+
FLASHMASK_V2_HANDLE_ROUTINE(oaccum_head_stride)
151+
152+
// The stride between rows of LSEaccum.
153+
FLASHMASK_V2_HANDLE_ROUTINE(lseaccum_split_stride)
154+
FLASHMASK_V2_HANDLE_ROUTINE(lseaccum_batch_stride)
155+
FLASHMASK_V2_HANDLE_ROUTINE(lseaccum_head_stride)
156+
157+
// The K_new and V_new matrices.
158+
FLASHMASK_V2_HANDLE_ROUTINE(knew_ptr)
159+
FLASHMASK_V2_HANDLE_ROUTINE(vnew_ptr)
160+
161+
// The stride between rows of the Q, K and V matrices.
162+
FLASHMASK_V2_HANDLE_ROUTINE(knew_batch_stride)
163+
FLASHMASK_V2_HANDLE_ROUTINE(vnew_batch_stride)
164+
FLASHMASK_V2_HANDLE_ROUTINE(knew_row_stride)
165+
FLASHMASK_V2_HANDLE_ROUTINE(vnew_row_stride)
166+
FLASHMASK_V2_HANDLE_ROUTINE(knew_head_stride)
167+
FLASHMASK_V2_HANDLE_ROUTINE(vnew_head_stride)
168+
169+
FLASHMASK_V2_HANDLE_ROUTINE(qv_ptr)
170+
FLASHMASK_V2_HANDLE_ROUTINE(qv_batch_stride)
171+
FLASHMASK_V2_HANDLE_ROUTINE(qv_row_stride)
172+
FLASHMASK_V2_HANDLE_ROUTINE(qv_head_stride)
173+
174+
// The cos and sin matrices for rotary embedding.
175+
FLASHMASK_V2_HANDLE_ROUTINE(rotary_cos_ptr)
176+
FLASHMASK_V2_HANDLE_ROUTINE(rotary_sin_ptr)
177+
178+
// The indices to index into the KV cache.
179+
FLASHMASK_V2_HANDLE_ROUTINE(kv_batch_idx)
180+
181+
// Paged KV cache
182+
FLASHMASK_V2_HANDLE_ROUTINE(page_table)
183+
FLASHMASK_V2_HANDLE_ROUTINE(page_table_batch_stride)
184+
FLASHMASK_V2_HANDLE_ROUTINE(page_size)
185+
FLASHMASK_V2_HANDLE_ROUTINE(num_pages)
186+
FLASHMASK_V2_HANDLE_ROUTINE(pagedkv_tma)
187+
188+
// The dropout probability (probability of keeping an activation).
189+
FLASHMASK_V2_HANDLE_ROUTINE(p_dropout)
190+
FLASHMASK_V2_HANDLE_ROUTINE(p_dropout_in_uint8_t)
191+
192+
// Scale factor of 1 / (1 - p_dropout).
193+
FLASHMASK_V2_HANDLE_ROUTINE(rp_dropout)
194+
195+
// Local window size
196+
FLASHMASK_V2_HANDLE_ROUTINE(window_size_left)
197+
FLASHMASK_V2_HANDLE_ROUTINE(window_size_right)
198+
199+
// Pointer to the RNG seed (idx 0) and offset (idx 1).
200+
FLASHMASK_V2_HANDLE_ROUTINE(rng_state)
201+
202+
FLASHMASK_V2_HANDLE_ROUTINE(is_bf16)
203+
FLASHMASK_V2_HANDLE_ROUTINE(is_fp32)
204+
FLASHMASK_V2_HANDLE_ROUTINE(is_e4m3)
205+
FLASHMASK_V2_HANDLE_ROUTINE(is_causal)
206+
FLASHMASK_V2_HANDLE_ROUTINE(is_local)
207+
208+
FLASHMASK_V2_HANDLE_ROUTINE(is_rotary_interleaved)
209+
210+
FLASHMASK_V2_HANDLE_ROUTINE(num_splits) // For split-KV version
211+
FLASHMASK_V2_HANDLE_ROUTINE(pack_gqa)
212+
213+
FLASHMASK_V2_HANDLE_ROUTINE(tile_count_semaphore)
214+
FLASHMASK_V2_HANDLE_ROUTINE(num_splits_dynamic_ptr)
215+
FLASHMASK_V2_HANDLE_ROUTINE(skip_scheduler_metadata_computation)
216+
217+
FLASHMASK_V2_HANDLE_ROUTINE(arch)
218+
FLASHMASK_V2_HANDLE_ROUTINE(num_sm)
219+
220+
FLASHMASK_V2_HANDLE_ROUTINE(h_flashmask)
221+
FLASHMASK_V2_HANDLE_ROUTINE(h_h_flashmask_ratio)
222+
FLASHMASK_V2_HANDLE_ROUTINE(lt_start_ptr)
223+
FLASHMASK_V2_HANDLE_ROUTINE(lt_end_ptr)
224+
FLASHMASK_V2_HANDLE_ROUTINE(ut_start_ptr)
225+
FLASHMASK_V2_HANDLE_ROUTINE(ut_end_ptr)
226+
FLASHMASK_V2_HANDLE_ROUTINE(flashmask_maxmin_ptr)
227+
228+
#define FLASHMASK_V2_BWD_HANDLE_ROUTINE(type, member) \
229+
DECLARE_DYNAMIC_LOAD_FLASHMASK_V2_WRAP(flashmaskv2_bwd_params_get_##member); \
230+
DECLARE_DYNAMIC_LOAD_FLASHMASK_V2_WRAP(flashmaskv2_bwd_params_set_##member);
231+
232+
// The dO and dQKV matrices.
233+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(void *, do_ptr)
234+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(void *, dq_ptr)
235+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(void *, dk_ptr)
236+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(void *, dv_ptr)
237+
238+
// To accumulate dQ
239+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(void *, dq_accum_ptr)
240+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(void *, dk_accum_ptr)
241+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(void *, dv_accum_ptr)
242+
243+
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
244+
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
245+
// dv_accum_ptr;
246+
247+
// The stride between rows of the dO, dQ, dK and dV matrices.
248+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int64_t, do_batch_stride)
249+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int64_t, do_row_stride)
250+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int64_t, do_head_stride)
251+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int64_t, dq_batch_stride)
252+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int64_t, dk_batch_stride)
253+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int64_t, dv_batch_stride)
254+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int64_t, dq_row_stride)
255+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int64_t, dk_row_stride)
256+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int64_t, dv_row_stride)
257+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int64_t, dq_head_stride)
258+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int64_t, dk_head_stride)
259+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int64_t, dv_head_stride)
260+
261+
// The pointer to the softmax d sum.
262+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(void *, dsoftmax_sum)
263+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(void *, softmax_lse_log2_ptr)
264+
265+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int *, dq_semaphore)
266+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int *, dk_semaphore)
267+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int *, dv_semaphore)
268+
269+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(bool, deterministic)
270+
FLASHMASK_V2_BWD_HANDLE_ROUTINE(int64_t, dq_accum_split_stride)
271+
#endif
272+
273+
#undef DYNAMIC_LOAD_FLASHMASK_V2_WRAP
274+
275+
} // namespace dynload
276+
} // namespace phi

0 commit comments

Comments
 (0)