From f8a2c866ec849f773405b38d16de95fa217afe3a Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 14 Jun 2022 14:34:39 -0700 Subject: [PATCH 01/17] tsne allow distance types --- cpp/include/cuml/manifold/tsne.h | 5 ++++ cpp/src/tsne/distances.cuh | 22 +++++++++++------- cpp/src/tsne/tsne.cu | 16 ++++++++----- cpp/src/tsne/tsne_runner.cuh | 12 ++++++---- cpp/test/sg/tsne_test.cu | 6 +++-- python/cuml/manifold/t_sne.pyx | 40 ++++++++++++++++++++++++++++---- 6 files changed, 75 insertions(+), 26 deletions(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index a72ece863c..bdebca4858 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -17,6 +17,7 @@ #pragma once #include +#include namespace raft { class handle_t; @@ -118,6 +119,7 @@ struct TSNEParams { * @param[in] knn_dists Array containing nearest neighors distances. * @param[in] params Parameters for TSNE model * @param[out] kl_div (optional) KL divergence output + * @param[in] metric Distance metric * * The CUDA implementation is derived from the excellent CannyLabs open source * implementation here: https://github.com/CannyLab/tsne-cuda/. The CannyLabs @@ -134,6 +136,7 @@ void TSNE_fit(const raft::handle_t& handle, int64_t* knn_indices, float* knn_dists, TSNEParams& params, + raft::distance::DistanceType metric, float* kl_div = nullptr); /** @@ -152,6 +155,7 @@ void TSNE_fit(const raft::handle_t& handle, * @param[in] knn_dists Array containing nearest neighors distances. * @param[in] params Parameters for TSNE model * @param[out] kl_div (optional) KL divergence output + * @param[in] metric Distance metric * * The CUDA implementation is derived from the excellent CannyLabs open source * implementation here: https://github.com/CannyLab/tsne-cuda/. The CannyLabs @@ -171,6 +175,7 @@ void TSNE_fit_sparse(const raft::handle_t& handle, int* knn_indices, float* knn_dists, TSNEParams& params, + raft::distance::DistanceType metric, float* kl_div = nullptr); } // namespace ML diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index 828e93f3f7..e809280caf 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -37,8 +37,6 @@ namespace ML { namespace TSNE { -auto DEFAULT_DISTANCE_METRIC = raft::distance::DistanceType::L2SqrtExpanded; - /** * @brief Uses FAISS's KNN to find the top n_neighbors. This speeds up the attractive forces. * @param[in] input: dense/sparse manifold input @@ -46,19 +44,22 @@ auto DEFAULT_DISTANCE_METRIC = raft::distance::DistanceType::L2SqrtExpanded; * @param[out] distances: The output sorted distances from KNN. * @param[in] n_neighbors: The number of nearest neighbors you want. * @param[in] stream: The GPU stream. + * @param[in] metric: The distance metric. */ template void get_distances(const raft::handle_t& handle, tsne_input& input, knn_graph& k_graph, - cudaStream_t stream); + cudaStream_t stream, + raft::distance::DistanceType metric); // dense, int64 indices template <> void get_distances(const raft::handle_t& handle, manifold_dense_inputs_t& input, knn_graph& k_graph, - cudaStream_t stream) + cudaStream_t stream, + raft::distance::DistanceType metric) { // TODO: for TSNE transform first fit some points then transform with 1/(1+d^2) // #861 @@ -86,7 +87,7 @@ void get_distances(const raft::handle_t& handle, true, true, nullptr, - DEFAULT_DISTANCE_METRIC); + metric); } // dense, int32 indices @@ -94,7 +95,8 @@ template <> void get_distances(const raft::handle_t& handle, manifold_dense_inputs_t& input, knn_graph& k_graph, - cudaStream_t stream) + cudaStream_t stream, + raft::distance::DistanceType metric) { throw raft::exception("Dense TSNE does not support 32-bit integer indices yet."); } @@ -104,7 +106,8 @@ template <> void get_distances(const raft::handle_t& handle, manifold_sparse_inputs_t& input, knn_graph& k_graph, - cudaStream_t stream) + cudaStream_t stream, + raft::distance::DistanceType metric) { raft::sparse::selection::brute_force_knn(input.indptr, input.indices, @@ -124,7 +127,7 @@ void get_distances(const raft::handle_t& handle, handle, ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, - DEFAULT_DISTANCE_METRIC); + metric); } // sparse, int64 @@ -132,7 +135,8 @@ template <> void get_distances(const raft::handle_t& handle, manifold_sparse_inputs_t& input, knn_graph& k_graph, - cudaStream_t stream) + cudaStream_t stream, + raft::distance::DistanceType metric) { throw raft::exception("Sparse TSNE does not support 64-bit integer indices yet."); } diff --git a/cpp/src/tsne/tsne.cu b/cpp/src/tsne/tsne.cu index 378e854a3e..1cb919a54f 100644 --- a/cpp/src/tsne/tsne.cu +++ b/cpp/src/tsne/tsne.cu @@ -16,6 +16,7 @@ #include "tsne_runner.cuh" #include +#include namespace ML { @@ -23,9 +24,10 @@ template value_t _fit(const raft::handle_t& handle, tsne_input& input, knn_graph& k_graph, - TSNEParams& params) + TSNEParams& params, + raft::distance::DistanceType metric) { - TSNE_runner runner(handle, input, k_graph, params); + TSNE_runner runner(handle, input, k_graph, params, metric); return runner.run(); // returns the Kullback–Leibler divergence } @@ -38,7 +40,8 @@ void TSNE_fit(const raft::handle_t& handle, int64_t* knn_indices, float* knn_dists, TSNEParams& params, - float* kl_div) + float* kl_div, + raft::distance::DistanceType metric) { ASSERT(n > 0 && p > 0 && params.dim > 0 && params.n_neighbors > 0 && X != NULL && Y != NULL, "Wrong input args"); @@ -47,7 +50,7 @@ void TSNE_fit(const raft::handle_t& handle, knn_graph k_graph(n, params.n_neighbors, knn_indices, knn_dists); float kl_div_v = _fit, knn_indices_dense_t, float>( - handle, input, k_graph, params); + handle, input, k_graph, params, metric); if (kl_div) { *kl_div = kl_div_v; } } @@ -63,7 +66,8 @@ void TSNE_fit_sparse(const raft::handle_t& handle, int* knn_indices, float* knn_dists, TSNEParams& params, - float* kl_div) + float* kl_div, + raft::distance::DistanceType metric) { ASSERT(n > 0 && p > 0 && params.dim > 0 && params.n_neighbors > 0 && indptr != NULL && indices != NULL && data != NULL && Y != NULL, @@ -73,7 +77,7 @@ void TSNE_fit_sparse(const raft::handle_t& handle, knn_graph k_graph(n, params.n_neighbors, knn_indices, knn_dists); float kl_div_v = _fit, knn_indices_sparse_t, float>( - handle, input, k_graph, params); + handle, input, k_graph, params, metric); if (kl_div) { *kl_div = kl_div_v; } } diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index d7a1a2453a..2724e0666d 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "barnes_hut_tsne.cuh" @@ -35,13 +36,15 @@ class TSNE_runner { TSNE_runner(const raft::handle_t& handle_, tsne_input& input_, knn_graph& k_graph_, - TSNEParams& params_) + TSNEParams& params_, + raft::distance::DistanceType metric_) : handle(handle_), input(input_), k_graph(k_graph_), params(params_), - COO_Matrix(handle_.get_stream()) - { + COO_Matrix(handle_.get_stream()), + metric(metric_) + { this->n = input.n; this->p = input.d; this->Y = input.y; @@ -117,7 +120,7 @@ class TSNE_runner { k_graph.knn_indices = indices.data(); k_graph.knn_dists = distances.data(); - TSNE::get_distances(handle, input, k_graph, stream); + TSNE::get_distances(handle, input, k_graph, stream, metric); } if (params.square_distances) { @@ -187,6 +190,7 @@ class TSNE_runner { tsne_input& input; knn_graph& k_graph; TSNEParams& params; + raft::distance::DistanceType metric; value_idx n, p; value_t* Y; diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index c7c79d8535..aeed31de23 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -108,6 +109,7 @@ class TSNETest : public ::testing::TestWithParam { auto stream = handle.get_stream(); TSNEResults results; + auto DEFAULT_DISTANCE_METRIC = raft::distance::DistanceType::L2SqrtExpanded; // Setup parameters model_params.algorithm = algo; model_params.dim = 2; @@ -132,11 +134,11 @@ class TSNETest : public ::testing::TestWithParam { input_dists.resize(n * model_params.n_neighbors, stream); k_graph.knn_indices = input_indices.data(); k_graph.knn_dists = input_dists.data(); - TSNE::get_distances(handle, input, k_graph, stream); + TSNE::get_distances(handle, input, k_graph, stream, DEFAULT_DISTANCE_METRIC); } handle.sync_stream(stream); TSNE_runner, knn_indices_dense_t, float> runner( - handle, input, k_graph, model_params); + handle, input, k_graph, model_params, DEFAULT_DISTANCE_METRIC); results.kl_div = runner.run(); // Compute embedding's pairwise distances diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index deefdd4763..937c60c4e8 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -39,6 +39,7 @@ from cuml.common.doc_utils import generate_docstring from cuml.common import input_to_cuml_array from cuml.common.mixins import CMajorInputTagMixin from cuml.common.sparsefuncs import extract_knn_graph +from cuml.metrics.distance_type cimport DistanceType import rmm from libcpp cimport bool @@ -93,6 +94,7 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML": int64_t* knn_indices, float* knn_dists, TSNEParams ¶ms, + DistanceType metric, float* kl_div) except + cdef void TSNE_fit_sparse( @@ -107,6 +109,7 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML": int* knn_indices, float* knn_dists, TSNEParams ¶ms, + DistanceType metric, float* kl_div) except + @@ -302,11 +305,6 @@ class TSNE(Base, if n_iter <= 100: warnings.warn("n_iter = {} might cause TSNE to output wrong " "results. Set it higher.".format(n_iter)) - if metric.lower() != 'euclidean': - # TODO https://github.com/rapidsai/cuml/issues/1653 - warnings.warn("TSNE does not support {} (only Euclidean).".format( - metric)) - metric = 'euclidean' if init.lower() != 'random': # TODO https://github.com/rapidsai/cuml/issues/3458 warnings.warn("TSNE does not support {} but only random " @@ -497,6 +495,36 @@ class TSNE(Base, self._build_tsne_params(algo) cdef float kl_divergence = 0 + + # metric + metric_parsing = { + "l2": DistanceType.L2SqrtUnexpanded, + "euclidean": DistanceType.L2SqrtUnexpanded, + "sqeuclidean": DistanceType.L2Expanded, + "cityblock": DistanceType.L1, + "l1": DistanceType.L1, + "manhattan": DistanceType.L1, + "taxicab": DistanceType.L1, + "braycurtis": DistanceType.BrayCurtis, + "canberra": DistanceType.Canberra, + "minkowski": DistanceType.LpUnexpanded, + "lp": DistanceType.LpUnexpanded, + "chebyshev": DistanceType.Linf, + "linf": DistanceType.Linf, + "jensenshannon": DistanceType.JensenShannon, + "cosine": DistanceType.CosineExpanded, + "correlation": DistanceType.CorrelationExpanded, + "inner_product": DistanceType.InnerProduct, + "jaccard": DistanceType.JaccardExpanded, + "hellinger": DistanceType.HellingerExpanded, + "haversine": DistanceType.Haversine + } + if self.metric.lower() in metric_parsing: + metric = metric_parsing[self.metric.lower()] + else: + raise ValueError("Invalid value for metric: {}" + .format(self.metric)) + if self.sparse_fit: TSNE_fit_sparse(handle_[0], @@ -512,6 +540,7 @@ class TSNE(Base, knn_indices_raw, knn_dists_raw, deref(params), + metric, &kl_divergence) else: TSNE_fit(handle_[0], @@ -522,6 +551,7 @@ class TSNE(Base, knn_indices_raw, knn_dists_raw, deref(params), + metric, &kl_divergence) self.handle.sync() From dbefb19cf65f16ad11059bf116c29b1124cf50d6 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 15 Jun 2022 17:08:58 -0700 Subject: [PATCH 02/17] Added other distance metrics for UMAP --- cpp/include/cuml/manifold/umap.hpp | 1 + cpp/include/cuml/manifold/umapparams.h | 3 +++ cpp/src/umap/knn_graph/algo.cuh | 5 ++-- python/cuml/manifold/umap.pyx | 36 ++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 2 deletions(-) diff --git a/cpp/include/cuml/manifold/umap.hpp b/cpp/include/cuml/manifold/umap.hpp index 008dbb7155..6be69ff321 100644 --- a/cpp/include/cuml/manifold/umap.hpp +++ b/cpp/include/cuml/manifold/umap.hpp @@ -21,6 +21,7 @@ #include #include #include +#include namespace raft { class handle_t; diff --git a/cpp/include/cuml/manifold/umapparams.h b/cpp/include/cuml/manifold/umapparams.h index 055ab2c897..f4e192dfb6 100644 --- a/cpp/include/cuml/manifold/umapparams.h +++ b/cpp/include/cuml/manifold/umapparams.h @@ -18,6 +18,7 @@ #include #include +#include namespace ML { @@ -158,6 +159,8 @@ class UMAPParams { bool deterministic = true; Internals::GraphBasedDimRedCallback* callback = nullptr; + + raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded; }; } // namespace ML diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 632970cdab..e111a395b6 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -69,7 +69,8 @@ void launcher(const raft::handle_t& handle, inputsB.n, out.knn_indices, out.knn_dists, - n_neighbors); + n_neighbors, + params->metric); } // Instantiation for dense inputs, int indices @@ -112,7 +113,7 @@ void launcher(const raft::handle_t& handle, handle, ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, - raft::distance::DistanceType::L2Expanded); + params->metric); } template <> diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 0979b5a6d1..0c7a9749d1 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -50,6 +50,7 @@ from cuml.common.array import CumlArray from cuml.common.array_sparse import SparseCumlArray from cuml.common.mixins import CMajorInputTagMixin from cuml.common.sparse_utils import is_sparse +from cuml.metrics.distance_type cimport DistanceType from cuml.manifold.simpl_set import fuzzy_simplicial_set, \ simplicial_set_embedding @@ -149,6 +150,10 @@ class UMAP(Base, n_components: int (optional, default 2) The dimension of the space to embed into. This defaults to 2 to provide easy visualization, but can reasonably be set to any + metric : string (default='euclidean'). + Distance metric to use. Supported distances are ['l1, 'cityblock', + 'taxicab', 'manhattan', 'euclidean', 'l2', 'braycurtis', 'canberra', + 'minkowski', 'chebyshev', 'jensenshannon', 'cosine', 'correlation'] n_epochs: int (optional, default None) The number of training epochs to be used in optimizing the low dimensional embedding. Larger values result in more accurate @@ -298,6 +303,7 @@ class UMAP(Base, def __init__(self, *, n_neighbors=15, n_components=2, + metric="euclidean", n_epochs=None, learning_rate=1.0, min_dist=0.1, @@ -328,6 +334,7 @@ class UMAP(Base, self.n_neighbors = n_neighbors self.n_components = n_components + self.metric = metric self.n_epochs = n_epochs if n_epochs else 0 if init == "spectral" or init == "random": @@ -419,6 +426,35 @@ class UMAP(Base, umap_params.random_state = cls.random_state umap_params.deterministic = cls.deterministic + # metric + metric_parsing = { + "l2": DistanceType.L2SqrtUnexpanded, + "euclidean": DistanceType.L2SqrtUnexpanded, + "sqeuclidean": DistanceType.L2Expanded, + "cityblock": DistanceType.L1, + "l1": DistanceType.L1, + "manhattan": DistanceType.L1, + "taxicab": DistanceType.L1, + "braycurtis": DistanceType.BrayCurtis, + "canberra": DistanceType.Canberra, + "minkowski": DistanceType.LpUnexpanded, + "lp": DistanceType.LpUnexpanded, + "chebyshev": DistanceType.Linf, + "linf": DistanceType.Linf, + "jensenshannon": DistanceType.JensenShannon, + "cosine": DistanceType.CosineExpanded, + "correlation": DistanceType.CorrelationExpanded, + "inner_product": DistanceType.InnerProduct, + "jaccard": DistanceType.JaccardExpanded, + "hellinger": DistanceType.HellingerExpanded, + "haversine": DistanceType.Haversine + } + if cls.metric.lower() in metric_parsing: + umap_params.metric = metric_parsing[cls.metric.lower()] + else: + raise ValueError("Invalid value for metric: {}" + .format(cls.metric)) + cdef uintptr_t callback_ptr = 0 if cls.callback: callback_ptr = cls.callback.get_native_callback() From d934f1eda752f3ecea33e62326a6f12a6fe54541 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 16 Jun 2022 09:40:50 -0700 Subject: [PATCH 03/17] Modified UMAPPARAMS --- cpp/include/cuml/manifold/umapparams.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/cuml/manifold/umapparams.h b/cpp/include/cuml/manifold/umapparams.h index f4e192dfb6..41e6e63f74 100644 --- a/cpp/include/cuml/manifold/umapparams.h +++ b/cpp/include/cuml/manifold/umapparams.h @@ -157,10 +157,10 @@ class UMAPParams { higher memory usage but produce stable numeric output. */ bool deterministic = true; + + raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded; Internals::GraphBasedDimRedCallback* callback = nullptr; - - raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded; }; } // namespace ML From ce9def66fa4502787a021a69263b5866072b140d Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 17 Jun 2022 12:32:47 -0700 Subject: [PATCH 04/17] Restructured tsne code with supported distance metrics --- cpp/include/cuml/manifold/tsne.h | 6 +-- cpp/src/tsne/tsne.cu | 15 +++---- cpp/src/tsne/tsne_runner.cuh | 9 ++-- cpp/test/sg/tsne_test.cu | 2 +- python/cuml/manifold/t_sne.pyx | 58 +++++++++++-------------- python/cuml/tests/test_tsne.py | 72 +++++++++++++++++++++++++++++++- 6 files changed, 109 insertions(+), 53 deletions(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index bdebca4858..1526803d40 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -102,6 +102,9 @@ struct TSNEParams { // behavior of Scikit-learn's T-SNE. bool square_distances = true; + //Distance metric to use. + raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded; + // Which implementation algorithm to use. TSNE_ALGORITHM algorithm = TSNE_ALGORITHM::FFT; }; @@ -136,7 +139,6 @@ void TSNE_fit(const raft::handle_t& handle, int64_t* knn_indices, float* knn_dists, TSNEParams& params, - raft::distance::DistanceType metric, float* kl_div = nullptr); /** @@ -155,7 +157,6 @@ void TSNE_fit(const raft::handle_t& handle, * @param[in] knn_dists Array containing nearest neighors distances. * @param[in] params Parameters for TSNE model * @param[out] kl_div (optional) KL divergence output - * @param[in] metric Distance metric * * The CUDA implementation is derived from the excellent CannyLabs open source * implementation here: https://github.com/CannyLab/tsne-cuda/. The CannyLabs @@ -175,7 +176,6 @@ void TSNE_fit_sparse(const raft::handle_t& handle, int* knn_indices, float* knn_dists, TSNEParams& params, - raft::distance::DistanceType metric, float* kl_div = nullptr); } // namespace ML diff --git a/cpp/src/tsne/tsne.cu b/cpp/src/tsne/tsne.cu index 1cb919a54f..cc2d2e89dd 100644 --- a/cpp/src/tsne/tsne.cu +++ b/cpp/src/tsne/tsne.cu @@ -24,10 +24,9 @@ template value_t _fit(const raft::handle_t& handle, tsne_input& input, knn_graph& k_graph, - TSNEParams& params, - raft::distance::DistanceType metric) + TSNEParams& params) { - TSNE_runner runner(handle, input, k_graph, params, metric); + TSNE_runner runner(handle, input, k_graph, params); return runner.run(); // returns the Kullback–Leibler divergence } @@ -40,8 +39,7 @@ void TSNE_fit(const raft::handle_t& handle, int64_t* knn_indices, float* knn_dists, TSNEParams& params, - float* kl_div, - raft::distance::DistanceType metric) + float* kl_div) { ASSERT(n > 0 && p > 0 && params.dim > 0 && params.n_neighbors > 0 && X != NULL && Y != NULL, "Wrong input args"); @@ -50,7 +48,7 @@ void TSNE_fit(const raft::handle_t& handle, knn_graph k_graph(n, params.n_neighbors, knn_indices, knn_dists); float kl_div_v = _fit, knn_indices_dense_t, float>( - handle, input, k_graph, params, metric); + handle, input, k_graph, params); if (kl_div) { *kl_div = kl_div_v; } } @@ -66,8 +64,7 @@ void TSNE_fit_sparse(const raft::handle_t& handle, int* knn_indices, float* knn_dists, TSNEParams& params, - float* kl_div, - raft::distance::DistanceType metric) + float* kl_div) { ASSERT(n > 0 && p > 0 && params.dim > 0 && params.n_neighbors > 0 && indptr != NULL && indices != NULL && data != NULL && Y != NULL, @@ -77,7 +74,7 @@ void TSNE_fit_sparse(const raft::handle_t& handle, knn_graph k_graph(n, params.n_neighbors, knn_indices, knn_dists); float kl_div_v = _fit, knn_indices_sparse_t, float>( - handle, input, k_graph, params, metric); + handle, input, k_graph, params); if (kl_div) { *kl_div = kl_div_v; } } diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index 2724e0666d..24837c6802 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -36,14 +36,12 @@ class TSNE_runner { TSNE_runner(const raft::handle_t& handle_, tsne_input& input_, knn_graph& k_graph_, - TSNEParams& params_, - raft::distance::DistanceType metric_) + TSNEParams& params_) : handle(handle_), input(input_), k_graph(k_graph_), params(params_), - COO_Matrix(handle_.get_stream()), - metric(metric_) + COO_Matrix(handle_.get_stream()) { this->n = input.n; this->p = input.d; @@ -120,7 +118,7 @@ class TSNE_runner { k_graph.knn_indices = indices.data(); k_graph.knn_dists = distances.data(); - TSNE::get_distances(handle, input, k_graph, stream, metric); + TSNE::get_distances(handle, input, k_graph, stream, params.metric); } if (params.square_distances) { @@ -190,7 +188,6 @@ class TSNE_runner { tsne_input& input; knn_graph& k_graph; TSNEParams& params; - raft::distance::DistanceType metric; value_idx n, p; value_t* Y; diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index aeed31de23..f9b553816b 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -138,7 +138,7 @@ class TSNETest : public ::testing::TestWithParam { } handle.sync_stream(stream); TSNE_runner, knn_indices_dense_t, float> runner( - handle, input, k_graph, model_params, DEFAULT_DISTANCE_METRIC); + handle, input, k_graph, model_params); results.kl_div = runner.run(); // Compute embedding's pairwise distances diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 937c60c4e8..551240fb9f 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -80,6 +80,7 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML": int verbosity, bool initialize_embeddings, bool square_distances, + DistanceType metric, TSNE_ALGORITHM algorithm @@ -94,7 +95,6 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML": int64_t* knn_indices, float* knn_dists, TSNEParams ¶ms, - DistanceType metric, float* kl_div) except + cdef void TSNE_fit_sparse( @@ -109,7 +109,6 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML": int* knn_indices, float* knn_dists, TSNEParams ¶ms, - DistanceType metric, float* kl_div) except + @@ -424,6 +423,7 @@ class TSNE(Base, convert_format=False) n, p = self.X_m.shape self.sparse_fit = True + # Handle dense inputs else: self.X_m, n, p, _ = \ @@ -496,35 +496,6 @@ class TSNE(Base, cdef float kl_divergence = 0 - # metric - metric_parsing = { - "l2": DistanceType.L2SqrtUnexpanded, - "euclidean": DistanceType.L2SqrtUnexpanded, - "sqeuclidean": DistanceType.L2Expanded, - "cityblock": DistanceType.L1, - "l1": DistanceType.L1, - "manhattan": DistanceType.L1, - "taxicab": DistanceType.L1, - "braycurtis": DistanceType.BrayCurtis, - "canberra": DistanceType.Canberra, - "minkowski": DistanceType.LpUnexpanded, - "lp": DistanceType.LpUnexpanded, - "chebyshev": DistanceType.Linf, - "linf": DistanceType.Linf, - "jensenshannon": DistanceType.JensenShannon, - "cosine": DistanceType.CosineExpanded, - "correlation": DistanceType.CorrelationExpanded, - "inner_product": DistanceType.InnerProduct, - "jaccard": DistanceType.JaccardExpanded, - "hellinger": DistanceType.HellingerExpanded, - "haversine": DistanceType.Haversine - } - if self.metric.lower() in metric_parsing: - metric = metric_parsing[self.metric.lower()] - else: - raise ValueError("Invalid value for metric: {}" - .format(self.metric)) - if self.sparse_fit: TSNE_fit_sparse(handle_[0], @@ -540,7 +511,6 @@ class TSNE(Base, knn_indices_raw, knn_dists_raw, deref(params), - metric, &kl_divergence) else: TSNE_fit(handle_[0], @@ -551,7 +521,6 @@ class TSNE(Base, knn_indices_raw, knn_dists_raw, deref(params), - metric, &kl_divergence) self.handle.sync() @@ -613,6 +582,29 @@ class TSNE(Base, params.initialize_embeddings = True params.square_distances = self.square_distances params.algorithm = algo + + # metric + metric_parsing = { + "l2": DistanceType.L2SqrtUnexpanded, + "euclidean": DistanceType.L2SqrtUnexpanded, + "sqeuclidean": DistanceType.L2Expanded, + "cityblock": DistanceType.L1, + "l1": DistanceType.L1, + "manhattan": DistanceType.L1, + "taxicab": DistanceType.L1, + "minkowski": DistanceType.LpUnexpanded, + "chebyshev": DistanceType.Linf, + "linf": DistanceType.Linf, + "cosine": DistanceType.CosineExpanded, + "correlation": DistanceType.CorrelationExpanded, + } + + if self.metric.lower() in metric_parsing: + params.metric = metric_parsing[self.metric.lower()] + else: + raise ValueError("Invalid value for metric: {}" + .format(self.metric)) + return params @property diff --git a/python/cuml/tests/test_tsne.py b/python/cuml/tests/test_tsne.py index b9b23e319c..ac3dfc9e47 100644 --- a/python/cuml/tests/test_tsne.py +++ b/python/cuml/tests/test_tsne.py @@ -26,7 +26,6 @@ from sklearn.manifold import trustworthiness from sklearn import datasets - pytestmark = pytest.mark.filterwarnings("ignore:Method 'fft' is " "experimental::") @@ -271,3 +270,74 @@ def test_tsne_knn_parameters_sparse(type_knn_graph, input_type, method): if input_type == 'cupy': Y = Y.get() validate_embedding(digits, Y, 0.85) + + +@pytest.mark.parametrize('dataset', test_datasets.values()) +@pytest.mark.parametrize('method', ['barnes_hut']) +@pytest.mark.parametrize('metric', ['l2', 'euclidean', 'sqeuclidean', 'cityblock', 'l1', + 'manhattan', 'minkowski', 'chebyshev', "cosine", "correlation"]) +def test_tsne_distance_metrics(dataset, method, metric): + """ + This tests how TSNE handles a lot of input data across time. + (1) Numpy arrays are passed in + (2) Params are changed in the TSNE class + (3) The class gets re-used across time + (4) Trustworthiness is checked + (5) Tests NAN in TSNE output for learning rate explosions + (6) Tests verbosity + """ + X = dataset.data + + tsne = TSNE(n_components=2, + random_state=1, + n_neighbors=DEFAULT_N_NEIGHBORS, + learning_rate_method='none', + method=method, + min_grad_norm=1e-12, + perplexity=DEFAULT_PERPLEXITY, + metric=metric) + + """Compares TSNE embedding trustworthiness, NAN and verbosity""" + Y = tsne.fit_transform(X) + nans = np.sum(np.isnan(Y)) + trust = trustworthiness(X, Y, n_neighbors=DEFAULT_N_NEIGHBORS, metric=metric) + + print("Trust=%s" % trust) + assert trust > 0.85 + assert nans == 0 + + +@pytest.mark.parametrize('input_type', ['cupy', 'scipy']) +@pytest.mark.parametrize('method', ['fft', 'barnes_hut']) +@pytest.mark.parametrize('metric', ['l2', 'euclidean', 'sqeuclidean', 'cityblock', 'l1', + 'manhattan', 'minkowski', 'chebyshev', "cosine", "correlation"]) +def test_tsne_fit_transform_on_digits_sparse_distance_metrics(input_type, method, metric): + + digits = test_datasets['digits'].data + + if input_type == 'cupy': + sp_prefix = cupyx.scipy.sparse + else: + sp_prefix = scipy.sparse + + fitter = TSNE(n_components=2, + random_state=1, + method=method, + min_grad_norm=1e-12, + n_neighbors=DEFAULT_N_NEIGHBORS, + learning_rate_method="none", + perplexity=DEFAULT_PERPLEXITY, + metric=metric) + + new_data = sp_prefix.csr_matrix( + scipy.sparse.csr_matrix(digits)).astype('float32') + + embedding = fitter.fit_transform(new_data, convert_dtype=True) + + if input_type == 'cupy': + embedding = embedding.get() + + trust = trustworthiness(digits, embedding, + n_neighbors=DEFAULT_N_NEIGHBORS, metric=metric) + assert trust >= 0.85 + \ No newline at end of file From 08c3669412f38918aff9d084b3ae8261bb91eb65 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Jun 2022 13:43:22 -0700 Subject: [PATCH 05/17] Added minkowski distance parameter p --- cpp/include/cuml/manifold/tsne.h | 3 +++ cpp/include/cuml/manifold/umapparams.h | 2 ++ cpp/src/tsne/distances.cuh | 21 ++++++++++------ cpp/src/tsne/tsne_runner.cuh | 2 +- cpp/src/umap/knn_graph/algo.cuh | 6 +++-- cpp/test/sg/tsne_test.cu | 3 ++- python/cuml/manifold/t_sne.pyx | 9 +++++++ python/cuml/manifold/umap.pyx | 33 ++++++++++++++++---------- python/cuml/manifold/umap_utils.pxd | 4 +++- python/cuml/tests/test_umap.py | 20 ++++++++++++++++ 10 files changed, 79 insertions(+), 24 deletions(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index 1526803d40..a0659cb96c 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -105,6 +105,9 @@ struct TSNEParams { //Distance metric to use. raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded; + //Value of p for Minkowski distance + float p = 2.0; + // Which implementation algorithm to use. TSNE_ALGORITHM algorithm = TSNE_ALGORITHM::FFT; }; diff --git a/cpp/include/cuml/manifold/umapparams.h b/cpp/include/cuml/manifold/umapparams.h index 41e6e63f74..dc0fffbe55 100644 --- a/cpp/include/cuml/manifold/umapparams.h +++ b/cpp/include/cuml/manifold/umapparams.h @@ -160,6 +160,8 @@ class UMAPParams { raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded; + float p = 2.0; + Internals::GraphBasedDimRedCallback* callback = nullptr; }; diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index e809280caf..1a01f4ca12 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -51,7 +51,8 @@ void get_distances(const raft::handle_t& handle, tsne_input& input, knn_graph& k_graph, cudaStream_t stream, - raft::distance::DistanceType metric); + raft::distance::DistanceType metric, + value_t p); // dense, int64 indices template <> @@ -59,7 +60,8 @@ void get_distances(const raft::handle_t& handle, manifold_dense_inputs_t& input, knn_graph& k_graph, cudaStream_t stream, - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, + float p) { // TODO: for TSNE transform first fit some points then transform with 1/(1+d^2) // #861 @@ -87,7 +89,8 @@ void get_distances(const raft::handle_t& handle, true, true, nullptr, - metric); + metric, + p); } // dense, int32 indices @@ -96,7 +99,8 @@ void get_distances(const raft::handle_t& handle, manifold_dense_inputs_t& input, knn_graph& k_graph, cudaStream_t stream, - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, + float p) { throw raft::exception("Dense TSNE does not support 32-bit integer indices yet."); } @@ -107,7 +111,8 @@ void get_distances(const raft::handle_t& handle, manifold_sparse_inputs_t& input, knn_graph& k_graph, cudaStream_t stream, - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, + float p) { raft::sparse::selection::brute_force_knn(input.indptr, input.indices, @@ -127,7 +132,8 @@ void get_distances(const raft::handle_t& handle, handle, ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, - metric); + metric, + p); } // sparse, int64 @@ -136,7 +142,8 @@ void get_distances(const raft::handle_t& handle, manifold_sparse_inputs_t& input, knn_graph& k_graph, cudaStream_t stream, - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, + float p) { throw raft::exception("Sparse TSNE does not support 64-bit integer indices yet."); } diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index 24837c6802..0a5adb8bee 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -118,7 +118,7 @@ class TSNE_runner { k_graph.knn_indices = indices.data(); k_graph.knn_dists = distances.data(); - TSNE::get_distances(handle, input, k_graph, stream, params.metric); + TSNE::get_distances(handle, input, k_graph, stream, params.metric, params.p); } if (params.square_distances) { diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index e111a395b6..65918d88fb 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -70,7 +70,8 @@ void launcher(const raft::handle_t& handle, out.knn_indices, out.knn_dists, n_neighbors, - params->metric); + params->metric, + params->p); } // Instantiation for dense inputs, int indices @@ -113,7 +114,8 @@ void launcher(const raft::handle_t& handle, handle, ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, - params->metric); + params->metric, + params->p); } template <> diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index f9b553816b..53d2590544 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -110,6 +110,7 @@ class TSNETest : public ::testing::TestWithParam { TSNEResults results; auto DEFAULT_DISTANCE_METRIC = raft::distance::DistanceType::L2SqrtExpanded; + float p = 2.0; // Setup parameters model_params.algorithm = algo; model_params.dim = 2; @@ -134,7 +135,7 @@ class TSNETest : public ::testing::TestWithParam { input_dists.resize(n * model_params.n_neighbors, stream); k_graph.knn_indices = input_indices.data(); k_graph.knn_dists = input_dists.data(); - TSNE::get_distances(handle, input, k_graph, stream, DEFAULT_DISTANCE_METRIC); + TSNE::get_distances(handle, input, k_graph, stream, DEFAULT_DISTANCE_METRIC, p); } handle.sync_stream(stream); TSNE_runner, knn_indices_dense_t, float> runner( diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 551240fb9f..84966fa241 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -81,6 +81,7 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML": bool initialize_embeddings, bool square_distances, DistanceType metric, + float p, TSNE_ALGORITHM algorithm @@ -261,6 +262,7 @@ class TSNE(Base, n_iter_without_progress=300, min_grad_norm=1e-07, metric='euclidean', + metric_params=None, init='random', verbose=False, random_state=None, @@ -350,6 +352,7 @@ class TSNE(Base, self.n_iter_without_progress = n_iter_without_progress self.min_grad_norm = min_grad_norm self.metric = metric + self.metric_params = metric_params self.init = init self.random_state = random_state self.method = method @@ -597,6 +600,7 @@ class TSNE(Base, "linf": DistanceType.Linf, "cosine": DistanceType.CosineExpanded, "correlation": DistanceType.CorrelationExpanded, + "hellinger": DistanceType.HellingerExpanded } if self.metric.lower() in metric_parsing: @@ -604,6 +608,11 @@ class TSNE(Base, else: raise ValueError("Invalid value for metric: {}" .format(self.metric)) + + if self.metric_params is None: + params.p = 2.0 + else: + params.p = self.metric_params.get('p') return params diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 0c7a9749d1..6f4724859a 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -152,8 +152,13 @@ class UMAP(Base, provide easy visualization, but can reasonably be set to any metric : string (default='euclidean'). Distance metric to use. Supported distances are ['l1, 'cityblock', - 'taxicab', 'manhattan', 'euclidean', 'l2', 'braycurtis', 'canberra', - 'minkowski', 'chebyshev', 'jensenshannon', 'cosine', 'correlation'] + 'taxicab', 'manhattan', 'euclidean', 'l2', 'canberra', 'minkowski', + 'chebyshev', 'linf', 'cosine', 'correlation', 'hellinger', 'hamming', + 'jaccard', 'canberra'] + Metrics that take arguments (such as minkowski) can have arguments + passed via the metric_kwds dictionary. At this time care must + be taken and dictionary elements must be ordered appropriately; + this will hopefully be fixed in the future. n_epochs: int (optional, default None) The number of training epochs to be used in optimizing the low dimensional embedding. Larger values result in more accurate @@ -304,6 +309,7 @@ class UMAP(Base, n_neighbors=15, n_components=2, metric="euclidean", + metric_kwds=None, n_epochs=None, learning_rate=1.0, min_dist=0.1, @@ -335,6 +341,7 @@ class UMAP(Base, self.n_neighbors = n_neighbors self.n_components = n_components self.metric = metric + self.metric_kwds = metric_kwds self.n_epochs = n_epochs if n_epochs else 0 if init == "spectral" or init == "random": @@ -426,35 +433,37 @@ class UMAP(Base, umap_params.random_state = cls.random_state umap_params.deterministic = cls.deterministic - # metric + # metric metric_parsing = { "l2": DistanceType.L2SqrtUnexpanded, "euclidean": DistanceType.L2SqrtUnexpanded, - "sqeuclidean": DistanceType.L2Expanded, + "sqeuclidean": DistanceType.L2Unexpanded, "cityblock": DistanceType.L1, "l1": DistanceType.L1, "manhattan": DistanceType.L1, "taxicab": DistanceType.L1, - "braycurtis": DistanceType.BrayCurtis, - "canberra": DistanceType.Canberra, "minkowski": DistanceType.LpUnexpanded, - "lp": DistanceType.LpUnexpanded, "chebyshev": DistanceType.Linf, "linf": DistanceType.Linf, - "jensenshannon": DistanceType.JensenShannon, "cosine": DistanceType.CosineExpanded, "correlation": DistanceType.CorrelationExpanded, - "inner_product": DistanceType.InnerProduct, - "jaccard": DistanceType.JaccardExpanded, "hellinger": DistanceType.HellingerExpanded, - "haversine": DistanceType.Haversine + "hamming": DistanceType.HammingUnexpanded, + "jaccard": DistanceType.JaccardExpanded, + "canberra": DistanceType.Canberra } + if cls.metric.lower() in metric_parsing: umap_params.metric = metric_parsing[cls.metric.lower()] else: raise ValueError("Invalid value for metric: {}" .format(cls.metric)) - + + if cls.metric_kwds is None: + umap_params.p = 2.0 + else: + umap_params.p = cls.metric_kwds.get('p') + cdef uintptr_t callback_ptr = 0 if cls.callback: callback_ptr = cls.callback.get_native_callback() diff --git a/python/cuml/manifold/umap_utils.pxd b/python/cuml/manifold/umap_utils.pxd index f4bebeb7b9..a9edd64ff8 100644 --- a/python/cuml/manifold/umap_utils.pxd +++ b/python/cuml/manifold/umap_utils.pxd @@ -22,7 +22,7 @@ from libcpp.memory cimport unique_ptr from libc.stdint cimport uint64_t, uintptr_t, int64_t from libcpp cimport bool from libcpp.memory cimport shared_ptr - +from cuml.metrics.distance_type cimport DistanceType cdef extern from "cuml/manifold/umapparams.h" namespace "ML::UMAPParams": @@ -58,6 +58,8 @@ cdef extern from "cuml/manifold/umapparams.h" namespace "ML": float target_weight, uint64_t random_state, bool deterministic, + DistanceType metric, + float p, GraphBasedDimRedCallback * callback cdef extern from "raft/sparse/coo.hpp": diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index e7dbf83748..f5f2ba3e78 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -584,3 +584,23 @@ def test_fuzzy_simplicial_set(n_rows, atol=0.1, rtol=0.2, threshold=0.95) + + +@pytest.mark.parametrize('metric', ['l2', 'euclidean', 'sqeuclidean', 'l1', 'manhattan', 'minkowski', + 'chebyshev', 'cosine', 'correlation', 'jaccard', 'hamming', 'canberra']) +def test_umap_distance_metrics_fit_transform_trust(metric): + data, labels = make_blobs(n_samples=1000, n_features=64, + centers=5, random_state=42) + + if metric == 'jaccard': + data = data >= 0 + + umap_model = umap.UMAP(n_neighbors=10, min_dist=0.01, metric=metric, init='random') + cuml_model = cuUMAP(n_neighbors=10, min_dist=0.01, metric=metric, init='random') + umap_embedding = umap_model.fit_transform(data) + cuml_embedding = cuml_model.fit_transform(data) + + umap_trust = trustworthiness(data, umap_embedding, n_neighbors=10, metric=metric) + cuml_trust = trustworthiness(data, cuml_embedding, n_neighbors=10, metric=metric) + + assert array_equal(umap_trust, cuml_trust, 0.05, with_sign=True) From 81fcd23d54134f75e05b8cbb0ca44f212363cf95 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Jun 2022 13:45:34 -0700 Subject: [PATCH 06/17] Styling fixes --- cpp/include/cuml/manifold/tsne.h | 4 ++-- cpp/include/cuml/manifold/umap.hpp | 2 +- cpp/include/cuml/manifold/umapparams.h | 2 +- cpp/src/tsne/tsne_runner.cuh | 2 +- cpp/test/sg/tsne_test.cu | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index a0659cb96c..d9d3827bac 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -102,10 +102,10 @@ struct TSNEParams { // behavior of Scikit-learn's T-SNE. bool square_distances = true; - //Distance metric to use. + // Distance metric to use. raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded; - //Value of p for Minkowski distance + // Value of p for Minkowski distance float p = 2.0; // Which implementation algorithm to use. diff --git a/cpp/include/cuml/manifold/umap.hpp b/cpp/include/cuml/manifold/umap.hpp index 6be69ff321..bdc704460e 100644 --- a/cpp/include/cuml/manifold/umap.hpp +++ b/cpp/include/cuml/manifold/umap.hpp @@ -20,8 +20,8 @@ #include #include -#include #include +#include namespace raft { class handle_t; diff --git a/cpp/include/cuml/manifold/umapparams.h b/cpp/include/cuml/manifold/umapparams.h index dc0fffbe55..f3e854c06e 100644 --- a/cpp/include/cuml/manifold/umapparams.h +++ b/cpp/include/cuml/manifold/umapparams.h @@ -157,7 +157,7 @@ class UMAPParams { higher memory usage but produce stable numeric output. */ bool deterministic = true; - + raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded; float p = 2.0; diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index 0a5adb8bee..f4cec143e4 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -42,7 +42,7 @@ class TSNE_runner { k_graph(k_graph_), params(params_), COO_Matrix(handle_.get_stream()) - { + { this->n = input.n; this->p = input.d; this->Y = input.y; diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index 53d2590544..33479d85de 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -16,8 +16,8 @@ #include #include -#include #include +#include #include #include @@ -110,7 +110,7 @@ class TSNETest : public ::testing::TestWithParam { TSNEResults results; auto DEFAULT_DISTANCE_METRIC = raft::distance::DistanceType::L2SqrtExpanded; - float p = 2.0; + float p = 2.0; // Setup parameters model_params.algorithm = algo; model_params.dim = 2; From be0c6cc4825f9c0fea8948187978d8bbf928cd31 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Jun 2022 13:53:28 -0700 Subject: [PATCH 07/17] Style fixes --- python/cuml/tests/test_umap.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index f5f2ba3e78..bc774a1017 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -586,21 +586,27 @@ def test_fuzzy_simplicial_set(n_rows, threshold=0.95) -@pytest.mark.parametrize('metric', ['l2', 'euclidean', 'sqeuclidean', 'l1', 'manhattan', 'minkowski', - 'chebyshev', 'cosine', 'correlation', 'jaccard', 'hamming', 'canberra']) +@pytest.mark.parametrize('metric', ['l2', 'euclidean', 'sqeuclidean', 'l1', + 'manhattan', 'minkowski', 'chebyshev', + 'cosine', 'correlation', 'jaccard', + 'hamming', 'canberra']) def test_umap_distance_metrics_fit_transform_trust(metric): data, labels = make_blobs(n_samples=1000, n_features=64, - centers=5, random_state=42) - + centers=5, random_state=42) + if metric == 'jaccard': data = data >= 0 - umap_model = umap.UMAP(n_neighbors=10, min_dist=0.01, metric=metric, init='random') - cuml_model = cuUMAP(n_neighbors=10, min_dist=0.01, metric=metric, init='random') + umap_model = umap.UMAP(n_neighbors=10, min_dist=0.01, + metric=metric, init='random') + cuml_model = cuUMAP(n_neighbors=10, min_dist=0.01, + metric=metric, init='random') umap_embedding = umap_model.fit_transform(data) cuml_embedding = cuml_model.fit_transform(data) - umap_trust = trustworthiness(data, umap_embedding, n_neighbors=10, metric=metric) - cuml_trust = trustworthiness(data, cuml_embedding, n_neighbors=10, metric=metric) + umap_trust = trustworthiness(data, umap_embedding, + n_neighbors=10, metric=metric) + cuml_trust = trustworthiness(data, cuml_embedding, + n_neighbors=10, metric=metric) assert array_equal(umap_trust, cuml_trust, 0.05, with_sign=True) From 67d1625ad602a0cd336bc15dbfa4c085bea58de1 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Jun 2022 17:23:16 -0700 Subject: [PATCH 08/17] styling and metric changes --- python/cuml/manifold/t_sne.pyx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 84966fa241..18654d4360 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -426,7 +426,7 @@ class TSNE(Base, convert_format=False) n, p = self.X_m.shape self.sparse_fit = True - + # Handle dense inputs else: self.X_m, n, p, _ = \ @@ -594,13 +594,13 @@ class TSNE(Base, "cityblock": DistanceType.L1, "l1": DistanceType.L1, "manhattan": DistanceType.L1, - "taxicab": DistanceType.L1, "minkowski": DistanceType.LpUnexpanded, "chebyshev": DistanceType.Linf, - "linf": DistanceType.Linf, "cosine": DistanceType.CosineExpanded, "correlation": DistanceType.CorrelationExpanded, - "hellinger": DistanceType.HellingerExpanded + "jaccard": DistanceType.JaccardExpanded, + "canberra": DistanceType.Canberra, + "hamming": DistanceType.HammingUnexpanded } if self.metric.lower() in metric_parsing: @@ -608,7 +608,7 @@ class TSNE(Base, else: raise ValueError("Invalid value for metric: {}" .format(self.metric)) - + if self.metric_params is None: params.p = 2.0 else: From 453df40d529876b9a6a85a41700eacc725a897d7 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 22 Jun 2022 17:27:54 -0700 Subject: [PATCH 09/17] styling fixes (copyright) --- cpp/include/cuml/manifold/tsne.h | 2 +- cpp/include/cuml/manifold/umapparams.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index d9d3827bac..ec4a577a4f 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/cuml/manifold/umapparams.h b/cpp/include/cuml/manifold/umapparams.h index f3e854c06e..5bdcff6da4 100644 --- a/cpp/include/cuml/manifold/umapparams.h +++ b/cpp/include/cuml/manifold/umapparams.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From f2be71f5aaef76146c275e8bb63cf3a69cbbfdf5 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 23 Jun 2022 10:12:26 -0700 Subject: [PATCH 10/17] update tsne tests --- python/cuml/tests/test_tsne.py | 37 +++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/python/cuml/tests/test_tsne.py b/python/cuml/tests/test_tsne.py index ac3dfc9e47..2ed19cc866 100644 --- a/python/cuml/tests/test_tsne.py +++ b/python/cuml/tests/test_tsne.py @@ -19,12 +19,13 @@ import cupyx from cuml.manifold import TSNE -from cuml.testing.utils import stress_param +from cuml.testing.utils import array_equal, stress_param from cuml.neighbors import NearestNeighbors as cuKNN from sklearn.datasets import make_blobs from sklearn.manifold import trustworthiness from sklearn import datasets +from sklearn.manifold import TSNE as skTSNE pytestmark = pytest.mark.filterwarnings("ignore:Method 'fft' is " "experimental::") @@ -273,9 +274,11 @@ def test_tsne_knn_parameters_sparse(type_knn_graph, input_type, method): @pytest.mark.parametrize('dataset', test_datasets.values()) -@pytest.mark.parametrize('method', ['barnes_hut']) -@pytest.mark.parametrize('metric', ['l2', 'euclidean', 'sqeuclidean', 'cityblock', 'l1', - 'manhattan', 'minkowski', 'chebyshev', "cosine", "correlation"]) +@pytest.mark.parametrize('method', ['exact', 'barnes_hut']) +@pytest.mark.parametrize('metric', ['l2', 'euclidean', 'sqeuclidean', 'cityblock', + 'l1', 'manhattan', 'minkowski', 'chebyshev', + 'cosine', 'correlation', 'jaccard', + 'hamming', 'canberra']) def test_tsne_distance_metrics(dataset, method, metric): """ This tests how TSNE handles a lot of input data across time. @@ -291,21 +294,31 @@ def test_tsne_distance_metrics(dataset, method, metric): tsne = TSNE(n_components=2, random_state=1, n_neighbors=DEFAULT_N_NEIGHBORS, - learning_rate_method='none', method=method, + learning_rate_method='none', min_grad_norm=1e-12, perplexity=DEFAULT_PERPLEXITY, metric=metric) - """Compares TSNE embedding trustworthiness, NAN and verbosity""" - Y = tsne.fit_transform(X) - nans = np.sum(np.isnan(Y)) - trust = trustworthiness(X, Y, n_neighbors=DEFAULT_N_NEIGHBORS, metric=metric) + sk_tsne = skTSNE(n_components=2, + random_state=1, + min_grad_norm=1e-12, + method=method, + perplexity=DEFAULT_PERPLEXITY, + metric=metric) - print("Trust=%s" % trust) - assert trust > 0.85 + """Compares TSNE embedding trustworthiness, NAN and verbosity""" + cuml_embedding = tsne.fit_transform(X) + sk_embedding = sk_tsne.fit_transform(X) + nans = np.sum(np.isnan(cuml_embedding)) + cuml_trust = trustworthiness(X, cuml_embedding, + n_neighbors=DEFAULT_N_NEIGHBORS, metric=metric) + sk_trust = trustworthiness(X, sk_embedding, + n_neighbors=DEFAULT_N_NEIGHBORS, metric=metric) + + assert cuml_trust > 0.85 assert nans == 0 - + assert array_equal(sk_trust, cuml_trust, 0.05, with_sign=True) @pytest.mark.parametrize('input_type', ['cupy', 'scipy']) @pytest.mark.parametrize('method', ['fft', 'barnes_hut']) From d3120a5739be8d44c8247b5457744f3a570fa795 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 23 Jun 2022 16:55:26 -0700 Subject: [PATCH 11/17] Re-evaluate supported distance metrics, update tests --- python/cuml/manifold/t_sne.pyx | 12 ++-- python/cuml/manifold/umap.pyx | 6 +- python/cuml/tests/test_tsne.py | 122 ++++++++++++++++----------------- python/cuml/tests/test_umap.py | 31 +++++++++ 4 files changed, 99 insertions(+), 72 deletions(-) diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 18654d4360..b7f6c0bef6 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -154,9 +154,10 @@ class TSNE(Base, Used in the 'exact' and 'fft' algorithms. Consider reducing if the embeddings are unsatisfactory. It's recommended to use a smaller value for smaller datasets. - metric : str 'euclidean' only (default 'euclidean') - Currently only supports euclidean distance. Will support cosine in - a future release. + metric : str (default='euclidean'). + Distance metric to use. Supported distances are ['l1, 'cityblock', + 'manhattan', 'euclidean', 'l2', 'sqeuclidean', 'minkowski', + 'chebyshev', 'cosine', 'correlation'] init : str 'random' (default 'random') Currently supports random intialization. verbose : int or boolean, default=False @@ -597,10 +598,7 @@ class TSNE(Base, "minkowski": DistanceType.LpUnexpanded, "chebyshev": DistanceType.Linf, "cosine": DistanceType.CosineExpanded, - "correlation": DistanceType.CorrelationExpanded, - "jaccard": DistanceType.JaccardExpanded, - "canberra": DistanceType.Canberra, - "hamming": DistanceType.HammingUnexpanded + "correlation": DistanceType.CorrelationExpanded } if self.metric.lower() in metric_parsing: diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 6f4724859a..4b3cf9726a 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -152,9 +152,9 @@ class UMAP(Base, provide easy visualization, but can reasonably be set to any metric : string (default='euclidean'). Distance metric to use. Supported distances are ['l1, 'cityblock', - 'taxicab', 'manhattan', 'euclidean', 'l2', 'canberra', 'minkowski', - 'chebyshev', 'linf', 'cosine', 'correlation', 'hellinger', 'hamming', - 'jaccard', 'canberra'] + 'taxicab', 'manhattan', 'euclidean', 'l2', 'sqeuclidean', 'canberra', + 'minkowski', 'chebyshev', 'linf', 'cosine', 'correlation', 'hellinger', + 'hamming', 'jaccard'] Metrics that take arguments (such as minkowski) can have arguments passed via the metric_kwds dictionary. At this time care must be taken and dictionary elements must be ordered appropriately; diff --git a/python/cuml/tests/test_tsne.py b/python/cuml/tests/test_tsne.py index 2ed19cc866..bd9a58107a 100644 --- a/python/cuml/tests/test_tsne.py +++ b/python/cuml/tests/test_tsne.py @@ -273,84 +273,82 @@ def test_tsne_knn_parameters_sparse(type_knn_graph, input_type, method): validate_embedding(digits, Y, 0.85) -@pytest.mark.parametrize('dataset', test_datasets.values()) -@pytest.mark.parametrize('method', ['exact', 'barnes_hut']) -@pytest.mark.parametrize('metric', ['l2', 'euclidean', 'sqeuclidean', 'cityblock', - 'l1', 'manhattan', 'minkowski', 'chebyshev', - 'cosine', 'correlation', 'jaccard', - 'hamming', 'canberra']) -def test_tsne_distance_metrics(dataset, method, metric): - """ - This tests how TSNE handles a lot of input data across time. - (1) Numpy arrays are passed in - (2) Params are changed in the TSNE class - (3) The class gets re-used across time - (4) Trustworthiness is checked - (5) Tests NAN in TSNE output for learning rate explosions - (6) Tests verbosity - """ - X = dataset.data +@pytest.mark.parametrize('metric', ['l2', 'euclidean', 'sqeuclidean', + 'cityblock', 'l1', 'manhattan', + 'minkowski', 'chebyshev', + 'cosine', 'correlation']) +def test_tsne_distance_metrics(metric): + + data, labels = make_blobs(n_samples=1000, n_features=64, + centers=5, random_state=42) tsne = TSNE(n_components=2, random_state=1, n_neighbors=DEFAULT_N_NEIGHBORS, - method=method, + method='exact', learning_rate_method='none', min_grad_norm=1e-12, perplexity=DEFAULT_PERPLEXITY, metric=metric) - - sk_tsne = skTSNE(n_components=2, - random_state=1, - min_grad_norm=1e-12, - method=method, - perplexity=DEFAULT_PERPLEXITY, - metric=metric) - """Compares TSNE embedding trustworthiness, NAN and verbosity""" - cuml_embedding = tsne.fit_transform(X) - sk_embedding = sk_tsne.fit_transform(X) + sk_tsne = skTSNE(n_components=2, + random_state=1, + min_grad_norm=1e-12, + method='exact', + perplexity=DEFAULT_PERPLEXITY, + metric=metric) + + cuml_embedding = tsne.fit_transform(data) + sk_embedding = sk_tsne.fit_transform(data) nans = np.sum(np.isnan(cuml_embedding)) - cuml_trust = trustworthiness(X, cuml_embedding, - n_neighbors=DEFAULT_N_NEIGHBORS, metric=metric) - sk_trust = trustworthiness(X, sk_embedding, - n_neighbors=DEFAULT_N_NEIGHBORS, metric=metric) + cuml_trust = trustworthiness(data, cuml_embedding, metric=metric) + sk_trust = trustworthiness(data, sk_embedding, metric=metric) assert cuml_trust > 0.85 assert nans == 0 assert array_equal(sk_trust, cuml_trust, 0.05, with_sign=True) -@pytest.mark.parametrize('input_type', ['cupy', 'scipy']) -@pytest.mark.parametrize('method', ['fft', 'barnes_hut']) -@pytest.mark.parametrize('metric', ['l2', 'euclidean', 'sqeuclidean', 'cityblock', 'l1', - 'manhattan', 'minkowski', 'chebyshev', "cosine", "correlation"]) -def test_tsne_fit_transform_on_digits_sparse_distance_metrics(input_type, method, metric): - digits = test_datasets['digits'].data +@pytest.mark.parametrize('method', ['fft', 'barnes_hut', 'exact']) +@pytest.mark.parametrize('metric', ['l2', 'euclidean', 'cityblock', + 'l1', 'manhattan', 'cosine']) +def test_tsne_distance_metrics_on_sparse_input(method, metric): + + data, labels = make_blobs(n_samples=1000, n_features=64, + centers=5, random_state=42) + data_sparse = scipy.sparse.csr_matrix(data) + + cuml_tsne = TSNE(n_components=2, + random_state=1, + n_neighbors=DEFAULT_N_NEIGHBORS, + method=method, + learning_rate_method='none', + min_grad_norm=1e-12, + perplexity=DEFAULT_PERPLEXITY, + metric=metric) + + if method == 'fft': + sk_tsne = skTSNE(n_components=2, + random_state=1, + min_grad_norm=1e-12, + method='barnes_hut', + perplexity=DEFAULT_PERPLEXITY, + metric=metric) - if input_type == 'cupy': - sp_prefix = cupyx.scipy.sparse else: - sp_prefix = scipy.sparse - - fitter = TSNE(n_components=2, - random_state=1, - method=method, - min_grad_norm=1e-12, - n_neighbors=DEFAULT_N_NEIGHBORS, - learning_rate_method="none", - perplexity=DEFAULT_PERPLEXITY, - metric=metric) - - new_data = sp_prefix.csr_matrix( - scipy.sparse.csr_matrix(digits)).astype('float32') - - embedding = fitter.fit_transform(new_data, convert_dtype=True) - - if input_type == 'cupy': - embedding = embedding.get() + sk_tsne = skTSNE(n_components=2, + random_state=1, + min_grad_norm=1e-12, + method=method, + perplexity=DEFAULT_PERPLEXITY, + metric=metric) + + cuml_embedding = cuml_tsne.fit_transform(data_sparse) + nans = np.sum(np.isnan(cuml_embedding)) + sk_embedding = sk_tsne.fit_transform(data_sparse) + cu_trust = trustworthiness(data, cuml_embedding, metric=metric) + sk_trust = trustworthiness(data, sk_embedding, metric=metric) - trust = trustworthiness(digits, embedding, - n_neighbors=DEFAULT_N_NEIGHBORS, metric=metric) - assert trust >= 0.85 - \ No newline at end of file + assert cu_trust > 0.85 + assert nans == 0 + assert array_equal(sk_trust, cu_trust, 0.05, with_sign=True) diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index 4fcd442d96..51f7ef127f 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -594,3 +594,34 @@ def test_umap_distance_metrics_fit_transform_trust(metric): n_neighbors=10, metric=metric) assert array_equal(umap_trust, cuml_trust, 0.05, with_sign=True) + + +@pytest.mark.parametrize('metric', ['euclidean', 'l1', 'manhattan', + 'minkowski', 'chebyshev', + 'cosine', 'correlation', 'jaccard', + 'hamming', 'canberra']) +def test_umap_distance_metrics_fit_transform_trust_on_sparse_input(metric): + data, labels = make_blobs(n_samples=1000, n_features=64, + centers=5, random_state=42) + + data_selection = np.random.RandomState(42).choice( + [True, False], 1000, replace=True, p=[0.75, 0.25]) + + if metric == 'jaccard': + data = data >= 0 + + new_data = scipy.sparse.csr_matrix(data[~data_selection]) + + umap_model = umap.UMAP(n_neighbors=10, min_dist=0.01, + metric=metric, init='random') + cuml_model = cuUMAP(n_neighbors=10, min_dist=0.01, + metric=metric, init='random') + umap_embedding = umap_model.fit_transform(new_data) + cuml_embedding = cuml_model.fit_transform(new_data) + + umap_trust = trustworthiness(data[~data_selection], umap_embedding, + n_neighbors=10, metric=metric) + cuml_trust = trustworthiness(data[~data_selection], cuml_embedding, + n_neighbors=10, metric=metric) + + assert array_equal(umap_trust, cuml_trust, 0.05, with_sign=True) From f9ef5dfe1490e649e49134f51012e98ced8640eb Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 7 Jul 2022 06:20:51 -0700 Subject: [PATCH 12/17] Update UMAP metric docs --- python/cuml/manifold/umap.pyx | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 4b3cf9726a..92249a0dbc 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -156,9 +156,7 @@ class UMAP(Base, 'minkowski', 'chebyshev', 'linf', 'cosine', 'correlation', 'hellinger', 'hamming', 'jaccard'] Metrics that take arguments (such as minkowski) can have arguments - passed via the metric_kwds dictionary. At this time care must - be taken and dictionary elements must be ordered appropriately; - this will hopefully be fixed in the future. + passed via the metric_kwds dictionary. n_epochs: int (optional, default None) The number of training epochs to be used in optimizing the low dimensional embedding. Larger values result in more accurate From a82c0819e3afc63c4743e9afbc15fc36d613736e Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 18 Jul 2022 09:02:58 -0700 Subject: [PATCH 13/17] correction in TSNE gtest --- cpp/test/sg/tsne_test.cu | 5 +++-- python/cuml/manifold/t_sne.pyx | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index ec1002a6e3..dfd17ca0ae 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -111,7 +111,8 @@ class TSNETest : public ::testing::TestWithParam { TSNEResults results; auto DEFAULT_DISTANCE_METRIC = raft::distance::DistanceType::L2SqrtExpanded; - float p = 2.0; + float minkowski_p = 2.0; + // Setup parameters model_params.algorithm = algo; model_params.dim = 2; @@ -136,7 +137,7 @@ class TSNETest : public ::testing::TestWithParam { input_dists.resize(n * model_params.n_neighbors, stream); k_graph.knn_indices = input_indices.data(); k_graph.knn_dists = input_dists.data(); - TSNE::get_distances(handle, input, k_graph, stream, DEFAULT_DISTANCE_METRIC, p); + TSNE::get_distances(handle, input, k_graph, stream, DEFAULT_DISTANCE_METRIC, minkowski_p); } handle.sync_stream(stream); TSNE_runner, knn_indices_dense_t, float> runner( diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index b7f6c0bef6..64c9367cb1 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -589,8 +589,8 @@ class TSNE(Base, # metric metric_parsing = { - "l2": DistanceType.L2SqrtUnexpanded, - "euclidean": DistanceType.L2SqrtUnexpanded, + "l2": DistanceType.L2SqrtExpanded, + "euclidean": DistanceType.L2SqrtExpanded, "sqeuclidean": DistanceType.L2Expanded, "cityblock": DistanceType.L1, "l1": DistanceType.L1, From 628c0335deb4c6afde7ba38fd10057ec498600e4 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 18 Jul 2022 09:10:15 -0700 Subject: [PATCH 14/17] Style fix --- cpp/test/sg/tsne_test.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index dfd17ca0ae..8c310d4400 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -111,7 +111,7 @@ class TSNETest : public ::testing::TestWithParam { TSNEResults results; auto DEFAULT_DISTANCE_METRIC = raft::distance::DistanceType::L2SqrtExpanded; - float minkowski_p = 2.0; + float minkowski_p = 2.0; // Setup parameters model_params.algorithm = algo; From e9c9195bf4c1b10377eb77f3746a1c2385858cbd Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Sat, 23 Jul 2022 12:00:38 -0700 Subject: [PATCH 15/17] Fix failing tests in CI --- python/cuml/manifold/t_sne.pyx | 1 + python/cuml/manifold/umap.pyx | 2 ++ 2 files changed, 3 insertions(+) diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 64c9367cb1..5309158490 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -654,6 +654,7 @@ class TSNE(Base, "n_iter_without_progress", "min_grad_norm", "metric", + "metric_params", "init", "random_state", "method", diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 92249a0dbc..860da1d158 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -811,4 +811,6 @@ class UMAP(Base, "hash_input", "random_state", "callback", + "metric", + "metric_kwds" ] From f74dc6f1401fbff1d19cb2a735850fd21c908014 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 25 Jul 2022 06:50:36 -0700 Subject: [PATCH 16/17] Update documentation based on PR Reviews --- python/cuml/manifold/t_sne.pyx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 5309158490..9956a1a027 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -193,11 +193,13 @@ class TSNE(Base, During the late phases, less forcefully apply gradients. square_distances : boolean, default=True Whether TSNE should square the distance values. - Internally, this will be used to compute a kNN graph using 'euclidean' + Internally, this will be used to compute a kNN graph using the provided metric and then squaring it when True. If a `knn_graph` is passed to `fit` or `fit_transform` methods, all the distances will be squared when True. For example, if a `knn_graph` was obtained using 'sqeuclidean' metric, the distances will still be squared when True. + Note: This argument should likely be set to False for distance metrics + other than 'euclidean' and 'l2'. handle : cuml.Handle Specifies the cuml.handle that holds internal CUDA state for computations in this model. Most importantly, this specifies the CUDA From 43b2de5b694ca8cb292be37393b80e9894667da6 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 3 Aug 2022 16:34:36 -0700 Subject: [PATCH 17/17] Updated docs for CI failure --- cpp/include/cuml/manifold/tsne.h | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index ec4a577a4f..127886bc9b 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -125,7 +125,6 @@ struct TSNEParams { * @param[in] knn_dists Array containing nearest neighors distances. * @param[in] params Parameters for TSNE model * @param[out] kl_div (optional) KL divergence output - * @param[in] metric Distance metric * * The CUDA implementation is derived from the excellent CannyLabs open source * implementation here: https://github.com/CannyLab/tsne-cuda/. The CannyLabs