Skip to content

Commit

Permalink
[EM] Add tests for irregular data shapes.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Nov 4, 2024
1 parent 27eb304 commit ac08cea
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 5 deletions.
20 changes: 19 additions & 1 deletion python-package/xgboost/testing/data_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from xgboost import testing as tm

from ..core import DataIter, ExtMemQuantileDMatrix, QuantileDMatrix
from ..core import DataIter, DMatrix, ExtMemQuantileDMatrix, QuantileDMatrix


def run_mixed_sparsity(device: str) -> None:
Expand Down Expand Up @@ -78,6 +78,24 @@ def reset(self) -> None:
ExtMemQuantileDMatrix(it, enable_categorical=True)


def check_uneven_sizes(device: str) -> None:
"""Tests for having irregular data shapes."""
batches = [
tm.make_regression(n_samples, 16, use_cupy=device == "cuda")
for n_samples in [512, 256, 1024]
]
unzip = list(zip(*batches))
it = tm.IteratorForTest(unzip[0], unzip[1], None, cache="cache", on_host=True)

Xy = DMatrix(it)
assert Xy.num_col() == 16
assert Xy.num_row() == sum(x.shape[0] for x in unzip[0])

Xy = ExtMemQuantileDMatrix(it)
assert Xy.num_col() == 16
assert Xy.num_row() == sum(x.shape[0] for x in unzip[0])


class CatIter(DataIter): # pylint: disable=too-many-instance-attributes
"""An iterator for testing categorical features."""

Expand Down
5 changes: 3 additions & 2 deletions src/data/ellpack_page_source.cu
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,9 @@ void ExtEllpackPageSourceImpl<F>::Fetch() {
this->GetCuts()};
this->info_->Extend(proxy_->Info(), false, true);
});
// The size of ellpack is logged in write cache.
LOG(INFO) << "Estimated batch size:"
LOG(INFO) << "Generated an Ellpack page with size: "
<< common::HumanMemUnit(this->page_->Impl()->MemCostBytes())
<< " from an batch with estimated size: "
<< cuda_impl::Dispatch<false>(proxy_, [](auto const& adapter) {
return common::HumanMemUnit(adapter->SizeBytes());
});
Expand Down
6 changes: 5 additions & 1 deletion tests/python-gpu/test_gpu_data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import xgboost as xgb
from xgboost import testing as tm
from xgboost.testing import no_cupy
from xgboost.testing.data_iter import check_invalid_cat_batches
from xgboost.testing.data_iter import check_invalid_cat_batches, check_uneven_sizes
from xgboost.testing.updater import (
check_categorical_missing,
check_categorical_ohe,
Expand Down Expand Up @@ -231,3 +231,7 @@ def test_categorical_ohe(tree_method: str) -> None:
@pytest.mark.skipif(**tm.no_cupy())
def test_invalid_cat_batches() -> None:
check_invalid_cat_batches("cuda")


def test_uneven_sizes() -> None:
check_uneven_sizes("cuda")
7 changes: 6 additions & 1 deletion tests/python/test_data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from xgboost import testing as tm
from xgboost.data import SingleBatchInternalIter as SingleBatch
from xgboost.testing import IteratorForTest, make_batches, non_increasing
from xgboost.testing.data_iter import check_invalid_cat_batches
from xgboost.testing.data_iter import check_invalid_cat_batches, check_uneven_sizes
from xgboost.testing.updater import (
check_categorical_missing,
check_categorical_ohe,
Expand Down Expand Up @@ -375,3 +375,8 @@ def test_categorical_ohe(tree_method: str) -> None:

def test_invalid_cat_batches() -> None:
check_invalid_cat_batches("cpu")


@pytest.mark.skipif(**tm.no_cupy())
def test_uneven_sizes() -> None:
check_uneven_sizes("cpu")

0 comments on commit ac08cea

Please sign in to comment.