Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add element access via at() to std::mdspan #302

Open
wants to merge 3 commits into
base: stable
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/experimental/__p0009_bits/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ static_assert(_MDSPAN_CPLUSPLUS >= MDSPAN_CXX_STD_14, "mdspan requires C++14 or
# endif
#endif

#ifndef _MDSPAN_USE_IF_CONSTEXPR_17
# if (defined(__cpp_if_constexpr) && __cpp_if_constexpr >= 201606) \
|| (!defined(__cpp_constexpr) && MDSPAN_HAS_CXX_17)
# define _MDSPAN_USE_IF_CONSTEXPR_17 1
# endif
#endif

#ifndef _MDSPAN_USE_INTEGER_SEQUENCE
# if defined(_MDSPAN_COMPILER_MSVC)
# if (defined(__cpp_lib_integer_sequence) && __cpp_lib_integer_sequence >= 201304)
Expand Down
6 changes: 6 additions & 0 deletions include/experimental/__p0009_bits/macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,9 @@ struct __bools;

// </editor-fold> end Pre-C++14 constexpr }}}1
//==============================================================================

#if _MDSPAN_USE_IF_CONSTEXPR_17
# define _MDSPAN_IF_CONSTEXPR_17 constexpr
#else
# define _MDSPAN_IF_CONSTEXPR_17
#endif
86 changes: 85 additions & 1 deletion include/experimental/__p0009_bits/mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@

#include "default_accessor.hpp"
#include "layout_right.hpp"
#include "macros.hpp"
#include "extents.hpp"
#include "trait_backports.hpp"
#include "compressed_pair.hpp"

#include <stdexcept>
#include <string>
#include <type_traits>

namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
template <
class ElementType,
Expand Down Expand Up @@ -219,6 +224,68 @@ class mdspan
//--------------------------------------------------------------------------------
// [mdspan.basic.mapping], mdspan mapping domain multidimensional index to access codomain element

MDSPAN_TEMPLATE_REQUIRES(
class... SizeTypes,
/* requires */ (
_MDSPAN_FOLD_AND(_MDSPAN_TRAIT(std::is_convertible, SizeTypes, index_type) /* && ... */) &&
_MDSPAN_FOLD_AND(_MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, SizeTypes) /* && ... */) &&
(rank() == sizeof...(SizeTypes))
)
)
constexpr reference at(SizeTypes... indices) const
{
size_t r = 0;
for (const auto& index : {indices...}) {
if (__is_index_oor(index, __mapping_ref().extents().extent(r))) {
throw std::out_of_range(
"mdspan::at(...," + std::to_string(index) + ",...) out-of-range at rank index " + std::to_string(r) +
" for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}");
}
++r;
}
return __accessor_ref().access(__ptr_ref(), __mapping_ref()(static_cast<index_type>(std::move(indices))...));
}

MDSPAN_TEMPLATE_REQUIRES(
class SizeType,
/* requires */ (
_MDSPAN_TRAIT(std::is_convertible, const SizeType&, index_type) &&
_MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, const SizeType&)
)
)
constexpr reference at(const std::array<SizeType, rank()>& indices) const
{
for (size_t r = 0; r < indices.size(); ++r) {
if (__is_index_oor(indices[r], __mapping_ref().extents().extent(r))) {
throw std::out_of_range(
"mdspan::at({...," + std::to_string(indices[r]) + ",...}) out-of-range at rank index " + std::to_string(r) +
" for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}");
}
}
return __impl::template __callop<reference>(*this, indices);
}

#ifdef __cpp_lib_span
MDSPAN_TEMPLATE_REQUIRES(
class SizeType,
/* requires */ (
_MDSPAN_TRAIT(std::is_convertible, const SizeType&, index_type) &&
_MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, const SizeType&)
)
)
constexpr reference at(std::span<SizeType, rank()> indices) const
{
for (size_t r = 0; r < indices.size(); ++r) {
if (__is_index_oor(indices[r], __mapping_ref().extents().extent(r))) {
throw std::out_of_range(
"mdspan::at({...," + std::to_string(indices[r]) + ",...}) out-of-range at rank index " + std::to_string(r) +
" for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}");
}
}
return __impl::template __callop<reference>(*this, indices);
}
#endif // __cpp_lib_span

#if MDSPAN_USE_BRACKET_OPERATOR
MDSPAN_TEMPLATE_REQUIRES(
class... SizeTypes,
Expand All @@ -243,7 +310,7 @@ class mdspan
)
)
MDSPAN_FORCE_INLINE_FUNCTION
constexpr reference operator[](const std::array< SizeType, rank()>& indices) const
constexpr reference operator[](const std::array<SizeType, rank()>& indices) const
{
return __impl::template __callop<reference>(*this, indices);
}
Expand Down Expand Up @@ -377,6 +444,23 @@ class mdspan
MDSPAN_FORCE_INLINE_FUNCTION _MDSPAN_CONSTEXPR_14 accessor_type& __accessor_ref() noexcept { return __members.__second().__second(); }
MDSPAN_FORCE_INLINE_FUNCTION constexpr accessor_type const& __accessor_ref() const noexcept { return __members.__second().__second(); }

MDSPAN_TEMPLATE_REQUIRES(
class SizeType,
/* requires */ (
_MDSPAN_TRAIT(std::is_convertible, const SizeType&, index_type) &&
_MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, const SizeType&)
)
)
MDSPAN_FORCE_INLINE_FUNCTION constexpr bool __is_index_oor(SizeType index, index_type extent) const noexcept {
Copy link
Contributor

@nmm0 nmm0 Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm of the opinion that we shouldn't add to the reserved names (i.e. __ prefix) even though it's consistent with the rest of the class since we do plan on removing them.

// Check for negative indices
if _MDSPAN_IF_CONSTEXPR_17 (_MDSPAN_TRAIT(std::is_signed, SizeType)) {
if(index < 0) {
return true;
}
}
return static_cast<index_type>(index) >= extent;
}

template <class, class, class, class>
friend class mdspan;

Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ mdspan_add_test(test_layout_preconditions ENABLE_PRECONDITIONS)

mdspan_add_test(test_dims)
mdspan_add_test(test_extents)
mdspan_add_test(test_mdspan_at)
mdspan_add_test(test_mdspan_ctors)
mdspan_add_test(test_mdspan_swap)
mdspan_add_test(test_mdspan_conversion)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_mdspan_at.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#include <array>
#include <mdspan/mdspan.hpp>

#include <gtest/gtest.h>


TEST(TestMdspanAt, test_mdspan_at) {
std::array<double, 6> a{};
Kokkos::mdspan<double, Kokkos::extents<size_t, 2, 3>> s(a.data());

s.at(0, 0) = 3.14;
s.at(std::array<int, 2>{1, 2}) = 2.72;
ASSERT_EQ(s.at(0, 0), 3.14);
ASSERT_EQ(s.at(std::array<int, 2>{1, 2}), 2.72);

EXPECT_THROW(s.at(2, 3), std::out_of_range);
EXPECT_THROW(s.at(std::array<int, 2>{3, 1}), std::out_of_range);
}
Loading