diff --git a/cpp/src/glm/qn/mg/standardization.cuh b/cpp/src/glm/qn/mg/standardization.cuh index 8721873655..5fd0b5865e 100644 --- a/cpp/src/glm/qn/mg/standardization.cuh +++ b/cpp/src/glm/qn/mg/standardization.cuh @@ -88,8 +88,7 @@ void vars(const raft::handle_t& handle, T scaled_m = weight * m * m; T diff = v - scaled_m; // avoid negative variance that is due to precision loss of floating point arithmetic - if (diff < 0) { diff += scaled_m; } - return diff; + return diff >= 0. ? diff : v; }, stream); }