Skip to content

Commit 436e523

Browse files
authored
Refactor attention kernels (#53)
1 parent 27f1410 commit 436e523

14 files changed

+1253
-2569
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
#include "attention_generic.cuh"
4+
#include "dtype_float16.cuh"
5+
#include "dtype_float32.cuh"
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#pragma once
2+
3+
#include <stdint.h>
4+
5+
namespace cacheflow {
6+
7+
// A vector type to store Q, K, V elements.
8+
template<typename T, int VEC_SIZE>
9+
struct Vec {};
10+
11+
// A vector type to store FP32 accumulators.
12+
template<typename T>
13+
struct FloatVec {};
14+
15+
// Template vector operations.
16+
template<typename Acc, typename A, typename B>
17+
inline __device__ Acc mul(A a, B b);
18+
19+
template<typename T>
20+
inline __device__ float sum(T v);
21+
22+
template<typename T>
23+
inline __device__ float dot(T a, T b) {
24+
return sum(mul<T, T, T>(a, b));
25+
}
26+
27+
template<typename A, typename T>
28+
inline __device__ float dot(T a, T b) {
29+
return sum(mul<A, T, T>(a, b));
30+
}
31+
32+
template<typename T>
33+
inline __device__ void zero(T& dst) {
34+
constexpr int WORDS = sizeof(T) / 4;
35+
union {
36+
T raw;
37+
uint32_t words[WORDS];
38+
} tmp;
39+
40+
#pragma unroll
41+
for (int ii = 0; ii < WORDS; ++ii) {
42+
tmp.words[ii] = 0u;
43+
}
44+
dst = tmp.raw;
45+
}
46+
47+
} // namespace cacheflow

0 commit comments

Comments
 (0)