From 82d062892477515a1a6038f0001b7b9749f1afae Mon Sep 17 00:00:00 2001 From: Antoine Prouvost Date: Wed, 22 Feb 2023 15:41:17 +0100 Subject: [PATCH] Add checked numeric cast (#2315) * Add test for util_compare * Add a checked numeric cast --- libmamba/CMakeLists.txt | 1 + libmamba/include/mamba/core/util_cast.hpp | 117 ++++++++++++++++++++++ libmamba/tests/CMakeLists.txt | 2 + libmamba/tests/test_util_cast.cpp | 94 +++++++++++++++++ libmamba/tests/test_util_compare.cpp | 87 ++++++++++++++++ 5 files changed, 301 insertions(+) create mode 100644 libmamba/include/mamba/core/util_cast.hpp create mode 100644 libmamba/tests/test_util_cast.cpp create mode 100644 libmamba/tests/test_util_compare.cpp diff --git a/libmamba/CMakeLists.txt b/libmamba/CMakeLists.txt index 0759ea9159..8d059cf3e8 100644 --- a/libmamba/CMakeLists.txt +++ b/libmamba/CMakeLists.txt @@ -228,6 +228,7 @@ set(LIBMAMBA_HEADERS ${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_random.hpp ${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_scope.hpp ${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_string.hpp + ${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_cast.hpp ${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_compare.hpp ${LIBMAMBA_INCLUDE_DIR}/mamba/core/validate.hpp ${LIBMAMBA_INCLUDE_DIR}/mamba/core/virtual_packages.hpp diff --git a/libmamba/include/mamba/core/util_cast.hpp b/libmamba/include/mamba/core/util_cast.hpp new file mode 100644 index 0000000000..e7ab60da69 --- /dev/null +++ b/libmamba/include/mamba/core/util_cast.hpp @@ -0,0 +1,117 @@ +// Copyright (c) 2023, QuantStack and Mamba Contributors +// +// Distributed under the terms of the BSD 3-Clause License. +// +// The full license is in the file LICENSE, distributed with this software. + +#ifndef MAMBA_CORE_UTIL_CAST_HPP +#define MAMBA_CORE_UTIL_CAST_HPP + +#include +#include +#include + +#include + +#include "util_compare.hpp" + +namespace mamba::util +{ + /** + * A safe cast between arithmetic types. + * + * If the conversion leads to an overflow, the cast will throw an ``std::overflow_error``. + * If the conversion to a floating point type loses precision, the cast will throw a + * ``std::runtime_error``. + */ + template + constexpr auto safe_num_cast(const From& val) -> To; + + /******************** + * Implementation * + ********************/ + + namespace detail + { + template + constexpr auto make_overflow_error(const From& val) + { + return std::overflow_error{ fmt::format( + "Value to cast ({}) is out of destination range ([{}, {}])", + val, + std::numeric_limits::lowest(), + std::numeric_limits::max() + ) }; + }; + } + + template + constexpr auto safe_num_cast(const From& val) -> To + { + static_assert(std::is_arithmetic_v); + static_assert(std::is_arithmetic_v); + + constexpr auto to_lowest = std::numeric_limits::lowest(); + constexpr auto to_max = std::numeric_limits::max(); + constexpr auto from_lowest = std::numeric_limits::lowest(); + constexpr auto from_max = std::numeric_limits::max(); + + if constexpr (std::is_same_v) + { + return val; + } + else if constexpr (std::is_integral_v && std::is_integral_v) + { + if constexpr (cmp_less(from_lowest, to_lowest)) + { + if (cmp_less(val, to_lowest)) + { + throw detail::make_overflow_error(val); + } + } + + if constexpr (cmp_greater(from_max, to_max)) + { + if (cmp_greater(val, to_max)) + { + throw detail::make_overflow_error(val); + } + } + + return static_cast(val); + } + else + { + using float_type = std::common_type_t; + constexpr auto float_cast = [](const auto& x) { return static_cast(x); }; + + if constexpr (float_cast(from_lowest) < float_cast(to_lowest)) + { + if (float_cast(val) < float_cast(to_lowest)) + { + throw detail::make_overflow_error(val); + } + } + + if constexpr (float_cast(from_max) > float_cast(to_max)) + { + if (float_cast(val) > float_cast(to_max)) + { + throw detail::make_overflow_error(val); + } + } + + To cast = static_cast(val); + From cast_back = static_cast(cast); + if (cast_back != val) + { + throw std::runtime_error{ + fmt::format("Casting from {} to {} loses precision", val, cast) + }; + } + return cast; + } + } +} + +#endif diff --git a/libmamba/tests/CMakeLists.txt b/libmamba/tests/CMakeLists.txt index 9a0c5c6b9c..905af43f5b 100644 --- a/libmamba/tests/CMakeLists.txt +++ b/libmamba/tests/CMakeLists.txt @@ -31,6 +31,8 @@ set(LIBMAMBA_TEST_SRCS test_validate.cpp test_virtual_packages.cpp test_util.cpp + test_util_cast.cpp + test_util_compare.cpp test_util_string.cpp test_util_graph.cpp test_env_lockfile.cpp diff --git a/libmamba/tests/test_util_cast.cpp b/libmamba/tests/test_util_cast.cpp new file mode 100644 index 0000000000..3c8ab36d26 --- /dev/null +++ b/libmamba/tests/test_util_cast.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2023, QuantStack and Mamba Contributors +// +// Distributed under the terms of the BSD 3-Clause License. +// +// The full license is in the file LICENSE, distributed with this software. + +#include +#include +#include + +#include + +#include "mamba/core/util_cast.hpp" + +namespace mamba::util +{ + template + struct cast_valid : ::testing::Test + { + using First = typename T::first_type; + using Second = typename T::second_type; + }; + using WidenTypes = ::testing::Types< + // integers + std::pair, + std::pair, + std::pair, + std::pair, + std::pair, + std::pair, + // floats + std::pair, + // Mixed + std::pair, + std::pair, + std::pair, + std::pair>; + TYPED_TEST_SUITE(cast_valid, WidenTypes); + + TYPED_TEST(cast_valid, checked_exact_num_cast_widen) + { + using From = typename TestFixture::First; + using To = typename TestFixture::Second; + static constexpr auto from_lowest = std::numeric_limits::lowest(); + static constexpr auto from_max = std::numeric_limits::max(); + + EXPECT_EQ(safe_num_cast(From(0)), To(0)); + EXPECT_EQ(safe_num_cast(From(1)), To(1)); + EXPECT_EQ(safe_num_cast(from_lowest), static_cast(from_lowest)); + EXPECT_EQ(safe_num_cast(from_max), static_cast(from_max)); + } + + TYPED_TEST(cast_valid, checked_exact_num_cast_narrow) + { + using From = typename TestFixture::Second; // inversed + using To = typename TestFixture::First; // inversed + EXPECT_EQ(safe_num_cast(From(0)), To(0)); + EXPECT_EQ(safe_num_cast(From(1)), To(1)); + } + + template + struct cast_overflow_lowest : ::testing::Test + { + using From = typename T::first_type; + using To = typename T::second_type; + }; + using OverflowLowestTypes = ::testing::Types< + // integers + std::pair, + std::pair, + std::pair, + std::pair, + // floats + std::pair, + // mixed + std::pair, + std::pair>; + TYPED_TEST_SUITE(cast_overflow_lowest, OverflowLowestTypes); + + TYPED_TEST(cast_overflow_lowest, checked_exact_num_cast) + { + using From = typename TestFixture::From; + using To = typename TestFixture::To; + static constexpr auto from_lowest = std::numeric_limits::lowest(); + + EXPECT_THROW(safe_num_cast(from_lowest), std::overflow_error); + } + + TEST(cast, precision) + { + EXPECT_THROW(safe_num_cast(1.1), std::runtime_error); + EXPECT_THROW(safe_num_cast(std::nextafter(double(1), 2)), std::runtime_error); + } +} diff --git a/libmamba/tests/test_util_compare.cpp b/libmamba/tests/test_util_compare.cpp new file mode 100644 index 0000000000..feea3e3a03 --- /dev/null +++ b/libmamba/tests/test_util_compare.cpp @@ -0,0 +1,87 @@ +// Copyright (c) 2023, QuantStack and Mamba Contributors +// +// Distributed under the terms of the BSD 3-Clause License. +// +// The full license is in the file LICENSE, distributed with this software. + +#include +#include + +#include + +#include "mamba/core/util_compare.hpp" + +namespace mamba::util +{ + TEST(compare, equal) + { + EXPECT_TRUE(cmp_equal(char(0), char(0))); + EXPECT_TRUE(cmp_equal(char(1), char(1))); + EXPECT_TRUE(cmp_equal(char(-1), char(-1))); + EXPECT_TRUE(cmp_equal(int(0), int(0))); + EXPECT_TRUE(cmp_equal(int(1), int(1))); + EXPECT_TRUE(cmp_equal(int(-1), int(-1))); + EXPECT_TRUE(cmp_equal(std::size_t(0), std::size_t(0))); + EXPECT_TRUE(cmp_equal(std::size_t(1), std::size_t(1))); + + EXPECT_TRUE(cmp_equal(char(0), int(0))); + EXPECT_TRUE(cmp_equal(char(1), int(1))); + EXPECT_TRUE(cmp_equal(char(-1), int(-1))); + EXPECT_TRUE(cmp_equal(std::size_t(0), char(0))); + EXPECT_TRUE(cmp_equal(std::size_t(1), char(1))); + EXPECT_TRUE(cmp_equal(std::size_t(0), int(0))); + EXPECT_TRUE(cmp_equal(std::size_t(1), int(1))); + + EXPECT_FALSE(cmp_equal(char(0), char(1))); + EXPECT_FALSE(cmp_equal(char(1), char(-1))); + EXPECT_FALSE(cmp_equal(int(0), int(1))); + EXPECT_FALSE(cmp_equal(int(-1), int(1))); + EXPECT_FALSE(cmp_equal(std::size_t(0), std::size_t(1))); + + EXPECT_FALSE(cmp_equal(char(0), int(1))); + EXPECT_FALSE(cmp_equal(char(1), int(-1))); + EXPECT_FALSE(cmp_equal(char(-1), int(1))); + EXPECT_FALSE(cmp_equal(std::size_t(1), int(-1))); + EXPECT_FALSE(cmp_equal(static_cast(-1), int(-1))); + EXPECT_FALSE(cmp_equal(std::size_t(1), int(0))); + EXPECT_FALSE(cmp_equal(std::numeric_limits::max(), int(0))); + } + + TEST(compare, less) + { + EXPECT_TRUE(cmp_less(char(0), char(1))); + EXPECT_TRUE(cmp_less(char(-1), char(0))); + EXPECT_TRUE(cmp_less(int(0), int(1))); + EXPECT_TRUE(cmp_less(int(-1), int(1))); + EXPECT_TRUE(cmp_less(std::size_t(0), std::size_t(1))); + + EXPECT_TRUE(cmp_less(char(0), int(1))); + EXPECT_TRUE(cmp_less(char(-1), int(0))); + EXPECT_TRUE(cmp_less(char(-1), int(1))); + EXPECT_TRUE(cmp_less(char(-1), std::size_t(1))); + EXPECT_TRUE(cmp_less(std::size_t(0), int(1))); + EXPECT_TRUE(cmp_less(std::numeric_limits::min(), char(0))); + EXPECT_TRUE(cmp_less(std::numeric_limits::min(), std::size_t(0))); + EXPECT_TRUE(cmp_less(int(-1), std::numeric_limits::max())); + EXPECT_TRUE(cmp_less(std::size_t(1), std::numeric_limits::max())); + + EXPECT_FALSE(cmp_less(char(1), char(0))); + EXPECT_FALSE(cmp_less(char(1), char(1))); + EXPECT_FALSE(cmp_less(char(0), char(-1))); + EXPECT_FALSE(cmp_less(int(1), int(0))); + EXPECT_FALSE(cmp_less(int(1), int(-1))); + EXPECT_FALSE(cmp_less(std::size_t(1), std::size_t(0))); + + EXPECT_FALSE(cmp_less(char(1), int(1))); + EXPECT_FALSE(cmp_less(char(1), int(0))); + EXPECT_FALSE(cmp_less(char(0), int(-1))); + EXPECT_FALSE(cmp_less(char(1), int(-11))); + EXPECT_FALSE(cmp_less(std::size_t(1), char(-1))); + EXPECT_FALSE(cmp_less(int(1), std::size_t(0))); + EXPECT_FALSE(cmp_less(char(0), std::numeric_limits::min())); + EXPECT_FALSE(cmp_less(std::size_t(0), std::numeric_limits::min())); + EXPECT_FALSE(cmp_less(std::numeric_limits::max(), int(-1))); + EXPECT_FALSE(cmp_less(std::numeric_limits::max(), std::size_t(1))); + } + +}