From 7c875ad92471475fd28bec0a69ace7bd6192b1d7 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Mon, 16 Dec 2024 22:52:05 -0600 Subject: [PATCH] ENH expose sigmas and rhos from fitting process to enable inverse_transform --- cpp/include/cuml/manifold/umap.hpp | 6 +++++ cpp/src/umap/fuzzy_simpl_set/naive.cuh | 22 +++++++++--------- cpp/src/umap/fuzzy_simpl_set/runner.cuh | 4 +++- cpp/src/umap/runner.cuh | 18 +++++++++++---- cpp/src/umap/supervised.cuh | 4 ++++ cpp/src/umap/umap.cu | 30 +++++++++++++++---------- python/cuml/cuml/manifold/simpl_set.pyx | 10 +++++++++ python/cuml/cuml/manifold/umap.pyx | 20 +++++++++++++++-- 8 files changed, 85 insertions(+), 29 deletions(-) diff --git a/cpp/include/cuml/manifold/umap.hpp b/cpp/include/cuml/manifold/umap.hpp index 7de08c5488..42d17b3055 100644 --- a/cpp/include/cuml/manifold/umap.hpp +++ b/cpp/include/cuml/manifold/umap.hpp @@ -61,6 +61,8 @@ std::unique_ptr> get_graph(const raft::handle_t& h int d, int64_t* knn_indices, float* knn_dists, + float * sigmas, + float * rhos, UMAPParams* params); /** @@ -128,6 +130,8 @@ void fit(const raft::handle_t& handle, float* knn_dists, UMAPParams* params, float* embeddings, + float * sigmas, + float * rhos, raft::sparse::COO* graph); /** @@ -159,6 +163,8 @@ void fit_sparse(const raft::handle_t& handle, float* knn_dists, UMAPParams* params, float* embeddings, + float * sigmas, + float * rhos, raft::sparse::COO* graph); /** diff --git a/cpp/src/umap/fuzzy_simpl_set/naive.cuh b/cpp/src/umap/fuzzy_simpl_set/naive.cuh index f872b80c4b..bdf9620f16 100644 --- a/cpp/src/umap/fuzzy_simpl_set/naive.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/naive.cuh @@ -291,21 +291,23 @@ void launcher(int n, int n_neighbors, raft::sparse::COO* out, UMAPParams* params, + float * sigmas, + float * rhos, cudaStream_t stream) { /** * Calculate mean distance through a parallel reduction */ - rmm::device_uvector sigmas(n, stream); - rmm::device_uvector rhos(n, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(sigmas.data(), 0, n * sizeof(value_t), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(rhos.data(), 0, n * sizeof(value_t), stream)); + // rmm::device_uvector sigmas(n, stream); + // rmm::device_uvector rhos(n, stream); + // RAFT_CUDA_TRY(cudaMemsetAsync(sigmas.data(), 0, n * sizeof(value_t), stream)); + // RAFT_CUDA_TRY(cudaMemsetAsync(rhos.data(), 0, n * sizeof(value_t), stream)); smooth_knn_dist(n, knn_indices, knn_dists, - rhos.data(), - sigmas.data(), + rhos, + sigmas, params, n_neighbors, params->local_connectivity, @@ -316,9 +318,9 @@ void launcher(int n, // check for logging in order to avoid the potentially costly `arr2Str` call! if (ML::Logger::get().shouldLogFor(CUML_LEVEL_DEBUG)) { CUML_LOG_DEBUG("Smooth kNN Distances"); - auto str = raft::arr2Str(sigmas.data(), 25, "sigmas", stream); + auto str = raft::arr2Str(sigmas, 25, "sigmas", stream); CUML_LOG_DEBUG("%s", str.c_str()); - str = raft::arr2Str(rhos.data(), 25, "rhos", stream); + str = raft::arr2Str(rhos, 25, "rhos", stream); CUML_LOG_DEBUG("%s", str.c_str()); } @@ -333,8 +335,8 @@ void launcher(int n, compute_membership_strength_kernel<<>>(knn_indices, knn_dists, - sigmas.data(), - rhos.data(), + sigmas, + rhos, in.vals(), in.rows(), in.cols(), diff --git a/cpp/src/umap/fuzzy_simpl_set/runner.cuh b/cpp/src/umap/fuzzy_simpl_set/runner.cuh index 6cfd3cd58d..52c759de5a 100644 --- a/cpp/src/umap/fuzzy_simpl_set/runner.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/runner.cuh @@ -45,13 +45,15 @@ void run(int n, int n_neighbors, raft::sparse::COO* coo, UMAPParams* params, + float * sigmas, + float * rhos, cudaStream_t stream, int algorithm = 0) { switch (algorithm) { case 0: Naive::launcher( - n, knn_indices, knn_dists, n_neighbors, coo, params, stream); + n, knn_indices, knn_dists, n_neighbors, coo, params, sigmas, rhos, stream); break; } } diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index 0ceeb3acaa..d51f2c41a8 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -95,6 +95,8 @@ template void _get_graph(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, + float * sigmas, + float * rhos, raft::sparse::COO* graph) { raft::common::nvtx::range fun_scope("umap::supervised::_get_graph"); @@ -137,7 +139,7 @@ void _get_graph(const raft::handle_t& handle, raft::common::nvtx::push_range("umap::simplicial_set"); raft::sparse::COO fss_graph(stream); FuzzySimplSet::run( - inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, k, &fss_graph, params, stream); + inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, k, &fss_graph, params, sigmas, rhos, stream); CUML_LOG_DEBUG("Done. Calling remove zeros"); @@ -152,6 +154,8 @@ template void _get_graph_supervised(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, + float * sigmas, + float * rhos, raft::sparse::COO* graph) { raft::common::nvtx::range fun_scope("umap::supervised::_get_graph_supervised"); @@ -206,6 +210,8 @@ void _get_graph_supervised(const raft::handle_t& handle, params->n_neighbors, &fss_graph_tmp, params, + sigmas, + rhos, stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); @@ -228,7 +234,7 @@ void _get_graph_supervised(const raft::handle_t& handle, } else { CUML_LOG_DEBUG("Performing general intersection"); Supervised::perform_general_intersection( - handle, inputs.y, &fss_graph, &ci_graph, params, stream); + handle, inputs.y, &fss_graph, &ci_graph, params, sigmas, rhos, stream); } /** @@ -277,6 +283,8 @@ void _fit(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, value_t* embeddings, + float * sigmas, + float * rhos, raft::sparse::COO* graph) { raft::common::nvtx::range fun_scope("umap::unsupervised::fit"); @@ -284,7 +292,7 @@ void _fit(const raft::handle_t& handle, cudaStream_t stream = handle.get_stream(); ML::Logger::get().setLevel(params->verbosity); - UMAPAlgo::_get_graph(handle, inputs, params, graph); + UMAPAlgo::_get_graph(handle, inputs, params, sigmas, rhos, graph); /** * Run initialization method @@ -313,6 +321,8 @@ void _fit_supervised(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, value_t* embeddings, + float * sigmas, + float * rhos, raft::sparse::COO* graph) { raft::common::nvtx::range fun_scope("umap::supervised::fit"); @@ -321,7 +331,7 @@ void _fit_supervised(const raft::handle_t& handle, ML::Logger::get().setLevel(params->verbosity); UMAPAlgo::_get_graph_supervised( - handle, inputs, params, graph); + handle, inputs, params, sigmas, rhos, graph); /** * Initialize embeddings diff --git a/cpp/src/umap/supervised.cuh b/cpp/src/umap/supervised.cuh index 21ed42f157..8403209234 100644 --- a/cpp/src/umap/supervised.cuh +++ b/cpp/src/umap/supervised.cuh @@ -283,6 +283,8 @@ void perform_general_intersection(const raft::handle_t& handle, raft::sparse::COO* rgraph_coo, raft::sparse::COO* final_coo, UMAPParams* params, + float * sigmas, + float * rhos, cudaStream_t stream) { /** @@ -323,6 +325,8 @@ void perform_general_intersection(const raft::handle_t& handle, params->target_n_neighbors, &ygraph_coo, params, + sigmas, + rhos, stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); diff --git a/cpp/src/umap/umap.cu b/cpp/src/umap/umap.cu index 899051f8de..005ad15998 100644 --- a/cpp/src/umap/umap.cu +++ b/cpp/src/umap/umap.cu @@ -44,6 +44,8 @@ std::unique_ptr> get_graph( int d, knn_indices_dense_t* knn_indices, // precomputed indices float* knn_dists, // precomputed distances + float* sigmas, + float* rhos, UMAPParams* params) { auto graph = std::make_unique>(handle.get_stream()); @@ -56,12 +58,12 @@ std::unique_ptr> get_graph( UMAPAlgo::_get_graph_supervised, - TPB_X>(handle, inputs, params, graph.get()); + TPB_X>(handle, inputs, params, sigmas, rhos, graph.get()); } else { UMAPAlgo::_get_graph, - TPB_X>(handle, inputs, params, graph.get()); + TPB_X>(handle, inputs, params, sigmas, rhos, graph.get()); } return graph; } else { @@ -69,10 +71,10 @@ std::unique_ptr> get_graph( if (y != nullptr) { UMAPAlgo:: _get_graph_supervised, TPB_X>( - handle, inputs, params, graph.get()); + handle, inputs, params, sigmas, rhos, graph.get()); } else { UMAPAlgo::_get_graph, TPB_X>( - handle, inputs, params, graph.get()); + handle, inputs, params, sigmas, rhos, graph.get()); } return graph; } @@ -115,6 +117,8 @@ void fit(const raft::handle_t& handle, float* knn_dists, UMAPParams* params, float* embeddings, + float * sigmas, + float * rhos, raft::sparse::COO* graph) { if (knn_indices != nullptr && knn_dists != nullptr) { @@ -126,22 +130,22 @@ void fit(const raft::handle_t& handle, UMAPAlgo::_fit_supervised, - TPB_X>(handle, inputs, params, embeddings, graph); + TPB_X>(handle, inputs, params, embeddings, sigmas, rhos, graph); } else { UMAPAlgo::_fit, - TPB_X>(handle, inputs, params, embeddings, graph); + TPB_X>(handle, inputs, params, embeddings, sigmas, rhos, graph); } } else { manifold_dense_inputs_t inputs(X, y, n, d); if (y != nullptr) { UMAPAlgo::_fit_supervised, TPB_X>( - handle, inputs, params, embeddings, graph); + handle, inputs, params, embeddings, sigmas, rhos, graph); } else { UMAPAlgo::_fit, TPB_X>( - handle, inputs, params, embeddings, graph); + handle, inputs, params, embeddings, sigmas, rhos, graph); } } } @@ -158,6 +162,8 @@ void fit_sparse(const raft::handle_t& handle, float* knn_dists, UMAPParams* params, float* embeddings, + float * sigmas, + float * rhos, raft::sparse::COO* graph) { if (knn_indices != nullptr && knn_dists != nullptr) { @@ -167,12 +173,12 @@ void fit_sparse(const raft::handle_t& handle, UMAPAlgo::_fit_supervised, - TPB_X>(handle, inputs, params, embeddings, graph); + TPB_X>(handle, inputs, params, embeddings, sigmas, rhos, graph); } else { UMAPAlgo::_fit, - TPB_X>(handle, inputs, params, embeddings, graph); + TPB_X>(handle, inputs, params, embeddings, sigmas, rhos, graph); } } else { manifold_sparse_inputs_t inputs(indptr, indices, data, y, nnz, n, d); @@ -180,12 +186,12 @@ void fit_sparse(const raft::handle_t& handle, UMAPAlgo::_fit_supervised, - TPB_X>(handle, inputs, params, embeddings, graph); + TPB_X>(handle, inputs, params, embeddings, sigmas, rhos, graph); } else { UMAPAlgo::_fit, - TPB_X>(handle, inputs, params, embeddings, graph); + TPB_X>(handle, inputs, params, embeddings, sigmas, rhos, graph); } } } diff --git a/python/cuml/cuml/manifold/simpl_set.pyx b/python/cuml/cuml/manifold/simpl_set.pyx index b0be2d5de7..a6f4faa08f 100644 --- a/python/cuml/cuml/manifold/simpl_set.pyx +++ b/python/cuml/cuml/manifold/simpl_set.pyx @@ -47,6 +47,8 @@ cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": int d, int64_t* knn_indices, float* knn_dists, + float * sigmas, + float * rhos, UMAPParams* params) void refine(handle_t &handle, @@ -197,6 +199,12 @@ def fuzzy_simplicial_set(X, knn_indices_ptr = 0 knn_dists_ptr = 0 + sigmas = CumlArray.zeros(self.n_rows, dtype=np.float32) + rhos = CumlArray.zeros(self.n_rows, dtype=np.float32) + + cdef uintptr_t _signmas_ptr = sigmas.ptr + cdef uintptr_t _rhos_ptr = rhos.ptr + handle = Handle() cdef handle_t* handle_ = handle.getHandle() cdef unique_ptr[COO] fss_graph_ptr = get_graph( @@ -207,6 +215,8 @@ def fuzzy_simplicial_set(X, X.shape[1], knn_indices_ptr, knn_dists_ptr, + _signmas_ptr, + _rhos_ptr, umap_params) fss_graph = GraphHolder.from_ptr(fss_graph_ptr) diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index 2c55e49c53..c1c79bb946 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -90,6 +90,8 @@ IF GPUBUILD == 1: float * knn_dists, UMAPParams * params, float * embeddings, + float * sigmas, + float * rhos, COO * graph) except + void fit_sparse(handle_t &handle, @@ -104,6 +106,8 @@ IF GPUBUILD == 1: float * knn_dists, UMAPParams *params, float *embeddings, + float * sigmas, + float * rhos, COO * graph) except + void transform(handle_t & handle, @@ -313,6 +317,7 @@ class UMAP(UniversalBase, * Using a pre-computed pairwise distance matrix (under consideration for future releases) * Manual initialization of initial embedding positions + * inverse_transform function In addition to these missing features, you should expect to see the final embeddings differing between cuml.umap and the reference @@ -577,11 +582,13 @@ class UMAP(UniversalBase, convert_format=False) self.n_rows, self.n_dims = self._raw_data.shape self.sparse_fit = True + self._sparse_data = True if self.build_algo == "nn_descent": raise ValueError("NN Descent does not support sparse inputs") # Handle dense inputs else: + self._sparse_data = False if data_on_host: convert_to_mem_type = MemoryType.host else: @@ -638,11 +645,15 @@ class UMAP(UniversalBase, order="C", dtype=np.float32, index=self._raw_data.index) + self._sigmas = CumlArray.zeros(self.n_rows, dtype=np.float32) + self._rhos = CumlArray.zeros(self.n_rows, dtype=np.float32) + if self.hash_input: self._input_hash = joblib.hash(self._raw_data.to_output('numpy')) cdef uintptr_t _embed_raw_ptr = self.embedding_.ptr - + cdef uintptr_t _signmas_ptr = self._sigmas.ptr + cdef uintptr_t _rhos_ptr = self._rhos.ptr cdef uintptr_t _y_raw_ptr = 0 if y is not None: @@ -673,6 +684,8 @@ class UMAP(UniversalBase, _knn_dists_ptr, umap_params, _embed_raw_ptr, + _signmas_ptr, + _rhos_ptr, fss_graph.get()) else: @@ -685,6 +698,8 @@ class UMAP(UniversalBase, _knn_dists_ptr, umap_params, _embed_raw_ptr, + _signmas_ptr, + _rhos_ptr, fss_graph.get()) self.graph_ = fss_graph.get_cupy_coo() @@ -908,6 +923,7 @@ class UMAP(UniversalBase, self.metric_kwds, False, self.random_state) super().gpu_to_cpu() + self._cpu_model._validate_parameters() @classmethod def _get_param_names(cls): @@ -943,4 +959,4 @@ class UMAP(UniversalBase, return ['_raw_data', 'embedding_', '_input_hash', '_small_data', '_knn_dists', '_knn_indices', '_knn_search_index', '_disconnection_distance', '_n_neighbors', '_a', '_b', - '_initial_alpha'] + '_initial_alpha', '_sparse_data', '_sigmas', '_rhos']