diff --git a/cpp/src/prims/key_store.cuh b/cpp/src/prims/key_store.cuh index 56be1456d0b..b8e17145590 100644 --- a/cpp/src/prims/key_store.cuh +++ b/cpp/src/prims/key_store.cuh @@ -28,6 +28,7 @@ #include #include +#include #include #include #include @@ -322,7 +323,7 @@ class key_cuco_store_t { static_cast(static_cast(num_keys) / load_factor), static_cast(num_keys) + 1); // cuco::static_map requires at least one empty slot - auto stream_adapter = rmm::mr::make_stream_allocator_adaptor( + auto stream_adapter = rmm::mr::stream_allocator_adaptor( rmm::mr::polymorphic_allocator(rmm::mr::get_current_device_resource()), stream); cuco_store_ = std::make_unique(cuco_size, diff --git a/cpp/src/prims/kv_store.cuh b/cpp/src/prims/kv_store.cuh index de233fd583b..a4e644b361e 100644 --- a/cpp/src/prims/kv_store.cuh +++ b/cpp/src/prims/kv_store.cuh @@ -820,7 +820,7 @@ class kv_cuco_store_t { static_cast(static_cast(num_keys) / load_factor), static_cast(num_keys) + 1); // cuco::static_map requires at least one empty slot - auto stream_adapter = rmm::mr::make_stream_allocator_adaptor( + auto stream_adapter = rmm::mr::stream_allocator_adaptor( rmm::mr::polymorphic_allocator(rmm::mr::get_current_device_resource()), stream); if constexpr (std::is_arithmetic_v) { cuco_store_ =