@@ -157,17 +157,14 @@ mha_fwd_sparse(at::Tensor &q, // batch_size x seqlen_q x num_heads x hea
157157 std::optional<at::Generator> gen_) {
158158
159159 auto [cc_major, cc_minor] = get_compute_capability (get_current_device ());
160- bool is_sm8x = cc_major == 8 && cc_minor >= 0 ;
161- bool is_sm90 = cc_major == 9 && cc_minor == 0 ;
162- TORCH_CHECK (is_sm90 || is_sm8x, " FlashAttention only supports Ampere GPUs or newer." );
163- // We will support Turing in the near future
164- // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
160+ bool is_sm8x_min = cc_major >= 8 ;
161+ TORCH_CHECK (is_sm8x_min, " FlashAttention only supports Ampere GPUs or newer." );
165162
166163 auto q_dtype = q.dtype ();
167164 TORCH_CHECK (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 ,
168165 " FlashAttention only support fp16 and bf16 data type" );
169166 if (q_dtype == torch::kBFloat16 ) {
170- TORCH_CHECK (is_sm90 || is_sm8x , " bfloat16 is only supported on Ampere GPUs or newer" );
167+ TORCH_CHECK (is_sm8x_min , " bfloat16 is only supported on Ampere GPUs or newer" );
171168 }
172169 TORCH_CHECK (k.dtype () == q_dtype, " query and key must have the same dtype" );
173170 TORCH_CHECK (v.dtype () == q_dtype, " query and value must have the same dtype" );
@@ -342,17 +339,14 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
342339 std::optional<at::Generator> gen_) {
343340
344341 auto [cc_major, cc_minor] = get_compute_capability (get_current_device ());
345- bool is_sm8x = cc_major == 8 && cc_minor >= 0 ;
346- bool is_sm90 = cc_major == 9 && cc_minor == 0 ;
347- TORCH_CHECK (is_sm90 || is_sm8x, " FlashAttention only supports Ampere GPUs or newer." );
348- // We will support Turing in the near future
349- // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
342+ bool is_sm8x_min = cc_major >= 8 ;
343+ TORCH_CHECK (is_sm8x_min, " FlashAttention only supports Ampere GPUs or newer." );
350344
351345 auto q_dtype = q.dtype ();
352346 TORCH_CHECK (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 ,
353347 " FlashAttention only support fp16 and bf16 data type" );
354348 if (q_dtype == torch::kBFloat16 ) {
355- TORCH_CHECK (is_sm90 || is_sm8x , " bfloat16 is only supported on Ampere GPUs or newer" );
349+ TORCH_CHECK (is_sm8x_min , " bfloat16 is only supported on Ampere GPUs or newer" );
356350 }
357351 TORCH_CHECK (k.dtype () == q_dtype, " query and key must have the same dtype" );
358352 TORCH_CHECK (v.dtype () == q_dtype, " query and value must have the same dtype" );
@@ -528,4 +522,4 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
528522 return {out, softmax_lse};
529523}
530524
531- } // namespace FLASH_NAMESPACE
525+ } // namespace FLASH_NAMESPACE
0 commit comments