Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checked numeric cast #2315

Merged
merged 2 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libmamba/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ set(LIBMAMBA_HEADERS
${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_random.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_scope.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_string.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_cast.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/util_compare.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/validate.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/virtual_packages.hpp
Expand Down
117 changes: 117 additions & 0 deletions libmamba/include/mamba/core/util_cast.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) 2023, QuantStack and Mamba Contributors
//
// Distributed under the terms of the BSD 3-Clause License.
//
// The full license is in the file LICENSE, distributed with this software.

#ifndef MAMBA_CORE_UTIL_CAST_HPP
#define MAMBA_CORE_UTIL_CAST_HPP

#include <limits>
#include <stdexcept>
#include <type_traits>

#include <fmt/format.h>

#include "util_compare.hpp"

namespace mamba::util
{
/**
* A safe cast between arithmetic types.
*
* If the conversion leads to an overflow, the cast will throw an ``std::overflow_error``.
* If the conversion to a floating point type loses precision, the cast will throw a
* ``std::runtime_error``.
*/
template <typename To, typename From>
constexpr auto safe_num_cast(const From& val) -> To;

/********************
* Implementation *
********************/

namespace detail
{
template <typename To, typename From>
constexpr auto make_overflow_error(const From& val)
{
return std::overflow_error{ fmt::format(
"Value to cast ({}) is out of destination range ([{}, {}])",
val,
std::numeric_limits<To>::lowest(),
std::numeric_limits<To>::max()
) };
};
}

template <typename To, typename From>
constexpr auto safe_num_cast(const From& val) -> To
{
static_assert(std::is_arithmetic_v<From>);
static_assert(std::is_arithmetic_v<To>);

constexpr auto to_lowest = std::numeric_limits<To>::lowest();
constexpr auto to_max = std::numeric_limits<To>::max();
constexpr auto from_lowest = std::numeric_limits<From>::lowest();
constexpr auto from_max = std::numeric_limits<From>::max();

if constexpr (std::is_same_v<From, To>)
{
return val;
}
else if constexpr (std::is_integral_v<From> && std::is_integral_v<To>)
{
if constexpr (cmp_less(from_lowest, to_lowest))
{
if (cmp_less(val, to_lowest))
{
throw detail::make_overflow_error<To>(val);
}
}

if constexpr (cmp_greater(from_max, to_max))
{
if (cmp_greater(val, to_max))
{
throw detail::make_overflow_error<To>(val);
}
}

return static_cast<To>(val);
}
else
{
using float_type = std::common_type_t<From, To>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we do a lot of casting between float types? Seems to me we should rather get rid of it if possible.

The integer ones are obviously necessary since libsolv uses 32 bit integers in quite some places.

Copy link
Member Author

@AntoinePrv AntoinePrv Feb 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we do a lot of casting between float types? Seems to me we should rather get rid of it if possible.

I think I've seen a couple casting from integer to float but perhaps they need not be exact.

The integer ones are obviously necessary since libsolv uses 32 bit integers in quite some places.

Yes, even between signed and unsigned.

constexpr auto float_cast = [](const auto& x) { return static_cast<float_type>(x); };

if constexpr (float_cast(from_lowest) < float_cast(to_lowest))
{
if (float_cast(val) < float_cast(to_lowest))
{
throw detail::make_overflow_error<To>(val);
}
}

if constexpr (float_cast(from_max) > float_cast(to_max))
{
if (float_cast(val) > float_cast(to_max))
{
throw detail::make_overflow_error<To>(val);
}
}

To cast = static_cast<To>(val);
From cast_back = static_cast<From>(cast);
if (cast_back != val)
{
throw std::runtime_error{
fmt::format("Casting from {} to {} loses precision", val, cast)
};
}
return cast;
}
}
}

#endif
2 changes: 2 additions & 0 deletions libmamba/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ set(LIBMAMBA_TEST_SRCS
test_validate.cpp
test_virtual_packages.cpp
test_util.cpp
test_util_cast.cpp
test_util_compare.cpp
test_util_string.cpp
test_util_graph.cpp
test_env_lockfile.cpp
Expand Down
94 changes: 94 additions & 0 deletions libmamba/tests/test_util_cast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) 2023, QuantStack and Mamba Contributors
//
// Distributed under the terms of the BSD 3-Clause License.
//
// The full license is in the file LICENSE, distributed with this software.

#include <cmath>
#include <limits>
#include <utility>

#include <gtest/gtest.h>

#include "mamba/core/util_cast.hpp"

namespace mamba::util
{
template <typename T>
struct cast_valid : ::testing::Test
{
using First = typename T::first_type;
using Second = typename T::second_type;
};
using WidenTypes = ::testing::Types<
// integers
std::pair<char, int>,
std::pair<unsigned char, int>,
std::pair<unsigned char, unsigned int>,
std::pair<int, long long int>,
std::pair<unsigned int, long long int>,
std::pair<unsigned int, unsigned long long int>,
// floats
std::pair<float, double>,
// Mixed
std::pair<char, float>,
std::pair<unsigned char, float>,
std::pair<int, double>,
std::pair<unsigned int, double>>;
TYPED_TEST_SUITE(cast_valid, WidenTypes);

TYPED_TEST(cast_valid, checked_exact_num_cast_widen)
{
using From = typename TestFixture::First;
using To = typename TestFixture::Second;
static constexpr auto from_lowest = std::numeric_limits<From>::lowest();
static constexpr auto from_max = std::numeric_limits<From>::max();

EXPECT_EQ(safe_num_cast<To>(From(0)), To(0));
EXPECT_EQ(safe_num_cast<To>(From(1)), To(1));
EXPECT_EQ(safe_num_cast<To>(from_lowest), static_cast<To>(from_lowest));
EXPECT_EQ(safe_num_cast<To>(from_max), static_cast<To>(from_max));
}

TYPED_TEST(cast_valid, checked_exact_num_cast_narrow)
{
using From = typename TestFixture::Second; // inversed
using To = typename TestFixture::First; // inversed
EXPECT_EQ(safe_num_cast<To>(From(0)), To(0));
EXPECT_EQ(safe_num_cast<To>(From(1)), To(1));
}

template <typename T>
struct cast_overflow_lowest : ::testing::Test
{
using From = typename T::first_type;
using To = typename T::second_type;
};
using OverflowLowestTypes = ::testing::Types<
// integers
std::pair<char, unsigned char>,
std::pair<char, unsigned int>,
std::pair<int, char>,
std::pair<int, unsigned long long int>,
// floats
std::pair<double, float>,
// mixed
std::pair<double, int>,
std::pair<float, char>>;
TYPED_TEST_SUITE(cast_overflow_lowest, OverflowLowestTypes);

TYPED_TEST(cast_overflow_lowest, checked_exact_num_cast)
{
using From = typename TestFixture::From;
using To = typename TestFixture::To;
static constexpr auto from_lowest = std::numeric_limits<From>::lowest();

EXPECT_THROW(safe_num_cast<To>(from_lowest), std::overflow_error);
}

TEST(cast, precision)
{
EXPECT_THROW(safe_num_cast<int>(1.1), std::runtime_error);
EXPECT_THROW(safe_num_cast<float>(std::nextafter(double(1), 2)), std::runtime_error);
}
}
87 changes: 87 additions & 0 deletions libmamba/tests/test_util_compare.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (c) 2023, QuantStack and Mamba Contributors
//
// Distributed under the terms of the BSD 3-Clause License.
//
// The full license is in the file LICENSE, distributed with this software.

#include <limits>
#include <utility>

#include <gtest/gtest.h>

#include "mamba/core/util_compare.hpp"

namespace mamba::util
{
TEST(compare, equal)
{
EXPECT_TRUE(cmp_equal(char(0), char(0)));
EXPECT_TRUE(cmp_equal(char(1), char(1)));
EXPECT_TRUE(cmp_equal(char(-1), char(-1)));
EXPECT_TRUE(cmp_equal(int(0), int(0)));
EXPECT_TRUE(cmp_equal(int(1), int(1)));
EXPECT_TRUE(cmp_equal(int(-1), int(-1)));
EXPECT_TRUE(cmp_equal(std::size_t(0), std::size_t(0)));
EXPECT_TRUE(cmp_equal(std::size_t(1), std::size_t(1)));

EXPECT_TRUE(cmp_equal(char(0), int(0)));
EXPECT_TRUE(cmp_equal(char(1), int(1)));
EXPECT_TRUE(cmp_equal(char(-1), int(-1)));
EXPECT_TRUE(cmp_equal(std::size_t(0), char(0)));
EXPECT_TRUE(cmp_equal(std::size_t(1), char(1)));
EXPECT_TRUE(cmp_equal(std::size_t(0), int(0)));
EXPECT_TRUE(cmp_equal(std::size_t(1), int(1)));

EXPECT_FALSE(cmp_equal(char(0), char(1)));
EXPECT_FALSE(cmp_equal(char(1), char(-1)));
EXPECT_FALSE(cmp_equal(int(0), int(1)));
EXPECT_FALSE(cmp_equal(int(-1), int(1)));
EXPECT_FALSE(cmp_equal(std::size_t(0), std::size_t(1)));

EXPECT_FALSE(cmp_equal(char(0), int(1)));
EXPECT_FALSE(cmp_equal(char(1), int(-1)));
EXPECT_FALSE(cmp_equal(char(-1), int(1)));
EXPECT_FALSE(cmp_equal(std::size_t(1), int(-1)));
EXPECT_FALSE(cmp_equal(static_cast<std::size_t>(-1), int(-1)));
EXPECT_FALSE(cmp_equal(std::size_t(1), int(0)));
EXPECT_FALSE(cmp_equal(std::numeric_limits<std::size_t>::max(), int(0)));
}

TEST(compare, less)
{
EXPECT_TRUE(cmp_less(char(0), char(1)));
EXPECT_TRUE(cmp_less(char(-1), char(0)));
EXPECT_TRUE(cmp_less(int(0), int(1)));
EXPECT_TRUE(cmp_less(int(-1), int(1)));
EXPECT_TRUE(cmp_less(std::size_t(0), std::size_t(1)));

EXPECT_TRUE(cmp_less(char(0), int(1)));
EXPECT_TRUE(cmp_less(char(-1), int(0)));
EXPECT_TRUE(cmp_less(char(-1), int(1)));
EXPECT_TRUE(cmp_less(char(-1), std::size_t(1)));
EXPECT_TRUE(cmp_less(std::size_t(0), int(1)));
EXPECT_TRUE(cmp_less(std::numeric_limits<int>::min(), char(0)));
EXPECT_TRUE(cmp_less(std::numeric_limits<int>::min(), std::size_t(0)));
EXPECT_TRUE(cmp_less(int(-1), std::numeric_limits<std::size_t>::max()));
EXPECT_TRUE(cmp_less(std::size_t(1), std::numeric_limits<int>::max()));

EXPECT_FALSE(cmp_less(char(1), char(0)));
EXPECT_FALSE(cmp_less(char(1), char(1)));
EXPECT_FALSE(cmp_less(char(0), char(-1)));
EXPECT_FALSE(cmp_less(int(1), int(0)));
EXPECT_FALSE(cmp_less(int(1), int(-1)));
EXPECT_FALSE(cmp_less(std::size_t(1), std::size_t(0)));

EXPECT_FALSE(cmp_less(char(1), int(1)));
EXPECT_FALSE(cmp_less(char(1), int(0)));
EXPECT_FALSE(cmp_less(char(0), int(-1)));
EXPECT_FALSE(cmp_less(char(1), int(-11)));
EXPECT_FALSE(cmp_less(std::size_t(1), char(-1)));
EXPECT_FALSE(cmp_less(int(1), std::size_t(0)));
EXPECT_FALSE(cmp_less(char(0), std::numeric_limits<int>::min()));
EXPECT_FALSE(cmp_less(std::size_t(0), std::numeric_limits<int>::min()));
EXPECT_FALSE(cmp_less(std::numeric_limits<std::size_t>::max(), int(-1)));
EXPECT_FALSE(cmp_less(std::numeric_limits<int>::max(), std::size_t(1)));
}

}