diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index bea06864307d..20a4728e3bf4 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -312,15 +312,18 @@ def data_handle(data, label=None, weight=None, base_margin=None, data, feature_names, feature_types ) dispatch_device_quantile_dmatrix_set_data(self.proxy, data) - self.proxy.set_info(label=label, weight=weight, - base_margin=base_margin, - group=group, - qid=qid, - label_lower_bound=label_lower_bound, - label_upper_bound=label_upper_bound, - feature_names=feature_names, - feature_types=feature_types, - feature_weights=feature_weights) + self.proxy.set_info( + label=label, + weight=weight, + base_margin=base_margin, + group=group, + qid=qid, + label_lower_bound=label_lower_bound, + label_upper_bound=label_upper_bound, + feature_names=feature_names, + feature_types=feature_types, + feature_weights=feature_weights + ) try: # Differ the exception in order to return 0 and stop the iteration. # Exception inside a ctype callback function has no effect except @@ -408,7 +411,7 @@ def inner_f(*args, **kwargs): return inner_f -class DMatrix: # pylint: disable=too-many-instance-attributes +class DMatrix: # pylint: disable=too-many-instance-attributes """Data Matrix used in XGBoost. DMatrix is an internal data structure that is used by XGBoost, @@ -416,13 +419,26 @@ class DMatrix: # pylint: disable=too-many-instance-attributes You can construct DMatrix from multiple different sources of data. """ - def __init__(self, data, label=None, weight=None, base_margin=None, - missing=None, - silent=False, - feature_names=None, - feature_types=None, - nthread=None, - enable_categorical=False): + @_deprecate_positional_args + def __init__( + self, + data, + label=None, + *, + weight=None, + base_margin=None, + missing: Optional[float] = None, + silent=False, + feature_names=None, + feature_types=None, + nthread: Optional[int] = None, + group=None, + qid=None, + label_lower_bound=None, + label_upper_bound=None, + feature_weights=None, + enable_categorical: bool = False, + ) -> None: """Parameters ---------- data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/ @@ -432,12 +448,9 @@ def __init__(self, data, label=None, weight=None, base_margin=None, libsvm format txt file, csv file (by specifying uri parameter 'path_to_csv?format=csv'), or binary file that xgboost can read from. - label : list, numpy 1-D array or cudf.DataFrame, optional + label : array_like Label of the training data. - missing : float, optional - Value in the input data which needs to be present as a missing - value. If None, defaults to np.nan. - weight : list, numpy 1-D array or cudf.DataFrame , optional + weight : array_like Weight for each instance. .. note:: For ranking task, weights are per-group. @@ -447,6 +460,11 @@ def __init__(self, data, label=None, weight=None, base_margin=None, ordering of data points within each group, so it doesn't make sense to assign weights to individual data points. + base_margin: array_like + Base margin used for boosting from existing model. + missing : float, optional + Value in the input data which needs to be present as a missing + value. If None, defaults to np.nan. silent : boolean, optional Whether print messages during construction feature_names : list, optional @@ -456,7 +474,16 @@ def __init__(self, data, label=None, weight=None, base_margin=None, nthread : integer, optional Number of threads to use for loading data when parallelization is applicable. If -1, uses maximum threads available on the system. - + group : array_like + Group size for all ranking group. + qid : array_like + Query ID for data samples, used for ranking. + label_lower_bound : array_like + Lower bound for survival training. + label_upper_bound : array_like + Upper bound for survival training. + feature_weights : array_like, optional + Set feature weights for column sampling. enable_categorical: boolean, optional .. versionadded:: 1.3.0 @@ -469,7 +496,9 @@ def __init__(self, data, label=None, weight=None, base_margin=None, """ if isinstance(data, list): - raise TypeError('Input data can not be a list.') + raise TypeError("Input data can not be a list.") + if group is not None and qid is not None: + raise ValueError("Either one of `group` or `qid` should be None.") self.missing = missing if missing is not None else np.nan self.nthread = nthread if nthread is not None else -1 @@ -481,16 +510,28 @@ def __init__(self, data, label=None, weight=None, base_margin=None, return from .data import dispatch_data_backend + handle, feature_names, feature_types = dispatch_data_backend( - data, missing=self.missing, + data, + missing=self.missing, threads=self.nthread, feature_names=feature_names, feature_types=feature_types, - enable_categorical=enable_categorical) + enable_categorical=enable_categorical, + ) assert handle is not None self.handle = handle - self.set_info(label=label, weight=weight, base_margin=base_margin) + self.set_info( + label=label, + weight=weight, + base_margin=base_margin, + group=group, + qid=qid, + label_lower_bound=label_lower_bound, + label_upper_bound=label_upper_bound, + feature_weights=feature_weights, + ) if feature_names is not None: self.feature_names = feature_names @@ -503,17 +544,23 @@ def __del__(self): self.handle = None @_deprecate_positional_args - def set_info(self, *, - label=None, weight=None, base_margin=None, - group=None, - qid=None, - label_lower_bound=None, - label_upper_bound=None, - feature_names=None, - feature_types=None, - feature_weights=None): - '''Set meta info for DMatrix.''' + def set_info( + self, + *, + label=None, + weight=None, + base_margin=None, + group=None, + qid=None, + label_lower_bound=None, + label_upper_bound=None, + feature_names=None, + feature_types=None, + feature_weights=None + ) -> None: + """Set meta info for DMatrix. See doc string for DMatrix constructor.""" from .data import dispatch_meta_backend + if label is not None: self.set_label(label) if weight is not None: @@ -918,39 +965,67 @@ class DeviceQuantileDMatrix(DMatrix): information may be lost in quantisation. This DMatrix is primarily designed to save memory in training from device memory inputs by avoiding intermediate storage. Set max_bin to control the number of bins during - quantisation. + quantisation. See doc string in `DMatrix` for documents on meta info. You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack. .. versionadded:: 1.1.0 """ - - def __init__(self, data, label=None, weight=None, # pylint: disable=W0231 - base_margin=None, - missing=None, - silent=False, - feature_names=None, - feature_types=None, - nthread=None, max_bin=256): + @_deprecate_positional_args + def __init__( # pylint: disable=super-init-not-called + self, + data, + label=None, + *, + weight=None, + base_margin=None, + missing=None, + silent=False, + feature_names=None, + feature_types=None, + nthread: Optional[int] = None, + max_bin: int = 256, + group=None, + qid=None, + label_lower_bound=None, + label_upper_bound=None, + feature_weights=None, + enable_categorical: bool = False, + ): self.max_bin = max_bin self.missing = missing if missing is not None else np.nan self.nthread = nthread if nthread is not None else 1 + self._silent = silent # unused, kept for compatibility if isinstance(data, ctypes.c_void_p): self.handle = data return from .data import init_device_quantile_dmatrix handle, feature_names, feature_types = init_device_quantile_dmatrix( - data, missing=self.missing, threads=self.nthread, - max_bin=self.max_bin, + data, label=label, weight=weight, base_margin=base_margin, - group=None, - label_lower_bound=None, - label_upper_bound=None, + group=group, + qid=qid, + missing=self.missing, + label_lower_bound=label_lower_bound, + label_upper_bound=label_upper_bound, + feature_weights=feature_weights, feature_names=feature_names, - feature_types=feature_types) + feature_types=feature_types, + threads=self.nthread, + max_bin=self.max_bin, + ) + if enable_categorical: + raise NotImplementedError( + 'categorical support is not enabled on DeviceQuantileDMatrix.' + ) self.handle = handle + if qid is not None and group is not None: + raise ValueError( + 'Only one of the eval_qid or eval_group for each evaluation ' + 'dataset should be provided.' + ) self.feature_names = feature_names self.feature_types = feature_types diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 6c40a8c97246..64d13bd800ed 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -38,8 +38,9 @@ from .core import _deprecate_positional_args from .training import train as worker_train from .tracker import RabitTracker, get_host_ip -from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase -from .sklearn import xgboost_model_doc, _objective_decorator +from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator +from .sklearn import XGBRankerMixIn +from .sklearn import xgboost_model_doc from .sklearn import _cls_predict_proba from .sklearn import XGBRanker @@ -180,10 +181,12 @@ def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Clie class DaskDMatrix: # pylint: disable=missing-docstring, too-many-instance-attributes - '''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing - a `DaskDMatrix` forces all lazy computation to be carried out. Wait for - the input data explicitly if you want to see actual computation of - constructing `DaskDMatrix`. + '''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a + `DaskDMatrix` forces all lazy computation to be carried out. Wait for the input data + explicitly if you want to see actual computation of constructing `DaskDMatrix`. + + See doc string for DMatrix constructor for other parameters. DaskDMatrix accepts only + dask collection. .. note:: @@ -197,29 +200,6 @@ class DaskDMatrix: client : Specify the dask client used for training. Use default client returned from dask if it's set to None. - data : - data source of DMatrix. - label : - label used for trainin. - missing : - Value in the input data (e.g. `numpy.ndarray`) which needs to be present as a - missing value. If None, defaults to np.nan. - weight : - Weight for each instance. - base_margin : - Global bias for each instance. - qid : - Query ID for ranking. - label_lower_bound : - Upper bound for survival training. - label_upper_bound : - Lower bound for survival training. - feature_weights : - Weight for features used in column sampling. - feature_names : - Set names for features. - feature_types : - Set types for features ''' @@ -230,15 +210,18 @@ def __init__( data: _DaskCollection, label: Optional[_DaskCollection] = None, *, - missing: float = None, weight: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None, + missing: float = None, + silent: bool = False, # pylint: disable=unused-argument + feature_names: Optional[Union[str, List[str]]] = None, + feature_types: Optional[Union[Any, List[Any]]] = None, + group: Optional[_DaskCollection] = None, qid: Optional[_DaskCollection] = None, label_lower_bound: Optional[_DaskCollection] = None, label_upper_bound: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None, - feature_names: Optional[Union[str, List[str]]] = None, - feature_types: Optional[Union[Any, List[Any]]] = None + enable_categorical: bool = False ) -> None: _assert_dask_support() client = _xgb_get_client(client) @@ -248,30 +231,41 @@ def __init__( self.missing = missing if qid is not None and weight is not None: - raise NotImplementedError('per-group weight is not implemented.') + raise NotImplementedError("per-group weight is not implemented.") + if group is not None: + raise NotImplementedError( + "group structure is not implemented, use qid instead." + ) + if enable_categorical: + raise NotImplementedError( + "categorical support is not enabled on `DaskDMatrix`." + ) if len(data.shape) != 2: raise ValueError( - 'Expecting 2 dimensional input, got: {shape}'.format( - shape=data.shape)) + "Expecting 2 dimensional input, got: {shape}".format(shape=data.shape) + ) if not isinstance(data, (dd.DataFrame, da.Array)): raise TypeError(_expect((dd.DataFrame, da.Array), type(data))) - if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, - type(None))): - raise TypeError( - _expect((dd.DataFrame, da.Array, dd.Series), type(label))) + if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))): + raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label))) self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list) self.is_quantile: bool = False - self._init = client.sync(self.map_local_data, - client, data, label=label, weights=weight, - base_margin=base_margin, - qid=qid, - feature_weights=feature_weights, - label_lower_bound=label_lower_bound, - label_upper_bound=label_upper_bound) + self._init = client.sync( + self.map_local_data, + client, + data, + label=label, + weights=weight, + base_margin=base_margin, + qid=qid, + feature_weights=feature_weights, + label_lower_bound=label_lower_bound, + label_upper_bound=label_upper_bound, + ) def __await__(self) -> Generator: return self._init.__await__() @@ -571,11 +565,11 @@ def next(self, input_data: Callable) -> int: class DaskDeviceQuantileDMatrix(DaskDMatrix): - '''Specialized data type for `gpu_hist` tree method. This class is used to - reduce the memory usage by eliminating data copies. Internally the all - partitions/chunks of data are merged by weighted GK sketching. So the - number of partitions from dask may affect training accuracy as GK generates - bounded error for each merge. + '''Specialized data type for `gpu_hist` tree method. This class is used to reduce the + memory usage by eliminating data copies. Internally the all partitions/chunks of data + are merged by weighted GK sketching. So the number of partitions from dask may affect + training accuracy as GK generates bounded error for each merge. See doc string for + `DeviceQuantileDMatrix` and `DMatrix` for other parameters. .. versionadded:: 1.2.0 @@ -584,42 +578,50 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix): max_bin : Number of bins for histogram construction. ''' + @_deprecate_positional_args def __init__( self, client: "distributed.Client", data: _DaskCollection, label: Optional[_DaskCollection] = None, - missing: float = None, + *, weight: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None, + missing: float = None, + silent: bool = False, + feature_names: Optional[Union[str, List[str]]] = None, + feature_types: Optional[Union[Any, List[Any]]] = None, + max_bin: int = 256, + group: Optional[_DaskCollection] = None, qid: Optional[_DaskCollection] = None, label_lower_bound: Optional[_DaskCollection] = None, label_upper_bound: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None, - feature_names: Optional[Union[str, List[str]]] = None, - feature_types: Optional[Union[Any, List[Any]]] = None, - max_bin: int = 256 + enable_categorical: bool = False, ) -> None: super().__init__( client=client, data=data, label=label, - missing=missing, - feature_weights=feature_weights, weight=weight, base_margin=base_margin, + group=group, qid=qid, label_lower_bound=label_lower_bound, label_upper_bound=label_upper_bound, + missing=missing, + silent=silent, + feature_weights=feature_weights, feature_names=feature_names, - feature_types=feature_types + feature_types=feature_types, + enable_categorical=enable_categorical, ) self.max_bin = max_bin self.is_quantile = True def create_fn_args(self, worker_addr: str) -> Dict[str, Any]: args = super().create_fn_args(worker_addr) - args['max_bin'] = self.max_bin + args["max_bin"] = self.max_bin return args @@ -630,35 +632,49 @@ def _create_device_quantile_dmatrix( meta_names: List[str], missing: float, parts: Optional[_DataParts], - max_bin: int + max_bin: int, ) -> DeviceQuantileDMatrix: worker = distributed.get_worker() if parts is None: - msg = 'worker {address} has an empty DMatrix. '.format( - address=worker.address) + msg = "worker {address} has an empty DMatrix.".format(address=worker.address) LOGGER.warning(msg) import cupy - d = DeviceQuantileDMatrix(cupy.zeros((0, 0)), - feature_names=feature_names, - feature_types=feature_types, - max_bin=max_bin) + + d = DeviceQuantileDMatrix( + cupy.zeros((0, 0)), + feature_names=feature_names, + feature_types=feature_types, + max_bin=max_bin, + ) return d - (data, labels, weights, base_margin, qid, - label_lower_bound, label_upper_bound) = _get_worker_parts( - parts, meta_names) - it = DaskPartitionIter(data=data, label=labels, weight=weights, - base_margin=base_margin, - qid=qid, - label_lower_bound=label_lower_bound, - label_upper_bound=label_upper_bound) - - dmatrix = DeviceQuantileDMatrix(it, - missing=missing, - feature_names=feature_names, - feature_types=feature_types, - nthread=worker.nthreads, - max_bin=max_bin) + ( + data, + labels, + weights, + base_margin, + qid, + label_lower_bound, + label_upper_bound, + ) = _get_worker_parts(parts, meta_names) + it = DaskPartitionIter( + data=data, + label=labels, + weight=weights, + base_margin=base_margin, + qid=qid, + label_lower_bound=label_lower_bound, + label_upper_bound=label_upper_bound, + ) + + dmatrix = DeviceQuantileDMatrix( + it, + missing=missing, + feature_names=feature_names, + feature_types=feature_types, + nthread=worker.nthreads, + max_bin=max_bin, + ) dmatrix.set_info(feature_weights=feature_weights) return dmatrix @@ -712,13 +728,15 @@ def concat_or_none(data: Tuple[Optional[T], ...]) -> Optional[T]: missing=missing, feature_names=feature_names, feature_types=feature_types, - nthread=worker.nthreads + nthread=worker.nthreads, ) dmatrix.set_info( - base_margin=_base_margin, qid=_qid, weight=_weights, + base_margin=_base_margin, + qid=_qid, + weight=_weights, label_lower_bound=_label_lower_bound, label_upper_bound=_label_upper_bound, - feature_weights=feature_weights + feature_weights=feature_weights, ) return dmatrix @@ -753,6 +771,8 @@ def _get_workers_from_data( for e in evals: assert len(e) == 2 assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str) + if e[0] is dtrain: + continue worker_map = set(e[0].worker_map.keys()) X_worker_map = X_worker_map.union(worker_map) return X_worker_map @@ -960,7 +980,7 @@ def mapped_predict(partition: Any, is_df: bool) -> Any: worker = distributed.get_worker() with config.config_context(**global_config): booster.set_param({'nthread': worker.nthreads}) - m = DMatrix(partition, missing=missing, nthread=worker.nthreads) + m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads) predt = booster.predict( data=m, output_margin=output_margin, @@ -1587,7 +1607,7 @@ async def _predict_async( For dask implementation, group is not supported, use qid instead. """, ) -class DaskXGBRanker(DaskScikitLearnBase): +class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): @_deprecate_positional_args def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any): if callable(objective): @@ -1632,11 +1652,10 @@ async def _fit_async( if eval_metric is not None: if callable(eval_metric): raise ValueError( - 'Custom evaluation metric is not yet supported for XGBRanker.') + "Custom evaluation metric is not yet supported for XGBRanker." + ) model, metric, params = self._configure_fit( - booster=xgb_model, - eval_metric=eval_metric, - params=params + booster=xgb_model, eval_metric=eval_metric, params=params ) results = await train( client=self.client, diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index cde96118ec4d..555d066f61cf 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -737,16 +737,28 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902 area for meta info. ''' - def __init__(self, data, label, weight, base_margin, group, - label_lower_bound, label_upper_bound, - feature_names, feature_types): + def __init__( + self, data, + label, + weight, + base_margin, + group, + qid, + label_lower_bound, + label_upper_bound, + feature_weights, + feature_names, + feature_types + ): self.data = data self.label = label self.weight = weight self.base_margin = base_margin self.group = group + self.qid = qid self.label_lower_bound = label_lower_bound self.label_upper_bound = label_upper_bound + self.feature_weights = feature_weights self.feature_names = feature_names self.feature_types = feature_types self.it = 0 # pylint: disable=invalid-name @@ -759,8 +771,10 @@ def next(self, input_data): input_data(data=self.data, label=self.label, weight=self.weight, base_margin=self.base_margin, group=self.group, + qid=self.qid, label_lower_bound=self.label_lower_bound, label_upper_bound=self.label_upper_bound, + feature_weights=self.feature_weights, feature_names=self.feature_names, feature_types=self.feature_types) return 1 @@ -770,7 +784,8 @@ def reset(self): def init_device_quantile_dmatrix( - data, missing, max_bin, threads, feature_names, feature_types, **meta): + data, missing, max_bin, threads, feature_names, feature_types, **meta +): '''Constructor for DeviceQuantileDMatrix.''' if not any([_is_cudf_df(data), _is_cudf_ser(data), _is_cupy_array(data), _is_dlpack(data), _is_iter(data)]): diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 3fcbcc0edbf9..ebf552e1bd11 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -556,7 +556,7 @@ def load_model(self, fname): def _configure_fit( self, - booster: Optional[Booster], + booster: Optional[Union[Booster, "XGBModel"]], eval_metric: Optional[Union[Callable, str, List[str]]], params: Dict[str, Any], ) -> Tuple[Booster, Optional[Metric], Dict[str, Any]]: @@ -631,7 +631,7 @@ def fit(self, X, y, *, sample_weight=None, base_margin=None, verbose : bool If `verbose` and an evaluation set is used, writes the evaluation metric measured on the validation set to stderr. - xgb_model : str + xgb_model : Union[str, Booster, XGBModel] file name of stored XGBoost model or 'Booster' instance XGBoost model to be loaded before training (allows training continuation). sample_weight_eval_set : list, optional @@ -942,10 +942,22 @@ def __init__(self, *, objective="binary:logistic", use_label_encoder=True, **kwa super().__init__(objective=objective, **kwargs) @_deprecate_positional_args - def fit(self, X, y, *, sample_weight=None, base_margin=None, - eval_set=None, eval_metric=None, - early_stopping_rounds=None, verbose=True, xgb_model=None, - sample_weight_eval_set=None, feature_weights=None, callbacks=None): + def fit( + self, + X, + y, + *, + sample_weight=None, + base_margin=None, + eval_set=None, + eval_metric=None, + early_stopping_rounds=None, + verbose=True, + xgb_model=None, + sample_weight_eval_set=None, + feature_weights=None, + callbacks=None + ): # pylint: disable = attribute-defined-outside-init,arguments-differ,too-many-statements can_use_label_encoder = True @@ -1283,7 +1295,10 @@ def __init__(self, *, objective='rank:pairwise', **kwargs): @_deprecate_positional_args def fit( - self, X, y, *, + self, + X, + y, + *, group=None, qid=None, sample_weight=None, @@ -1372,7 +1387,7 @@ def fit( verbose : bool If `verbose` and an evaluation set is used, writes the evaluation metric measured on the validation set to stderr. - xgb_model : str + xgb_model : Union[str, Booster, XGBModel] file name of stored XGBoost model or 'Booster' instance XGBoost model to be loaded before training (allows training continuation). feature_weights: array_like @@ -1391,9 +1406,8 @@ def fit( save_best=True)] """ - # check if group information is provided - if group is None: - raise ValueError("group is required for ranking task") + if group is None and qid is None: + raise ValueError("group or qid is required for ranking task") if eval_set is not None: if eval_group is None and eval_qid is None: diff --git a/tests/python-gpu/test_device_quantile_dmatrix.py b/tests/python-gpu/test_device_quantile_dmatrix.py index 4f90480f90cd..2695a1168380 100644 --- a/tests/python-gpu/test_device_quantile_dmatrix.py +++ b/tests/python-gpu/test_device_quantile_dmatrix.py @@ -34,3 +34,25 @@ def test_dmatrix_cupy_init(self): import cupy as cp data = cp.random.randn(5, 5) xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64)) + + @pytest.mark.skipif(**tm.no_cupy()) + def test_metainfo(self) -> None: + import cupy as cp + rng = cp.random.RandomState(1994) + + rows = 10 + cols = 3 + data = rng.randn(rows, cols) + + labels = rng.randn(rows) + + fw = rng.randn(rows) + fw -= fw.min() + + m = xgb.DeviceQuantileDMatrix(data=data, label=labels, feature_weights=fw) + + got_fw = m.get_float_info("feature_weights") + got_labels = m.get_label() + + cp.testing.assert_allclose(fw, got_fw) + cp.testing.assert_allclose(labels, got_labels) diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 476a9651a258..da8bd6298595 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -6,7 +6,9 @@ import asyncio import xgboost import subprocess -from hypothesis import given, strategies, settings, note, HealthCheck +from collections import OrderedDict +from inspect import signature +from hypothesis import given, strategies, settings, note from hypothesis._settings import duration from test_gpu_updaters import parameter_strategy @@ -18,13 +20,15 @@ from test_with_dask import run_empty_dmatrix_cls # noqa from test_with_dask import _get_client_workers # noqa from test_with_dask import generate_array # noqa -from test_with_dask import suppress +from test_with_dask import kCols as random_cols # noqa +from test_with_dask import suppress # noqa import testing as tm # noqa try: import dask.dataframe as dd from xgboost import dask as dxgb + import xgboost as xgb from dask.distributed import Client from dask import array as da from dask_cuda import LocalCUDACluster @@ -252,6 +256,64 @@ def test_empty_dmatrix(self, local_cuda_cluster: LocalCUDACluster) -> None: run_empty_dmatrix_reg(client, parameters) run_empty_dmatrix_cls(client, parameters) + def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None: + with Client(local_cuda_cluster) as client: + X, y, _ = generate_array() + fw = da.random.random((random_cols, )) + fw = fw - fw.min() + m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw) + + workers = list(_get_client_workers(client).keys()) + rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client) + + def worker_fn(worker_addr: str, data_ref: Dict) -> None: + with dxgb.RabitContext(rabit_args): + local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref) + fw_rows = local_dtrain.get_float_info("feature_weights").shape[0] + assert fw_rows == local_dtrain.num_col() + + futures = [] + for i in range(len(workers)): + futures.append(client.submit(worker_fn, workers[i], + m.create_fn_args(workers[i]), pure=False, + workers=[workers[i]])) + client.gather(futures) + + def test_interface_consistency(self) -> None: + sig = OrderedDict(signature(dxgb.DaskDMatrix).parameters) + del sig["client"] + ddm_names = list(sig.keys()) + sig = OrderedDict(signature(dxgb.DaskDeviceQuantileDMatrix).parameters) + del sig["client"] + del sig["max_bin"] + ddqdm_names = list(sig.keys()) + assert len(ddm_names) == len(ddqdm_names) + + # between dask + for i in range(len(ddm_names)): + assert ddm_names[i] == ddqdm_names[i] + + sig = OrderedDict(signature(xgb.DMatrix).parameters) + del sig["nthread"] # no nthread in dask + dm_names = list(sig.keys()) + sig = OrderedDict(signature(xgb.DeviceQuantileDMatrix).parameters) + del sig["nthread"] + del sig["max_bin"] + dqdm_names = list(sig.keys()) + + # between single node + assert len(dm_names) == len(dqdm_names) + for i in range(len(dm_names)): + assert dm_names[i] == dqdm_names[i] + + # ddm <-> dm + for i in range(len(ddm_names)): + assert ddm_names[i] == dm_names[i] + + # dqdm <-> ddqdm + for i in range(len(ddqdm_names)): + assert ddqdm_names[i] == dqdm_names[i] + def run_quantile(self, name: str, local_cuda_cluster: LocalCUDACluster) -> None: if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows")