-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enquiry] developing Flash Attention Transformer example using Hidet #281
Comments
Hi @keneoneth, Very glad to see that you have tried Hidet and use it to implement flash attention, even before we write the documentation for Hidet Script! You mentioned you have tried the "dlopen" method to load the generated "lib.so", can I know what error have you encountered? To use hidet generated kernels in other packages,
FYI: my colleage @hjjq has written a flash attention in hidet (see https://github.com/hidet-org/hidet/blob/main/python/hidet/graph/ops/attention/attention.py) and it can be a reference of attention implementation. If you only want to benchmark the performance, it is recommanded to directly do this in python, and it would be much easier. Let me know if you have other questions in using hidet and I am happy to help. |
Thanks @yaoyaoding for your help, I took a look at method 2) the source code of hidet_runtime.so and updated my code as below. The kernels can work properly now👍. This is my updated flash_attention_example.py import os
import math
import time
import numpy as np
import torch
import torch.nn as nn
torch.manual_seed(123)
# NOTE: this script is a simplified implementation of the following research work using Hidet
# Dao, T., Fu, D., Ermon, S., Rudra, A., & Ré, C. (2022). Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35, 16344-16359.
# link to paper: https://arxiv.org/abs/2205.14135
import hidet
from hidet.ir.compute import compute, reduce
from hidet.ir.task import Task
from hidet.ir.func import IRModule
from hidet.ir.primitives.cuda.atomic import atomic_add
from hidet.lang import f16, spatial, repeat, tensor, attr, grid, printf
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
from hidet.graph.ops.definitions.utils import input_like
from hidet.ir.expr import cast, address
from hidet.ir.primitives import exp, max, printf
from hidet import driver
from hidet.runtime.module import CompiledTaskCache
HIDET_BUILD_PATH = os.path.join(os.environ['HIDET_HOME'],"build/lib")
# define Flash Attention Task
class FlashAttentionTask(Task):
def allow_epilogue(self) -> bool:
return False
def flash_attention_implement_cuda(self, working_dir: str) -> IRModule:
# override this method to use template-based scheduling
return flash_attention_schedule(self)
# Require: Matrices Q�K�V Nxd in HBM, on-chip SRAM of size M.
# NOTE: typical SRAM size 100 kB, default to 48 kB
# NOTE: max thread num is set to 1024
def __init__(self,N=512,d=128,H=16,B=1,M=48*1024,ratio=12,max_thread_num=1024,disable_flash_attention=False):
# 1. set block sizes Bc = ceil(M/(4d)), Br = min(M/(4d),d)
Bc = math.ceil(M/(ratio*d))
Br = min(math.ceil(M/(ratio*d)),d)
Tr = math.ceil(N/Br)
Tc = math.ceil(N/Bc)
GLOBAL_Q = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_Q')
GLOBAL_K = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_K')
GLOBAL_V = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_V')
def normal_transformer():
matmulQK = compute(
name = 'GLOBAL_QK',
shape = [N, N],
fcompute = lambda i, j: reduce(
shape=[d],
fcompute=lambda k: GLOBAL_Q[i, k] * GLOBAL_K[j, k],
reduce_type='sum',
)
)
max_val = lambda i : reduce(shape=[N], fcompute=lambda j: matmulQK[i,j], reduce_type='max')
S = compute(
name = 'S',
shape = [N, N],
fcompute = lambda i,j: matmulQK[i,j] - max_val(i)
)
exp_s = compute(
name = 'exp_s',
shape = [N, N],
fcompute = lambda i,j: exp(S[i,j])
)
exp_sum = lambda i : reduce(shape=[N], fcompute=lambda j: exp_s[i,j], reduce_type='sum')
softmax = compute('softmax', shape=[N,N], fcompute=lambda i,j: exp_s[i,j] / exp_sum(i))
matmulPV = compute(
name = 'GLOBAL_O',
shape = [N, d],
fcompute = lambda i, j: reduce(
shape=[N],
fcompute=lambda k: softmax[i, k] * GLOBAL_V[k, j],
reduce_type='sum',
)
)
return matmulPV
super().__init__(
name='flash_attention_task',
inputs=[GLOBAL_Q,GLOBAL_K,GLOBAL_V],
outputs=[normal_transformer()],
attributes={
'B' : B,
'H' : H,
'N' : N,
'd' : d,
'Bc' : Bc,
'Br' : Br,
'Tc' : Tc,
'Tr' : Tr,
'BLK' : Tr,
'THD' : Br * Bc,
'MAX_THD' : max_thread_num
},
)
if not disable_flash_attention:
self.implement_cuda = self.flash_attention_implement_cuda
self.define = f'-DRUN_FLASH_ATTN -DHIDET_BUILD_PATH=\\"{HIDET_BUILD_PATH}\\"'
else:
self.define = f'-DHIDET_BUILD_PATH=\\"{HIDET_BUILD_PATH}\\"'
# define custom schedule
def flash_attention_schedule(task:FlashAttentionTask) -> IRModule:
print_debug = False
B = task.attrs['B']
H = task.attrs['H']
N = task.attrs['N']
d = task.attrs['d']
Bc = task.attrs['Bc']
Br = task.attrs['Br']
Tr = task.attrs['Tr']
Tc = task.attrs['Tc']
dims = ( task.attrs['BLK'] )
threads = task.attrs['THD']
assert threads <= task.attrs['MAX_THD'], f'err: {threads} not < {task.attrs["MAX_THD"]}'
assert d % Bc == 0, f'err: Bc is not divisible by d'
assert d % Br == 0, f'err: Br is not divisible by d'
largest_fp16_value = 65504
print(f'task.attrs {task.attrs}')
# define the tensor program
with hidet.script_module() as module:
"""Flash attention kernel."""
@hidet.script
def QK_matmul_compute(A:f16[Br,d],B:f16[d,Bc],C:f16[Br,Bc]):
for m,n in spatial(Br,Bc).on(threadIdx.x):
C[m,n] = 0.0
syncthreads()
for m,k,n in spatial(Br,1,Bc).repeat(1,d,1).on(threadIdx.x):
atomic_add(~C[m,n],A[m,k] * B[k,n])
syncthreads()
@hidet.script
def PV_matmul_compute(A:f16[Br,Bc],B:f16[Bc,d],C:f16[Br,d]):
for m,n in spatial(Br,Bc).repeat(1,d//Bc).on(threadIdx.x):
C[m,n] = 0.0
syncthreads()
for m,k,n in spatial(Br,1,Bc).repeat(1,Bc,d//Bc).on(threadIdx.x):
atomic_add(~C[m,n],A[m,k] * B[k,n])
syncthreads()
@hidet.script
def rowmax_compute(A:f16[Br,Bc],M:f16[Br],T:f16[Br,Bc]):
for i,j in spatial(Br,Bc).on(threadIdx.x):
T.write([i,j],A[i,j],protected=True)
syncthreads()
for i,j in spatial(Br,Bc).on(threadIdx.x):
k = 1
while k < Bc:
if j % (k*2) == 0:
T.write([i,j],max(T[i,j],T[i,j+k]),protected=True)
syncthreads()
k *= 2
for i in spatial(Br).on(threadIdx.x):
if threadIdx.x < Br:
M[i] = T[i,0]
syncthreads()
@hidet.script
def rowsum_compute(A:f16[Br,Bc],L:f16[Br],T:f16[Br,Bc]):
for i,j in spatial(Br,Bc).on(threadIdx.x):
T.write([i,j],A[i,j],protected=True)
syncthreads()
for i,j in spatial(Br,Bc).on(threadIdx.x):
k = 1
while k < Bc:
if j % (k*2) == 0:
T.write([i,j],(T[i,j]+T[i,j+k]),protected=True)
syncthreads()
k *= 2
for i in spatial(Br).on(threadIdx.x):
if threadIdx.x < Br:
L[i] = T[i,0]
syncthreads()
@hidet.script
def local_softmax_compute(S:f16[Br,Bc],M:f16[Br]):
for i,j in spatial(Br,Bc).on(threadIdx.x):
if False and blockIdx.x==0:
printf("S[i,j] before %d %d %d %d : %f - %f\n",blockIdx.x,threadIdx.x,i,j,cast(S[i,j],"float32"),cast(M[i],"float32"))
S[i,j] = exp(S[i,j] - M[i])
if False and blockIdx.x==0:
printf("S[i,j] %d %d %d %d : %f\n",blockIdx.x,threadIdx.x,i,j,cast(S[i,j],"float32"))
syncthreads()
@hidet.script
def local_update_compute(M:f16[Br],M_new:f16[Br],M_local:f16[Br],L:f16[Br],L_new:f16[Br],L_local:f16[Br]):
for i in spatial(Br).on(threadIdx.x):
if threadIdx.x < Br:
M_new[i] = max(M[i],M_local[i])
L_new[i] = exp(M[i] - M_new[i]) * L[i] + exp(M_local[i] - M_new[i]) * L_local[i]
syncthreads()
@hidet.script
def global_update_compute(PV:f16[Br,d],O:f16[Br,d],M_local:f16[Br],M_new:f16[Br],M:f16[Br],L_new:f16[Br],L:f16[Br]):
for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
O.write(
[i,j],
((L_new[i]**-1) * (L[i]*exp(M[i]-M_new[i])) * O[i,j]) + (exp(M_local[i]-M_new[i]) * PV[i,j]),
protected=True
)
syncthreads()
@hidet.script
def flash_attention_kernel(
Q: f16[N,d],
K: f16[N,d],
V: f16[N,d],
O: f16[N,d]
):
attr.cuda_grid_dim = dims
attr.cuda_block_dim = threads
# Init O=(0), N x d in HBM
for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
offset_i = blockIdx.x * (Br)
O[offset_i:,:].write([i,j], 0, protected=True)
syncthreads()
smem_q = tensor('shared', 'float16', [Br, d])
smem_k = tensor('shared', 'float16', [d, Bc]) # transposed
smem_v = tensor('shared', 'float16', [Bc, d])
smem_o = tensor('shared', 'float16', [Br, d])
smem_l = tensor('shared', 'float16', [Br])
smem_l_local = tensor('shared', 'float16', [Br])
smem_l_new = tensor('shared', 'float16', [Br])
smem_m = tensor('shared', 'float16', [Br])
smem_m_local = tensor('shared', 'float16', [Br])
smem_m_new = tensor('shared', 'float16', [Br])
smem_sp = tensor('shared', 'float16', [Br,Bc])
smem_pv = tensor('shared', 'float16', [Br,d])
smem_temp = tensor('shared', 'float16', [Br,Bc])
for a,b in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
# load Qi from HBM to on-chip SRAM
# initialization of o,l,m
offset_i = blockIdx.x * (Br)
smem_q[a,b] = Q[offset_i:,:].read([a,b],protected=True)
smem_o[a,b] = 0
smem_l[a] = 0
smem_m[a] = -largest_fp16_value
syncthreads()
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
idx = 0
for i,j in grid(Br,d):
printf("idx: %d, Q val: %f\n",idx,cast(smem_q[i,j],"float32"))
idx += 1
syncthreads()
for j in grid(Tc):
for a,b in spatial(Bc,Br).repeat(1,(d//Br)).on(threadIdx.x):
# load Kj,Vj from HBM to on-chip SRAM
offset_j = j * (Bc)
smem_k[b,a] = K[offset_j:,:].read([a,b],protected=True)
smem_v[a,b] = V[offset_j:,:].read([a,b],protected=True)
syncthreads()
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
idx = 0
for i,j in grid(d,Bc):
printf("idx: %d, K val: %f\n",idx,cast(smem_k[i,j],"float32"))
idx += 1
for i,j in grid(Bc,d):
printf("idx: %d, V val: %f\n",idx,cast(smem_v[i,j],"float32"))
idx += 1
syncthreads()
# on chip, compute Sij = Qi @ (Kj)^T, Br X Bc
QK_matmul_compute(smem_q,smem_k,smem_sp)
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
idx = 0
for i,j in grid(Br,Bc):
printf("idx: %d, S val: %f\n",idx,cast(smem_sp[i,j],"float32"))
idx += 1
syncthreads()
# on chip, compute m'_ij = rowmax(Sij), Br; Pij = exp(Sij - m'_ij), Br x Bc (pointwise); l'_ij = rowsum(P'ij), Br
rowmax_compute(smem_sp,smem_m_local,smem_temp)
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
for i in grid(Br):
printf("i: %d M val: %f\n",i,cast(smem_m_local[i],"float32"))
# for j in grid(Bc):
# printf("j: %d, S val: %f\n",j,cast(smem_sp[i,j],"float32"))
syncthreads()
local_softmax_compute(smem_sp,smem_m_local)
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
idx = 0
for i,j in grid(Br,Bc):
printf("idx: %d, P val: %f\n",idx,cast(smem_sp[i,j],"float32"))
idx += 1
syncthreads()
rowsum_compute(smem_sp,smem_l_local,smem_temp)
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
for i in grid(Br):
printf("i: %d L val: %f\n",i,cast(smem_l_local[i],"float32"))
# for j in grid(Bc):
# printf("j: %d, P val: %f\n",j,cast(smem_sp[i,j],"float32"))
syncthreads()
# on chip, compute m_new_i = max(m_i,m'_ij), Br; l_new_i = e^(m_i - m_new_i) * l_i + e^(m'_ij - m_i_new) * l'_ij, Br
local_update_compute(smem_m,smem_m_new,smem_m_local,smem_l,smem_l_new,smem_l_local)
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
for i in grid(Br):
printf("i: %d smem_m val: %f\n",i,cast(smem_m[i],"float32"))
printf("i: %d smem_m_new val: %f\n",i,cast(smem_m[i],"float32"))
printf("i: %d smem_m_local val: %f\n",i,cast(smem_m[i],"float32"))
printf("i: %d smem_l val: %f\n",i,cast(smem_m[i],"float32"))
printf("i: %d smem_l_new val: %f\n",i,cast(smem_m[i],"float32"))
printf("i: %d smem_l_local val: %f\n",i,cast(smem_m[i],"float32"))
syncthreads()
# write Oi = diag(l_i_new)^-1 * (diag(l_i)*e^(m_i-m_i_new) @ Oi + e^*m'_ij-m_i_new*(P'ij @ Vj))
PV_matmul_compute(smem_sp,smem_v,smem_pv)
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
idx = 0
for i,j in grid(Br,d):
printf("idx: %d, PV val: %f\n",idx,cast(smem_pv[i,j],"float32"))
idx += 1
syncthreads()
global_update_compute(smem_pv,smem_o,smem_m_local,smem_m_new,smem_m,smem_l_new,smem_l)
if j + 1 == Tc:
for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
offset_i = blockIdx.x * (Br)
O[offset_i:,:].write([i,j], smem_o[i,j], protected=True)
syncthreads()
# write l_i = l_i_new, m_i = m_i_new
for i in spatial(Br).on(threadIdx.x):
if threadIdx.x < Br:
smem_m[i] = smem_m_new[i]
smem_l[i] = smem_l_new[i]
syncthreads()
if print_debug and (blockIdx.x==15 and threadIdx.x==0):
idx = 0
for i,j in grid(Br,d):
offset_i = blockIdx.x * (Br)
printf("blockIdx %d : output idx: %d, val: %f\n",blockIdx.x,idx,cast(O[offset_i+i,j],"float32"))
idx += 1
syncthreads()
return
@hidet.script
def flash_attention_launch_func(
G_Q: f16[B, H, N, d],
G_K: f16[B, H, N, d],
G_V: f16[B, H, N, d],
G_O: f16[B, H, N, d]
):
# NOTE: this section needs to be written in flash_attention_main.cu
for b,h in grid(B,H):
flash_attention_kernel(
address(G_Q[b,h,0,0]),
address(G_K[b,h,0,0]),
address(G_V[b,h,0,0]),
address(G_O[b,h,0,0])
)
# build ir module
ir_module = module.ir_module()
return ir_module
# gen Python gold data as reference
def gen_gold(attrs,r1=-3,r2=3):
Q = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
K = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
V = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
t = time.process_time()
Q.half().numpy().tofile('mat_Q.bin')
K.half().numpy().tofile('mat_K.bin')
V.half().numpy().tofile('mat_V.bin')
S = torch.from_numpy(Q.numpy() @ torch.transpose(K, -2, -1).numpy())
row_max, _ = torch.max(S,dim=-1)
S = torch.from_numpy(np.exp((S - row_max.reshape(attrs['B'],attrs['H'],attrs['N'],1)).numpy()))
row_sum = torch.sum(S,dim=-1).reshape(attrs['B'],attrs['H'],attrs['N'],1)
P = S / row_sum
# TODO: test with softmax float precision
# P = nn.Softmax(dim=-1)(S.float()).half()
O = torch.from_numpy(P.numpy() @ V.numpy())
elapsed_time = (time.process_time() - t)*1000
print(f"Python gold gen run elapsed time {round(elapsed_time,3)} msec")
O.half().numpy().tofile('gold_mat_O.bin')
# run task
def run_task(disable_flash_attention=False):
# clear cache
driver.compiled_task_cache = CompiledTaskCache()
# define the task here
flash_attention_task = FlashAttentionTask(disable_flash_attention=disable_flash_attention)
# build the task
ret = flash_attention_task.build('cuda')
# copy source file and lib to current directory
source_path = ret.src_path
library_path = ret.lib_path
print(f'source_path {source_path} library_path {library_path}')
import shutil
shutil.move(source_path,os.path.join("./","flash_attention_"+os.path.basename(source_path)))
shutil.move(library_path,os.path.join("./","flash_attention_"+os.path.basename(library_path)))
# generate golden data
gen_gold(flash_attention_task.attrs)
def exe_f(command='', shell=True):
print(f'running {command}')
import subprocess
process = subprocess.Popen(command, shell=shell)
code = process.wait()
process.communicate()
return code
# launch testcase flash_attention_main.cu
HIDET_CUDA_INCLUDE_PATH = "../cuda-samples-master/Common/"
CUDA_SAMPLES_INCLUDE_PATH = "../../include/"
ret = exe_f(f'nvcc flash_attention_main.cu {flash_attention_task.define} -gencode arch=compute_86,code=sm_86 -I {CUDA_SAMPLES_INCLUDE_PATH} -I {HIDET_CUDA_INCLUDE_PATH} -std=c++11 -o fa.out && ./fa.out -BATCH={flash_attention_task.attrs["B"]} -HEAD={flash_attention_task.attrs["H"]} -BLK={flash_attention_task.attrs["BLK"]} -THD={flash_attention_task.attrs["THD"]}')
print('test done' if ret==0 else 'test error')
# main function
if __name__ == '__main__':
# normal approach execution
run_task(disable_flash_attention=True)
# flash attention approach execution
run_task(disable_flash_attention=False) And here's my flash_attention_main.cu, the code written to do hidet_launch // System includes
#include <stdio.h>
#include <sys/stat.h>
#include <dlfcn.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <assert.h>
#include <string>
#include <vector>
// CUDA runtime
#include <cuda_runtime.h>
#include <cuda_profiler_api.h>
// Helper functions and utilities to work with CUDA,
#include <helper_functions.h>
#include <helper_cuda.h>
#include <cuda_fp16.h>
typedef void (*hidet_launch_t)(int32_t num_args, int32_t * __restrict__ arg_types, void* * __restrict__ args);
typedef void (*set_cuda_stream_t)(void*);
typedef void (*register_callback_t)(const char* name, void *func_ptr);
// typedef uint8_t* (*cudaMallocCallback_t)(uint64_t nbytes);
void *cuda_lib_handle;
void *runtime_lib_handle;
hidet_launch_t hidet_launch_func;
void * allocate_cuda_storage_func;
uint8_t * cudaMallocCallback(uint64_t nbytes) {
uint8_t * buffer;
checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&buffer), nbytes));
return buffer;
}
void set_kernel_func() {
printf("setting kernel func ...\n");
char *error;
cuda_lib_handle = dlopen("./flash_attention_lib.so", RTLD_LAZY | RTLD_LOCAL);
if (!cuda_lib_handle)
{
fprintf(stderr, "%s\n", dlerror());
exit(EXIT_FAILURE);
}
dlerror();
hidet_launch_func = (hidet_launch_t) dlsym(cuda_lib_handle, "hidet_launch");
if ((error = dlerror()) != NULL)
{
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
}
// #define STRINGIFY(x) #x
void setup_kernel_run(cudaStream_t & stream) {
printf("setting up kernel run ...\n");
char *error;
std::string libpath = HIDET_BUILD_PATH "/libhidet_runtime.so";
printf("hidet runtime libpath %s", libpath.c_str());
runtime_lib_handle = dlopen(libpath.c_str(), RTLD_LAZY | RTLD_LOCAL);
if (!runtime_lib_handle)
{
fprintf(stderr, "%s\n", dlerror());
exit(EXIT_FAILURE);
}
dlerror();
set_cuda_stream_t set_cuda_stream_func = (set_cuda_stream_t) dlsym(runtime_lib_handle, "set_cuda_stream");
if ((error = dlerror()) != NULL)
{
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
// set stream
set_cuda_stream_func(stream);
register_callback_t register_callback_func = (register_callback_t) dlsym(runtime_lib_handle, "register_callback");
if ((error = dlerror()) != NULL)
{
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
assert (register_callback_func!=nullptr);
register_callback_func("allocate_cuda_storage", (void *) (cudaMallocCallback));
register_callback_func("cuda_memset", (void *) (cudaMemset));
}
// set_cuda_stream
// test function, execute kernel, compare with gold data
int flash_attention_test(
unsigned int B, unsigned int H,
unsigned int block_size, unsigned int thread_size,
half *h_Q, unsigned int size_Q,
half *h_K, unsigned int size_K,
half *h_V, unsigned int size_V,
half *h_gold_O, unsigned int size_O)
{
// set up kernel function
set_kernel_func();
assert(hidet_launch_func!=nullptr);
cudaStream_t stream;
const unsigned int BH = B * H;
// Allocate device memory
half *d_Q, *d_K, *d_V, *d_O, *h_O;
checkCudaErrors(cudaMallocHost(&h_O, size_O * sizeof(half)));
if (h_O == NULL)
{
fprintf(stderr, "Failed to allocate host matrix O!\n");
exit(EXIT_FAILURE);
}
checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_Q), size_Q * sizeof(half)));
checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_K), size_K * sizeof(half)));
checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_V), size_V * sizeof(half)));
checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_O), size_O * sizeof(half)));
// Allocate CUDA events that we'll use for timing
cudaEvent_t start, stop;
checkCudaErrors(cudaEventCreate(&start));
checkCudaErrors(cudaEventCreate(&stop));
checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
// copy host memory to device
checkCudaErrors(
cudaMemcpyAsync(d_Q, h_Q, size_Q * sizeof(half), cudaMemcpyHostToDevice, stream));
checkCudaErrors(
cudaMemcpyAsync(d_K, h_K, size_K * sizeof(half), cudaMemcpyHostToDevice, stream));
checkCudaErrors(
cudaMemcpyAsync(d_V, h_V, size_V * sizeof(half), cudaMemcpyHostToDevice, stream));
const unsigned int k_size_Q = (size_Q / BH);
const unsigned int k_size_K = (size_K / BH);
const unsigned int k_size_V = (size_V / BH);
const unsigned int k_size_O = (size_O / BH);
printf("k_size_Q %u k_size_K %u k_size_V %u k_size_O %u\n", k_size_Q, k_size_K, k_size_V, k_size_O);
// prepare setup for kernel run
setup_kernel_run(stream);
// Record the start event
checkCudaErrors(cudaEventRecord(start, stream));
const int32_t num_args = 4;
int32_t arg_types[num_args] = {3,3,3,3};
for (unsigned int b = 0; b < B; b++)
{
for (unsigned int h = 0; h < H; h++)
{
unsigned int offset_index = (b * H) + h;
half *param[num_args] = {
d_Q + offset_index * k_size_Q,
d_K + offset_index * k_size_K,
d_V + offset_index * k_size_V,
d_O + offset_index * k_size_O};
void* args[num_args] = {param[0],param[1],param[2],param[3]};
hidet_launch_func(num_args, arg_types, args);
}
}
// stream sync
checkCudaErrors(cudaStreamSynchronize(stream));
// Record the stop event
checkCudaErrors(cudaEventRecord(stop, stream));
printf("test done !!!\n");
// Wait for the stop event to complete
checkCudaErrors(cudaEventSynchronize(stop));
float msecTotal = 0.0f;
checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));
// Compute and print the performance
#if RUN_FLASH_ATTN
printf("flash attention elapsed time = %.3f msec\n", msecTotal);
#else
printf("normal approach elapsed time = %.3f msec\n", msecTotal);
#endif // RUN_FLASH_ATTN
// Copy result from device to host
checkCudaErrors(
cudaMemcpyAsync(h_O, d_O, size_O * sizeof(half), cudaMemcpyDeviceToHost, stream));
checkCudaErrors(cudaStreamSynchronize(stream));
printf("Checking computed result for correctness: \n");
double eps = 0.01; // 1% error with python output
const unsigned int max_print_count = 100;
uint32_t total_count = 0;
uint32_t total_err_count = 0;
for (int i = 0; i < static_cast<int>(size_O); i++)
{
double gold_val = fabs((double)h_gold_O[i]);
double abs_val = fabs((double)h_O[i]);
double abs_err = fabs(abs_val - gold_val);
double rel_err = abs_err / abs_val;
if (rel_err > eps)
{
if (total_err_count < max_print_count)
printf("Error! Matrix[%05d]=%.8f, ref=%.8f error term %E is > %E\n",
i, (double)h_O[i], (double)h_gold_O[i], rel_err, eps);
total_err_count++;
}
total_count++;
}
double error_ratio = (double)total_err_count / (double)total_count;
bool correct = error_ratio < eps;
printf("total count %u total error count %u (%.8f %%)\n", total_count, total_err_count, error_ratio * 100);
printf("%s\n", correct ? "Result = PASS" : "Result = FAIL");
// Clean up memory
checkCudaErrors(cudaFree(d_Q));
checkCudaErrors(cudaFree(d_K));
checkCudaErrors(cudaFree(d_V));
checkCudaErrors(cudaFree(d_O));
checkCudaErrors(cudaEventDestroy(start));
checkCudaErrors(cudaEventDestroy(stop));
// close dynamic library
if(cuda_lib_handle!=nullptr) dlclose(cuda_lib_handle);
if(runtime_lib_handle!=nullptr) dlclose(runtime_lib_handle);
if (correct)
{
return EXIT_SUCCESS;
}
else
{
return EXIT_FAILURE;
}
}
inline bool file_exists(const std::string &name)
{
struct stat buffer;
return (stat(name.c_str(), &buffer) == 0);
}
void load_data(std::vector<half> &matrix, const std::string bin_file)
{
printf("loading %s\n", bin_file.c_str());
assert(file_exists(bin_file) && "Error! binary file doesn't exist");
std::ifstream fin(bin_file, std::ios::binary);
half elem;
while (fin.read(reinterpret_cast<char *>(&elem), sizeof(half)))
{
matrix.push_back(elem);
}
}
int main(int argc, char **argv)
{
printf("[Flash Attention Using CUDA] - Starting...\n");
if (checkCmdLineFlag(argc, (const char **)argv, "help") ||
checkCmdLineFlag(argc, (const char **)argv, "?"))
{
printf("Usage -device=n (n >= 0 for deviceID)\n");
printf(" -BATCH=number of Batch\n");
printf(" -HEAD=number of Head\n");
printf(" -BLK=block size\n");
printf(" -THD=thread size\n");
exit(EXIT_SUCCESS);
}
// This will pick the best possible CUDA capable device, otherwise
// override the device ID based on input provided at the command line
int dev = findCudaDevice(argc, (const char **)argv);
unsigned int batch = 1;
if (checkCmdLineFlag(argc, (const char **)argv, "BATCH"))
{
batch = getCmdLineArgumentInt(argc, (const char **)argv, "BATCH");
}
unsigned int head = 1;
if (checkCmdLineFlag(argc, (const char **)argv, "HEAD"))
{
head = getCmdLineArgumentInt(argc, (const char **)argv, "HEAD");
}
unsigned int block_size = 1;
if (checkCmdLineFlag(argc, (const char **)argv, "BLK"))
{
block_size = getCmdLineArgumentInt(argc, (const char **)argv, "BLK");
}
unsigned int thread_size = 1;
if (checkCmdLineFlag(argc, (const char **)argv, "THD"))
{
thread_size = getCmdLineArgumentInt(argc, (const char **)argv, "THD");
}
// load Q
std::vector<half> mat_Q;
load_data(mat_Q, "./mat_Q.bin");
// load K
std::vector<half> mat_K;
load_data(mat_K, "./mat_K.bin");
// load V
std::vector<half> mat_V;
load_data(mat_V, "./mat_V.bin");
// load golden data O
std::vector<half> gold_mat_O;
load_data(gold_mat_O, "./gold_mat_O.bin");
printf("batch %u head %u block_size %u thread_size %u\n", batch, head, block_size, thread_size);
printf("Q size %lu K size %lu V size %lu O size %lu\n", mat_Q.size(), mat_K.size(), mat_V.size(), gold_mat_O.size());
checkCudaErrors(cudaProfilerStart());
int result = flash_attention_test(
batch, head, block_size, thread_size,
&mat_Q[0], mat_Q.size(),
&mat_K[0], mat_K.size(),
&mat_V[0], mat_V.size(),
&gold_mat_O[0], gold_mat_O.size());
checkCudaErrors(cudaProfilerStop());
exit(result);
} And I have some follow up questions Thank you very much again for your help 🙏 |
Hi @keneoneth, Glad you run the generated lib.so successfully. For your questions: ii) In C++, no. In Python, please refer to our tests (https://github.com/hidet-org/hidet/tree/main/tests/operators) to see how we compare the results with numpy/pytorch. iii) Deepview for now is designed to profile and predict the performance of training and it's built based on pytorch. You need to run the kernel on the target GPU to know the performance. That being said, the idea of Deepview also applied to hidet. |
Closes #278 Allowing the case in calls `getitem(x, index)` where `x` is a tensor on GPU whereas `index` is a tensor on CPU, which is allowed in PyTorch. Also reinforced the same constraint as PyTorch: `index` must be either on GPU or the same device as `x`.
Closes #278 Allowing the case in calls `getitem(x, index)` where `x` is a tensor on GPU whereas `index` is a tensor on CPU, which is allowed in PyTorch. Also reinforced the same constraint as PyTorch: `index` must be either on GPU or the same device as `x`.
Closes #278 Allowing the case in calls `getitem(x, index)` where `x` is a tensor on GPU whereas `index` is a tensor on CPU, which is allowed in PyTorch. Also reinforced the same constraint as PyTorch: `index` must be either on GPU or the same device as `x`.
Hello guys, really appreciate your work on Hidet. It is an awesome tool and it really makes developer's life easier when writing custom schedule for their CUDA kernel for performance optimization👍👍!
To test on Hidet's features, I am currently writing an example of the Flash Attention Transformer (link to research work: https://arxiv.org/abs/2205.14135) using the Hidet tool stack. I have writteb my custom testing setup (which contains my own host/device memory allocation & performance tracking & precision comparison code) in my "flash_attention_main.cu", and I am trying to call the kernel functions in Hidet generated cuda dynamic library.
May I know if there is a standard way of doing this? I tried using "dlopen" to load the library and launch the kernel functions but unfortunately it is not working properly. I therefore just manually copied the Hidet generated cuda source code to two separate header files "flash_attention_kernel_func.h" and "normal_transformer_kernel_func.h" and include them in my "flash_attention_main.cu". And I directly compile "flash_attention_main.cu" and everything works properly as well.
Let me share some source code below for illustration.
Here is my flash_attention_example.py, which includes the flash attention custom schedule and the normal approach.
Here is my flash_attention_main.cu, which includes the performance tracking, precision comparison & memory allocation operations, and it lauches the test kernels.
Here are the flash_attention_kernel_func.h and normal_transformer_func.h, respectively.
Again, really wonderful work on Hidet! And any help will be well appreciated 🙏 Or if any further info. is needed, please let me know.
The text was updated successfully, but these errors were encountered: