diff --git a/CHANGELOG.md b/CHANGELOG.md index 01335629db..463376b771 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,8 @@ - PR #3155: Eliminate unnecessary warnings from random projection test - PR #3176: Add probabilistic SVM tests with various input array types - PR #3180: FIL: `blocks_per_sm` support in Python +- PR #3219: Update CI to use XGBoost 1.3.0 RCs +- PR #3221: Update contributing doc for label support - PR #3177: Make Multinomial Naive Bayes inherit from `ClassifierMixin` and use it for score ## Bug Fixes @@ -91,6 +93,7 @@ - PR #3185: Add documentation for Distributed TFIDF Transformer - PR #3190: Fix Attribute error on ICPA #3183 and PCA input type - PR #3208: Fix EXITCODE override in notebook test script +- PR #3216: Ignore splits that do not satisfy constraints # cuML 0.16.0 (23 Oct 2020) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index afb28bc23e..b372f24e96 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -41,10 +41,21 @@ into three categories: ### A note related to our CI process After you have started a PR (refer to step 6 in the previous section), every time you do a `git push `, it triggers a new CI run on all the commits thus far. Even though GPUCI has mechanisms to deal with this to a certain extent, if you keep `push`ing too frequently, it might just clog our GPUCI servers and slow down every PR and conda package generation! So, please be mindful of this and try not to do many frequent pushes. -To quantify this, the average check in our CI takes between 25 and 32 minutes on our servers. The GPUCI infrastructure has limited resources, so if the servers get overwhelmed, every current active PR will not be able to correctly schedule CI. +To quantify this, the average check in our CI takes between 80 and 90 minutes on our servers. The GPUCI infrastructure has limited resources, so if the servers get overwhelmed, every current active PR will not be able to correctly schedule CI. Remember, if you are unsure about anything, don't hesitate to comment on issues and ask for clarifications! +### Managing PR labels + +Each PR must be labeled according to whether it is a "breaking" or "non-breaking" change (using Github labels). This is used to highlight changes that users should know about when upgrading. + +For cuML, a "breaking" change is one that modifies the public, non-experimental, Python API in a +non-backward-compatible way. The C++ API does not have an expectation of backward compatibility at this +time, so changes to it are not typically considered breaking. Backward-compatible API changes to the Python +API (such as adding a new keyword argument to a function) do not need to be labeled. + +Additional labels must be applied to indicate whether the change is a feature, improvement, bugfix, or documentation change. See the shared RAPIDS documentation for these labels: https://github.com/rapidsai/kb/issues/42. + ### Seasoned developers Once you have gotten your feet wet and are more comfortable with the code, you diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 34910e2011..b944361e68 100755 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -53,7 +53,7 @@ gpuci_conda_retry install -c conda-forge -c rapidsai -c rapidsai-nightly -c nvid "dask-cudf=${MINOR_VERSION}" \ "dask-cuda=${MINOR_VERSION}" \ "ucx-py=${MINOR_VERSION}" \ - "xgboost=1.2.0dev.rapidsai${MINOR_VERSION}" \ + "xgboost=1.3.0dev.rapidsai${MINOR_VERSION}" \ "rapids-build-env=${MINOR_VERSION}.*" \ "rapids-notebook-env=${MINOR_VERSION}.*" \ "rapids-doc-env=${MINOR_VERSION}.*" diff --git a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh index 07be5da0ac..8fd495d1e8 100644 --- a/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/builder_base.cuh @@ -402,6 +402,7 @@ struct ClsTraits { computeSplitClassificationKernel <<>>( b.hist, b.params.n_bins, b.params.max_depth, b.params.min_samples_split, + b.params.min_samples_leaf, b.params.min_impurity_decrease, b.params.max_leaves, b.input, b.curr_nodes, col, b.done_count, b.mutex, b.n_leaves, b.splits, splitType); } @@ -479,9 +480,10 @@ struct RegTraits { computeSplitRegressionKernel <<>>( b.pred, b.pred2, b.pred2P, b.pred_count, b.params.n_bins, - b.params.max_depth, b.params.min_samples_split, b.params.max_leaves, - b.input, b.curr_nodes, col, b.done_count, b.mutex, b.n_leaves, b.splits, - b.block_sync, splitType); + b.params.max_depth, b.params.min_samples_split, + b.params.min_samples_leaf, b.params.min_impurity_decrease, + b.params.max_leaves, b.input, b.curr_nodes, col, b.done_count, b.mutex, + b.n_leaves, b.splits, b.block_sync, splitType); } /** diff --git a/cpp/src/decisiontree/batched-levelalgo/kernels.cuh b/cpp/src/decisiontree/batched-levelalgo/kernels.cuh index f7cf3b67b7..eb2cc0f945 100644 --- a/cpp/src/decisiontree/batched-levelalgo/kernels.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/kernels.cuh @@ -267,10 +267,10 @@ __device__ OutT* alignPointer(InT input) { template __global__ void computeSplitClassificationKernel( int* hist, IdxT nbins, IdxT max_depth, IdxT min_samples_split, - IdxT max_leaves, Input input, - const Node* nodes, IdxT colStart, int* done_count, - int* mutex, const IdxT* n_leaves, Split* splits, - CRITERION splitType) { + IdxT min_samples_leaf, DataT min_impurity_decrease, IdxT max_leaves, + Input input, const Node* nodes, + IdxT colStart, int* done_count, int* mutex, const IdxT* n_leaves, + Split* splits, CRITERION splitType) { extern __shared__ char smem[]; IdxT nid = blockIdx.z; auto node = nodes[nid]; @@ -326,9 +326,11 @@ __global__ void computeSplitClassificationKernel( sp.init(); __syncthreads(); if (splitType == CRITERION::GINI) { - giniGain(shist, sbins, sp, col, range_len, nbins, nclasses); + giniGain(shist, sbins, sp, col, range_len, nbins, nclasses, + min_samples_leaf, min_impurity_decrease); } else { - entropyGain(shist, sbins, sp, col, range_len, nbins, nclasses); + entropyGain(shist, sbins, sp, col, range_len, nbins, nclasses, + min_samples_leaf, min_impurity_decrease); } __syncthreads(); sp.evalBestSplit(smem, splits + nid, mutex + nid); @@ -337,7 +339,8 @@ __global__ void computeSplitClassificationKernel( template __global__ void computeSplitRegressionKernel( DataT* pred, DataT* pred2, DataT* pred2P, IdxT* count, IdxT nbins, - IdxT max_depth, IdxT min_samples_split, IdxT max_leaves, + IdxT max_depth, IdxT min_samples_split, IdxT min_samples_leaf, + DataT min_impurity_decrease, IdxT max_leaves, Input input, const Node* nodes, IdxT colStart, int* done_count, int* mutex, const IdxT* n_leaves, Split* splits, void* workspace, CRITERION splitType) { @@ -471,7 +474,8 @@ __global__ void computeSplitRegressionKernel( scount[i] = count[gcOffset + i]; } __syncthreads(); - mseGain(spred, scount, sbins, sp, col, range_len, nbins); + mseGain(spred, scount, sbins, sp, col, range_len, nbins, min_samples_leaf, + min_impurity_decrease); } else { for (IdxT i = threadIdx.x; i < len; i += blockDim.x) { spred2[i] = pred2[gOffset + i]; @@ -480,7 +484,8 @@ __global__ void computeSplitRegressionKernel( spred2P[i] = pred2P[gcOffset + i]; } __syncthreads(); - maeGain(spred2, spred2P, scount, sbins, sp, col, range_len, nbins); + maeGain(spred2, spred2P, scount, sbins, sp, col, range_len, nbins, + min_samples_leaf, min_impurity_decrease); } __syncthreads(); sp.evalBestSplit(smem, splits + nid, mutex + nid); diff --git a/cpp/src/decisiontree/batched-levelalgo/metrics.cuh b/cpp/src/decisiontree/batched-levelalgo/metrics.cuh index d0a536b007..81af415bd2 100644 --- a/cpp/src/decisiontree/batched-levelalgo/metrics.cuh +++ b/cpp/src/decisiontree/batched-levelalgo/metrics.cuh @@ -22,25 +22,55 @@ #include "node.cuh" #include "split.cuh" +namespace { + +template +class NumericLimits; + +template <> +class NumericLimits { + public: + static constexpr double kMax = __FLT_MAX__; +}; + +template <> +class NumericLimits { + public: + static constexpr double kMax = __DBL_MAX__; +}; + +} // anonymous namespace + namespace ML { namespace DecisionTree { /** * @brief Compute gain based on gini impurity metric * - * @param[in] shist left/right class histograms for all bins - * [dim = nbins x 2 x nclasses] - * @param[in] sbins quantiles for the current column [len = nbins] - * @param[inout] sp will contain the per-thread best split so far - * @param[in] col current column - * @param[in] len total number of samples for the current node to be - * split - * @param[in] nbins number of bins - * @param[in] nclasses number of classes + * @param[in] shist left/right class histograms for all bins + * [dim = nbins x 2 x nclasses] + * @param[in] sbins quantiles for the current column + * [len = nbins] + * @param[inout] sp will contain the per-thread best split + * so far + * @param[in] col current column + * @param[in] len total number of samples for the current + * node to be split + * @param[in] nbins number of bins + * @param[in] nclasses number of classes + * @param[in] min_samples_leaf minimum number of samples per each leaf. + * Any splits that lead to a leaf node with + * samples fewer than min_samples_leaf will + * be ignored. + * @param[in] min_impurity_decrease minimum improvement in MSE metric. Any + * splits that do not improve (decrease) + * the MSE metric at least by this amount + * will be ignored. */ template DI void giniGain(int* shist, DataT* sbins, Split& sp, IdxT col, - IdxT len, IdxT nbins, IdxT nclasses) { + IdxT len, IdxT nbins, IdxT nclasses, IdxT min_samples_leaf, + DataT min_impurity_decrease) { constexpr DataT One = DataT(1.0); DataT invlen = One / len; for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) { @@ -49,25 +79,32 @@ DI void giniGain(int* shist, DataT* sbins, Split& sp, IdxT col, nLeft += shist[i * 2 * nclasses + j]; } auto nRight = len - nLeft; - auto invLeft = One / nLeft; - auto invRight = One / nRight; auto gain = DataT(0.0); - for (IdxT j = 0; j < nclasses; ++j) { - int val_i = 0; - if (nLeft != 0) { + // if there aren't enough samples in this split, don't bother! + if (nLeft < min_samples_leaf || nRight < min_samples_leaf) { + gain = -NumericLimits::kMax; + } else { + auto invLeft = One / nLeft; + auto invRight = One / nRight; + for (IdxT j = 0; j < nclasses; ++j) { + int val_i = 0; auto lval_i = shist[i * 2 * nclasses + j]; auto lval = DataT(lval_i); gain += lval * invLeft * lval * invlen; + val_i += lval_i; - } - if (nRight != 0) { auto rval_i = shist[i * 2 * nclasses + nclasses + j]; auto rval = DataT(rval_i); gain += rval * invRight * rval * invlen; + val_i += rval_i; + auto val = DataT(val_i) * invlen; + gain -= val * val; } - auto val = DataT(val_i) * invlen; - gain -= val * val; + } + // if the gain is not "enough", don't bother! + if (gain <= min_impurity_decrease) { + gain = -NumericLimits::kMax; } sp.update({sbins[i], col, gain, nLeft}); } @@ -76,18 +113,30 @@ DI void giniGain(int* shist, DataT* sbins, Split& sp, IdxT col, /** * @brief Compute gain based on entropy * - * @param[in] shist left/right class histograms for all bins - * [dim = nbins x 2 x nclasses] - * @param[in] sbins quantiles for the current column [len = nbins] - * @param[inout] sp will contain the per-thread best split so far - * @param[in] col current column - * @param[in] len total number of samples for the current node to be split - * @param[in] nbins number of bins - * @param[in] nclasses number of classes + * @param[in] shist left/right class histograms for all bins + * [dim = nbins x 2 x nclasses] + * @param[in] sbins quantiles for the current column + * [len = nbins] + * @param[inout] sp will contain the per-thread best split + * so far + * @param[in] col current column + * @param[in] len total number of samples for the current + * node to be split + * @param[in] nbins number of bins + * @param[in] nclasses number of classes + * @param[in] min_samples_leaf minimum number of samples per each leaf. + * Any splits that lead to a leaf node with + * samples fewer than min_samples_leaf will + * be ignored. + * @param[in] min_impurity_decrease minimum improvement in MSE metric. Any + * splits that do not improve (decrease) + * the MSE metric at least by this amount + * will be ignored. */ template DI void entropyGain(int* shist, DataT* sbins, Split& sp, IdxT col, - IdxT len, IdxT nbins, IdxT nclasses) { + IdxT len, IdxT nbins, IdxT nclasses, IdxT min_samples_leaf, + DataT min_impurity_decrease) { constexpr DataT One = DataT(1.0); DataT invlen = One / len; for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) { @@ -96,32 +145,41 @@ DI void entropyGain(int* shist, DataT* sbins, Split& sp, IdxT col, nLeft += shist[i * 2 * nclasses + j]; } auto nRight = len - nLeft; - auto invLeft = One / nLeft; - auto invRight = One / nRight; auto gain = DataT(0.0); - for (IdxT j = 0; j < nclasses; ++j) { - int val_i = 0; - if (nLeft != 0) { + // if there aren't enough samples in this split, don't bother! + if (nLeft < min_samples_leaf || nRight < min_samples_leaf) { + gain = -NumericLimits::kMax; + } else { + auto invLeft = One / nLeft; + auto invRight = One / nRight; + for (IdxT j = 0; j < nclasses; ++j) { + int val_i = 0; auto lval_i = shist[i * 2 * nclasses + j]; if (lval_i != 0) { auto lval = DataT(lval_i); - gain += raft::myLog(lval * invLeft) * lval * invlen; + gain += + raft::myLog(lval * invLeft) / raft::myLog(DataT(2)) * lval * invlen; } + val_i += lval_i; - } - if (nRight != 0) { auto rval_i = shist[i * 2 * nclasses + nclasses + j]; if (rval_i != 0) { auto rval = DataT(rval_i); - gain += raft::myLog(rval * invRight) * rval * invlen; + gain += raft::myLog(rval * invRight) / raft::myLog(DataT(2)) * rval * + invlen; } + val_i += rval_i; - } - if (val_i != 0) { - auto val = DataT(val_i) * invlen; - gain -= val * raft::myLog(val); + if (val_i != 0) { + auto val = DataT(val_i) * invlen; + gain -= val * raft::myLog(val) / raft::myLog(DataT(2)); + } } } + // if the gain is not "enough", don't bother! + if (gain <= min_impurity_decrease) { + gain = -NumericLimits::kMax; + } sp.update({sbins[i], col, gain, nLeft}); } } @@ -129,35 +187,54 @@ DI void entropyGain(int* shist, DataT* sbins, Split& sp, IdxT col, /** * @brief Compute gain based on MSE * - * @param[in] spred left/right child mean prediction for all bins - * [dim = 2 x bins] - * @param[in] scount left child count for all bins [len = nbins] - * @param[in] sbins quantiles for the current column [len = nbins] - * @param[inout] sp will contain the per-thread best split so far - * @param[in] col current column - * @param[in] len total number of samples for the current node to be split - * @param[in] nbins number of bins + * @param[in] spred left/right child mean prediction for all + * bins [dim = 2 x bins] + * @param[in] scount left child count for all bins + * [len = nbins] + * @param[in] sbins quantiles for the current column + * [len = nbins] + * @param[inout] sp will contain the per-thread best split + * so far + * @param[in] col current column + * @param[in] len total number of samples for the current + * node to be split + * @param[in] nbins number of bins + * @param[in] min_samples_leaf minimum number of samples per each leaf. + * Any splits that lead to a leaf node with + * samples fewer than min_samples_leaf will + * be ignored. + * @param[in] min_impurity_decrease minimum improvement in MSE metric. Any + * splits that do not improve (decrease) + * the MSE metric at least by this amount + * will be ignored. */ template DI void mseGain(DataT* spred, IdxT* scount, DataT* sbins, - Split& sp, IdxT col, IdxT len, IdxT nbins) { + Split& sp, IdxT col, IdxT len, IdxT nbins, + IdxT min_samples_leaf, DataT min_impurity_decrease) { auto invlen = DataT(1.0) / len; for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) { auto nLeft = scount[i]; auto nRight = len - nLeft; - auto invLeft = DataT(1.0) / nLeft; - auto invRight = DataT(1.0) / nRight; - auto valL = spred[i]; - auto valR = spred[nbins + i]; - // parent sum is basically sum of its left and right children - auto valP = (valL + valR) * invlen; - DataT gain = -valP * valP; - if (nLeft != 0) { + DataT gain; + // if there aren't enough samples in this split, don't bother! + if (nLeft < min_samples_leaf || nRight < min_samples_leaf) { + gain = -NumericLimits::kMax; + } else { + auto invLeft = DataT(1.0) / nLeft; + auto invRight = DataT(1.0) / nRight; + auto valL = spred[i]; + auto valR = spred[nbins + i]; + // parent sum is basically sum of its left and right children + auto valP = (valL + valR) * invlen; + gain = -valP * valP; gain += valL * invlen * valL * invLeft; - } - if (nRight != 0) { gain += valR * invlen * valR * invRight; } + // if the gain is not "enough", don't bother! + if (gain <= min_impurity_decrease) { + gain = -NumericLimits::kMax; + } sp.update({sbins[i], col, gain, nLeft}); } } @@ -165,32 +242,51 @@ DI void mseGain(DataT* spred, IdxT* scount, DataT* sbins, /** * @brief Compute gain based on MAE * - * @param[in] spred left/right child sum of abs diff of prediction for all - * bins [dim = 2 x bins] - * @param[in] spredP parent's sum of abs diff of prediction for all bins - * [dim = 2 x bins] - * @param[in] scount left child count for all bins [len = nbins] - * @param[in] sbins quantiles for the current column [len = nbins] - * @param[inout] sp will contain the per-thread best split so far - * @param[in] col current column - * @param[in] len total number of samples for current node to be split - * @param[in] nbins number of bins + * @param[in] spred left/right child sum of abs diff of + * prediction for all bins [dim = 2 x bins] + * @param[in] spredP parent's sum of abs diff of prediction + * for all bins [dim = 2 x bins] + * @param[in] scount left child count for all bins + * [len = nbins] + * @param[in] sbins quantiles for the current column + * [len = nbins] + * @param[inout] sp will contain the per-thread best split + * so far + * @param[in] col current column + * @param[in] len total number of samples for current node + * to be split + * @param[in] nbins number of bins + * @param[in] min_samples_leaf minimum number of samples per each leaf. + * Any splits that lead to a leaf node with + * samples fewer than min_samples_leaf will + * be ignored. + * @param[in] min_impurity_decrease minimum improvement in MSE metric. Any + * splits that do not improve (decrease) + * the MSE metric at least by this amount + * will be ignored. */ template DI void maeGain(DataT* spred, DataT* spredP, IdxT* scount, DataT* sbins, - Split& sp, IdxT col, IdxT len, IdxT nbins) { + Split& sp, IdxT col, IdxT len, IdxT nbins, + IdxT min_samples_leaf, DataT min_impurity_decrease) { auto invlen = DataT(1.0) / len; for (IdxT i = threadIdx.x; i < nbins; i += blockDim.x) { auto nLeft = scount[i]; auto nRight = len - nLeft; - DataT gain = spredP[i]; - if (nLeft != 0) { + DataT gain; + // if there aren't enough samples in this split, don't bother! + if (nLeft < min_samples_leaf || nRight < min_samples_leaf) { + gain = -NumericLimits::kMax; + } else { + gain = spredP[i]; gain -= spred[i]; - } - if (nRight != 0) { gain -= spred[i + nbins]; + gain *= invlen; + } + // if the gain is not "enough", don't bother! + if (gain <= min_impurity_decrease) { + gain = -NumericLimits::kMax; } - gain *= invlen; sp.update({sbins[i], col, gain, nLeft}); } } diff --git a/cpp/test/sg/decisiontree_batchedlevel_unittest.cu b/cpp/test/sg/decisiontree_batchedlevel_unittest.cu index 74597e395f..7e3f19420c 100644 --- a/cpp/test/sg/decisiontree_batchedlevel_unittest.cu +++ b/cpp/test/sg/decisiontree_batchedlevel_unittest.cu @@ -301,7 +301,8 @@ TEST_P(TestMetric, MSEGain) { computeSplitRegressionKernel <<>>( pred, nullptr, nullptr, pred_count, n_bins, params.max_depth, - params.min_samples_split, params.max_leaves, input, curr_nodes, 0, + params.min_samples_split, params.min_samples_leaf, + params.min_impurity_decrease, params.max_leaves, input, curr_nodes, 0, done_count, mutex, n_new_leaves, splits, nullptr, params.split_criterion); raft::update_host(h_splits.data(), splits, 1, 0); CUDA_CHECK(cudaGetLastError());