diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 26f4bc56cc39..87a955372886 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1726,6 +1726,7 @@ def _init( nthread=self.nthread, cache_prefix=it.cache_prefix if it.cache_prefix else "", on_host=it.on_host, + max_bin=self.max_bin, ) handle = ctypes.c_void_p() reset_callback, next_callback = it.get_callbacks(enable_categorical) diff --git a/python-package/xgboost/testing/updater.py b/python-package/xgboost/testing/updater.py index b92d7a9fa7e3..29125f4f5e41 100644 --- a/python-package/xgboost/testing/updater.py +++ b/python-package/xgboost/testing/updater.py @@ -195,10 +195,12 @@ def check_quantile_loss_extmem( np.testing.assert_allclose(predt, predt_it) -def check_extmem_qdm( +def check_extmem_qdm( # pylint: disable=too-many-arguments n_samples_per_batch: int, n_features: int, + *, n_batches: int, + n_bins: int, device: str, on_host: bool, ) -> None: @@ -212,21 +214,25 @@ def check_extmem_qdm( on_host=on_host, ) - Xy_it = xgb.ExtMemQuantileDMatrix(it) + Xy_it = xgb.ExtMemQuantileDMatrix(it, max_bin=n_bins) with pytest.raises(ValueError, match="Only the `hist`"): booster_it = xgb.train( - {"device": device, "tree_method": "approx"}, Xy_it, num_boost_round=8 + {"device": device, "tree_method": "approx", "max_bin": n_bins}, + Xy_it, + num_boost_round=8, ) - booster_it = xgb.train({"device": device}, Xy_it, num_boost_round=8) + booster_it = xgb.train( + {"device": device, "max_bin": n_bins}, Xy_it, num_boost_round=8 + ) it = tm.IteratorForTest( *tm.make_batches( n_samples_per_batch, n_features, n_batches, use_cupy=device != "cpu" ), cache=None, ) - Xy = xgb.QuantileDMatrix(it) - booster = xgb.train({"device": device}, Xy, num_boost_round=8) + Xy = xgb.QuantileDMatrix(it, max_bin=n_bins) + booster = xgb.train({"device": device, "max_bin": n_bins}, Xy, num_boost_round=8) cut_it = Xy_it.get_quantile_cut() cut = Xy.get_quantile_cut() diff --git a/tests/python-gpu/test_gpu_data_iterator.py b/tests/python-gpu/test_gpu_data_iterator.py index 094b324b5a3f..63333579f140 100644 --- a/tests/python-gpu/test_gpu_data_iterator.py +++ b/tests/python-gpu/test_gpu_data_iterator.py @@ -65,14 +65,26 @@ def test_cpu_data_iterator() -> None: strategies.integers(1, 2048), strategies.integers(1, 8), strategies.integers(1, 4), + strategies.integers(2, 16), strategies.booleans(), ) @settings(deadline=None, max_examples=10, print_blob=True) @pytest.mark.filterwarnings("ignore") def test_extmem_qdm( - n_samples_per_batch: int, n_features: int, n_batches: int, on_host: bool + n_samples_per_batch: int, + n_features: int, + n_batches: int, + n_bins: int, + on_host: bool, ) -> None: - check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cuda", on_host) + check_extmem_qdm( + n_samples_per_batch, + n_features, + n_batches=n_batches, + n_bins=n_bins, + device="cuda", + on_host=on_host, + ) @pytest.mark.filterwarnings("ignore") diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index d5498b8523c3..698a644b5d55 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -310,7 +310,17 @@ def test_quantile_objective( strategies.integers(1, 4096), strategies.integers(1, 8), strategies.integers(1, 4), + strategies.integers(2, 16), ) @settings(deadline=None, max_examples=10, print_blob=True) -def test_extmem_qdm(n_samples_per_batch: int, n_features: int, n_batches: int) -> None: - check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cpu", False) +def test_extmem_qdm( + n_samples_per_batch: int, n_features: int, n_batches: int, n_bins: int +) -> None: + check_extmem_qdm( + n_samples_per_batch, + n_features, + n_batches=n_batches, + n_bins=n_bins, + device="cpu", + on_host=False, + )