diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index 4254d698007b..d98a19aed87d 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -108,8 +108,9 @@ computation a bit faster when meta information like ``base_margin`` is not neede prediction = xgb.dask.inplace_predict(client, output, X) Here ``prediction`` is a dask ``Array`` object containing predictions from model if input -is a ``DaskDMatrix`` or ``da.Array``. For ``dd.DataFrame``, the return value is a -``dd.Series``. +is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly into the +``predict`` function or using ``inplace_predict``, the output type depends on input data. +See next section for details. Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier`` and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples. @@ -143,9 +144,23 @@ Also for inplace prediction: .. code-block:: python booster.set_param({'predictor': 'gpu_predictor'}) - # where X is a dask DataFrame or dask Array. + # where X is a dask DataFrame or dask Array containing cupy or cuDF backed data. prediction = xgb.dask.inplace_predict(client, booster, X) +When input is ``da.Array`` object, output is always ``da.Array``. However, if the input +type is ``dd.DataFrame``, output can be ``dd.Series``, ``dd.DataFrame`` or ``da.Array``, +depending on output shape. For example, when shap based prediction is used, the return +value can have 3 or 4 dimensions , in such cases an ``Array`` is always returned. + +The performance of running prediction, either using ``predict`` or ``inplace_predict``, is +sensitive to number of blocks. Internally, it's implemented using ``da.map_blocks`` or +``dd.map_partitions``. When number of partitions is large and each of them have only +small amount of data, the overhead of calling predict becomes visible. On the other hand, +if not using GPU, the number of threads used for prediction on each block matters. Right +now, xgboost uses single thread for each partition. If the number of blocks on each +workers is smaller than number of cores, then the CPU workers might not be fully utilized. + + *************************** Working with other clusters diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 1927ba54ad9a..5a9c334e84bc 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -112,14 +112,15 @@ def _start_tracker(n_workers: int) -> Dict[str, Any]: def _assert_dask_support() -> None: try: - import dask # pylint: disable=W0621,W0611 + import dask # pylint: disable=W0621,W0611 except ImportError as e: raise ImportError( - 'Dask needs to be installed in order to use this module') from e + "Dask needs to be installed in order to use this module" + ) from e - if platform.system() == 'Windows': - msg = 'Windows is not officially supported for dask/xgboost,' - msg += ' contribution are welcomed.' + if platform.system() == "Windows": + msg = "Windows is not officially supported for dask/xgboost," + msg += " contribution are welcomed." LOGGER.warning(msg) @@ -252,6 +253,7 @@ def __init__( if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))): raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label))) + self._n_cols = data.shape[1] self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list) self.is_quantile: bool = False @@ -403,6 +405,9 @@ def create_fn_args(self, worker_addr: str) -> Dict[str, Any]: 'parts': self.worker_map.get(worker_addr, None), 'is_quantile': self.is_quantile} + def num_col(self) -> int: + return self._n_cols + _DataParts = List[Tuple[Any, Optional[Any], Optional[Any], Optional[Any], Optional[Any], Optional[Any], Optional[Any]]] @@ -930,27 +935,90 @@ def train( callbacks=callbacks) +def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool: + return isinstance(data, dd.DataFrame) and len(output_shape) <= 2 + + async def _direct_predict_impl( client: "distributed.Client", + mapped_predict: Callable, + booster: Booster, data: _DaskCollection, - predict_fn: Callable + base_margin: Optional[_DaskCollection], + output_shape: Tuple[int, ...], + meta: Dict[int, str], ) -> _DaskCollection: - if isinstance(data, da.Array): - predictions = await client.submit( - da.map_blocks, - predict_fn, data, False, drop_axis=1, - dtype=numpy.float32 - ).result() - return predictions - if isinstance(data, dd.DataFrame): - predictions = await client.submit( - dd.map_partitions, - predict_fn, data, True, - meta=dd.utils.make_meta({'prediction': 'f4'}) - ).result() - return predictions.iloc[:, 0] - raise TypeError('data of type: ' + str(type(data)) + - ' is not supported by direct prediction') + columns = list(meta.keys()) + booster_f = await client.scatter(data=booster, broadcast=True) + if _can_output_df(data, output_shape): + if base_margin is not None and isinstance(base_margin, da.Array): + base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe() + else: + base_margin_df = base_margin + predictions = dd.map_partitions( + mapped_predict, + booster_f, + data, + True, + columns, + base_margin_df, + meta=dd.utils.make_meta(meta), + ) + # classification can return a dataframe, drop 1 dim when it's reg/binary + if len(output_shape) == 1: + predictions = predictions.iloc[:, 0] + else: + if base_margin is not None and isinstance( + base_margin, (dd.Series, dd.DataFrame) + ): + base_margin_array: Optional[da.Array] = base_margin.to_dask_array() + else: + base_margin_array = base_margin + # Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class, + # contrib)/3(contrib)/4(interaction) dims. + if len(output_shape) == 1: + drop_axis: Union[int, List[int]] = [1] # drop from 2 to 1 dim. + new_axis: Union[int, List[int]] = [] + else: + drop_axis = [] + new_axis = [i + 2 for i in range(len(output_shape) - 2)] + predictions = da.map_blocks( + mapped_predict, + booster_f, + data, + False, + columns, + base_margin_array, + drop_axis=drop_axis, + new_axis=new_axis, + dtype=numpy.float32, + ) + return predictions + + +def _infer_predict_output( + booster: Booster, data: _DaskCollection, inplace: bool, **kwargs: Any +) -> Tuple[Tuple[int, ...], Dict[int, str]]: + """Create a dummy test sample to infer output shape for prediction.""" + if isinstance(data, DaskDMatrix): + features = data.num_col() + else: + features = data.shape[1] + rng = numpy.random.RandomState(1994) + test_sample = rng.randn(1, features) + if inplace: + # clear the state to avoid gpu_id, gpu_predictor + booster = Booster(model_file=booster.save_raw()) + test_predt = booster.inplace_predict(test_sample, **kwargs) + else: + m = DMatrix(test_sample) + test_predt = booster.predict(m, **kwargs) + n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1 + meta: Dict[int, str] = {} + if _can_output_df(data, test_predt.shape): + for i in range(n_columns): + meta[i] = "f4" + return test_predt.shape, meta # pylint: disable=too-many-statements @@ -968,19 +1036,19 @@ async def _predict_async( validate_features: bool, ) -> _DaskCollection: if isinstance(model, Booster): - booster = model + _booster = model elif isinstance(model, dict): - booster = model["booster"] + _booster = model["booster"] else: raise TypeError(_expect([Booster, dict], type(model))) if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)): raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data))) - def mapped_predict(partition: Any, is_df: bool) -> Any: - worker = distributed.get_worker() + def mapped_predict( + booster: Booster, partition: Any, is_df: bool, columns: List[int], _: Any + ) -> Any: with config.config_context(**global_config): - booster.set_param({"nthread": worker.nthreads}) - m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads) + m = DMatrix(data=partition, missing=missing) predt = booster.predict( data=m, output_margin=output_margin, @@ -990,167 +1058,115 @@ def mapped_predict(partition: Any, is_df: bool) -> Any: pred_interactions=pred_interactions, validate_features=validate_features, ) - if is_df: + if is_df and len(predt.shape) <= 2: if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"): import cudf - predt = cudf.DataFrame(predt, columns=["prediction"]) + + predt = cudf.DataFrame(predt, columns=columns) else: - predt = DataFrame(predt, columns=["prediction"]) + predt = DataFrame(predt, columns=columns) return predt # Predict on dask collection directly. if isinstance(data, (da.Array, dd.DataFrame)): - return await _direct_predict_impl(client, data, mapped_predict) - + _output_shape, meta = _infer_predict_output( + _booster, + data, + inplace=False, + output_margin=output_margin, + pred_leaf=pred_leaf, + pred_contribs=pred_contribs, + approx_contribs=approx_contribs, + pred_interactions=pred_interactions, + validate_features=False, + ) + return await _direct_predict_impl( + client, mapped_predict, _booster, data, None, _output_shape, meta + ) + output_shape, _ = _infer_predict_output( + booster=_booster, + data=data, + inplace=False, + output_margin=output_margin, + pred_leaf=pred_leaf, + pred_contribs=pred_contribs, + approx_contribs=approx_contribs, + pred_interactions=pred_interactions, + validate_features=False, + ) # Prediction on dask DMatrix. - worker_map = data.worker_map partition_order = data.partition_order feature_names = data.feature_names feature_types = data.feature_types missing = data.missing meta_names = data.meta_names - def dispatched_predict( - worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts - ) -> List[Tuple[List[Union["dask.delayed.Delayed", int]], int]]: - """Perform prediction on each worker.""" - LOGGER.debug("Predicting on %d", worker_id) + def dispatched_predict(booster: Booster, part: Any) -> numpy.ndarray: + data = part[0] + assert isinstance(part, tuple), type(part) + base_margin = None + for i, blob in enumerate(part[1:]): + if meta_names[i] == "base_margin": + base_margin = blob + worker = distributed.get_worker() with config.config_context(**global_config): - worker = distributed.get_worker() - list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts) - predictions = [] - - booster.set_param({"nthread": worker.nthreads}) - for i, parts in enumerate(list_of_parts): - (data, _, _, base_margin, _, _, _) = parts - order = list_of_orders[i] - local_part = DMatrix( - data, - base_margin=base_margin, - feature_names=feature_names, - feature_types=feature_types, - missing=missing, - nthread=worker.nthreads, - ) - predt = booster.predict( - data=local_part, - output_margin=output_margin, - pred_leaf=pred_leaf, - pred_contribs=pred_contribs, - approx_contribs=approx_contribs, - pred_interactions=pred_interactions, - validate_features=validate_features, - ) - if pred_contribs and predt.size != local_part.num_row(): - assert len(predt.shape) in (2, 3) - if len(predt.shape) == 2: - groups = 1 - columns = predt.shape[1] - else: - groups = predt.shape[1] - columns = predt.shape[2] - # pylint: disable=no-member - ret = ( - [dask.delayed(predt), groups, columns], - order, - ) - elif pred_interactions and predt.size != local_part.num_row(): - assert len(predt.shape) in (3, 4) - if len(predt.shape) == 3: - groups = 1 - columns = predt.shape[1] - else: - groups = predt.shape[1] - columns = predt.shape[2] - # pylint: disable=no-member - ret = ( - [dask.delayed(predt), groups, columns], - order, - ) - else: - assert len(predt.shape) == 1 or len(predt.shape) == 2 - columns = 1 if len(predt.shape) == 1 else predt.shape[1] - # pylint: disable=no-member - ret = ( - [dask.delayed(predt), columns], - order, - ) - predictions.append(ret) - - return predictions - - def dispatched_get_shape( - worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts - ) -> List[Tuple[int, int]]: - """Get shape of data in each worker.""" - LOGGER.debug("Get shape on %d", worker_id) - list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts) - shapes = [] - for i, parts in enumerate(list_of_parts): - (data, _, _, _, _, _, _) = parts - shapes.append((data.shape, list_of_orders[i])) - return shapes - - async def map_function( - func: Callable[[int, List[int], _DataParts], Any] - ) -> List[Any]: - """Run function for each part of the data.""" - futures = [] - workers_address = list(worker_map.keys()) - for wid, worker_addr in enumerate(workers_address): - worker_addr = workers_address[wid] - list_of_parts = worker_map[worker_addr] - list_of_orders = [partition_order[part.key] for part in list_of_parts] - - f = client.submit( - func, - worker_id=wid, - list_of_orders=list_of_orders, - list_of_parts=list_of_parts, - pure=True, - workers=[worker_addr], + m = DMatrix( + data, + nthread=worker.nthreads, + missing=missing, + base_margin=base_margin, + feature_names=feature_names, + feature_types=feature_types, ) - assert isinstance(f, distributed.client.Future) - futures.append(f) - # Get delayed objects - results = await client.gather(futures) - # flatten into 1 dim list - results = [t for list_per_worker in results for t in list_per_worker] - # sort by order, l[0] is the delayed object, l[1] is its order - results = sorted(results, key=lambda l: l[1]) - results = [predt for predt, order in results] # remove order - return results - - results = await map_function(dispatched_predict) - shapes = await map_function(dispatched_get_shape) + predt = booster.predict( + m, + output_margin=output_margin, + pred_leaf=pred_leaf, + pred_contribs=pred_contribs, + approx_contribs=approx_contribs, + pred_interactions=pred_interactions, + validate_features=validate_features, + ) + return predt + + all_parts = [] + all_orders = [] + all_shapes = [] + workers_address = list(data.worker_map.keys()) + for worker_addr in workers_address: + list_of_parts = data.worker_map[worker_addr] + all_parts.extend(list_of_parts) + all_orders.extend([partition_order[part.key] for part in list_of_parts]) + for part in all_parts: + s = client.submit(lambda part: part[0].shape[0], part) + all_shapes.append(s) + all_shapes = await client.gather(all_shapes) + + parts_with_order = list(zip(all_parts, all_shapes, all_orders)) + parts_with_order = sorted(parts_with_order, key=lambda p: p[2]) + all_parts = [part for part, shape, order in parts_with_order] + all_shapes = [shape for part, shape, order in parts_with_order] + + futures = [] + booster_f = await client.scatter(data=_booster, broadcast=True) + for part in all_parts: + f = client.submit(dispatched_predict, booster_f, part) + futures.append(f) # Constructing a dask array from list of numpy arrays # See https://docs.dask.org/en/latest/array-creation.html arrays = [] - for i, shape in enumerate(shapes): - if pred_contribs: - out_shape = ( - (shape[0], results[i][2]) - if results[i][1] == 1 - else (shape[0], results[i][1], results[i][2]) - ) - elif pred_interactions: - out_shape = ( - (shape[0], results[i][2], results[i][2]) - if results[i][1] == 1 - else (shape[0], results[i][1], results[i][2]) - ) - else: - out_shape = (shape[0],) if results[i][1] == 1 else (shape[0], results[i][1]) + for i, rows in enumerate(all_shapes): arrays.append( - da.from_delayed(results[i][0], shape=out_shape, dtype=numpy.float32) + da.from_delayed( + futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32 + ) ) - predictions = await da.concatenate(arrays, axis=0) return predictions -def predict( +def predict( # pylint: disable=unused-argument client: "distributed.Client", model: Union[TrainReturnT, Booster], data: Union[DaskDMatrix, _DaskCollection], @@ -1190,22 +1206,15 @@ def predict( ------- prediction: dask.array.Array/dask.dataframe.Series When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an - array, when input data is ``dask.dataframe.DataFrame``, return value is - ``dask.dataframe.Series`` + array, when input data is ``dask.dataframe.DataFrame``, return value can be + ``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``, + depending on the output shape. ''' _assert_dask_support() client = _xgb_get_client(client) - global_config = config.get_config() return client.sync( - _predict_async, client, global_config, model, data, - output_margin=output_margin, - missing=missing, - pred_leaf=pred_leaf, - pred_contribs=pred_contribs, - approx_contribs=approx_contribs, - pred_interactions=pred_interactions, - validate_features=validate_features + _predict_async, global_config=config.get_config(), **locals() ) @@ -1228,30 +1237,38 @@ async def _inplace_predict_async( if not isinstance(data, (da.Array, dd.DataFrame)): raise TypeError(_expect([da.Array, dd.DataFrame], type(data))) - def mapped_predict(data: Any, is_df: bool) -> Any: - worker = distributed.get_worker() - config.set_config(**global_config) - booster.set_param({'nthread': worker.nthreads}) - prediction = booster.inplace_predict( - data, - iteration_range=iteration_range, - predict_type=predict_type, - missing=missing) - if is_df: + def mapped_predict( + booster: Booster, data: Any, is_df: bool, columns: List[int], _: Any + ) -> Any: + with config.config_context(**global_config): + prediction = booster.inplace_predict( + data, + iteration_range=iteration_range, + predict_type=predict_type, + missing=missing + ) + if is_df and len(prediction.shape) <= 2: if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): import cudf - prediction = cudf.DataFrame({'prediction': prediction}, - dtype=numpy.float32) + prediction = cudf.DataFrame( + prediction, columns=columns, dtype=numpy.float32 + ) else: # If it's from pandas, the partition is a numpy array - prediction = DataFrame(prediction, columns=['prediction'], - dtype=numpy.float32) + prediction = DataFrame( + prediction, columns=columns, dtype=numpy.float32 + ) return prediction - return await _direct_predict_impl(client, data, mapped_predict) + shape, meta = _infer_predict_output( + booster, data, True, predict_type=predict_type, iteration_range=iteration_range + ) + return await _direct_predict_impl( + client, mapped_predict, booster, data, None, shape, meta + ) -def inplace_predict( +def inplace_predict( # pylint: disable=unused-argument client: "distributed.Client", model: Union[TrainReturnT, Booster], data: _DaskCollection, @@ -1281,16 +1298,17 @@ def inplace_predict( Returns ------- - prediction + prediction : + When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an + array, when input data is ``dask.dataframe.DataFrame``, return value can be + ``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``, + depending on the output shape. ''' _assert_dask_support() client = _xgb_get_client(client) - global_config = config.get_config() - return client.sync(_inplace_predict_async, client, global_config, model=model, - data=data, - iteration_range=iteration_range, - predict_type=predict_type, - missing=missing) + return client.sync( + _inplace_predict_async, global_config=config.get_config(), **locals() + ) async def _async_wrap_evaluation_matrices( diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index d58404960f7e..43b0c33b5e1e 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -24,7 +24,6 @@ if tm.no_dask()['condition']: pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True) -import distributed from distributed import LocalCluster, Client from distributed.utils_test import client, loop, cluster_fixture import dask.dataframe as dd @@ -130,24 +129,34 @@ def test_from_dask_array() -> None: assert np.all(single_node_predt == from_arr.compute()) -def test_dask_predict_shape_infer() -> None: - with LocalCluster(n_workers=kWorkers) as cluster: - with Client(cluster) as client: - X, y = make_classification(n_samples=1000, n_informative=5, - n_classes=3) - X_ = dd.from_array(X, chunksize=100) - y_ = dd.from_array(y, chunksize=100) - dtrain = xgb.dask.DaskDMatrix(client, data=X_, label=y_) - - model = xgb.dask.train( - client, - {"objective": "multi:softprob", "num_class": 3}, - dtrain=dtrain - ) +def test_dask_predict_shape_infer(client: "Client") -> None: + X, y = make_classification(n_samples=1000, n_informative=5, n_classes=3) + X_ = dd.from_array(X, chunksize=100) + y_ = dd.from_array(y, chunksize=100) + dtrain = xgb.dask.DaskDMatrix(client, data=X_, label=y_) + + model = xgb.dask.train( + client, {"objective": "multi:softprob", "num_class": 3}, dtrain=dtrain + ) + + preds = xgb.dask.predict(client, model, dtrain) + assert preds.shape[0] == preds.compute().shape[0] + assert preds.shape[1] == preds.compute().shape[1] - preds = xgb.dask.predict(client, model, dtrain) - assert preds.shape[0] == preds.compute().shape[0] - assert preds.shape[1] == preds.compute().shape[1] + prediction = xgb.dask.predict(client, model, X_, output_margin=True) + assert isinstance(prediction, dd.DataFrame) + + prediction = prediction.compute() + assert prediction.ndim == 2 + assert prediction.shape[0] == kRows + assert prediction.shape[1] == 3 + + prediction = xgb.dask.inplace_predict(client, model, X_, predict_type="margin") + assert isinstance(prediction, dd.DataFrame) + prediction = prediction.compute() + assert prediction.ndim == 2 + assert prediction.shape[0] == kRows + assert prediction.shape[1] == 3 @pytest.mark.parametrize("tree_method", ["hist", "approx"]) @@ -340,7 +349,7 @@ def test_dask_classifier(model: str, client: "Client") -> None: classifier.fit(X_d, y_d) assert classifier.n_classes_ == 10 - prediction = classifier.predict(X_d) + prediction = classifier.predict(X_d).compute() assert prediction.ndim == 1 assert prediction.shape[0] == kRows @@ -541,6 +550,9 @@ async def run_dask_regressor_asyncio(scheduler_address: str) -> None: assert list(history['validation_0'].keys())[0] == 'rmse' assert len(history['validation_0']['rmse']) == 2 + awaited = await client.compute(prediction) + assert awaited.shape[0] == kRows + async def run_dask_classifier_asyncio(scheduler_address: str) -> None: async with Client(scheduler_address, asynchronous=True) as client: @@ -578,7 +590,7 @@ async def run_dask_classifier_asyncio(scheduler_address: str) -> None: await classifier.fit(X_d, y_d) assert classifier.n_classes_ == 10 - prediction = await classifier.predict(X_d) + prediction = await client.compute(await classifier.predict(X_d)) assert prediction.ndim == 1 assert prediction.shape[0] == kRows @@ -1019,6 +1031,17 @@ def test_data_initialization(self, client: "Client") -> None: run_data_initialization(xgb.dask.DaskDMatrix, xgb.dask.DaskXGBClassifier, X, y) def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None: + rows = X.shape[0] + cols = X.shape[1] + + def assert_shape(shape): + assert shape[0] == rows + if "num_class" in params.keys(): + assert shape[1] == params["num_class"] + assert shape[2] == cols + 1 + else: + assert shape[1] == cols + 1 + X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32) Xy = xgb.dask.DaskDMatrix(client, X, y) booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster'] @@ -1027,15 +1050,17 @@ def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> shap = xgb.dask.predict(client, booster, test_Xy, pred_contribs=True).compute() margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute() + assert_shape(shap.shape) assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5) shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute() margin = xgb.dask.predict(client, booster, X, output_margin=True).compute() + assert_shape(shap.shape) assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5) def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None: X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32) - cls = xgb.dask.DaskXGBClassifier() + cls = xgb.dask.DaskXGBClassifier(n_estimators=4) cls.client = client cls.fit(X, y) booster = cls.get_booster() @@ -1072,6 +1097,8 @@ def run_shap_interactions( params: Dict[str, Any], client: "Client" ) -> None: + rows = X.shape[0] + cols = X.shape[1] X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32) Xy = xgb.dask.DaskDMatrix(client, X, y) @@ -1082,6 +1109,12 @@ def run_shap_interactions( shap = xgb.dask.predict( client, booster, test_Xy, pred_interactions=True ).compute() + + assert len(shap.shape) == 3 + assert shap.shape[0] == rows + assert shap.shape[1] == cols + 1 + assert shap.shape[2] == cols + 1 + margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute() assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)), margin,