diff --git a/CHANGELOG.md b/CHANGELOG.md index f893a588b0..abe71a6ff3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # cuML 0.8.0 (Date TBD) ## New Features + +- PR #652: Adjusted Rand Index metric ml-prim - PR #679: Class label manipulation ml-prim - PR #636: Rand Index metric ml-prim - PR #515: Added Random Projection feature diff --git a/cpp/src/metrics/metrics.cu b/cpp/src/metrics/metrics.cu index d587777968..90e32deeda 100644 --- a/cpp/src/metrics/metrics.cu +++ b/cpp/src/metrics/metrics.cu @@ -18,6 +18,7 @@ #include "cuda_utils.h" #include "metrics.hpp" +#include "metrics/adjustedRandIndex.h" #include "metrics/randIndex.h" #include "score/scores.h" @@ -39,5 +40,14 @@ double randIndex(const cumlHandle &handle, const double *y, const double *y_hat, y, y_hat, (uint64_t)n, handle.getDeviceAllocator(), handle.getStream()); } +double adjustedRandIndex(const cumlHandle &handle, const int *y, + const int *y_hat, const int n, + const int lower_class_range, + const int upper_class_range) { + return MLCommon::Metrics::computeAdjustedRandIndex( + y, y_hat, n, lower_class_range, upper_class_range, + handle.getDeviceAllocator(), handle.getStream()); +} + } // namespace Metrics -} // namespace ML +} // namespace ML \ No newline at end of file diff --git a/cpp/src/metrics/metrics.hpp b/cpp/src/metrics/metrics.hpp index 0a26cacbb8..8c9579ad76 100644 --- a/cpp/src/metrics/metrics.hpp +++ b/cpp/src/metrics/metrics.hpp @@ -69,7 +69,26 @@ double r2_score_py(const cumlHandle &handle, double *y, double *y_hat, int n); * @param n: Number of elements in y and y_hat * @return: The rand index value */ + double randIndex(const cumlHandle &handle, double *y, double *y_hat, int n); +/** + * Calculates the "adjusted rand index" + * + * This metric is the corrected-for-chance version of the rand index + * + * @param handle: cumlHandle + * @param y: Array of response variables of the first clustering classifications + * @param y_hat: Array of response variables of the second clustering classifications + * @param n: Number of elements in y and y_hat + * @param lower_class_range: the lowest value in the range of classes + * @param upper_class_range: the highest value in the range of classes + * @return: The adjusted rand index value + */ +double adjustedRandIndex(const cumlHandle &handle, const int *y, + const int *y_hat, const int n, + const int lower_class_range, + const int upper_class_range); + } // namespace Metrics -} // namespace ML +} // namespace ML \ No newline at end of file diff --git a/cpp/src_prims/metrics/adjustedRandIndex.h b/cpp/src_prims/metrics/adjustedRandIndex.h new file mode 100644 index 0000000000..2793327e6f --- /dev/null +++ b/cpp/src_prims/metrics/adjustedRandIndex.h @@ -0,0 +1,161 @@ + + +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** +* @file adjustedRandIndex.h +* @brief The adjusted Rand index is the corrected-for-chance version of the Rand index. +* Such a correction for chance establishes a baseline by using the expected similarity +* of all pair-wise comparisons between clusterings specified by a random model. +*/ + +#include +#include +#include "common/cuml_allocator.hpp" +#include "common/device_buffer.hpp" +#include "cuda_utils.h" +#include "linalg/map_then_reduce.h" +#include "linalg/reduce.h" +#include "metrics/contingencyMatrix.h" + +namespace MLCommon { + +/** +* @brief Lambda to calculate the number of unordered pairs in a given input +* +* @tparam Type: Data type of the input +* @tparam IdxType : type of the indexing (by default int) +* @param in: the input to the functional mapping +* @param i: the indexing(not used in this case) +*/ +template +struct nCTwo { + HDI Type operator()(Type in, IdxType i = 0) { return ((in) * (in - 1)) / 2; } +}; + +namespace Metrics { + +/** +* @brief Function to calculate Adjusted RandIndex +* more info on rand index +* @param firstClusterArray: the array of classes of type T +* @param secondClusterArray: the array of classes of type T +* @param size: the size of the data points of type int +* @param numUniqueClasses: number of Unique classes used for clustering +* @param lowerLabelRange: the lower bound of the range of labels +* @param upperLabelRange: the upper bound of the range of labels +* @param allocator: object that takes care of temporary device memory allocation of type std::shared_ptr +* @param stream: the cudaStream object +*/ +template +double computeAdjustedRandIndex( + const T* firstClusterArray, const T* secondClusterArray, const int size, + const T lowerLabelRange, const T upperLabelRange, + std::shared_ptr allocator, cudaStream_t stream) { + //rand index for size less than 2 is not defined + ASSERT(size >= 2, "Rand Index for size less than 2 not defined!"); + + int numUniqueClasses = upperLabelRange - lowerLabelRange + 1; + + //declaring, allocating and initializing memory for the contingency marix + MLCommon::device_buffer dContingencyMatrix( + allocator, stream, numUniqueClasses * numUniqueClasses); + CUDA_CHECK(cudaMemsetAsync(dContingencyMatrix.data(), 0, + numUniqueClasses * numUniqueClasses * sizeof(int), + stream)); + + //workspace allocation + char* pWorkspace = nullptr; + size_t workspaceSz = MLCommon::Metrics::getContingencyMatrixWorkspaceSize( + size, firstClusterArray, stream, lowerLabelRange, upperLabelRange); + if (workspaceSz != 0) MLCommon::allocate(pWorkspace, workspaceSz); + + //calculating the contingency matrix + MLCommon::Metrics::contingencyMatrix( + firstClusterArray, secondClusterArray, (int)size, + (int*)dContingencyMatrix.data(), stream, (void*)pWorkspace, workspaceSz, + lowerLabelRange, upperLabelRange); + + //creating device buffers for all the parameters involved in ARI calculation + //device variables + MLCommon::device_buffer a(allocator, stream, numUniqueClasses); + MLCommon::device_buffer b(allocator, stream, numUniqueClasses); + MLCommon::device_buffer d_aCTwoSum(allocator, stream, 1); + MLCommon::device_buffer d_bCTwoSum(allocator, stream, 1); + MLCommon::device_buffer d_nChooseTwoSum(allocator, stream, 1); + //host variables + int h_aCTwoSum; + int h_bCTwoSum; + int h_nChooseTwoSum; + + //initializing device memory + CUDA_CHECK( + cudaMemsetAsync(a.data(), 0, numUniqueClasses * sizeof(int), stream)); + CUDA_CHECK( + cudaMemsetAsync(b.data(), 0, numUniqueClasses * sizeof(int), stream)); + CUDA_CHECK(cudaMemsetAsync(d_aCTwoSum.data(), 0, sizeof(int), stream)); + CUDA_CHECK(cudaMemsetAsync(d_bCTwoSum.data(), 0, sizeof(int), stream)); + CUDA_CHECK(cudaMemsetAsync(d_nChooseTwoSum.data(), 0, sizeof(int), stream)); + + //calculating the sum of NijC2 + MLCommon::LinAlg::mapThenSumReduce>( + d_nChooseTwoSum.data(), numUniqueClasses * numUniqueClasses, nCTwo(), + stream, dContingencyMatrix.data(), dContingencyMatrix.data()); + + //calculating the row-wise sums + MLCommon::LinAlg::reduce(a.data(), dContingencyMatrix.data(), + numUniqueClasses, numUniqueClasses, 0, + true, true, stream); + + //calculating the column-wise sums + MLCommon::LinAlg::reduce(b.data(), dContingencyMatrix.data(), + numUniqueClasses, numUniqueClasses, 0, + true, false, stream); + + //calculating the sum of number of unordered pairs for every element in a + MLCommon::LinAlg::mapThenSumReduce>( + d_aCTwoSum.data(), numUniqueClasses, nCTwo(), stream, a.data(), + a.data()); + + //calculating the sum of number of unordered pairs for every element of b + MLCommon::LinAlg::mapThenSumReduce>( + d_bCTwoSum.data(), numUniqueClasses, nCTwo(), stream, b.data(), + b.data()); + + //updating in the host memory + MLCommon::updateHost(&h_nChooseTwoSum, d_nChooseTwoSum.data(), 1, stream); + MLCommon::updateHost(&h_aCTwoSum, d_aCTwoSum.data(), 1, stream); + MLCommon::updateHost(&h_bCTwoSum, d_bCTwoSum.data(), 1, stream); + + //freeing the memories in the device + if (pWorkspace) CUDA_CHECK(cudaFree(pWorkspace)); + + //calculating the ARI + int nChooseTwo = ((size) * (size - 1)) / 2; + double expectedIndex = + ((double)((h_aCTwoSum) * (h_bCTwoSum))) / ((double)(nChooseTwo)); + double maxIndex = ((double)(h_bCTwoSum + h_aCTwoSum)) / 2.0; + double index = (double)h_nChooseTwoSum; + + //checking if the denominator is zero + if (maxIndex - expectedIndex) + return (index - expectedIndex) / (maxIndex - expectedIndex); + else + return 0; +} + +}; //end namespace Metrics +}; //end namespace MLCommon diff --git a/cpp/src_prims/metrics/contingencyMatrix.h b/cpp/src_prims/metrics/contingencyMatrix.h index 8ab7049005..085a4ebed5 100644 --- a/cpp/src_prims/metrics/contingencyMatrix.h +++ b/cpp/src_prims/metrics/contingencyMatrix.h @@ -33,8 +33,9 @@ typedef enum { } ContingencyMatrixImplType; template -__global__ void devConstructContingencyMatrix(T *groundTruth, T *predicted, - int nSamples, int *outMat, +__global__ void devConstructContingencyMatrix(const T *groundTruth, + const T *predicted, + const int nSamples, int *outMat, int outIdxOffset, int outMatWidth) { int elementId = threadIdx.x + blockDim.x * blockIdx.x; @@ -48,9 +49,10 @@ __global__ void devConstructContingencyMatrix(T *groundTruth, T *predicted, } template -__global__ void devConstructContingencyMatrixSmem(T *groundTruth, T *predicted, - int nSamples, int *outMat, - int outIdxOffset, +__global__ void devConstructContingencyMatrixSmem(const T *groundTruth, + const T *predicted, + const int nSamples, + int *outMat, int outIdxOffset, int outMatWidth) { extern __shared__ int sMemMatrix[]; // init smem to zero @@ -81,8 +83,9 @@ __global__ void devConstructContingencyMatrixSmem(T *groundTruth, T *predicted, // helper functions to launch kernel for global atomic add template -cudaError_t computeCMatWAtomics(T *groundTruth, T *predictedLabel, int nSamples, - int *outMat, int outIdxOffset, int outDimN, +cudaError_t computeCMatWAtomics(const T *groundTruth, const T *predictedLabel, + const int nSamples, int *outMat, + int outIdxOffset, int outDimN, cudaStream_t stream) { CUDA_CHECK(cudaFuncSetCacheConfig(devConstructContingencyMatrix, cudaFuncCachePreferL1)); @@ -98,9 +101,10 @@ cudaError_t computeCMatWAtomics(T *groundTruth, T *predictedLabel, int nSamples, // helper function to launch share memory atomic add kernel template -cudaError_t computeCMatWSmemAtomics(T *groundTruth, T *predictedLabel, - int nSamples, int *outMat, int outIdxOffset, - int outDimN, cudaStream_t stream) { +cudaError_t computeCMatWSmemAtomics(const T *groundTruth, + const T *predictedLabel, const int nSamples, + int *outMat, int outIdxOffset, int outDimN, + cudaStream_t stream) { dim3 block(128, 1, 1); dim3 grid((nSamples + block.x - 1) / block.x); size_t smemSizePerBlock = outDimN * outDimN * sizeof(int); @@ -113,9 +117,9 @@ cudaError_t computeCMatWSmemAtomics(T *groundTruth, T *predictedLabel, // helper function to sort and global atomic update template -void contingencyMatrixWSort(T *groundTruth, T *predictedLabel, int nSamples, - int *outMat, T minLabel, T maxLabel, - void *workspace, size_t workspaceSize, +void contingencyMatrixWSort(const T *groundTruth, const T *predictedLabel, + const int nSamples, int *outMat, T minLabel, + T maxLabel, void *workspace, size_t workspaceSize, cudaStream_t stream) { T *outKeys = reinterpret_cast(workspace); size_t alignedBufferSz = alignTo((size_t)nSamples * sizeof(T), (size_t)256); @@ -177,9 +181,10 @@ inline ContingencyMatrixImplType getImplVersion(int outDimN) { * @param maxLabel: [out] calculated max value in input array */ template -void getInputClassCardinality(T *groundTruth, int nSamples, cudaStream_t stream, - T &minLabel, T &maxLabel) { - thrust::device_ptr dTrueLabel = thrust::device_pointer_cast(groundTruth); +void getInputClassCardinality(const T *groundTruth, const int nSamples, + cudaStream_t stream, T &minLabel, T &maxLabel) { + thrust::device_ptr dTrueLabel = + thrust::device_pointer_cast(groundTruth); auto min_max = thrust::minmax_element(thrust::cuda::par.on(stream), dTrueLabel, dTrueLabel + nSamples); minLabel = *min_max.first; @@ -195,15 +200,16 @@ void getInputClassCardinality(T *groundTruth, int nSamples, cudaStream_t stream, * @param maxLabel: Optional, max value in input array */ template -size_t getCMatrixWorkspaceSize(int nSamples, T *groundTruth, - cudaStream_t stream, - T minLabel = std::numeric_limits::max(), - T maxLabel = std::numeric_limits::max()) { +size_t getContingencyMatrixWorkspaceSize( + const int nSamples, const T *groundTruth, cudaStream_t stream, + T minLabel = std::numeric_limits::max(), + T maxLabel = std::numeric_limits::max()) { size_t workspaceSize = 0; // below is a redundant computation - can be avoided if (minLabel == std::numeric_limits::max() || maxLabel == std::numeric_limits::max()) { - thrust::device_ptr dTrueLabel = thrust::device_pointer_cast(groundTruth); + thrust::device_ptr dTrueLabel = + thrust::device_pointer_cast(groundTruth); auto min_max = thrust::minmax_element(thrust::cuda::par.on(stream), dTrueLabel, dTrueLabel + nSamples); minLabel = *min_max.first; @@ -233,7 +239,7 @@ size_t getCMatrixWorkspaceSize(int nSamples, T *groundTruth, /** * @brief contruct contingency matrix given input ground truth and prediction labels. * Users should call function getInputClassCardinality to find and allocate memory for - * output. Similarly workspace requirements should be checked using function getCMatrixWorkspaceSize + * output. Similarly workspace requirements should be checked using function getContingencyMatrixWorkspaceSize * @param groundTruth: device 1-d array for ground truth (num of rows) * @param predictedLabel: device 1-d array for prediction (num of columns) * @param nSamples: number of elements in input array @@ -245,8 +251,8 @@ size_t getCMatrixWorkspaceSize(int nSamples, T *groundTruth, * @param maxLabel: Optional, max value in input ground truth array */ template -void contingencyMatrix(T *groundTruth, T *predictedLabel, int nSamples, - int *outMat, cudaStream_t stream, +void contingencyMatrix(const T *groundTruth, const T *predictedLabel, + const int nSamples, int *outMat, cudaStream_t stream, void *workspace = nullptr, size_t workspaceSize = 0, T minLabel = std::numeric_limits::max(), T maxLabel = std::numeric_limits::max()) { @@ -264,7 +270,8 @@ void contingencyMatrix(T *groundTruth, T *predictedLabel, int nSamples, if (minLabel == std::numeric_limits::max() || maxLabel == std::numeric_limits::max()) { - thrust::device_ptr dTrueLabel = thrust::device_pointer_cast(groundTruth); + thrust::device_ptr dTrueLabel = + thrust::device_pointer_cast(groundTruth); auto min_max = thrust::minmax_element(thrust::cuda::par.on(stream), dTrueLabel, dTrueLabel + nSamples); minLabel = *min_max.first; diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index ea035b2f71..f92aa77a3e 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -125,6 +125,7 @@ if(BUILD_PRIMS_TESTS) add_executable(prims prims/add.cu prims/add_sub_dev_scalar.cu + prims/adjustedRandIndex.cu prims/binary_op.cu prims/ternary_op.cu prims/coalesced_reduction.cu diff --git a/cpp/test/prims/adjustedRandIndex.cu b/cpp/test/prims/adjustedRandIndex.cu new file mode 100644 index 0000000000..fd23adbbd4 --- /dev/null +++ b/cpp/test/prims/adjustedRandIndex.cu @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + + + + + + + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "common/cuml_allocator.hpp" +#include "metrics/adjustedRandIndex.h" +#include "metrics/contingencyMatrix.h" +#include "test_utils.h" + +namespace MLCommon { +namespace Metrics { + +//parameter structure definition +struct AdjustedRandIndexParam { + int nElements; + int lowerLabelRange; + int upperLabelRange; + bool sameArrays; + double tolerance; +}; + +//test fixture class +template + + + class adjustedRandIndexTest + : public ::testing::TestWithParam { + protected: + //the constructor + void SetUp() override { + //getting the parameters + params = ::testing::TestWithParam::GetParam(); + + nElements = params.nElements; + lowerLabelRange = params.lowerLabelRange; + upperLabelRange = params.upperLabelRange; + + //generating random value test input + std::vector arr1(nElements, 0); + std::vector arr2(nElements, 0); + std::random_device rd; + std::default_random_engine dre(rd()); + std::uniform_int_distribution intGenerator(lowerLabelRange, + upperLabelRange); + + std::generate(arr1.begin(), arr1.end(), + [&]() { return intGenerator(dre); }); + if (params.sameArrays) { + arr2 = arr1; + } else { + std::generate(arr2.begin(), arr2.end(), + [&]() { return intGenerator(dre); }); + } + + //generating the golden output + //calculating the contingency matrix + int numUniqueClasses = upperLabelRange - lowerLabelRange + 1; + size_t sizeOfMat = numUniqueClasses * numUniqueClasses * sizeof(int); + int *hGoldenOutput = (int *)malloc(sizeOfMat); + memset(hGoldenOutput, 0, sizeOfMat); + int i, j; + for (i = 0; i < nElements; i++) { + int row = arr1[i] - lowerLabelRange; + int column = arr2[i] - lowerLabelRange; + + hGoldenOutput[row * numUniqueClasses + column] += 1; + } + int sumOfNijCTwo = 0; + int *a = (int *)malloc(numUniqueClasses * sizeof(int)); + int *b = (int *)malloc(numUniqueClasses * sizeof(int)); + memset(a, 0, numUniqueClasses * sizeof(int)); + memset(b, 0, numUniqueClasses * sizeof(int)); + int sumOfAiCTwo = 0; + int sumOfBiCTwo = 0; + + //calculating the sum of number of pairwise points in each index + //and also the reducing contingency matrix along row and column + for (i = 0; i < numUniqueClasses; ++i) { + for (j = 0; j < numUniqueClasses; ++j) { + int Nij = hGoldenOutput[i * numUniqueClasses + j]; + sumOfNijCTwo += ((Nij) * (Nij - 1)) / 2; + a[i] += hGoldenOutput[i * numUniqueClasses + j]; + b[i] += hGoldenOutput[j * numUniqueClasses + i]; + } + } + + //claculating the sum of number pairwise points in ever column sum + //claculating the sum of number pairwise points in ever row sum + for (i = 0; i < numUniqueClasses; ++i) { + sumOfAiCTwo += ((a[i]) * (a[i] - 1)) / 2; + sumOfBiCTwo += ((b[i]) * (b[i] - 1)) / 2; + } + + //calculating the ARI + int nCTwo = ((nElements) * (nElements - 1)) / 2; + double expectedIndex = + ((double)(sumOfBiCTwo * sumOfAiCTwo)) / ((double)(nCTwo)); + double maxIndex = ((double)(sumOfAiCTwo + sumOfBiCTwo)) / 2.0; + double index = (double)sumOfNijCTwo; + + if (maxIndex - expectedIndex) + truthAdjustedRandIndex = + (index - expectedIndex) / (maxIndex - expectedIndex); + else + truthAdjustedRandIndex = 0; + + //allocating and initializing memory to the GPU + CUDA_CHECK(cudaStreamCreate(&stream)); + MLCommon::allocate(firstClusterArray, nElements, true); + MLCommon::allocate(secondClusterArray, nElements, true); + + MLCommon::updateDevice(firstClusterArray, &arr1[0], (int)nElements, stream); + MLCommon::updateDevice(secondClusterArray, &arr2[0], (int)nElements, + stream); + std::shared_ptr allocator( + new defaultDeviceAllocator); + + //calling the adjustedRandIndex CUDA implementation + computedAdjustedRandIndex = MLCommon::Metrics::computeAdjustedRandIndex( + firstClusterArray, secondClusterArray, nElements, lowerLabelRange, + upperLabelRange, allocator, stream); + } + + //the destructor + void TearDown() override { + CUDA_CHECK(cudaFree(firstClusterArray)); + CUDA_CHECK(cudaFree(secondClusterArray)); + CUDA_CHECK(cudaStreamDestroy(stream)); + } + + //declaring the data values + AdjustedRandIndexParam params; + T lowerLabelRange, upperLabelRange; + T *firstClusterArray = nullptr; + T *secondClusterArray = nullptr; + int nElements = 0; + double truthAdjustedRandIndex = 0; + double computedAdjustedRandIndex = 0; + cudaStream_t stream; +}; + +//setting test parameter values +const std::vector inputs = { + {199, 1, 10, false, 0.000001}, {200, 15, 100, false, 0.000001}, + {100, 1, 20, false, 0.000001}, {10, 1, 10, false, 0.000001}, + {198, 1, 100, false, 0.000001}, {300, 3, 99, false, 0.000001}, + {199, 1, 10, true, 0.000001}, {200, 15, 100, true, 0.000001}, + {100, 1, 20, true, 0.000001}, {10, 1, 10, true, 0.000001}, + {198, 1, 100, true, 0.000001}, {300, 3, 99, true, 0.000001}}; + +//writing the test suite +typedef adjustedRandIndexTest adjustedRandIndexTestClass; +TEST_P(adjustedRandIndexTestClass, Result) { + ASSERT_NEAR(computedAdjustedRandIndex, truthAdjustedRandIndex, + params.tolerance); +} +INSTANTIATE_TEST_CASE_P(adjustedRandIndex, adjustedRandIndexTestClass, + ::testing::ValuesIn(inputs)); + +} //end namespace Metrics +} //end namespace MLCommon diff --git a/cpp/test/prims/contingencyMatrix.cu b/cpp/test/prims/contingencyMatrix.cu index 4638f288d7..1e89d15352 100644 --- a/cpp/test/prims/contingencyMatrix.cu +++ b/cpp/test/prims/contingencyMatrix.cu @@ -90,7 +90,7 @@ class ContingencyMatrixTestImpl MLCommon::allocate(dComputedOutput, numUniqueClasses * numUniqueClasses); MLCommon::allocate(dGoldenOutput, numUniqueClasses * numUniqueClasses); - size_t workspaceSz = MLCommon::Metrics::getCMatrixWorkspaceSize( + size_t workspaceSz = MLCommon::Metrics::getContingencyMatrixWorkspaceSize( numElements, dY, stream, lowerLabelRange, upperLabelRange); if (workspaceSz != 0) MLCommon::allocate(pWorkspace, workspaceSz);