@@ -5,21 +5,19 @@ namespace torch_tensorrt {
55namespace tests {
66namespace util {
77
8- bool checkRtol (const at::Tensor& diff, const std::vector<at::Tensor> inputs, float threshold) {
9- double maxValue = 0.0 ;
10- for (auto & tensor : inputs) {
11- maxValue = fmax (tensor.abs ().max ().item <float >(), maxValue);
12- }
13- std::cout << " Max Difference: " << diff.abs ().max ().item <float >() << std::endl;
14- std::cout << " Acceptable Threshold: " << threshold << std::endl;
15- return diff.abs ().max ().item <float >() <= threshold * maxValue;
16- }
178
18- bool almostEqual (const at::Tensor& a, const at::Tensor& b, float threshold) {
9+ bool almostEqual (const at::Tensor& a, const at::Tensor& b, float threshold, float atol= 1e-8 , float rtol= 1e-5 ) {
1910 LOG_GRAPH (a << std::endl << b << std::endl);
2011 auto a_float = a.toType (at::kFloat );
2112 auto b_float = b.toType (at::kFloat );
22- return checkRtol (a_float - b_float, {a_float, b_float}, threshold);
13+
14+ auto diff = a_float - b_float;
15+ auto result = diff.abs ().max ().item <float >() - (atol + rtol * b.abs ().max ().item <float >());
16+
17+ std::cout << " Max Difference: " << result << std::endl;
18+ std::cout << " Acceptable Threshold: " << threshold << std::endl;
19+
20+ return result <= threshold;
2321}
2422
2523bool exactlyEqual (const at::Tensor& a, const at::Tensor& b) {
0 commit comments