-
Notifications
You must be signed in to change notification settings - Fork 25
/
decoderMaskedMultiheadAttentionTemplate.hpp
2274 lines (1977 loc) · 98.6 KB
/
decoderMaskedMultiheadAttentionTemplate.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Inspired by TRT-LLM.
// Modified by Haotian Tang and Shang Yang.
// @article{lin2024qserve,
// title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving},
// author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and Han, Song},
// journal={arXiv preprint arXiv:2405.04532},
// year={2024}
// }
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cudaTypeUtils.cuh"
#include "memoryUtils.h"
#include "decoderMaskedMultiheadAttention.h"
#include "decoderMaskedMultiheadAttentionUtils.h"
#include "kvCacheUtils.h"
#include <cuda_fp16.h>
#include <cuda_pipeline_primitives.h>
#include <assert.h>
#include <float.h>
#include <type_traits>
// Multi-block mmha kernel can only be selected when CUDA >= 11.7
#if (CUDART_VERSION >= 11070)
#define ENABLE_MULTI_BLOCK_OPTION
#endif
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B"
#else
#define L2_CACHEHINT(size)
#endif
// #ifdef ENABLE_MULTI_BLOCK_OPTION
// #include <cub/block/block_reduce.cuh>
// #include <cuda/atomic>
// #include <cuda/std/bit>
// #endif // ENABLE_MULTI_BLOCK_OPTION
// #define MMHA_USE_HMMA_FOR_REDUCTION
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
// Does not seem to affect the accuracy that much
#define MMHA_USE_FP32_ACUM_FOR_FMA
// Seems to slightly improve the accuracy
#define MMHA_USE_FP32_ACUM_FOR_OUT
#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
// Does not seem to improve the accuracy
//#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#endif
namespace mmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
//
// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
// 256 threads per block to maximum occupancy and performance.
//
// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
// cache buffer helps with memory accesses and contains keys with bias.
//
// The layout of the cache buffer for the keys/values is [B, H, L, Dh]
// where the fastest moving dimension (contiguous data) is the rightmost one.
// Contiguous threads will read one hidden_dimension per LDG unless we need more than 32 threads.
//
// The different kernels use 1 ~ 32 threads per key (THREADS_PER_KEY). The size of the LDGs
// is always 16bytes (8 bytes for 8bit cache). Each thread sums Dh / THREADS_PER_KEY elements. At
// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
// HMMA instruction (Tensor Core). Each Q * K^T value is stored in shared memory in FP32.
//
// After that loop, a parallel softmax is computed across the different Q * K^T values stored in
// shared memory.
//
// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
// timesteps are computed by loop iteration. As with the keys, the values are read from a cache
// except for the current timestep. The layout of the cache buffer for the values is same as the key,
// which is [B, H, L, Dh].
//
// Note that we have remapped key layout to make sure it shares the same pattern as value [B, H, L, Dh].
// It helps coalescing memory access, and reducing register pressure.
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int Dh_MAX>
struct Qk_vec_m_
{
};
template <>
struct Qk_vec_m_<uint16_t, 32>
{
using Type = uint32_t;
};
template <>
struct Qk_vec_m_<uint16_t, 64>
{
using Type = uint32_t;
};
template <>
struct Qk_vec_m_<uint16_t, 128>
{
using Type = uint2;
};
template <>
struct Qk_vec_m_<uint16_t, 256>
{
using Type = uint4;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int Dh>
struct Qk_vec_k_
{
using Type = typename Qk_vec_m_<T, Dh>::Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int V_VEC_SIZE>
struct V_vec_m_
{
};
template <>
struct V_vec_m_<uint16_t, 2>
{
using Type = uint32_t;
};
template <>
struct V_vec_m_<uint16_t, 4>
{
using Type = uint2;
};
template <>
struct V_vec_m_<uint16_t, 8>
{
using Type = uint4;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int V_VEC_SIZE>
struct V_vec_k_
{
using Type = typename V_vec_m_<T, V_VEC_SIZE>::Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Reuse V_vec traits as key and value share the same layout.
template <typename T, int K_VEC_SIZE>
struct K_vec_m_
{
using Type = typename V_vec_m_<T, K_VEC_SIZE>::Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int K_VEC_SIZE>
struct K_vec_k_
{
using Type = typename K_vec_m_<T, K_VEC_SIZE>::Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
template <typename T>
struct Qk_vec_acum_fp32_
{
};
template <>
struct Qk_vec_acum_fp32_<float>
{
using Type = float;
};
template <>
struct Qk_vec_acum_fp32_<float2>
{
using Type = float2;
};
template <>
struct Qk_vec_acum_fp32_<float4>
{
using Type = float4;
};
// template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float; };
template <>
struct Qk_vec_acum_fp32_<uint32_t>
{
using Type = float2;
};
template <>
struct Qk_vec_acum_fp32_<uint2>
{
using Type = Float4_;
};
template <>
struct Qk_vec_acum_fp32_<uint4>
{
using Type = Float8_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct K_vec_acum_fp32_
{
};
template <>
struct K_vec_acum_fp32_<float>
{
using Type = float;
};
template <>
struct K_vec_acum_fp32_<float2>
{
using Type = float2;
};
template <>
struct K_vec_acum_fp32_<float4>
{
using Type = float4;
};
template <>
struct K_vec_acum_fp32_<Float8_>
{
using Type = Float8_;
};
template <>
struct K_vec_acum_fp32_<uint32_t>
{
using Type = float2;
};
template <>
struct K_vec_acum_fp32_<uint2>
{
using Type = Float4_;
};
template <>
struct K_vec_acum_fp32_<uint4>
{
using Type = Float8_;
};
#endif // MMHA_USE_FP32_ACUM_FOR_FMA
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template <typename T>
struct V_vec_acum_fp32_
{
};
template <>
struct V_vec_acum_fp32_<float>
{
using Type = float;
};
template <>
struct V_vec_acum_fp32_<float2>
{
using Type = float2;
};
template <>
struct V_vec_acum_fp32_<float4>
{
using Type = float4;
};
template <>
struct V_vec_acum_fp32_<uint32_t>
{
using Type = float2;
};
template <>
struct V_vec_acum_fp32_<uint2>
{
using Type = Float4_;
};
template <>
struct V_vec_acum_fp32_<uint4>
{
using Type = Float8_;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Tout, typename Tin>
__inline__ __device__ constexpr Tout vec_conversion(const Tin &x)
{
static_assert(std::is_same<Tout, Tin>::value, "Type mismatch");
return x;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int THREADS_PER_KEY, typename K_vec, int N>
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
{
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using K_vec_acum = typename K_vec_acum_fp32_<K_vec>::Type;
#else
using K_vec_acum = K_vec;
#endif
// Compute the parallel products for Q*K^T (treat vector lanes separately).
K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii)
{
qk_vec = fma(q[ii], k[ii], qk_vec);
}
// Finalize the reduction across lanes.
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2)
{
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int THREADS_PER_KEY>
struct Qk_dot
{
template <typename K_vec, int N>
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
{
return qk_dot_<THREADS_PER_KEY>(q, k);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 hmma_fp32(const uint2 &a, uint32_t b)
{
float4 c;
float zero = 0.f;
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5}, \n"
" {%6}, \n"
" {%7, %7, %7, %7}; \n"
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
#else
using K_vec_acum = uint32_t;
#endif
K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii)
{
qk_vec = fma(q[ii], k[ii], qk_vec);
}
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
uint32_t qk_vec_ = float2_to_half2(qk_vec);
return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
#else
return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
#endif
#else
return 0.f;
#endif
}
template <int THREADS_PER_KEY, typename K_vec_k>
inline __device__ float qk_hmma_dot_simple(const K_vec_k& q, const K_vec_k& k);
template <int THREADS_PER_KEY>
inline __device__ float qk_hmma_dot_simple(const uint32_t& q, const uint32_t& k)
{
assert (0);
}
template <int THREADS_PER_KEY>
inline __device__ float qk_hmma_dot_simple(const uint2& q, const uint2& k)
{
assert (0);
}
template <int THREADS_PER_KEY>
inline __device__ float qk_hmma_dot_simple(const uint4& q, const uint4& k)
{
using K_vec_acum = uint32_t;
K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q.x, k.x);
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(qk_vec) : "r"(q.y), "r"(k.y), "r"(qk_vec));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(qk_vec) : "r"(q.z), "r"(k.z), "r"(qk_vec));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(qk_vec) : "r"(q.w), "r"(k.w), "r"(qk_vec));
// return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
half2 qk_vec_h = (half2 &)qk_vec;
float qk = __half2float(__hadd(qk_vec_h.x, qk_vec_h.y));
#pragma unroll
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2)
{
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct Qk_dot<uint16_t, 4>
{
template <typename K_vec, int N>
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
{
return qk_dot_<4>(q, k);
}
template <int N>
static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
{
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
return qk_hmma_dot_(q, k);
#else
return qk_dot_<4>(q, k);
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float *red_smem, float sum)
{
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2)
{
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0)
{
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < WARPS_PER_BLOCK)
{
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2)
{
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
}
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float cast_to_float(float u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 cast_to_float(float2 u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 cast_to_float(float4 u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ cast_to_float(Float4_ u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ cast_to_float(Float8_ u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 cast_to_float(uint32_t u)
{
return half2_to_float2(u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ cast_to_float(uint2 u)
{
Float4_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
return tmp;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ cast_to_float(uint4 u)
{
Float8_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
tmp.z = half2_to_float2(u.z);
tmp.w = half2_to_float2(u.w);
return tmp;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
inline __device__ __host__ T divUp(T m, T n)
{
return (m + n - 1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
inline __device__ __host__ T div(T m, T n)
{
return m / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct kernel_type_t
{
using Type = T;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Compute the largest supported head size (dh_max). It must be the smallest power-of-2 that is not strictly smaller
// than the head size (dh).
inline __device__ __host__ constexpr unsigned dh_max(unsigned dh)
{
return next_power_of_two(const_max(dh, 32u));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
inline __device__ __host__ constexpr unsigned threads_per_value(unsigned dh_max)
{
return dh_max * sizeof(T) / 16;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, unsigned Dh_MAX>
inline __device__ __host__ constexpr unsigned threads_per_key()
{
// Since we want to perform the reduction entirely within a warp, the number of threads per key
// is capped at 32.
constexpr unsigned threads = (unsigned)(Dh_MAX * sizeof(T) / 16u);
if ((threads & (threads - 1)) != 0)
{
assert(false); // Not a power of two.
}
return std::min(32u, threads);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, typename T_VEC, unsigned VECS_PER_CHUNK>
__device__ inline constexpr uint2 chunk_index(unsigned tidx)
{
// The chunk associated with the thread.
auto const idx_chunk = tidx / VECS_PER_CHUNK;
// The position of the T_VEC vector in that chunk associated with the thread.
static_assert(sizeof(T_VEC) % sizeof(T) == 0);
unsigned constexpr kVecSize{sizeof(T_VEC) / sizeof(T)};
auto const idx_vec = (tidx % VECS_PER_CHUNK) * kVecSize;
return uint2{idx_chunk, idx_vec};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
__inline__ __device__ uint32_t cast_smem_ptr_to_uint_helper(void const *const ptr)
{
uint32_t smem_int_ptr;
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, "
"smem_ptr; }\n"
: "=r"(smem_int_ptr)
: "l"(ptr));
return smem_int_ptr;
}
__inline__ __device__ void
cp_async_helper(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask)
{
const int cp_size = 16;
// cachehint will not impact performance.
// clang-format off
asm volatile("{"
" .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
"}" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
// clang-format on
}
__inline__ __device__ void
cp_async_launch(void *dst_ptr, const uint4 *__restrict__ src_ptr, bool mask)
{
uint32_t addr = cast_smem_ptr_to_uint_helper(dst_ptr);
cp_async_helper(addr, src_ptr, mask);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The type of the inputs. Supported types: float, uint16_t, nv_bfloat16.
typename T,
// The type of the cache.
typename Tcache,
// Type of struct containing KV cache
typename KVCacheBuffer,
// The hidden dimension per head.
unsigned Dh,
// The number of threads in a threadblock.
unsigned THREADS_PER_BLOCK,
// Whether enable multi-block mode for long-sequence-length.
bool DO_MULTI_BLOCK = false,
// Whether use INT4KV
bool INT4KV = false,
bool KV_WITH_ZEROS = false,
bool SMEM_PRELOAD = false,
// The number of threads per key.
unsigned THREADS_PER_KEY = mmha::threads_per_key<T, dh_max(Dh)>(),
// The number of threads per value.
unsigned THREADS_PER_VALUE = mmha::threads_per_value<T>(dh_max(Dh)),
// The unroll factor for loading from K cache.
// unsigned K_LOOP_UNROLL = 8, // 8,
// The unroll factor for loading from V cache.
// Set it default to 4 for higher occupancy (by reducing registers usage).
unsigned V_LOOP_UNROLL = 4>
__global__ void masked_multihead_attention_kernel(
Multihead_attention_params<T> params, KVCacheBuffer kvCacheBuffer)
{
constexpr unsigned K_LOOP_UNROLL = SMEM_PRELOAD ? 8 : 4;
using Tk = typename kernel_type_t<T>::Type;
// Use 8bit cache.
static constexpr bool ENABLE_8BITS_CACHE = sizeof(Tcache) == 1;
static constexpr bool ENABLE_4BITS_CACHE = (INT4KV && ENABLE_8BITS_CACHE);
static constexpr bool ENABLE_ZEROS = KV_WITH_ZEROS;
// The size of a warp.
constexpr unsigned WARP_SIZE{32};
// The number of warps in a threadblock.
constexpr unsigned WARPS_PER_BLOCK{THREADS_PER_BLOCK / WARP_SIZE};
// The maximum hidden size per head.
constexpr auto Dh_MAX = dh_max(Dh);
constexpr bool IS_Dh_MAX = Dh == Dh_MAX;
static_assert(Dh_MAX >= WARP_SIZE);
static_assert(Dh_MAX >= Dh);
// The maximum sequence length in the kv_cache, i.e., an upper bound on L.
// Note that the maximum sequence length supported by the model might be greater than this.
const auto max_seq_len = static_cast<unsigned>(params.memory_max_len);
assert(max_seq_len > 0);
// The current timestep (including paddings).
// It is only used to calculate the smem stride.
const auto timestep = static_cast<unsigned>(DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep);
constexpr bool MULTI_BLOCK_FLAG = false;
// Use smem_size_in_bytes (above) to determine the amount of shared memory.
extern __shared__ char smem_[];
// The shared memory for the Q*K^T values and partial logits in softmax.
auto qk_smem = reinterpret_cast<float *>(smem_);
__shared__ float qk_current_smem[1];
// The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
char *logits_smem_ = smem_;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if (sizeof(Tk) != 4)
{
// TODO - change to tlength
const auto max_timesteps = min(timestep, max_seq_len); // const auto max_timesteps = DO_CROSS_ATTENTION ? max_seq_len : min(timestep, max_seq_len);
logits_smem_ += divUp(max_timesteps + 1, 4u) * 16;
}
Tk *logits_smem = reinterpret_cast<Tk *>(logits_smem_);
#else
float *logits_smem = reinterpret_cast<float *>(logits_smem_);
#endif
__shared__ Tk logits_current_smem[1];
// The shared memory to do the final reduction for the output values. Reuse qk_smem.
Tk *out_smem = reinterpret_cast<Tk *>(smem_);
// The shared memory buffers for the block-wide reductions. One for max, one for sum.
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
// A vector of Q or K elements for the current timestep.
using Qk_vec_m = typename Qk_vec_m_<T, Dh_MAX>::Type; // with memory-used precision
using Qk_vec_k = typename Qk_vec_k_<T, Dh_MAX>::Type; // with kernel-used precision
// Make sure the hidden dimension per head is a multiple of the number of threads per key.
static_assert(Dh_MAX % THREADS_PER_KEY == 0); // trivially satisfied since THREADS_PER_KEY in {1, 2, 4}
// The number of elements per vector.
// Each thread will handle 16 bytes.
constexpr int K_VEC_SIZE = 16u / sizeof(T);
// Make sure the hidden size per head is a multiple of the vector size.
static_assert(Dh_MAX % K_VEC_SIZE == 0);
// The type of queries and keys for the math in the Q*K^T product.
using K_vec_k = typename K_vec_k_<T, K_VEC_SIZE>::Type;
// Only used when key cache is quantized to 4 or 8 bits.
constexpr int K_VEC_M_SIZE = K_VEC_SIZE / (ENABLE_4BITS_CACHE ? 2 : 1);
using K_vec_m = typename packed_type<Tcache, K_VEC_M_SIZE>::type;
// Use alignment for safely casting the shared buffers as Qk_vec_k and K_vec_k.
// Shared memory to store Q inputs.
__shared__ __align__(const_max(sizeof(Qk_vec_k), sizeof(K_vec_k))) Tk q_smem[Dh_MAX];
// Make sure the hidden dimension per head is a multiple of the number of threads per value.
static_assert(Dh_MAX % THREADS_PER_VALUE == 0); // trivially satisfied since THREADS_PER_VALUE == Dh_MAX / p
// The number of elements per vector.
constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
// A vector of V elements for the current timestep.
using V_vec_k = typename V_vec_k_<T, V_VEC_SIZE>::Type;
// Only used when value cache is quantized to 4 or 8 bits.
constexpr int V_VEC_M_SIZE = V_VEC_SIZE / (ENABLE_4BITS_CACHE ? 2 : 1);
using V_vec_m = typename packed_type<Tcache, V_VEC_M_SIZE>::type;
static_assert(V_VEC_SIZE == sizeof(V_vec_k) / sizeof(T));
// This could be one of the reasons to have a separate kernel for cross attention
// constexpr auto bias_smem_size = 1u; // constexpr auto bias_smem_size = DO_CROSS_ATTENTION ? Dh_MAX : 1u;
// __shared__ __align__(const_max(const_max(sizeof(Qk_vec_k), sizeof(K_vec_k)), sizeof(V_vec_k)))
// Tk bias_smem[bias_smem_size];
// The number of elements per vector.
constexpr unsigned QK_VEC_SIZE{sizeof(Qk_vec_m) / sizeof(T)};
// Make sure the hidden size per head is a multiple of the vector size.
static_assert(Dh_MAX % QK_VEC_SIZE == 0);
// We will use block wide reduction if needed
// The number of vectors per Dh_MAX.
constexpr unsigned QK_VECS_PER_Dh_MAX{Dh_MAX / QK_VEC_SIZE};
static_assert(THREADS_PER_BLOCK >= QK_VECS_PER_Dh_MAX);
// The batch/beam idx
const auto bi = blockIdx.y;
// half *k_scale_quant_orig_ptr = params.k_scale_quant_orig[bi];
// half *v_scale_quant_orig_ptr = params.v_scale_quant_orig[bi];
if (params.finished != nullptr && params.finished[bi])
{
return;
}
// The head.
const unsigned hi{blockIdx.x};
// The head index of keys and values adjusted for MQA/GQA.
const int qhead_per_kv{params.num_heads / params.num_kv_heads};
const unsigned hi_kv{hi / qhead_per_kv};
// The number of heads.
const auto num_heads = static_cast<unsigned>(params.num_heads);
// The number of heads for keys and values adjusted for MQA/GQA.
const auto num_heads_kv = static_cast<unsigned>(params.num_kv_heads);
// The thread in the block.
const unsigned tidx{threadIdx.x};
// The column tile along L dimension on K^T -- noted as T_c in flash-attention paper
const unsigned c_tile{0}; // const unsigned c_tile{MULTI_BLOCK_FLAG ? blockIdx.z : 0};
// Indicate if we need to compute the K/V cache element (add KV bias, IA3, RoPE, etc.) and update the cache.
// For Self-Attention, it's always required.
// For Cross-Attention, as everything is pre-computed,
// in the context phase of the encoder, it's not needed in that kernel.
// Therefore, handle_kv is !DO_CROSS_ATTENTION and irrelevant of timestep.
const bool handle_kv = true; // const bool handle_kv{!DO_CROSS_ATTENTION};
// While doing the product Q*K^T for the different keys we track the max.
float qk_max = -FLT_MAX;
float qk = 0.0F;
// Compute relative attention bias on the fly, with relative attention table [head_num/TP, num_buckets] passed in.
// num_buckets passed as params.relative_attention_bias_stride, max_distance passed as params.max_distance
bool implicit_rel_attn_bias = params.max_distance != 0;
int relative_attention_bias_stride = params.relative_attention_bias_stride; // num_buckets might be modified below, save it beforehand
int max_distance = params.max_distance;
// The actual sequence length excluding the paddings.
// minus 1 because it includes the current timestep while tlength denotes the kv cache length.
// const int tlength = DO_CROSS_ATTENTION
// ? params.memory_length_per_sample[bi] - 1
// : (params.length_per_sample ? (params.length_per_sample[bi] - 1) : static_cast<int>(timestep));
const int tlength = (params.length_per_sample ? (params.length_per_sample[bi] - 1) : static_cast<int>(timestep));
// The context length for beam searching optimization (all points to beam 0).
const int input_length = params.input_lengths[bi];
// The offset in the Q and K buffer also accounts for the batch.
const auto qk_vec_idx = tidx * QK_VEC_SIZE;
const auto is_valid_qk_vec = qk_vec_idx < Dh;
// const bool load_qkv_quant = params.qkv_scale_quant_orig != nullptr;
const bool write_attention_quant = params.attention_out_scale_orig_quant != nullptr;
// Quant/Dequant scales for 8bits kv cache.
using T_scale = typename kv_cache_scale_type_t<T, Tcache>::Type;
T_scale kv_scale_quant_orig[2];
T_scale kv_scale_orig_quant[2];
constexpr int MAX_TIMESTEP_SCALES = SMEM_PRELOAD ? 2048 : 1;
__shared__ half k_scales_history_smem[MAX_TIMESTEP_SCALES], k_zeros_history_smem[MAX_TIMESTEP_SCALES], v_scales_history_smem[MAX_TIMESTEP_SCALES], v_zeros_history_smem[MAX_TIMESTEP_SCALES];
if constexpr (SMEM_PRELOAD)
{
int cur_timestep_idx = threadIdx.x * 8;
Tcache *k_cache_ptr = reinterpret_cast<Tcache *>(kvCacheBuffer.getKBlockPtr(bi, cur_timestep_idx));
half *k_scale_quant_orig_local_ptr = reinterpret_cast<half *>(k_cache_ptr + kvCacheBuffer.mBytesPerSeq);
half *k_zeros_local_ptr = k_scale_quant_orig_local_ptr + kvCacheBuffer.mTokensPerBlock * num_heads_kv;
Tcache *v_cache_ptr = reinterpret_cast<Tcache *>(kvCacheBuffer.getVBlockPtr(bi, cur_timestep_idx));
half *v_scale_quant_orig_local_ptr = reinterpret_cast<half *>(v_cache_ptr + kvCacheBuffer.mBytesPerSeq);
half *v_zeros_local_ptr = v_scale_quant_orig_local_ptr + kvCacheBuffer.mTokensPerBlock * num_heads_kv;
// assume kscales stored as num_heads * num_tokens_per_block
int k_scale_quant_orig_local_index = hi_kv * kvCacheBuffer.mTokensPerBlock + kvCacheBuffer.getLocalIdx(cur_timestep_idx);
// if (cur_timestep_idx < tlength)
// {
// *reinterpret_cast<uint4*>(k_scales_history_smem + cur_timestep_idx) = *(uint4*)(k_scale_quant_orig_local_ptr + k_scale_quant_orig_local_index);
// *reinterpret_cast<uint4*>(k_zeros_history_smem + cur_timestep_idx) = *(uint4*)(k_zeros_local_ptr + k_scale_quant_orig_local_index);
// *reinterpret_cast<uint4*>(v_scales_history_smem + cur_timestep_idx) = *(uint4*)(v_scale_quant_orig_local_ptr + k_scale_quant_orig_local_index);
// *reinterpret_cast<uint4*>(v_zeros_history_smem + cur_timestep_idx) = *(uint4*)(v_zeros_local_ptr + k_scale_quant_orig_local_index);
// }
// else
// {
// *reinterpret_cast<uint4*>(k_scales_history_smem + cur_timestep_idx) = make_uint4(0, 0, 0, 0);
// *reinterpret_cast<uint4*>(k_zeros_history_smem + cur_timestep_idx) = make_uint4(0, 0, 0, 0);
// *reinterpret_cast<uint4*>(v_scales_history_smem + cur_timestep_idx) = make_uint4(0, 0, 0, 0);
// *reinterpret_cast<uint4*>(v_zeros_history_smem + cur_timestep_idx) = make_uint4(0, 0, 0, 0);
// }
bool ld_scale_zero_pred = cur_timestep_idx < tlength;
if (ld_scale_zero_pred)
{
cp_async_launch(k_scales_history_smem + cur_timestep_idx, (uint4*)(k_scale_quant_orig_local_ptr + k_scale_quant_orig_local_index), ld_scale_zero_pred);
cp_async_launch(k_zeros_history_smem + cur_timestep_idx, (uint4*)(k_zeros_local_ptr + k_scale_quant_orig_local_index), ld_scale_zero_pred);
cp_async_launch(v_scales_history_smem + cur_timestep_idx, (uint4*)(v_scale_quant_orig_local_ptr + k_scale_quant_orig_local_index), ld_scale_zero_pred);
cp_async_launch(v_zeros_history_smem + cur_timestep_idx, (uint4*)(v_zeros_local_ptr + k_scale_quant_orig_local_index), ld_scale_zero_pred);
__pipeline_commit();
}
// __pipeline_wait_prior(0);
// __syncthreads();
}
// #pragma unroll
// for (int i = 0; i < 2; i++)
// {
// convert_from_float(&kv_scale_quant_orig[i], (ENABLE_8BITS_CACHE ? params.kv_scale_quant_orig[i] : 1.0f));
// }
// #pragma unroll
// for (int i = 0; i < 2; i++)
// {
// convert_from_float(&kv_scale_orig_quant[i], (ENABLE_8BITS_CACHE ? params.kv_scale_orig_quant[i] : 1.0f));
// }
// Up to QK_VECS_PER_Dh_MAX threads load Q and K + the bias values for the current timestep.
// Trigger the loads from the Q and K buffers.
Qk_vec_k q, k; //, q_bias, k_bias;
zero(q);
zero(k);
// zero(q_bias);
// zero(k_bias);
float rotary_embedding_base = params.rotary_embedding_base;
float rotary_embedding_scale = params.rotary_embedding_scale;
if (is_valid_qk_vec)
{
update_rotary_base_n_scale(rotary_embedding_base, rotary_embedding_scale,
params.rotary_embedding_scale_type, params.rotary_embedding_dim, params.rotary_embedding_max_positions,
tlength);
// Query
// The stride between tokens. We may be able to always use params.stride.
uint32_t q_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads * Dh);
// The offset.
const auto q_offset = flat_index_strided3(bi, hi, qk_vec_idx, q_stride, Dh);
// Note (shang): Load the current qk here. Not the quantized kv cache.
{
// Removed a branch for load_qkv_quant (current step qkv)
q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m *>(¶ms.q[q_offset]));
}
{
// Removed DO_CROSS_ATTENTION branch
// Key