Skip to content

Commit

Permalink
[FEA] Update MST Reduction Op (#5386)
Browse files Browse the repository at this point in the history
Make HDBSCAN MST reduction operation compatible with [#1445](rapidsai/raft#1445) raft PR. This PR updates the Mutual Reachability reduction op in HDBSCAN in the following ways:
1. Colors information is no longer needed in the reduction op because the new `mst` implementation in raft ensures that distances between points of the same color are not computed.
2. Adds gather and scatter functions to rearrange the core distances within the reduction op so that they are aligned with the sort-plan wherein rows in the input matrix and core distances are rearranged so that the training data points are sorted by color.
Closes #5456

Authors:
  - Tarang Jain (https://github.com/tarang-jain)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #5386
  • Loading branch information
tarang-jain authored Jul 26, 2023
1 parent b2a781f commit cfa5020
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions cpp/src/hdbscan/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <raft/core/resource/thrust_policy.hpp>
#include <raft/util/cudart_utils.hpp>

#include <raft/core/handle.hpp>
Expand All @@ -36,6 +37,8 @@

#include <thrust/device_ptr.h>
#include <thrust/extrema.h>
#include <thrust/gather.h>
#include <thrust/scatter.h>
#include <thrust/transform.h>

namespace ML {
Expand All @@ -50,20 +53,17 @@ namespace HDBSCAN {
*/
template <typename value_idx, typename value_t>
struct FixConnectivitiesRedOp {
value_idx* colors;
value_t* core_dists;
value_idx m;

DI FixConnectivitiesRedOp() : colors(0), m(0) {}
DI FixConnectivitiesRedOp() : m(0) {}

FixConnectivitiesRedOp(value_idx* colors_, value_t* core_dists_, value_idx m_)
: colors(colors_), core_dists(core_dists_), m(m_){};
FixConnectivitiesRedOp(value_t* core_dists_, value_idx m_) : core_dists(core_dists_), m(m_){};

typedef typename raft::KeyValuePair<value_idx, value_t> KVP;
DI void operator()(value_idx rit, KVP* out, const KVP& other) const
{
if (rit < m && other.value < std::numeric_limits<value_t>::max() &&
colors[rit] != colors[other.key]) {
if (rit < m && other.value < std::numeric_limits<value_t>::max()) {
value_t core_dist_rit = core_dists[rit];
value_t core_dist_other = max(core_dist_rit, max(core_dists[other.key], other.value));

Expand All @@ -82,7 +82,7 @@ struct FixConnectivitiesRedOp {

DI KVP operator()(value_idx rit, const KVP& a, const KVP& b) const
{
if (rit < m && a.key > -1 && colors[rit] != colors[a.key]) {
if (rit < m && a.key > -1) {
value_t core_dist_rit = core_dists[rit];
value_t core_dist_a = max(core_dist_rit, max(core_dists[a.key], a.value));

Expand Down Expand Up @@ -111,6 +111,30 @@ struct FixConnectivitiesRedOp {

DI value_t get_value(KVP& out) const { return out.value; }
DI value_t get_value(value_t& out) const { return out; }

void gather(const raft::resources& handle, value_idx* map)
{
auto tmp_core_dists = raft::make_device_vector<value_t>(handle, m);
thrust::gather(raft::resource::get_thrust_policy(handle),
map,
map + m,
core_dists,
tmp_core_dists.data_handle());
raft::copy_async(
core_dists, tmp_core_dists.data_handle(), m, raft::resource::get_cuda_stream(handle));
}

void scatter(const raft::resources& handle, value_idx* map)
{
auto tmp_core_dists = raft::make_device_vector<value_t>(handle, m);
thrust::scatter(raft::resource::get_thrust_policy(handle),
core_dists,
core_dists + m,
map,
tmp_core_dists.data_handle());
raft::copy_async(
core_dists, tmp_core_dists.data_handle(), m, raft::resource::get_cuda_stream(handle));
}
};

/**
Expand Down Expand Up @@ -167,7 +191,7 @@ void build_linkage(const raft::handle_t& handle,
*/

rmm::device_uvector<value_idx> color(m, stream);
FixConnectivitiesRedOp<value_idx, value_t> red_op(color.data(), core_dists, m);
FixConnectivitiesRedOp<value_idx, value_t> red_op(core_dists, m);
// during knn graph connection
raft::cluster::detail::build_sorted_mst(handle,
X,
Expand Down

0 comments on commit cfa5020

Please sign in to comment.