diff --git a/cpp/include/cuml/neighbors/knn.hpp b/cpp/include/cuml/neighbors/knn.hpp index 32181e7b76..d468cb0e29 100644 --- a/cpp/include/cuml/neighbors/knn.hpp +++ b/cpp/include/cuml/neighbors/knn.hpp @@ -16,6 +16,9 @@ #pragma once +#include +#include +#include #include #include @@ -36,29 +39,71 @@ enum MetricType { METRIC_Correlation }; +struct knnIndex { + faiss::gpu::StandardGpuResources *gpu_res; + faiss::gpu::GpuIndex *index; + int device; + ~knnIndex() { + delete gpu_res; + delete index; + } +}; + +typedef enum { + QT_8bit, + QT_4bit, + QT_8bit_uniform, + QT_4bit_uniform, + QT_fp16, + QT_8bit_direct, + QT_6bit +} QuantizerType; + +struct knnIndexParam { + virtual ~knnIndexParam() {} +}; + +struct IVFParam : knnIndexParam { + int nlist; + int nprobe; +}; + +struct IVFFlatParam : IVFParam {}; + +struct IVFPQParam : IVFParam { + int M; + int n_bits; + bool usePrecomputedTables; +}; + +struct IVFSQParam : IVFParam { + QuantizerType qtype; + bool encodeResidual; +}; + /** - * @brief Flat C++ API function to perform a brute force knn on - * a series of input arrays and combine the results into a single - * output array for indexes and distances. - * - * @param[in] handle the cuml handle to use - * @param[in] input vector of pointers to the input arrays - * @param[in] sizes vector of sizes of input arrays - * @param[in] D the dimensionality of the arrays - * @param[in] search_items array of items to search of dimensionality D - * @param[in] n number of rows in search_items - * @param[out] res_I the resulting index array of size n * k - * @param[out] res_D the resulting distance array of size n * k - * @param[in] k the number of nearest neighbors to return - * @param[in] rowMajorIndex are the index arrays in row-major order? - * @param[in] rowMajorQuery are the query arrays in row-major order? - * @param[in] metric distance metric to use. Euclidean (L2) is used by - * default + * @brief Flat C++ API function to perform a brute force knn on + * a series of input arrays and combine the results into a single + * output array for indexes and distances. + * + * @param[in] handle the cuml handle to use + * @param[in] input vector of pointers to the input arrays + * @param[in] sizes vector of sizes of input arrays + * @param[in] D the dimensionality of the arrays + * @param[in] search_items array of items to search of dimensionality D + * @param[in] n number of rows in search_items + * @param[out] res_I the resulting index array of size n * k + * @param[out] res_D the resulting distance array of size n * k + * @param[in] k the number of nearest neighbors to return + * @param[in] rowMajorIndex are the index arrays in row-major order? + * @param[in] rowMajorQuery are the query arrays in row-major order? + * @param[in] metric distance metric to use. Euclidean (L2) is used by + * default * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This * is ignored if the metric_type is not Minkowski. * @param[in] expanded should lp-based distances be returned in their expanded * form (e.g., without raising to the 1/p power). - */ + */ void brute_force_knn(raft::handle_t &handle, std::vector &input, std::vector &sizes, int D, float *search_items, int n, int64_t *res_I, float *res_D, int k, @@ -66,6 +111,14 @@ void brute_force_knn(raft::handle_t &handle, std::vector &input, MetricType metric = MetricType::METRIC_L2, float metric_arg = 2.0f, bool expanded = false); +void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index, + ML::knnIndexParam *params, int D, + ML::MetricType metric, float metricArg, + float *index_items, int n); + +void approx_knn_search(ML::knnIndex *index, int n, const float *x, int k, + float *distances, int64_t *labels); + /** * @brief Flat C++ API function to perform a knn classification using a * given a vector of label arrays. This supports multilabel classification diff --git a/cpp/src/knn/knn.cu b/cpp/src/knn/knn.cu index 16b081cbc9..4055d61a82 100644 --- a/cpp/src/knn/knn.cu +++ b/cpp/src/knn/knn.cu @@ -49,6 +49,19 @@ void brute_force_knn(raft::handle_t &handle, std::vector &input, metric, metric_arg, expanded); } +void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index, + ML::knnIndexParam *params, int D, + ML::MetricType metric, float metricArg, + float *index_items, int n) { + MLCommon::Selection::approx_knn_build_index( + index, params, D, metric, metricArg, index_items, n, handle.get_stream()); +} + +void approx_knn_search(ML::knnIndex *index, int n, const float *x, int k, + float *distances, int64_t *labels) { + MLCommon::Selection::approx_knn_search(index, n, x, k, distances, labels); +} + void knn_classify(raft::handle_t &handle, int *out, int64_t *knn_indices, std::vector &y, size_t n_index_rows, size_t n_query_rows, int k) { diff --git a/cpp/src_prims/selection/knn.cuh b/cpp/src_prims/selection/knn.cuh index c7ea416911..4d3a0f5dc7 100644 --- a/cpp/src_prims/selection/knn.cuh +++ b/cpp/src_prims/selection/knn.cuh @@ -24,6 +24,9 @@ #include #include +#include +#include +#include #include #include #include @@ -41,6 +44,7 @@ #include #include +#include namespace MLCommon { namespace Selection { @@ -193,32 +197,141 @@ inline faiss::MetricType build_faiss_metric(ML::MetricType metric) { } } +inline faiss::ScalarQuantizer::QuantizerType build_faiss_qtype( + ML::QuantizerType qtype) { + switch (qtype) { + case ML::QuantizerType::QT_8bit: + return faiss::ScalarQuantizer::QuantizerType::QT_8bit; + case ML::QuantizerType::QT_8bit_uniform: + return faiss::ScalarQuantizer::QuantizerType::QT_8bit_uniform; + case ML::QuantizerType::QT_4bit_uniform: + return faiss::ScalarQuantizer::QuantizerType::QT_4bit_uniform; + case ML::QuantizerType::QT_fp16: + return faiss::ScalarQuantizer::QuantizerType::QT_fp16; + case ML::QuantizerType::QT_8bit_direct: + return faiss::ScalarQuantizer::QuantizerType::QT_8bit_direct; + case ML::QuantizerType::QT_6bit: + return faiss::ScalarQuantizer::QuantizerType::QT_6bit; + default: + return (faiss::ScalarQuantizer::QuantizerType)qtype; + } +} + +template +void approx_knn_ivfflat_build_index(ML::knnIndex *index, ML::IVFParam *params, + IntType D, ML::MetricType metric, + IntType n) { + faiss::gpu::GpuIndexIVFFlatConfig config; + config.device = index->device; + faiss::MetricType faiss_metric = build_faiss_metric(metric); + faiss::gpu::GpuIndexIVFFlat *faiss_index = new faiss::gpu::GpuIndexIVFFlat( + index->gpu_res, D, params->nlist, faiss_metric, config); + faiss_index->setNumProbes(params->nprobe); + index->index = faiss_index; +} + +template +void approx_knn_ivfpq_build_index(ML::knnIndex *index, ML::IVFPQParam *params, + IntType D, ML::MetricType metric, IntType n) { + faiss::gpu::GpuIndexIVFPQConfig config; + config.device = index->device; + config.usePrecomputedTables = params->usePrecomputedTables; + faiss::MetricType faiss_metric = build_faiss_metric(metric); + faiss::gpu::GpuIndexIVFPQ *faiss_index = + new faiss::gpu::GpuIndexIVFPQ(index->gpu_res, D, params->nlist, params->M, + params->n_bits, faiss_metric, config); + faiss_index->setNumProbes(params->nprobe); + index->index = faiss_index; +} + +template +void approx_knn_ivfsq_build_index(ML::knnIndex *index, ML::IVFSQParam *params, + IntType D, ML::MetricType metric, IntType n) { + faiss::gpu::GpuIndexIVFScalarQuantizerConfig config; + config.device = index->device; + faiss::MetricType faiss_metric = build_faiss_metric(metric); + faiss::ScalarQuantizer::QuantizerType faiss_qtype = + build_faiss_qtype(params->qtype); + faiss::gpu::GpuIndexIVFScalarQuantizer *faiss_index = + new faiss::gpu::GpuIndexIVFScalarQuantizer(index->gpu_res, D, params->nlist, + faiss_qtype, faiss_metric, + params->encodeResidual); + faiss_index->setNumProbes(params->nprobe); + index->index = faiss_index; +} + +template +void approx_knn_build_index(ML::knnIndex *index, ML::knnIndexParam *params, + IntType D, ML::MetricType metric, float metricArg, + float *index_items, IntType n, + cudaStream_t userStream) { + int device; + CUDA_CHECK(cudaGetDevice(&device)); + + faiss::gpu::StandardGpuResources *gpu_res = + new faiss::gpu::StandardGpuResources(); + gpu_res->noTempMemory(); + gpu_res->setCudaMallocWarning(false); + gpu_res->setDefaultStream(device, userStream); + index->gpu_res = gpu_res; + index->device = device; + index->index = nullptr; + + if (dynamic_cast(params)) { + ML::IVFFlatParam *IVFFlat_param = dynamic_cast(params); + approx_knn_ivfflat_build_index(index, IVFFlat_param, D, metric, n); + std::vector h_index_items(n * D); + raft::update_host(h_index_items.data(), index_items, h_index_items.size(), + userStream); + index->index->train(n, h_index_items.data()); + index->index->add(n, h_index_items.data()); + return; + } else if (dynamic_cast(params)) { + ML::IVFPQParam *IVFPQ_param = dynamic_cast(params); + approx_knn_ivfpq_build_index(index, IVFPQ_param, D, metric, n); + } else if (dynamic_cast(params)) { + ML::IVFSQParam *IVFSQ_param = dynamic_cast(params); + approx_knn_ivfsq_build_index(index, IVFSQ_param, D, metric, n); + } else { + ASSERT(index->index, "KNN index could not be initialized"); + } + + index->index->train(n, index_items); + index->index->add(n, index_items); +} + +template +void approx_knn_search(ML::knnIndex *index, IntType n, const float *x, + IntType k, float *distances, int64_t *labels) { + index->index->search(n, x, k, distances, labels); +} + /** - * Search the kNN for the k-nearest neighbors of a set of query vectors - * @param[in] input vector of device device memory array pointers to search - * @param[in] sizes vector of memory sizes for each device array pointer in input - * @param[in] D number of cols in input and search_items - * @param[in] search_items set of vectors to query for neighbors - * @param[in] n number of items in search_items - * @param[out] res_I pointer to device memory for returning k nearest indices - * @param[out] res_D pointer to device memory for returning k nearest distances - * @param[in] k number of neighbors to query - * @param[in] allocator the device memory allocator to use for temporary scratch memory - * @param[in] userStream the main cuda stream to use - * @param[in] internalStreams optional when n_params > 0, the index partitions can be - * queried in parallel using these streams. Note that n_int_streams also - * has to be > 0 for these to be used and their cardinality does not need - * to correspond to n_parts. - * @param[in] n_int_streams size of internalStreams. When this is <= 0, only the - * user stream will be used. - * @param[in] rowMajorIndex are the index arrays in row-major layout? - * @param[in] rowMajorQuery are the query array in row-major layout? - * @param[in] translations translation ids for indices when index rows represent - * non-contiguous partitions - * @param[in] metric corresponds to the FAISS::metricType enum (default is euclidean) - * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm - * @param[in] expanded_form whether or not lp variants should be reduced w/ lp-root - */ + * Search the kNN for the k-nearest neighbors of a set of query vectors + * @param[in] input vector of device device memory array pointers to search + * @param[in] sizes vector of memory sizes for each device array pointer in input + * @param[in] D number of cols in input and search_items + * @param[in] search_items set of vectors to query for neighbors + * @param[in] n number of items in search_items + * @param[out] res_I pointer to device memory for returning k nearest indices + * @param[out] res_D pointer to device memory for returning k nearest distances + * @param[in] k number of neighbors to query + * @param[in] allocator the device memory allocator to use for temporary scratch memory + * @param[in] userStream the main cuda stream to use + * @param[in] internalStreams optional when n_params > 0, the index partitions can be + * queried in parallel using these streams. Note that n_int_streams also + * has to be > 0 for these to be used and their cardinality does not need + * to correspond to n_parts. + * @param[in] n_int_streams size of internalStreams. When this is <= 0, only the + * user stream will be used. + * @param[in] rowMajorIndex are the index arrays in row-major layout? + * @param[in] rowMajorQuery are the query array in row-major layout? + * @param[in] translations translation ids for indices when index rows represent + * non-contiguous partitions + * @param[in] metric corresponds to the FAISS::metricType enum (default is euclidean) + * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm + * @param[in] expanded_form whether or not lp variants should be reduced w/ lp-root + */ template void brute_force_knn(std::vector &input, std::vector &sizes, IntType D, float *search_items, IntType n, int64_t *res_I, diff --git a/python/cuml/neighbors/__init__.py b/python/cuml/neighbors/__init__.py index 8c8a3bfba2..ad2e8cc739 100644 --- a/python/cuml/neighbors/__init__.py +++ b/python/cuml/neighbors/__init__.py @@ -32,4 +32,8 @@ "cosine", "correlation", "inner_product", "sqeuclidean" ]), - "sparse": set(["euclidean", "l2", "inner_product"])} + "sparse": set(["euclidean", "l2", "inner_product"]), + "ivfflat": set(["l2", "euclidean"]), + "ivfpq": set(["l2", "euclidean"]), + "ivfsq": set(["l2", "euclidean"]) + } diff --git a/python/cuml/neighbors/ann.pxd b/python/cuml/neighbors/ann.pxd new file mode 100644 index 0000000000..cf316e421b --- /dev/null +++ b/python/cuml/neighbors/ann.pxd @@ -0,0 +1,173 @@ +# +# Copyright (c) 2019-2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# distutils: language = c++ + + +from libc.stdint cimport uintptr_t +from libcpp cimport bool + +cimport cuml.common.cuda + + +cdef extern from "cuml/neighbors/knn.hpp" namespace "ML": + cdef cppclass knnIndexParam: + pass + + ctypedef enum QuantizerType: + QT_8bit, + QT_4bit, + QT_8bit_uniform, + QT_4bit_uniform, + QT_fp16, + QT_8bit_direct, + QT_6bit + + cdef cppclass IVFParam(knnIndexParam): + int nlist + int nprobe + + cdef cppclass IVFFlatParam(IVFParam): + pass + + cdef cppclass IVFPQParam(IVFParam): + int M + int n_bits + bool usePrecomputedTables + + cdef cppclass IVFSQParam(IVFParam): + QuantizerType qtype + bool encodeResidual + + +cdef inline check_algo_params(algo, params): + def check_param_list(params, param_list): + for param in param_list: + if not hasattr(params, param): + ValueError('algo_params misconfigured : {} \ + parameter unset'.format(param)) + if algo == 'ivfflat': + check_param_list(params, ['nlist', 'nprobe']) + elif algo == "ivfpq": + check_param_list(params, ['nlist', 'nprobe', 'M', 'n_bits', + 'usePrecomputedTables']) + elif algo == "ivfsq": + check_param_list(params, ['nlist', 'nprobe', 'qtype', + 'encodeResidual']) + + +cdef inline build_ivfflat_algo_params(params, automated): + cdef IVFFlatParam* algo_params = new IVFFlatParam() + if automated: + params = { + 'nlist': 8, + 'nprobe': 2 + } + algo_params.nlist = params['nlist'] + algo_params.nprobe = params['nprobe'] + return algo_params + + +cdef inline build_ivfpq_algo_params(params, automated, additional_info): + cdef IVFPQParam* algo_params = new IVFPQParam() + if automated: + allowedSubquantizers = [1, 2, 3, 4, 8, 12, 16, 20, 24, 28, 32] + allowedSubDimSize = {1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32} + N = additional_info['n_samples'] + D = additional_info['n_features'] + + params = { + 'nlist': 8, + 'nprobe': 3 + } + + for n_subq in allowedSubquantizers: + if D % n_subq == 0 and (D / n_subq) in allowedSubDimSize: + params['usePrecomputedTables'] = False + params['M'] = n_subq + break + + if 'M' not in params: + for n_subq in allowedSubquantizers: + if D % n_subq == 0: + params['usePrecomputedTables'] = True + params['M'] = n_subq + break + + for i in reversed(range(1, 4)): + min_train_points = (2 ** i) * 39 + if N >= min_train_points: + params['n_bits'] = i + break + + algo_params.nlist = params['nlist'] + algo_params.nprobe = params['nprobe'] + algo_params.M = params['M'] + algo_params.n_bits = params['n_bits'] + algo_params.usePrecomputedTables = \ + params['usePrecomputedTables'] + return algo_params + + +cdef inline build_ivfsq_algo_params(params, automated): + cdef IVFSQParam* algo_params = new IVFSQParam() + if automated: + params = { + 'nlist': 8, + 'nprobe': 2, + 'qtype': 'QT_8bit', + 'encodeResidual': True + } + + quantizer_type = { + 'QT_8bit': QuantizerType.QT_8bit, + 'QT_4bit': QuantizerType.QT_4bit, + 'QT_8bit_uniform': QuantizerType.QT_8bit_uniform, + 'QT_4bit_uniform': QuantizerType.QT_4bit_uniform, + 'QT_fp16': QuantizerType.QT_fp16, + 'QT_8bit_direct': QuantizerType.QT_8bit_direct, + 'QT_6bit': QuantizerType.QT_6bit, + } + + algo_params.nlist = params['nlist'] + algo_params.nprobe = params['nprobe'] + algo_params.qtype = quantizer_type[params['qtype']] + algo_params.encodeResidual = params['encodeResidual'] + return algo_params + + +cdef inline build_algo_params(algo, params, additional_info): + automated = params is None or params == 'auto' + if not automated: + check_algo_params(algo, params) + + cdef knnIndexParam* algo_params = 0 + if algo == 'ivfflat': + algo_params = \ + build_ivfflat_algo_params(params, automated) + if algo == 'ivfpq': + algo_params = \ + build_ivfpq_algo_params(params, automated, additional_info) + elif algo == 'ivfsq': + algo_params = \ + build_ivfsq_algo_params(params, automated) + + return algo_params + + +cdef inline destroy_algo_params(ptr): + cdef knnIndexParam* algo_params = ptr + del algo_params diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index 8631ec9717..ca8ad9d3e8 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -34,6 +34,7 @@ from cuml.common.doc_utils import generate_docstring from cuml.common.doc_utils import insert_into_docstring from cuml.common.import_utils import has_scipy from cuml.common import input_to_cuml_array +from cuml.neighbors.ann cimport * from cuml.common.sparse_utils import is_sparse from cuml.common.sparse_utils import is_dense @@ -75,6 +76,9 @@ cdef extern from "cuml/neighbors/knn.hpp" namespace "ML": METRIC_Cosine = 100, METRIC_Correlation + cdef cppclass knnIndex: + pass + void brute_force_knn( handle_t &handle, vector[float*] &inputs, @@ -92,6 +96,26 @@ cdef extern from "cuml/neighbors/knn.hpp" namespace "ML": bool expanded ) except + + void approx_knn_build_index( + handle_t &handle, + knnIndex* index, + knnIndexParam* params, + int D, + MetricType metric, + float metricArg, + float *search_items, + int n + ) except + + + void approx_knn_search( + knnIndex* index, + int n, + const float *x, + int k, + float *distances, + int64_t* labels + ) except + + cdef extern from "cuml/neighbors/knn_sparse.hpp" namespace "ML::Sparse": void brute_force_knn(handle_t &handle, const int *idxIndptr, @@ -137,7 +161,19 @@ class NearestNeighbors(Base): handles in several streams. If it is None, a new one is created. algorithm : string (default='brute') - The query algorithm to use. Currently, only 'brute' is supported. + The query algorithm to use. Valid options are : + - 'brute' for brute-force, slow but produces exact results + - 'ivfflat' for inverted file, divide the dataset in partitions + and perform search on relevant partitions only + - 'ivfpq' for inverted file and product quantization, + same as inverted list, in addition the vectors are broken + in n_features/M sub-vectors that will be encoded thanks + to intermediary k-means clusterings. This encoding provide + partial information allowing faster distances calculations + - 'ivfsq' for inverted file and scalar quantization, + same as inverted list, in addition vectors components + are quantized into reduced binary representation allowing + faster distances calculations metric : string (default='euclidean'). Distance metric to use. Supported distances are ['l1, 'cityblock', 'taxicab', 'manhattan', 'euclidean', 'l2', 'braycurtis', 'canberra', @@ -157,7 +193,27 @@ class NearestNeighbors(Base): metric_expanded : bool Can increase performance in Minkowski-based (Lp) metrics (for p > 1) by using the expanded form and not computing the n-th roots. + algo_params : dict, optional (default = None) Used to configure the + nearest neighbor algorithm to be used. + If set to None, parameters will be generated automatically. + Parameters for algorithm 'ivfflat': + - nlist : (int) number of cells to partition dataset into + - nprobe : (int) at query time, number of cells used for search + Parameters for algorithm 'ivfpq': + - nlist : (int) number of cells to partition dataset into + - nprobe : (int) at query time, number of cells used for search + - M : (int) number of subquantizers + - n_bits : (int) bits allocated per subquantizer + - usePrecomputedTables : (bool) wether to use precomputed tables + Parameters for algorithm 'ivfsq': + - nlist : (int) number of cells to partition dataset into + - nprobe : (int) at query time, number of cells used for search + - qtype : (string) quantizer type (among QT_8bit, QT_4bit, + QT_8bit_uniform, QT_4bit_uniform, QT_fp16, QT_8bit_direct, + QT_6bit) + - encodeResidual : (bool) wether to encode residuals metric_params : dict, optional (default = None) This is currently ignored. + output_type : {'input', 'cudf', 'cupy', 'numpy', 'numba'}, default=None Variable to control output type of the results and attributes of the estimator. If None, it'll inherit the output type set at the @@ -248,10 +304,6 @@ class NearestNeighbors(Base): verbose=verbose, output_type=output_type) - if algorithm != "brute": - raise ValueError("Algorithm %s is not valid. Only 'brute' is" - "supported currently." % algorithm) - if metric not in cuml.neighbors.VALID_METRICS[algorithm]: raise ValueError("Metric %s is not valid. " "Use sorted(cuml.neighbors.VALID_METRICS[%s]) " @@ -264,6 +316,8 @@ class NearestNeighbors(Base): self.algo_params = algo_params self.p = p self.algorithm = algorithm + self.algo_params = algo_params + self.knn_index = 0 @generate_docstring() def fit(self, X, convert_dtype=True) -> "NearestNeighbors": @@ -283,13 +337,42 @@ class NearestNeighbors(Base): else: self.X_m, self.n_rows, n_cols, dtype = \ - input_to_cuml_array(X, order='F', check_dtype=np.float32, + input_to_cuml_array(X, order='C', check_dtype=np.float32, convert_to_dtype=(np.float32 if convert_dtype else None)) - self.n_indices = 1 + cdef handle_t* handle_ = self.handle.getHandle() + cdef knnIndexParam* algo_params = 0 + if self.algorithm in ['ivfflat', 'ivfpq', 'ivfsq']: + if not is_dense(X): + raise ValueError("Approximate Nearest Neigbors methods " + "require dense data") + + additional_info = {'n_samples': self.n_rows, + 'n_features': n_cols} + knn_index = new knnIndex() + self.knn_index = knn_index + algo_params = \ + build_algo_params(self.algorithm, self.algo_params, + additional_info) + metric, expanded = self._build_metric_type(self.metric) + + approx_knn_build_index(handle_[0], + knn_index, + algo_params, + n_cols, + metric, + self.p, + self.X_m.ptr, + self.n_rows) + self.handle.sync() + + destroy_algo_params(algo_params) + + del self.X_m + self.n_indices = 1 return self def get_param_names(self): @@ -419,7 +502,7 @@ class NearestNeighbors(Base): or n_neighbors <= 0: raise ValueError("k or n_neighbors must be a positive integers") - if n_neighbors > self.X_m.shape[0]: + if n_neighbors > self.n_rows: raise ValueError("n_neighbors must be <= number of " "samples in index") @@ -431,11 +514,11 @@ class NearestNeighbors(Base): raise ValueError("Dimensions of X need to match dimensions of " "indices (%d)" % self.n_dims) - if isinstance(self.X_m, CumlArray): + if hasattr(self, 'X_m') and isinstance(self.X_m, SparseCumlArray): + D_ndarr, I_ndarr = self._kneighbors_sparse(X, n_neighbors) + else: D_ndarr, I_ndarr = self._kneighbors_dense(X, n_neighbors, convert_dtype) - elif isinstance(self.X_m, SparseCumlArray): - D_ndarr, I_ndarr = self._kneighbors_sparse(X, n_neighbors) self.handle.sync() @@ -459,14 +542,14 @@ class NearestNeighbors(Base): def _kneighbors_dense(self, X, n_neighbors, convert_dtype=None): - if isinstance(self.X_m, CumlArray) and not is_dense(X): + if not is_dense(X): raise ValueError("A NearestNeighbors model trained on dense " "data requires dense input to kneighbors()") metric, expanded = self._build_metric_type(self.metric) X_m, N, _, dtype = \ - input_to_cuml_array(X, order='F', check_dtype=np.float32, + input_to_cuml_array(X, order='C', check_dtype=np.float32, convert_to_dtype=(np.float32 if convert_dtype else False)) @@ -479,35 +562,44 @@ class NearestNeighbors(Base): cdef uintptr_t I_ptr = I_ndarr.ptr cdef uintptr_t D_ptr = D_ndarr.ptr + cdef handle_t* handle_ = self.handle.getHandle() cdef vector[float*] *inputs = new vector[float*]() cdef vector[int] *sizes = new vector[int]() + cdef knnIndex* knn_index = 0 + + if self.algorithm == 'brute': + inputs.push_back(self.X_m.ptr) + sizes.push_back(self.X_m.shape[0]) + + brute_force_knn( + handle_[0], + deref(inputs), + deref(sizes), + self.n_dims, + X_m.ptr, + N, + I_ptr, + D_ptr, + n_neighbors, + True, + True, + metric, + # minkowski order is currently the only metric argument. + self.p, + expanded + ) + else: + knn_index = self.knn_index + approx_knn_search( + knn_index, + N, + X_m.ptr, + n_neighbors, + D_ptr, + I_ptr + ) - cdef uintptr_t idx_ptr = self.X_m.ptr - inputs.push_back(idx_ptr) - sizes.push_back(self.X_m.shape[0]) - - cdef handle_t* handle_ = self.handle.getHandle() - cdef uintptr_t x_ctype_st = X_m.ptr - - brute_force_knn( - handle_[0], - deref(inputs), - deref(sizes), - self.n_dims, - x_ctype_st, - N, - I_ptr, - D_ptr, - n_neighbors, - False, - False, - metric, - - # minkowski order is currently the only metric argument. - self.p, - < bool > expanded - ) - + self.handle.sync() return D_ndarr, I_ndarr def _kneighbors_sparse(self, X, n_neighbors): @@ -647,6 +739,16 @@ class NearestNeighbors(Base): return sparse_csr + def __del__(self): + cdef knnIndex* knn_index = self.knn_index + if knn_index: + del knn_index + + def _more_tags(self): + return { + 'preferred_input_order': 'C' + } + @cuml.internals.api_return_sparse_array() def kneighbors_graph(X=None, n_neighbors=5, mode='connectivity', verbose=False, @@ -746,8 +848,3 @@ def kneighbors_graph(X=None, n_neighbors=5, mode='connectivity', verbose=False, query = X.X_m return X.kneighbors_graph(X=query, n_neighbors=n_neighbors, mode=mode) - - def _more_tags(self): - return { - 'preferred_input_order': 'F' - } diff --git a/python/cuml/test/test_nearest_neighbors.py b/python/cuml/test/test_nearest_neighbors.py index 2c6abc64ed..2d01eaea3d 100644 --- a/python/cuml/test/test_nearest_neighbors.py +++ b/python/cuml/test/test_nearest_neighbors.py @@ -35,6 +35,7 @@ import sklearn import cuml from cuml.common import has_scipy +import gc def predict(neigh_ind, _y, n_neighbors): @@ -55,11 +56,12 @@ def valid_metrics(algo="brute", cuml_algo=None): @pytest.mark.parametrize("datatype", ["dataframe", "numpy"]) @pytest.mark.parametrize("nrows", [500, 1000, 10000]) -@pytest.mark.parametrize("ncols", [100, 1000]) +@pytest.mark.parametrize("ncols", [128, 1024]) @pytest.mark.parametrize("n_neighbors", [10, 50]) @pytest.mark.parametrize("n_clusters", [2, 10]) +@pytest.mark.parametrize("algo", ["brute", "ivfflat", "ivfpq", "ivfsq"]) def test_neighborhood_predictions(nrows, ncols, n_neighbors, n_clusters, - datatype): + datatype, algo): if not has_scipy(): pytest.skip('Skipping test_neighborhood_predictions because ' + 'Scipy is missing') @@ -67,15 +69,15 @@ def test_neighborhood_predictions(nrows, ncols, n_neighbors, n_clusters, X, y = make_blobs(n_samples=nrows, centers=n_clusters, n_features=ncols, random_state=0) - X = X.astype(np.float32) - if datatype == "dataframe": X = cudf.DataFrame(X) - knn_cu = cuKNN() + knn_cu = cuKNN(algorithm=algo) knn_cu.fit(X) neigh_ind = knn_cu.kneighbors(X, n_neighbors=n_neighbors, return_distance=False) + del knn_cu + gc.collect() if datatype == "dataframe": assert isinstance(neigh_ind, cudf.DataFrame) @@ -88,6 +90,92 @@ def test_neighborhood_predictions(nrows, ncols, n_neighbors, n_clusters, assert array_equal(labels, y) +@pytest.mark.parametrize("nlist", [4, 8]) +@pytest.mark.parametrize("nrows", [10000]) +@pytest.mark.parametrize("ncols", [128, 512]) +@pytest.mark.parametrize("n_neighbors", [8, 16]) +def test_ivfflat_pred(nrows, ncols, n_neighbors, nlist): + algo_params = { + 'nlist': nlist, + 'nprobe': nlist * 0.25 + } + + X, y = make_blobs(n_samples=nrows, centers=5, + n_features=ncols, random_state=0) + + knn_cu = cuKNN(algorithm="ivfflat", algo_params=algo_params) + knn_cu.fit(X) + neigh_ind = knn_cu.kneighbors(X, n_neighbors=n_neighbors, + return_distance=False) + del knn_cu + gc.collect() + + labels, probs = predict(neigh_ind, y, n_neighbors) + + assert array_equal(labels, y) + + +@pytest.mark.parametrize("nlist", [8]) +@pytest.mark.parametrize("M", [16, 32]) +@pytest.mark.parametrize("n_bits", [2, 4]) +@pytest.mark.parametrize("usePrecomputedTables", [False, True]) +@pytest.mark.parametrize("nrows", [4000]) +@pytest.mark.parametrize("ncols", [128, 512]) +@pytest.mark.parametrize("n_neighbors", [8]) +def test_ivfpq_pred(nrows, ncols, n_neighbors, + nlist, M, n_bits, usePrecomputedTables): + algo_params = { + 'nlist': nlist, + 'nprobe': int(nlist * 0.2), + 'M': M, + 'n_bits': n_bits, + 'usePrecomputedTables': usePrecomputedTables + } + + X, y = make_blobs(n_samples=nrows, centers=5, + n_features=ncols, random_state=0) + + knn_cu = cuKNN(algorithm="ivfpq", algo_params=algo_params) + knn_cu.fit(X) + neigh_ind = knn_cu.kneighbors(X, n_neighbors=n_neighbors, + return_distance=False) + del knn_cu + gc.collect() + + labels, probs = predict(neigh_ind, y, n_neighbors) + + assert array_equal(labels, y) + + +@pytest.mark.parametrize("nlist", [4]) +@pytest.mark.parametrize("qtype", ['QT_4bit', 'QT_8bit', 'QT_fp16']) +@pytest.mark.parametrize("encodeResidual", [False, True]) +@pytest.mark.parametrize("nrows", [10000]) +@pytest.mark.parametrize("ncols", [128, 512]) +@pytest.mark.parametrize("n_neighbors", [8]) +def test_ivfsq_pred(nrows, ncols, n_neighbors, nlist, qtype, encodeResidual): + algo_params = { + 'nlist': nlist, + 'nprobe': nlist * 0.25, + 'qtype': qtype, + 'encodeResidual': encodeResidual + } + + X, y = make_blobs(n_samples=nrows, centers=5, + n_features=ncols, random_state=0) + + knn_cu = cuKNN(algorithm="ivfsq", algo_params=algo_params) + knn_cu.fit(X) + neigh_ind = knn_cu.kneighbors(X, n_neighbors=n_neighbors, + return_distance=False) + del knn_cu + gc.collect() + + labels, probs = predict(neigh_ind, y, n_neighbors) + + assert array_equal(labels, y) + + def test_return_dists(): n_samples = 50 n_feats = 50