-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Issue]: failed to run the tune_flash.py #32
Comments
This seems due to API changes in upstream Triton. It is recommended to use the bundled Triton to do the tuning since this is the actual compiler that generates the GPU code. We are moving to the upstream Triton but not going to happen immediately due to known bugs. |
Thanks @xinyazhang, Is there any obvious performance gain after autotuning in our repo?
My understand is that we need to setup some tune space like this https://github.com/michael-sandoval/aotriton/blob/b164a966c7eceac84cfda2c3719cbc3e5bcaa553/test/triton_attn_torch_function.py#L20, and then rebuild the project so that the build process will tune it automatically? please correct me if I am wrong. If I want to tune the seqlen_q and head_dim, how should I to that? Thanks |
Correct, and the right file to edit the tune space is
Currently it is not the case, the tuning database Tuning against The attached code is the generated autotuning dispatcher and hopefully will gain you some idea about how it works internally // Copyright © 2023-2024 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: MIT
// clang-format off
#define INCBIN_PREFIX g_aotriton_FAMILY_flash_KERNEL_attn_fwd_GPU_MI300X_
#define INCBIN_STYLE INCBIN_STYLE_SNAKE
#define mangle(x) g_aotriton_FAMILY_flash_KERNEL_attn_fwd_GPU_MI300X_ ## x ## _data
#define smangle(x) g_aotriton_FAMILY_flash_KERNEL_attn_fwd_GPU_MI300X_ ## x ## _size
#include "../shim.attn_fwd.h"
#include <aotriton/_internal/triton_kernel.h>
#include <incbin.h>
#include <iostream>
// ['Q', 'K', 'V', 'B', 'Out', 'encoded_softmax'] = *bf16:16 sm_scale = fp32 M = *fp32:16 ['stride_qz', 'stride_qh', 'stride_qm'] = u64:16 stride_qk = 1 ['stride_kz', 'stride_kh', 'stride_kn'] = u64:16 stride_kk = 1 ['stride_vz', 'stride_vh', 'stride_vk'] = u64:16 stride_vn = 1 ['stride_bz', 'stride_bh', 'stride_bm'] = u64:16 stride_bn = 1 ['stride_oz', 'stride_oh', 'stride_om'] = u64:16 stride_on = 1 ['seqlen_q', 'seqlen_k'] = i32 head_dim = u64 dropout_p = fp32 philox_seed = u64 philox_offset_base = u32 CAUSAL = False BLOCK_DMODEL = 128 ENABLE_DROPOUT = False RETURN_ENCODED_SOFTMAX = False PADDED_HEAD = False BIAS_TYPE = 0 ; BLOCK_M = 128 BLOCK_N = 64 pre_load_v = 1 ; num_warps=4 num_stages=1 waves_per_eu=1
#define CURRENT_ENTRY_PUBLIC Autotune_attn_fwd__A1__F208
INCBIN(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave1, "/home/xinyazha/aotriton/build/v2src/flash/gpu_kernel_image.attn_fwd/attn_fwd-Sig-F__^bf16@16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave1-Gpu-MI300X.hsaco.zst");
INCBIN(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave2, "/home/xinyazha/aotriton/build/v2src/flash/gpu_kernel_image.attn_fwd/attn_fwd-Sig-F__^bf16@16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave2-Gpu-MI300X.hsaco.zst");
INCBIN(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave1, "/home/xinyazha/aotriton/build/v2src/flash/gpu_kernel_image.attn_fwd/attn_fwd-Sig-F__^bf16@16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave1-Gpu-MI300X.hsaco.zst");
INCBIN(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave0, "/home/xinyazha/aotriton/build/v2src/flash/gpu_kernel_image.attn_fwd/attn_fwd-Sig-F__^bf16@16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave0-Gpu-MI300X.hsaco.zst");
INCBIN(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave0, "/home/xinyazha/aotriton/build/v2src/flash/gpu_kernel_image.attn_fwd/attn_fwd-Sig-F__^bf16@16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave0-Gpu-MI300X.hsaco.zst");
#ifndef NDEBUG
static const char* incbin_kernel_names[] = {
"F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave1",
"F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave2",
"F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave1",
"F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave0",
"F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave0"
};;
#endif
namespace { // Anonymous namespace
struct PerfFields {
int32_t BLOCK_M;
int32_t BLOCK_N;
bool pre_load_v;
};
PerfFields image_perf_list [] = {
{ .BLOCK_M = 128, .BLOCK_N = 64, .pre_load_v = 1 },
{ .BLOCK_M = 128, .BLOCK_N = 64, .pre_load_v = 0 },
{ .BLOCK_M = 128, .BLOCK_N = 64, .pre_load_v = 0 },
{ .BLOCK_M = 128, .BLOCK_N = 64, .pre_load_v = 1 },
{ .BLOCK_M = 128, .BLOCK_N = 64, .pre_load_v = 0 }
};
aotriton::TritonKernel image_list [] = {
{ mangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave1), smangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave1), { 256 , 1, 1 }, 34816 },
{ mangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave2), smangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave2), { 256 , 1, 1 }, 34816 },
{ mangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave1), smangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave1), { 256 , 1, 1 }, 34816 },
{ mangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave0), smangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_1__CO__warp4_stg1_wave0), { 256 , 1, 1 }, 34816 },
{ mangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave0), smangle(F__Ptrbf16Align16_False_128_False_False_False_0__P__128_64_0__CO__warp4_stg1_wave0), { 256 , 1, 1 }, 34816 },
};
uint8_t lut[6][6] = {{0,1,2,3,2,0},
{1,2,2,4,0,0},
{2,2,1,0,0,0},
{2,0,0,3,0,3},
{3,3,3,0,0,3},
{0,0,0,3,3,3}};
}; // End of anonymous namespace
namespace aotriton::v2::flash::autotune {
// using aotriton::v2::flash::AttnFwdParams;
void CURRENT_ENTRY_PUBLIC::operator()(AttnFwdParams& params) {
auto seqlen_q_binned_index = [] (int x) {
if (x <= 64) return 0;
if (x <= 128) return 1;
if (x <= 256) return 2;
if (x <= 512) return 3;
if (x <= 1024) return 4;
if (x <= 2048) return 5;
return 5;
}(params.seqlen_q);
auto seqlen_k_binned_index = [] (int x) {
if (x <= 64) return 0;
if (x <= 128) return 1;
if (x <= 256) return 2;
if (x <= 512) return 3;
if (x <= 1024) return 4;
if (x <= 2048) return 5;
return 5;
}(params.seqlen_k);
auto kernel_index = lut[seqlen_q_binned_index][seqlen_k_binned_index];
params.selected_kernel = &image_list[kernel_index];
#ifndef NDEBUG
std::cerr << __FILE__ << " kernel_index = " << int(kernel_index) << std::endl;
params._debug_kernel_name = incbin_kernel_names[kernel_index];
#endif
const auto& perf = image_perf_list[kernel_index];
params.BLOCK_M = perf.BLOCK_M;
params.BLOCK_N = perf.BLOCK_N;
params.pre_load_v = perf.pre_load_v;
}
#undef CURRENT_ENTRY_PUBLIC
#undef mangle
#undef smangle
} |
Thanks @xinyazhang , I can use tritonsrc/tune_flash.py to autotune the database for larger sequence 4096->16384, and for head_dim-128, llama-7b used this head-dim. By the way, looks we need to use scripts under tritonsrc to get better performance. |
Some error message would be helpful. However the most likely cause is the database got truncated. The updated database should have roughly identical size of the original one's. |
It seems the HSACO kernel was not correctly compiled. You probably want to remove |
It seems your tuning database contains an entry with |
Sure, Actually, there is no need to tune seqlen_q=16384, seqlen_k=2048, seqlen_q and seqlen_k should be equal length in our cases, looks we need to setup them manually in tune_flash.py. |
I have tuned the database with equal seqlen_q and seqlen_k, but it gives me some similar issue like this: the database is here: BTW, the latest repo after you checking the varlen can't run tune_flash.py with the following issue: Thanks |
@jinsong-mao I found the problem. For a pair of It is possible to fix this but our whole team got reassigned for another emergency task and the ETA is hard to estimate. |
@xinyazhang, I think it's better to fix it OOTB, looks the typical seqlen is larger than 2048 in most cases, and looks the default database only covers head_dim=64, the smallest llama2-7b has head_dim=128, the official flash attention(version 2.5.9) supports head_dim=256 now, which gives perf gain for larger models. |
Problem Description
Hi,
I can't run the tritonsrc/tune_flash.py to autotune the flash attention kernel with one specific problem size(just examples), the error message is like this:
we don't know the pre-optimized database contains what kind of sequence length, heads, head_dim, and we may need to tune our own sequence length with this script, However, we can't run it with the above issue, am I understand correctly? can we finetune the kernels for different seqlen and heads.
thanks
Operating System
ubuntu 22
CPU
5900x
GPU
AMD Instinct MI300X
ROCm Version
ROCm 6.1.0
ROCm Component
No response
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response
The text was updated successfully, but these errors were encountered: