Skip to content

Commit

Permalink
Update predict() / predict_proba() of RF to match sklearn (#3609)
Browse files Browse the repository at this point in the history
Closes #3347.

Make the `predict()` and `predict_proba()` functions of RF to match those in the scikit-learn RF.

* Eliminate the parameter `output_class`. Instead, `predict()` will always produce class prediction, and `predict_proba()` will always produce probability prediction. (This applies to binary and multi-class classifiers. Regressors will only have `predict()`.)
* Remove the `threshold` parameter from `predict_proba()`.

Authors:
  - Philip Hyunsu Cho (@hcho3)

Approvers:
  - John Zedlewski (@JohnZed)

URL: #3609
  • Loading branch information
hcho3 authored Mar 15, 2021
1 parent 5aa3c72 commit f5d86b9
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 53 deletions.
18 changes: 2 additions & 16 deletions python/cuml/dask/ensemble/randomforestclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def fit(self, X, y, convert_dtype=False):
convert_dtype=convert_dtype)
return self

def predict(self, X, output_class=True, algo='auto', threshold=0.5,
def predict(self, X, algo='auto', threshold=0.5,
convert_dtype=True, predict_model="GPU",
fil_sparse_format='auto', delayed=True):
"""
Expand Down Expand Up @@ -304,12 +304,6 @@ def predict(self, X, output_class=True, algo='auto', threshold=0.5,
X : Dask cuDF dataframe or CuPy backed Dask Array (n_rows, n_features)
Distributed dense matrix (floats or doubles) of shape
(n_samples, n_features).
output_class : boolean (default = True)
This is optional and required only while performing the
predict operation on the GPU.
If true, return a 1 or 0 depending on whether the raw
prediction exceeds the threshold. If False, just return
the raw prediction.
algo : string (default = 'auto')
This is optional and required only while performing the
predict operation on the GPU.
Expand All @@ -325,7 +319,6 @@ def predict(self, X, output_class=True, algo='auto', threshold=0.5,
Threshold used for classification. Optional and required only
while performing the predict operation on the GPU, that is for,
predict_model='GPU'.
It is applied if output_class == True, else it is ignored
convert_dtype : bool, optional (default = True)
When set to True, the predict method will, when necessary, convert
the input to the data type which was used to train the model. This
Expand Down Expand Up @@ -358,7 +351,7 @@ def predict(self, X, output_class=True, algo='auto', threshold=0.5,
convert_dtype=convert_dtype)
else:
preds = \
self._predict_using_fil(X, output_class=output_class,
self._predict_using_fil(X,
algo=algo,
threshold=threshold,
convert_dtype=convert_dtype,
Expand Down Expand Up @@ -456,12 +449,6 @@ def predict_proba(self, X,
be used if the model was trained on float32 data and `X` is float32
or convert_dtype is set to True. Also the 'GPU' should only be
used for binary classification problems.
output_class : boolean (default = True)
This is optional and required only while performing the
predict operation on the GPU.
If true, return a 1 or 0 depending on whether the raw
prediction exceeds the threshold. If False, just return
the raw prediction.
algo : string (default = 'auto')
This is optional and required only while performing the
predict operation on the GPU.
Expand All @@ -476,7 +463,6 @@ def predict_proba(self, X,
threshold : float (default = 0.5)
Threshold used for classification. Optional and required only
while performing the predict operation on the GPU.
It is applied if output_class == True, else it is ignored
convert_dtype : bool, optional (default = True)
When set to True, the predict method will, when necessary, convert
the input to the data type which was used to train the model. This
Expand Down
28 changes: 4 additions & 24 deletions python/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,7 @@ class RandomForestClassifier(BaseRandomForestModel,

@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
return_values=[('dense', '(n_samples, 1)')])
def predict(self, X, predict_model="GPU",
output_class=True, threshold=0.5,
def predict(self, X, predict_model="GPU", threshold=0.5,
algo='auto', num_classes=None,
convert_dtype=True,
fil_sparse_format='auto') -> CumlArray:
Expand All @@ -583,12 +582,6 @@ class RandomForestClassifier(BaseRandomForestModel,
be used if the model was trained on float32 data and `X` is float32
or convert_dtype is set to True. Also the 'GPU' should only be
used for binary classification problems.
output_class : boolean (default = True)
This is optional and required only while performing the
predict operation on the GPU.
If true, return a 1 or 0 depending on whether the raw
prediction exceeds the threshold. If False, just return
the raw prediction.
algo : string (default = 'auto')
This is optional and required only while performing the
predict operation on the GPU.
Expand All @@ -603,7 +596,6 @@ class RandomForestClassifier(BaseRandomForestModel,
threshold : float (default = 0.5)
Threshold used for classification. Optional and required only
while performing the predict operation on the GPU.
It is applied if output_class == True, else it is ignored
num_classes : int (default = None)
number of different classes present in the dataset.

Expand Down Expand Up @@ -651,7 +643,7 @@ class RandomForestClassifier(BaseRandomForestModel,

else:
preds = \
self._predict_model_on_gpu(X=X, output_class=output_class,
self._predict_model_on_gpu(X=X, output_class=True,
threshold=threshold,
algo=algo,
convert_dtype=convert_dtype,
Expand Down Expand Up @@ -722,8 +714,7 @@ class RandomForestClassifier(BaseRandomForestModel,

@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
return_values=[('dense', '(n_samples, 1)')])
def predict_proba(self, X, output_class=True,
threshold=0.5, algo='auto',
def predict_proba(self, X, algo='auto',
num_classes=None, convert_dtype=True,
fil_sparse_format='auto') -> CumlArray:
"""
Expand All @@ -734,12 +725,6 @@ class RandomForestClassifier(BaseRandomForestModel,
Parameters
----------
X : {}
output_class: boolean (default = True)
This is optional and required only while performing the
predict operation on the GPU.
If true, return a 1 or 0 depending on whether the raw
prediction exceeds the threshold. If False, just return
the raw prediction.
algo : string (default = 'auto')
This is optional and required only while performing the
predict operation on the GPU.
Expand All @@ -751,10 +736,6 @@ class RandomForestClassifier(BaseRandomForestModel,
`auto` - choose the algorithm automatically. Currently
'batch_tree_reorg' is used for dense storage
and 'naive' for sparse storage
threshold : float (default = 0.5)
Threshold used for classification. Optional and required only
while performing the predict operation on the GPU.
It is applied if output_class == True, else it is ignored
num_classes : int (default = None)
number of different classes present in the dataset.

Expand Down Expand Up @@ -799,8 +780,7 @@ class RandomForestClassifier(BaseRandomForestModel,
"training dataset.")

preds_proba = \
self._predict_model_on_gpu(X, output_class=output_class,
threshold=threshold,
self._predict_model_on_gpu(X, output_class=True,
algo=algo,
convert_dtype=convert_dtype,
fil_sparse_format=fil_sparse_format,
Expand Down
13 changes: 6 additions & 7 deletions python/cuml/test/dask/test_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,7 @@ def test_rf_regression_dask_fil(partitions_per_worker,


@pytest.mark.parametrize('partitions_per_worker', [5])
@pytest.mark.parametrize('output_class', [True, False])
def test_rf_classification_dask_array(partitions_per_worker, client,
output_class):
def test_rf_classification_dask_array(partitions_per_worker, client):
n_workers = len(client.scheduler_info()['workers'])

X, y = make_classification(n_samples=n_workers * 2000, n_features=30,
Expand All @@ -199,10 +197,7 @@ def test_rf_classification_dask_array(partitions_per_worker, client,
X_test_dask_array = from_array(X_test)
cuml_mod = cuRFC_mg(**cu_rf_params)
cuml_mod.fit(X_train_df, y_train_df)
cuml_mod_predict = cuml_mod.predict(X_test_dask_array,
output_class).compute()
if not output_class:
cuml_mod_predict = np.round(cuml_mod_predict)
cuml_mod_predict = cuml_mod.predict(X_test_dask_array).compute()

acc_score = accuracy_score(cuml_mod_predict, y_test, normalize=True)

Expand Down Expand Up @@ -279,8 +274,12 @@ def test_rf_classification_dask_fil_predict_proba(partitions_per_worker,
cu_rf_mg = cuRFC_mg(**cu_rf_params)
cu_rf_mg.fit(X_train_df, y_train_df)

fil_preds = cu_rf_mg.predict(X_test_df).compute()
fil_preds = fil_preds.to_array()
fil_preds_proba = cu_rf_mg.predict_proba(X_test_df).compute()
fil_preds_proba = cp.asnumpy(fil_preds_proba.as_gpu_matrix())
np.testing.assert_equal(fil_preds, np.argmax(fil_preds_proba, axis=1))

y_proba = np.zeros(np.shape(fil_preds_proba))
y_proba[:, 1] = y_test
y_proba[:, 0] = 1.0 - y_test
Expand Down
10 changes: 4 additions & 6 deletions python/cuml/test/test_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def test_rf_classification(small_clf, datatype, split_algo,
assert captured_stdout == ''
fil_preds = cuml_model.predict(X_test,
predict_model="GPU",
output_class=True,
threshold=0.5,
algo='auto')
cu_preds = cuml_model.predict(X_test, predict_model="CPU")
Expand Down Expand Up @@ -428,14 +427,15 @@ def rf_classification(datatype, array_type, max_features, max_samples,
.as_gpu_matrix())
cu_preds_cpu = cuml_model.predict(X_test_df,
predict_model="CPU").to_array()
cu_preds_gpu = cuml_model.predict(X_test_df, output_class=True,
cu_preds_gpu = cuml_model.predict(X_test_df,
predict_model="GPU").to_array()
else:
cuml_model.fit(X_train, y_train)
cu_proba_gpu = cuml_model.predict_proba(X_test)
cu_preds_cpu = cuml_model.predict(X_test, predict_model="CPU")
cu_preds_gpu = cuml_model.predict(X_test, predict_model="GPU",
output_class=True)
cu_preds_gpu = cuml_model.predict(X_test, predict_model="GPU")
np.testing.assert_array_equal(cu_preds_gpu,
np.argmax(cu_proba_gpu, axis=1))

cu_acc_cpu = accuracy_score(y_test, cu_preds_cpu)
cu_acc_gpu = accuracy_score(y_test, cu_preds_gpu)
Expand Down Expand Up @@ -507,14 +507,12 @@ def test_rf_classification_sparse(small_clf, datatype,
with pytest.raises(ValueError):
fil_preds = cuml_model.predict(X_test,
predict_model="GPU",
output_class=True,
threshold=0.5,
fil_sparse_format=fil_sparse_format,
algo=algo)
else:
fil_preds = cuml_model.predict(X_test,
predict_model="GPU",
output_class=True,
threshold=0.5,
fil_sparse_format=fil_sparse_format,
algo=algo)
Expand Down

0 comments on commit f5d86b9

Please sign in to comment.