Skip to content

Commit 0ffdc94

Browse files
Merge remote-tracking branch 'upstream/main' into lwilkinson/upstream-sync-4
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2 parents d637d89 + d836a6b commit 0ffdc94

File tree

64 files changed

+485
-217
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+485
-217
lines changed

csrc/cutlass

Submodule cutlass updated 349 files

flash_attn/modules/mha.py

Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
flash_attn_with_kvcache = None
2424

2525
try:
26-
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
26+
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
2727
except ImportError:
28-
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
28+
ColumnParallelLinear, RowParallelLinear = None, None, None
2929

3030
try:
3131
from flash_attn.layers.rotary import RotaryEmbedding
@@ -341,13 +341,6 @@ def forward(self, q, kv, causal=None, key_padding_mask=None):
341341
return output
342342

343343

344-
class LinearResidual(nn.Linear):
345-
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
346-
347-
def forward(self, input: torch.Tensor) -> torch.Tensor:
348-
return super().forward(input), input
349-
350-
351344
def _update_kv_cache(kv, inference_params, layer_idx):
352345
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
353346
# Pre-allocate memory for key-values for inference.
@@ -452,13 +445,6 @@ def __init__(
452445
device=device,
453446
)
454447

455-
if fused_bias_fc and FusedDense is None:
456-
raise ImportError("fused_dense is not installed")
457-
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
458-
linear_resid_cls = (
459-
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
460-
)
461-
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
462448
inner_attn_cls = (
463449
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
464450
if use_flash_attn
@@ -470,10 +456,10 @@ def __init__(
470456
else CrossAttention
471457
)
472458
if not self.cross_attn:
473-
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
459+
self.Wqkv = nn.Linear(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
474460
else:
475-
self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
476-
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
461+
self.Wq = nn.Linear(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
462+
self.Wkv = nn.Linear(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
477463
if self.dwconv:
478464
if self.num_heads_kv == self.num_heads:
479465
self.dwconv_qkv = nn.Conv1d(
@@ -492,7 +478,7 @@ def __init__(
492478
self.inner_cross_attn = inner_cross_attn_cls(
493479
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
494480
)
495-
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
481+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
496482

497483
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
498484
dtype = self.out_proj.weight.dtype if dtype is None else dtype
@@ -646,10 +632,7 @@ def forward(
646632
batch, seqlen = x.shape[:2]
647633
if not self.cross_attn and self.num_heads_kv == self.num_heads:
648634
assert x_kv is None and mixer_subset is None
649-
if not self.return_residual:
650-
qkv = self.Wqkv(x)
651-
else:
652-
qkv, x = self.Wqkv(x)
635+
qkv = self.Wqkv(x)
653636
if self.dwconv:
654637
qkv = rearrange(
655638
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
@@ -680,21 +663,11 @@ def forward(
680663
)
681664
else:
682665
if self.cross_attn:
683-
if not self.return_residual:
684-
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
685-
kv = self.Wkv(x_kv if x_kv is not None else x)
686-
else:
687-
if x_kv is not None:
688-
kv, x_kv = self.Wkv(x_kv)
689-
else:
690-
kv, x = self.Wkv(x)
691-
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
666+
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
667+
kv = self.Wkv(x_kv if x_kv is not None else x)
692668
else:
693669
assert self.num_heads_kv != self.num_heads
694-
if not self.return_residual:
695-
qkv = self.Wqkv(x)
696-
else:
697-
qkv, x = self.Wqkv(x)
670+
qkv = self.Wqkv(x)
698671
q = qkv[..., : self.num_heads * self.head_dim]
699672
kv = qkv[..., self.num_heads * self.head_dim :]
700673
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)

flash_attn/ops/fused_dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
import torch.nn as nn
1212
import torch.nn.functional as F
1313
from torch import Tensor
14-
from torch.cuda.amp import custom_bwd, custom_fwd
1514
from torch.distributed import ProcessGroup
1615

16+
from flash_attn.utils.torch import custom_fwd, custom_bwd
1717
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd
1818
from flash_attn.utils.distributed import (
1919
all_gather_raw,

flash_attn/ops/triton/layer_norm.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010

1111
import torch
1212
import torch.nn.functional as F
13-
from torch.cuda.amp import custom_fwd, custom_bwd
1413

1514
import triton
1615
import triton.language as tl
1716

17+
from flash_attn.utils.torch import custom_fwd, custom_bwd
18+
19+
1820
def triton_autotune_configs():
1921
# Return configs with a valid warp count for the current device
2022
configs=[]
@@ -635,7 +637,9 @@ def _layer_norm_bwd(
635637
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
636638
if N > BLOCK_N:
637639
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
638-
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
640+
# Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the
641+
# latency of the gmem reads/writes, but will increase the time of summing up dw / db.
642+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8
639643
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
640644
_db = (
641645
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
@@ -1018,12 +1022,12 @@ def forward(
10181022
norm_bias,
10191023
eps,
10201024
residual,
1021-
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
1025+
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"),
10221026
residual_dtype=residual_dtype,
10231027
is_rms_norm=is_rms_norm,
10241028
)
10251029
y = y.reshape(x_shape_og)
1026-
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1030+
dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype
10271031
linear_weight = linear_weight.to(dtype)
10281032
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
10291033
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)

flash_attn/ops/triton/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import torch
55
import torch.nn as nn
66
import torch.nn.functional as F
7-
from torch.cuda.amp import custom_bwd, custom_fwd
87

8+
from flash_attn.utils.torch import custom_fwd, custom_bwd
99
from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd
1010
from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act
1111

flash_attn/ops/triton/rotary.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def rotary_kernel(
3838
BLOCK_M: tl.constexpr,
3939
):
4040
pid_m = tl.program_id(axis=0)
41-
pid_batch = tl.program_id(axis=1)
42-
pid_head = tl.program_id(axis=2)
41+
pid_head = tl.program_id(axis=1)
42+
pid_batch = tl.program_id(axis=2)
4343
rotary_dim_half = rotary_dim // 2
4444

4545
if not IS_VARLEN:
@@ -193,7 +193,7 @@ def apply_rotary(
193193
if rotary_dim <= 32
194194
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
195195
)
196-
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
196+
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), nheads, batch) # noqa
197197
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 128 else 4)
198198

199199
# Need this, otherwise Triton tries to launch from cuda:0 and we get
@@ -223,5 +223,6 @@ def apply_rotary(
223223
interleaved,
224224
conjugate,
225225
BLOCK_M,
226+
num_warps=2 if rotary_dim <= 64 else 4,
226227
)
227228
return output

flash_attn/utils/torch.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
from typing import Callable
3+
4+
5+
def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
6+
def decorator(*args, **kwargs):
7+
if cuda_amp_deprecated:
8+
kwargs["device_type"] = "cuda"
9+
return dec(*args, **kwargs)
10+
return decorator
11+
12+
13+
if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
14+
deprecated = True
15+
from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
16+
else:
17+
deprecated = False
18+
from torch.cuda.amp import custom_fwd, custom_bwd
19+
20+
custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
21+
custom_bwd = custom_amp_decorator(custom_bwd, deprecated)

hopper/benchmark_mla_decode.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@
3636

3737
use_bench_cudagraph = False
3838

39-
attn_variants = ["mha", "gqa", "mqa", "mla"]
40-
for attn_variant in attn_variants:
41-
# for attn_variant in attn_variants[3:]:
42-
nheads_kv = nheads_q if attn_variant == "mha" else (max(nheads_q // 8, 1) if attn_variant == "gqa" else 1)
43-
headdim = 64 if attn_variant == "mla" else 128
44-
headdim_v = 512 if attn_variant == "mla" else headdim
45-
has_qv = headdim == 64 and headdim_v == 512
39+
attn_variants = ["mha", "gqa", "mqa", "mla", "gla"]
40+
# for attn_variant in attn_variants:
41+
for attn_variant in attn_variants[3:5]:
42+
nheads_kv = nheads_q if attn_variant == "mha" else (max(nheads_q // 8, 1) if attn_variant == "gqa" else (1 if attn_variant == "mla" else 2))
43+
headdim = 64 if attn_variant in ["mla", "gla"] else 128
44+
headdim_v = 512 if attn_variant == "mla" else (256 if attn_variant == "gla" else headdim)
45+
has_qv = headdim == 64 and headdim_v > 64
4646
# page_size = None
47-
page_size = 64 if attn_variant == "mla" else 128
47+
page_size = 64 if attn_variant in ["mla", "gla"] else 128
4848

4949
should_run_flashmla = attn_variant == "mla" and page_size == 64 and flash_mla_with_kvcache is not None
5050

@@ -60,7 +60,7 @@
6060
print(f"\n{attn_variant.upper()}, nheads_q = {nheads_q}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}")
6161

6262
for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]:
63-
# for seqlen in [s * 1024 for s in [1]]:
63+
# for seqlen in [s * 1024 for s in [8]]:
6464
cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int)
6565
num_splits = 0
6666
q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, dtype=dtype, device=device)
@@ -84,6 +84,7 @@
8484
cache_seqlens, q.dtype, headdim_v=headdim_v, page_size=page_size, causal=True
8585
)
8686
# scheduler_metadata = None
87+
# breakpoint()
8788
fn0 = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True, scheduler_metadata=scheduler_metadata)
8889
time.sleep(1) # to avoid power throttling
8990
# Time in ms
@@ -109,7 +110,7 @@
109110
t1 = do_bench_cudagraph(fn1, rep=10)
110111

111112
total_seqlen = seqlen * batch_size if cache_seqlens is None else cache_seqlens.sum().item()
112-
mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last time is for the output
113+
mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last term is for the output
113114
flops = seqlen_q * total_seqlen * nheads_q * (headdim + headdim_v * (2 if has_qv else 1)) * 2
114115
ideal_h100_time_mem = mem_io / 3.35e12 * 1e6
115116
ideal_h100_time_flop = flops / 989e12 * 1e6

hopper/flash.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ struct Flash_fwd_params : public Qkv_params {
112112
// The cos and sin matrices for rotary embedding.
113113
void * __restrict__ rotary_cos_ptr;
114114
void * __restrict__ rotary_sin_ptr;
115+
int *__restrict__ seqlens_rotary;
115116

116117
// The indices to index into the KV cache.
117118
int * __restrict__ kv_batch_idx;

hopper/flash_api.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,11 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
272272
if (params.is_bf16) {
273273
#ifndef FLASHATTENTION_DISABLE_HDIM64
274274
if (params.d <= 64) {
275-
if (params.dv > 64 && Arch == 90) {
275+
if (params.dv > 256 && Arch == 90) {
276276
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
277-
}
278-
else {
277+
} else if (params.dv > 64 && Arch == 90) {
278+
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
279+
} else {
279280
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
280281
}
281282
}
@@ -302,10 +303,11 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
302303
#ifndef FLASHATTENTION_DISABLE_FP16
303304
#ifndef FLASHATTENTION_DISABLE_HDIM64
304305
if (params.d <= 64) {
305-
if (params.dv > 64 && Arch == 90) {
306+
if (params.dv > 256 && Arch == 90) {
306307
return run_mha_fwd_<Arch, cutlass::half_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
307-
}
308-
else {
308+
} else if (params.dv > 64 && Arch == 90) {
309+
return run_mha_fwd_<Arch, cutlass::half_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
310+
} else {
309311
return run_mha_fwd_<Arch, cutlass::half_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
310312
}
311313
}
@@ -490,6 +492,15 @@ inline int round_up_headdim(int head_size) {
490492
return 256;
491493
}
492494

495+
inline int round_up_headdimv(int head_size) {
496+
if (head_size <= 64) { return 64; }
497+
if (head_size <= 96) { return 96; }
498+
if (head_size <= 128) { return 128; }
499+
if (head_size <= 192) { return 192; }
500+
if (head_size <= 256) { return 256; }
501+
return 512;
502+
}
503+
493504
// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
494505
at::Tensor
495506
mha_fwd_get_scheduler_metadata(
@@ -534,7 +545,7 @@ mha_fwd_get_scheduler_metadata(
534545
params.d = headdim;
535546
params.dv = headdim_v;
536547
params.d_rounded = round_up_headdim(headdim);
537-
params.dv_rounded = round_up_headdim(headdim_v);
548+
params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v);
538549
params.seqlen_knew = max_seqlen_k_new;
539550

540551
bool const is_varlen_q = cu_seqlens_q_.has_value();
@@ -640,6 +651,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
640651
std::optional<const at::Tensor> &leftpad_k_, // b
641652
std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
642653
std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
654+
std::optional<const at::Tensor> &seqlens_rotary_, // b
643655
std::optional<at::Tensor> &q_descale_, // (b, h_k), not (b, h)
644656
std::optional<at::Tensor> &k_descale_, // (b, h_k)
645657
std::optional<at::Tensor> &v_descale_, // (b, h_k)
@@ -823,7 +835,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
823835

824836
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
825837
int const head_size_rounded = round_up_headdim(head_size);
826-
int const head_size_v_rounded = round_up_headdim(head_size_v);
838+
int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v);
827839
int const seqlen_q_rounded = round_multiple(seqlen_q, 128);
828840
int const seqlen_k_rounded = round_multiple(seqlen_k, 128);
829841

@@ -1001,6 +1013,13 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
10011013
params.rotary_cos_ptr = rotary_cos.data_ptr();
10021014
params.rotary_sin_ptr = rotary_sin.data_ptr();
10031015
params.is_rotary_interleaved = is_rotary_interleaved;
1016+
if (seqlens_rotary_.has_value()) {
1017+
at::Tensor seqlens_rotary = seqlens_rotary_.value();
1018+
CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary);
1019+
TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32");
1020+
CHECK_SHAPE(seqlens_rotary, batch_size);
1021+
params.seqlens_rotary = seqlens_rotary.data_ptr<int>();
1022+
}
10041023
} else {
10051024
params.rotary_dim = 0;
10061025
}
@@ -1104,7 +1123,11 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
11041123
// params.b = 1;
11051124
// params.seqlen_q = total_q;
11061125
// }
1126+
// This will zero out the semaphore if needed
11071127
run_mha_fwd_combine(params, stream, true /*enable_pdl*/);
1128+
} else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) {
1129+
// need to zero out the semaphore in this case
1130+
tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_();
11081131
}
11091132
} else if (total_q > 0 && num_heads_k > 0) {
11101133
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
@@ -1492,7 +1515,6 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x
14921515
const int seqlen = sizes[2];
14931516
const int num_heads = sizes[3];
14941517
const int head_size_og = sizes[4];
1495-
TORCH_CHECK(head_size_og <= 512, "FlashAttention combine only supports head dimension at most 512");
14961518
TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256");
14971519

14981520
CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);

0 commit comments

Comments
 (0)