-
Notifications
You must be signed in to change notification settings - Fork 371
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add test for util_compare * Add a checked numeric cast
- Loading branch information
1 parent
365342c
commit 82d0628
Showing
5 changed files
with
301 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
// Copyright (c) 2023, QuantStack and Mamba Contributors | ||
// | ||
// Distributed under the terms of the BSD 3-Clause License. | ||
// | ||
// The full license is in the file LICENSE, distributed with this software. | ||
|
||
#ifndef MAMBA_CORE_UTIL_CAST_HPP | ||
#define MAMBA_CORE_UTIL_CAST_HPP | ||
|
||
#include <limits> | ||
#include <stdexcept> | ||
#include <type_traits> | ||
|
||
#include <fmt/format.h> | ||
|
||
#include "util_compare.hpp" | ||
|
||
namespace mamba::util | ||
{ | ||
/** | ||
* A safe cast between arithmetic types. | ||
* | ||
* If the conversion leads to an overflow, the cast will throw an ``std::overflow_error``. | ||
* If the conversion to a floating point type loses precision, the cast will throw a | ||
* ``std::runtime_error``. | ||
*/ | ||
template <typename To, typename From> | ||
constexpr auto safe_num_cast(const From& val) -> To; | ||
|
||
/******************** | ||
* Implementation * | ||
********************/ | ||
|
||
namespace detail | ||
{ | ||
template <typename To, typename From> | ||
constexpr auto make_overflow_error(const From& val) | ||
{ | ||
return std::overflow_error{ fmt::format( | ||
"Value to cast ({}) is out of destination range ([{}, {}])", | ||
val, | ||
std::numeric_limits<To>::lowest(), | ||
std::numeric_limits<To>::max() | ||
) }; | ||
}; | ||
} | ||
|
||
template <typename To, typename From> | ||
constexpr auto safe_num_cast(const From& val) -> To | ||
{ | ||
static_assert(std::is_arithmetic_v<From>); | ||
static_assert(std::is_arithmetic_v<To>); | ||
|
||
constexpr auto to_lowest = std::numeric_limits<To>::lowest(); | ||
constexpr auto to_max = std::numeric_limits<To>::max(); | ||
constexpr auto from_lowest = std::numeric_limits<From>::lowest(); | ||
constexpr auto from_max = std::numeric_limits<From>::max(); | ||
|
||
if constexpr (std::is_same_v<From, To>) | ||
{ | ||
return val; | ||
} | ||
else if constexpr (std::is_integral_v<From> && std::is_integral_v<To>) | ||
{ | ||
if constexpr (cmp_less(from_lowest, to_lowest)) | ||
{ | ||
if (cmp_less(val, to_lowest)) | ||
{ | ||
throw detail::make_overflow_error<To>(val); | ||
} | ||
} | ||
|
||
if constexpr (cmp_greater(from_max, to_max)) | ||
{ | ||
if (cmp_greater(val, to_max)) | ||
{ | ||
throw detail::make_overflow_error<To>(val); | ||
} | ||
} | ||
|
||
return static_cast<To>(val); | ||
} | ||
else | ||
{ | ||
using float_type = std::common_type_t<From, To>; | ||
constexpr auto float_cast = [](const auto& x) { return static_cast<float_type>(x); }; | ||
|
||
if constexpr (float_cast(from_lowest) < float_cast(to_lowest)) | ||
{ | ||
if (float_cast(val) < float_cast(to_lowest)) | ||
{ | ||
throw detail::make_overflow_error<To>(val); | ||
} | ||
} | ||
|
||
if constexpr (float_cast(from_max) > float_cast(to_max)) | ||
{ | ||
if (float_cast(val) > float_cast(to_max)) | ||
{ | ||
throw detail::make_overflow_error<To>(val); | ||
} | ||
} | ||
|
||
To cast = static_cast<To>(val); | ||
From cast_back = static_cast<From>(cast); | ||
if (cast_back != val) | ||
{ | ||
throw std::runtime_error{ | ||
fmt::format("Casting from {} to {} loses precision", val, cast) | ||
}; | ||
} | ||
return cast; | ||
} | ||
} | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
// Copyright (c) 2023, QuantStack and Mamba Contributors | ||
// | ||
// Distributed under the terms of the BSD 3-Clause License. | ||
// | ||
// The full license is in the file LICENSE, distributed with this software. | ||
|
||
#include <cmath> | ||
#include <limits> | ||
#include <utility> | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include "mamba/core/util_cast.hpp" | ||
|
||
namespace mamba::util | ||
{ | ||
template <typename T> | ||
struct cast_valid : ::testing::Test | ||
{ | ||
using First = typename T::first_type; | ||
using Second = typename T::second_type; | ||
}; | ||
using WidenTypes = ::testing::Types< | ||
// integers | ||
std::pair<char, int>, | ||
std::pair<unsigned char, int>, | ||
std::pair<unsigned char, unsigned int>, | ||
std::pair<int, long long int>, | ||
std::pair<unsigned int, long long int>, | ||
std::pair<unsigned int, unsigned long long int>, | ||
// floats | ||
std::pair<float, double>, | ||
// Mixed | ||
std::pair<char, float>, | ||
std::pair<unsigned char, float>, | ||
std::pair<int, double>, | ||
std::pair<unsigned int, double>>; | ||
TYPED_TEST_SUITE(cast_valid, WidenTypes); | ||
|
||
TYPED_TEST(cast_valid, checked_exact_num_cast_widen) | ||
{ | ||
using From = typename TestFixture::First; | ||
using To = typename TestFixture::Second; | ||
static constexpr auto from_lowest = std::numeric_limits<From>::lowest(); | ||
static constexpr auto from_max = std::numeric_limits<From>::max(); | ||
|
||
EXPECT_EQ(safe_num_cast<To>(From(0)), To(0)); | ||
EXPECT_EQ(safe_num_cast<To>(From(1)), To(1)); | ||
EXPECT_EQ(safe_num_cast<To>(from_lowest), static_cast<To>(from_lowest)); | ||
EXPECT_EQ(safe_num_cast<To>(from_max), static_cast<To>(from_max)); | ||
} | ||
|
||
TYPED_TEST(cast_valid, checked_exact_num_cast_narrow) | ||
{ | ||
using From = typename TestFixture::Second; // inversed | ||
using To = typename TestFixture::First; // inversed | ||
EXPECT_EQ(safe_num_cast<To>(From(0)), To(0)); | ||
EXPECT_EQ(safe_num_cast<To>(From(1)), To(1)); | ||
} | ||
|
||
template <typename T> | ||
struct cast_overflow_lowest : ::testing::Test | ||
{ | ||
using From = typename T::first_type; | ||
using To = typename T::second_type; | ||
}; | ||
using OverflowLowestTypes = ::testing::Types< | ||
// integers | ||
std::pair<char, unsigned char>, | ||
std::pair<char, unsigned int>, | ||
std::pair<int, char>, | ||
std::pair<int, unsigned long long int>, | ||
// floats | ||
std::pair<double, float>, | ||
// mixed | ||
std::pair<double, int>, | ||
std::pair<float, char>>; | ||
TYPED_TEST_SUITE(cast_overflow_lowest, OverflowLowestTypes); | ||
|
||
TYPED_TEST(cast_overflow_lowest, checked_exact_num_cast) | ||
{ | ||
using From = typename TestFixture::From; | ||
using To = typename TestFixture::To; | ||
static constexpr auto from_lowest = std::numeric_limits<From>::lowest(); | ||
|
||
EXPECT_THROW(safe_num_cast<To>(from_lowest), std::overflow_error); | ||
} | ||
|
||
TEST(cast, precision) | ||
{ | ||
EXPECT_THROW(safe_num_cast<int>(1.1), std::runtime_error); | ||
EXPECT_THROW(safe_num_cast<float>(std::nextafter(double(1), 2)), std::runtime_error); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
// Copyright (c) 2023, QuantStack and Mamba Contributors | ||
// | ||
// Distributed under the terms of the BSD 3-Clause License. | ||
// | ||
// The full license is in the file LICENSE, distributed with this software. | ||
|
||
#include <limits> | ||
#include <utility> | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include "mamba/core/util_compare.hpp" | ||
|
||
namespace mamba::util | ||
{ | ||
TEST(compare, equal) | ||
{ | ||
EXPECT_TRUE(cmp_equal(char(0), char(0))); | ||
EXPECT_TRUE(cmp_equal(char(1), char(1))); | ||
EXPECT_TRUE(cmp_equal(char(-1), char(-1))); | ||
EXPECT_TRUE(cmp_equal(int(0), int(0))); | ||
EXPECT_TRUE(cmp_equal(int(1), int(1))); | ||
EXPECT_TRUE(cmp_equal(int(-1), int(-1))); | ||
EXPECT_TRUE(cmp_equal(std::size_t(0), std::size_t(0))); | ||
EXPECT_TRUE(cmp_equal(std::size_t(1), std::size_t(1))); | ||
|
||
EXPECT_TRUE(cmp_equal(char(0), int(0))); | ||
EXPECT_TRUE(cmp_equal(char(1), int(1))); | ||
EXPECT_TRUE(cmp_equal(char(-1), int(-1))); | ||
EXPECT_TRUE(cmp_equal(std::size_t(0), char(0))); | ||
EXPECT_TRUE(cmp_equal(std::size_t(1), char(1))); | ||
EXPECT_TRUE(cmp_equal(std::size_t(0), int(0))); | ||
EXPECT_TRUE(cmp_equal(std::size_t(1), int(1))); | ||
|
||
EXPECT_FALSE(cmp_equal(char(0), char(1))); | ||
EXPECT_FALSE(cmp_equal(char(1), char(-1))); | ||
EXPECT_FALSE(cmp_equal(int(0), int(1))); | ||
EXPECT_FALSE(cmp_equal(int(-1), int(1))); | ||
EXPECT_FALSE(cmp_equal(std::size_t(0), std::size_t(1))); | ||
|
||
EXPECT_FALSE(cmp_equal(char(0), int(1))); | ||
EXPECT_FALSE(cmp_equal(char(1), int(-1))); | ||
EXPECT_FALSE(cmp_equal(char(-1), int(1))); | ||
EXPECT_FALSE(cmp_equal(std::size_t(1), int(-1))); | ||
EXPECT_FALSE(cmp_equal(static_cast<std::size_t>(-1), int(-1))); | ||
EXPECT_FALSE(cmp_equal(std::size_t(1), int(0))); | ||
EXPECT_FALSE(cmp_equal(std::numeric_limits<std::size_t>::max(), int(0))); | ||
} | ||
|
||
TEST(compare, less) | ||
{ | ||
EXPECT_TRUE(cmp_less(char(0), char(1))); | ||
EXPECT_TRUE(cmp_less(char(-1), char(0))); | ||
EXPECT_TRUE(cmp_less(int(0), int(1))); | ||
EXPECT_TRUE(cmp_less(int(-1), int(1))); | ||
EXPECT_TRUE(cmp_less(std::size_t(0), std::size_t(1))); | ||
|
||
EXPECT_TRUE(cmp_less(char(0), int(1))); | ||
EXPECT_TRUE(cmp_less(char(-1), int(0))); | ||
EXPECT_TRUE(cmp_less(char(-1), int(1))); | ||
EXPECT_TRUE(cmp_less(char(-1), std::size_t(1))); | ||
EXPECT_TRUE(cmp_less(std::size_t(0), int(1))); | ||
EXPECT_TRUE(cmp_less(std::numeric_limits<int>::min(), char(0))); | ||
EXPECT_TRUE(cmp_less(std::numeric_limits<int>::min(), std::size_t(0))); | ||
EXPECT_TRUE(cmp_less(int(-1), std::numeric_limits<std::size_t>::max())); | ||
EXPECT_TRUE(cmp_less(std::size_t(1), std::numeric_limits<int>::max())); | ||
|
||
EXPECT_FALSE(cmp_less(char(1), char(0))); | ||
EXPECT_FALSE(cmp_less(char(1), char(1))); | ||
EXPECT_FALSE(cmp_less(char(0), char(-1))); | ||
EXPECT_FALSE(cmp_less(int(1), int(0))); | ||
EXPECT_FALSE(cmp_less(int(1), int(-1))); | ||
EXPECT_FALSE(cmp_less(std::size_t(1), std::size_t(0))); | ||
|
||
EXPECT_FALSE(cmp_less(char(1), int(1))); | ||
EXPECT_FALSE(cmp_less(char(1), int(0))); | ||
EXPECT_FALSE(cmp_less(char(0), int(-1))); | ||
EXPECT_FALSE(cmp_less(char(1), int(-11))); | ||
EXPECT_FALSE(cmp_less(std::size_t(1), char(-1))); | ||
EXPECT_FALSE(cmp_less(int(1), std::size_t(0))); | ||
EXPECT_FALSE(cmp_less(char(0), std::numeric_limits<int>::min())); | ||
EXPECT_FALSE(cmp_less(std::size_t(0), std::numeric_limits<int>::min())); | ||
EXPECT_FALSE(cmp_less(std::numeric_limits<std::size_t>::max(), int(-1))); | ||
EXPECT_FALSE(cmp_less(std::numeric_limits<int>::max(), std::size_t(1))); | ||
} | ||
|
||
} |