Skip to content

Commit

Permalink
Reduce compile times of distance specializations (#1307)
Browse files Browse the repository at this point in the history
Following the findings in https://github.com/ahendriksen/raft/tree/investigate-compile-time-reduction-strategies#investigation-of-compile-times, this PR reduces the compile times of the pairwise distance specializations.
This is achieved by:
1. Reducing the number of included files in the translation units where kernels are instantiated, specifically `spdlog` and `rmm` are avoided. 
2. Limiting loop unrolling in kernels with expensive operations in the inner loop.

Additional improvements geared towards iterative development:
1. The tests do not have to be recompiled when the internals of a pairwise distance kernel change. Before, a rebuilt was triggered due an include of `raft/distance/distance.cuh`.
2. Addition of a  fine tuning benchmark for the pairwise distance kernels that separates building the kernel from the benchmark code. This dramatically speeds up development. Compiling an empty benchmark takes roughly 18 seconds on my machine. Whereas recompiling a kernel takes ~3.8 seconds. Without this addition, a commit like 35a2ad4 would require substantially more time to make sure that performance is not degraded. 

![image](https://user-images.githubusercontent.com/4172822/225383120-5f8a82f9-0b46-4c39-bc1d-7b2a0551e881.png)

```
Parallel build time before: 270 seconds (6 cores, SMT, 12 jobs)
Parallel build time before: 147 seconds (6 cores, SMT, 12 jobs)

Sum of compile times before: 3022.6 seconds
Sum of compile times after:  1816.2 seconds

Comparison of compile times between headers and compiled: 
path                                                                         before (s)     after (s)  change (s) change (%)
pairwise_test                                                                    None        0.486       None     None
ance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu.o  101.1          10.3       -90.8     -89.8%
src/distance/distance/specializations/detail/canberra_float_float_float_int.cu.o   52.9           6.3       -46.6     -88.0%
/distance/distance/specializations/detail/canberra_double_double_double_int.cu.o   48.5           6.4       -42.1     -86.8%
stance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu.o   65.3          10.4       -55.0     -84.1%
istance/distance/specializations/detail/kl_divergence_float_float_float_int.cu.o   70.2          12.6       -57.6     -82.0%
stance/distance/specializations/detail/correlation_double_double_double_int.cu.o   46.7           8.9       -37.8     -80.9%
distance/specializations/detail/hellinger_expanded_double_double_double_int.cu.o   41.6           8.1       -33.5     -80.6%
nce/distance/specializations/detail/jensen_shannon_double_double_double_int.cu.o   74.6          15.1       -59.5     -79.7%
ir/src/distance/distance/specializations/detail/l1_double_double_double_int.cu.o   40.9           8.4       -32.5     -79.4%
ance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu.o   40.7           8.6       -32.1     -78.8%
distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu.o   40.8           9.0       -31.7     -77.8%
istance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu.o   45.9          10.2       -35.7     -77.8%
src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu.o   41.2           9.5       -31.8     -77.0%
istance/distance/specializations/detail/russel_rao_double_double_double_int.cu.o   29.5           7.2       -22.3     -75.6%
t.dir/src/distance/distance/specializations/detail/l1_float_float_float_int.cu.o   47.3          13.2       -34.1     -72.2%
ce/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu.o   47.0          13.3       -33.7     -71.6%
/distance/distance/specializations/detail/correlation_float_float_float_int.cu.o   49.4          14.1       -35.3     -71.5%
ce/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu.o   43.6          12.5       -31.1     -71.4%
c/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu.o   28.5           8.2       -20.3     -71.2%
ance/distance/specializations/detail/kl_divergence_double_double_double_int.cu.o   75.8          21.9       -53.9     -71.1%
istance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu.o   46.2          13.5       -32.7     -70.7%
ir/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu.o   43.1          12.7       -30.4     -70.6%
stance/distance/specializations/detail/l2_expanded_double_double_double_int.cu.o   52.3          24.9       -27.3     -52.3%
/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu.o   75.8          40.3       -35.5     -46.8%
rc/distance/distance/specializations/detail/cosine_double_double_double_int.cu.o   53.5          28.7       -24.8     -46.4%
r/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu.o   83.9          50.1       -33.8     -40.3%
CMakeFiles/pairwise_test.dir/test/distance/fused_l2_nn.cu.o                        85.1          64.1       -21.1     -24.7%
wise_test.dir/src/distance/distance/specializations/fused_l2_nn_float_int64.cu.o   56.2          42.9       -13.3     -23.6%
irwise_test.dir/src/distance/distance/specializations/fused_l2_nn_float_int.cu.o   52.5          40.2       -12.3     -23.5%
CMakeFiles/pairwise_test.dir/test/distance/dist_lp_unexp.cu.o                      56.3          43.3       -13.0     -23.1%
CMakeFiles/pairwise_test.dir/test/distance/dist_russell_rao.cu.o                   55.7          44.0       -11.7     -21.0%
rwise_test.dir/src/distance/distance/specializations/fused_l2_nn_double_int.cu.o   45.3          36.4        -9.0     -19.8%
CMakeFiles/pairwise_test.dir/test/distance/dist_l2_unexp.cu.o                      54.6          44.1       -10.6     -19.3%
CMakeFiles/pairwise_test.dir/test/distance/dist_canberra.cu.o                      51.6          42.1        -9.6     -18.6%
CMakeFiles/pairwise_test.dir/test/distance/dist_l2_exp.cu.o                        53.1          43.4        -9.6     -18.2%
CMakeFiles/pairwise_test.dir/test/distance/dist_l_inf.cu.o                         53.2          43.9        -9.3     -17.5%
CMakeFiles/pairwise_test.dir/test/distance/dist_hellinger.cu.o                     53.1          44.0        -9.0     -17.0%
CMakeFiles/pairwise_test.dir/test/distance/dist_hamming.cu.o                       52.3          43.4        -8.9     -17.0%
CMakeFiles/pairwise_test.dir/test/distance/dist_l2_sqrt_exp.cu.o                   54.0          45.6        -8.4     -15.6%
CMakeFiles/pairwise_test.dir/test/distance/dist_l1.cu.o                            52.6          44.5        -8.1     -15.4%
CMakeFiles/pairwise_test.dir/test/distance/dist_kl_divergence.cu.o                 52.4          44.7        -7.7     -14.8%
ise_test.dir/src/distance/distance/specializations/fused_l2_nn_double_int64.cu.o   43.5          37.2        -6.4     -14.7%
CMakeFiles/pairwise_test.dir/test/distance/dist_cos.cu.o                           52.4          44.8        -7.6     -14.5%
CMakeFiles/pairwise_test.dir/test/distance/dist_jensen_shannon.cu.o                53.2          45.7        -7.6     -14.2%
CMakeFiles/pairwise_test.dir/test/distance/dist_inner_product.cu.o                 51.1          44.8        -6.3     -12.4%
istance/distance/specializations/detail/inner_product_float_float_float_int.cu.o   39.5          35.1        -4.5     -11.3%
CMakeFiles/pairwise_test.dir/test/distance/dist_correlation.cu.o                   51.7          46.8        -4.9     -9.5%
ance/distance/specializations/detail/inner_product_double_double_double_int.cu.o   37.1          33.9        -3.1     -8.5%
src/distance/distance/specializations/detail/kernels/gram_matrix_base_float.cu.o   45.3          41.7        -3.6     -8.0%
rc/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu.o   42.5          39.6        -2.9     -6.8%
stance/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu.o   40.4          38.5        -1.9     -4.8%
CMakeFiles/pairwise_test.dir/test/distance/dist_adj.cu.o                          123.3         117.8        -5.4     -4.4%
CMakeFiles/pairwise_test.dir/test/distance/gram.cu.o                               55.3          53.4        -1.9     -3.5%
build.ninja                                                                         4.0           4.0        +0.0     +0.1%
istance/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu.o   45.2          45.6        +0.4     +0.8%
.dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_float.cu.o   45.2          46.0        +0.8     +1.7%
dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_double.cu.o   39.0          39.8        +0.8     +2.1%
CMakeFiles/pairwise_test.dir/src/distance/distance/pairwise_distance.cu.o          39.6          50.1       +10.5     +26.6%
```

Authors:
  - Allard Hendriksen (https://github.com/ahendriksen)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Divye Gala (https://github.com/divyegala)

URL: #1307
  • Loading branch information
ahendriksen authored Mar 23, 2023
1 parent a7e619c commit 08e7012
Show file tree
Hide file tree
Showing 87 changed files with 2,057 additions and 2,000 deletions.
4 changes: 0 additions & 4 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,6 @@ if(RAFT_COMPILE_LIBRARY)
src/distance/specializations/detail/l1_double_double_double_int.cu
src/distance/specializations/detail/l2_expanded_float_float_float_int.cu
src/distance/specializations/detail/l2_expanded_double_double_double_int.cu
src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu
src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu
src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/l_inf_double_double_double_int.cu
Expand Down
5 changes: 5 additions & 0 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ if(BUILD_BENCH)
OPTIONAL LIB
)

ConfigureBench(
NAME TUNE_DISTANCE PATH bench/distance/tune_pairwise/kernel.cu
bench/distance/tune_pairwise/bench.cu bench/main.cpp
)

ConfigureBench(
NAME
DISTANCE_BENCH
Expand Down
151 changes: 151 additions & 0 deletions cpp/bench/distance/tune_pairwise/bench.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Tuning benchmarks.
//
// Goals:
//
// 1. Fast compile times to maintain iteration speed.
// 2. Create benchmarks that can inform the design of the kernels.
//
// Non-goals:
//
// 1. Measure every distance operation. Instead measures just one distance
// operation at the same time.
// 2. Be useful for finding performance regressions. This is handled by the
// normal benchmarks.
//
// So far, both goals are partly achieved.
//
// RE (1), COMPILE TIMES: kernel.cu is fast to compile. This file is not.
// When the internals of a pairwise distance kernel is changed, this file is not
// recompiled.
//
// RE 2, benchmarks with intent: this file contains a benchmark to check the
// maximal throughput of a kernel. Measuring other things, like performance on
// skinny or wide matrices is not yet implemented.

#include "kernel.cuh" // launch_kernel
#include <algorithm> // std::min
#include <common/benchmark.hpp> // RAFT_BENCH_REGISTER
#include <raft/distance/detail/pairwise_matrix/params.cuh> // pairwise_matrix_params
#include <rmm/device_uvector.hpp> // rmm::device_uvector
#include <vector> // std::vector

namespace raft::bench::distance::tune {

// Max throughput benchmark.
//
// Goal: Measure the maximum distances/sec that can be computed.
//
// To achieve this, we make sure that:
//
// - Input data size is a multiple of the block tile size.
//
// - Perfect distribution of work between SMs, i.e. the number of block tiles is
// a large multiple (num_waves) of the number of blocks (#SMs * occupancy).
//
// - Multiple iterations over Kblk are executed (num_k_iters).
struct throughput_param {
int num_waves;
int occupancy;
int num_k_iters;
};

const std::vector<throughput_param> throughput_params{
// 32 waves, requested occupancy of 4, and 32 k iterations typically achieves
// maximum throughput. No need to pick higher values.
{32, 4, 32},
};

struct throughput_bench : public fixture {
const throughput_param p;

throughput_bench(const throughput_param& p_) : p(p_) {}

void run_benchmark(::benchmark::State& state) override
{
// Get block size:
int block_m, block_n, block_k;
get_block_size(block_m, block_n, block_k);

// Determine number of blocks that will be launched. This informs the size
// of the inputs as well as the grid size.
const int num_sms = raft::getMultiProcessorCount();
const int max_occupancy = get_max_occupancy();
const int occupancy = std::min(p.occupancy, max_occupancy);
const int num_blocks = occupancy * num_sms;
dim3 grid(num_blocks);

// Create input sizes that are a multiple of the block tile size.
size_t m = block_m;
size_t n = block_n * p.num_waves * num_blocks;
size_t k = block_k * p.num_k_iters;

// DataT, OutT, IdxT, etc, are defined in tuned_kernel.cuh
rmm::device_uvector<DataT> x_vec(m * k, stream);
rmm::device_uvector<DataT> y_vec(n * k, stream);
rmm::device_uvector<DataT> x_norm_vec(m, stream);
rmm::device_uvector<DataT> y_norm_vec(n, stream);
rmm::device_uvector<OutT> out_vec(m * n, stream);

auto x = x_vec.data();
auto y = y_vec.data();
auto x_norm = x_norm_vec.data();
auto y_norm = y_norm_vec.data();
auto out = out_vec.data();
FinOpT fin_op{};

// Create kernel parameter struct. Flip x and y if column major.
IdxT ldx = row_major ? k : m;
IdxT ldy = row_major ? k : n;
IdxT ld_out = row_major ? n : m;

// Template parameters of pairwise_matrix_params are defined in kernel.cuh
pairwise_matrix_params kparams{
IdxT(m), IdxT(n), IdxT(k), ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, row_major};

// Run benchmark
loop_on_state(state, [&]() { launch_kernel(kparams, grid, stream); });

// Report metrics. We don't report flop/s because we do not know for each
// distance operation how many flops it costs. For L2_unexp and l1, we can
// double this number to get the flop/s. For l2 expanded, core_ops/s should
// equal flop/s (modulo the sqrt and subtracting from the norm).
size_t num_core_ops = m * n * k;
size_t read_elts = n * k + m * k;
size_t write_elts = m * n;

state.counters["m"] = benchmark::Counter(m);
state.counters["n"] = benchmark::Counter(n);
state.counters["k"] = benchmark::Counter(k);
state.counters["occupancy"] = benchmark::Counter(occupancy);
state.counters["# waves"] = benchmark::Counter(p.num_waves);
state.counters["# k iters"] = benchmark::Counter(p.num_k_iters);

state.counters["core_ops/s"] = benchmark::Counter(num_core_ops,
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);

state.counters["BW"] = benchmark::Counter(write_elts * sizeof(OutT) + read_elts * sizeof(DataT),
benchmark::Counter::kIsIterationInvariantRate,
benchmark::Counter::OneK::kIs1000);
}
};

RAFT_BENCH_REGISTER(throughput_bench, "", throughput_params);

} // namespace raft::bench::distance::tune
88 changes: 88 additions & 0 deletions cpp/bench/distance/tune_pairwise/kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kernel.cuh"
#include <raft/distance/detail/pairwise_matrix/kernel_sm60.cuh> // pairwise_matrix_sm60_wrapper
#include <raft/linalg/contractions.cuh> // raft::linalg::Policy4x4
#include <raft/util/arch.cuh> // raft::util::arch::SM_compute_arch

namespace raft::bench::distance::tune {

// Distance op
using OpT = raft::distance::detail::ops::lp_unexp_distance_op<DataT, AccT, IdxT>;
constexpr float metric_arg = 2.0;
OpT distance_op{metric_arg};

// Kernel policy
constexpr int vec_len = 1;
using Policy = typename raft::linalg::Policy4x4<DataT, vec_len>::Policy;

// Architecture
namespace arch = raft::util::arch;
constexpr auto sm_compat_range = arch::SM_range(arch::SM_min(), arch::SM_future());

void launch_kernel(pairwise_matrix_params params, dim3 grid, cudaStream_t stream)
{
dim3 block(Policy::Nthreads);
int smem_size = OpT::shared_mem_size<Policy>();

// Obtain function pointer to kernel
auto kernel = raft::distance::detail::pairwise_matrix_kernel<Policy,
row_major,
decltype(sm_compat_range),
OpT,
IdxT,
DataT,
OutT,
FinOpT>;

kernel<<<grid, block, smem_size, stream>>>(distance_op, params);
RAFT_CUDA_TRY(cudaGetLastError());
}

void get_block_size(int& m, int& n, int& k)
{
m = Policy::Mblk;
n = Policy::Nblk;
k = Policy::Kblk;
}

void* get_kernel_ptr()
{
auto kernel = raft::distance::detail::pairwise_matrix_kernel<Policy,
row_major,
decltype(sm_compat_range),
OpT,
IdxT,
DataT,
OutT,
FinOpT>;
return reinterpret_cast<void*>(kernel);
}

int get_max_occupancy()
{
void* kernel_ptr = get_kernel_ptr();
int max_occupancy;
int smem_size = OpT::shared_mem_size<Policy>();

RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy, kernel_ptr, Policy::Nthreads, smem_size));

return max_occupancy;
}

} // namespace raft::bench::distance::tune
44 changes: 44 additions & 0 deletions cpp/bench/distance/tune_pairwise/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/distance/detail/distance_ops/all_ops.cuh> // lp_unexp_distance_op
#include <raft/distance/detail/pairwise_matrix/params.cuh> // pairwise_matrix_params

namespace raft::bench::distance::tune {

// Launch one specific kernel with the following template parameters
constexpr bool row_major = true;
using DataT = float;
using AccT = float;
using OutT = DataT;
using IdxT = int;

using FinOpT = raft::identity_op;

using pairwise_matrix_params =
raft::distance::detail::pairwise_matrix_params<IdxT, DataT, OutT, FinOpT>;

// Launches kernel
void launch_kernel(pairwise_matrix_params, dim3, cudaStream_t);

// Describes the block size that is decided by the policy
void get_block_size(int& m, int& n, int& k);

int get_max_occupancy();

} // namespace raft::bench::distance::tune
2 changes: 1 addition & 1 deletion cpp/include/raft/core/kvp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#ifdef _RAFT_HAS_CUDA
#include <cub/cub.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cuda_utils.cuh> // raft::shfl_xor
#endif
namespace raft {
/**
Expand Down
Loading

0 comments on commit 08e7012

Please sign in to comment.