Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diskann Benchmarking Wrapper #260

Open
wants to merge 103 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 93 commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
8e8d3c1
initial commit
tarang-jain Jul 29, 2024
e937ebd
merge 24.08
tarang-jain Jul 29, 2024
0bbbf0d
make build
tarang-jain Jul 29, 2024
02084e2
Merge branch 'branch-24.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jul 29, 2024
64f1d60
Merge branch 'branch-24.10' into diskann-wrapper
tarang-jain Jul 29, 2024
706f22e
update wrapper
tarang-jain Jul 31, 2024
3ea499b
Merge branch 'branch-24.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jul 31, 2024
e0aab8f
diskann_memory working
tarang-jain Aug 1, 2024
17c5510
Merge branch 'branch-24.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Aug 1, 2024
f426df9
Merge branch 'branch-24.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Aug 1, 2024
a7bdd33
make compile
tarang-jain Aug 3, 2024
d2442ca
Merge branch 'diskann-wrapper' of https://github.com/tarang-jain/cuvs…
tarang-jain Aug 3, 2024
dbc84cc
rm num_threads_
tarang-jain Aug 5, 2024
7e37218
FEA Add cuvs-bench to dependencies and conda environments
dantegd Aug 5, 2024
b2aef6d
FIX add missing deps
dantegd Aug 5, 2024
b9762d5
Merge branch 'fea-add-bench-deps' of https://github.com/dantegd/cuvs …
tarang-jain Aug 5, 2024
bf75242
FIX version and other improvements
dantegd Aug 6, 2024
a8bcdef
FEA Add cuvs_bench.run
dantegd Aug 6, 2024
3818da9
update patch;build command
tarang-jain Aug 6, 2024
cd8bfe5
Merge branch 'cuvsbench-run' of https://github.com/dantegd/cuvs into …
tarang-jain Aug 6, 2024
ec6d70c
FIX some cuvs_bench python build dependencies
dantegd Aug 6, 2024
c9f797a
Merge branch 'cuvsbench-run' of https://github.com/dantegd/cuvs into …
tarang-jain Aug 6, 2024
585ad53
FIX add missing algorithms.yaml
dantegd Aug 6, 2024
441ab2a
Merge branch 'cuvsbench-run' of https://github.com/dantegd/cuvs into …
tarang-jain Aug 6, 2024
33b075d
working mem index
tarang-jain Aug 7, 2024
81c92e6
Merge branch 'branch-24.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Aug 7, 2024
11545c3
Merge branch 'rapidsai:branch-24.10' into diskann-wrapper
tarang-jain Aug 7, 2024
ffea663
remove base_set warning
tarang-jain Aug 8, 2024
9c1cddc
Merge branch 'branch-24.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Aug 8, 2024
96d5642
Merge branch 'diskann-wrapper' of https://github.com/tarang-jain/cuvs…
tarang-jain Aug 8, 2024
a7eb787
pull upstream
tarang-jain Sep 19, 2024
4cbe7b1
merge 24.10
tarang-jain Oct 1, 2024
63621f4
revert some changes
tarang-jain Oct 1, 2024
a890ac5
Merge branch 'branch-24.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 1, 2024
450dcee
revert
tarang-jain Oct 1, 2024
b0f4b57
revert
tarang-jain Oct 1, 2024
8cd6c40
revert
tarang-jain Oct 1, 2024
d658856
update dependencies
tarang-jain Oct 2, 2024
d1e4101
Merge branch 'branch-24.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 2, 2024
2e080c6
update diff
tarang-jain Oct 3, 2024
8c6a178
Merge branch 'branch-24.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 3, 2024
626dc17
Merge branch 'branch-24.12' into diskann-wrapper
tarang-jain Oct 3, 2024
3d15882
style
tarang-jain Oct 3, 2024
a72165c
Merge branch 'branch-24.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 3, 2024
9c202f2
Merge branch 'diskann-wrapper' of https://github.com/tarang-jain/cuvs…
tarang-jain Oct 3, 2024
58a729c
diskann_memory working
tarang-jain Oct 8, 2024
2412b70
Merge branch 'branch-24.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 8, 2024
3a56402
ssd wrapper working, cuvs_vamana DOES NOT BUILD
tarang-jain Oct 15, 2024
0b39d5b
Merge branch 'branch-24.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 15, 2024
92ec474
builds now
tarang-jain Oct 15, 2024
333539f
Merge branch 'branch-24.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 15, 2024
a13bf1a
rm bug
tarang-jain Oct 15, 2024
df54939
add beam_width, rm dbg statements
tarang-jain Oct 15, 2024
93b2620
style
tarang-jain Oct 15, 2024
54385ab
updates after PR reviews, replace thread_pool with omp pragma
tarang-jain Oct 16, 2024
c35d899
codespell
tarang-jain Oct 16, 2024
31d846a
sync stream in vamana serialize
tarang-jain Oct 16, 2024
61e00c7
re-enable warnings
tarang-jain Oct 17, 2024
0bf43e9
Merge branch 'branch-24.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 17, 2024
0a6b094
deps in conda recipe
tarang-jain Oct 17, 2024
645d84b
host deps
tarang-jain Oct 17, 2024
d6897cc
Merge branch 'branch-24.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 17, 2024
396a589
Update conda/recipes/cuvs_bench/meta.yaml
tarang-jain Oct 17, 2024
d325698
Update conda/recipes/cuvs_bench_cpu/meta.yaml
tarang-jain Oct 17, 2024
03a1e09
arch dependendent diskann deps
tarang-jain Oct 18, 2024
f061f27
Merge branch 'branch-24.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 18, 2024
f95aec7
arch specific deps
tarang-jain Oct 18, 2024
b667786
CMAKE_SYSTEM_PROCESSOR check
tarang-jain Oct 18, 2024
4aa513f
update cmake flags
tarang-jain Oct 19, 2024
17f723e
CMAKE_SYSTEM_PROCESSOR regex
tarang-jain Oct 19, 2024
6532914
diskann build params
tarang-jain Oct 19, 2024
1f168a8
dbg
tarang-jain Oct 19, 2024
2a5d1fb
rename script to cuvs_vamana.cu
tarang-jain Oct 21, 2024
173df8f
Merge branch 'branch-24.12' into diskann-wrapper
tarang-jain Oct 21, 2024
46e7728
rm dbg statement
tarang-jain Oct 21, 2024
1b03cf7
Merge branch 'diskann-wrapper' of https://github.com/tarang-jain/cuvs…
tarang-jain Oct 21, 2024
c131c52
do not link cuvs for diskann only targets
tarang-jain Oct 21, 2024
3d40d2d
link aio
tarang-jain Oct 24, 2024
c3a25fc
Merge branch 'branch-24.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 24, 2024
6bebeb8
style
tarang-jain Oct 24, 2024
e254c9b
Merge branch 'branch-24.12' into diskann-wrapper
tarang-jain Oct 28, 2024
82e21e8
Merge branch 'branch-24.12' into diskann-wrapper
cjnolet Nov 4, 2024
de2bf84
Merge branch 'branch-24.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Nov 13, 2024
6d3b32d
include utils.h
tarang-jain Nov 27, 2024
2506ef1
serialize dataset
tarang-jain Nov 28, 2024
16f35b3
debug
tarang-jain Nov 28, 2024
18f26a7
rm debug statements
tarang-jain Nov 29, 2024
663dfe0
style
tarang-jain Nov 29, 2024
4bb2b39
alloc strided dim on host
tarang-jain Nov 29, 2024
fd43711
size
tarang-jain Nov 29, 2024
30bcc6e
Merge branch 'diskann-wrapper' of https://github.com/tarang-jain/cuvs…
Dec 2, 2024
a10d834
include_dataset flag
Dec 2, 2024
4b396d7
Merge branch 'branch-24.12' into diskann-wrapper
tarang-jain Dec 2, 2024
c443944
Merge branch 'diskann-wrapper' of https://github.com/tarang-jain/cuvs…
Dec 2, 2024
ed8c9b6
Merge branch 'branch-24.12' into diskann-wrapper
tarang-jain Dec 3, 2024
f015264
rm cagra+diskann
Dec 4, 2024
6d6167d
Merge branch 'branch-24.12' into diskann-wrapper
tarang-jain Dec 4, 2024
9c2185b
Merge branch 'branch-24.12' into diskann-wrapper
tarang-jain Dec 5, 2024
2b758d3
Merge branch 'diskann-wrapper' of https://github.com/tarang-jain/cuvs…
tarang-jain Dec 10, 2024
b7ba35b
update copyright
tarang-jain Dec 10, 2024
63e02ff
Merge branch 'branch-25.02' into diskann-wrapper
tarang-jain Dec 11, 2024
48a6a9d
Merge branch 'branch-25.02' into diskann-wrapper
cjnolet Dec 12, 2024
fd429d2
style
tarang-jain Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions conda/environments/bench_ann_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ dependencies:
- gcc_linux-64=11.*
- glog>=0.6.0
- h5py>=3.8.0
- libaio
- libboost-devel
- libcublas-dev=11.11.3.6
- libcublas=11.11.3.6
- libcurand-dev=10.3.0.86
Expand All @@ -34,6 +36,7 @@ dependencies:
- libcusparse=11.7.5.86
- librmm==24.12.*,>=0.0.0a0
- matplotlib
- mkl-devel
- nccl>=2.19
- ninja
- nlohmann_json>=3.11.2
Expand Down
3 changes: 3 additions & 0 deletions conda/environments/bench_ann_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ dependencies:
- gcc_linux-64=11.*
- glog>=0.6.0
- h5py>=3.8.0
- libaio
- libboost-devel
- libcublas-dev
- libcurand-dev
- libcusolver-dev
- libcusparse-dev
- librmm==24.12.*,>=0.0.0a0
- matplotlib
- mkl-devel
- nccl>=2.19
- ninja
- nlohmann_json>=3.11.2
Expand Down
3 changes: 3 additions & 0 deletions conda/recipes/cuvs-bench-cpu/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ requirements:
- benchmark
- fmt {{ fmt_version }}
- glog {{ glog_version }}
- libaio
- libboost-devel
- mkl-devel # [linux64]
- nlohmann_json {{ nlohmann_json_version }}
- openblas
- python
Expand Down
3 changes: 3 additions & 0 deletions conda/recipes/cuvs-bench/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ requirements:
- libcublas-dev
{% endif %}
- glog {{ glog_version }}
- libaio
- libboost-devel
- libcuvs {{ version }}
- mkl-devel # [linux64]
- nlohmann_json {{ nlohmann_json_version }}
- openblas
# rmm is needed to determine if package is gpu-enabled
Expand Down
25 changes: 25 additions & 0 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ option(CUVS_ANN_BENCH_USE_CUVS_BRUTE_FORCE "Include cuVS brute force knn in benc
option(CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB "Include cuVS CAGRA with HNSW search in benchmark" ON)
option(CUVS_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON)
option(CUVS_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" OFF)
option(CUVS_ANN_BENCH_USE_DISKANN "Include DISKANN search in benchmark" ON)
option(CUVS_ANN_BENCH_USE_CUVS_VAMANA "Include cuVS Vamana with DiskANN search in benchmark" ON)
if(CMAKE_SYSTEM_PROCESSOR MATCHES "(ARM|arm|aarch64)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does MSFT DiskANN repo not support ARM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they have mkl-devel as a dependency, which is not meant to be installed in aarch64.

set(CUVS_ANN_BENCH_USE_DISKANN OFF)
set(CUVS_ANN_BENCH_USE_CUVS_VAMANA OFF)
endif()
option(CUVS_ANN_BENCH_USE_CUVS_MG "Include cuVS ann mg algorithm in benchmark" ${BUILD_MG_ALGOS})
option(CUVS_ANN_BENCH_SINGLE_EXE
"Make a single executable with benchmark as shared library modules" OFF
Expand All @@ -57,6 +63,7 @@ if(BUILD_CPU_ONLY)
set(CUVS_ANN_BENCH_USE_GGNN OFF)
set(CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE OFF)
set(CUVS_ANN_BENCH_USE_CUVS_MG OFF)
set(CUVS_ANN_BENCH_USE_CUVS_VAMANA OFF)
else()
set(CUVS_FAISS_ENABLE_GPU ON)
endif()
Expand All @@ -69,6 +76,7 @@ if(CUVS_ANN_BENCH_USE_CUVS_IVF_PQ
OR CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB
OR CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE
OR CUVS_ANN_BENCH_USE_CUVS_MG
OR CUVS_ANN_BENCH_USE_CUVS_VAMANA
)
set(CUVS_ANN_BENCH_USE_CUVS ON)
endif()
Expand All @@ -90,6 +98,10 @@ if(CUVS_ANN_BENCH_USE_FAISS)
include(cmake/thirdparty/get_faiss)
endif()

if(CUVS_ANN_BENCH_USE_DISKANN OR CUVS_ANN_BENCH_USE_CUVS_VAMANA)
include(cmake/thirdparty/get_diskann)
endif()

# ##################################################################################################
# * Target function -------------------------------------------------------------

Expand Down Expand Up @@ -288,6 +300,19 @@ if(CUVS_ANN_BENCH_USE_GGNN)
)
endif()

if(CUVS_ANN_BENCH_USE_DISKANN)
ConfigureAnnBench(
NAME DISKANN_MEMORY PATH src/diskann/diskann_benchmark.cpp LINKS diskann::diskann aio
)
ConfigureAnnBench(
NAME DISKANN_SSD PATH src/diskann/diskann_benchmark.cpp LINKS diskann::diskann aio
)
endif()

if(CUVS_ANN_BENCH_USE_CUVS_VAMANA)
ConfigureAnnBench(NAME CUVS_VAMANA PATH src/cuvs/cuvs_vamana.cu LINKS cuvs diskann::diskann aio)
endif()

# ##################################################################################################
# * Dynamically-loading ANN_BENCH executable -------------------------------------------------------
if(CUVS_ANN_BENCH_SINGLE_EXE)
Expand Down
16 changes: 14 additions & 2 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ void bench_build(::benchmark::State& state,
}
}

if (index.algo == "diskann_ssd") {
make_sure_parent_dir_exists(index.file);
index.build_param["dataset_file"] = dataset->base_filename();
index.build_param["path_to_index"] = index.file;
}

std::unique_ptr<algo<T>> algo;
try {
algo = create_algo<T>(index.algo, dataset->distance(), dataset->dim(), index.build_param);
Expand All @@ -144,7 +150,8 @@ void bench_build(::benchmark::State& state,

const auto algo_property = parse_algo_property(algo->get_preference(), index.build_param);

const T* base_set = dataset->base_set(algo_property.dataset_memory_type);
const T* base_set = nullptr;
if (index.algo != "diskann_ssd") base_set = dataset->base_set(algo_property.dataset_memory_type);
std::size_t index_size = dataset->base_set_size();

cuda_timer gpu_timer{algo};
Expand Down Expand Up @@ -223,7 +230,12 @@ void bench_search(::benchmark::State& state,

const T* query_set = nullptr;

if (!file_exists(index.file)) {
std::string filename;
if (index.algo != "diskann_ssd")
filename = index.file;
else
filename = index.file + "_disk.index";
if (!file_exists(filename)) {
state.SkipWithError("Index file is missing. Run the benchmark in the build mode first.");
return;
}
Expand Down
18 changes: 14 additions & 4 deletions cpp/bench/ann/src/common/dataset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class bin_file {
}
}

std::string file() const { return file_; }

private:
void check_suffix();
void open_file() const;
Expand Down Expand Up @@ -253,10 +255,11 @@ class dataset {

auto name() const -> std::string { return name_; }
auto distance() const -> std::string { return distance_; }
virtual auto dim() const -> int = 0;
virtual auto max_k() const -> uint32_t = 0;
virtual auto base_set_size() const -> size_t = 0;
virtual auto query_set_size() const -> size_t = 0;
virtual auto dim() const -> int = 0;
virtual auto max_k() const -> uint32_t = 0;
virtual auto base_set_size() const -> size_t = 0;
virtual auto query_set_size() const -> size_t = 0;
virtual auto base_filename() const -> std::string = 0;

// load data lazily, so don't pay the overhead of reading unneeded set
// e.g. don't load base set when searching
Expand Down Expand Up @@ -424,6 +427,7 @@ class bin_dataset : public dataset<T> {
auto max_k() const -> uint32_t override;
auto base_set_size() const -> size_t override;
auto query_set_size() const -> size_t override;
std::string base_filename() const override;

private:
void load_base_set() const;
Expand Down Expand Up @@ -541,4 +545,10 @@ void bin_dataset<T>::map_base_set() const
this->mapped_base_set_ = base_file_.map();
}

template <typename T>
std::string bin_dataset<T>::base_filename() const
{
return base_file_.file();
}

} // namespace cuvs::bench
108 changes: 108 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_cagra_diskann_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "../hnswlib/hnswlib_wrapper.h"
#include "cuvs_cagra_wrapper.h"

#include <memory>

namespace cuvs::bench {

template <typename T, typename IdxT>
class cuvs_cagra_diskann : public algo<T>, public algo_gpu {
public:
using search_param_base = typename algo<T>::search_param;
using build_param = typename cuvs_cagra<T, IdxT>::build_param;
using search_param = typename diskann_mem<T>::search_param;

cuvs_cagra_diskann(Metric metric, int dim, const build_param& param)
: algo<T>(metric, dim),
cagra_build_{metric, dim, param},
// hnsw_lib param values don't matter since we don't build with hnsw_lib
diskann_mem_search_{metric, dim, typename diskann_mem<T>::build_param{50, 100}}
{
}

void build(const T* dataset, size_t nrow) final;

void set_search_param(const search_param_base& param) override;

void search(const T* queries,
int batch_size,
int k,
algo_base::index_type* neighbors,
float* distances) const override;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
return cagra_build_.get_sync_stream();
}

// to enable dataset access from GPU memory
[[nodiscard]] auto get_preference() const -> algo_property override
{
algo_property property;
property.dataset_memory_type = MemoryType::kHostMmap;
property.query_memory_type = MemoryType::kHost;
return property;
}

void save(const std::string& file) const override;
void load(const std::string&) override;
std::unique_ptr<algo<T>> copy() override
{
return std::make_unique<cuvs_cagra_hnswlib<T, IdxT>>(*this);
}

private:
cuvs_cagra<T, IdxT> cagra_build_;
hnsw_lib<T> hnswlib_search_;
};

template <typename T, typename IdxT>
void cuvs_cagra_hnswlib<T, IdxT>::build(const T* dataset, size_t nrow)
{
cagra_build_.build(dataset, nrow);
}

template <typename T, typename IdxT>
void cuvs_cagra_hnswlib<T, IdxT>::set_search_param(const search_param_base& param_)
{
hnswlib_search_.set_search_param(param_);
}

template <typename T, typename IdxT>
void cuvs_cagra_hnswlib<T, IdxT>::save(const std::string& file) const
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
{
cagra_build_.save_to_hnswlib(file);
}

template <typename T, typename IdxT>
void cuvs_cagra_hnswlib<T, IdxT>::load(const std::string& file)
{
hnswlib_search_.load(file);
hnswlib_search_.set_base_layer_only();
}

template <typename T, typename IdxT>
void cuvs_cagra_hnswlib<T, IdxT>::search(
const T* queries, int batch_size, int k, algo_base::index_type* neighbors, float* distances) const
{
hnswlib_search_.search(queries, batch_size, k, neighbors, distances);
}

} // namespace cuvs::bench
93 changes: 93 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_vamana.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "../common/ann_types.hpp"
#include "cuvs_vamana_wrapper.h"

#include <rmm/cuda_device.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

namespace cuvs::bench {

template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
typename cuvs::bench::cuvs_vamana<T, IdxT>::build_param& param)
{
if (conf.contains("graph_degree")) { param.graph_degree = conf.at("graph_degree"); }
if (conf.contains("visited_size")) { param.visited_size = conf.at("visited_size"); }
if (conf.contains("alpha")) { param.alpha = conf.at("alpha"); }
}

template <typename T, typename IdxT>
void parse_search_param(const nlohmann::json& conf,
typename cuvs::bench::cuvs_vamana<T, IdxT>::search_param& param)
{
if (conf.contains("L_search")) { param.L_search = conf.at("L_search"); }
if (conf.contains("num_threads")) { param.num_threads = conf.at("num_threads"); }
}

template <typename T>
auto create_algo(const std::string& algo_name,
const std::string& distance,
int dim,
const nlohmann::json& conf) -> std::unique_ptr<cuvs::bench::algo<T>>
{
[[maybe_unused]] cuvs::bench::Metric metric = parse_metric(distance);
std::unique_ptr<cuvs::bench::algo<T>> a;

if constexpr (std::is_same_v<T, float> or std::is_same_v<T, std::uint8_t>) {
if (algo_name == "cuvs_vamana") {
typename cuvs::bench::cuvs_vamana<T, uint32_t>::build_param param;
parse_build_param<T, uint32_t>(conf, param);
a = std::make_unique<cuvs::bench::cuvs_vamana<T, uint32_t>>(metric, dim, param);
}
}

if (!a) { throw std::runtime_error("invalid algo: '" + algo_name + "'"); }

return a;
}

template <typename T>
auto create_search_param(const std::string& algo_name, const nlohmann::json& conf)
-> std::unique_ptr<typename cuvs::bench::algo<T>::search_param>
{
if (algo_name == "cuvs_vamana") {
auto param = std::make_unique<typename cuvs::bench::cuvs_vamana<T, uint32_t>::search_param>();
parse_search_param<T, uint32_t>(conf, *param);
return param;
}

throw std::runtime_error("invalid algo: '" + algo_name + "'");
}

} // namespace cuvs::bench

REGISTER_ALGO_INSTANCE(float);

#ifdef ANN_BENCH_BUILD_MAIN
#include "../common/benchmark.hpp"
/*
[NOTE] Dear developer,

Please don't modify the content of the `main` function; this will make the behavior of the benchmark
executable differ depending on the cmake flags and will complicate the debugging. In particular,
don't try to setup an RMM memory resource here; it will anyway be modified by the memory resource
set on per-algorithm basis. For example, see `cuvs/cuvs_ann_bench_utils.h`.
*/
int main(int argc, char** argv) { return cuvs::bench::run_main(argc, argv); }
#endif
Loading
Loading