Skip to content

Commit

Permalink
Learn heuristic to pick fastest select_k algorithm (#1523)
Browse files Browse the repository at this point in the history
This uses the select_k dataset from #1497 to learn a heuristic of the fastest select_k variant based off the rows/ cols/k of the input. This heuristic is modelled as a DecisionTree, which is automatically exported in C++ code that is compiled into RAFT. This lets us learn a function to pick the fastest select_k method - which requires only a few if statements in C++ code, making it very cheap to evaluate.

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1523
  • Loading branch information
benfred authored May 17, 2023
1 parent 5392a91 commit 618dc23
Show file tree
Hide file tree
Showing 7 changed files with 1,620 additions and 11 deletions.
14 changes: 12 additions & 2 deletions cpp/bench/prims/matrix/select_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ using namespace raft::bench; // NOLINT
template <typename KeyT, typename IdxT, select::Algo Algo>
struct selection : public fixture {
explicit selection(const select::params& p)
: fixture(true),
: fixture(p.use_memory_pool),
params_(p),
in_dists_(p.batch_size * p.len, stream),
in_ids_(p.batch_size * p.len, stream),
Expand Down Expand Up @@ -193,7 +193,8 @@ SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT
using SelectK = selection<KeyT, IdxT, select::Algo::A>; \
std::stringstream name; \
name << "SelectKDataset/" << #KeyT "/" #IdxT "/" #A << "/" << input.batch_size << "/" \
<< input.len << "/" << input.k << "/" << input.use_index_input; \
<< input.len << "/" << input.k << "/" << input.use_index_input << "/" \
<< input.use_memory_pool; \
auto* b = ::benchmark::internal::RegisterBenchmarkInternal( \
new raft::bench::internal::Fixture<SelectK, select::params>(name.str(), input)); \
b->UseManualTime(); \
Expand Down Expand Up @@ -266,5 +267,14 @@ void add_select_k_dataset_benchmarks()
SELECTION_REGISTER_INPUT(float, int64_t, input);
SELECTION_REGISTER_INPUT(float, uint32_t, input);
}

// also try again without a memory pool to see if there are significant differences
for (auto input : inputs) {
input.use_memory_pool = false;
SELECTION_REGISTER_INPUT(double, int64_t, input);
SELECTION_REGISTER_INPUT(double, uint32_t, input);
SELECTION_REGISTER_INPUT(float, int64_t, input);
SELECTION_REGISTER_INPUT(float, uint32_t, input);
}
}
} // namespace raft::matrix
120 changes: 111 additions & 9 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,101 @@

#include <raft/core/nvtx.hpp>

#include <raft/neighbors/detail/selection_faiss.cuh>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

namespace raft::matrix::detail {

// this is a subset of algorithms, chosen by running the algorithm_selection
// notebook in cpp/scripts/heuristics/select_k
enum class Algo { kRadix11bits, kWarpDistributedShm, kFaissBlockSelect };

/**
* Predict the fastest select_k algorithm based on the number of rows/cols/k
*
* The body of this method is automatically generated, using a DecisionTree
* to predict the fastest algorithm based off of thousands of trial runs
* on different values of rows/cols/k. The decision tree is converted to c++
* code, which is cut and paste below.
*
* NOTE: The code to generate is in cpp/scripts/heuristics/select_k, running the
* 'generate_heuristic' notebook there will replace the body of this function
* with the latest learned heuristic
*/
inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)
{
if (k > 134) {
if (k > 256) {
if (k > 809) {
return Algo::kRadix11bits;
} else {
if (rows > 124) {
if (cols > 63488) {
return Algo::kFaissBlockSelect;
} else {
return Algo::kRadix11bits;
}
} else {
return Algo::kRadix11bits;
}
}
} else {
if (cols > 678736) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kRadix11bits;
}
}
} else {
if (cols > 13776) {
if (rows > 335) {
if (k > 1) {
if (rows > 546) {
return Algo::kWarpDistributedShm;
} else {
if (k > 17) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kFaissBlockSelect;
}
}
} else {
return Algo::kFaissBlockSelect;
}
} else {
if (k > 44) {
if (cols > 1031051) {
return Algo::kWarpDistributedShm;
} else {
if (rows > 22) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kRadix11bits;
}
}
} else {
return Algo::kWarpDistributedShm;
}
}
} else {
if (k > 1) {
if (rows > 188) {
return Algo::kWarpDistributedShm;
} else {
if (k > 72) {
return Algo::kRadix11bits;
} else {
return Algo::kWarpDistributedShm;
}
}
} else {
return Algo::kFaissBlockSelect;
}
}
}
}

/**
* Select k smallest or largest key/values from each row in the input data.
*
Expand Down Expand Up @@ -77,15 +167,27 @@ void select_k(const T* in_val,
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);
// TODO (achirkin): investigate the trade-off for a wider variety of inputs.
const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128;
if (k <= select::warpsort::kMaxCapacity && !radix_faster) {
select::warpsort::select_k<T, IdxT>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
} else {
select::radix::select_k<T, IdxT, (sizeof(T) >= 4 ? 11 : 8), 512>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr);

auto algo = choose_select_k_algorithm(batch_size, len, k);
switch (algo) {
case Algo::kRadix11bits:
return detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
true, // fused_last_filter
stream);
case Algo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream);
case Algo::kFaissBlockSelect:
return neighbors::detail::select_k(
in_val, in_idx, batch_size, len, out_val, out_idx, select_min, k, stream);
}
}

} // namespace raft::matrix::detail
1 change: 1 addition & 0 deletions cpp/internal/raft_internal/matrix/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct params {
bool select_min;
bool use_index_input = true;
bool use_same_leading_bits = false;
bool use_memory_pool = true;
};

inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream&
Expand Down
Loading

0 comments on commit 618dc23

Please sign in to comment.