Skip to content

Commit 2c87db5

Browse files
authored
Update flash_api_sparse.cpp to support SM_120 (#73)
1 parent 1c2624e commit 2c87db5

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

csrc/flash_attn/flash_api_sparse.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -157,17 +157,14 @@ mha_fwd_sparse(at::Tensor &q, // batch_size x seqlen_q x num_heads x hea
157157
std::optional<at::Generator> gen_) {
158158

159159
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
160-
bool is_sm8x = cc_major == 8 && cc_minor >= 0;
161-
bool is_sm90 = cc_major == 9 && cc_minor == 0;
162-
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
163-
// We will support Turing in the near future
164-
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
160+
bool is_sm8x_min = cc_major >= 8;
161+
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
165162

166163
auto q_dtype = q.dtype();
167164
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
168165
"FlashAttention only support fp16 and bf16 data type");
169166
if (q_dtype == torch::kBFloat16) {
170-
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
167+
TORCH_CHECK(is_sm8x_min, "bfloat16 is only supported on Ampere GPUs or newer");
171168
}
172169
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
173170
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
@@ -342,17 +339,14 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
342339
std::optional<at::Generator> gen_) {
343340

344341
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
345-
bool is_sm8x = cc_major == 8 && cc_minor >= 0;
346-
bool is_sm90 = cc_major == 9 && cc_minor == 0;
347-
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
348-
// We will support Turing in the near future
349-
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
342+
bool is_sm8x_min = cc_major >= 8;
343+
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
350344

351345
auto q_dtype = q.dtype();
352346
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
353347
"FlashAttention only support fp16 and bf16 data type");
354348
if (q_dtype == torch::kBFloat16) {
355-
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
349+
TORCH_CHECK(is_sm8x_min, "bfloat16 is only supported on Ampere GPUs or newer");
356350
}
357351
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
358352
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
@@ -528,4 +522,4 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
528522
return {out, softmax_lse};
529523
}
530524

531-
} // namespace FLASH_NAMESPACE
525+
} // namespace FLASH_NAMESPACE

0 commit comments

Comments
 (0)