diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 4251af78f8bc..d7939a7717d7 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -280,18 +280,30 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group return results[0] -def _predict_part(part, model, proba, **kwargs): +def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, **kwargs): data = part.values if isinstance(part, pd.DataFrame) else part if data.shape[0] == 0: result = np.array([]) - elif proba: - result = model.predict_proba(data, **kwargs) + elif pred_proba: + result = model.predict_proba( + data, + raw_score=raw_score, + pred_leaf=pred_leaf, + pred_contrib=pred_contrib, + **kwargs + ) else: - result = model.predict(data, **kwargs) + result = model.predict( + data, + raw_score=raw_score, + pred_leaf=pred_leaf, + pred_contrib=pred_contrib, + **kwargs + ) if isinstance(part, pd.DataFrame): - if proba: + if pred_proba or pred_contrib: result = pd.DataFrame(result, index=part.index) else: result = pd.Series(result, index=part.index, name='predictions') @@ -299,7 +311,8 @@ def _predict_part(part, model, proba, **kwargs): return result -def _predict(model, data, proba=False, dtype=np.float32, **kwargs): +def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pred_contrib=False, + dtype=np.float32, **kwargs): """Inner predict routine. Parameters @@ -307,20 +320,42 @@ def _predict(model, data, proba=False, dtype=np.float32, **kwargs): model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class data : dask array of shape = [n_samples, n_features] Input feature matrix. - proba : bool - Should method return results of predict_proba (proba == True) or predict (proba == False). + pred_proba : bool, optional (default=False) + Should method return results of ``predict_proba`` (``pred_proba=True``) or ``predict`` (``pred_proba=False``). + pred_leaf : bool, optional (default=False) + Whether to predict leaf index. + pred_contrib : bool, optional (default=False) + Whether to predict feature contributions. dtype : np.dtype Dtype of the output. - kwargs : other parameters passed to predict or predict_proba method + kwargs : dict + Other parameters passed to ``predict`` or ``predict_proba`` method. """ if isinstance(data, dd._Frame): - return data.map_partitions(_predict_part, model=model, proba=proba, **kwargs).values + return data.map_partitions( + _predict_part, + model=model, + raw_score=raw_score, + pred_proba=pred_proba, + pred_leaf=pred_leaf, + pred_contrib=pred_contrib, + **kwargs + ).values elif isinstance(data, da.Array): - if proba: + if pred_proba: kwargs['chunks'] = (data.chunks[0], (model.n_classes_,)) else: kwargs['drop_axis'] = 1 - return data.map_blocks(_predict_part, model=model, proba=proba, dtype=dtype, **kwargs) + return data.map_blocks( + _predict_part, + model=model, + raw_score=raw_score, + pred_proba=pred_proba, + pred_leaf=pred_leaf, + pred_contrib=pred_contrib, + dtype=dtype, + **kwargs + ) else: raise TypeError('Data must be either Dask array or dataframe. Got %s.' % str(type(data))) @@ -370,7 +405,7 @@ def predict(self, X, **kwargs): def predict_proba(self, X, **kwargs): """Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba.""" - return _predict(self.to_local(), X, proba=True, **kwargs) + return _predict(self.to_local(), X, pred_proba=True, **kwargs) predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__ def to_local(self): diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 960e1a56da63..f1564955c073 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -235,6 +235,55 @@ def test_classifier(output, centers, client, listen_port): client.close() +@pytest.mark.parametrize('output', data_output) +@pytest.mark.parametrize('centers', data_centers) +def test_classifier_pred_contrib(output, centers, client, listen_port): + X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) + + dask_classifier = dlgbm.DaskLGBMClassifier( + time_out=5, + local_listen_port=listen_port, + tree_learner='data', + n_estimators=10, + num_leaves=10 + ) + dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client) + preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True).compute() + + local_classifier = lightgbm.LGBMClassifier( + n_estimators=10, + num_leaves=10 + ) + local_classifier.fit(X, y, sample_weight=w) + local_preds_with_contrib = local_classifier.predict(X, pred_contrib=True) + + if output == 'scipy_csr_matrix': + preds_with_contrib = np.array(preds_with_contrib.todense()) + + # shape depends on whether it is binary or multiclass classification + num_features = dask_classifier.n_features_ + num_classes = dask_classifier.n_classes_ + if num_classes == 2: + expected_num_cols = num_features + 1 + else: + expected_num_cols = (num_features + 1) * num_classes + + # * shape depends on whether it is binary or multiclass classification + # * matrix for binary classification is of the form [feature_contrib, base_value], + # for multi-class it's [feat_contrib_class1, base_value_class1, feat_contrib_class2, base_value_class2, etc.] + # * contrib outputs for distributed training are different than from local training, so we can just test + # that the output has the right shape and base values are in the right position + assert preds_with_contrib.shape[1] == expected_num_cols + assert preds_with_contrib.shape == local_preds_with_contrib.shape + + if num_classes == 2: + assert len(np.unique(preds_with_contrib[:, num_features]) == 1) + else: + for i in range(num_classes): + base_value_col = num_features * (i + 1) + i + assert len(np.unique(preds_with_contrib[:, base_value_col]) == 1) + + def test_training_does_not_fail_on_port_conflicts(client): _, _, _, dX, dy, dw = _create_data('classification', output='array') @@ -315,6 +364,37 @@ def test_regressor(output, client, listen_port): client.close() +@pytest.mark.parametrize('output', data_output) +def test_regressor_pred_contrib(output, client, listen_port): + X, y, w, dX, dy, dw = _create_data('regression', output=output) + + dask_regressor = dlgbm.DaskLGBMRegressor( + time_out=5, + local_listen_port=listen_port, + tree_learner='data', + n_estimators=10, + num_leaves=10 + ) + dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client) + preds_with_contrib = dask_regressor.predict(dX, pred_contrib=True).compute() + + local_regressor = lightgbm.LGBMRegressor( + n_estimators=10, + num_leaves=10 + ) + local_regressor.fit(X, y, sample_weight=w) + local_preds_with_contrib = local_regressor.predict(X, pred_contrib=True) + + if output == "scipy_csr_matrix": + preds_with_contrib = np.array(preds_with_contrib.todense()) + + # contrib outputs for distributed training are different than from local training, so we can just test + # that the output has the right shape and base values are in the right position + num_features = dX.shape[1] + assert preds_with_contrib.shape[1] == num_features + 1 + assert preds_with_contrib.shape == local_preds_with_contrib.shape + + @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('alpha', [.1, .5, .9]) def test_regressor_quantile(output, client, listen_port, alpha):