Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update cugraph for compatibility with the latest cuco #4111

Merged
merged 16 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 36 additions & 34 deletions cpp/src/prims/key_store.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,8 @@ namespace cugraph {

namespace detail {

using cuco_storage_type = cuco::storage<1>; ///< cuco window storage type

template <typename KeyIterator>
struct key_binary_search_contains_op_t {
using key_type = typename thrust::iterator_traits<KeyIterator>::value_type;
Expand Down Expand Up @@ -72,7 +74,7 @@ template <typename ViewType>
struct key_cuco_store_contains_device_view_t {
using key_type = typename ViewType::key_type;
using cuco_store_device_ref_type =
typename ViewType::cuco_store_type::ref_type<cuco::experimental::contains_tag>;
typename ViewType::cuco_store_type::ref_type<cuco::contains_tag>;

static_assert(!ViewType::binary_search);

Expand All @@ -88,9 +90,8 @@ struct key_cuco_store_contains_device_view_t {

template <typename ViewType>
struct key_cuco_store_insert_device_view_t {
using key_type = typename ViewType::key_type;
using cuco_store_device_ref_type =
typename ViewType::cuco_store_type::ref_type<cuco::experimental::insert_tag>;
using key_type = typename ViewType::key_type;
using cuco_store_device_ref_type = typename ViewType::cuco_store_type::ref_type<cuco::insert_tag>;

static_assert(!ViewType::binary_search);

Expand Down Expand Up @@ -147,14 +148,15 @@ class key_cuco_store_view_t {

static constexpr bool binary_search = false;

using cuco_store_type = cuco::experimental::static_set<
key_t,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::experimental::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>>;
using cuco_store_type =
PointKernel marked this conversation as resolved.
Show resolved Hide resolved
cuco::static_set<key_t,
cuco::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>,
cuco_storage_type>;

key_cuco_store_view_t(cuco_store_type const* store) : cuco_store_(store) {}

Expand All @@ -167,12 +169,9 @@ class key_cuco_store_view_t {
cuco_store_->contains(key_first, key_last, value_first, stream);
}

auto cuco_store_contains_device_ref() const
{
return cuco_store_->ref(cuco::experimental::contains);
}
auto cuco_store_contains_device_ref() const { return cuco_store_->ref(cuco::contains); }

auto cuco_store_insert_device_ref() const { return cuco_store_->ref(cuco::experimental::insert); }
auto cuco_store_insert_device_ref() const { return cuco_store_->ref(cuco::insert); }

key_t invalid_key() const { return cuco_store_->get_empty_key_sentinel(); }

Expand Down Expand Up @@ -240,14 +239,15 @@ class key_cuco_store_t {
public:
using key_type = key_t;

using cuco_store_type = cuco::experimental::static_set<
key_t,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::experimental::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>>;
using cuco_store_type =
cuco::static_set<key_t,
cuco::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>,
cuco_storage_type>;

key_cuco_store_t(rmm::cuda_stream_view stream) {}

Expand Down Expand Up @@ -324,14 +324,16 @@ class key_cuco_store_t {

auto stream_adapter = rmm::mr::make_stream_allocator_adaptor(
rmm::mr::polymorphic_allocator<std::byte>(rmm::mr::get_current_device_resource()), stream);
cuco_store_ = std::make_unique<cuco_store_type>(
cuco_size,
cuco::sentinel::empty_key<key_t>{invalid_key},
thrust::equal_to<key_t>{},
cuco::experimental::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>{},
stream_adapter,
stream.value());
cuco_store_ =
std::make_unique<cuco_store_type>(cuco_size,
cuco::sentinel::empty_key<key_t>{invalid_key},
thrust::equal_to<key_t>{},
cuco::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>{},
cuco::thread_scope_device,
cuco_storage_type{},
stream_adapter,
stream.value());
}

std::unique_ptr<cuco_store_type> cuco_store_{nullptr};
Expand Down
Loading
Loading