Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Accept NumPy (and CuPy) RandomState objects as estimator random_state #4753

Open
beckernick opened this issue May 23, 2022 · 6 comments · May be fixed by #6150
Open

[FEA] Accept NumPy (and CuPy) RandomState objects as estimator random_state #4753

beckernick opened this issue May 23, 2022 · 6 comments · May be fixed by #6150
Assignees
Labels
cuml-cpu Cython / Python Cython or Python issue feature request New feature or request good first issue Good for newcomers

Comments

@beckernick
Copy link
Member

Many estimators provide a random_state parameter to let users provide seeds for random number generators. Scikit-learn estimators can accept either an integer or a numpy.random.RandomState for random_state, and some PyData ecosystem tools (e.g. Boruta) pass RandomStates to estimators, so it would be nice if we could accept these as well.

import cuml
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
import numpy as np


rs = np.random.RandomState(seed=10)

X, y = make_classification(n_samples=1000)

clf = RandomForestClassifier(random_state=rs)
clf.fit(X, y)
clf = cuml.ensemble.RandomForestClassifier(random_state=rs)
clf.fit(X, y)
/home/nicholasb/miniconda3/envs/rapids-22.06/lib/python3.9/site-packages/cuml/internals/api_decorators.py:794: UserWarning: For reproducible results in Random Forest Classifier or for almost reproducible results in Random Forest Regressor, n_streams=1 is recommended. If n_streams is > 1, results may vary due to stream/thread timing differences, even when random_state is set
  return func(**kwargs)
/home/nicholasb/miniconda3/envs/rapids-22.06/lib/python3.9/site-packages/cuml/internals/api_decorators.py:567: UserWarning: To use pickling or GPU-based prediction first train using float32 data to fit the estimator
  ret_val = func(*args, **kwargs)

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [1], in <cell line: 14>()
     12 clf.fit(X, y)
     13 clf = cuml.ensemble.RandomForestClassifier(random_state=rs)
---> 14 clf.fit(X, y)

File ~/miniconda3/envs/rapids-22.06/lib/python3.9/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File ~/miniconda3/envs/rapids-22.06/lib/python3.9/site-packages/cuml/internals/api_decorators.py:409, in BaseReturnAnyDecorator.__call__.<locals>.inner_with_setters(*args, **kwargs)
    402 self_val, input_val, target_val = \
    403     self.get_arg_values(*args, **kwargs)
    405 self.do_setters(self_val=self_val,
    406                 input_val=input_val,
    407                 target_val=target_val)
--> 409 return func(*args, **kwargs)

File cuml/ensemble/randomforestclassifier.pyx:452, in cuml.ensemble.randomforestclassifier.RandomForestClassifier.fit()

TypeError: an integer is required

Conceptually, it's possible to generate an integer from a RandomState object fairly efficiently. Perhaps this might be a path forward?

import numpy as np
%timeit rs = np.random.RandomState(); rs.randint(0, 100000)
113 µs ± 989 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

import numpy as np
[np.random.RandomState(seed=12).randint(0, 100000) for x in range(10)]
[79691, 79691, 79691, 79691, 79691, 79691, 79691, 79691, 79691, 79691]

cc @teju85 , as we were chatting about Boruta earlier.

@beckernick beckernick added feature request New feature or request Cython / Python Cython or Python issue good first issue Good for newcomers labels May 23, 2022
@jakirkham
Copy link
Member

jakirkham commented May 25, 2022

It its worth noting that NumPy is moving away from RandomState. From the docs

The RandomState provides access to legacy generators. This generator is considered frozen and will have no further improvements.

The new API is Generator.

CuPy has implemented this ( cupy/cupy#4177 ) in 9.0.0. Dask is working on it ( dask/dask#9038 ). Scikit-learn is moving in this direction ( scikit-learn/scikit-learn#22271 ) ( scikit-learn/scikit-learn#22327 )

This isn't to say there isn't value in supporting RandomState. Just there is a new API that is undergoing adoption and it may be worthwhile to design code so that it will also work with this new API

@cjnolet
Copy link
Member

cjnolet commented May 25, 2022

This would happen after we just refactored the RAFT RNG API to follow the random_state model.

@beckernick
Copy link
Member Author

beckernick commented May 25, 2022

Good to know, thanks for sharing the new API and adoption progress John.

Sounds like unfortunate timing. Perhaps it makes sense to support RandomState for now in cuML Python to unblock users (as legacy usage will probably persist for a while) and then eventually design for / explore the Generator API in RAFT? There is still significant usage of RandomState (via a utility function) in sickit-learn, for example.

@jakirkham
Copy link
Member

Would suggest looking at the scikit-learn PRs linked (in particular the changes therein). They seem to be able to gracefully handle both APIs. So think it should be possible for us to do the same

@tarang-jain
Copy link
Contributor

tarang-jain commented May 31, 2022

I have made an initial draft here. There are some issues that I would like to point out.

  • In addition to integers, even float random_state values are accepted and are forcibly converted to <uintptr_t>. (This is also happening in the original implementation outside of the PR).
  • Currently we do not have support for Generator (the new API). This is because I am able to retrieve the value of the seed if the Generator object has an integer seed declared. However, if the Generator is declared without a seed, the default/random value of the seed is very large to be stored into the <uintptr_t> datatype (into which it is being cast). Furthermore, the Generator also allows a numpy array as the seed. I am able to retrieve the numpy array, but I do not know how this array can then be cast into <uintptr_t>
  • Another thing to note is the comparison between cuml and scikit-learn: In the check_random_state function here, a seed is accepted and then converted into a RandomState object if it is not already a RandomState object. On the other hand, we accept numpy.random.RandomStateand cupy.random.RandomState objects and attempt to extract the seed from them, so that the seed is later used as the hash for fnv1a32 in builder_kernels.cuh

@betatim betatim self-assigned this Nov 27, 2024
@betatim
Copy link
Member

betatim commented Nov 27, 2024

I'll see if we can reboot the PR or do something else to fix this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuml-cpu Cython / Python Cython or Python issue feature request New feature or request good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants