diff --git a/include/experimental/__p0009_bits/config.hpp b/include/experimental/__p0009_bits/config.hpp index e8cacf40..ce2e6f20 100644 --- a/include/experimental/__p0009_bits/config.hpp +++ b/include/experimental/__p0009_bits/config.hpp @@ -176,6 +176,13 @@ static_assert(_MDSPAN_CPLUSPLUS >= MDSPAN_CXX_STD_14, "mdspan requires C++14 or # endif #endif +#ifndef _MDSPAN_USE_IF_CONSTEXPR_17 +# if (defined(__cpp_if_constexpr) && __cpp_if_constexpr >= 201606) \ + || (!defined(__cpp_constexpr) && MDSPAN_HAS_CXX_17) +# define _MDSPAN_USE_IF_CONSTEXPR_17 1 +# endif +#endif + #ifndef _MDSPAN_USE_INTEGER_SEQUENCE # if defined(_MDSPAN_COMPILER_MSVC) # if (defined(__cpp_lib_integer_sequence) && __cpp_lib_integer_sequence >= 201304) diff --git a/include/experimental/__p0009_bits/macros.hpp b/include/experimental/__p0009_bits/macros.hpp index b60c4261..91c7817e 100644 --- a/include/experimental/__p0009_bits/macros.hpp +++ b/include/experimental/__p0009_bits/macros.hpp @@ -697,3 +697,9 @@ struct __bools; // end Pre-C++14 constexpr }}}1 //============================================================================== + +#if _MDSPAN_USE_IF_CONSTEXPR_17 +# define _MDSPAN_IF_CONSTEXPR_17 constexpr +#else +# define _MDSPAN_IF_CONSTEXPR_17 +#endif diff --git a/include/experimental/__p0009_bits/mdspan.hpp b/include/experimental/__p0009_bits/mdspan.hpp index 23114aa5..9fb6f947 100644 --- a/include/experimental/__p0009_bits/mdspan.hpp +++ b/include/experimental/__p0009_bits/mdspan.hpp @@ -18,10 +18,15 @@ #include "default_accessor.hpp" #include "layout_right.hpp" +#include "macros.hpp" #include "extents.hpp" #include "trait_backports.hpp" #include "compressed_pair.hpp" +#include +#include +#include + namespace MDSPAN_IMPL_STANDARD_NAMESPACE { template < class ElementType, @@ -219,6 +224,68 @@ class mdspan //-------------------------------------------------------------------------------- // [mdspan.basic.mapping], mdspan mapping domain multidimensional index to access codomain element + MDSPAN_TEMPLATE_REQUIRES( + class... SizeTypes, + /* requires */ ( + _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(std::is_convertible, SizeTypes, index_type) /* && ... */) && + _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, SizeTypes) /* && ... */) && + (rank() == sizeof...(SizeTypes)) + ) + ) + constexpr reference at(SizeTypes... indices) const + { + size_t r = 0; + for (const auto& index : {indices...}) { + if (__is_index_oor(index, __mapping_ref().extents().extent(r))) { + throw std::out_of_range( + "mdspan::at(...," + std::to_string(index) + ",...) out-of-range at rank index " + std::to_string(r) + + " for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}"); + } + ++r; + } + return __accessor_ref().access(__ptr_ref(), __mapping_ref()(static_cast(std::move(indices))...)); + } + + MDSPAN_TEMPLATE_REQUIRES( + class SizeType, + /* requires */ ( + _MDSPAN_TRAIT(std::is_convertible, const SizeType&, index_type) && + _MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, const SizeType&) + ) + ) + constexpr reference at(const std::array& indices) const + { + for (size_t r = 0; r < indices.size(); ++r) { + if (__is_index_oor(indices[r], __mapping_ref().extents().extent(r))) { + throw std::out_of_range( + "mdspan::at({...," + std::to_string(indices[r]) + ",...}) out-of-range at rank index " + std::to_string(r) + + " for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}"); + } + } + return __impl::template __callop(*this, indices); + } + + #ifdef __cpp_lib_span + MDSPAN_TEMPLATE_REQUIRES( + class SizeType, + /* requires */ ( + _MDSPAN_TRAIT(std::is_convertible, const SizeType&, index_type) && + _MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, const SizeType&) + ) + ) + constexpr reference at(std::span indices) const + { + for (size_t r = 0; r < indices.size(); ++r) { + if (__is_index_oor(indices[r], __mapping_ref().extents().extent(r))) { + throw std::out_of_range( + "mdspan::at({...," + std::to_string(indices[r]) + ",...}) out-of-range at rank index " + std::to_string(r) + + " for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}"); + } + } + return __impl::template __callop(*this, indices); + } + #endif // __cpp_lib_span + #if MDSPAN_USE_BRACKET_OPERATOR MDSPAN_TEMPLATE_REQUIRES( class... SizeTypes, @@ -243,7 +310,7 @@ class mdspan ) ) MDSPAN_FORCE_INLINE_FUNCTION - constexpr reference operator[](const std::array< SizeType, rank()>& indices) const + constexpr reference operator[](const std::array& indices) const { return __impl::template __callop(*this, indices); } @@ -377,6 +444,23 @@ class mdspan MDSPAN_FORCE_INLINE_FUNCTION _MDSPAN_CONSTEXPR_14 accessor_type& __accessor_ref() noexcept { return __members.__second().__second(); } MDSPAN_FORCE_INLINE_FUNCTION constexpr accessor_type const& __accessor_ref() const noexcept { return __members.__second().__second(); } + MDSPAN_TEMPLATE_REQUIRES( + class SizeType, + /* requires */ ( + _MDSPAN_TRAIT(std::is_convertible, const SizeType&, index_type) && + _MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, const SizeType&) + ) + ) + MDSPAN_FORCE_INLINE_FUNCTION constexpr bool __is_index_oor(SizeType index, index_type extent) const noexcept { + // Check for negative indices + if _MDSPAN_IF_CONSTEXPR_17 (_MDSPAN_TRAIT(std::is_signed, SizeType)) { + if(index < 0) { + return true; + } + } + return static_cast(index) >= extent; + } + template friend class mdspan; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 14d61b2f..76799fb3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -88,6 +88,7 @@ mdspan_add_test(test_layout_preconditions ENABLE_PRECONDITIONS) mdspan_add_test(test_dims) mdspan_add_test(test_extents) +mdspan_add_test(test_mdspan_at) mdspan_add_test(test_mdspan_ctors) mdspan_add_test(test_mdspan_swap) mdspan_add_test(test_mdspan_conversion) diff --git a/tests/test_mdspan_at.cpp b/tests/test_mdspan_at.cpp new file mode 100644 index 00000000..55496812 --- /dev/null +++ b/tests/test_mdspan_at.cpp @@ -0,0 +1,34 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#include +#include + +#include + + +TEST(TestMdspanAt, test_mdspan_at) { + std::array a{}; + Kokkos::mdspan> s(a.data()); + + s.at(0, 0) = 3.14; + s.at(std::array{1, 2}) = 2.72; + ASSERT_EQ(s.at(0, 0), 3.14); + ASSERT_EQ(s.at(std::array{1, 2}), 2.72); + + EXPECT_THROW(s.at(2, 3), std::out_of_range); + EXPECT_THROW(s.at(std::array{3, 1}), std::out_of_range); +}