Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RF: fix performance of feature sampling for node splits #2292

Merged
merged 17 commits into from
Mar 24, 2023

Conversation

ahuber21
Copy link
Contributor

@ahuber21 ahuber21 commented Mar 10, 2023

Description

With merging #1725, the chooseFeatures() helper function has become inefficient.
It calls a helper, rng.uniformWithoutReplacement(), that implements a non-standard version of drawing k out of N without replacement. The algo is O(N^2) compute time and with the changes from #1725 N went from _nFeaturesPerNode to maxFeatures, where maxFeatures is the total number of features in the dataset and _nFeaturesPerNode is the number of features per node split (default value: sqrt(maxFeatures)).

It was reported in uxlfoundation/scikit-learn-intelex#1050 and uxlfoundation/scikit-learn-intelex#984 that our RF classification implementation is very slow and under some circumstances even stalls. I think this is related.

In an initial investigation I studied the training time for different values of max_features
image

We see that for large values of max_features, the intelex implementation outperforms the stock implementation. But for smaller values (in particular including the default value sqrt(N)), we are slower.
Profiling the code confirmed that indeed we spend most of our time sampling features
image

I have added a few changes in this PR

  • Added a helper function in service_memory.h to initialize memory with sequential entries
  • Implemented a modified Fisher Yates sampling in service_rng.h
    • We are using a provided buffer, therefore, no initialization of the indices is required
    • This reduces the time complexity from O(N) -> O(k)
  • Implemented a modified Fisher Yates sampling in findBestSplitSerial()
    • With the changes from RF: support filtering of const features #1725 we still generate a sample of size N, so we do not benefit from the optimization in the previous bullet
    • Drawing numbers only when necessary reduces the time complecity from O(N) to O(k + number of constant features), which is the optimal solution
  • Exposed the useConstFeatures setting to the training params - it used to be hardcoded to false

With the changes the picture changes
image
Note: This second run was performed on a system with fewer cores, explaining the longer overall training times. In the comparison against the stock implementation, however, this effect cancels out.

The profiler confirms that the bottleneck of calling the random number generator disappeared
image

Most of the compute time is now spent creating the histograms.

Statistical correctness

I have tested the statistical correctness of my Fisher Yates sampling technique numerically.
I draw 5 out of 10 numbers, meaning each number has p=0.5 of being drawn.
Running the test 10,000,000 times confirms the uniform probability of all numbers

# run with RNG seed 777
$ _results/dpcpp_intel64_a/rng.exe 10000000 777
0,1,2,3,4,5,6,7,8,9,
5000024,5000684,5000634,5002295,5000448,4999619,4998491,4998169,5000202,4999434,

# run with RNG seed 1337
$ _results/dpcpp_intel64_a/rng.exe 10000000 1337
0,1,2,3,4,5,6,7,8,9,
4998947,5002204,5000942,4997803,4999591,4999615,4999998,5000007,5001135,4999758,

Find the source code of the test file below

// file: rng.cpp
#include <cstdlib>
#include <iostream>
#include <vector>

#include "algorithms/engines/mt2203/mt2203.h"
#include "daal.h"
#include "src/algorithms/engines/engine_batch_impl.h"
#include "src/externals/service_rng.h"

static const int M = 5;  // draw M
static const int N = 10; // out of N

template <typename T> void printCounts(const std::vector<T> &counts) {
  if (N < 50) {
    for (int i = 0; i < N; ++i)
      std::cout << i << ",";
    std::cout << std::endl;
    for (int i = 0; i < N; ++i)
      std::cout << counts[i] << ",";
    std::cout << std::endl;
  }
}

int main(int argc, char *argv[]) {
  int EXPERIMENTS = 1000000;
  if (argc > 1) {
    EXPERIMENTS = atoi(argv[1]);
  }

  unsigned long seed = 777;
  if (argc > 2) {
    seed = atol(argv[2]);
  }

  auto engine = daal::algorithms::engines::mt2203::Batch<>::create(seed);
  auto engineImpl =
      dynamic_cast<daal::algorithms::engines::internal::BatchBaseImpl *>(
          engine.get());

  daal::internal::RNGs<int, daal::CpuType::avx512> rng;

  int result[N];
  int buffer[N];

  std::vector<unsigned int> counts(N, 0);

  for (int i = 0; i < N; ++i) {
    result[i] = -1;
    buffer[i] = i;
  }

  for (int r = 0; r < EXPERIMENTS; ++r) {
    rng.uniformWithoutReplacement(M, result, buffer, engineImpl->getState(),
                                  0, N);

    for (int i = 0; i < M; ++i) {
      counts[result[i]]++;
    }
  }

  if (EXPERIMENTS < 50) {
    std::cout << "B: ";
    for (int i = 0; i < N; ++i)
      std::cout << buffer[i] << "\t";
    std::cout << std::endl;
    std::cout << "R: ";
    for (int i = 0; i < N; ++i)
      std::cout << result[i] << "\t";
    std::cout << std::endl;
  }

  printCounts(counts);

  return 0;
}

@ahuber21
Copy link
Contributor Author

/intelci: run

@ahuber21
Copy link
Contributor Author

The work is not fully finished. For instance, I do not know if / how we should escalate the error if rng.uniform() does not work. There is a // TODO comment in the latest commit.

@@ -138,6 +138,17 @@ void service_memset_seq(T * const ptr, const T value, const size_t num)
}
}

template <typename T, CpuType cpu>
void service_memset_ser(T * const ptr, const T startValue, const size_t num)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not obvious meaning of function name. Maybe, service_memset_range?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with service_memset_incrementing, do you agree?

@ahuber21 ahuber21 changed the title Random forest investigation RF: fix performance of feature sampling for node splits Mar 14, 2023
@ahuber21
Copy link
Contributor Author

Hi @Alexsandruss and @samir-nasibli could you comment on the test failures?
As far as I can tell, we are comparing against previous versions of the algorithm, not some sort of ground truth.
Since I changed to node splitting algorithm, it is expected that the resulting trees in the forest will have different numerics.
Is there any way I can confirm the results are statistically still correct?

Is it okay to update decision_forest_regression_batch.csv with the updated numbers?

@ahuber21
Copy link
Contributor Author

I have created a PR in scikit-learn-intelex to follow up on the CI failures
uxlfoundation/scikit-learn-intelex#1213

@Alexsandruss could you sign off on the changes requested in this PR?

@@ -134,6 +134,7 @@ class DAAL_EXPORT Parameter
Default is 256. Increasing the number results in higher computation costs */
size_t minBinSize; /*!< Used with 'hist' split finding method only.
Minimal number of observations in a bin. Default is 5 */
bool useConstFeatures; /*!< Use or ignore constant-valued features when splitting nodes. Default is false */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameter should be added to oneDAL interfaces too. Otherwise, it creates inconsistency also in scikit-learn-intelex, where RandomForest is going to use oneDAL interfaces by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to add it here: 6ba776f

I hope I found all relevant places. Let me know if you see something obvious missing. Thank you!

@Alexsandruss
Copy link
Contributor

Failures should be investigated and fixed on daal4py side:

FAIL: test_decision_forest_regression_default_dense_batch (test_examples.TestExNpyArray)
FAIL: test_decision_forest_regression_hist_batch (test_examples.TestExNpyArray)
FAIL: test_decision_forest_regression_default_dense_batch (test_examples.TestExPandasDF)
FAIL: test_decision_forest_regression_hist_batch (test_examples.TestExPandasDF)
=================================== FAILURES ===================================
_______________ test_classifier_class_weight_iris[class_weight5] _______________
>       assert ratio <= LOG_LOSS_RATIO, reason
E       AssertionError: Classifier class weight: scikit_log_loss=0.10065237679616006, daal4py_log_loss=0.14592264814110162
E       assert 1.4497685279367258 <= 1.4

@ahuber21
Copy link
Contributor Author

log loss went from 0.136 to 0.149 (+9.6 %), which took the ratio from ~1.36 to 1.47. This is indeed reason to worry, but I'm convinced at this point that there is a bigger accuracy issue in DF, which is currently investigated in a separate issue (see uxlfoundation/scikit-learn-intelex#1090).
I therefore doubt that my changes add an additional error, and think that they emphasize the existing one by chance.
The rmse and r2_score from https://github.com/IntelPython/scikit-learn_bench/tree/master/configs/sklearn/performance agree all very well*.

[*] All ratios between 0.99 and 1.02, except for one measurement on the year_prediction_msd dataset, where 7-9% disagreement are observed. Again, I would like to tackle that in a separate issue.

Is it okay if I temporarily allow ratios <= 1.5 in the unit test?

@ahuber21
Copy link
Contributor Author

@Alexsandruss updated the test in the other repo: uxlfoundation/scikit-learn-intelex#1213

@Alexsandruss
Copy link
Contributor

9% quality metric drop for some datasets looks unacceptable. Did you test datasets from scikit-learn_bench/configs/testing/metrics/rf_*.json?

@ahuber21
Copy link
Contributor Author

ahuber21 commented Mar 24, 2023

I'm running the tests right now. But regardless, could we split these issues? There is a critical underperformance in the current implementation. Compared to the stock version, the latest intelex release has a 25 % worse MSE on the year prediction dataset (the same where I see the +9%) and a factor of 3 (!!) is reported in uxlfoundation/scikit-learn-intelex#1213
Technically, the old sampling was not uniformly random, which is why some changes are expected

Copy link
Contributor

@Alexsandruss Alexsandruss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge when fix on Python side is ready

Copy link
Contributor

@Vika-F Vika-F left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good. Special thanks for adding comments.
But it seems that the error handling needs to be improved.

Comment on lines 1126 to 1130
int errorcode = rng.uniform(1, &swapIdx, _engineImpl->getState(), 0, maxFeatures - i);
if (errorcode)
{
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior would be to return Status (or SafeStatus if this code is running in parallel region), not bool, from findBestSplit* functions. This Status will be able to contain a description of the error. Maybe a new error code also needs to be added to handle failures in random number generators.

@@ -619,7 +619,7 @@ protected:
const size_t nGen = (!_par.memorySavingMode && !_maxLeafNodes && !_useConstFeatures) ? n : _nFeaturesPerNode;
*_numElems += n;
RNGs<IndexType, cpu> rng;
rng.uniformWithoutReplacement(nGen, _aFeatureIdx.get(), _aFeatureIdx.get() + nGen, _engineImpl->getState(), 0, n);
rng.drawKFromBufferWithoutReplacement(nGen, _aFeatureIdx.get(), _aFeatureIdx.get() + nGen, _engineImpl->getState(), n);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drawKFromBufferWithoutReplacement returns error code, but it is not checked here.
Please see the other comment in this review that describes more correct way of handling errors in DAAL.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't add it to this part because it's dead code, will remove it soon

@ahuber21 ahuber21 force-pushed the random-forest-investigation branch from 3efc0d0 to cb9d780 Compare March 24, 2023 09:58
@ahuber21
Copy link
Contributor Author

ahuber21 commented Mar 24, 2023

Force-pushed after rebase. @Vika-F, I will add the status checks next.

@ahuber21
Copy link
Contributor Author

ahuber21 commented Mar 24, 2023

@Vika-F I have added status checks with minimal changes, running a few performance benchmarks, just to make sure nothing changed.

Edit: Can confirm that nothing has changed

@ahuber21 ahuber21 merged commit af225ad into uxlfoundation:master Mar 24, 2023
@ahuber21 ahuber21 deleted the random-forest-investigation branch March 24, 2023 19:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants