Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Adjusted Rand Index implementation #652

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# cuML 0.8.0 (Date TBD)

## New Features
- PR #652: Adjusted Rand Index metric ml-prim
- PR #636: Rand Index metric ml-prim
- PR #515: Added Random Projection feature
- PR #504: Contingency matrix ml-prim
Expand Down
12 changes: 11 additions & 1 deletion cpp/src/metrics/metrics.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "cuda_utils.h"
#include "metrics.hpp"

#include "metrics/adjustedRandIndex.h"
#include "metrics/randIndex.h"
#include "score/scores.h"

Expand All @@ -39,5 +40,14 @@ double randIndex(const cumlHandle &handle, const double *y, const double *y_hat,
y, y_hat, (uint64_t)n, handle.getDeviceAllocator(), handle.getStream());
}

double adjustedRandIndex(const cumlHandle &handle, const int *y,
const int *y_hat, const int n,
const int lower_class_range,
const int upper_class_range) {
return MLCommon::Metrics::computeAdjustedRandIndex(
y, y_hat, n, lower_class_range, upper_class_range,
handle.getDeviceAllocator(), handle.getStream());
}

} // namespace Metrics
} // namespace ML
} // namespace ML
21 changes: 20 additions & 1 deletion cpp/src/metrics/metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,26 @@ double r2_score_py(const cumlHandle &handle, double *y, double *y_hat, int n);
* @param n: Number of elements in y and y_hat
* @return: The rand index value
*/

double randIndex(const cumlHandle &handle, double *y, double *y_hat, int n);

/**
* Calculates the "adjusted rand index"
*
* This metric is the corrected-for-chance version of the rand index
*
* @param handle: cumlHandle
* @param y: Array of response variables of the first clustering classifications
* @param y_hat: Array of response variables of the second clustering classifications
* @param n: Number of elements in y and y_hat
* @param lower_class_range: the lowest value in the range of classes
* @param upper_class_range: the highest value in the range of classes
* @return: The adjusted rand index value
*/
double adjustedRandIndex(const cumlHandle &handle, const int *y,
const int *y_hat, const int n,
const int lower_class_range,
const int upper_class_range);

} // namespace Metrics
} // namespace ML
} // namespace ML
161 changes: 161 additions & 0 deletions cpp/src_prims/metrics/adjustedRandIndex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@


/*
* Copyright (c) 2019, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* @file adjustedRandIndex.h
* @brief The adjusted Rand index is the corrected-for-chance version of the Rand index.
* Such a correction for chance establishes a baseline by using the expected similarity
* of all pair-wise comparisons between clusterings specified by a random model.
*/

#include <math.h>
#include <cub/cub.cuh>
#include "common/cuml_allocator.hpp"
#include "common/device_buffer.hpp"
#include "cuda_utils.h"
#include "linalg/map_then_reduce.h"
#include "linalg/reduce.h"
#include "metrics/contingencyMatrix.h"

namespace MLCommon {

/**
* @brief Lambda to calculate the number of unordered pairs in a given input
*
* @tparam Type: Data type of the input
* @tparam IdxType : type of the indexing (by default int)
* @param in: the input to the functional mapping
* @param i: the indexing(not used in this case)
*/
template <typename Type, typename IdxType = int>
struct nCTwo {
HDI Type operator()(Type in, IdxType i = 0) { return ((in) * (in - 1)) / 2; }
};

namespace Metrics {

/**
* @brief Function to calculate Adjusted RandIndex
* <a href="https://en.wikipedia.org/wiki/Rand_index">more info on rand index</a>
* @param firstClusterArray: the array of classes of type T
* @param secondClusterArray: the array of classes of type T
* @param size: the size of the data points of type int
* @param numUniqueClasses: number of Unique classes used for clustering
* @param lowerLabelRange: the lower bound of the range of labels
* @param upperLabelRange: the upper bound of the range of labels
* @param allocator: object that takes care of temporary device memory allocation of type std::shared_ptr<MLCommon::deviceAllocator>
* @param stream: the cudaStream object
*/
template <typename T>
double computeAdjustedRandIndex(
const T* firstClusterArray, const T* secondClusterArray, const int size,
const T lowerLabelRange, const T upperLabelRange,
std::shared_ptr<MLCommon::deviceAllocator> allocator, cudaStream_t stream) {
//rand index for size less than 2 is not defined
ASSERT(size >= 2, "Rand Index for size less than 2 not defined!");

int numUniqueClasses = upperLabelRange - lowerLabelRange + 1;

//declaring, allocating and initializing memory for the contingency marix
MLCommon::device_buffer<int> dContingencyMatrix(
allocator, stream, numUniqueClasses * numUniqueClasses);
CUDA_CHECK(cudaMemsetAsync(dContingencyMatrix.data(), 0,
numUniqueClasses * numUniqueClasses * sizeof(int),
stream));

//workspace allocation
char* pWorkspace = nullptr;
size_t workspaceSz = MLCommon::Metrics::getContingencyMatrixWorkspaceSize(
size, firstClusterArray, stream, lowerLabelRange, upperLabelRange);
if (workspaceSz != 0) MLCommon::allocate(pWorkspace, workspaceSz);

//calculating the contingency matrix
MLCommon::Metrics::contingencyMatrix(
firstClusterArray, secondClusterArray, (int)size,
(int*)dContingencyMatrix.data(), stream, (void*)pWorkspace, workspaceSz,
lowerLabelRange, upperLabelRange);

//creating device buffers for all the parameters involved in ARI calculation
//device variables
MLCommon::device_buffer<int> a(allocator, stream, numUniqueClasses);
MLCommon::device_buffer<int> b(allocator, stream, numUniqueClasses);
MLCommon::device_buffer<int> d_aCTwoSum(allocator, stream, 1);
MLCommon::device_buffer<int> d_bCTwoSum(allocator, stream, 1);
MLCommon::device_buffer<int> d_nChooseTwoSum(allocator, stream, 1);
//host variables
int h_aCTwoSum;
int h_bCTwoSum;
int h_nChooseTwoSum;

//initializing device memory
CUDA_CHECK(
cudaMemsetAsync(a.data(), 0, numUniqueClasses * sizeof(int), stream));
CUDA_CHECK(
cudaMemsetAsync(b.data(), 0, numUniqueClasses * sizeof(int), stream));
CUDA_CHECK(cudaMemsetAsync(d_aCTwoSum.data(), 0, sizeof(int), stream));
CUDA_CHECK(cudaMemsetAsync(d_bCTwoSum.data(), 0, sizeof(int), stream));
CUDA_CHECK(cudaMemsetAsync(d_nChooseTwoSum.data(), 0, sizeof(int), stream));

//calculating the sum of NijC2
MLCommon::LinAlg::mapThenSumReduce<int, nCTwo<int>>(
d_nChooseTwoSum.data(), numUniqueClasses * numUniqueClasses, nCTwo<int>(),
stream, dContingencyMatrix.data(), dContingencyMatrix.data());

//calculating the row-wise sums
MLCommon::LinAlg::reduce<int, int, int>(a.data(), dContingencyMatrix.data(),
numUniqueClasses, numUniqueClasses, 0,
true, true, stream);

//calculating the column-wise sums
MLCommon::LinAlg::reduce<int, int, int>(b.data(), dContingencyMatrix.data(),
numUniqueClasses, numUniqueClasses, 0,
true, false, stream);

//calculating the sum of number of unordered pairs for every element in a
MLCommon::LinAlg::mapThenSumReduce<int, nCTwo<int>>(
d_aCTwoSum.data(), numUniqueClasses, nCTwo<int>(), stream, a.data(),
a.data());

//calculating the sum of number of unordered pairs for every element of b
MLCommon::LinAlg::mapThenSumReduce<int, nCTwo<int>>(
d_bCTwoSum.data(), numUniqueClasses, nCTwo<int>(), stream, b.data(),
b.data());

//updating in the host memory
MLCommon::updateHost(&h_nChooseTwoSum, d_nChooseTwoSum.data(), 1, stream);
MLCommon::updateHost(&h_aCTwoSum, d_aCTwoSum.data(), 1, stream);
MLCommon::updateHost(&h_bCTwoSum, d_bCTwoSum.data(), 1, stream);

//freeing the memories in the device
if (pWorkspace) CUDA_CHECK(cudaFree(pWorkspace));

//calculating the ARI
int nChooseTwo = ((size) * (size - 1)) / 2;
double expectedIndex =
((double)((h_aCTwoSum) * (h_bCTwoSum))) / ((double)(nChooseTwo));
double maxIndex = ((double)(h_bCTwoSum + h_aCTwoSum)) / 2.0;
double index = (double)h_nChooseTwoSum;

//checking if the denominator is zero
if (maxIndex - expectedIndex)
return (index - expectedIndex) / (maxIndex - expectedIndex);
else
return 0;
}

}; //end namespace Metrics
}; //end namespace MLCommon
57 changes: 32 additions & 25 deletions cpp/src_prims/metrics/contingencyMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ typedef enum {
} ContingencyMatrixImplType;

template <typename T>
__global__ void devConstructContingencyMatrix(T *groundTruth, T *predicted,
int nSamples, int *outMat,
__global__ void devConstructContingencyMatrix(const T *groundTruth,
const T *predicted,
const int nSamples, int *outMat,
int outIdxOffset,
int outMatWidth) {
int elementId = threadIdx.x + blockDim.x * blockIdx.x;
Expand All @@ -48,9 +49,10 @@ __global__ void devConstructContingencyMatrix(T *groundTruth, T *predicted,
}

template <typename T>
__global__ void devConstructContingencyMatrixSmem(T *groundTruth, T *predicted,
int nSamples, int *outMat,
int outIdxOffset,
__global__ void devConstructContingencyMatrixSmem(const T *groundTruth,
const T *predicted,
const int nSamples,
int *outMat, int outIdxOffset,
int outMatWidth) {
extern __shared__ int sMemMatrix[]; // init smem to zero

Expand Down Expand Up @@ -81,8 +83,9 @@ __global__ void devConstructContingencyMatrixSmem(T *groundTruth, T *predicted,

// helper functions to launch kernel for global atomic add
template <typename T>
cudaError_t computeCMatWAtomics(T *groundTruth, T *predictedLabel, int nSamples,
int *outMat, int outIdxOffset, int outDimN,
cudaError_t computeCMatWAtomics(const T *groundTruth, const T *predictedLabel,
const int nSamples, int *outMat,
int outIdxOffset, int outDimN,
cudaStream_t stream) {
CUDA_CHECK(cudaFuncSetCacheConfig(devConstructContingencyMatrix<T>,
cudaFuncCachePreferL1));
Expand All @@ -98,9 +101,10 @@ cudaError_t computeCMatWAtomics(T *groundTruth, T *predictedLabel, int nSamples,

// helper function to launch share memory atomic add kernel
template <typename T>
cudaError_t computeCMatWSmemAtomics(T *groundTruth, T *predictedLabel,
int nSamples, int *outMat, int outIdxOffset,
int outDimN, cudaStream_t stream) {
cudaError_t computeCMatWSmemAtomics(const T *groundTruth,
const T *predictedLabel, const int nSamples,
int *outMat, int outIdxOffset, int outDimN,
cudaStream_t stream) {
dim3 block(128, 1, 1);
dim3 grid((nSamples + block.x - 1) / block.x);
size_t smemSizePerBlock = outDimN * outDimN * sizeof(int);
Expand All @@ -113,9 +117,9 @@ cudaError_t computeCMatWSmemAtomics(T *groundTruth, T *predictedLabel,

// helper function to sort and global atomic update
template <typename T>
void contingencyMatrixWSort(T *groundTruth, T *predictedLabel, int nSamples,
int *outMat, T minLabel, T maxLabel,
void *workspace, size_t workspaceSize,
void contingencyMatrixWSort(const T *groundTruth, const T *predictedLabel,
const int nSamples, int *outMat, T minLabel,
T maxLabel, void *workspace, size_t workspaceSize,
cudaStream_t stream) {
T *outKeys = reinterpret_cast<T *>(workspace);
size_t alignedBufferSz = alignTo((size_t)nSamples * sizeof(T), (size_t)256);
Expand Down Expand Up @@ -177,9 +181,10 @@ inline ContingencyMatrixImplType getImplVersion(int outDimN) {
* @param maxLabel: [out] calculated max value in input array
*/
template <typename T>
void getInputClassCardinality(T *groundTruth, int nSamples, cudaStream_t stream,
T &minLabel, T &maxLabel) {
thrust::device_ptr<T> dTrueLabel = thrust::device_pointer_cast(groundTruth);
void getInputClassCardinality(const T *groundTruth, const int nSamples,
cudaStream_t stream, T &minLabel, T &maxLabel) {
thrust::device_ptr<const T> dTrueLabel =
thrust::device_pointer_cast(groundTruth);
auto min_max = thrust::minmax_element(thrust::cuda::par.on(stream),
dTrueLabel, dTrueLabel + nSamples);
minLabel = *min_max.first;
Expand All @@ -195,15 +200,16 @@ void getInputClassCardinality(T *groundTruth, int nSamples, cudaStream_t stream,
* @param maxLabel: Optional, max value in input array
*/
template <typename T>
size_t getCMatrixWorkspaceSize(int nSamples, T *groundTruth,
cudaStream_t stream,
T minLabel = std::numeric_limits<T>::max(),
T maxLabel = std::numeric_limits<T>::max()) {
size_t getContingencyMatrixWorkspaceSize(
const int nSamples, const T *groundTruth, cudaStream_t stream,
T minLabel = std::numeric_limits<T>::max(),
T maxLabel = std::numeric_limits<T>::max()) {
size_t workspaceSize = 0;
// below is a redundant computation - can be avoided
if (minLabel == std::numeric_limits<T>::max() ||
maxLabel == std::numeric_limits<T>::max()) {
thrust::device_ptr<T> dTrueLabel = thrust::device_pointer_cast(groundTruth);
thrust::device_ptr<const T> dTrueLabel =
thrust::device_pointer_cast(groundTruth);
auto min_max = thrust::minmax_element(thrust::cuda::par.on(stream),
dTrueLabel, dTrueLabel + nSamples);
minLabel = *min_max.first;
Expand Down Expand Up @@ -233,7 +239,7 @@ size_t getCMatrixWorkspaceSize(int nSamples, T *groundTruth,
/**
* @brief contruct contingency matrix given input ground truth and prediction labels.
* Users should call function getInputClassCardinality to find and allocate memory for
* output. Similarly workspace requirements should be checked using function getCMatrixWorkspaceSize
* output. Similarly workspace requirements should be checked using function getContingencyMatrixWorkspaceSize
* @param groundTruth: device 1-d array for ground truth (num of rows)
* @param predictedLabel: device 1-d array for prediction (num of columns)
* @param nSamples: number of elements in input array
Expand All @@ -245,8 +251,8 @@ size_t getCMatrixWorkspaceSize(int nSamples, T *groundTruth,
* @param maxLabel: Optional, max value in input ground truth array
*/
template <typename T>
void contingencyMatrix(T *groundTruth, T *predictedLabel, int nSamples,
int *outMat, cudaStream_t stream,
void contingencyMatrix(const T *groundTruth, const T *predictedLabel,
const int nSamples, int *outMat, cudaStream_t stream,
void *workspace = nullptr, size_t workspaceSize = 0,
T minLabel = std::numeric_limits<T>::max(),
T maxLabel = std::numeric_limits<T>::max()) {
Expand All @@ -264,7 +270,8 @@ void contingencyMatrix(T *groundTruth, T *predictedLabel, int nSamples,

if (minLabel == std::numeric_limits<T>::max() ||
maxLabel == std::numeric_limits<T>::max()) {
thrust::device_ptr<T> dTrueLabel = thrust::device_pointer_cast(groundTruth);
thrust::device_ptr<const T> dTrueLabel =
thrust::device_pointer_cast(groundTruth);
auto min_max = thrust::minmax_element(thrust::cuda::par.on(stream),
dTrueLabel, dTrueLabel + nSamples);
minLabel = *min_max.first;
Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ if(BUILD_PRIMS_TESTS)
add_executable(prims
prims/add.cu
prims/add_sub_dev_scalar.cu
prims/adjustedRandIndex.cu
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
prims/array.cu
prims/binary_op.cu
prims/ternary_op.cu
Expand Down
Loading