diff --git a/third_party/xla/xla/service/gpu/gpu_float_support.cc b/third_party/xla/xla/service/gpu/gpu_float_support.cc index 1403ad021a217d..5e8a2c59179f0d 100644 --- a/third_party/xla/xla/service/gpu/gpu_float_support.cc +++ b/third_party/xla/xla/service/gpu/gpu_float_support.cc @@ -99,10 +99,12 @@ bool GpuFloatSupport::IsSupported(const HloInstruction& hlo) const { case HloOpcode::kSubtract: case HloOpcode::kMultiply: { if (LowPrecisionType() == BF16) { - auto* cuda_compute_capability = - std::get_if(&compute_capability_); - return cuda_compute_capability != nullptr && - cuda_compute_capability->IsAtLeastHopper(); + if (std::holds_alternative(compute_capability_)){ + return std::get(compute_capability_).IsAtLeastHopper(); + } + else if (std::holds_alternative(compute_capability_)){ + return std::get(compute_capability_).gfx9_mi200_or_later(); + } } return false; } diff --git a/third_party/xla/xla/service/gpu/gpu_float_support_test.cc b/third_party/xla/xla/service/gpu/gpu_float_support_test.cc index 1d2f6c167bb090..da4d266e697268 100644 --- a/third_party/xla/xla/service/gpu/gpu_float_support_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_float_support_test.cc @@ -253,5 +253,11 @@ TEST_F(FloatSupportTest, ShouldKeepBf16OnHopper) { /*should_convert_rhs=*/false, BF16); } +TEST_F(FloatSupportTest, ShouldKeepBf16OnMI200orLater) { + TestDotConversion(BF16, BF16, F32, se::RocmComputeCapability("gfx940"), + /*should_convert_lhs=*/false, + /*should_convert_rhs=*/false, BF16); + } + } // namespace } // namespace xla::gpu