Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinAlg impl in detail #383

Merged
merged 27 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
478ddac
working through
divyegala Oct 21, 2021
d4b72ba
working ththrough
divyegala Nov 5, 2021
b472870
linalg detail
divyegala Nov 17, 2021
3bd9645
merging branch 22.02
divyegala Nov 17, 2021
788ffa8
style fix
divyegala Nov 17, 2021
f7d43b5
correcting include
divyegala Nov 17, 2021
282cd48
merging branch-21.12
divyegala Nov 17, 2021
37596c9
Merge remote-tracking branch 'upstream/branch-22.02' into imp-21.12-l…
divyegala Nov 17, 2021
cd4e1f9
merging upstream
divyegala Dec 14, 2021
9c0d655
removing deleted file again
divyegala Dec 14, 2021
a071d09
correcting merges and passing tests
divyegala Dec 14, 2021
db817f6
changing h extensions to hpp
divyegala Dec 14, 2021
abec4d2
cublas/cusolver only in detail, wrap up rest of linalg
divyegala Dec 22, 2021
b424cf1
merging upstream
divyegala Dec 22, 2021
34b2439
correcting doxygen build
divyegala Dec 22, 2021
897e6f7
correcting wrong docs
divyegala Dec 22, 2021
3d4b5f1
review feedback
divyegala Jan 11, 2022
4163619
merging branch-22.02
divyegala Jan 25, 2022
8ff01a9
Merge remote-tracking branch 'upstream/branch-22.04' into imp-21.12-l…
divyegala Jan 25, 2022
b6471d6
review changes
divyegala Jan 26, 2022
5d8c176
more macro renames
divyegala Jan 27, 2022
14cddfc
adding explict stream set back to cublas and cusolver wrappers
divyegala Feb 2, 2022
a2f670f
resolving errors
divyegala Feb 2, 2022
89bf3c1
adding set stream to cublas set pointer mode
divyegala Feb 4, 2022
3c5d303
Merge branch 'branch-22.04' into imp-linalg-public
cjnolet Feb 4, 2022
5759c80
Merge branch 'branch-22.04' into imp-21.12-linalg_detail
cjnolet Feb 7, 2022
f94beef
Fixing a bad merge
cjnolet Feb 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/detail/correlation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#pragma once
#include <raft/cuda_utils.cuh>
#include <raft/distance/detail/pairwise_distance_base.cuh>
#include <raft/linalg/reduce.cuh>
#include <raft/linalg/reduce.hpp>

namespace raft {
namespace distance {
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/detail/cosine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#pragma once

#include <raft/distance/detail/pairwise_distance_base.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/norm.hpp>

namespace raft {
namespace distance {
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#pragma once

#include <cuda_runtime_api.h>
#include <raft/linalg/distance_type.h>
#include <raft/cuda_utils.cuh>
#include <raft/distance/detail/canberra.cuh>
#include <raft/distance/detail/chebyshev.cuh>
Expand All @@ -31,6 +30,7 @@
#include <raft/distance/detail/l1.cuh>
#include <raft/distance/detail/minkowski.cuh>
#include <raft/distance/detail/russell_rao.cuh>
#include <raft/linalg/distance_type.hpp>
divyegala marked this conversation as resolved.
Show resolved Hide resolved
#include <rmm/device_uvector.hpp>

namespace raft {
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/detail/euclidean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#pragma once
#include <raft/distance/detail/pairwise_distance_base.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/norm.hpp>

namespace raft {
namespace distance {
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include <limits>
#include <raft/cuda_utils.cuh>
#include <raft/distance/detail/pairwise_distance_base.cuh>
#include <raft/linalg/contractions.cuh>
#include <raft/linalg/contractions.hpp>

namespace raft {
namespace distance {
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/detail/hellinger.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#pragma once
#include <raft/distance/detail/pairwise_distance_base.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/linalg/unary_op.hpp>

namespace raft {
namespace distance {
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#pragma once
#include <raft/cudart_utils.h>
#include <raft/cuda_utils.cuh>
#include <raft/linalg/contractions.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/contractions.hpp>
#include <raft/linalg/norm.hpp>
#include <raft/vectorized.cuh>

#include <cstddef>
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

#pragma once

#include <raft/linalg/distance_type.h>
#include <raft/distance/detail/distance.cuh>
#include <raft/handle.hpp>
#include <raft/linalg/distance_type.hpp>
#include <rmm/device_uvector.hpp>

namespace raft {
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
///@todo: enable once we have migrated cuml-comms layer too
//#include <common/cuml_comms_int.hpp>

#include <raft/linalg/cublas_wrappers.h>
#include <raft/linalg/cusolver_wrappers.h>
#include <raft/sparse/cusparse_wrappers.h>
#include <raft/comms/comms.hpp>
#include <raft/linalg/cublas_wrappers.hpp>
divyegala marked this conversation as resolved.
Show resolved Hide resolved
#include <raft/linalg/cusolver_wrappers.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include <rmm/exec_policy.hpp>
#include "cudart_utils.h"
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/label/classlabels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#include <raft/cudart_utils.h>
#include <raft/cuda_utils.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/linalg/unary_op.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/label/merge_labels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
#include <limits>

#include <raft/cudart_utils.h>
#include <raft/linalg/init.h>
#include <raft/cuda_utils.cuh>
#include <raft/linalg/init.hpp>

namespace raft {
namespace label {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@

#pragma once

#include "binary_op.cuh"
#include "unary_op.cuh"
#include "detail/add.cuh"
#include "detail/functional.cuh"

#include "binary_op.hpp"
#include "unary_op.hpp"

namespace raft {
namespace linalg {

using detail::adds_scalar;

/**
* @brief Elementwise scalar add operation on the input buffer
*
Expand All @@ -39,8 +44,7 @@ namespace linalg {
template <typename InT, typename OutT = InT, typename IdxType = int>
void addScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream)
{
auto op = [scalar] __device__(InT in) { return OutT(in + scalar); };
unaryOp<InT, decltype(op), IdxType, OutT>(out, in, len, op, stream);
unaryOp(out, in, len, adds_scalar<InT, OutT>(scalar), stream);
}

/**
Expand All @@ -59,18 +63,7 @@ void addScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t s
template <typename InT, typename OutT = InT, typename IdxType = int>
void add(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t stream)
{
auto op = [] __device__(InT a, InT b) { return OutT(a + b); };
binaryOp<InT, decltype(op), OutT, IdxType>(out, in1, in2, len, op, stream);
}

template <class math_t, typename IdxType>
__global__ void add_dev_scalar_kernel(math_t* outDev,
const math_t* inDev,
const math_t* singleScalarDev,
IdxType len)
{
IdxType i = ((IdxType)blockIdx.x * (IdxType)blockDim.x) + threadIdx.x;
if (i < len) { outDev[i] = inDev[i] + *singleScalarDev; }
binaryOp(out, in1, in2, len, thrust::plus<InT>(), stream);
}

/** Substract single value pointed by singleScalarDev parameter in device memory from inDev[i] and
Expand All @@ -90,11 +83,7 @@ void addDevScalar(math_t* outDev,
IdxType len,
cudaStream_t stream)
{
// TODO: block dimension has not been tuned
dim3 block(256);
dim3 grid(raft::ceildiv(len, (IdxType)block.x));
add_dev_scalar_kernel<math_t><<<grid, block, 0, stream>>>(outDev, inDev, singleScalarDev, len);
RAFT_CUDA_TRY(cudaPeekAtLastError());
detail::addDevScalar(outDev, inDev, singleScalarDev, len, stream);
}

}; // end namespace linalg
Expand Down
54 changes: 54 additions & 0 deletions cpp/include/raft/linalg/binary_op.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
divyegala marked this conversation as resolved.
Show resolved Hide resolved
* Copyright (c) 2018-2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "detail/binary_op.cuh"

#include <raft/cuda_utils.cuh>

namespace raft {
namespace linalg {

/**
* @brief perform element-wise binary operation on the input arrays
* @tparam InType input data-type
* @tparam Lambda the device-lambda performing the actual operation
* @tparam OutType output data-type
* @tparam IdxType Integer type used to for addressing
* @tparam TPB threads-per-block in the final kernel launched
* @param out the output array
* @param in1 the first input array
* @param in2 the second input array
* @param len number of elements in the input array
* @param op the device-lambda
* @param stream cuda stream where to launch work
* @note Lambda must be a functor with the following signature:
* `OutType func(const InType& val1, const InType& val2);`
*/
template <typename InType,
typename Lambda,
typename OutType = InType,
typename IdxType = int,
int TPB = 256>
void binaryOp(
OutType* out, const InType* in1, const InType* in2, IdxType len, Lambda op, cudaStream_t stream)
{
detail::binaryOp(out, in1, in2, len, op, stream);
}

}; // end namespace linalg
}; // end namespace raft
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@

#pragma once

#include <raft/linalg/cublas_wrappers.h>
#include <raft/linalg/cusolver_wrappers.h>
#include <raft/cuda_utils.cuh>
#include <raft/handle.hpp>
#include <raft/linalg/binary_op.cuh>
#include "detail/cholesky_r1_update.hpp"

namespace raft {
namespace linalg {
Expand Down Expand Up @@ -132,94 +128,7 @@ void choleskyRank1Update(const raft::handle_t& handle,
cudaStream_t stream,
math_t eps = -1)
{
// The matrix A' is defined as:
// A' = [[A_11, A_12]
// [A_21, A_22]]
// where:
// - A_11 = A, matrix of size (n-1)x(n-1)
// - A_21[j] = A_12.T[j] = A_new[j] j=0..n-2, vector with (n-1) elements
// - A_22 = A_new[n-1] scalar.
//
// Instead of caclulating the Cholelsky decomposition of A' from scratch,
// we just update L with the new row. The new Cholesky decomposition will be
// calculated as:
// L' = [[L_11, 0]
// [L_12, L_22]]
// where L_11 is the Cholesky decomposition of A (size [n-1 x n-1]), and
// L_12 and L_22 are the new quantities that we need to calculate.

// We need a workspace in device memory to store a scalar. Additionally, in
// CUBLAS_FILL_MODE_LOWER we need space for n-1 floats.
const int align = 256;
int offset =
(uplo == CUBLAS_FILL_MODE_LOWER) ? raft::alignTo<int>(sizeof(math_t) * (n - 1), align) : 0;
if (workspace == nullptr) {
*n_bytes = offset + 1 * sizeof(math_t);
return;
}
math_t* s = reinterpret_cast<math_t*>(((char*)workspace) + offset);
math_t* L_22 = L + (n - 1) * ld + n - 1;

math_t* A_new;
math_t* A_row;
if (uplo == CUBLAS_FILL_MODE_UPPER) {
// A_new is stored as the n-1 th column of L
A_new = L + (n - 1) * ld;
} else {
// If the input is lower triangular, then the new elements of A are stored
// as the n-th row of L. Since the matrix is column major, this is non
// contiguous. We copy elements from A_row to a contiguous workspace A_new.
A_row = L + n - 1;
A_new = reinterpret_cast<math_t*>(workspace);
RAFT_CUBLAS_TRY(
raft::linalg::cublasCopy(handle.get_cublas_handle(), n - 1, A_row, ld, A_new, 1, stream));
}
cublasOperation_t op = (uplo == CUBLAS_FILL_MODE_UPPER) ? CUBLAS_OP_T : CUBLAS_OP_N;
if (n > 1) {
// Calculate L_12 = x by solving equation L_11 x = A_12
math_t alpha = 1;
RAFT_CUBLAS_TRY(raft::linalg::cublastrsm(handle.get_cublas_handle(),
CUBLAS_SIDE_LEFT,
uplo,
op,
CUBLAS_DIAG_NON_UNIT,
n - 1,
1,
&alpha,
L,
ld,
A_new,
n - 1,
stream));

// A_new now stores L_12, we calculate s = L_12 * L_12
RAFT_CUBLAS_TRY(
raft::linalg::cublasdot(handle.get_cublas_handle(), n - 1, A_new, 1, A_new, 1, s, stream));

if (uplo == CUBLAS_FILL_MODE_LOWER) {
// Copy back the L_12 elements as the n-th row of L
RAFT_CUBLAS_TRY(
raft::linalg::cublasCopy(handle.get_cublas_handle(), n - 1, A_new, 1, A_row, ld, stream));
}
} else { // n == 1 case
RAFT_CUDA_TRY(cudaMemsetAsync(s, 0, sizeof(math_t), stream));
}

// L_22 = sqrt(A_22 - L_12 * L_12)
math_t s_host;
math_t L_22_host;
raft::update_host(&s_host, s, 1, stream);
raft::update_host(&L_22_host, L_22, 1, stream); // L_22 stores A_22
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
L_22_host = std::sqrt(L_22_host - s_host);

// Check for numeric error with sqrt. If the matrix is not positive definit or
// the system is very ill conditioned then the A_22 - L_12 * L_12 can be
// negative, which would result L_22 = NaN. A small positive eps parameter
// can be used to prevent this.
if (eps >= 0 && (std::isnan(L_22_host) || L_22_host < eps)) { L_22_host = eps; }
ASSERT(!std::isnan(L_22_host), "Error during Cholesky rank one update");
raft::update_device(L_22, &L_22_host, 1, stream);
detail::choleskyRank1Update(handle, L, n, ld, workspace, n_bytes, uplo, stream, eps);
}
}; // namespace linalg
}; // namespace raft
Loading