Skip to content

Commit

Permalink
Take distance_op in pairwise_distance_base
Browse files Browse the repository at this point in the history
Fixes issue rapidsai#1323
  • Loading branch information
ahendriksen committed Mar 14, 2023
1 parent f1db5a1 commit 2621306
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 220 deletions.
2 changes: 1 addition & 1 deletion cpp/include/raft/core/kvp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#ifdef _RAFT_HAS_CUDA
#include <cub/cub.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cuda_utils.cuh> // raft::shfl_xor
#endif
namespace raft {
/**
Expand Down
133 changes: 37 additions & 96 deletions cpp/include/raft/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,20 @@

#pragma once

#include <limits>
#include <raft/core/kvp.hpp>
#include <raft/distance/detail/pairwise_distance_base.cuh>
#include <raft/linalg/contractions.cuh>
#include <raft/util/cuda_utils.cuh>
#include <stdint.h>
#include <cstddef> // size_t
#include <limits> // std::numeric_limits
#include <raft/core/kvp.hpp> // raft::KeyValuePair
#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/distance/detail/distance_ops/l2_exp.cuh> // ops::l2_exp_distance_op
#include <raft/distance/detail/pairwise_distance_base.cuh> // PairwiseDistances
#include <raft/linalg/contractions.cuh> // Policy
#include <raft/util/cuda_utils.cuh> // raft::ceildiv, raft::shfl

namespace raft {
namespace distance {

namespace detail {

#if (ENABLE_MEMCPY_ASYNC == 1)
#include <cuda_pipeline.h>
using namespace nvcuda::experimental;
#endif

template <typename LabelT, typename DataT>
struct KVPMinReduceImpl {
typedef raft::KeyValuePair<LabelT, DataT> KVP;
Expand Down Expand Up @@ -124,11 +121,10 @@ DI void updateReducedVal(
template <typename DataT,
typename OutT,
typename IdxT,
bool Sqrt,
typename P,
typename ReduceOpT,
typename KVPReduceOpT,
typename CoreLambda,
typename OpT,
typename FinalLambda>
__global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min,
const DataT* x,
Expand All @@ -142,7 +138,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min,
int* mutex,
ReduceOpT redOp,
KVPReduceOpT pairRedOp,
CoreLambda core_op,
OpT distance_op,
FinalLambda fin_op)
{
extern __shared__ char smem[];
Expand All @@ -163,24 +159,6 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min,
IdxT gridStrideY) {
KVPReduceOpT pairRed_op(pairRedOp);

#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j];
}
}
if (Sqrt) {
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
auto acc_ij = acc[i][j];
acc[i][j] = acc_ij > DataT{0} ? raft::sqrt(acc_ij) : DataT{0};
}
}
}

// intra thread reduce
const auto acccolid = threadIdx.x % P::AccThCols;
const auto accrowid = threadIdx.x / P::AccThCols;
Expand Down Expand Up @@ -229,18 +207,18 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min,
};

IdxT lda = k, ldb = k, ldd = n;
PairwiseDistances<true,
DataT,
DataT,
DataT,
constexpr bool row_major = true;
constexpr bool write_out = false;
PairwiseDistances<DataT,
DataT, // OutT (unused in PairwiseDistances)
IdxT,
P,
CoreLambda,
decltype(distance_op),
decltype(epilog_lambda),
FinalLambda,
decltype(rowEpilog_lambda),
true,
false>
row_major,
write_out>
obj(x,
y,
m,
Expand All @@ -251,9 +229,9 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min,
ldd,
xn,
yn,
nullptr,
nullptr, // Output pointer
smem,
core_op,
distance_op,
epilog_lambda,
fin_op,
rowEpilog_lambda);
Expand Down Expand Up @@ -289,9 +267,6 @@ void fusedL2NNImpl(OutT* min,
constexpr auto maxVal = std::numeric_limits<DataT>::max();
typedef KeyValuePair<IdxT, DataT> KVPair;

// Accumulation operation lambda
auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; };

RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream));
if (initOutBuffer) {
initKernel<DataT, OutT, IdxT, ReduceOpT>
Expand All @@ -300,59 +275,25 @@ void fusedL2NNImpl(OutT* min,
}

constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT));
if (sqrt) {
auto fusedL2NNSqrt = fusedL2NNkernel<DataT,
OutT,
IdxT,
true,
P,
ReduceOpT,
KVPReduceOpT,
decltype(core_lambda),
raft::identity_op>;
dim3 grid = launchConfigGenerator<P>(m, n, shmemSize, fusedL2NNSqrt);

fusedL2NNSqrt<<<grid, blk, shmemSize, stream>>>(min,
x,
y,
xn,
yn,
m,
n,
k,
maxVal,
workspace,
redOp,
pairRedOp,
core_lambda,
raft::identity_op{});
} else {
auto fusedL2NN = fusedL2NNkernel<DataT,
OutT,
IdxT,
false,
P,
ReduceOpT,
KVPReduceOpT,
decltype(core_lambda),
raft::identity_op>;
dim3 grid = launchConfigGenerator<P>(m, n, shmemSize, fusedL2NN);
fusedL2NN<<<grid, blk, shmemSize, stream>>>(min,
x,
y,
xn,
yn,
m,
n,
k,
maxVal,
workspace,
redOp,
pairRedOp,
core_lambda,
raft::identity_op{});
}

using AccT = DataT;
ops::l2_exp_distance_op<DataT, AccT, IdxT> distance_op{sqrt};

raft::identity_op fin_op{};

auto kernel = fusedL2NNkernel<DataT,
OutT,
IdxT,
P,
ReduceOpT,
KVPReduceOpT,
decltype(distance_op),
decltype(fin_op)>;

dim3 grid = launchConfigGenerator<P>(m, n, shmemSize, kernel);

kernel<<<grid, blk, shmemSize, stream>>>(
min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op);
RAFT_CUDA_TRY(cudaGetLastError());
}

Expand Down
37 changes: 22 additions & 15 deletions cpp/include/raft/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,12 @@ namespace detail {

/**
* @brief Device class for L1, L2 and cosine distance metrics.
* @tparam useNorms whether norms are needed
* @tparam DataT input data-type (for A and B matrices)
* @tparam AccT accumulation data-type
* @tparam OutT output data-type (for C and D matrices)
* @tparam IdxT index data-type
* @tparam Policy struct which tunes the Contraction kernel
* @tparam CoreLambda tells how to accumulate an x and y into
acc. its signature:
template <typename AccT, typename DataT> void core_lambda(AccT& acc,
const DataT& x, const DataT& y)
* @tparam OpT A distance operation, e.g., cosine_distance_op.
* @tparam EpilogueLambda applies an elementwise function to compute final
values. Its signature is:
template <typename AccT, typename DataT> void epilogue_lambda
Expand All @@ -53,34 +49,35 @@ namespace detail {
* @param[in] yn row norms of input matrix B. Required for expanded L2, cosine
* @param[output] pD output matrix
* @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn.
* @param core_op the core accumulation operation lambda
* @param distance_op the distance operation, e.g. cosine_distance_op
* @param epilog_op the epilog operation lambda
* @param fin_op the final gemm epilogue lambda
* @param rowEpilog_op epilog lambda that executes when a full row has been processed
*/

template <bool useNorms,
typename DataT,
typename AccT,
template <typename DataT,
typename OutT,
typename IdxT,
typename Policy,
typename CoreLambda,
typename OpT,
typename EpilogueLambda,
typename FinalLambda,
typename rowEpilogueLambda,
bool isRowMajor = true,
bool writeOut = true,
typename BaseClass = raft::linalg::Contractions_NT<DataT, IdxT, Policy, isRowMajor>>
struct PairwiseDistances : public BaseClass {
// Get accumulation type from distance_op
using AccT = typename OpT::AccT;

private:
typedef Policy P;
const DataT* xn;
const DataT* yn;
const DataT* const yBase;
OutT* dOutput;
char* smem;
CoreLambda core_op;
OpT distance_op;
EpilogueLambda epilog_op;
FinalLambda fin_op;
rowEpilogueLambda rowEpilog_op;
Expand All @@ -106,7 +103,7 @@ struct PairwiseDistances : public BaseClass {
const DataT* _yn,
OutT* _dOutput,
char* _smem,
CoreLambda _core_op,
OpT _distance_op,
EpilogueLambda _epilog_op,
FinalLambda _fin_op,
rowEpilogueLambda _rowEpilog_op)
Expand All @@ -116,7 +113,7 @@ struct PairwiseDistances : public BaseClass {
yBase(_y),
dOutput(_dOutput),
smem(_smem),
core_op(_core_op),
distance_op(_distance_op),
epilog_op(_epilog_op),
fin_op(_fin_op),
rowEpilog_op(_rowEpilog_op),
Expand Down Expand Up @@ -156,15 +153,25 @@ struct PairwiseDistances : public BaseClass {
this->switch_read_buffer();

// Epilog:
if (useNorms) {
if (distance_op.use_norms) {
DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh];
load_norms(tile_idx_m, tile_idx_n, regxn, regyn);
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
// Calculate distance_op epilog.
// Use .template to disambiguate (See:
// https://en.cppreference.com/w/cpp/language/dependent_name)
distance_op.template epilog<Policy>(acc, regxn, regyn, tile_idx_n, tile_idx_m);
// And any possible additional epilogs
epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m);
} else {
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
// Calculate distance_op epilog.
// Use .template to disambiguate (See:
// https://en.cppreference.com/w/cpp/language/dependent_name)
distance_op.template epilog<Policy>(acc, nullptr, nullptr, tile_idx_n, tile_idx_m);
// And any possible additional epilogs
epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m);
}
if (writeOut) { store_output(tile_idx_m, tile_idx_n); }
Expand Down Expand Up @@ -209,7 +216,7 @@ struct PairwiseDistances : public BaseClass {
for (int j = 0; j < P::AccColsPerTh; ++j) {
#pragma unroll
for (int v = 0; v < P::Veclen; ++v) {
core_op(acc[i][j], this->regx[i][v], this->regy[j][v]);
distance_op.core(acc[i][j], this->regx[i][v], this->regy[j][v]);
}
}
}
Expand Down
28 changes: 6 additions & 22 deletions cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,36 +43,20 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(

extern __shared__ char smem[];

using AccT = typename OpT::AccT;

// Wrap operator back into lambdas. This is temporary and should be removed.
// See: https://github.com/rapidsai/raft/issues/1323
auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) {
distance_op.core(acc, x, y);
};
auto epilog_op = [distance_op] __device__(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh],
DataT * regxn,
DataT * regyn,
IdxT gridStrideX,
IdxT gridStrideY) {
// Use .template to disambiguate (See:
// https://en.cppreference.com/w/cpp/language/dependent_name)
distance_op.template epilog<Policy>(acc, regxn, regyn, gridStrideX, gridStrideY);
};

// The epilog is already provided by distance_op. Do not provide additional
// epilogs.
auto epilog_op = raft::void_op();
// No support for row_epilog_op.
auto row_epilog_op = raft::void_op();

// Always write output
constexpr bool write_out = true;
constexpr bool use_norms = distance_op.use_norms;
PairwiseDistances<use_norms,
DataT,
AccT,
PairwiseDistances<DataT,
OutT,
IdxT,
Policy,
decltype(core_op),
decltype(distance_op),
decltype(epilog_op),
decltype(params.fin_op),
decltype(row_epilog_op),
Expand All @@ -90,7 +74,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(
params.y_norm,
params.out,
smem,
core_op,
distance_op,
epilog_op,
params.fin_op,
row_epilog_op);
Expand Down
Loading

0 comments on commit 2621306

Please sign in to comment.