Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CK_TILE] Add PagedAttention kernels #1387

Merged
merged 256 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
256 commits
Select commit Hold shift + click to select a range
d0b9fd0
Merge branch 'develop' into feature/refactor-fmha-codegen
poyenc Jun 23, 2024
4060416
Use dictionary to config all the functions
poyenc Jun 23, 2024
bace0e5
Add init codegen logic for fmha fwd appendkv
poyenc Jun 24, 2024
342c8cf
Call HIP_CHECK_ERROR() macro to get real source info
poyenc Jun 24, 2024
eee035a
Setup meaningfull arguments
poyenc Jun 24, 2024
3449027
Sync kernel name with the codegen
poyenc Jun 24, 2024
1ac17da
Add knew/vnew tensors to the kernel argument
poyenc Jun 25, 2024
4e6c285
Fix wrong K values after appending
poyenc Jun 25, 2024
8fb567c
Fix vnew append errro
poyenc Jun 26, 2024
c40c1da
Extract common logics
poyenc Jun 26, 2024
34a3ff8
Fix Vnew tile dstr for row major case
poyenc Jun 27, 2024
efd18fa
Conditionally add fwd_splitkv API in fmha_fwd example
poyenc Jul 8, 2024
82f3b3d
Conditionally add call to fmha_fwd_splitkv()
poyenc Jul 8, 2024
1c07038
Merge branch 'feature/cond-add-splitkv' into feature/fmha-fwd-appendkv
poyenc Jul 8, 2024
3aefb56
Remove "EXAMPLE_" prefix of cmake variables
poyenc Jul 8, 2024
aba46cd
Regsiter API handlers automatically
poyenc Jul 8, 2024
be076db
Merge branch 'feature/cond-add-splitkv' into feature/fmha-fwd-appendkv
poyenc Jul 8, 2024
fe4ae5d
Early return if 0 < s_k_new is not supported
poyenc Jul 8, 2024
6ca3910
Show message if we are ignoring option
poyenc Jul 8, 2024
5d21b4d
Merge branch 'feature/cond-add-splitkv' into feature/fmha-fwd-appendkv
poyenc Jul 8, 2024
8ac6bac
Unify CMakeLists.txt coding style
poyenc Jul 8, 2024
18a3834
Set num_splits=1 if split-kv is not supported
poyenc Jul 8, 2024
dc72074
Merge branch 'develop' into feature/cond-add-splitkv
poyenc Jul 8, 2024
2e164f1
Add length/stride getters for HostTensor
poyenc Jul 9, 2024
e939082
Add RoPE example utilities
poyenc Jul 9, 2024
f2d28e8
Add reference_rotary_position_embedding() (not implemented)
poyenc Jul 9, 2024
9d29311
Finish reference_rotary_position_embedding() impl
poyenc Jul 10, 2024
03b6d99
Fix typo of HostTensor<>::get_length()
poyenc Jul 10, 2024
8c733fb
Fix compilation errors
poyenc Jul 10, 2024
52da00a
Fix wrong answer when interleaved=false
poyenc Jul 10, 2024
ee365bb
Fix wrong answer when interleaved=true
poyenc Jul 11, 2024
bbdb0a5
Merge branch 'develop' into feature/cond-add-splitkv
carlushuang Jul 11, 2024
b4306af
Merge branch 'develop' into feature/cond-add-splitkv
poyenc Jul 12, 2024
b34ddf5
Merge remote-tracking branch 'origin/feature/cond-add-splitkv' into f…
poyenc Jul 12, 2024
4107bf0
Merge remote-tracking branch 'origin/feature/cond-add-splitkv' into f…
poyenc Jul 12, 2024
3578c6f
Append K/V in the host verification code
poyenc Jul 12, 2024
e5885ca
Simplify K appending logics
poyenc Jul 12, 2024
3183b68
Simplify v_host_ref definition
poyenc Jul 12, 2024
ff75eff
Reduce input/output dimensions
poyenc Jul 12, 2024
44c9bac
Rename function: add "batched" prefix
poyenc Jul 12, 2024
83d6acc
Apply RoPE on host side
poyenc Jul 14, 2024
93e5125
Rename RoPE utility function
poyenc Jul 14, 2024
55f5502
Fix wrong tensor size
poyenc Jul 14, 2024
5ce0fec
Merge branch 'develop' into feature/cond-add-splitkv
poyenc Jul 14, 2024
8c1647d
Avoid invoking deprecated method 'find_module'
poyenc Jul 14, 2024
c6717bb
Merge branch 'feature/cond-add-splitkv' of github.com:ROCm/composable…
poyenc Jul 14, 2024
b5ad141
Merge branch 'feature/cond-add-splitkv' into feature/fmha-fwd-appendkv
poyenc Jul 14, 2024
391210e
Pass RoPE kernel args
poyenc Jul 14, 2024
b0925bb
Create Rotary Cos/Sin tile windows in kernel
poyenc Jul 14, 2024
f6850ae
Add compute data type alias for RoPE
poyenc Jul 15, 2024
ad61d9d
Randomly generate seqlen_knew if needed
poyenc Jul 15, 2024
57c6a41
Fix seqlen_knew enabling check logic
poyenc Jul 15, 2024
1a093f9
Add minimum seqlen_k to generate compliance kvcache
poyenc Jul 15, 2024
4e01307
Fix compilation error in debug mode
poyenc Jul 15, 2024
65dac9f
Fix wrong boundaries
poyenc Jul 15, 2024
879710a
Fix wrong seqlen_k for kvcache
poyenc Jul 16, 2024
b32fd8d
Rename variables used in distributio encoding
poyenc Jul 16, 2024
99f863e
Fix rotary cos/sin tensor/tile size
poyenc Jul 16, 2024
e83c3c7
Add constraint to the rotary_dim option
poyenc Jul 16, 2024
39ef09b
Remove unused inner namespace
poyenc Jul 18, 2024
85bfed0
Add dram distribution for rotary_cos/rotary_sin (interleaved)
poyenc Jul 18, 2024
2345052
Only apply interleaved RoPE on Knew for now
poyenc Jul 18, 2024
27b5141
Fix wrong thread starting offset
poyenc Jul 18, 2024
fffd679
Instantiate multiple kernels for RoPE approaches
poyenc Jul 20, 2024
01865d2
Clean-up pipeline
poyenc Jul 22, 2024
1136e6b
Fix error in RoPE host reference
poyenc Jul 22, 2024
631f29d
Handle RoPE half-rotated logics
poyenc Jul 22, 2024
d1ecfdc
Support 8x rotary_dim under half-rotated RoPE
poyenc Jul 23, 2024
df352f9
Add comment
poyenc Jul 23, 2024
c0bc097
Apply elementwise function to the loaded tiles
poyenc Jul 23, 2024
c26c60d
Unify parameter/variable naming style
poyenc Jul 23, 2024
1dbed18
Remove constness from q_ptr
poyenc Jul 23, 2024
e88253a
Add code blocks for q_tile
poyenc Jul 23, 2024
48c7072
Apply RoPE to q_tile
poyenc Jul 23, 2024
56df4d6
Remove debug print code in kernel
poyenc Jul 23, 2024
bc7c7ee
Fix wrong knew/vnew appending positions
poyenc Jul 23, 2024
0925c0e
Use better naming for tile indices
poyenc Jul 23, 2024
7124f3e
Add make_tile_window() for adding distribution only
poyenc Jul 23, 2024
0e5cb6f
Skip code if # of block is more than needed
poyenc Jul 23, 2024
eb649a2
Move thread locating logics into policy
poyenc Jul 23, 2024
b275732
Remove always true static_assert()
poyenc Jul 23, 2024
d4606cf
Rename header
poyenc Jul 23, 2024
2192bbc
Rename RotaryEmbeddingEnum
poyenc Jul 23, 2024
fb80c7b
Extract rotary embedding logic out
poyenc Jul 23, 2024
ce5e0f1
Re-order parameters
poyenc Jul 23, 2024
99c1d46
Align naming of some tile size constants
poyenc Jul 23, 2024
52b4781
Rename more tile size constants
poyenc Jul 23, 2024
ca4b208
Fix wrong grid size
poyenc Jul 23, 2024
b11f92d
Fix wrong shape of knew_host/vnew_host
poyenc Jul 23, 2024
85bac93
Fix wrong index into knew_host/vnew_host
poyenc Jul 23, 2024
eb4ea3a
Fix wrong rotary_cos/rotary_sin memory size for Q
poyenc Jul 23, 2024
47a74f2
Extract Q/Knew vector size to helper methods
poyenc Jul 24, 2024
6f95239
Use different rotary_cos/rotary_sin distr for Q/Knew
poyenc Jul 24, 2024
5ea6071
Update host/device specifiers
poyenc Jul 24, 2024
3348131
Fix wrong data type for Q rotary_cos/rotary_sin
poyenc Jul 24, 2024
251f8cf
Merge branch 'develop' into feature/fmha-fwd-appendkv
poyenc Jul 24, 2024
a4da1e7
Remove RoPEComputeDataType type alias
poyenc Jul 24, 2024
59e1d9b
Shift rotary_cos/rotary_sin by cache_seqlen_k
poyenc Jul 24, 2024
c7b7b44
Add comment for why I just 't' for all padding flags
poyenc Jul 24, 2024
29c9b65
Align commit message to the real comment
poyenc Jul 24, 2024
d59e098
Fix wrong pipeline
poyenc Jul 24, 2024
8a73d33
Rename utility function
poyenc Jul 24, 2024
d84c915
Disable host verification if API not exist
poyenc Jul 24, 2024
08b4e8a
Fix wrong rope key for fp8 pipeline
poyenc Jul 24, 2024
f7fb3fa
Allow only apply RoPE on Q (without append KV)
poyenc Jul 24, 2024
2126d4d
Add append-kv smoke tests
poyenc Jul 24, 2024
5c733dc
Remove debug statements
poyenc Jul 24, 2024
8fb015b
Remove more debug statements
poyenc Jul 24, 2024
c50c36a
Re-arrange the 'set +x' command
poyenc Jul 24, 2024
bd28e96
Remove no-longer used method in pipeline
poyenc Jul 24, 2024
f053ae2
Add missing init code
poyenc Jul 24, 2024
4280a07
Refine pipeline padding settings
poyenc Jul 24, 2024
d41ff70
Enlarge rotary_dim limit (8 -> 16)
poyenc Jul 26, 2024
c1c50ee
Enlarge KPerThread for rotary_interleaved=false
poyenc Jul 26, 2024
94f430d
Update rotary_dim range in smoke_test_fwd.sh
poyenc Jul 26, 2024
e688d99
Merge remote-tracking branch 'origin/develop' into feature/fmha-fwd-a…
poyenc Jul 26, 2024
08d82ee
Merge branch 'develop' into feature/fmha-fwd-appendkv
poyenc Jul 30, 2024
3f71998
Merge branch 'develop' into feature/fmha-fwd-appendkv
poyenc Jul 31, 2024
e7969b9
Add template argument 'kIsPagedKV' for splitkv kernels
poyenc Aug 2, 2024
db95d25
Launch splitkv kernel if given page_block_size
poyenc Aug 2, 2024
baf4a61
Fix wrong kernel name
poyenc Aug 2, 2024
381f7e9
Merge branch 'develop' into feature/fmha-fwd-appendkv
poyenc Aug 4, 2024
90d84ea
Fix seqlen_k_min for pre-fill case (1 -> 0)
poyenc Aug 4, 2024
24cb604
Add copy_const<> type trait
poyenc Aug 5, 2024
55b77cf
Add another make_tile_window()
poyenc Aug 5, 2024
1c9d77b
Introduce 'TileWindowNavigator' types
poyenc Aug 5, 2024
ecaaa6f
Simplify TileWindowNavigator interfaces
poyenc Aug 5, 2024
8fea413
Fix tile window navigation bugs
poyenc Aug 5, 2024
3fc7279
Disable calling fmha_fwd()
poyenc Aug 5, 2024
bb78353
Remove ununnecessary data members
poyenc Aug 5, 2024
ab086bd
Simplify more make_tile_window() overloads
poyenc Aug 5, 2024
77dac77
Move V tile through TileWindowNavigator
poyenc Aug 5, 2024
8779716
Fix uneven split checking logic
poyenc Aug 6, 2024
4fed268
Move code after decide seqlen_q/seqlen_k
poyenc Aug 6, 2024
f9e2baf
Make sure we always start reading complete tile
poyenc Aug 6, 2024
12da00c
Use 128 as minimus page_block_size
poyenc Aug 6, 2024
faf6b0e
Fix wrong origin for bias
poyenc Aug 6, 2024
bd0d2f3
Add batch_stride_k/batch_stride_v in group mode
poyenc Aug 6, 2024
db31475
Unify origin
poyenc Aug 6, 2024
b989852
Add missing kernel arguments for group mode
poyenc Aug 6, 2024
15d0034
Add paged-kv codegen logic for appendkv kernels
poyenc Aug 7, 2024
443a528
Add block_table kernel args for appendkv kernel
poyenc Aug 7, 2024
7789b53
Add tile navigators to the appendkv kernel
poyenc Aug 7, 2024
78209c7
Fix wrong tensor descriptor lengths
poyenc Aug 7, 2024
26ed468
Pass re-created tile window to pipeline
poyenc Aug 7, 2024
838f995
Fix wrong strides for appendkv kernel
poyenc Aug 7, 2024
40f0d01
Allow transit tile_window to another page-block
poyenc Aug 7, 2024
f265742
Handle cross-page-block write
poyenc Aug 7, 2024
1b96dc2
Donot perform write again if already in last page-block
poyenc Aug 7, 2024
eda78d1
Merge branch 'develop' into feature/fmha-fwd-appendkv
poyenc Aug 7, 2024
55ce294
Always add fmha_fwd() api
poyenc Aug 7, 2024
b6c2f2f
Add missing group mode argument
poyenc Aug 7, 2024
cef9da0
Remove debug macro usages
poyenc Aug 7, 2024
655b13b
Rename option s_k_new to s_knew
poyenc Aug 7, 2024
291e9b4
Separate splitkv/non-splitkv args/traits
poyenc Aug 8, 2024
247e135
Remove fmha_fwd_dispatch()
poyenc Aug 8, 2024
9d9c5a6
Fix compilation errors
poyenc Aug 8, 2024
a0d2163
Remove dropout code in splitkv kernel
poyenc Aug 8, 2024
2f42e44
Allow problem types without define kHasDropout attr
poyenc Aug 8, 2024
677d9b2
Use generic lambda to init traits objects
poyenc Aug 8, 2024
c8f63d4
Separate more non-splitkv & splitkv traits/args
poyenc Aug 8, 2024
3e2b69e
Display more info for specific kernels
poyenc Aug 8, 2024
d3624a0
Merge branch 'develop' into feature/fmha-fwd-appendkv
poyenc Aug 8, 2024
e3a4bfb
Show more detailed warning message
poyenc Aug 8, 2024
9dddf6e
Rename 'max_num_blocks' to 'max_num_page_blocks'
poyenc Aug 8, 2024
d2f5d09
Remove no-longer used pipeline files
poyenc Aug 8, 2024
028d898
Wrap code by #if directives
poyenc Aug 8, 2024
9206808
Move functors to the begining of validation code
poyenc Aug 8, 2024
6a399ea
Use generic lambda to init all the api traits/args
poyenc Aug 8, 2024
822d5dc
Fix wrong seqlen for kvcache
poyenc Aug 8, 2024
e8603dc
Add missing comment
poyenc Aug 8, 2024
c54de64
Rename TileWindowNavigator to PageBlockNavigator
poyenc Aug 13, 2024
19c19d8
Only expose necessary methods (not attributes)
poyenc Aug 13, 2024
3dd6ef6
Re-order pipeline paremeters
poyenc Aug 13, 2024
d96752d
Refine smoke_test_fwd.sh
poyenc Aug 13, 2024
a8a2275
Fix wrong arugment count
poyenc Aug 13, 2024
370babc
Make tile window directly via PageBlockNavigator
poyenc Aug 13, 2024
9de0f35
Remove unused template paremeter
poyenc Aug 13, 2024
5805f5a
Remove group mode from appendkv kernel
poyenc Aug 16, 2024
a4c6029
Fix skcheck logic
poyenc Aug 16, 2024
aadd3ec
Fix wrong syntax in skcheck expr
poyenc Aug 16, 2024
f2b3620
Use meaningful options in smoke test
poyenc Aug 16, 2024
095819a
Remove options
poyenc Aug 16, 2024
5728c0b
Fix formatting
poyenc Aug 16, 2024
2523c8e
Fix more format
poyenc Aug 16, 2024
e6239e1
Re-organize bash functions
poyenc Aug 16, 2024
9c904b0
Pass cache_batch_idx to kernels
poyenc Aug 16, 2024
43b8100
Support cache_batch_idx in example
poyenc Aug 16, 2024
41fdf9b
Fix compilation error
poyenc Aug 16, 2024
51062ca
Merge remote-tracking branch 'origin/develop' into feature/fmha-fwd-a…
poyenc Aug 16, 2024
d3fd64c
Add more appendkv test
poyenc Aug 16, 2024
d52278a
Add more case for appendkv
poyenc Aug 16, 2024
34fea29
Fix unexisted attribute
rocking5566 Aug 16, 2024
352f6d5
Merge branch 'feature/fmha-fwd-appendkv' of github.com:ROCm/composabl…
poyenc Aug 16, 2024
c30d7f9
Remove 0 < seqlen_knew constraint
poyenc Aug 16, 2024
6b361f5
Clarify the case in warning message
poyenc Aug 18, 2024
cc52587
Remove macro checking
poyenc Aug 18, 2024
05157bf
Force batch mode when invoking appendkv & splitkv apis
poyenc Aug 18, 2024
48b7a5b
Fix mode overriding logics
poyenc Aug 18, 2024
3d3d73b
Fix wrong parameter name
poyenc Aug 18, 2024
996f46b
Randomize seqlen_k if use kvcache
poyenc Aug 18, 2024
e5db71c
Use randomized seqlen_k for kvcache
poyenc Aug 18, 2024
4cd3432
Avoid using too small rotary_cos & rotary_sin
poyenc Aug 18, 2024
a93c5e8
Rename parameter
poyenc Aug 18, 2024
8a856f5
Add seqlen_q & seqlen_k rules
poyenc Aug 18, 2024
90c2008
Add comment
poyenc Aug 18, 2024
9d5c33d
Add more comments
poyenc Aug 18, 2024
e8cd975
Fix compilation errors
poyenc Aug 18, 2024
f37cd41
Fix typo in comment
poyenc Aug 18, 2024
3fb77a0
Remove type argument
poyenc Aug 18, 2024
21c4df8
Avoid seqlen_k=0 for kvcache
poyenc Aug 19, 2024
3f0dab6
Revert "Avoid seqlen_k=0 for kvcache"
poyenc Aug 19, 2024
8166aa5
Fix wrong uneven split checking logics
poyenc Aug 19, 2024
b9a4ab0
Only randomize kvcache seqlen_k if 1 < batch
poyenc Aug 19, 2024
40a4d96
Return earlier if split is empty
poyenc Aug 19, 2024
60fe251
Revert "Only randomize kvcache seqlen_k if 1 < batch"
poyenc Aug 19, 2024
ee1445d
Re-order seqlen_k_start adjustment logics
poyenc Aug 19, 2024
eb1c8a2
Merge branch 'develop' into feature/fmha-fwd-appendkv
poyenc Aug 20, 2024
e23e6b5
Fix compilation errors
poyenc Aug 20, 2024
e44852b
Re-format script
poyenc Aug 20, 2024
d88ccc1
Merge branch 'develop' into feature/fmha-fwd-appendkv
poyenc Aug 20, 2024
6230a78
Find executable from folder automatically
poyenc Aug 20, 2024
8745f5f
Fix kvcache seqlen_k generating logic
poyenc Aug 20, 2024
3850028
Make comment more clear
poyenc Aug 20, 2024
73378ff
Fix wrong knew/vew appending logic on host
poyenc Aug 20, 2024
d3f550f
Add s_barrier to sync threads
poyenc Aug 22, 2024
ce0624a
Revert "Add s_barrier to sync threads"
poyenc Aug 22, 2024
d7603a3
Support only using 1 row of rotary_cos/rotary_sin
poyenc Aug 22, 2024
fd939bb
Rotate Q in different way
poyenc Aug 22, 2024
1b4fcf7
Unify tensor view creation logics
poyenc Aug 22, 2024
c24fba1
Fix wrong argument
poyenc Aug 22, 2024
8e62d6b
Add mask to switch how we use the rotary_cos/sin
poyenc Aug 22, 2024
e6c179f
Move attr from traits to problem
poyenc Aug 22, 2024
bb1c4ba
Move has_mask to fmha_fwd_appendkv_args
poyenc Aug 22, 2024
fc3b275
Support use uint32_t as SAD operand in Alibi<>
poyenc Aug 26, 2024
81a5412
Use sad_u32() in splitkv kernels
poyenc Aug 26, 2024
7fedc5c
Store tensor views in PageBlockNavigator
poyenc Aug 26, 2024
7774ae5
Use stored tensor view to update tile windows
poyenc Aug 26, 2024
99c3aaf
Enlarge tensor view size
poyenc Aug 26, 2024
6883006
Remove debug code
poyenc Aug 26, 2024
ff5ca5e
Fix wrong tensor view size
poyenc Aug 26, 2024
e322bf0
Wrap tensor view into PageBlockNavigator
poyenc Aug 26, 2024
fc91cfc
Add DataType member to PageBlockNavigator
poyenc Aug 26, 2024
7c25d18
Remove unnecessary member functions
poyenc Aug 26, 2024
da4bcaf
Refind macro use
poyenc Aug 26, 2024
c906293
Fix typo
poyenc Aug 26, 2024
9623c0c
Add blank line between directives and actual code
poyenc Aug 26, 2024
3be5f28
Re-format files
poyenc Aug 27, 2024
6fea8c7
Remove type in comment
poyenc Aug 27, 2024
91a316d
Merge branch 'develop' into feature/fmha-fwd-appendkv
poyenc Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions example/ck_tile/01_fmha/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
# generate a list of kernels, but not actually emit files at config stage
# validate user-specified fmha_fwd API list
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv")
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
endif()

foreach(api ${FMHA_FWD_ENABLE_APIS})
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.")
endif()
endforeach()

# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
endif()

string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
)

execute_process(
Expand All @@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR}
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
)

add_custom_command(
Expand Down Expand Up @@ -60,6 +80,20 @@ else()
endif()
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)

# conditionally enable call to the fwd_splitkv API in fmha_fwd example
if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
endif()

# conditionally enable call to the fwd_appendkv API in fmha_fwd example
if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
endif()

# Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
Expand Down
14 changes: 13 additions & 1 deletion example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ def get_mask_check_map(mask : str):
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
}

ROPE_MAP = {
"no" : "ck_tile::RotaryEmbeddingEnum::NONE",
"inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
"half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
}

ROPE_CHECK_MAP = {
"no" : "rope_enum::none",
"inter" : "rope_enum::interleaved",
"half" : "rope_enum::half_rotated"
}

MODE_MAP = {
"batch" : "false",
"group" : "true"
Expand All @@ -105,4 +117,4 @@ def get_mask_check_map(mask : str):
BOOL_MAP = {
"t" : "true",
"f" : "false"
}
}
Loading