Skip to content

Commit

Permalink
Add checked numeric cast (#2315)
Browse files Browse the repository at this point in the history
* Add test for util_compare

* Add a checked numeric cast
  • Loading branch information
AntoinePrv authored Feb 22, 2023
1 parent 365342c commit 82d0628
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 0 deletions.
1 change: 1 addition & 0 deletions libmamba/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ set(LIBMAMBA_HEADERS
${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_random.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_scope.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_string.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_cast.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_compare.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/validate.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/virtual_packages.hpp
Expand Down
117 changes: 117 additions & 0 deletions libmamba/include/mamba/core/util_cast.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) 2023, QuantStack and Mamba Contributors
//
// Distributed under the terms of the BSD 3-Clause License.
//
// The full license is in the file LICENSE, distributed with this software.

#ifndef MAMBA_CORE_UTIL_CAST_HPP
#define MAMBA_CORE_UTIL_CAST_HPP

#include <limits>
#include <stdexcept>
#include <type_traits>

#include <fmt/format.h>

#include "util_compare.hpp"

namespace mamba::util
{
/**
* A safe cast between arithmetic types.
*
* If the conversion leads to an overflow, the cast will throw an ``std::overflow_error``.
* If the conversion to a floating point type loses precision, the cast will throw a
* ``std::runtime_error``.
*/
template <typename To, typename From>
constexpr auto safe_num_cast(const From& val) -> To;

/********************
* Implementation *
********************/

namespace detail
{
template <typename To, typename From>
constexpr auto make_overflow_error(const From& val)
{
return std::overflow_error{ fmt::format(
"Value to cast ({}) is out of destination range ([{}, {}])",
val,
std::numeric_limits<To>::lowest(),
std::numeric_limits<To>::max()
) };
};
}

template <typename To, typename From>
constexpr auto safe_num_cast(const From& val) -> To
{
static_assert(std::is_arithmetic_v<From>);
static_assert(std::is_arithmetic_v<To>);

constexpr auto to_lowest = std::numeric_limits<To>::lowest();
constexpr auto to_max = std::numeric_limits<To>::max();
constexpr auto from_lowest = std::numeric_limits<From>::lowest();
constexpr auto from_max = std::numeric_limits<From>::max();

if constexpr (std::is_same_v<From, To>)
{
return val;
}
else if constexpr (std::is_integral_v<From> && std::is_integral_v<To>)
{
if constexpr (cmp_less(from_lowest, to_lowest))
{
if (cmp_less(val, to_lowest))
{
throw detail::make_overflow_error<To>(val);
}
}

if constexpr (cmp_greater(from_max, to_max))
{
if (cmp_greater(val, to_max))
{
throw detail::make_overflow_error<To>(val);
}
}

return static_cast<To>(val);
}
else
{
using float_type = std::common_type_t<From, To>;
constexpr auto float_cast = [](const auto& x) { return static_cast<float_type>(x); };

if constexpr (float_cast(from_lowest) < float_cast(to_lowest))
{
if (float_cast(val) < float_cast(to_lowest))
{
throw detail::make_overflow_error<To>(val);
}
}

if constexpr (float_cast(from_max) > float_cast(to_max))
{
if (float_cast(val) > float_cast(to_max))
{
throw detail::make_overflow_error<To>(val);
}
}

To cast = static_cast<To>(val);
From cast_back = static_cast<From>(cast);
if (cast_back != val)
{
throw std::runtime_error{
fmt::format("Casting from {} to {} loses precision", val, cast)
};
}
return cast;
}
}
}

#endif
2 changes: 2 additions & 0 deletions libmamba/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ set(LIBMAMBA_TEST_SRCS
test_validate.cpp
test_virtual_packages.cpp
test_util.cpp
test_util_cast.cpp
test_util_compare.cpp
test_util_string.cpp
test_util_graph.cpp
test_env_lockfile.cpp
Expand Down
94 changes: 94 additions & 0 deletions libmamba/tests/test_util_cast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) 2023, QuantStack and Mamba Contributors
//
// Distributed under the terms of the BSD 3-Clause License.
//
// The full license is in the file LICENSE, distributed with this software.

#include <cmath>
#include <limits>
#include <utility>

#include <gtest/gtest.h>

#include "mamba/core/util_cast.hpp"

namespace mamba::util
{
template <typename T>
struct cast_valid : ::testing::Test
{
using First = typename T::first_type;
using Second = typename T::second_type;
};
using WidenTypes = ::testing::Types<
// integers
std::pair<char, int>,
std::pair<unsigned char, int>,
std::pair<unsigned char, unsigned int>,
std::pair<int, long long int>,
std::pair<unsigned int, long long int>,
std::pair<unsigned int, unsigned long long int>,
// floats
std::pair<float, double>,
// Mixed
std::pair<char, float>,
std::pair<unsigned char, float>,
std::pair<int, double>,
std::pair<unsigned int, double>>;
TYPED_TEST_SUITE(cast_valid, WidenTypes);

TYPED_TEST(cast_valid, checked_exact_num_cast_widen)
{
using From = typename TestFixture::First;
using To = typename TestFixture::Second;
static constexpr auto from_lowest = std::numeric_limits<From>::lowest();
static constexpr auto from_max = std::numeric_limits<From>::max();

EXPECT_EQ(safe_num_cast<To>(From(0)), To(0));
EXPECT_EQ(safe_num_cast<To>(From(1)), To(1));
EXPECT_EQ(safe_num_cast<To>(from_lowest), static_cast<To>(from_lowest));
EXPECT_EQ(safe_num_cast<To>(from_max), static_cast<To>(from_max));
}

TYPED_TEST(cast_valid, checked_exact_num_cast_narrow)
{
using From = typename TestFixture::Second; // inversed
using To = typename TestFixture::First; // inversed
EXPECT_EQ(safe_num_cast<To>(From(0)), To(0));
EXPECT_EQ(safe_num_cast<To>(From(1)), To(1));
}

template <typename T>
struct cast_overflow_lowest : ::testing::Test
{
using From = typename T::first_type;
using To = typename T::second_type;
};
using OverflowLowestTypes = ::testing::Types<
// integers
std::pair<char, unsigned char>,
std::pair<char, unsigned int>,
std::pair<int, char>,
std::pair<int, unsigned long long int>,
// floats
std::pair<double, float>,
// mixed
std::pair<double, int>,
std::pair<float, char>>;
TYPED_TEST_SUITE(cast_overflow_lowest, OverflowLowestTypes);

TYPED_TEST(cast_overflow_lowest, checked_exact_num_cast)
{
using From = typename TestFixture::From;
using To = typename TestFixture::To;
static constexpr auto from_lowest = std::numeric_limits<From>::lowest();

EXPECT_THROW(safe_num_cast<To>(from_lowest), std::overflow_error);
}

TEST(cast, precision)
{
EXPECT_THROW(safe_num_cast<int>(1.1), std::runtime_error);
EXPECT_THROW(safe_num_cast<float>(std::nextafter(double(1), 2)), std::runtime_error);
}
}
87 changes: 87 additions & 0 deletions libmamba/tests/test_util_compare.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (c) 2023, QuantStack and Mamba Contributors
//
// Distributed under the terms of the BSD 3-Clause License.
//
// The full license is in the file LICENSE, distributed with this software.

#include <limits>
#include <utility>

#include <gtest/gtest.h>

#include "mamba/core/util_compare.hpp"

namespace mamba::util
{
TEST(compare, equal)
{
EXPECT_TRUE(cmp_equal(char(0), char(0)));
EXPECT_TRUE(cmp_equal(char(1), char(1)));
EXPECT_TRUE(cmp_equal(char(-1), char(-1)));
EXPECT_TRUE(cmp_equal(int(0), int(0)));
EXPECT_TRUE(cmp_equal(int(1), int(1)));
EXPECT_TRUE(cmp_equal(int(-1), int(-1)));
EXPECT_TRUE(cmp_equal(std::size_t(0), std::size_t(0)));
EXPECT_TRUE(cmp_equal(std::size_t(1), std::size_t(1)));

EXPECT_TRUE(cmp_equal(char(0), int(0)));
EXPECT_TRUE(cmp_equal(char(1), int(1)));
EXPECT_TRUE(cmp_equal(char(-1), int(-1)));
EXPECT_TRUE(cmp_equal(std::size_t(0), char(0)));
EXPECT_TRUE(cmp_equal(std::size_t(1), char(1)));
EXPECT_TRUE(cmp_equal(std::size_t(0), int(0)));
EXPECT_TRUE(cmp_equal(std::size_t(1), int(1)));

EXPECT_FALSE(cmp_equal(char(0), char(1)));
EXPECT_FALSE(cmp_equal(char(1), char(-1)));
EXPECT_FALSE(cmp_equal(int(0), int(1)));
EXPECT_FALSE(cmp_equal(int(-1), int(1)));
EXPECT_FALSE(cmp_equal(std::size_t(0), std::size_t(1)));

EXPECT_FALSE(cmp_equal(char(0), int(1)));
EXPECT_FALSE(cmp_equal(char(1), int(-1)));
EXPECT_FALSE(cmp_equal(char(-1), int(1)));
EXPECT_FALSE(cmp_equal(std::size_t(1), int(-1)));
EXPECT_FALSE(cmp_equal(static_cast<std::size_t>(-1), int(-1)));
EXPECT_FALSE(cmp_equal(std::size_t(1), int(0)));
EXPECT_FALSE(cmp_equal(std::numeric_limits<std::size_t>::max(), int(0)));
}

TEST(compare, less)
{
EXPECT_TRUE(cmp_less(char(0), char(1)));
EXPECT_TRUE(cmp_less(char(-1), char(0)));
EXPECT_TRUE(cmp_less(int(0), int(1)));
EXPECT_TRUE(cmp_less(int(-1), int(1)));
EXPECT_TRUE(cmp_less(std::size_t(0), std::size_t(1)));

EXPECT_TRUE(cmp_less(char(0), int(1)));
EXPECT_TRUE(cmp_less(char(-1), int(0)));
EXPECT_TRUE(cmp_less(char(-1), int(1)));
EXPECT_TRUE(cmp_less(char(-1), std::size_t(1)));
EXPECT_TRUE(cmp_less(std::size_t(0), int(1)));
EXPECT_TRUE(cmp_less(std::numeric_limits<int>::min(), char(0)));
EXPECT_TRUE(cmp_less(std::numeric_limits<int>::min(), std::size_t(0)));
EXPECT_TRUE(cmp_less(int(-1), std::numeric_limits<std::size_t>::max()));
EXPECT_TRUE(cmp_less(std::size_t(1), std::numeric_limits<int>::max()));

EXPECT_FALSE(cmp_less(char(1), char(0)));
EXPECT_FALSE(cmp_less(char(1), char(1)));
EXPECT_FALSE(cmp_less(char(0), char(-1)));
EXPECT_FALSE(cmp_less(int(1), int(0)));
EXPECT_FALSE(cmp_less(int(1), int(-1)));
EXPECT_FALSE(cmp_less(std::size_t(1), std::size_t(0)));

EXPECT_FALSE(cmp_less(char(1), int(1)));
EXPECT_FALSE(cmp_less(char(1), int(0)));
EXPECT_FALSE(cmp_less(char(0), int(-1)));
EXPECT_FALSE(cmp_less(char(1), int(-11)));
EXPECT_FALSE(cmp_less(std::size_t(1), char(-1)));
EXPECT_FALSE(cmp_less(int(1), std::size_t(0)));
EXPECT_FALSE(cmp_less(char(0), std::numeric_limits<int>::min()));
EXPECT_FALSE(cmp_less(std::size_t(0), std::numeric_limits<int>::min()));
EXPECT_FALSE(cmp_less(std::numeric_limits<std::size_t>::max(), int(-1)));
EXPECT_FALSE(cmp_less(std::numeric_limits<int>::max(), std::size_t(1)));
}

}

0 comments on commit 82d0628

Please sign in to comment.