Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Masked NN for connect_components #1445

Merged
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
02189e8
first commit
tarang-jain Mar 23, 2023
51f1c39
Merge branch 'branch-23.04' of https://github.com/tarang-jain/raft in…
tarang-jain Mar 23, 2023
7c86e6e
more changes
tarang-jain Mar 30, 2023
d0c795c
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 19, 2023
27210b9
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 20, 2023
b26cda2
sorting impl
tarang-jain Apr 20, 2023
e48c486
Gather function
tarang-jain Apr 21, 2023
89caefb
Updated with batch
tarang-jain Apr 22, 2023
3b0e186
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 24, 2023
4ef8731
Working impl
tarang-jain Apr 25, 2023
aebbb3f
Remove debug
tarang-jain Apr 26, 2023
aa699d0
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 26, 2023
6dbd6a4
change api
tarang-jain Apr 27, 2023
435e408
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 27, 2023
7ea89c2
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain May 3, 2023
681da7a
Benchmarking
tarang-jain May 3, 2023
81a3d60
Remove nvtx
tarang-jain May 3, 2023
d7aec0b
bm
tarang-jain May 5, 2023
8949881
Row batch_size
tarang-jain May 11, 2023
a51966a
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain May 11, 2023
9ecd782
Changes after PR Reviews
tarang-jain May 12, 2023
b371c63
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain May 12, 2023
3b99d35
Some updates after new PR Reviews
tarang-jain May 16, 2023
8e99668
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain May 16, 2023
9c7dcef
Docstring change
tarang-jain May 16, 2023
e050d4c
Styling changes
tarang-jain May 16, 2023
57e081a
remove device resources
tarang-jain May 16, 2023
0323a95
Resolve merge conflicts
tarang-jain May 16, 2023
2162de3
Some changes after PR reviews
tarang-jain May 18, 2023
c0fa543
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain May 18, 2023
1385573
rbug fixes
tarang-jain May 19, 2023
0aee3d5
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain May 19, 2023
a7ba987
Bug fixes
tarang-jain May 19, 2023
10d5d9d
Remove unnecessary imports
tarang-jain May 19, 2023
0f73a8c
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain May 20, 2023
76031d1
Refactor based on Allard's comments
tarang-jain May 21, 2023
87c62a6
Debugging differences between fused and masked
tarang-jain May 22, 2023
55d6d49
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain May 22, 2023
f34f0bc
cleanup
tarang-jain May 22, 2023
e0b4118
Update copyright
tarang-jain May 22, 2023
116ee3d
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain May 23, 2023
f1b3bf4
Working gtest
tarang-jain May 24, 2023
c433d49
scatter gtest and refactoring
tarang-jain May 26, 2023
81ca1fb
bug free
tarang-jain Jun 3, 2023
d0eb626
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jun 3, 2023
1e22636
style changes
tarang-jain Jun 5, 2023
e3dbdb0
Remove unnecessary imports
tarang-jain Jun 5, 2023
12a2f5c
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jun 5, 2023
85abf92
Merge branch 'branch-23.08' into update-connected-components
cjnolet Jun 6, 2023
7120e58
merge
tarang-jain Jun 7, 2023
f53ad4b
Merge branch 'update-connected-components' of https://github.com/tara…
tarang-jain Jun 7, 2023
fefaac2
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jun 8, 2023
6ae1081
some updates after pr reviews
tarang-jain Jun 9, 2023
d178104
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jun 9, 2023
b7be24c
Updates after PR reviews
tarang-jain Jun 13, 2023
2aca710
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jun 13, 2023
7f4f9f3
Updates after PR reviews
tarang-jain Jun 13, 2023
be74d60
nit
tarang-jain Jun 13, 2023
b94844b
Resolve typos
tarang-jain Jun 14, 2023
8df0e00
add libcudacxx dependency
tarang-jain Jun 20, 2023
e3121c5
Update with libcudacxx type_traits header
tarang-jain Jun 21, 2023
d929378
Update todo
tarang-jain Jun 21, 2023
a0169c3
add proclaim_return_type to predicate
tarang-jain Jun 21, 2023
69e393a
remove libcudacxx dependency and rename api
tarang-jain Jun 22, 2023
c4958ac
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jun 22, 2023
d95350f
updates
tarang-jain Jun 23, 2023
2a0a491
dbg
tarang-jain Jun 23, 2023
6ef2b92
fix failing test
tarang-jain Jun 23, 2023
c600728
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jun 23, 2023
cad9b0e
Link issue
tarang-jain Jun 23, 2023
c02d67b
fix docs
tarang-jain Jun 23, 2023
43778ae
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 17, 2023
f714918
doc fix
tarang-jain Jul 17, 2023
17a79a6
Added more tests for gather
tarang-jain Jul 17, 2023
a61df96
Scatter tests
tarang-jain Jul 17, 2023
91ee0c7
Update tests, reenaming API headers
tarang-jain Jul 20, 2023
0efa56e
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 20, 2023
5bfbcb6
Merge branch 'branch-23.08' into update-connected-components
tarang-jain Jul 20, 2023
9990062
Stylegi
tarang-jain Jul 20, 2023
738fbbc
Merge branch 'update-connected-components' of https://github.com/tara…
tarang-jain Jul 20, 2023
6ea871d
remove EXPLICIT_INSTANTIATE
tarang-jain Jul 21, 2023
8d618a3
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 21, 2023
4437644
Merge branch 'branch-23.08' into update-connected-components
tarang-jain Jul 21, 2023
92b8697
remove todo
tarang-jain Jul 25, 2023
fd9fa62
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 25, 2023
7e24fb0
Merge branch 'update-connected-components' of https://github.com/tara…
tarang-jain Jul 25, 2023
03d9ef4
revert
tarang-jain Jul 25, 2023
2da5b8c
Update todo
tarang-jain Jul 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cpp/bench/prims/distance/masked_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ struct masked_l2_nn : public fixture {
dim3 block(32, 32);
dim3 grid(10, 10);
init_adj<<<grid, block, 0, stream>>>(p.pattern, p.n, adj.view(), group_idxs.view());

RAFT_CUDA_TRY(cudaGetLastError());
}

Expand Down
13 changes: 9 additions & 4 deletions cpp/include/raft/cluster/detail/mst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <raft/sparse/neighbors/connect_components.cuh>
#include <raft/sparse/neighbors/cross_component_nn.cuh>
#include <raft/sparse/op/sort.cuh>
#include <raft/sparse/solver/mst.cuh>
#include <rmm/device_uvector.hpp>
Expand Down Expand Up @@ -81,15 +81,20 @@ void connect_knn_graph(

raft::sparse::COO<value_t, value_idx> connected_edges(stream);

raft::sparse::neighbors::connect_components<value_idx, value_t>(handle,
// default row and column batch sizes are chosen for computing cross component nearest neighbors.
// Reference: PR #1445
static constexpr size_t default_row_batch_size = 4096;
static constexpr size_t default_col_batch_size = 16;

raft::sparse::neighbors::cross_component_nn<value_idx, value_t>(handle,
connected_edges,
X,
color,
m,
n,
reduction_op,
min(m, (size_t)4096),
min(n, (size_t)16));
min(m, default_row_batch_size),
min(n, default_col_batch_size));

rmm::device_uvector<value_idx> indptr2(m + 1, stream);
raft::sparse::convert::sorted_coo_to_csr(
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/cluster/detail/single_linkage.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void single_linkage(raft::resources const& handle,
* 2. Construct MST, sorted by weights
*/
rmm::device_uvector<value_idx> color(m, stream);
raft::sparse::neighbors::FixConnectivitiesRedOp<value_idx, value_t> op(color.data(), m);
raft::sparse::neighbors::FixConnectivitiesRedOp<value_idx, value_t> op(m);
detail::build_sorted_mst<value_idx, value_t>(handle,
X,
indptr.data(),
Expand Down
92 changes: 0 additions & 92 deletions cpp/include/raft/matrix/batched_rearrange.cuh
tarang-jain marked this conversation as resolved.
Outdated
Show resolved Hide resolved

This file was deleted.

103 changes: 1 addition & 102 deletions cpp/include/raft/matrix/detail/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,9 @@

#pragma once

#include <raft/core/device_mdarray.hpp>
#include <functional>
#include <raft/core/operators.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/linalg/map.cuh>
#include <raft/util/cuda_dev_essentials.cuh>
#include <raft/util/cudart_utils.hpp>
#include <raft/util/fast_int_div.cuh>
#include <thrust/iterator/counting_iterator.h>

namespace raft {
namespace matrix {
Expand Down Expand Up @@ -141,16 +136,6 @@ void gatherImpl(const InputIteratorT in,
// stencil value type
typedef typename std::iterator_traits<StencilIteratorT>::value_type StencilValueT;

// return type of MapTransformOp, must be convertible to IndexT
typedef typename std::result_of<decltype(transform_op)(MapValueT)>::type MapTransformOpReturnT;
static_assert((std::is_convertible<MapTransformOpReturnT, IndexT>::value),
"MapTransformOp's result type must be convertible to signed integer");

// return type of UnaryPredicateOp, must be convertible to bool
typedef typename std::result_of<decltype(pred_op)(StencilValueT)>::type PredicateOpReturnT;
static_assert((std::is_convertible<PredicateOpReturnT, bool>::value),
"UnaryPredicateOp's result type must be convertible to bool type");

IndexT len = map_length * D;
constexpr int TPB = 128;
const int n_sm = raft::getMultiProcessorCount();
Expand Down Expand Up @@ -350,92 +335,6 @@ void gather_if(const InputIteratorT in,
gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream);
}

template <typename MatrixT, typename MapT, typename MapTransformOp, typename IndexT>
void gatherInplaceImpl(raft::resources const& handle,
raft::device_matrix_view<MatrixT, IndexT, raft::layout_c_contiguous> inout,
raft::device_vector_view<const MapT, IndexT, raft::layout_c_contiguous> map,
MapTransformOp transform_op,
IndexT batch_size)
{
// return type of MapTransformOp, must be convertible to IndexT
typedef typename std::result_of<decltype(transform_op)(MapT)>::type MapTransformOpReturnT;
RAFT_EXPECTS((std::is_convertible<MapTransformOpReturnT, IndexT>::value),
"MapTransformOp's result type must be convertible to signed integer");

IndexT m = inout.extent(0);
IndexT n = inout.extent(1);
IndexT map_length = map.extent(0);

// skip in case of 0 length input
if (map_length <= 0 || m <= 0 || n <= 0 || batch_size < 0) return;

RAFT_EXPECTS(map_length <= m, "Length of map should be <= number of rows for inplace gather");

// re-assign batch_size for default case
if (batch_size == 0) batch_size = n;

RAFT_EXPECTS(batch_size <= n, "batch size should be <= number of columns");

auto exec_policy = resource::get_thrust_policy(handle);
IndexT n_batches = raft::ceildiv(n, batch_size);
for (IndexT bid = 0; bid < n_batches; bid++) {
IndexT batch_offset = bid * batch_size;
IndexT cols_per_batch = min(batch_size, n - batch_offset);
auto scratch_space =
raft::make_device_vector<MatrixT, IndexT>(handle, map_length * cols_per_batch);

auto gather_op = [inout = inout.data_handle(),
map = map.data_handle(),
transform_op,
batch_offset,
map_length,
cols_per_batch = raft::util::FastIntDiv(cols_per_batch),
n] __device__(auto idx) {
IndexT row = idx / cols_per_batch;
IndexT col = idx % cols_per_batch;
MapT map_val = map[row];

IndexT i_src = transform_op(map_val);
return inout[i_src * n + batch_offset + col];
};
raft::linalg::map_offset(handle, scratch_space.view(), gather_op);

auto copy_op = [inout = inout.data_handle(),
map = map.data_handle(),
scratch_space = scratch_space.data_handle(),
batch_offset,
map_length,
cols_per_batch = raft::util::FastIntDiv(cols_per_batch),
n] __device__(auto idx) {
IndexT row = idx / cols_per_batch;
IndexT col = idx % cols_per_batch;
inout[row * n + batch_offset + col] = scratch_space[idx];
return;
};
auto counting = thrust::make_counting_iterator<IndexT>(0);
thrust::for_each(exec_policy, counting, counting + map_length * cols_per_batch, copy_op);
}
}

template <typename MatrixT, typename MapT, typename MapTransformOp, typename IndexT>
void gather(raft::resources const& handle,
raft::device_matrix_view<MatrixT, IndexT, raft::layout_c_contiguous> inout,
raft::device_vector_view<const MapT, IndexT, raft::layout_c_contiguous> map,
MapTransformOp transform_op,
IndexT batch_size)
{
gatherInplaceImpl(handle, inout, map, transform_op, batch_size);
}

template <typename MatrixT, typename MapT, typename IndexT>
void gather(raft::resources const& handle,
raft::device_matrix_view<MatrixT, IndexT, raft::layout_c_contiguous> inout,
raft::device_vector_view<const MapT, IndexT, raft::layout_c_contiguous> map,
IndexT batch_size)
{
gatherInplaceImpl(handle, inout, map, raft::identity_op(), batch_size);
}

} // namespace detail
} // namespace matrix
} // namespace raft
116 changes: 116 additions & 0 deletions cpp/include/raft/matrix/detail/gather_inplace.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/linalg/map.cuh>
#include <raft/util/fast_int_div.cuh>
#include <thrust/iterator/counting_iterator.h>

namespace raft {
namespace matrix {
namespace detail {

template <typename MatrixT, typename MapT, typename MapTransformOp, typename IndexT>
void gatherInplaceImpl(raft::resources const& handle,
raft::device_matrix_view<MatrixT, IndexT, raft::layout_c_contiguous> inout,
raft::device_vector_view<const MapT, IndexT, raft::layout_c_contiguous> map,
MapTransformOp transform_op,
IndexT batch_size)
{
IndexT m = inout.extent(0);
IndexT n = inout.extent(1);
IndexT map_length = map.extent(0);

// skip in case of 0 length input
if (map_length <= 0 || m <= 0 || n <= 0 || batch_size < 0) return;

RAFT_EXPECTS(map_length <= m, "Length of map should be <= number of rows for inplace gather");

RAFT_EXPECTS(batch_size >= 0, "batch size should be >= 0");

// re-assign batch_size for default case
if (batch_size == 0 || batch_size > n) batch_size = n;

auto exec_policy = resource::get_thrust_policy(handle);

IndexT n_batches = raft::ceildiv(n, batch_size);

auto scratch_space = raft::make_device_vector<MatrixT, IndexT>(handle, map_length * batch_size);

for (IndexT bid = 0; bid < n_batches; bid++) {
IndexT batch_offset = bid * batch_size;
IndexT cols_per_batch = min(batch_size, n - batch_offset);

auto gather_op = [inout = inout.data_handle(),
map = map.data_handle(),
transform_op,
batch_offset,
map_length,
cols_per_batch = raft::util::FastIntDiv(cols_per_batch),
n] __device__(auto idx) {
IndexT row = idx / cols_per_batch;
IndexT col = idx % cols_per_batch;
MapT map_val = map[row];

IndexT i_src = transform_op(map_val);
return inout[i_src * n + batch_offset + col];
};
raft::linalg::map_offset(
handle,
raft::make_device_vector_view(scratch_space.data_handle(), map_length * cols_per_batch),
gather_op);

auto copy_op = [inout = inout.data_handle(),
map = map.data_handle(),
scratch_space = scratch_space.data_handle(),
batch_offset,
map_length,
cols_per_batch = raft::util::FastIntDiv(cols_per_batch),
n] __device__(auto idx) {
IndexT row = idx / cols_per_batch;
IndexT col = idx % cols_per_batch;
inout[row * n + batch_offset + col] = scratch_space[idx];
return;
};
auto counting = thrust::make_counting_iterator<IndexT>(0);
thrust::for_each(exec_policy, counting, counting + map_length * cols_per_batch, copy_op);
}
}

template <typename MatrixT, typename MapT, typename MapTransformOp, typename IndexT>
void gather(raft::resources const& handle,
raft::device_matrix_view<MatrixT, IndexT, raft::layout_c_contiguous> inout,
raft::device_vector_view<const MapT, IndexT, raft::layout_c_contiguous> map,
MapTransformOp transform_op,
IndexT batch_size)
{
gatherInplaceImpl(handle, inout, map, transform_op, batch_size);
}

template <typename MatrixT, typename MapT, typename IndexT>
void gather(raft::resources const& handle,
raft::device_matrix_view<MatrixT, IndexT, raft::layout_c_contiguous> inout,
raft::device_vector_view<const MapT, IndexT, raft::layout_c_contiguous> map,
IndexT batch_size)
{
gatherInplaceImpl(handle, inout, map, raft::identity_op(), batch_size);
}

} // namespace detail
} // namespace matrix
} // namespace raft
Loading