Skip to content

Commit

Permalink
[FEA] Masked NN for connect_components (#1445)
Browse files Browse the repository at this point in the history
Replace fused L2 Nearest Neighbors in `connect_components` with masked NN.
Closes #743
Closes #1569

Authors:
  - Tarang Jain (https://github.com/tarang-jain)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Allard Hendriksen (https://github.com/ahendriksen)

URL: #1445
  • Loading branch information
tarang-jain authored Jul 25, 2023
1 parent 202385b commit a66c3a3
Show file tree
Hide file tree
Showing 17 changed files with 1,811 additions and 461 deletions.
20 changes: 16 additions & 4 deletions cpp/include/raft/cluster/detail/mst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <raft/sparse/neighbors/connect_components.cuh>
#include <raft/sparse/neighbors/cross_component_nn.cuh>
#include <raft/sparse/op/sort.cuh>
#include <raft/sparse/solver/mst.cuh>
#include <rmm/device_uvector.hpp>
Expand Down Expand Up @@ -81,8 +81,20 @@ void connect_knn_graph(

raft::sparse::COO<value_t, value_idx> connected_edges(stream);

raft::sparse::neighbors::connect_components<value_idx, value_t>(
handle, connected_edges, X, color, m, n, reduction_op);
// default row and column batch sizes are chosen for computing cross component nearest neighbors.
// Reference: PR #1445
static constexpr size_t default_row_batch_size = 4096;
static constexpr size_t default_col_batch_size = 16;

raft::sparse::neighbors::cross_component_nn<value_idx, value_t>(handle,
connected_edges,
X,
color,
m,
n,
reduction_op,
min(m, default_row_batch_size),
min(n, default_col_batch_size));

rmm::device_uvector<value_idx> indptr2(m + 1, stream);
raft::sparse::convert::sorted_coo_to_csr(
Expand Down Expand Up @@ -192,4 +204,4 @@ void build_sorted_mst(
raft::copy_async(mst_weight, mst_coo.weights.data(), mst_coo.n_edges, stream);
}

}; // namespace raft::cluster::detail
}; // namespace raft::cluster::detail
2 changes: 1 addition & 1 deletion cpp/include/raft/cluster/detail/single_linkage.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void single_linkage(raft::resources const& handle,
* 2. Construct MST, sorted by weights
*/
rmm::device_uvector<value_idx> color(m, stream);
raft::sparse::neighbors::FixConnectivitiesRedOp<value_idx, value_t> op(color.data(), m);
raft::sparse::neighbors::FixConnectivitiesRedOp<value_idx, value_t> op(m);
detail::build_sorted_mst<value_idx, value_t>(handle,
X,
indptr.data(),
Expand Down
12 changes: 2 additions & 10 deletions cpp/include/raft/matrix/detail/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <functional>
#include <raft/core/operators.hpp>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -135,16 +136,6 @@ void gatherImpl(const InputIteratorT in,
// stencil value type
typedef typename std::iterator_traits<StencilIteratorT>::value_type StencilValueT;

// return type of MapTransformOp, must be convertible to IndexT
typedef typename std::result_of<decltype(transform_op)(MapValueT)>::type MapTransformOpReturnT;
static_assert((std::is_convertible<MapTransformOpReturnT, IndexT>::value),
"MapTransformOp's result type must be convertible to signed integer");

// return type of UnaryPredicateOp, must be convertible to bool
typedef typename std::result_of<decltype(pred_op)(StencilValueT)>::type PredicateOpReturnT;
static_assert((std::is_convertible<PredicateOpReturnT, bool>::value),
"UnaryPredicateOp's result type must be convertible to bool type");

IndexT len = map_length * D;
constexpr int TPB = 128;
const int n_sm = raft::getMultiProcessorCount();
Expand Down Expand Up @@ -343,6 +334,7 @@ void gather_if(const InputIteratorT in,
typedef typename std::iterator_traits<MapIteratorT>::value_type MapValueT;
gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream);
}

} // namespace detail
} // namespace matrix
} // namespace raft
116 changes: 116 additions & 0 deletions cpp/include/raft/matrix/detail/gather_inplace.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/linalg/map.cuh>
#include <raft/util/fast_int_div.cuh>
#include <thrust/iterator/counting_iterator.h>

namespace raft {
namespace matrix {
namespace detail {

template <typename MatrixT, typename MapT, typename MapTransformOp, typename IndexT>
void gatherInplaceImpl(raft::resources const& handle,
raft::device_matrix_view<MatrixT, IndexT, raft::layout_c_contiguous> inout,
raft::device_vector_view<const MapT, IndexT, raft::layout_c_contiguous> map,
MapTransformOp transform_op,
IndexT batch_size)
{
IndexT m = inout.extent(0);
IndexT n = inout.extent(1);
IndexT map_length = map.extent(0);

// skip in case of 0 length input
if (map_length <= 0 || m <= 0 || n <= 0 || batch_size < 0) return;

RAFT_EXPECTS(map_length <= m, "Length of map should be <= number of rows for inplace gather");

RAFT_EXPECTS(batch_size >= 0, "batch size should be >= 0");

// re-assign batch_size for default case
if (batch_size == 0 || batch_size > n) batch_size = n;

auto exec_policy = resource::get_thrust_policy(handle);

IndexT n_batches = raft::ceildiv(n, batch_size);

auto scratch_space = raft::make_device_vector<MatrixT, IndexT>(handle, map_length * batch_size);

for (IndexT bid = 0; bid < n_batches; bid++) {
IndexT batch_offset = bid * batch_size;
IndexT cols_per_batch = min(batch_size, n - batch_offset);

auto gather_op = [inout = inout.data_handle(),
map = map.data_handle(),
transform_op,
batch_offset,
map_length,
cols_per_batch = raft::util::FastIntDiv(cols_per_batch),
n] __device__(auto idx) {
IndexT row = idx / cols_per_batch;
IndexT col = idx % cols_per_batch;
MapT map_val = map[row];

IndexT i_src = transform_op(map_val);
return inout[i_src * n + batch_offset + col];
};
raft::linalg::map_offset(
handle,
raft::make_device_vector_view(scratch_space.data_handle(), map_length * cols_per_batch),
gather_op);

auto copy_op = [inout = inout.data_handle(),
map = map.data_handle(),
scratch_space = scratch_space.data_handle(),
batch_offset,
map_length,
cols_per_batch = raft::util::FastIntDiv(cols_per_batch),
n] __device__(auto idx) {
IndexT row = idx / cols_per_batch;
IndexT col = idx % cols_per_batch;
inout[row * n + batch_offset + col] = scratch_space[idx];
return;
};
auto counting = thrust::make_counting_iterator<IndexT>(0);
thrust::for_each(exec_policy, counting, counting + map_length * cols_per_batch, copy_op);
}
}

template <typename MatrixT, typename MapT, typename MapTransformOp, typename IndexT>
void gather(raft::resources const& handle,
raft::device_matrix_view<MatrixT, IndexT, raft::layout_c_contiguous> inout,
raft::device_vector_view<const MapT, IndexT, raft::layout_c_contiguous> map,
MapTransformOp transform_op,
IndexT batch_size)
{
gatherInplaceImpl(handle, inout, map, transform_op, batch_size);
}

template <typename MatrixT, typename MapT, typename IndexT>
void gather(raft::resources const& handle,
raft::device_matrix_view<MatrixT, IndexT, raft::layout_c_contiguous> inout,
raft::device_vector_view<const MapT, IndexT, raft::layout_c_contiguous> map,
IndexT batch_size)
{
gatherInplaceImpl(handle, inout, map, raft::identity_op(), batch_size);
}

} // namespace detail
} // namespace matrix
} // namespace raft
127 changes: 127 additions & 0 deletions cpp/include/raft/matrix/detail/scatter_inplace.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cstdint>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/linalg/map.cuh>
#include <raft/util/cuda_dev_essentials.cuh>
#include <raft/util/fast_int_div.cuh>
#include <thrust/iterator/counting_iterator.h>

namespace raft {
namespace matrix {
namespace detail {

/**
* @brief In-place scatter elements in a row-major matrix according to a
* map. The length of the map is equal to the number of rows. The
* map specifies the destination index for each row, i.e. in the
* resulting matrix, row map[i] is assigned to row i. For example,
* the matrix [[1, 2, 3], [4, 5, 6], [7, 8, 9]] with the map [2, 0, 1] will
* be transformed to [[4, 5, 6], [7, 8, 9], [1, 2, 3]]. Batching is done on
* columns and an additional scratch space of shape n_rows * cols_batch_size
* is created. For each batch, chunks of columns from each row are copied
* into the appropriate location in the scratch space and copied back to
* the corresponding locations in the input matrix.
*
* @tparam InputIteratorT
* @tparam MapIteratorT
* @tparam IndexT
*
* @param[inout] handle raft handle
* @param[inout] inout input matrix (n_rows * n_cols)
* @param[inout] map map containing the destination index for each row (n_rows)
* @param[inout] batch_size column batch size
*/

template <typename MatrixT, typename IndexT>
void scatterInplaceImpl(
raft::resources const& handle,
raft::device_matrix_view<MatrixT, IndexT, raft::layout_c_contiguous> inout,
raft::device_vector_view<const IndexT, IndexT, raft::layout_c_contiguous> map,
IndexT batch_size)
{
IndexT m = inout.extent(0);
IndexT n = inout.extent(1);
IndexT map_length = map.extent(0);

// skip in case of 0 length input
if (map_length <= 0 || m <= 0 || n <= 0 || batch_size < 0) return;

RAFT_EXPECTS(map_length == m,
"Length of map should be equal to number of rows for inplace scatter");

RAFT_EXPECTS(batch_size >= 0, "batch size should be >= 0");

// re-assign batch_size for default case
if (batch_size == 0 || batch_size > n) batch_size = n;

auto exec_policy = resource::get_thrust_policy(handle);

IndexT n_batches = raft::ceildiv(n, batch_size);

auto scratch_space = raft::make_device_vector<MatrixT, IndexT>(handle, m * batch_size);

for (IndexT bid = 0; bid < n_batches; bid++) {
IndexT batch_offset = bid * batch_size;
IndexT cols_per_batch = min(batch_size, n - batch_offset);

auto copy_op = [inout = inout.data_handle(),
map = map.data_handle(),
batch_offset,
cols_per_batch = raft::util::FastIntDiv(cols_per_batch),
n] __device__(auto idx) {
IndexT row = idx / cols_per_batch;
IndexT col = idx % cols_per_batch;
return inout[row * n + batch_offset + col];
};
raft::linalg::map_offset(
handle,
raft::make_device_vector_view(scratch_space.data_handle(), m * cols_per_batch),
copy_op);

auto scatter_op = [inout = inout.data_handle(),
map = map.data_handle(),
scratch_space = scratch_space.data_handle(),
batch_offset,
cols_per_batch = raft::util::FastIntDiv(cols_per_batch),
n] __device__(auto idx) {
IndexT row = idx / cols_per_batch;
IndexT col = idx % cols_per_batch;
IndexT map_val = map[row];

inout[map_val * n + batch_offset + col] = scratch_space[idx];
return;
};
auto counting = thrust::make_counting_iterator<IndexT>(0);
thrust::for_each(exec_policy, counting, counting + m * cols_per_batch, scatter_op);
}
}

template <typename MatrixT, typename IndexT>
void scatter(raft::resources const& handle,
raft::device_matrix_view<MatrixT, IndexT, raft::layout_c_contiguous> inout,
raft::device_vector_view<const IndexT, IndexT, raft::layout_c_contiguous> map,
IndexT batch_size)
{
scatterInplaceImpl(handle, inout, map, batch_size);
}

} // end namespace detail
} // end namespace matrix
} // end namespace raft
41 changes: 41 additions & 0 deletions cpp/include/raft/matrix/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/detail/gather.cuh>
#include <raft/matrix/detail/gather_inplace.cuh>
#include <raft/util/itertools.hpp>

namespace raft::matrix {
Expand Down Expand Up @@ -289,6 +290,46 @@ void gather_if(const raft::resources& handle,
resource::get_cuda_stream(handle));
}

/**
* @brief In-place gather elements in a row-major matrix according to a
* map. The map specifies the new order in which rows of the input matrix are
* rearranged, i.e. for each output row, read the index in the input matrix
* from the map, apply a transformation to this input index if specified, and copy the row.
* map[i]. For example, the matrix [[1, 2, 3], [4, 5, 6], [7, 8, 9]] with the
* map [2, 0, 1] will be transformed to [[7, 8, 9], [1, 2, 3], [4, 5, 6]].
* Batching is done on columns and an additional scratch space of
* shape n_rows * cols_batch_size is created. For each batch, chunks
* of columns from each row are copied into the appropriate location
* in the scratch space and copied back to the corresponding locations
* in the input matrix.
*
* @tparam matrix_t Matrix element type
* @tparam map_t Integer type of map elements
* @tparam map_xform_t Unary lambda expression or operator type. MapTransformOp's result type must
* be convertible to idx_t.
* @tparam idx_t Integer type used for indexing
*
* @param[in] handle raft handle
* @param[inout] inout input matrix (n_rows * n_cols)
* @param[in] map Pointer to the input sequence of gather locations
* @param[in] col_batch_size (optional) column batch size. Determines the shape of the scratch space
* (map_length, col_batch_size). When set to zero (default), no batching is done and an additional
* scratch space of shape (map_lengthm, n_cols) is created.
* @param[in] transform_op (optional) Transformation to apply to map values
*/
template <typename matrix_t,
typename map_t,
typename idx_t,
typename map_xform_t = raft::identity_op>
void gather(raft::resources const& handle,
raft::device_matrix_view<matrix_t, idx_t, raft::layout_c_contiguous> inout,
raft::device_vector_view<const map_t, idx_t, raft::layout_c_contiguous> map,
idx_t col_batch_size = 0,
map_xform_t transform_op = raft::identity_op())
{
detail::gather(handle, inout, map, transform_op, col_batch_size);
}

/** @} */ // end of group matrix_gather

} // namespace raft::matrix
Loading

0 comments on commit a66c3a3

Please sign in to comment.