Skip to content

Commit

Permalink
Update CAGRA test
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed May 9, 2023
1 parent ca768b5 commit b447bdd
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,7 @@ template <typename DistanceT, typename DataT, typename IdxT>
class AnnCagraSortTest : public ::testing::TestWithParam<AnnCagraInputs> {
public:
AnnCagraSortTest()
: stream_(handle_.get_stream()),
ps(::testing::TestWithParam<AnnCagraInputs>::GetParam()),
database(0, stream_)
: ps(::testing::TestWithParam<AnnCagraInputs>::GetParam()), database(0, handle_.get_stream())
{
}

Expand All @@ -313,7 +311,8 @@ class AnnCagraSortTest : public ::testing::TestWithParam<AnnCagraInputs> {
auto database_view = raft::make_device_matrix_view<const DataT, IdxT>(
(const DataT*)database.data(), ps.n_rows, ps.dim);
auto database_host = raft::make_host_matrix<DataT, IdxT>(ps.n_rows, ps.dim);
raft::copy(database_host.data_handle(), database.data(), database.size(), stream_);
raft::copy(
database_host.data_handle(), database.data(), database.size(), handle_.get_stream());
auto database_host_view = raft::make_host_matrix_view<const DataT, IdxT>(
(const DataT*)database_host.data_handle(), ps.n_rows, ps.dim);

Expand Down Expand Up @@ -342,26 +341,25 @@ class AnnCagraSortTest : public ::testing::TestWithParam<AnnCagraInputs> {
void SetUp() override
{
std::cout << "Resizing database: " << ps.n_rows * ps.dim << std::endl;
database.resize(((size_t)ps.n_rows) * ps.dim, stream_);
database.resize(((size_t)ps.n_rows) * ps.dim, handle_.get_stream());
std::cout << "Done.\nRuning rng" << std::endl;
raft::random::Rng r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
GenerateRoundingErrorFreeDataset(database.data(), ps.n_rows, ps.dim, r, stream_);
GenerateRoundingErrorFreeDataset(database.data(), ps.n_rows, ps.dim, r, handle_.get_stream());
} else {
r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_);
r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), handle_.get_stream());
}
handle_.sync_stream();
}

void TearDown() override
{
handle_.sync_stream();
database.resize(0, stream_);
database.resize(0, handle_.get_stream());
}

private:
raft::device_resources handle_;
rmm::cuda_stream_view stream_;
AnnCagraInputs ps;
rmm::device_uvector<DataT> database;
};
Expand Down

0 comments on commit b447bdd

Please sign in to comment.