Skip to content

Commit

Permalink
Initial commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 5, 2022
1 parent 0e4bab3 commit 9dd110d
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 2 deletions.
5 changes: 3 additions & 2 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ BUILD_STATIC_FAISS=OFF
SINGLEGPU=""
NVTX=OFF
CLEAN=0
BUILD_TYPE=Release
BUILD_DISABLE_DEPRECATION_WARNING=ON

# Set defaults for vars that may not have been defined externally
Expand Down Expand Up @@ -153,8 +154,8 @@ if (( ${NUMARGS} == 0 )) || hasArg cppraft; then
-DNVTX=${NVTX} \
-DDISABLE_DEPRECATION_WARNING=${BUILD_DISABLE_DEPRECATION_WARNING} \
-DBUILD_GTEST=${BUILD_GTEST} \
-DBUILD_STATIC_FAISS=${BUILD_STATIC_FAISS}

-DBUILD_STATIC_FAISS=${BUILD_STATIC_FAISS} \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE}

# Run all c++ targets at once
cmake --build ${CPP_RAFT_BUILD_DIR} -j${PARALLEL_LEVEL} ${MAKE_TARGETS} ${VERBOSE_FLAG}
Expand Down
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ include(cmake/thirdparty/get_thrust.cmake)
include(cmake/thirdparty/get_rmm.cmake)
include(cmake/thirdparty/get_libcudacxx.cmake)
include(cmake/thirdparty/get_cuco.cmake)
include(cmake/thirdparty/get_mdspan.cmake)

if(BUILD_TESTS)
include(cmake/thirdparty/get_faiss.cmake)
Expand Down Expand Up @@ -148,6 +149,7 @@ set(RAFT_LINK_LIBRARIES
$<$<BOOL:${NVTX}>:CUDA::nvToolsExt>
rmm::rmm
cuco::cuco
std::mdspan
)

target_link_libraries(raft INTERFACE ${RAFT_LINK_LIBRARIES})
Expand Down
15 changes: 15 additions & 0 deletions cpp/cmake/thirdparty/get_mdspan.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
function(find_and_configure_mdspan VERSION)
rapids_cpm_find(
mdspan ${VERSION}
GLOBAL_TARGETS std::mdspan
BUILD_EXPORT_SET raft-exports
INSTALL_EXPORT_SET raft-exports
CPM_ARGS
GIT_REPOSITORY https://github.com/trivialfis/mdspan
GIT_TAG 0193f075e977cc5f3c957425fd899e53d598f524
OPTIONS "MDSPAN_ENABLE_CUDA ON"
"MDSPAN_CXX_STANDARD ON"
)
endfunction()

find_and_configure_mdspan(0.2.0)
6 changes: 6 additions & 0 deletions cpp/include/raft/cudart_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,12 @@ void ASSERT_DEVICE_MEM(T* ptr, std::string name)
<< ", err=" << s_err << std::endl;
}

#define __ASSERT_STR_HELPER(x) #x

#define DEVICE_ASSERT(cond) \
(cond) ? static_cast<void>(0) \
: __assert_fail(__ASSERT_STR_HELPER((cond)), __FILE__, __LINE__, __PRETTY_FUNCTION__)

inline uint32_t curTimeMillis()
{
auto now = std::chrono::high_resolution_clock::now();
Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ add_executable(test_raft
test/linalg/map.cu
test/linalg/map_then_reduce.cu
test/linalg/matrix_vector_op.cu
test/linalg/mdspan.cu
test/linalg/multiply.cu
test/linalg/norm.cu
test/linalg/reduce.cu
Expand Down
225 changes: 225 additions & 0 deletions cpp/test/linalg/mdarray.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <experimental/mdspan>
#include <raft/cudart_utils.h>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <thrust/device_ptr.h>

namespace raft::linalg {
namespace detail {

/**
* @brief A simplified version of thrust::device_reference with support for CUDA stream.
*/
template <typename T>
class device_ref {
public:
using value_type = typename std::remove_cv_t<T>;
using pointer = thrust::device_ptr<T>;
using const_pointer = thrust::device_ptr<T const>;

private:
std::conditional_t<std::is_const<T>::value, const_pointer, pointer> ptr_;
rmm::cuda_stream_view stream_;

public:
device_ref(thrust::device_ptr<T> ptr, rmm::cuda_stream_view stream) : ptr_{ptr}, stream_{stream}
{
}

operator value_type() const // NOLINT
{
auto* raw = ptr_.get();
value_type v{};
RAFT_CUDA_TRY(cudaMemcpyAsync(&v, raw, sizeof(v), cudaMemcpyDeviceToHost, stream_.value()));
return v;
}
auto operator=(T const& other) -> device_ref&
{
auto* raw = ptr_.get();
RAFT_CUDA_TRY(cudaMemcpyAsync(raw, &other, sizeof(T), cudaMemcpyHostToDevice, stream_.value()));
return *this;
}
};

/**
* \brief A thin wrapper over rmm::device_uvector for implementing the mdarray container policy.
*
*/
template <typename T>
class device_vector {
rmm::device_uvector<T> data_;

public:
using value_type = T;
using size_type = std::size_t;

using reference = device_ref<T>;
using const_reference = device_ref<T const>;

using pointer = value_type*;
using const_pointer = value_type const*;

using iterator = pointer;
using const_iterator = const_pointer;

public:
~device_vector() = default;
device_vector(device_vector&&) noexcept = default;
device_vector(device_vector const& that) : data_{that.data_, that.data_.stream()} {}

/**
* @brief Default ctor is deleted as it doesn't accept stream.
*/
device_vector() = delete;
/**
* @brief Ctor that accepts a size, stream and an optional mr.
*/
explicit device_vector(
std::size_t size,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
: data_{size, stream, mr}
{
}
/**
* @brief Index operator that returns a proxy to the actual data.
*/
template <typename Index>
auto operator[](Index i) -> reference
{
return device_ref<T>{thrust::device_ptr<T>{data_.data() + i}, data_.stream()};
}
/**
* @brief Index operator that returns a proxy to the actual data.
*/
template <typename Index>
auto operator[](Index i) const
{
return device_ref<T const>{thrust::device_ptr<T const>{data_.data() + i}, data_.stream()};
}

[[nodiscard]] auto data() -> pointer { return data_.data(); }
[[nodiscard]] auto data() const -> const_pointer { return data_.data(); }
};

class uvector_policy {
rmm::cuda_stream_view stream_;

public:
using container_type = device_vector<float>;
using pointer = typename container_type::pointer;
using const_pointer = typename container_type::const_pointer;
using reference = device_ref<float>;
using const_reference = device_ref<float const>;
using accessor_policy = std::experimental::default_accessor<float>;

public:
auto create(size_t n) -> container_type { return container_type(n, stream_); }
/**
* @brief A quick hack to pass the policy to `stdex::mdspan` without passing down the stream.
*/
operator std::experimental::default_accessor<float>() const { return {}; } // NOLINT

uvector_policy() = default;
[[nodiscard]] auto access(container_type& c, size_t n) const { return c[n]; }
};
} // namespace detail

namespace stdex = std::experimental;

/**
* @brief Modified from the c++ mdarray proposal, with the differences listed below.
*
* - Layout policy is different, the mdarray in raft uses `stdex::extent` directly just
* like `mdspan`, while the `mdarray` in the reference implementation uses varidic
* template.
*
* - Most of the constructors from the reference implementation is removed to make sure
* CUDA stream is honorred.
*/
template <class ElementType,
class Extents,
class LayoutPolicy = stdex::layout_right,
class AccessorPolicy = stdex::default_accessor<ElementType>>
class mdarray {
public:
using element_type = ElementType;
using extents_type = Extents;
using layout_type = LayoutPolicy;
using mapping_type = typename layout_type::template mapping<extents_type>;

using index_type = std::size_t;
using difference_type = std::ptrdiff_t;
using container_policy_type = AccessorPolicy;
using container_type = typename container_policy_type::container_type;

static_assert(!std::is_const<ElementType>::value,
"Element type for container must not be const.");

using pointer = typename container_policy_type::pointer;
using const_pointer = typename container_policy_type::const_pointer;
using reference = typename container_policy_type::reference;
using const_reference = typename container_policy_type::const_reference;
using view_type = stdex::mdspan<element_type,
extents_type,
layout_type,
typename container_policy_type::accessor_policy>;
using const_view_type = stdex::mdspan<element_type const,
extents_type,
layout_type,
typename container_policy_type::accessor_policy>;

public:
constexpr mdarray() noexcept(std::is_nothrow_default_constructible<container_type>::value) =
default;
constexpr mdarray(mdarray const&) noexcept(
std::is_nothrow_copy_constructible<container_type>::value) = default;
constexpr mdarray(mdarray&&) noexcept(std::is_nothrow_move_constructible<container_type>::value) =
default;

auto operator =(mdarray&&) noexcept(std::is_nothrow_move_assignable<container_type>::value)
-> mdarray& = default;

auto operator =(mdarray const&) noexcept(std::is_nothrow_copy_assignable<container_type>::value)
-> mdarray& = default;

~mdarray() noexcept(std::is_nothrow_destructible<container_type>::value) = default;

constexpr mdarray(mapping_type const& m, container_policy_type const& cp)
: cp_(cp), map_(m), c_(cp_.create(map_.required_span_size()))
{
}

auto view() noexcept { return view_type(c_.data(), map_, cp_); }

auto data() noexcept -> pointer { return cp_.data(); }
constexpr auto data() const noexcept -> const_pointer { return cp_.data(); }

template <typename... IndexType>
auto operator()(IndexType&&... indices)
-> std::enable_if_t<sizeof...(IndexType) == extents_type::rank(), reference>
{
return cp_.access(c_, map_(index_type(indices)...));
}

private:
container_policy_type cp_;
mapping_type map_;
container_type c_;
};
} // namespace raft::linalg
81 changes: 81 additions & 0 deletions cpp/test/linalg/mdspan.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mdarray.h"
#include <experimental/mdspan>
#include <gtest/gtest.h>
#include <raft/cudart_utils.h>
#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>
#include <thrust/for_each.h>
#include <thrust/device_vector.h>
#include <thrust/iterator/counting_iterator.h>

namespace {
namespace stdex = std::experimental;
void test_mdspan()
{
auto it = thrust::make_counting_iterator(0ul);

cudaStream_t stream = nullptr;
rmm::device_uvector<float> a{16ul, stream};
thrust::sequence(rmm::exec_policy(stream), a.begin(), a.end());

stdex::mdspan<float, stdex::extents<stdex::dynamic_extent, stdex::dynamic_extent>> span{
a.data(), 4, 4};

thrust::for_each(it, it + 4, [=] __device__(size_t i) {
auto v = span(0, i);
DEVICE_ASSERT(v == i);
auto k = stdex::submdspan(span, 0, stdex::full_extent);
DEVICE_ASSERT(k(i) == i);
});
}

} // namespace

namespace raft::linalg {
void test_mdarray()
{
detail::uvector_policy policy;
using matrix_extent = stdex::extents<stdex::dynamic_extent, stdex::dynamic_extent>;
stdex::layout_right::mapping<matrix_extent> layout{matrix_extent{4, 4}};
mdarray<float, matrix_extent, stdex::layout_right, detail::uvector_policy> array{layout, policy};

std::cout << array(0, 0) << std::endl;
array(0, 3) = 1;
std::cout << array(0, 3) << std::endl;
ASSERT_EQ(array(0, 3), 1);
auto view = array.view();
auto it = thrust::make_counting_iterator(0ul);
thrust::for_each(
rmm::exec_policy(rmm::cuda_stream_default), it, it + 1, [view] __device__(auto i) {
DEVICE_ASSERT(view(0, 3) == 1);
});

thrust::device_vector<float> vec(10);
vec[0] = 1.0;
auto b = vec[0];
std::cout << b << std::endl;

detail::device_vector<float> dvec(10, rmm::cuda_stream_default);
auto a = dvec[2];
float c = a;
std::cout << c << std::endl;
}
} // namespace raft::linalg

TEST(MDSpan, Kernel) { test_mdspan(); }
TEST(MDArray, Basic) { raft::linalg::test_mdarray(); }

0 comments on commit 9dd110d

Please sign in to comment.