Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose local scaling parameters in UMAP from fitting process to enable inverse_transform and improve CPU interoperability #6185

Draft
wants to merge 1 commit into
base: branch-25.02
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpp/include/cuml/manifold/umap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ std::unique_ptr<raft::sparse::COO<float, int>> get_graph(const raft::handle_t& h
int d,
int64_t* knn_indices,
float* knn_dists,
float * sigmas,
float * rhos,
UMAPParams* params);

/**
Expand Down Expand Up @@ -128,6 +130,8 @@ void fit(const raft::handle_t& handle,
float* knn_dists,
UMAPParams* params,
float* embeddings,
float * sigmas,
float * rhos,
raft::sparse::COO<float, int>* graph);

/**
Expand Down Expand Up @@ -159,6 +163,8 @@ void fit_sparse(const raft::handle_t& handle,
float* knn_dists,
UMAPParams* params,
float* embeddings,
float * sigmas,
float * rhos,
raft::sparse::COO<float, int>* graph);

/**
Expand Down
22 changes: 12 additions & 10 deletions cpp/src/umap/fuzzy_simpl_set/naive.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -291,21 +291,23 @@ void launcher(int n,
int n_neighbors,
raft::sparse::COO<value_t>* out,
UMAPParams* params,
float * sigmas,
float * rhos,
cudaStream_t stream)
{
/**
* Calculate mean distance through a parallel reduction
*/
rmm::device_uvector<value_t> sigmas(n, stream);
rmm::device_uvector<value_t> rhos(n, stream);
RAFT_CUDA_TRY(cudaMemsetAsync(sigmas.data(), 0, n * sizeof(value_t), stream));
RAFT_CUDA_TRY(cudaMemsetAsync(rhos.data(), 0, n * sizeof(value_t), stream));
// rmm::device_uvector<value_t> sigmas(n, stream);
// rmm::device_uvector<value_t> rhos(n, stream);
// RAFT_CUDA_TRY(cudaMemsetAsync(sigmas.data(), 0, n * sizeof(value_t), stream));
// RAFT_CUDA_TRY(cudaMemsetAsync(rhos.data(), 0, n * sizeof(value_t), stream));

smooth_knn_dist<TPB_X, value_idx, value_t>(n,
knn_indices,
knn_dists,
rhos.data(),
sigmas.data(),
rhos,
sigmas,
params,
n_neighbors,
params->local_connectivity,
Expand All @@ -316,9 +318,9 @@ void launcher(int n,
// check for logging in order to avoid the potentially costly `arr2Str` call!
if (ML::Logger::get().shouldLogFor(CUML_LEVEL_DEBUG)) {
CUML_LOG_DEBUG("Smooth kNN Distances");
auto str = raft::arr2Str(sigmas.data(), 25, "sigmas", stream);
auto str = raft::arr2Str(sigmas, 25, "sigmas", stream);
CUML_LOG_DEBUG("%s", str.c_str());
str = raft::arr2Str(rhos.data(), 25, "rhos", stream);
str = raft::arr2Str(rhos, 25, "rhos", stream);
CUML_LOG_DEBUG("%s", str.c_str());
}

Expand All @@ -333,8 +335,8 @@ void launcher(int n,

compute_membership_strength_kernel<TPB_X><<<grid_elm, blk_elm, 0, stream>>>(knn_indices,
knn_dists,
sigmas.data(),
rhos.data(),
sigmas,
rhos,
in.vals(),
in.rows(),
in.cols(),
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/umap/fuzzy_simpl_set/runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ void run(int n,
int n_neighbors,
raft::sparse::COO<T>* coo,
UMAPParams* params,
float * sigmas,
float * rhos,
cudaStream_t stream,
int algorithm = 0)
{
switch (algorithm) {
case 0:
Naive::launcher<TPB_X, value_idx, T>(
n, knn_indices, knn_dists, n_neighbors, coo, params, stream);
n, knn_indices, knn_dists, n_neighbors, coo, params, sigmas, rhos, stream);
break;
}
}
Expand Down
18 changes: 14 additions & 4 deletions cpp/src/umap/runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ template <typename value_idx, typename value_t, typename umap_inputs, int TPB_X>
void _get_graph(const raft::handle_t& handle,
const umap_inputs& inputs,
UMAPParams* params,
float * sigmas,
float * rhos,
raft::sparse::COO<value_t, int>* graph)
{
raft::common::nvtx::range fun_scope("umap::supervised::_get_graph");
Expand Down Expand Up @@ -137,7 +139,7 @@ void _get_graph(const raft::handle_t& handle,
raft::common::nvtx::push_range("umap::simplicial_set");
raft::sparse::COO<value_t> fss_graph(stream);
FuzzySimplSet::run<TPB_X, value_idx, value_t>(
inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, k, &fss_graph, params, stream);
inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, k, &fss_graph, params, sigmas, rhos, stream);

CUML_LOG_DEBUG("Done. Calling remove zeros");

Expand All @@ -152,6 +154,8 @@ template <typename value_idx, typename value_t, typename umap_inputs, int TPB_X>
void _get_graph_supervised(const raft::handle_t& handle,
const umap_inputs& inputs,
UMAPParams* params,
float * sigmas,
float * rhos,
raft::sparse::COO<value_t, int>* graph)
{
raft::common::nvtx::range fun_scope("umap::supervised::_get_graph_supervised");
Expand Down Expand Up @@ -206,6 +210,8 @@ void _get_graph_supervised(const raft::handle_t& handle,
params->n_neighbors,
&fss_graph_tmp,
params,
sigmas,
rhos,
stream);
RAFT_CUDA_TRY(cudaPeekAtLastError());

Expand All @@ -228,7 +234,7 @@ void _get_graph_supervised(const raft::handle_t& handle,
} else {
CUML_LOG_DEBUG("Performing general intersection");
Supervised::perform_general_intersection<TPB_X, value_idx, value_t>(
handle, inputs.y, &fss_graph, &ci_graph, params, stream);
handle, inputs.y, &fss_graph, &ci_graph, params, sigmas, rhos, stream);
}

/**
Expand Down Expand Up @@ -277,14 +283,16 @@ void _fit(const raft::handle_t& handle,
const umap_inputs& inputs,
UMAPParams* params,
value_t* embeddings,
float * sigmas,
float * rhos,
raft::sparse::COO<float, int>* graph)
{
raft::common::nvtx::range fun_scope("umap::unsupervised::fit");

cudaStream_t stream = handle.get_stream();
ML::Logger::get().setLevel(params->verbosity);

UMAPAlgo::_get_graph<value_idx, value_t, umap_inputs, TPB_X>(handle, inputs, params, graph);
UMAPAlgo::_get_graph<value_idx, value_t, umap_inputs, TPB_X>(handle, inputs, params, sigmas, rhos, graph);

/**
* Run initialization method
Expand Down Expand Up @@ -313,6 +321,8 @@ void _fit_supervised(const raft::handle_t& handle,
const umap_inputs& inputs,
UMAPParams* params,
value_t* embeddings,
float * sigmas,
float * rhos,
raft::sparse::COO<float, int>* graph)
{
raft::common::nvtx::range fun_scope("umap::supervised::fit");
Expand All @@ -321,7 +331,7 @@ void _fit_supervised(const raft::handle_t& handle,
ML::Logger::get().setLevel(params->verbosity);

UMAPAlgo::_get_graph_supervised<value_idx, value_t, umap_inputs, TPB_X>(
handle, inputs, params, graph);
handle, inputs, params, sigmas, rhos, graph);

/**
* Initialize embeddings
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/umap/supervised.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ void perform_general_intersection(const raft::handle_t& handle,
raft::sparse::COO<value_t>* rgraph_coo,
raft::sparse::COO<value_t>* final_coo,
UMAPParams* params,
float * sigmas,
float * rhos,
cudaStream_t stream)
{
/**
Expand Down Expand Up @@ -323,6 +325,8 @@ void perform_general_intersection(const raft::handle_t& handle,
params->target_n_neighbors,
&ygraph_coo,
params,
sigmas,
rhos,
stream);
RAFT_CUDA_TRY(cudaPeekAtLastError());

Expand Down
30 changes: 18 additions & 12 deletions cpp/src/umap/umap.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ std::unique_ptr<raft::sparse::COO<float, int>> get_graph(
int d,
knn_indices_dense_t* knn_indices, // precomputed indices
float* knn_dists, // precomputed distances
float* sigmas,
float* rhos,
UMAPParams* params)
{
auto graph = std::make_unique<raft::sparse::COO<float>>(handle.get_stream());
Expand All @@ -56,23 +58,23 @@ std::unique_ptr<raft::sparse::COO<float, int>> get_graph(
UMAPAlgo::_get_graph_supervised<knn_indices_dense_t,
float,
manifold_precomputed_knn_inputs_t<knn_indices_dense_t, float>,
TPB_X>(handle, inputs, params, graph.get());
TPB_X>(handle, inputs, params, sigmas, rhos, graph.get());
} else {
UMAPAlgo::_get_graph<knn_indices_dense_t,
float,
manifold_precomputed_knn_inputs_t<knn_indices_dense_t, float>,
TPB_X>(handle, inputs, params, graph.get());
TPB_X>(handle, inputs, params, sigmas, rhos, graph.get());
}
return graph;
} else {
manifold_dense_inputs_t<float> inputs(X, y, n, d);
if (y != nullptr) {
UMAPAlgo::
_get_graph_supervised<knn_indices_dense_t, float, manifold_dense_inputs_t<float>, TPB_X>(
handle, inputs, params, graph.get());
handle, inputs, params, sigmas, rhos, graph.get());
} else {
UMAPAlgo::_get_graph<knn_indices_dense_t, float, manifold_dense_inputs_t<float>, TPB_X>(
handle, inputs, params, graph.get());
handle, inputs, params, sigmas, rhos, graph.get());
}
return graph;
}
Expand Down Expand Up @@ -115,6 +117,8 @@ void fit(const raft::handle_t& handle,
float* knn_dists,
UMAPParams* params,
float* embeddings,
float * sigmas,
float * rhos,
raft::sparse::COO<float, int>* graph)
{
if (knn_indices != nullptr && knn_dists != nullptr) {
Expand All @@ -126,22 +130,22 @@ void fit(const raft::handle_t& handle,
UMAPAlgo::_fit_supervised<knn_indices_dense_t,
float,
manifold_precomputed_knn_inputs_t<knn_indices_dense_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
TPB_X>(handle, inputs, params, embeddings, sigmas, rhos, graph);
} else {
UMAPAlgo::_fit<knn_indices_dense_t,
float,
manifold_precomputed_knn_inputs_t<knn_indices_dense_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
TPB_X>(handle, inputs, params, embeddings, sigmas, rhos, graph);
}

} else {
manifold_dense_inputs_t<float> inputs(X, y, n, d);
if (y != nullptr) {
UMAPAlgo::_fit_supervised<knn_indices_dense_t, float, manifold_dense_inputs_t<float>, TPB_X>(
handle, inputs, params, embeddings, graph);
handle, inputs, params, embeddings, sigmas, rhos, graph);
} else {
UMAPAlgo::_fit<knn_indices_dense_t, float, manifold_dense_inputs_t<float>, TPB_X>(
handle, inputs, params, embeddings, graph);
handle, inputs, params, embeddings, sigmas, rhos, graph);
}
}
}
Expand All @@ -158,6 +162,8 @@ void fit_sparse(const raft::handle_t& handle,
float* knn_dists,
UMAPParams* params,
float* embeddings,
float * sigmas,
float * rhos,
raft::sparse::COO<float, int>* graph)
{
if (knn_indices != nullptr && knn_dists != nullptr) {
Expand All @@ -167,25 +173,25 @@ void fit_sparse(const raft::handle_t& handle,
UMAPAlgo::_fit_supervised<knn_indices_sparse_t,
float,
manifold_precomputed_knn_inputs_t<knn_indices_sparse_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
TPB_X>(handle, inputs, params, embeddings, sigmas, rhos, graph);
} else {
UMAPAlgo::_fit<knn_indices_sparse_t,
float,
manifold_precomputed_knn_inputs_t<knn_indices_sparse_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
TPB_X>(handle, inputs, params, embeddings, sigmas, rhos, graph);
}
} else {
manifold_sparse_inputs_t<int, float> inputs(indptr, indices, data, y, nnz, n, d);
if (y != nullptr) {
UMAPAlgo::_fit_supervised<knn_indices_sparse_t,
float,
manifold_sparse_inputs_t<knn_indices_sparse_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
TPB_X>(handle, inputs, params, embeddings, sigmas, rhos, graph);
} else {
UMAPAlgo::_fit<knn_indices_sparse_t,
float,
manifold_sparse_inputs_t<knn_indices_sparse_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
TPB_X>(handle, inputs, params, embeddings, sigmas, rhos, graph);
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions python/cuml/cuml/manifold/simpl_set.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP":
int d,
int64_t* knn_indices,
float* knn_dists,
float * sigmas,
float * rhos,
UMAPParams* params)

void refine(handle_t &handle,
Expand Down Expand Up @@ -197,6 +199,12 @@ def fuzzy_simplicial_set(X,
knn_indices_ptr = 0
knn_dists_ptr = 0

sigmas = CumlArray.zeros(self.n_rows, dtype=np.float32)
rhos = CumlArray.zeros(self.n_rows, dtype=np.float32)

cdef uintptr_t _signmas_ptr = sigmas.ptr
cdef uintptr_t _rhos_ptr = rhos.ptr

handle = Handle()
cdef handle_t* handle_ = <handle_t*><size_t>handle.getHandle()
cdef unique_ptr[COO] fss_graph_ptr = get_graph(
Expand All @@ -207,6 +215,8 @@ def fuzzy_simplicial_set(X,
<int> X.shape[1],
<int64_t*><uintptr_t> knn_indices_ptr,
<float*><uintptr_t> knn_dists_ptr,
<float*> _signmas_ptr,
<float*> _rhos_ptr,
<UMAPParams*> umap_params)
fss_graph = GraphHolder.from_ptr(fss_graph_ptr)

Expand Down
Loading
Loading