Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add: Add index_gt::merge() #572

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 186 additions & 60 deletions cpp/bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ struct running_stats_printer_t {
timestamp_t time = std::chrono::high_resolution_clock::now();
std::size_t duration = std::chrono::duration_cast<std::chrono::nanoseconds>(time - start_time).count();
float vectors_per_second = count * 1e9 / duration;
std::printf("\r\33[2K100 %% completed, %.0f vectors/s\n", vectors_per_second);
std::printf("\r\33[2K100 %% completed, %.0f vectors/s, %.1f s\n", vectors_per_second, duration / 1e9);
}

void refresh(std::size_t step = 1024 * 32) {
Expand Down Expand Up @@ -286,8 +286,9 @@ struct running_stats_printer_t {
}
};

template <typename index_at, typename vector_id_at, typename scalar_at>
void index_many(index_at& index, std::size_t n, vector_id_at const* ids, scalar_at const* vectors, std::size_t dims) {
template <typename index_at, typename vector_id_at, typename scalar_at, typename add_at>
void index_many(index_at& index, std::size_t n, vector_id_at const* ids, scalar_at const* vectors, std::size_t dims,
add_at&& add) {

running_stats_printer_t printer{n, "Indexing"};

Expand All @@ -300,7 +301,7 @@ void index_many(index_at& index, std::size_t n, vector_id_at const* ids, scalar_
config.thread = omp_get_thread_num();
#endif
span_gt<scalar_at const> vector{vectors + dims * i, dims};
index.add(ids[i], vector, config.thread);
add(index, ids[i], vector, config);
printer.progress++;
if (config.thread == 0)
printer.refresh();
Expand All @@ -310,10 +311,10 @@ void index_many(index_at& index, std::size_t n, vector_id_at const* ids, scalar_
printer.print();
}

template <typename index_at, typename vector_id_at, typename scalar_at, typename distance_at>
template <typename index_at, typename vector_id_at, typename scalar_at, typename distance_at, typename search_at>
void search_many( //
index_at& index, std::size_t n, scalar_at const* vectors, std::size_t dims, std::size_t wanted, vector_id_at* ids,
distance_at* distances) {
distance_at* distances, search_at&& search) {

std::string name = "Search " + std::to_string(wanted);
running_stats_printer_t printer{n, name.c_str()};
Expand All @@ -327,7 +328,8 @@ void search_many( //
config.thread = omp_get_thread_num();
#endif
span_gt<scalar_at const> vector{vectors + dims * i, dims};
index.search(vector, wanted, config.thread).dump_to(ids + wanted * i, distances + wanted * i);
typename index_at::search_result_t search_result = search(index, vector, wanted, config);
search_result.dump_to(ids + wanted * i, distances + wanted * i);
printer.progress++;
if (config.thread == 0)
printer.refresh();
Expand All @@ -337,8 +339,8 @@ void search_many( //
printer.print();
}

template <typename dataset_at, typename index_at> //
static void single_shot(dataset_at& dataset, index_at& index, bool construct = true) {
template <typename dataset_at, typename index_at, typename add_at, typename search_at> //
static void single_shot(dataset_at& dataset, index_at& index, bool construct, add_at&& add, search_at&& search) {
using distance_t = typename index_at::distance_t;
constexpr default_key_t missing_key = std::numeric_limits<default_key_t>::max();

Expand All @@ -348,14 +350,14 @@ static void single_shot(dataset_at& dataset, index_at& index, bool construct = t
// Perform insertions, evaluate speed
std::vector<default_key_t> ids(dataset.vectors_count());
std::iota(ids.begin(), ids.end(), 0);
index_many(index, dataset.vectors_count(), ids.data(), dataset.vector(0), dataset.dimensions());
index_many(index, dataset.vectors_count(), ids.data(), dataset.vector(0), dataset.dimensions(), add);
}

// Perform search, evaluate speed
std::vector<default_key_t> found_neighbors(dataset.queries_count() * dataset.neighborhood_size());
std::vector<distance_t> found_distances(dataset.queries_count() * dataset.neighborhood_size());
search_many(index, dataset.queries_count(), dataset.query(0), dataset.dimensions(), dataset.neighborhood_size(),
found_neighbors.data(), found_distances.data());
found_neighbors.data(), found_distances.data(), search);

// Evaluate search quality
std::size_t recall_at_1 = 0, recall_full = 0;
Expand All @@ -369,43 +371,45 @@ static void single_shot(dataset_at& dataset, index_at& index, bool construct = t
std::printf("Recall@1 %.2f %%\n", recall_at_1 * 100.f / dataset.queries_count());
std::printf("Recall %.2f %%\n", recall_full * 100.f / dataset.queries_count());

// Perform joins
std::vector<default_key_t> man_to_woman(dataset.vectors_count());
std::vector<default_key_t> woman_to_man(dataset.vectors_count());
std::size_t join_attempts = 0;
{
index_at& men = index;
index_at women = index.copy();
std::fill(man_to_woman.begin(), man_to_woman.end(), missing_key);
std::fill(woman_to_man.begin(), woman_to_man.end(), missing_key);
if constexpr (!std::is_same_v<index_at, index_gt<>>) {
// Perform joins
std::vector<default_key_t> man_to_woman(dataset.vectors_count());
std::vector<default_key_t> woman_to_man(dataset.vectors_count());
std::size_t join_attempts = 0;
{
executor_default_t executor(index.limits().threads());
running_stats_printer_t printer{1, "Join"};
join_result_t result = join( //
men, women, index_join_config_t{executor.size()}, //
man_to_woman.data(), woman_to_man.data(), //
executor, [&](std::size_t progress, std::size_t total) {
if (progress % 1000 == 0)
printer.print(progress, total);
return true;
});
// Refresh once again to show 100% completion
printer.print();
join_attempts = result.visited_members;
index_at& men = index;
index_at women = index.copy();
std::fill(man_to_woman.begin(), man_to_woman.end(), missing_key);
std::fill(woman_to_man.begin(), woman_to_man.end(), missing_key);
{
executor_default_t executor(index.limits().threads());
running_stats_printer_t printer{1, "Join"};
join_result_t result = join( //
men, women, index_join_config_t{executor.size()}, //
man_to_woman.data(), woman_to_man.data(), //
executor, [&](std::size_t progress, std::size_t total) {
if (progress % 1000 == 0)
printer.print(progress, total);
return true;
});
// Refresh once again to show 100% completion
printer.print();
join_attempts = result.visited_members;
}
}
}
// Evaluate join quality
std::size_t recall_join = 0, unmatched_count = 0;
for (std::size_t i = 0; i != index.size(); ++i) {
recall_join += man_to_woman[i] == static_cast<default_key_t>(i);
unmatched_count += man_to_woman[i] == missing_key;
}
std::printf("Recall Joins %.2f %%\n", recall_join * 100.f / index.size());
std::printf("Unmatched %.2f %% (%zu items)\n", unmatched_count * 100.f / index.size(), unmatched_count);
std::printf("Proposals %.2f / man (%zu total)\n", join_attempts * 1.f / index.size(), join_attempts);
// Evaluate join quality
std::size_t recall_join = 0, unmatched_count = 0;
for (std::size_t i = 0; i != index.size(); ++i) {
recall_join += man_to_woman[i] == static_cast<default_key_t>(i);
unmatched_count += man_to_woman[i] == missing_key;
}
std::printf("Recall Joins %.2f %%\n", recall_join * 100.f / index.size());
std::printf("Unmatched %.2f %% (%zu items)\n", unmatched_count * 100.f / index.size(), unmatched_count);
std::printf("Proposals %.2f / man (%zu total)\n", join_attempts * 1.f / index.size(), join_attempts);

std::printf("------------\n");
std::printf("\n");
std::printf("------------\n");
std::printf("\n");
}
}

void handler(int sig) {
Expand Down Expand Up @@ -468,6 +472,9 @@ struct args_t {

bool big = false;

bool index_gt_api = false;
std::size_t chunks_to_merge = 0;

bool quantize_bf16 = false;
bool quantize_f16 = false;
bool quantize_i8 = false;
Expand Down Expand Up @@ -516,6 +523,8 @@ struct args_t {
template <typename index_at, typename dataset_at> //
void run_punned(dataset_at& dataset, args_t const& args, index_config_t config, index_limits_t limits) {

using scalar_t = typename dataset_at::scalar_t;

scalar_kind_t quantization = args.quantization();
std::printf("-- Quantization: %s\n", scalar_kind_name(quantization));

Expand All @@ -528,31 +537,141 @@ void run_punned(dataset_at& dataset, args_t const& args, index_config_t config,
std::printf("-- Hardware acceleration: %s\n", index.metric().isa_name());
std::printf("Will benchmark in-memory\n");

single_shot(dataset, index, true);
auto add = [&](index_at& index, default_key_t id, span_gt<scalar_t const> vector, index_update_config_t config) {
index.add(id, vector, config.thread);
};
auto search = [&](index_at& index, span_gt<scalar_t const> vector, std::size_t wanted,
index_search_config_t config) { return index.search(vector, wanted, config.thread); };

single_shot(dataset, index, true, add, search);
index.save(args.path_output.c_str());

std::printf("Will benchmark an on-disk view\n");

index_at index_view = index.fork();
index_view.view(args.path_output.c_str());
single_shot(dataset, index_view, false);
single_shot(dataset, index_view, false, add, search);
}

template <typename index_at, typename dataset_at> //
void run_typed(dataset_at& dataset, args_t const& args, index_config_t config, index_limits_t limits) {
using distance_t = typename index_at::distance_t;
using member_ref_t = typename index_at::member_ref_t;
using member_cref_t = typename index_at::member_cref_t;
using member_citerator_t = typename index_at::member_citerator_t;

using scalar_t = typename dataset_at::scalar_t;

scalar_kind_t quantization = args.quantization();
std::printf("-- Quantization: %s\n", scalar_kind_name(quantization));

metric_kind_t kind = args.metric();
std::printf("-- Metric: %s\n", metric_kind_name(kind));

metric_punned_t metric_punned(dataset.dimensions(), kind, quantization);
buffer_gt<byte_t const*> values(limits.members);

std::printf("-- Hardware acceleration: %s\n", metric_punned.isa_name());

class metric_t {
metric_punned_t metric_;
buffer_gt<byte_t const*>& values_;

public:
metric_t(metric_punned_t metric, buffer_gt<byte_t const*>& values) noexcept
: metric_(metric), values_(values) {}

inline distance_t operator()(byte_t const* a, member_cref_t b) const noexcept { return f(a, v(b)); }
inline distance_t operator()(member_cref_t a, member_cref_t b) const noexcept { return f(v(a), v(b)); }

inline distance_t operator()(byte_t const* a, member_citerator_t b) const noexcept { return f(a, v(b)); }
inline distance_t operator()(member_citerator_t a, member_citerator_t b) const noexcept {
return f(v(a), v(b));
}

inline distance_t operator()(byte_t const* a, byte_t const* b) const noexcept { return f(a, b); }

inline byte_t const* v(member_cref_t m) const noexcept { return values_[get_slot(m)]; }
inline byte_t const* v(member_citerator_t m) const noexcept { return values_[get_slot(m)]; }
inline distance_t f(byte_t const* a, byte_t const* b) const noexcept { return metric_(a, b); }
};
metric_t metric{metric_punned, values};

auto add = [&](index_at& index, default_key_t id, span_gt<scalar_t const> vector, index_update_config_t& config) {
byte_t const* vector_data = reinterpret_cast<byte_t const*>(vector.data());
auto on_success = [&](member_ref_t member) { values[member.slot] = vector_data; };
index.add(id, vector_data, metric, config, on_success);
};
auto search = [&](index_at& index, span_gt<scalar_t const> vector, std::size_t wanted,
index_search_config_t config) {
byte_t const* vector_data = reinterpret_cast<byte_t const*>(vector.data());
return index.search(vector_data, wanted, metric, config);
};

index_at index(config);
index.reserve(limits);
std::printf("Will benchmark in-memory\n");
if (args.chunks_to_merge > 0) {
index.save(args.path_output.c_str());
memory_mapped_file_t output{args.path_output.c_str(), true};
index.load(std::move(output));
index.reserve(limits);
std::printf("Will benchmark merge: %zu\n", args.chunks_to_merge);

single_shot(dataset, index, true);
index.save(args.path_output.c_str());
std::printf("\n");
std::printf("------------\n");

std::printf("Will benchmark an on-disk view\n");
{
// Perform insertions, evaluate speed
std::vector<default_key_t> ids(dataset.vectors_count());
std::iota(ids.begin(), ids.end(), 0);
std::vector<index_at> subindexes;
std::vector<buffer_gt<byte_t const*>> subvalues;
std::size_t offset = 0;
std::size_t chunk_rows = dataset.vectors_count() / args.chunks_to_merge;
for (std::size_t i = 0; i != args.chunks_to_merge; ++i) {
subindexes.emplace_back(config);
index_at& subindex = subindexes[i];
std::size_t n = std::min(chunk_rows, dataset.vectors_count() - offset);
subindex.reserve(n);
subvalues.emplace_back(n);
buffer_gt<byte_t const*>& subvs = subvalues[i];
metric_t submetric{metric_punned, subvs};
auto subadd = [&](index_at& index, default_key_t id, span_gt<scalar_t const> vector,
index_update_config_t& config) {
byte_t const* vector_data = reinterpret_cast<byte_t const*>(vector.data());
auto on_success = [&](member_ref_t member) { subvs[member.slot] = vector_data; };
index.add(id, vector_data, submetric, config, on_success);
};
index_many(subindex, n, ids.data() + offset, dataset.vector(offset), dataset.dimensions(), subadd);
offset += chunk_rows;
}
{
running_stats_printer_t printer{dataset.vectors_count(), "Merging"};
auto merge_on_success = [&](member_ref_t member, byte_t const* value) { values[member.slot] = value; };
for (std::size_t i = 0; i != args.chunks_to_merge; ++i) {
buffer_gt<byte_t const*>& subvs = subvalues[i];
auto get_value = [&](member_cref_t member) { return subvs[member.slot]; };
index.merge(subindexes[i], get_value, metric, {}, merge_on_success);
printer.progress += subindexes[i].size();
printer.refresh();
}
printer.print();
}
}
// Perform searches, evaluate speed
single_shot(dataset, index, false, add, search);
} else {
index.reserve(limits);
std::printf("Will benchmark in-memory\n");

index_at index_view = index.fork();
index_view.view(args.path_output.c_str());
single_shot(dataset, index_view, false);
single_shot(dataset, index, true, add, search);
index.save(args.path_output.c_str());

std::printf("Will benchmark an on-disk view\n");

index_at index_view = index.fork();
index_view.view(args.path_output.c_str());
single_shot(dataset, index_view, false, add, search);
}
}

template <typename dataset_scalar_at> void bench_with_args(args_t const& args) {
Expand Down Expand Up @@ -583,14 +702,18 @@ template <typename dataset_scalar_at> void bench_with_args(args_t const& args) {
std::printf("-- Expansion @ Add: %zu\n", config.expansion_add);
std::printf("-- Expansion @ Search: %zu\n", config.expansion_search);

if (args.big)
if (args.index_gt_api) {
run_typed<index_gt<>>(dataset, args, config, limits);
} else {
if (args.big)
#ifdef USEARCH_64BIT_ENV
run_punned<index_dense_gt<default_key_t, uint40_t>>(dataset, args, config, limits);
run_punned<index_dense_gt<default_key_t, uint40_t>>(dataset, args, config, limits);
#else
std::printf("Error: Don't use 40 bit identifiers in 32bit environment\n");
std::printf("Error: Don't use 40 bit identifiers in 32bit environment\n");
#endif
else
run_punned<index_dense_gt<default_key_t, std::uint32_t>>(dataset, args, config, limits);
else
run_punned<index_dense_gt<default_key_t, std::uint32_t>>(dataset, args, config, limits);
}
}

int main(int argc, char** argv) {
Expand All @@ -613,6 +736,9 @@ int main(int argc, char** argv) {
(option("--expansion-search") & value("integer", args.expansion_search)).doc("Affects search depth"),
(option("--rows-skip") & value("integer", args.vectors_to_skip)).doc("Number of vectors to skip"),
(option("--rows-take") & value("integer", args.vectors_to_take)).doc("Number of vectors to take"),
option("--index-gt-api").set(args.index_gt_api).doc("Use index_gt<> API not index_dense_gt API"),
(option("--chunks-to-merge") & value("integer", args.chunks_to_merge))
.doc("Number of chunks to merge. This requires --index-gt-api"),
( //
option("-bf16", "--bf16quant").set(args.quantize_bf16).doc("Enable `bf16_t` quantization") |
option("-f16", "--f16quant").set(args.quantize_f16).doc("Enable `f16_t` quantization") |
Expand Down
Loading