diff --git a/xla/service/gpu/tests/gpu_fused_mha_test.cc b/xla/service/gpu/tests/gpu_fused_mha_test.cc index 33214758e230f..feff7c3805b9e 100644 --- a/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -1545,7 +1545,8 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, custom-call.21.0 = ( f8e4m3fn[4,4,16,16]{3,1,2,0}, f32[1,1,1,1]{3,2,1,0}, - f32[1,1,1,1]{3,2,1,0} + f32[1,1,1,1]{3,2,1,0}, + u8[0]{0} ) custom-call( convert.18, convert.30, @@ -1723,7 +1724,8 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, custom-call.21.0 = ( f8e4m3fn[4,16,4,16]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, - f32[1,1,1,1]{3,2,1,0} + f32[1,1,1,1]{3,2,1,0}, + u8[0]{0} ) custom-call( convert.18, convert.30,