From 12b90c9e0c1304979c095ffbef5236d1e9ebd7c8 Mon Sep 17 00:00:00 2001 From: Cong Ma Date: Thu, 7 Dec 2023 02:16:22 +0000 Subject: [PATCH] Update CPU reference 1. Revert the default threshold of relative difference to (100 * std::numeric_limits::epsilon()) 2. Update CPU reference to make the difference between CPU reference and output of contraction instance is less than (100 * std::numeric_limits::epsilon()). --- .../contraction_cpu_reference_impl.hpp | 29 ++++++++++++++----- .../contraction_cpu_reference_instances.cpp | 14 +++++++++ .../configs/bilinear_test_params.yaml | 2 +- .../configs/scale_test_params.yaml | 2 +- test/utils.hpp | 9 +++--- 5 files changed, 42 insertions(+), 14 deletions(-) diff --git a/library/src/contraction/contraction_cpu_reference_impl.hpp b/library/src/contraction/contraction_cpu_reference_impl.hpp index d21df2d3..2e3d0cbe 100644 --- a/library/src/contraction/contraction_cpu_reference_impl.hpp +++ b/library/src/contraction/contraction_cpu_reference_impl.hpp @@ -45,19 +45,25 @@ namespace hiptensor { // hardcoded for NumDimM == NumDimN == NumDimK == 2 + // + // ck::bhalf_t is ushort, cannot perform bhalf_t * bhalf_t + // CK does not use ck::bhalf_t as AccDataType. But we still + // add this guard here template < ck::index_t NumDimM, ck::index_t NumDimN, ck::index_t NumDimK, typename ADataType, typename BDataType, + typename AccDataType, typename DsDataType, typename EDataType, typename AElementwiseOperation, typename BElementwiseOperation, typename CDEElementwiseOperation, typename ComputeDataType = ADataType, - ck::enable_if_t, bool> = false> struct ReferenceContraction_M2_N2_K2 @@ -151,7 +157,7 @@ namespace hiptensor }; auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { - float accum = 0.0f; + AccDataType accum = 0; auto K0 = arg.mA_ms_ks_lengths[2]; auto K1 = arg.mA_ms_ks_lengths[3]; @@ -165,16 +171,19 @@ namespace hiptensor auto indexB = offset(std::vector{n0, n1, k0, k1}, arg.mB_ns_ks_strides); - ADataType valA; - BDataType valB; + AccDataType valA; + AccDataType valB; // Element-wise ops - arg.mOpA(valA, ((ADataType*)arg.mA)[indexA]); - arg.mOpB(valB, ((BDataType*)arg.mB)[indexB]); + arg.mOpA( + valA, + ck::type_convert(((ADataType*)arg.mA)[indexA])); + arg.mOpB( + valB, + ck::type_convert(((BDataType*)arg.mB)[indexB])); // Mult / accum - accum += ck::type_convert(ck::type_convert( - ck::type_convert(valA) * ck::type_convert(valB))); + accum += valA * valB; } } @@ -322,6 +331,7 @@ namespace hiptensor ck::index_t NumDimsK, typename ADataType, typename BDataType, + typename AccDataType, typename DsDataType, typename EDataType, typename AElementwiseOperation, @@ -333,6 +343,7 @@ namespace hiptensor NumDimsK, ADataType, BDataType, + AccDataType, DsDataType, EDataType, AElementwiseOperation, @@ -359,6 +370,7 @@ namespace hiptensor ck::index_t NumDimK, typename ADataType, typename BDataType, + typename AccDataType, typename DsDataType, typename EDataType, typename AElementwiseOperation, @@ -372,6 +384,7 @@ namespace hiptensor NumDimK, ADataType, BDataType, + AccDataType, DsDataType, EDataType, AElementwiseOperation, diff --git a/library/src/contraction/contraction_cpu_reference_instances.cpp b/library/src/contraction/contraction_cpu_reference_instances.cpp index 173a49e9..31fb0191 100644 --- a/library/src/contraction/contraction_cpu_reference_instances.cpp +++ b/library/src/contraction/contraction_cpu_reference_instances.cpp @@ -39,6 +39,7 @@ namespace hiptensor 2, ck::half_t, ck::half_t, + float, ck::Tuple, ck::half_t, ck::tensor_operation::element_wise::PassThrough, @@ -53,6 +54,7 @@ namespace hiptensor 2, ck::bhalf_t, ck::bhalf_t, + float, ck::Tuple, ck::bhalf_t, ck::tensor_operation::element_wise::PassThrough, @@ -67,6 +69,7 @@ namespace hiptensor 2, float, float, + float, ck::Tuple, float, ck::tensor_operation::element_wise::PassThrough, @@ -80,6 +83,7 @@ namespace hiptensor 2, float, float, + float, ck::Tuple, float, ck::tensor_operation::element_wise::PassThrough, @@ -93,6 +97,7 @@ namespace hiptensor 2, float, float, + float, ck::Tuple, float, ck::tensor_operation::element_wise::PassThrough, @@ -107,6 +112,7 @@ namespace hiptensor 2, double, double, + float, ck::Tuple, double, ck::tensor_operation::element_wise::PassThrough, @@ -120,6 +126,7 @@ namespace hiptensor 2, double, double, + double, ck::Tuple, double, ck::tensor_operation::element_wise::PassThrough, @@ -134,6 +141,7 @@ namespace hiptensor 2, ck::half_t, ck::half_t, + float, ck::Tuple<>, ck::half_t, ck::tensor_operation::element_wise::PassThrough, @@ -148,6 +156,7 @@ namespace hiptensor 2, ck::bhalf_t, ck::bhalf_t, + float, ck::Tuple<>, ck::bhalf_t, ck::tensor_operation::element_wise::PassThrough, @@ -162,6 +171,7 @@ namespace hiptensor 2, float, float, + float, ck::Tuple<>, float, ck::tensor_operation::element_wise::PassThrough, @@ -175,6 +185,7 @@ namespace hiptensor 2, float, float, + float, ck::Tuple<>, float, ck::tensor_operation::element_wise::PassThrough, @@ -188,6 +199,7 @@ namespace hiptensor 2, float, float, + float, ck::Tuple<>, float, ck::tensor_operation::element_wise::PassThrough, @@ -202,6 +214,7 @@ namespace hiptensor 2, double, double, + float, ck::Tuple<>, double, ck::tensor_operation::element_wise::PassThrough, @@ -215,6 +228,7 @@ namespace hiptensor 2, double, double, + double, ck::Tuple<>, double, ck::tensor_operation::element_wise::PassThrough, diff --git a/test/01_contraction/configs/bilinear_test_params.yaml b/test/01_contraction/configs/bilinear_test_params.yaml index eee5d7f1..f4be1a88 100644 --- a/test/01_contraction/configs/bilinear_test_params.yaml +++ b/test/01_contraction/configs/bilinear_test_params.yaml @@ -29,7 +29,7 @@ Betas: Lengths: - [ 5, 6, 3, 4, 3, 4 ] - [ 4, 3, 4, 3, 6, 5 ] - - [ 24, 18, 2, 4, 9, 1 ] + - [ 24, 18, 2, 4, 9, 2 ] Strides: - [] ... diff --git a/test/01_contraction/configs/scale_test_params.yaml b/test/01_contraction/configs/scale_test_params.yaml index eee5d7f1..f4be1a88 100644 --- a/test/01_contraction/configs/scale_test_params.yaml +++ b/test/01_contraction/configs/scale_test_params.yaml @@ -29,7 +29,7 @@ Betas: Lengths: - [ 5, 6, 3, 4, 3, 4 ] - [ 4, 3, 4, 3, 6, 5 ] - - [ 24, 18, 2, 4, 9, 1 ] + - [ 24, 18, 2, 4, 9, 2 ] Strides: - [] ... diff --git a/test/utils.hpp b/test/utils.hpp index f39f0fb5..ad4bb565 100644 --- a/test/utils.hpp +++ b/test/utils.hpp @@ -140,7 +140,7 @@ template std::pair compareEqual(DDataType const* deviceD, DDataType const* hostD, std::size_t elementsD, - double tolerance = 0.001) + double tolerance = 100.0) { bool retval = true; double max_relative_error = 0.0; @@ -202,7 +202,7 @@ std::pair compareEqual(DDataType const* deviceD, retval = false; max_relative_error = std::numeric_limits::signaling_NaN(); } - else if(max_relative_error > tolerance) + else if(max_relative_error > (eps * tolerance)) { retval = false; } @@ -214,7 +214,7 @@ template std::pair compareEqualLaunchKernel(DDataType* deviceD, DDataType* hostD, std::size_t elementsD, - double tolerance = 0.001) + double tolerance = 100.0) { auto blockDim = dim3(1024, 1, 1); auto gridDim = dim3(ceilDiv(elementsD, blockDim.x), 1, 1); @@ -276,12 +276,13 @@ std::pair compareEqualLaunchKernel(DDataType* deviceD, auto toDouble = [](DDataType const& val) { return static_cast(static_cast(val)); }; + auto eps = toDouble(std::numeric_limits::epsilon()); if(isNaN) { retval = false; maxRelativeError = std::numeric_limits::signaling_NaN(); } - else if(maxRelativeError > tolerance) + else if(maxRelativeError > (eps * tolerance)) { retval = false; }