Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] tSNE Lock up #2565

Merged
merged 17 commits into from
Jul 29, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@
- PR #2535: Fix issue with incorrect docker image being used in local build script
- PR #2542: Fix small memory leak in TSNE
- PR #2552: Fixed the length argument of updateDevice calls in RF test
- PR #2565: Fix cell allocation code to avoid loops in quad-tree. Prevent NaNs causing infinite descent
- PR #2563: Update scipy call for arima gradient test
- PR #2569: Fix for cuDF update
- PR #2508: Use keyword parameters in sklearn.datasets.make_* functions
- PR #2573: Considering managed memory as device type on checking for KMeans
- PR #2574: Fixing include path in `tsvd_mg.pyx`

# cuML 0.14.0 (03 Jun 2020)

## New Features
Expand Down
9 changes: 6 additions & 3 deletions cpp/src/tsne/barnes_hut.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
const int max_iter = 1000, const float min_grad_norm = 1e-7,
const float pre_momentum = 0.5, const float post_momentum = 0.8,
const long long random_state = -1) {
using MLCommon::device_buffer;
drobison00 marked this conversation as resolved.
Show resolved Hide resolved
auto d_alloc = handle.getDeviceAllocator();
cudaStream_t stream = handle.getStream();

Expand All @@ -70,7 +71,6 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
CUML_LOG_DEBUG("N_nodes = %d blocks = %d", nnodes, blocks);

// Allocate more space
//---------------------------------------------------
// MLCommon::device_buffer<unsigned> errl(d_alloc, stream, 1);
MLCommon::device_buffer<unsigned> limiter(d_alloc, stream, 1);
MLCommon::device_buffer<int> maxdepthd(d_alloc, stream, 1);
Expand Down Expand Up @@ -122,6 +122,7 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,

// Apply
MLCommon::device_buffer<float> gains_bh(d_alloc, stream, n * 2);

thrust::device_ptr<float> begin_gains_bh =
thrust::device_pointer_cast(gains_bh.data());
thrust::fill(thrust::cuda::par.on(stream), begin_gains_bh,
Expand All @@ -147,7 +148,6 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
cudaFuncSetCacheConfig(TSNE::RepulsionKernel, cudaFuncCachePreferL1);
cudaFuncSetCacheConfig(TSNE::attractive_kernel_bh, cudaFuncCachePreferL1);
cudaFuncSetCacheConfig(TSNE::IntegrationKernel, cudaFuncCachePreferL1);

// Do gradient updates
//---------------------------------------------------
CUML_LOG_DEBUG("Start gradient updates!");
Expand All @@ -162,6 +162,7 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
CUDA_CHECK(cudaMemsetAsync(static_cast<void *>(attr_forces.data()), 0,
attr_forces.size() * sizeof(*attr_forces.data()),
stream));

TSNE::Reset_Normalization<<<1, 1, 0, stream>>>(
Z_norm.data(), radiusd_squared.data(), bottomd.data(), NNODES,
radiusd.data());
Expand Down Expand Up @@ -265,9 +266,11 @@ void Barnes_Hut(float *VAL, const int *COL, const int *ROW, const int NNZ,
}
PRINT_TIMES;

// Copy final YY into true output Y
drobison00 marked this conversation as resolved.
Show resolved Hide resolved
MLCommon::copy(Y, YY.data(), n, stream);
CUDA_CHECK(cudaPeekAtLastError());
drobison00 marked this conversation as resolved.
Show resolved Hide resolved

MLCommon::copy(Y + n, YY.data() + nnodes + 1, n, stream);
CUDA_CHECK(cudaPeekAtLastError());
}

} // namespace TSNE
Expand Down
45 changes: 36 additions & 9 deletions cpp/src/tsne/bh_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
namespace ML {
namespace TSNE {


/**
* Intializes the states of objects. This speeds the overall kernel up.
*/
Expand Down Expand Up @@ -171,7 +172,8 @@ __global__ __launch_bounds__(1024, 1) void ClearKernel1(int *restrict childd,
}

/**
* Build the actual KD Tree.
* Build the actual QuadTree.
* See: https://iss.oden.utexas.edu/Publications/Papers/burtscher11.pdf
*/
__global__ __launch_bounds__(
THREADS2, FACTOR2) void TreeBuildingKernel(/* int *restrict errd, */
Expand All @@ -194,6 +196,7 @@ __global__ __launch_bounds__(

int localmaxdepth = 1;
int skip = 1;

const int inc = blockDim.x * gridDim.x;
int i = threadIdx.x + blockIdx.x * blockDim.x;

Expand All @@ -206,6 +209,11 @@ __global__ __launch_bounds__(
depth = 1;
r = radius * 0.5f;

/* Select child node 'j'
rootx < px rootx > px
* rooty < py 1 -> 3 0 -> 2
* rooty > py 1 -> 1 0 -> 0
*/
x = rootx + ((rootx < (px = posxd[i])) ? (j = 1, r) : (j = 0, -r));

y = rooty + ((rooty < (py = posyd[i])) ? (j |= 2, r) : (-r));
Expand All @@ -217,41 +225,50 @@ __global__ __launch_bounds__(
depth++;
r *= 0.5f;

// determine which child to follow
x += ((x < px) ? (j = 1, r) : (j = 0, -r));

y += ((y < py) ? (j |= 2, r) : (-r));
}

// (ch)ild will be '-1' (nullptr), '-2' (locked), or an Integer corresponding to a body offset
// in the lower [0, N) blocks of childd
if (ch != -2) {
// skip if child pointer is locked and try again later
// skip if child pointer was locked when we examined it, and try again later.
locked = n * 4 + j;
// store the locked position in case we need to patch in a cell later.

if (ch == -1) {
// Child is a nullptr ('-1'), so we write our body index to the leaf, and move on to the next body.
if (atomicCAS(&childd[locked], -1, i) == -1) {
if (depth > localmaxdepth) localmaxdepth = depth;

i += inc; // move on to next body
skip = 1;
}
} else {
// Child node isn't empty, so we store the current value of the child, lock the leaf, and patch in a new cell
if (ch == atomicCAS(&childd[locked], ch, -2)) {
// try to lock
patch = -1;

while (ch >= 0) {
depth++;

const int cell = atomicSub(bottomd, 1) - 1;
if (cell <= N) {
if (cell == N) {
// atomicExch(errd, 1);
drobison00 marked this conversation as resolved.
Show resolved Hide resolved
atomicExch(bottomd, NNODES);
//printf("Cell Allocation Overflow, depth=%d, r=%f, rbound=%f, N=%d, NNODES=%d, bottomd=%d\n",
drobison00 marked this conversation as resolved.
Show resolved Hide resolved
// depth, r, radius, N, NNODES, prev_bottom);
} else if (cell < N) {
depth--;
continue;
}

if (patch != -1) childd[n * 4 + j] = cell;

if (cell > patch) patch = cell;

// Insert migrated child node
j = (x < posxd[ch]) ? 1 : 0;
if (y < posyd[ch]) j |= 2;

Expand All @@ -264,7 +281,9 @@ __global__ __launch_bounds__(
y += ((y < py) ? (j |= 2, r) : (-r));

ch = childd[n * 4 + j];
if (r <= 1e-10) break;
if (r <= 1e-10) {
break;
}
}

childd[n * 4 + j] = i;
Expand All @@ -276,6 +295,7 @@ __global__ __launch_bounds__(
}
}
}

__threadfence();

if (skip == 2) childd[locked] = patch;
Expand Down Expand Up @@ -642,12 +662,19 @@ __global__ void attractive_kernel_bh(
if (index >= NNZ) return;
const int i = ROW[index];
const int j = COL[index];
float PQ;

// TODO: Calculate Kullback-Leibler divergence
// TODO: Convert attractive forces to CSR format
const float PQ = __fdividef(
VAL[index],
norm_add1[i] + norm[j] - 2.0f * (Y1[i] * Y1[j] + Y2[i] * Y2[j])); // P*Q
// Try single precision compute first
float denominator = __fmaf_rn(-2.0f, (Y1[i] * Y1[j]), norm_add1[i]) + __fmaf_rn(-2.0f, (Y2[i] * Y2[j]), norm[j]);

if (denominator == 0) {
/* repeat with double precision */
double dbl_denominator = __fma_rn(-2.0f, Y1[i] * Y1[j], norm_add1[i]) + __fma_rn(-2.0f, Y2[i] * Y2[j], norm[j]);
denominator = (dbl_denominator != 0) ? static_cast<float>(dbl_denominator) : FLT_EPSILON;
}
PQ = __fdividef(VAL[index], denominator);

// Apply forces
atomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j]));
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/tsne/exact_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace ML {
namespace TSNE {

/****************************************/
/* Finds the best guassian bandwith for
/* Finds the best Gaussian bandwidth for
each row in the dataset */
__global__ void sigmas_kernel(const float *restrict distances,
float *restrict P, const float perplexity,
Expand All @@ -45,7 +45,7 @@ __global__ void sigmas_kernel(const float *restrict distances,
for (int step = 0; step < epochs; step++) {
float sum_Pi = FLT_EPSILON;

// Exponentiate to get guassian
// Exponentiate to get Gaussian
for (int j = 0; j < k; j++) {
P[ik + j] = __expf(-distances[ik + j] * beta);
sum_Pi += P[ik + j];
Expand Down Expand Up @@ -84,7 +84,7 @@ __global__ void sigmas_kernel(const float *restrict distances,
}

/****************************************/
/* Finds the best guassian bandwith for
/* Finds the best Gaussian bandwith for
each row in the dataset */
__global__ void sigmas_kernel_2d(const float *restrict distances,
float *restrict P, const float perplexity,
Expand All @@ -101,7 +101,7 @@ __global__ void sigmas_kernel_2d(const float *restrict distances,
register const int ik = i * 2;

for (int step = 0; step < epochs; step++) {
// Exponentiate to get guassian
// Exponentiate to get Gaussian
P[ik] = __expf(-distances[ik] * beta);
P[ik + 1] = __expf(-distances[ik + 1] * beta);
const float sum_Pi = FLT_EPSILON + P[ik] + P[ik + 1];
Expand Down