diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 524e0b16e4..526143d8f9 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -299,9 +299,7 @@ template class AnnCagraSortTest : public ::testing::TestWithParam { public: AnnCagraSortTest() - : stream_(handle_.get_stream()), - ps(::testing::TestWithParam::GetParam()), - database(0, stream_) + : ps(::testing::TestWithParam::GetParam()), database(0, handle_.get_stream()) { } @@ -313,7 +311,8 @@ class AnnCagraSortTest : public ::testing::TestWithParam { auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.n_rows, ps.dim); auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); - raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + raft::copy( + database_host.data_handle(), database.data(), database.size(), handle_.get_stream()); auto database_host_view = raft::make_host_matrix_view( (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); @@ -342,13 +341,13 @@ class AnnCagraSortTest : public ::testing::TestWithParam { void SetUp() override { std::cout << "Resizing database: " << ps.n_rows * ps.dim << std::endl; - database.resize(((size_t)ps.n_rows) * ps.dim, stream_); + database.resize(((size_t)ps.n_rows) * ps.dim, handle_.get_stream()); std::cout << "Done.\nRuning rng" << std::endl; raft::random::Rng r(1234ULL); if constexpr (std::is_same{}) { - GenerateRoundingErrorFreeDataset(database.data(), ps.n_rows, ps.dim, r, stream_); + GenerateRoundingErrorFreeDataset(database.data(), ps.n_rows, ps.dim, r, handle_.get_stream()); } else { - r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_); + r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), handle_.get_stream()); } handle_.sync_stream(); } @@ -356,12 +355,11 @@ class AnnCagraSortTest : public ::testing::TestWithParam { void TearDown() override { handle_.sync_stream(); - database.resize(0, stream_); + database.resize(0, handle_.get_stream()); } private: raft::device_resources handle_; - rmm::cuda_stream_view stream_; AnnCagraInputs ps; rmm::device_uvector database; };