Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Confusing behavior of predict() / predict_proba() of RF when output_class is specified #3347

Closed
hcho3 opened this issue Jan 6, 2021 · 4 comments · Fixed by #3609
Closed
Assignees
Labels
bug Something isn't working Cython / Python Cython or Python issue

Comments

@hcho3
Copy link
Contributor

hcho3 commented Jan 6, 2021

Describe the bug

The predict() and predict_proba() methods of RandomForestClassifier behave in a confusing way when output_class parameter is specified.

Code to reproduce bug

Binary classification

from cuml.ensemble import RandomForestClassifier as rfc
from sklearn.datasets import make_classification
import numpy as np

X, y = make_classification(n_samples=10000, n_features=20, n_informative=10, n_classes=2)
X, y = X.astype(np.float32), y.astype(np.int32)

model = rfc(n_estimators=800, 
            max_depth=16, 
            random_state=0,
            min_samples_leaf=50)
model.fit(X, y)

## predict() works as expected
out1 = model.predict(X)                          # Outputs class predictions (0 or 1)
out2 = model.predict(X, output_class=True)       # Outputs class predictions (0 or 1)
out3 = model.predict(X, output_class=False)      # Outputs probability for the positive class
np.testing.assert_array_equal(out1, out2)        # out1 is identical to out2
np.testing.assert_array_equal(out3 > 0.5, out1)  # Thresholding out3 with 0.5 yields the correct
                                                 # class predictions

## predict_proba() behaves weirdly when output_class is specified
out4 = model.predict_proba(X)                     # Outputs probability for the positive class
out5 = model.predict_proba(X, output_class=True)  # Outputs probability for the positive class (*BAD*)
np.testing.assert_array_equal(out4, out5)         # out1 is identical to out2 (*BAD*)
out6 = model.predict_proba(X, output_class=False) # *BAD* This line causes an exception
  • Setting output_class=True has no effect on predict_proba(). The predict_proba() function still produces probability prediction, when the presence of output_class=True suggests a class prediction.
  • Setting output_class=False causes predict_proba() to crash:
NotImplementedError: Predict_proba function is not available for Regression models.
If you are  using a Classification model, please set `output_class=True` while creating the FIL model.

Multi-class classification

from cuml.ensemble import RandomForestClassifier as rfc
from sklearn.datasets import make_classification
import numpy as np

X, y = make_classification(n_samples=10000, n_features=20, n_informative=10, n_classes=3)
X, y = X.astype(np.float32), y.astype(np.int32)

model = rfc(n_estimators=800, 
            max_depth=16, 
            random_state=0,
            min_samples_leaf=50)
model.fit(X, y)

## predict() works as expected if output_class is not given or is True
out1 = model.predict(X)                          # Outputs class predictions (0, 1, or 2)
out2 = model.predict(X, output_class=True)       # Outputs class predictions (0, 1, or 2)
np.testing.assert_array_equal(out1, out2)        # out1 is identical to out2

## predict_proba() behaves weirdly when output_class=True is specified
out4 = model.predict_proba(X)                     # Outputs probabilities for all classes
out5 = model.predict_proba(X, output_class=True)  # Outputs probabilities for all classes (*BAD*)
np.testing.assert_array_equal(out4, out5)         # out1 is identical to out2 (*BAD*)

# The following lines causes an unrecoverable crash
out3 = model.predict(X, output_class=False)       # *BAD* causes crash
out6 = model.predict_proba(X, output_class=False) # *BAD* causes crash
  • Setting output_class=True has no effect on predict_proba(). The predict_proba() function still produces probability prediction, when the presence of output_class=True suggests a class prediction.
  • Setting output_class=False causes predict() to crash:
terminate called after throwing an instance of 'raft::exception'
  what():  exception occured! file=/opt/conda/envs/rapids/conda-bld/libcuml_1607637233004/work/cpp/src/fil/fil.cu line=699:
      output_class==true is required for multi-class models
  • Setting output_class=False causes predict_proba() to crash:
terminate called after throwing an instance of 'raft::exception'
  what():  exception occured! file=/opt/conda/envs/rapids/conda-bld/libcuml_1607637233004/work/cpp/src/fil/fil.cu line=699:
      output_class==true is required for multi-class models

Suggested fix

This behavior makes cuML RF highly inconsistent with sklearn RF. I suggest the following list of modification:

  • Do away with the output_class parameter entirely. Instead, the predict() method should aways produce class prediction, and the predict_proba() method should always produce probability prediction.
  • Also do away with the threshold parameter from predict_proba() method. The predict() method can keep it.
  • The predict_proba() method should always return an array of dimension (n_samples, n_classes), no matter whether n_classes is 2 (binary classification) or more than 2 (multi-class classification). Right now, the predict_proba() method in cuML RF returns an array of size (n_samples,) for binary classification.

Consult the scikit-learn RF's documentation. Screenshot:
Screen Shot 2021-01-05 at 7 50 49 PM

Environment details

  • Environment location: Bare-metal
  • Linux Distro/Architecture: Ubuntu 18.04 amd64
  • GPU Model/Driver: GPU Model/Driver: Quadro RTX 8000 and driver 440.33.01
  • CUDA: 11.0
  • Method of cuDF & cuML install: conda 0.17 stable
Conda environment

# packages in environment at /home/phcho/miniconda3/envs/rapids-0.17:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       1_gnu    conda-forge
abseil-cpp                20200225.2           he1b5a44_2    conda-forge
aiohttp                   3.7.3            py37h5e8e339_0    conda-forge
alsa-lib                  1.2.3                h516909a_0    conda-forge
appdirs                   1.4.4              pyh9f0ad1d_0    conda-forge
argon2-cffi               20.1.0           py37h4abf009_2    conda-forge
arrow-cpp                 1.0.1           py37h2318771_14_cuda    conda-forge
arrow-cpp-proc            2.0.0                      cuda    conda-forge
async-timeout             3.0.1                   py_1000    conda-forge
async_generator           1.10                       py_0    conda-forge
attrs                     20.3.0             pyhd3deb0d_0    conda-forge
aws-c-common              0.4.59               h36c2ea0_1    conda-forge
aws-c-event-stream        0.1.6                had2084c_6    conda-forge
aws-checksums             0.1.10               h4e93380_0    conda-forge
aws-sdk-cpp               1.8.63               h9b98462_0    conda-forge
backcall                  0.2.0              pyh9f0ad1d_0    conda-forge
backports                 1.0                        py_2    conda-forge
backports.functools_lru_cache 1.6.1                      py_0    conda-forge
blazingsql                0.17.0                   pypi_0    pypi
bleach                    3.2.1              pyh9f0ad1d_0    conda-forge
bokeh                     2.2.3            py37h89c1867_0    conda-forge
boost                     1.72.0           py37h48f8a5e_1    conda-forge
boost-cpp                 1.72.0               h8e57a91_0    conda-forge
brotli                    1.0.9                he1b5a44_3    conda-forge
brotlipy                  0.7.0           py37hb5d75c8_1001    conda-forge
bzip2                     1.0.8                h7f98852_4    conda-forge
c-ares                    1.17.1               h36c2ea0_0    conda-forge
ca-certificates           2020.12.5            ha878542_0    conda-forge
cairo                     1.16.0            hcf35c78_1003    conda-forge
certifi                   2020.12.5        py37h89c1867_0    conda-forge
cffi                      1.14.4           py37hc58025e_1    conda-forge
cfitsio                   3.470                hb418390_7    conda-forge
chardet                   3.0.4           py37he5f6b98_1008    conda-forge
click                     7.1.2              pyh9f0ad1d_0    conda-forge
click-plugins             1.1.1                      py_0    conda-forge
cligj                     0.7.1              pyhd8ed1ab_0    conda-forge
cloudpickle               1.6.0                      py_0    conda-forge
colorcet                  2.0.1                      py_0    conda-forge
cryptography              3.3.1            py37h7f0c10b_0    conda-forge
cudatoolkit               11.0.221             h6bb024c_0    nvidia
cudf                      0.17.0          cuda_11.0_py37_gf56ef850e6_0    rapidsai
cudf_kafka                0.17.0          py37_gf56ef850e6_0    rapidsai
cudnn                     8.0.0                cuda11.0_0    nvidia
cugraph                   0.17.0          py37_gb58e49e8_0    rapidsai
cuml                      0.17.0          cuda11.0_py37_ga48e0ffbf_0    rapidsai
cupy                      8.0.0            py37h0ce7dbb_0    rapidsai
curl                      7.71.1               he644dc0_8    conda-forge
cusignal                  0.17.0          py38_ge853242_0    rapidsai
cuspatial                 0.17.0          py37_g897b304_0    rapidsai
custreamz                 0.17.0          py37_gf56ef850e6_0    rapidsai
cuxfilter                 0.17.0          py37_g340129d_0    rapidsai
cyrus-sasl                2.1.27               h063b49f_1    conda-forge
cytoolz                   0.11.0           py37h4abf009_1    conda-forge
dask                      2020.12.0          pyhd8ed1ab_0    conda-forge
dask-core                 2020.12.0          pyhd8ed1ab_0    conda-forge
dask-cuda                 0.17.0                   py37_0    rapidsai
dask-cudf                 0.17.0          py37_gf56ef850e6_0    rapidsai
datashader                0.11.1             pyh9f0ad1d_0    conda-forge
datashape                 0.5.4                      py_1    conda-forge
decorator                 4.4.2                      py_0    conda-forge
defusedxml                0.6.0                      py_0    conda-forge
distributed               2020.12.0        py37h89c1867_0    conda-forge
dlpack                    0.3                  he1b5a44_1    conda-forge
entrypoints               0.3             pyhd8ed1ab_1003    conda-forge
expat                     2.2.9                he1b5a44_2    conda-forge
faiss-proc                1.0.0                      cuda    conda-forge
fastavro                  1.2.3            py37h5e8e339_0    conda-forge
fastrlock                 0.5              py37h3340039_1    conda-forge
fiona                     1.8.13           py37h0492a4a_1    conda-forge
fontconfig                2.13.1            h86ecdb6_1001    conda-forge
freetype                  2.10.4               h7ca028e_0    conda-forge
freexl                    1.0.5             h516909a_1002    conda-forge
fsspec                    0.8.5              pyhd8ed1ab_0    conda-forge
future                    0.18.2           py37h89c1867_2    conda-forge
gdal                      3.0.4           py37h4b180d9_10    conda-forge
geopandas                 0.8.1                      py_0    conda-forge
geos                      3.8.1                he1b5a44_0    conda-forge
geotiff                   1.6.0                h05acad5_0    conda-forge
gettext                   0.19.8.1          h0b5b191_1005    conda-forge
gflags                    2.2.2             he1b5a44_1004    conda-forge
giflib                    5.2.1                h36c2ea0_2    conda-forge
glib                      2.66.4               hcd2ae1e_1    conda-forge
glog                      0.4.0                h49b9bf7_3    conda-forge
google-cloud-cpp          1.16.0               he4a878c_2    conda-forge
google-cloud-cpp-common   0.25.0               he83eced_7    conda-forge
googleapis-cpp            0.10.0               h6b1abdc_4    conda-forge
graphite2                 1.3.13            h58526e2_1001    conda-forge
grpc-cpp                  1.32.0               h7997a97_1    conda-forge
gtest                     1.10.0               h4bd325d_5    conda-forge
harfbuzz                  2.4.0                h9f30f68_3    conda-forge
hdf4                      4.2.13            h10796ff_1004    conda-forge
hdf5                      1.10.6          nompi_h6a2412b_1113    conda-forge
heapdict                  1.0.1                      py_0    conda-forge
icu                       64.2                 he1b5a44_1    conda-forge
idna                      2.10               pyh9f0ad1d_0    conda-forge
importlib-metadata        3.3.0            py37h89c1867_2    conda-forge
importlib_metadata        3.3.0                hd8ed1ab_2    conda-forge
ipykernel                 5.4.2            py37h888b3d9_0    conda-forge
ipython                   7.19.0           py37h888b3d9_0    conda-forge
ipython_genutils          0.2.0                      py_1    conda-forge
ipywidgets                7.6.2              pyhd3deb0d_1    conda-forge
jedi                      0.18.0           py37h89c1867_1    conda-forge
jinja2                    2.11.2             pyh9f0ad1d_0    conda-forge
joblib                    1.0.0              pyhd8ed1ab_0    conda-forge
jpeg                      9d                   h36c2ea0_0    conda-forge
jpype1                    1.2.1            py37h2527ec5_0    conda-forge
json-c                    0.13.1            hbfbb72e_1002    conda-forge
jsonschema                3.2.0                      py_2    conda-forge
jupyter-server-proxy      1.5.2              pyhd8ed1ab_0    conda-forge
jupyter_client            6.1.7                      py_0    conda-forge
jupyter_core              4.7.0            py37h89c1867_0    conda-forge
jupyterlab_pygments       0.1.2              pyh9f0ad1d_0    conda-forge
jupyterlab_widgets        1.0.0              pyhd8ed1ab_1    conda-forge
kealib                    1.4.14               h0042707_0    conda-forge
krb5                      1.17.2               h926e7f8_0    conda-forge
lcms2                     2.11                 hcbb858e_1    conda-forge
ld_impl_linux-64          2.35.1               hea4e1c9_1    conda-forge
libblas                   3.8.0               17_openblas    conda-forge
libcblas                  3.8.0               17_openblas    conda-forge
libcrc32c                 1.1.1                he1b5a44_2    conda-forge
libcudf                   0.17.0          cuda11.0_gf56ef850e6_0    rapidsai
libcudf_kafka             0.17.0            gf56ef850e6_0    rapidsai
libcugraph                0.17.0          cuda11.0_gb58e49e8_0    rapidsai
libcuml                   0.17.0          cuda11.0_ga48e0ffbf_0    rapidsai
libcumlprims              0.17.0          cuda11.0_g8216947_0    nvidia
libcurl                   7.71.1               hcdd3856_8    conda-forge
libcuspatial              0.17.0          cuda11.0_g897b304_0    rapidsai
libdap4                   3.20.6               h1d1bd15_1    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libev                     4.33                 h516909a_1    conda-forge
libevent                  2.1.10               hcdb4288_3    conda-forge
libfaiss                  1.6.3           h328c4c8_1_cuda    rapidsai
libffi                    3.3                  h58526e2_2    conda-forge
libgcc-ng                 9.3.0               h5dbcf3e_17    conda-forge
libgcrypt                 1.8.7                h36c2ea0_0    conda-forge
libgdal                   3.0.4               he6a97d6_10    conda-forge
libgfortran-ng            9.3.0               he4bcb1c_17    conda-forge
libgfortran5              9.3.0               he4bcb1c_17    conda-forge
libglib                   2.66.4               h164308a_1    conda-forge
libgomp                   9.3.0               h5dbcf3e_17    conda-forge
libgpg-error              1.41                 h9c3ff4c_0    conda-forge
libgsasl                  1.8.0                         2    conda-forge
libhwloc                  2.3.0                h5e5b7d1_1    conda-forge
libiconv                  1.16                 h516909a_0    conda-forge
libkml                    1.3.0             hd79254b_1012    conda-forge
liblapack                 3.8.0               17_openblas    conda-forge
libllvm10                 10.0.1               he513fc3_3    conda-forge
libnetcdf                 4.7.4           nompi_h56d31a8_107    conda-forge
libnghttp2                1.41.0               h8cfc5f6_2    conda-forge
libntlm                   1.4               h7f98852_1002    conda-forge
libopenblas               0.3.10          pthreads_h4812303_5    conda-forge
libpng                    1.6.37               h21135ba_2    conda-forge
libpq                     12.3                 h255efa7_3    conda-forge
libprotobuf               3.13.0.1             h8b12597_0    conda-forge
librdkafka                1.5.3                h54cafa9_0    conda-forge
librmm                    0.17.0          cuda11.0_gc4cc945_0    rapidsai
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libspatialindex           1.9.3                he1b5a44_3    conda-forge
libspatialite             4.3.0a            h2482549_1038    conda-forge
libssh2                   1.9.0                hab1572f_5    conda-forge
libstdcxx-ng              9.3.0               h2ae2ef3_17    conda-forge
libthrift                 0.13.0               h5aa387f_6    conda-forge
libtiff                   4.2.0                hdc55705_0    conda-forge
libutf8proc               2.6.1                h7f98852_0    conda-forge
libuuid                   2.32.1            h7f98852_1000    conda-forge
libuv                     1.34.0               h516909a_0    conda-forge
libwebp                   1.1.0                h76fa15c_4    conda-forge
libwebp-base              1.1.0                h36c2ea0_3    conda-forge
libxcb                    1.13              h14c3975_1002    conda-forge
libxgboost                1.3.0dev.rapidsai0.17      cuda11.0_0    rapidsai
libxml2                   2.9.10               hee79883_0    conda-forge
llvmlite                  0.35.0           py37h9d7f4d0_0    conda-forge
locket                    0.2.0                      py_2    conda-forge
lz4-c                     1.9.2                he1b5a44_3    conda-forge
markdown                  3.3.3              pyh9f0ad1d_0    conda-forge
markupsafe                1.1.1            py37hb5d75c8_2    conda-forge
mistune                   0.8.4           py37h4abf009_1002    conda-forge
msgpack-python            1.0.2            py37h2527ec5_0    conda-forge
multidict                 5.1.0            py37h5e8e339_0    conda-forge
multipledispatch          0.6.0                      py_0    conda-forge
munch                     2.5.0                      py_0    conda-forge
nbclient                  0.5.1                      py_0    conda-forge
nbconvert                 6.0.7            py37h89c1867_3    conda-forge
nbformat                  5.0.8                      py_0    conda-forge
nccl                      2.7.8.1            h4962215_100    nvidia
ncurses                   6.2                  h58526e2_4    conda-forge
nest-asyncio              1.4.3              pyhd8ed1ab_0    conda-forge
netifaces                 0.10.9          py37h8f50634_1003    conda-forge
networkx                  2.5                        py_0    conda-forge
nodejs                    13.13.0              hf5d1a2b_0    conda-forge
notebook                  6.1.6            py37h89c1867_0    conda-forge
numba                     0.52.0           py37hdc94413_0    conda-forge
numpy                     1.19.5           py37haa41c4c_0    conda-forge
nvtx                      0.2.1            py37h8f50634_2    conda-forge
olefile                   0.46               pyh9f0ad1d_1    conda-forge
openjdk                   11.0.8               hacce0ff_0    conda-forge
openjpeg                  2.3.1                hf7af979_3    conda-forge
openssl                   1.1.1i               h7f98852_0    conda-forge
orc                       1.6.5                hd3605a7_0    conda-forge
packaging                 20.8               pyhd3deb0d_0    conda-forge
pandas                    1.1.5            py37hdc94413_0    conda-forge
pandoc                    2.11.3.2             h7f98852_0    conda-forge
pandocfilters             1.4.2                      py_1    conda-forge
panel                     0.9.7                      py_0    conda-forge
param                     1.10.0                     py_0    conda-forge
parquet-cpp               1.5.1                         2    conda-forge
parso                     0.8.1              pyhd8ed1ab_0    conda-forge
partd                     1.1.0                      py_0    conda-forge
pcre                      8.44                 he1b5a44_0    conda-forge
pexpect                   4.8.0              pyh9f0ad1d_2    conda-forge
pickle5                   0.0.11           py37h8f50634_0    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    8.1.0            py37he6b4880_0    conda-forge
pip                       20.3.3             pyhd8ed1ab_0    conda-forge
pixman                    0.38.0            h516909a_1003    conda-forge
poppler                   0.87.0               h4190859_1    conda-forge
poppler-data              0.4.10                        0    conda-forge
postgresql                12.3                 hc2f5b80_3    conda-forge
proj                      7.0.0                h966b41f_5    conda-forge
prometheus_client         0.9.0              pyhd3deb0d_0    conda-forge
prompt-toolkit            3.0.8              pyha770c72_0    conda-forge
protobuf                  3.13.0.1         py37h745909e_1    conda-forge
psutil                    5.8.0            py37h5e8e339_0    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
py-xgboost                1.3.0dev.rapidsai0.17  cuda11.0py37_0    rapidsai
pyarrow                   1.0.1           py37hbeecfa9_14_cuda    conda-forge
pycparser                 2.20               pyh9f0ad1d_2    conda-forge
pyct                      0.4.6                      py_0    conda-forge
pyct-core                 0.4.6                      py_0    conda-forge
pydeck                    0.5.0              pyh9f0ad1d_0    conda-forge
pyee                      7.0.4              pyh9f0ad1d_0    conda-forge
pygments                  2.7.3              pyhd8ed1ab_0    conda-forge
pyhive                    0.6.2              pyh9f0ad1d_0    conda-forge
pynvml                    8.0.4                      py_1    conda-forge
pyopenssl                 20.0.1             pyhd8ed1ab_0    conda-forge
pyparsing                 2.4.7              pyh9f0ad1d_0    conda-forge
pyppeteer                 0.2.2                      py_1    conda-forge
pyproj                    2.6.1.post1      py37h34dd122_0    conda-forge
pyrsistent                0.17.3           py37h4abf009_1    conda-forge
pysocks                   1.7.1            py37he5f6b98_2    conda-forge
python                    3.7.9           hffdb5ce_0_cpython    conda-forge
python-confluent-kafka    1.5.0            py37h8f50634_0    conda-forge
python-dateutil           2.8.1                      py_0    conda-forge
python_abi                3.7                     1_cp37m    conda-forge
pytz                      2020.5             pyhd8ed1ab_0    conda-forge
pyviz_comms               2.0.1              pyhd3deb0d_0    conda-forge
pyyaml                    5.3.1            py37hb5d75c8_1    conda-forge
pyzmq                     20.0.0           py37h5a562af_1    conda-forge
rapids                    0.17.0          cuda11.0_py37_g180f433_165    rapidsai
rapids-blazing            0.17.0          cuda11.0_py37_g180f433_165    rapidsai
rapids-xgboost            0.17.0          cuda11.0_py37_g180f433_165    rapidsai
re2                       2020.10.01           he1b5a44_0    conda-forge
readline                  8.0                  he28a2e2_2    conda-forge
requests                  2.25.1             pyhd3deb0d_0    conda-forge
rmm                       0.17.0          cuda_11.0_py37_gc4cc945_0    rapidsai
rtree                     0.9.7            py37h0b55af0_0    conda-forge
sasl                      0.2.1           py37h3340039_1002    conda-forge
scikit-learn              0.24.0           py37h69acf81_0    conda-forge
scipy                     1.6.0            py37h14a347d_0    conda-forge
send2trash                1.5.0                      py_0    conda-forge
setuptools                49.6.0           py37he5f6b98_2    conda-forge
shapely                   1.7.1            py37hedb1597_1    conda-forge
simpervisor               0.4                pyhd8ed1ab_0    conda-forge
six                       1.15.0             pyh9f0ad1d_0    conda-forge
snappy                    1.1.8                he1b5a44_3    conda-forge
sortedcontainers          2.3.0              pyhd8ed1ab_0    conda-forge
spdlog                    1.7.0                hc9558a2_2    conda-forge
sqlalchemy                1.3.22           py37h5e8e339_0    conda-forge
sqlite                    3.34.0               h74cdb3f_0    conda-forge
streamz                   0.6.1              pyhd3deb0d_1    conda-forge
tbb                       2020.2               h4bd325d_2    conda-forge
tblib                     1.6.0                      py_0    conda-forge
terminado                 0.9.2            py37h89c1867_0    conda-forge
testpath                  0.4.4                      py_0    conda-forge
threadpoolctl             2.1.0              pyh5ca1d4c_0    conda-forge
thrift                    0.13.0           py37h3340039_2    conda-forge
thrift_sasl               0.4.2            py37h8f50634_0    conda-forge
tiledb                    1.7.7                h8efa9f0_3    conda-forge
tk                        8.6.10               h21135ba_1    conda-forge
toolz                     0.11.1                     py_0    conda-forge
tornado                   6.1              py37h4abf009_0    conda-forge
tqdm                      4.55.1             pyhd8ed1ab_0    conda-forge
traitlets                 5.0.5                      py_0    conda-forge
treelite                  0.93             py37h745909e_3    conda-forge
treelite-runtime          0.93                     pypi_0    pypi
typing-extensions         3.7.4.3                       0    conda-forge
typing_extensions         3.7.4.3                    py_0    conda-forge
tzcode                    2020f                h7f98852_0    conda-forge
ucx                       1.8.1+g6b29558       cuda11.0_0    rapidsai
ucx-proc                  1.0.0                       gpu    rapidsai
ucx-py                    0.17.0          py37_g6b29558_0    rapidsai
urllib3                   1.26.2             pyhd8ed1ab_0    conda-forge
wcwidth                   0.2.5              pyh9f0ad1d_2    conda-forge
webencodings              0.5.1                      py_1    conda-forge
websockets                8.1              py37h8f50634_2    conda-forge
wheel                     0.36.2             pyhd3deb0d_0    conda-forge
widgetsnbextension        3.5.1            py37h89c1867_4    conda-forge
xarray                    0.16.2             pyhd8ed1ab_0    conda-forge
xerces-c                  3.2.2             h8412b87_1004    conda-forge
xgboost                   1.3.0dev.rapidsai0.17  cuda11.0py37_0    rapidsai
xorg-fixesproto           5.0               h14c3975_1002    conda-forge
xorg-inputproto           2.3.2             h14c3975_1002    conda-forge
xorg-kbproto              1.0.7             h14c3975_1002    conda-forge
xorg-libice               1.0.10               h516909a_0    conda-forge
xorg-libsm                1.2.3             h84519dc_1000    conda-forge
xorg-libx11               1.6.12               h516909a_0    conda-forge
xorg-libxau               1.0.9                h14c3975_0    conda-forge
xorg-libxdmcp             1.1.3                h516909a_0    conda-forge
xorg-libxext              1.3.4                h516909a_0    conda-forge
xorg-libxfixes            5.0.3             h516909a_1004    conda-forge
xorg-libxi                1.7.10               h516909a_0    conda-forge
xorg-libxrender           0.9.10            h516909a_1002    conda-forge
xorg-libxtst              1.2.3             h516909a_1002    conda-forge
xorg-recordproto          1.14.2            h516909a_1002    conda-forge
xorg-renderproto          0.11.1            h14c3975_1002    conda-forge
xorg-xextproto            7.3.0             h14c3975_1002    conda-forge
xorg-xproto               7.0.31            h7f98852_1007    conda-forge
xz                        5.2.5                h516909a_1    conda-forge
yaml                      0.2.5                h516909a_0    conda-forge
yarl                      1.6.3            py37h5e8e339_0    conda-forge
zeromq                    4.3.3                h58526e2_3    conda-forge
zict                      2.0.0                      py_0    conda-forge
zipp                      3.4.0                      py_0    conda-forge
zlib                      1.2.11            h516909a_1010    conda-forge
zstd                      1.4.8                hdf46e1d_0    conda-forge

@hcho3 hcho3 added ? - Needs Triage Need team to review and classify bug Something isn't working labels Jan 6, 2021
@hcho3 hcho3 added Cython / Python Cython or Python issue and removed ? - Needs Triage Need team to review and classify labels Jan 6, 2021
@github-actions
Copy link

This issue has been marked stale due to no recent activity in the past 30d. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be marked rotten if there is no activity in the next 60d.

@hcho3
Copy link
Contributor Author

hcho3 commented Mar 12, 2021

#3609 will fix this issue.

@jameslamb Would you like to try out the updated versions of predict() and predict_proba() in #3609? In general, we try to closely match the behavior of scikit-learn; any mismatch is considered a bug.

rapids-bot bot pushed a commit that referenced this issue Mar 15, 2021
Closes #3347.

Make the `predict()` and `predict_proba()` functions of RF to match those in the scikit-learn RF.

* Eliminate the parameter `output_class`. Instead, `predict()` will always produce class prediction, and `predict_proba()` will always produce probability prediction. (This applies to binary and multi-class classifiers. Regressors will only have `predict()`.)
* Remove the `threshold` parameter from `predict_proba()`.

Authors:
  - Philip Hyunsu Cho (@hcho3)

Approvers:
  - John Zedlewski (@JohnZed)

URL: #3609
@jameslamb
Copy link
Member

#3609 will fix this issue.

@jameslamb Would you like to try out the updated versions of predict() and predict_proba() in #3609?

Sorry I wasn't able to get back to you before this was merged. I'd be happy to look this week and see if there are other differences with the scikit-learn API for classifiers.

we try to closely match the behavior of scikit-learn; any mismatch is considered a bug.

Related to this, I looked tonight and saw that this project doesn't currently run the scikit-learn estimator checks (documented in "Rolling Your Own Estimator").

For an example of how those checks are used, you can see lightgbm's tests:

Would you be open to a PR that tries to add those checks for RandomForestClassifier?

@hcho3
Copy link
Contributor Author

hcho3 commented Mar 16, 2021

@jameslamb Yes, a PR would be great. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Cython / Python Cython or Python issue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants