From ccc5f057f27686d7724573114eddb0249310ac5b Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 6 Nov 2024 04:46:49 +0800 Subject: [PATCH] [EM] Add tests for irregular data shapes. (#10980) - More tests. - Recommend arena in the document. --- demo/dask/forward_logging.py | 3 ++- demo/guide-python/distributed_extmem_basic.py | 22 ++++++++++++++----- doc/conf.py | 1 + doc/tutorials/external_memory.rst | 18 +++++++++------ python-package/xgboost/testing/data_iter.py | 20 ++++++++++++++++- src/data/ellpack_page_source.cu | 5 +++-- tests/python-gpu/test_gpu_data_iterator.py | 6 ++++- tests/python/test_data_iterator.py | 7 +++++- 8 files changed, 64 insertions(+), 18 deletions(-) diff --git a/demo/dask/forward_logging.py b/demo/dask/forward_logging.py index d49d8c1cbfe6..37189e8a429a 100644 --- a/demo/dask/forward_logging.py +++ b/demo/dask/forward_logging.py @@ -1,4 +1,5 @@ -"""Example of forwarding evaluation logs to the client +""" +Example of forwarding evaluation logs to the client =================================================== The example runs on GPU. Two classes are defined to show how to use Dask builtins to diff --git a/demo/guide-python/distributed_extmem_basic.py b/demo/guide-python/distributed_extmem_basic.py index 2ee9b33f6684..00f4fd59c68a 100644 --- a/demo/guide-python/distributed_extmem_basic.py +++ b/demo/guide-python/distributed_extmem_basic.py @@ -13,6 +13,7 @@ If `device` is `cuda`, following are also needed: - cupy +- python-cuda - rmm """ @@ -104,11 +105,22 @@ def setup_rmm() -> None: if not xgboost.build_info()["USE_RMM"]: return - # The combination of pool and async is by design. As XGBoost needs to allocate large - # pages repeatly, it's not easy to handle fragmentation. We can use more experiments - # here. - mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource()) - rmm.mr.set_current_device_resource(mr) + try: + from cuda import cudart + from rmm.mr import ArenaMemoryResource + + status, free, total = cudart.cudaMemGetInfo() + if status != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(cudart.cudaGetErrorString(status)) + + mr = rmm.mr.CudaMemoryResource() + mr = ArenaMemoryResource(mr, arena_size=int(total * 0.9)) + except ImportError: + # The combination of pool and async is by design. As XGBoost needs to allocate + # large pages repeatly, it's not easy to handle fragmentation. We can use more + # experiments here. + mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource()) + rmm.mr.set_current_device_resource(mr) # Set the allocator for cupy as well. cp.cuda.set_allocator(rmm_cupy_allocator) diff --git a/doc/conf.py b/doc/conf.py index a2546cbbc336..89dc0f4eaee2 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -294,6 +294,7 @@ def is_readthedocs_build(): "dask": ("https://docs.dask.org/en/stable/", None), "distributed": ("https://distributed.dask.org/en/stable/", None), "pyspark": ("https://spark.apache.org/docs/latest/api/python/", None), + "rmm": ("https://docs.rapids.ai/api/rmm/nightly/", None), } diff --git a/doc/tutorials/external_memory.rst b/doc/tutorials/external_memory.rst index c0fa7fd98769..bbdd9f20df2b 100644 --- a/doc/tutorials/external_memory.rst +++ b/doc/tutorials/external_memory.rst @@ -138,6 +138,8 @@ the GPU. Following is a snippet from :ref:`sphx_glr_python_examples_external_mem # It's important to use RMM for GPU-based external memory to improve performance. # If XGBoost is not built with RMM support, a warning will be raised. + # We use the pool memory resource here, you can also try the `ArenaMemoryResource` for + # improved memory fragmentation handling. mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource()) rmm.mr.set_current_device_resource(mr) # Set the allocator for cupy as well. @@ -278,13 +280,15 @@ determines the time it takes to run inference, even if a C2C link is available. Xy_valid = xgboost.ExtMemQuantileDMatrix(it_valid, max_bin=n_bins, ref=Xy_train) In addition, since the GPU implementation relies on asynchronous memory pool, which is -subject to memory fragmentation even if the ``CudaAsyncMemoryResource`` is used. You might -want to start the training with a fresh pool instead of starting training right after the -ETL process. If you run into out-of-memory errors and you are convinced that the pool is -not full yet (pool memory usage can be profiled with ``nsight-system``), consider tuning -the RMM memory resource like using ``rmm.mr.CudaAsyncMemoryResource`` in conjunction with -``rmm.mr.BinningMemoryResource(mr, 21, 25)`` instead of the -``rmm.mr.PoolMemoryResource(mr)`` shown in the example. +subject to memory fragmentation even if the :py:class:`~rmm.mr.CudaAsyncMemoryResource` is +used. You might want to start the training with a fresh pool instead of starting training +right after the ETL process. If you run into out-of-memory errors and you are convinced +that the pool is not full yet (pool memory usage can be profiled with ``nsight-system``), +consider tuning the RMM memory resource like using +:py:class:`~rmm.mr.CudaAsyncMemoryResource` in conjunction with +:py:class:`BinningMemoryResource(mr, 21, 25) ` instead of +the :py:class:`~rmm.mr.PoolMemoryResource`. Alternately, the +:py:class:`~rmm.mr.ArenaMemoryResource` is also an excellent option. During CPU benchmarking, we used an NVMe connected to a PCIe-4 slot. Other types of storage can be too slow for practical usage. However, your system will likely perform some diff --git a/python-package/xgboost/testing/data_iter.py b/python-package/xgboost/testing/data_iter.py index 924282dda872..bd612e2c3e84 100644 --- a/python-package/xgboost/testing/data_iter.py +++ b/python-package/xgboost/testing/data_iter.py @@ -6,7 +6,7 @@ from xgboost import testing as tm -from ..core import DataIter, ExtMemQuantileDMatrix, QuantileDMatrix +from ..core import DataIter, DMatrix, ExtMemQuantileDMatrix, QuantileDMatrix def run_mixed_sparsity(device: str) -> None: @@ -78,6 +78,24 @@ def reset(self) -> None: ExtMemQuantileDMatrix(it, enable_categorical=True) +def check_uneven_sizes(device: str) -> None: + """Tests for having irregular data shapes.""" + batches = [ + tm.make_regression(n_samples, 16, use_cupy=device == "cuda") + for n_samples in [512, 256, 1024] + ] + unzip = list(zip(*batches)) + it = tm.IteratorForTest(unzip[0], unzip[1], None, cache="cache", on_host=True) + + Xy = DMatrix(it) + assert Xy.num_col() == 16 + assert Xy.num_row() == sum(x.shape[0] for x in unzip[0]) + + Xy = ExtMemQuantileDMatrix(it) + assert Xy.num_col() == 16 + assert Xy.num_row() == sum(x.shape[0] for x in unzip[0]) + + class CatIter(DataIter): # pylint: disable=too-many-instance-attributes """An iterator for testing categorical features.""" diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 4901f900a7d5..5fb6fc925111 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -404,8 +404,9 @@ void ExtEllpackPageSourceImpl::Fetch() { this->GetCuts()}; this->info_->Extend(proxy_->Info(), false, true); }); - // The size of ellpack is logged in write cache. - LOG(INFO) << "Estimated batch size:" + LOG(INFO) << "Generated an Ellpack page with size: " + << common::HumanMemUnit(this->page_->Impl()->MemCostBytes()) + << " from an batch with estimated size: " << cuda_impl::Dispatch(proxy_, [](auto const& adapter) { return common::HumanMemUnit(adapter->SizeBytes()); }); diff --git a/tests/python-gpu/test_gpu_data_iterator.py b/tests/python-gpu/test_gpu_data_iterator.py index 83bf44ccdef5..b3e7254244b6 100644 --- a/tests/python-gpu/test_gpu_data_iterator.py +++ b/tests/python-gpu/test_gpu_data_iterator.py @@ -7,7 +7,7 @@ import xgboost as xgb from xgboost import testing as tm from xgboost.testing import no_cupy -from xgboost.testing.data_iter import check_invalid_cat_batches +from xgboost.testing.data_iter import check_invalid_cat_batches, check_uneven_sizes from xgboost.testing.updater import ( check_categorical_missing, check_categorical_ohe, @@ -231,3 +231,7 @@ def test_categorical_ohe(tree_method: str) -> None: @pytest.mark.skipif(**tm.no_cupy()) def test_invalid_cat_batches() -> None: check_invalid_cat_batches("cuda") + + +def test_uneven_sizes() -> None: + check_uneven_sizes("cuda") diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index c15ae77c1e19..545b849b4bdb 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -12,7 +12,7 @@ from xgboost import testing as tm from xgboost.data import SingleBatchInternalIter as SingleBatch from xgboost.testing import IteratorForTest, make_batches, non_increasing -from xgboost.testing.data_iter import check_invalid_cat_batches +from xgboost.testing.data_iter import check_invalid_cat_batches, check_uneven_sizes from xgboost.testing.updater import ( check_categorical_missing, check_categorical_ohe, @@ -375,3 +375,8 @@ def test_categorical_ohe(tree_method: str) -> None: def test_invalid_cat_batches() -> None: check_invalid_cat_batches("cpu") + + +@pytest.mark.skipif(**tm.no_cupy()) +def test_uneven_sizes() -> None: + check_uneven_sizes("cpu")