Skip to content

Commit

Permalink
[EM] Fix the max bin parameter. (#10886)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Oct 11, 2024
1 parent 59e6c92 commit 24aeaf4
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 10 deletions.
1 change: 1 addition & 0 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,7 @@ def _init(
nthread=self.nthread,
cache_prefix=it.cache_prefix if it.cache_prefix else "",
on_host=it.on_host,
max_bin=self.max_bin,
)
handle = ctypes.c_void_p()
reset_callback, next_callback = it.get_callbacks(enable_categorical)
Expand Down
18 changes: 12 additions & 6 deletions python-package/xgboost/testing/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,12 @@ def check_quantile_loss_extmem(
np.testing.assert_allclose(predt, predt_it)


def check_extmem_qdm(
def check_extmem_qdm( # pylint: disable=too-many-arguments
n_samples_per_batch: int,
n_features: int,
*,
n_batches: int,
n_bins: int,
device: str,
on_host: bool,
) -> None:
Expand All @@ -212,21 +214,25 @@ def check_extmem_qdm(
on_host=on_host,
)

Xy_it = xgb.ExtMemQuantileDMatrix(it)
Xy_it = xgb.ExtMemQuantileDMatrix(it, max_bin=n_bins)
with pytest.raises(ValueError, match="Only the `hist`"):
booster_it = xgb.train(
{"device": device, "tree_method": "approx"}, Xy_it, num_boost_round=8
{"device": device, "tree_method": "approx", "max_bin": n_bins},
Xy_it,
num_boost_round=8,
)

booster_it = xgb.train({"device": device}, Xy_it, num_boost_round=8)
booster_it = xgb.train(
{"device": device, "max_bin": n_bins}, Xy_it, num_boost_round=8
)
it = tm.IteratorForTest(
*tm.make_batches(
n_samples_per_batch, n_features, n_batches, use_cupy=device != "cpu"
),
cache=None,
)
Xy = xgb.QuantileDMatrix(it)
booster = xgb.train({"device": device}, Xy, num_boost_round=8)
Xy = xgb.QuantileDMatrix(it, max_bin=n_bins)
booster = xgb.train({"device": device, "max_bin": n_bins}, Xy, num_boost_round=8)

cut_it = Xy_it.get_quantile_cut()
cut = Xy.get_quantile_cut()
Expand Down
16 changes: 14 additions & 2 deletions tests/python-gpu/test_gpu_data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,26 @@ def test_cpu_data_iterator() -> None:
strategies.integers(1, 2048),
strategies.integers(1, 8),
strategies.integers(1, 4),
strategies.integers(2, 16),
strategies.booleans(),
)
@settings(deadline=None, max_examples=10, print_blob=True)
@pytest.mark.filterwarnings("ignore")
def test_extmem_qdm(
n_samples_per_batch: int, n_features: int, n_batches: int, on_host: bool
n_samples_per_batch: int,
n_features: int,
n_batches: int,
n_bins: int,
on_host: bool,
) -> None:
check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cuda", on_host)
check_extmem_qdm(
n_samples_per_batch,
n_features,
n_batches=n_batches,
n_bins=n_bins,
device="cuda",
on_host=on_host,
)


@pytest.mark.filterwarnings("ignore")
Expand Down
14 changes: 12 additions & 2 deletions tests/python/test_data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,17 @@ def test_quantile_objective(
strategies.integers(1, 4096),
strategies.integers(1, 8),
strategies.integers(1, 4),
strategies.integers(2, 16),
)
@settings(deadline=None, max_examples=10, print_blob=True)
def test_extmem_qdm(n_samples_per_batch: int, n_features: int, n_batches: int) -> None:
check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cpu", False)
def test_extmem_qdm(
n_samples_per_batch: int, n_features: int, n_batches: int, n_bins: int
) -> None:
check_extmem_qdm(
n_samples_per_batch,
n_features,
n_batches=n_batches,
n_bins=n_bins,
device="cpu",
on_host=False,
)

0 comments on commit 24aeaf4

Please sign in to comment.