From bdd5e5913b2a5ff92fa98d77b6e76924d5afc05a Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 24 Sep 2024 19:45:55 +0800 Subject: [PATCH] [EM] Concatenate ellpack pages for ExtMemQdm. The concatenation only happens if ExtMemQdm is used with host cache. We will work on other types of cache for consistency in the future. For now, the host cache in ExtMemQdm is the most important use case. --- demo/guide-python/external_memory.py | 32 +++- doc/tutorials/external_memory.rst | 7 +- include/xgboost/c_api.h | 3 + include/xgboost/data.h | 26 ++- python-package/xgboost/core.py | 20 +++ python-package/xgboost/testing/__init__.py | 7 +- src/c_api/c_api.cc | 123 ++++++++----- src/common/cuda_rt_utils.cc | 7 + src/common/cuda_rt_utils.h | 5 + src/common/device_helpers.cuh | 8 - src/common/device_vector.cuh | 3 +- src/common/hist_util.cu | 3 +- src/data/batch_utils.h | 7 + src/data/data.cc | 34 ++-- src/data/ellpack_page.cu | 31 ++-- src/data/ellpack_page.cuh | 22 ++- src/data/ellpack_page_raw_format.cu | 8 +- src/data/ellpack_page_source.cu | 167 +++++++++++++----- src/data/ellpack_page_source.h | 139 +++++++++++---- src/data/extmem_quantile_dmatrix.cc | 24 ++- src/data/extmem_quantile_dmatrix.cu | 31 +++- src/data/extmem_quantile_dmatrix.h | 8 +- src/data/proxy_dmatrix.h | 3 +- src/data/quantile_dmatrix.cc | 4 +- src/data/quantile_dmatrix.cu | 16 +- src/data/sparse_page_dmatrix.cc | 55 +++--- src/data/sparse_page_dmatrix.cu | 24 ++- src/data/sparse_page_dmatrix.h | 9 +- src/data/sparse_page_source.h | 13 +- tests/cpp/c_api/test_c_api.cc | 19 +- tests/cpp/data/test_ellpack_page.cu | 5 +- .../cpp/data/test_ellpack_page_raw_format.cu | 29 ++- .../cpp/data/test_extmem_quantile_dmatrix.cu | 71 ++++++++ tests/cpp/data/test_sparse_page_dmatrix.cc | 24 +-- tests/cpp/data/test_sparse_page_dmatrix.cu | 5 +- tests/cpp/helpers.cc | 25 ++- tests/cpp/helpers.h | 10 +- tests/python-gpu/test_gpu_data_iterator.py | 30 +++- 38 files changed, 750 insertions(+), 307 deletions(-) diff --git a/demo/guide-python/external_memory.py b/demo/guide-python/external_memory.py index 3f5e0a892b8b..f32abef5b902 100644 --- a/demo/guide-python/external_memory.py +++ b/demo/guide-python/external_memory.py @@ -142,21 +142,35 @@ def main(tmpdir: str, args: argparse.Namespace) -> None: approx_train(it) +def setup_rmm() -> None: + """Setup RMM for GPU-based external memory training.""" + import rmm + from cuda import cudart + from rmm.allocators.cupy import rmm_cupy_allocator + + if not xgboost.build_info()["USE_RMM"]: + return + + # The combination of pool and async is by design. As XGBoost needs to allocate large + # pages repeatly, it's not easy to handle fragmentation. We can use more experiments + # here. + mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource()) + rmm.mr.set_current_device_resource(mr) + # Set the allocator for cupy as well. + cp.cuda.set_allocator(rmm_cupy_allocator) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--device", choices=["cpu", "cuda"], default="cpu") args = parser.parse_args() if args.device == "cuda": import cupy as cp - import rmm - from rmm.allocators.cupy import rmm_cupy_allocator - - # It's important to use RMM for GPU-based external memory to improve performance. - # If XGBoost is not built with RMM support, a warning will be raised. - mr = rmm.mr.CudaAsyncMemoryResource() - rmm.mr.set_current_device_resource(mr) - # Set the allocator for cupy as well. - cp.cuda.set_allocator(rmm_cupy_allocator) + + # It's important to use RMM with `CudaAsyncMemoryResource`. for GPU-based + # external memory to improve performance. If XGBoost is not built with RMM + # support, a warning is raised when constructing the `DMatrix`. + setup_rmm() # Make sure XGBoost is using RMM for all allocations. with xgboost.config_context(use_rmm=True): with tempfile.TemporaryDirectory() as tmpdir: diff --git a/doc/tutorials/external_memory.rst b/doc/tutorials/external_memory.rst index 80b91775c1a9..f12f8f60bf87 100644 --- a/doc/tutorials/external_memory.rst +++ b/doc/tutorials/external_memory.rst @@ -134,7 +134,7 @@ the GPU. This is a current limitation we aim to address in the future. # It's important to use RMM for GPU-based external memory to improve performance. # If XGBoost is not built with RMM support, a warning will be raised. - mr = rmm.mr.CudaAsyncMemoryResource() + mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource()) rmm.mr.set_current_device_resource(mr) # Set the allocator for cupy as well. cp.cuda.set_allocator(rmm_cupy_allocator) @@ -159,9 +159,8 @@ the GPU. This is a current limitation we aim to address in the future. It's crucial to use `RAPIDS Memory Manager (RMM) `__ for all memory allocation when training with external memory. XGBoost relies on the memory -pool to reduce the overhead for data fetching. The size of each batch should be slightly -smaller than a quarter of the available GPU memory. In addition, the open source `NVIDIA -Linux driver +pool to reduce the overhead for data fetching. In addition, the open source `NVIDIA Linux +driver `__ is required for ``Heterogeneous memory management (HMM)`` support. diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index c4ab4f2467c0..b2eca8dafb3d 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -523,6 +523,9 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand * - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with the corresponding booster training parameter. * - on_host (optional): Whether the data should be placed on host memory. Used by GPU inputs. + * - min_cache_page_bytes (optional): The minimum number of bytes for each internal GPU + * page. Set to 0 to disable page concatenation. Automatic configuration if the + * parameter is not provided or set to None. * @param out The created Quantile DMatrix. * * @return 0 when success, -1 when failure happens diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 2254a348b282..d8922b28cb29 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -517,6 +517,20 @@ class BatchSet { struct XGBAPIThreadLocalEntry; +struct ExtMemConfig { + // Cache prefix, not used if the cache is in the host memory. + std::string cache; + // Whether the ellpack page is stored in the host memory. + bool on_host{true}; + // Minimum number of of bytes for each ellpack page in cache. Only used for in-host + // ExtMemQdm. + bst_idx_t min_cache_page_bytes{0}; + // Missing value. + float missing{std::numeric_limits::quiet_NaN()}; + // The number of threads for CPU. + std::int32_t n_threads{0}; +}; + /** * @brief Internal data structured used by XGBoost to hold all external data. * @@ -637,18 +651,14 @@ class DMatrix { * @param proxy A hanlde to ProxyDMatrix * @param reset Callback for reset * @param next Callback for next - * @param missing Value that should be treated as missing. - * @param nthread number of threads used for initialization. - * @param cache Prefix of cache file path. - * @param on_host Used for GPU, whether the data should be cached on host memory. + * @param config Configuration for the cache. * * @return A created external memory DMatrix. */ template static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset, - XGDMatrixCallbackNext* next, float missing, std::int32_t nthread, - std::string cache, bool on_host); + XGDMatrixCallbackNext* next, ExtMemConfig const& config); /** * @brief Create an external memory quantile DMatrix with callbacks. @@ -660,8 +670,8 @@ class DMatrix { template static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr ref, - DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing, - std::int32_t nthread, bst_bin_t max_bin, std::string cache, bool on_host); + DataIterResetCallback* reset, XGDMatrixCallbackNext* next, + bst_bin_t max_bin, ExtMemConfig const& config); virtual DMatrix *Slice(common::Span ridxs) = 0; diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 87a955372886..c182f723cfe4 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -536,16 +536,34 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes This is an experimental parameter. + min_cache_page_bytes : + Only used for on-host cache with GPU. The minimum number of bytes of each cached + pages. When using GPU-based external memory with data cached in the host memory, + XGBoost can concatenate the pages internally to increase the batch size for the + GPU. The page concatenation is enabled by default when all conditions are + satisfied and is set to about a 1/8 of the total device memory. Users can + manually set the values based on the actual hardware. If it's set to 0, then no + page concatenation is performed. + + .. versionadded:: 3.0.0 + + .. warning:: + + This is an experimental parameter. + """ def __init__( self, cache_prefix: Optional[str] = None, release_data: bool = True, + *, on_host: bool = True, + min_cache_page_bytes: Optional[int] = None, ) -> None: self.cache_prefix = cache_prefix self.on_host = on_host + self.min_cache_page_bytes = min_cache_page_bytes self._handle = _ProxyDMatrix() self._exception: Optional[Exception] = None @@ -940,6 +958,7 @@ def _init_from_iter(self, it: DataIter, enable_categorical: bool) -> None: nthread=self.nthread, cache_prefix=it.cache_prefix if it.cache_prefix else "", on_host=it.on_host, + min_cache_page_bytes=it.min_cache_page_bytes, ) handle = ctypes.c_void_p() reset_callback, next_callback = it.get_callbacks(enable_categorical) @@ -1727,6 +1746,7 @@ def _init( cache_prefix=it.cache_prefix if it.cache_prefix else "", on_host=it.on_host, max_bin=self.max_bin, + min_cache_page_bytes=it.min_cache_page_bytes, ) handle = ctypes.c_void_p() reset_callback, next_callback = it.get_callbacks(enable_categorical) diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 9d64616ea9f6..f707cfb1f0a5 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -227,13 +227,18 @@ def __init__( # pylint: disable=too-many-arguments *, cache: Optional[str], on_host: bool = False, + min_cache_page_bytes: Optional[int] = None, ) -> None: assert len(X) == len(y) self.X = X self.y = y self.w = w self.it = 0 - super().__init__(cache_prefix=cache, on_host=on_host) + super().__init__( + cache_prefix=cache, + on_host=on_host, + min_cache_page_bytes=min_cache_page_bytes, + ) def next(self, input_data: Callable) -> bool: if self.it == len(self.X): diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index d2610ff1586b..49b03731ad07 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -3,47 +3,48 @@ */ #include "xgboost/c_api.h" -#include // for copy, transform -#include // for strtoimax -#include // for nan -#include // for strcmp -#include // for numeric_limits -#include // for operator!=, _Rb_tree_const_iterator, _Rb_tre... -#include // for shared_ptr, allocator, __shared_ptr_access -#include // for char_traits, basic_string, operator==, string -#include // for errc -#include // for pair -#include // for vector - -#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry -#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch... -#include "../common/error_msg.h" // for NoFederated -#include "../common/hist_util.h" // for HistogramCuts -#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf... -#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor -#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte... -#include "../data/ellpack_page.h" // for EllpackPage -#include "../data/proxy_dmatrix.h" // for DMatrixProxy -#include "../data/simple_dmatrix.h" // for SimpleDMatrix -#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN -#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM... -#include "dmlc/base.h" // for BeginPtr -#include "dmlc/io.h" // for Stream -#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager -#include "dmlc/thread_local.h" // for ThreadLocalStore -#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat... -#include "xgboost/context.h" // for Context -#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage -#include "xgboost/feature_map.h" // for FeatureMap -#include "xgboost/global_config.h" // for GlobalConfiguration, GlobalConfigThreadLocal... -#include "xgboost/host_device_vector.h" // for HostDeviceVector -#include "xgboost/json.h" // for Json, get, Integer, IsA, Boolean, String -#include "xgboost/learner.h" // for Learner, PredictionType -#include "xgboost/logging.h" // for LOG_FATAL, LogMessageFatal, CHECK, LogCheck_EQ -#include "xgboost/predictor.h" // for PredictionCacheEntry -#include "xgboost/span.h" // for Span -#include "xgboost/string_view.h" // for StringView, operator<< -#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS... +#include // for copy, transform +#include // for strtoimax +#include // for nan +#include // for strcmp +#include // for numeric_limits +#include // for operator!=, _Rb_tree_const_iterator, _Rb_tre... +#include // for shared_ptr, allocator, __shared_ptr_access +#include // for char_traits, basic_string, operator==, string +#include // for errc +#include // for pair +#include // for vector + +#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry +#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch... +#include "../common/error_msg.h" // for NoFederated +#include "../common/hist_util.h" // for HistogramCuts +#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf... +#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor +#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte... +#include "../data/batch_utils.h" // for MatchingPageBytes, CachePageRatio +#include "../data/ellpack_page.h" // for EllpackPage +#include "../data/proxy_dmatrix.h" // for DMatrixProxy +#include "../data/simple_dmatrix.h" // for SimpleDMatrix +#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN +#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM... +#include "dmlc/base.h" // for BeginPtr +#include "dmlc/io.h" // for Stream +#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager +#include "dmlc/thread_local.h" // for ThreadLocalStore +#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat... +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage +#include "xgboost/feature_map.h" // for FeatureMap +#include "xgboost/global_config.h" // for GlobalConfiguration, GlobalConfigThreadLocal... +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/json.h" // for Json, get, Integer, IsA, Boolean, String +#include "xgboost/learner.h" // for Learner, PredictionType +#include "xgboost/logging.h" // for LOG_FATAL, LogMessageFatal, CHECK, LogCheck_EQ +#include "xgboost/predictor.h" // for PredictionCacheEntry +#include "xgboost/span.h" // for Span +#include "xgboost/string_view.h" // for StringView, operator<< +#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS... using namespace xgboost; // NOLINT(*); @@ -286,6 +287,20 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *, char const *, DM #endif +namespace { +// Default value for the ratio between a single cached page and the device memory size. +[[nodiscard]] bst_idx_t DftMinCachePageBytes(Json const &jconfig) { + // Set to 0 if it should match the user input size. + auto min_cache_page_bytes = + OptionalArg(jconfig, "min_cache_page_bytes", -1); + if (min_cache_page_bytes == -1) { + double n_total_bytes = curt::TotalMemory(); + min_cache_page_bytes = n_total_bytes * xgboost::cuda_impl::CachePageRatio(); + } + return min_cache_page_bytes; +} +} // namespace + // Create from data iterator XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset, XGDMatrixCallbackNext *next, @@ -296,15 +311,25 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy auto jconfig = Json::Load(StringView{config}); auto missing = GetMissing(jconfig); std::string cache = RequiredArg(jconfig, "cache_prefix", __func__); - auto n_threads = OptionalArg(jconfig, "nthread", 0); + std::int32_t n_threads = OptionalArg(jconfig, "nthread", 0); auto on_host = OptionalArg(jconfig, "on_host", false); + bst_idx_t min_cache_page_bytes = OptionalArg( + jconfig, "min_cache_page_bytes", cuda_impl::MatchingPageBytes()); + CHECK_EQ(min_cache_page_bytes, cuda_impl::MatchingPageBytes()) + << "Page concatenation is not supported by the DMatrix yet."; xgboost_CHECK_C_ARG_PTR(next); xgboost_CHECK_C_ARG_PTR(reset); xgboost_CHECK_C_ARG_PTR(out); + ExtMemConfig config{.cache = cache, + .on_host = on_host, + .min_cache_page_bytes = min_cache_page_bytes, + .missing = missing, + .n_threads = n_threads}; + *out = new std::shared_ptr{ - xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache, on_host)}; + xgboost::DMatrix::Create(iter, proxy, reset, next, config)}; API_END(); } @@ -368,17 +393,23 @@ XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr xgboost_CHECK_C_ARG_PTR(config); auto jconfig = Json::Load(StringView{config}); auto missing = GetMissing(jconfig); - auto n_threads = OptionalArg(jconfig, "nthread", 0); + std::int32_t n_threads = OptionalArg(jconfig, "nthread", 0); auto max_bin = OptionalArg(jconfig, "max_bin", 256); auto on_host = OptionalArg(jconfig, "on_host", false); std::string cache = RequiredArg(jconfig, "cache_prefix", __func__); + bst_idx_t min_cache_page_bytes = DftMinCachePageBytes(jconfig); xgboost_CHECK_C_ARG_PTR(next); xgboost_CHECK_C_ARG_PTR(reset); xgboost_CHECK_C_ARG_PTR(out); - *out = new std::shared_ptr{xgboost::DMatrix::Create( - iter, proxy, p_ref, reset, next, missing, n_threads, max_bin, cache, on_host)}; + ExtMemConfig config{.cache = cache, + .on_host = on_host, + .min_cache_page_bytes = min_cache_page_bytes, + .missing = missing, + .n_threads = n_threads}; + *out = new std::shared_ptr{ + xgboost::DMatrix::Create(iter, proxy, p_ref, reset, next, max_bin, config)}; API_END(); } diff --git a/src/common/cuda_rt_utils.cc b/src/common/cuda_rt_utils.cc index 53a4105dcb5f..66de8dd4d311 100644 --- a/src/common/cuda_rt_utils.cc +++ b/src/common/cuda_rt_utils.cc @@ -65,6 +65,13 @@ void SetDevice(std::int32_t device) { } } +[[nodiscard]] std::size_t TotalMemory() { + std::size_t device_free = 0; + std::size_t device_total = 0; + dh::safe_cuda(cudaMemGetInfo(&device_free, &device_total)); + return device_total; +} + namespace { template void GetVersionImpl(Fn&& fn, std::int32_t* major, std::int32_t* minor) { diff --git a/src/common/cuda_rt_utils.h b/src/common/cuda_rt_utils.h index 0fac7e35ef3e..cb57090f59db 100644 --- a/src/common/cuda_rt_utils.h +++ b/src/common/cuda_rt_utils.h @@ -24,6 +24,11 @@ void CheckComputeCapability(); void SetDevice(std::int32_t device); +/** + * @brief Total device memory size. + */ +[[nodiscard]] std::size_t TotalMemory(); + // Returns the CUDA Runtime version. void RtVersion(std::int32_t* major, std::int32_t* minor); diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 1678c8786010..b15b1f28c23a 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -120,14 +120,6 @@ inline auto GetDevice(xgboost::Context const *ctx) { return d; } -inline size_t TotalMemory(int device_idx) { - size_t device_free = 0; - size_t device_total = 0; - safe_cuda(cudaSetDevice(device_idx)); - dh::safe_cuda(cudaMemGetInfo(&device_free, &device_total)); - return device_total; -} - /** * \fn inline int MaxSharedMemory(int device_idx) * diff --git a/src/common/device_vector.cuh b/src/common/device_vector.cuh index 6daa4f565ced..308efd6d8004 100644 --- a/src/common/device_vector.cuh +++ b/src/common/device_vector.cuh @@ -228,8 +228,7 @@ class GrowOnlyVirtualMemVec { auto CreatePhysicalMem(std::size_t size) const { CUmemGenericAllocationHandle alloc_handle; auto padded_size = RoundUp(size, this->granularity_); - CUresult status = this->cu_.cuMemCreate(&alloc_handle, padded_size, &this->prop_, 0); - CHECK_EQ(status, CUDA_SUCCESS); + safe_cu(this->cu_.cuMemCreate(&alloc_handle, padded_size, &this->prop_, 0)); return alloc_handle; } void Reserve(std::size_t new_size); diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 8b80e86ec2b7..a22e8263ce43 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -86,8 +86,9 @@ bst_idx_t SketchBatchNumElements(bst_idx_t sketch_batch_num_elements, SketchShap size_t num_cuts, bool has_weight, std::size_t container_bytes) { auto constexpr kIntMax = static_cast(std::numeric_limits::max()); #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 + (void)device; // Device available memory is not accurate when rmm is used. - double total_mem = dh::TotalMemory(device) - container_bytes; + double total_mem = curt::TotalMemory() - container_bytes; double total_f32 = total_mem / sizeof(float); double n_max_used_f32 = std::max(total_f32 / 16.0, 1.0); // a quarter if (shape.nnz > shape.Size()) { diff --git a/src/data/batch_utils.h b/src/data/batch_utils.h index dbd0b3fb9614..73b0fbcc6d90 100644 --- a/src/data/batch_utils.h +++ b/src/data/batch_utils.h @@ -35,4 +35,11 @@ inline bool RegenGHist(BatchParam old, BatchParam p) { */ void CheckParam(BatchParam const& init, BatchParam const& param); } // namespace xgboost::data::detail + +namespace xgboost::cuda_impl { +// Indicator for XGBoost to not concatenate any page. +constexpr bst_idx_t MatchingPageBytes() { return 0; } +// Default size of the cached page +constexpr double CachePageRatio() { return 0.125; } +} // namespace xgboost::cuda_impl #endif // XGBOOST_DATA_BATCH_UTILS_H_ diff --git a/src/data/data.cc b/src/data/data.cc index 3f78697e7c37..14d056fe2c62 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -35,6 +35,7 @@ #include "../data/iterative_dmatrix.h" // for IterativeDMatrix #include "./sparse_page_dmatrix.h" // for SparsePageDMatrix #include "array_interface.h" // for ArrayInterfaceHandler, ArrayInterface, Dispa... +#include "batch_utils.h" // for MatchingPageBytes #include "dmlc/base.h" // for BeginPtr #include "dmlc/common.h" // for OMPException #include "dmlc/data.h" // for Parser @@ -914,14 +915,13 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s CHECK(data_split_mode != DataSplitMode::kCol) << "Column-wise data split is not supported for external memory."; data::FileIterator iter{fname, static_cast(partid), static_cast(npart)}; - dmat = new data::SparsePageDMatrix{&iter, - iter.Proxy(), - data::fileiter::Reset, - data::fileiter::Next, - std::numeric_limits::quiet_NaN(), - 1, - cache_file, - false}; + auto config = ExtMemConfig{.cache = cache_file, + .on_host = false, + .min_cache_page_bytes = cuda_impl::MatchingPageBytes(), + .missing = std::numeric_limits::quiet_NaN(), + .n_threads = 1}; + dmat = new data::SparsePageDMatrix{&iter, iter.Proxy(), data::fileiter::Reset, + data::fileiter::Next, config}; } return dmat; @@ -938,18 +938,16 @@ DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_p template DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset, - XGDMatrixCallbackNext* next, float missing, int32_t n_threads, - std::string cache, bool on_host) { - return new data::SparsePageDMatrix{iter, proxy, reset, next, missing, n_threads, cache, on_host}; + XGDMatrixCallbackNext* next, ExtMemConfig const& config) { + return new data::SparsePageDMatrix{iter, proxy, reset, next, config}; } template DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr ref, - DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing, - std::int32_t nthread, bst_bin_t max_bin, std::string cache, bool on_host) { - return new data::ExtMemQuantileDMatrix{ - iter, proxy, ref, reset, next, missing, nthread, std::move(cache), max_bin, on_host}; + DataIterResetCallback* reset, XGDMatrixCallbackNext* next, + bst_bin_t max_bin, ExtMemConfig const& config) { + return new data::ExtMemQuantileDMatrix{iter, proxy, ref, reset, next, max_bin, config}; } template DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset, - XGDMatrixCallbackNext* next, float missing, - int32_t n_threads, std::string, bool); + XGDMatrixCallbackNext* next, + ExtMemConfig const&); template DMatrix* DMatrix::Create( DataIterHandle, DMatrixHandle, std::shared_ptr, DataIterResetCallback*, - XGDMatrixCallbackNext*, float, std::int32_t, bst_bin_t, std::string, bool); + XGDMatrixCallbackNext*, bst_bin_t, ExtMemConfig const&); template DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&, diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 7d5b06049002..9673a33e1055 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -127,7 +127,6 @@ __global__ void CompressBinEllpackKernel( wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + cpr_fidx); } -namespace { // Calculate the number of symbols for the compressed ellpack. Similar to what the CPU // implementation does, we compress the dense data by subtracting the bin values with the // starting bin of its feature if it's dense. In addition, we treat the data as dense if @@ -176,7 +175,6 @@ namespace { return {row_stride, n_symbols}; } } -} // namespace // Construct an ELLPACK matrix with the given number of empty rows. EllpackPageImpl::EllpackPageImpl(Context const* ctx, @@ -486,7 +484,14 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag EllpackPageImpl::~EllpackPageImpl() noexcept(false) { // Sync the stream to make sure all running CUDA kernels finish before deallocation. - dh::DefaultStream().Sync(); + auto status = dh::DefaultStream().Sync(false); + if (status != cudaSuccess) { + auto str = cudaGetErrorString(status); + // For external-memory, throwing here can trigger a series of calls to + // `std::terminate` by various destructors. For now, we just log the error. + LOG(WARNING) << "Ran into CUDA error:" << str << "\nXGBoost is likely to abort."; + } + dh::safe_cuda(status); } // A functor that copies the data from one EllpackPage to another. @@ -503,7 +508,7 @@ struct CopyPage { src_iterator_d{src->gidx_buffer.data(), src->NumSymbols()}, offset{offset} {} - __device__ void operator()(size_t element_id) { + __device__ void operator()(std::size_t element_id) { cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[element_id], element_id + offset); } }; @@ -511,17 +516,15 @@ struct CopyPage { // Copy the data from the given EllpackPage to the current page. bst_idx_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bst_idx_t offset) { monitor_.Start(__func__); - bst_idx_t num_elements = page->n_rows * page->info.row_stride; + bst_idx_t n_elements = page->n_rows * page->info.row_stride; + CHECK_NE(this, page); CHECK_EQ(this->info.row_stride, page->info.row_stride); - CHECK_EQ(NumSymbols(), page->NumSymbols()); - CHECK_GE(this->n_rows * this->info.row_stride, offset + num_elements); - if (page == this) { - LOG(FATAL) << "Concatenating the same Ellpack."; - return this->n_rows * this->info.row_stride; - } - dh::LaunchN(num_elements, ctx->CUDACtx()->Stream(), CopyPage{this, page, offset}); + CHECK_EQ(this->NumSymbols(), page->NumSymbols()); + CHECK_GE(this->n_rows * this->info.row_stride, offset + n_elements); + thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), n_elements, + CopyPage{this, page, offset}); monitor_.Stop(__func__); - return num_elements; + return n_elements; } // A functor that compacts the rows from one EllpackPage into another. @@ -608,7 +611,7 @@ void EllpackPageImpl::CreateHistIndices(Context const* ctx, const SparsePage& ro // bin and compress entries in batches of rows size_t gpu_batch_nrows = - std::min(dh::TotalMemory(ctx->Ordinal()) / (16 * this->info.row_stride * sizeof(Entry)), + std::min(curt::TotalMemory() / (16 * this->info.row_stride * sizeof(Entry)), static_cast(row_batch.Size())); size_t gpu_nbatches = common::DivRoundUp(row_batch.Size(), gpu_batch_nrows); diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 85e412f86398..1caa38b0e59a 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -211,7 +211,6 @@ class EllpackPageImpl { * @returns The number of elements copied. */ bst_idx_t Copy(Context const* ctx, EllpackPageImpl const* page, bst_idx_t offset); - /** * @brief Compact the given ELLPACK page into the current page. * @@ -252,6 +251,17 @@ class EllpackPageImpl { */ [[nodiscard]] auto NumSymbols() const { return this->info.n_symbols; } void SetNumSymbols(bst_idx_t n_symbols) { this->info.n_symbols = n_symbols; } + /** + * @brief Copy basic shape from another page. + */ + void CopyInfo(EllpackPageImpl const* page) { + CHECK_NE(this, page); + this->n_rows = page->Size(); + this->is_dense = page->IsDense(); + this->info.row_stride = page->info.row_stride; + this->SetBaseRowId(page->base_rowid); + this->SetNumSymbols(page->NumSymbols()); + } /** * @brief Get an accessor that can be passed into CUDA kernels. */ @@ -310,9 +320,11 @@ class EllpackPageImpl { }; [[nodiscard]] inline bst_idx_t GetRowStride(DMatrix* dmat) { - if (dmat->IsDense()) return dmat->Info().num_col_; + if (dmat->IsDense()) { + return dmat->Info().num_col_; + } - size_t row_stride = 0; + bst_idx_t row_stride = 0; for (const auto& batch : dmat->GetBatches()) { const auto& row_offset = batch.offset.ConstHostVector(); for (auto i = 1ull; i < row_offset.size(); i++) { @@ -321,6 +333,10 @@ class EllpackPageImpl { } return row_stride; } + +[[nodiscard]] EllpackPageImpl::Info CalcNumSymbols( + Context const* ctx, bst_idx_t row_stride, bool is_dense, + std::shared_ptr cuts); } // namespace xgboost #endif // XGBOOST_DATA_ELLPACK_PAGE_CUH_ diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index 262b9c8d3796..2907174a0920 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -113,10 +113,14 @@ template EllpackHostCacheStream* fo) const { xgboost_NVTX_FN_RANGE(); - fo->Write(page); + bool new_page = fo->Write(page); dh::DefaultStream().Sync(); - return page.Impl()->MemCostBytes(); + if (new_page) { + return fo->Share()->pages.back()->MemCostBytes(); + } else { + return InvalidPageSize(); + } } #undef RET_IF_NOT diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 358e91db00f0..0432f17d36fd 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -8,6 +8,7 @@ #include // for move #include "../common/common.h" // for safe_cuda +#include "../common/common.h" // for HumanMemUnit #include "../common/cuda_rt_utils.h" // for SetDevice #include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream #include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc @@ -23,21 +24,22 @@ namespace xgboost::data { /** * Cache */ -EllpackHostCache::EllpackHostCache() = default; +EllpackHostCache::EllpackHostCache(EllpackCacheInfo info) + : cache_mapping{std::move(info.cache_mapping)}, + buffer_bytes{std::move(info.buffer_bytes)}, + buffer_rows{std::move(info.buffer_rows)} { + CHECK_EQ(buffer_bytes.size(), buffer_rows.size()); +} + EllpackHostCache::~EllpackHostCache() = default; -[[nodiscard]] std::size_t EllpackHostCache::Size() const { +[[nodiscard]] std::size_t EllpackHostCache::SizeBytes() const { auto it = common::MakeIndexTransformIter([&](auto i) { return pages.at(i)->MemCostBytes(); }); - return std::accumulate(it, it + pages.size(), 0l); + using T = std::iterator_traits::value_type; + return std::accumulate(it, it + pages.size(), static_cast(0)); } -void EllpackHostCache::Push(std::unique_ptr page) { - this->pages.emplace_back(std::move(page)); -} - -EllpackPageImpl const* EllpackHostCache::Get(std::int32_t k) { - return this->pages.at(k).get(); -} +EllpackPageImpl const* EllpackHostCache::At(std::int32_t k) { return this->pages.at(k).get(); } /** * Cache stream. @@ -69,28 +71,79 @@ class EllpackHostCacheStreamImpl { ptr_ = k; } - void Write(EllpackPage const& page) { + [[nodiscard]] bool Write(EllpackPage const& page) { auto impl = page.Impl(); + auto ctx = Context{}.MakeCUDA(dh::CurrentDevice()); + + this->cache_->sizes_orig.push_back(page.Impl()->MemCostBytes()); + auto orig_ptr = this->cache_->sizes_orig.size() - 1; + + CHECK_LT(orig_ptr, this->cache_->NumBatchesOrig()); + auto cache_idx = this->cache_->cache_mapping.at(orig_ptr); + // Wrap up the previous page if this is a new page, or this is the last page. + auto new_page = cache_idx == this->cache_->pages.size(); + + auto last_page = (orig_ptr + 1) == this->cache_->NumBatchesOrig(); + bool no_concat = this->cache_->NumBatchesOrig() == this->cache_->buffer_rows.size(); + + auto commit_page = [&ctx](EllpackPageImpl const* old_impl) { + CHECK_EQ(old_impl->gidx_buffer.Resource()->Type(), common::ResourceHandler::kCudaMalloc); + + auto new_impl = std::make_unique(); + new_impl->CopyInfo(old_impl); + new_impl->gidx_buffer = common::MakeFixedVecWithPinnedMalloc( + old_impl->gidx_buffer.size()); + dh::safe_cuda(cudaMemcpyAsync(new_impl->gidx_buffer.data(), old_impl->gidx_buffer.data(), + old_impl->gidx_buffer.size_bytes(), cudaMemcpyDefault)); + LOG(INFO) << "Create cache page with size:" << common::HumanMemUnit(new_impl->MemCostBytes()); + return new_impl; + }; + + if (no_concat) { + // Avoid a device->device->host copy. + CHECK(new_page); + auto commited = commit_page(page.Impl()); + this->cache_->offsets.push_back(commited->n_rows * commited->info.row_stride); + this->cache_->pages.push_back(std::move(commited)); + return new_page; + } - auto new_impl = std::make_unique(); - auto new_cache = std::make_shared(); - new_impl->gidx_buffer = - common::MakeFixedVecWithPinnedMalloc(impl->gidx_buffer.size()); - new_impl->n_rows = impl->Size(); - new_impl->is_dense = impl->IsDense(); - new_impl->info.row_stride = impl->info.row_stride; - new_impl->base_rowid = impl->base_rowid; - new_impl->SetNumSymbols(impl->NumSymbols()); - - dh::safe_cuda(cudaMemcpyAsync(new_impl->gidx_buffer.data(), impl->gidx_buffer.data(), - impl->gidx_buffer.size_bytes(), cudaMemcpyDefault)); - - this->cache_->Push(std::move(new_impl)); - ptr_ += 1; + if (new_page) { + if (!this->cache_->pages.empty()) { + // New to wrap up the previous page. + auto commited = commit_page(this->cache_->pages.back().get()); + this->cache_->pages.back() = std::move(commited); + } + // Push a new page + auto n_bytes = this->cache_->buffer_bytes.at(this->cache_->pages.size()); + auto n_samples = this->cache_->buffer_rows.at(this->cache_->pages.size()); + auto new_impl = std::make_unique(&ctx, impl->CutsShared(), impl->IsDense(), + impl->info.row_stride, n_samples); + new_impl->SetBaseRowId(impl->base_rowid); + new_impl->SetNumSymbols(impl->NumSymbols()); + new_impl->gidx_buffer = + common::MakeFixedVecWithCudaMalloc(&ctx, n_bytes, 0); + auto offset = new_impl->Copy(&ctx, impl, 0); + + this->cache_->offsets.push_back(offset); + this->cache_->pages.push_back(std::move(new_impl)); + } else { + CHECK(!this->cache_->pages.empty()); + CHECK_EQ(cache_idx, this->cache_->pages.size() - 1); + auto& new_impl = this->cache_->pages.back(); + auto offset = new_impl->Copy(&ctx, impl, this->cache_->offsets.back()); + this->cache_->offsets.back() += offset; + if (last_page) { + auto commited = commit_page(this->cache_->pages.back().get()); + this->cache_->pages.back() = std::move(commited); + } + } + + return new_page; } void Read(EllpackPage* out, bool prefetch_copy) const { - auto page = this->cache_->Get(ptr_); + auto page = this->cache_->At(ptr_); auto impl = out->Impl(); if (prefetch_copy) { @@ -104,24 +157,21 @@ class EllpackHostCacheStreamImpl { res->DataAs(), page->gidx_buffer.size(), res}; } - impl->n_rows = page->Size(); - impl->is_dense = page->IsDense(); - impl->info.row_stride = page->info.row_stride; - impl->base_rowid = page->base_rowid; - impl->SetNumSymbols(page->NumSymbols()); + impl->CopyInfo(page); } }; /** * EllpackHostCacheStream */ - EllpackHostCacheStream::EllpackHostCacheStream(std::shared_ptr cache) : p_impl_{std::make_unique(std::move(cache))} {} EllpackHostCacheStream::~EllpackHostCacheStream() = default; -std::shared_ptr EllpackHostCacheStream::Share() { return p_impl_->Share(); } +std::shared_ptr EllpackHostCacheStream::Share() const { + return p_impl_->Share(); +} void EllpackHostCacheStream::Seek(bst_idx_t offset_bytes) { this->p_impl_->Seek(offset_bytes); } @@ -129,24 +179,24 @@ void EllpackHostCacheStream::Read(EllpackPage* page, bool prefetch_copy) const { this->p_impl_->Read(page, prefetch_copy); } -void EllpackHostCacheStream::Write(EllpackPage const& page) { this->p_impl_->Write(page); } +[[nodiscard]] bool EllpackHostCacheStream::Write(EllpackPage const& page) { + return this->p_impl_->Write(page); +} /** * EllpackCacheStreamPolicy */ - -template typename F> -EllpackCacheStreamPolicy::EllpackCacheStreamPolicy() - : p_cache_{std::make_shared()} {} - template typename F> [[nodiscard]] std::unique_ptr::WriterT> EllpackCacheStreamPolicy::CreateWriter(StringView, std::uint32_t iter) { + if (!this->p_cache_) { + this->p_cache_ = std::make_unique(this->CacheInfo()); + } auto fo = std::make_unique(this->p_cache_); if (iter == 0) { CHECK(this->p_cache_->Empty()); } else { - fo->Seek(this->p_cache_->Size()); + fo->Seek(this->p_cache_->SizeBytes()); } return fo; } @@ -160,8 +210,6 @@ EllpackCacheStreamPolicy::CreateReader(StringView, bst_idx_t offset, bst_i } // Instantiation -template EllpackCacheStreamPolicy::EllpackCacheStreamPolicy(); - template std::unique_ptr< typename EllpackCacheStreamPolicy::WriterT> EllpackCacheStreamPolicy::CreateWriter(StringView name, @@ -195,6 +243,37 @@ EllpackMmapStreamPolicy::CreateReader(StringVi bst_idx_t offset, bst_idx_t length) const; +void CalcCacheMapping(Context const* ctx, bool is_dense, + std::shared_ptr cuts, double min_page_bytes, + ExternalDataInfo const& ext_info, EllpackCacheInfo* cinfo) { + CHECK(cinfo->param.Initialized()) << "Need to initialize scalar fields first."; + auto ell_info = CalcNumSymbols(ctx, ext_info.row_stride, is_dense, cuts); + std::vector cache_bytes; + std::vector cache_mapping(ext_info.n_batches, 0); + std::vector cache_rows; + + for (std::size_t i = 0; i < ext_info.n_batches; ++i) { + auto n_samples = ext_info.base_rowids.at(i + 1) - ext_info.base_rowids[i]; + auto n_bytes = common::CompressedBufferWriter::CalculateBufferSize( + ext_info.row_stride * n_samples, ell_info.n_symbols); + if (cache_bytes.empty()) { + cache_bytes.push_back(n_bytes); + cache_rows.push_back(n_samples); + } else if (cache_bytes.back() < min_page_bytes) { + cache_bytes.back() += n_bytes; + cache_rows.back() += n_samples; + } else { + cache_bytes.push_back(n_bytes); + cache_rows.push_back(n_samples); + } + cache_mapping[i] = cache_bytes.size() - 1; + } + + cinfo->cache_mapping = std::move(cache_mapping); + cinfo->buffer_bytes = std::move(cache_bytes); + cinfo->buffer_rows = std::move(cache_rows); +} + /** * EllpackPageSourceImpl */ @@ -265,7 +344,7 @@ void ExtEllpackPageSourceImpl::Fetch() { << cuda_impl::Dispatch(proxy_, [](auto const& adapter) { return common::HumanMemUnit(adapter->SizeBytes()); }); - this->page_->SetBaseRowId(this->ext_info_.base_rows.at(iter)); + this->page_->SetBaseRowId(this->ext_info_.base_rowids.at(iter)); this->WriteCache(); } } diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h index 573d9c34a702..71556b56c78f 100644 --- a/src/data/ellpack_page_source.h +++ b/src/data/ellpack_page_source.h @@ -6,6 +6,7 @@ #define XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_ #include // for int32_t +#include // for numeric_limits #include // for shared_ptr #include // for move #include // for vector @@ -21,40 +22,84 @@ #include "xgboost/span.h" // for Span namespace xgboost::data { +struct EllpackCacheInfo { + BatchParam param; + float missing{std::numeric_limits::quiet_NaN()}; + std::vector cache_mapping; + std::vector buffer_bytes; + std::vector buffer_rows; + + EllpackCacheInfo() = default; + EllpackCacheInfo(BatchParam param, float missing) : param{std::move(param)}, missing{missing} {} +}; + // We need to decouple the storage and the view of the storage so that we can implement // concurrent read. As a result, there are two classes, one for cache storage, another one // for stream. struct EllpackHostCache { - std::vector> pages; - - EllpackHostCache(); + std::vector> pages; + std::vector offsets; + // Size of each batch before concatenation. + std::vector sizes_orig; + // Mapping of pages before concatenation to after concatenation. + std::vector const cache_mapping; + // Cache info + std::vector const buffer_bytes; + std::vector const buffer_rows; + + explicit EllpackHostCache(EllpackCacheInfo info); ~EllpackHostCache(); - [[nodiscard]] std::size_t Size() const; + // The number of bytes for the entire cache. + [[nodiscard]] std::size_t SizeBytes() const; - bool Empty() const { return this->Size() == 0; } + bool Empty() const { return this->SizeBytes() == 0; } - void Push(std::unique_ptr page); - EllpackPageImpl const* Get(std::int32_t k); + [[nodiscard]] bst_idx_t NumBatchesOrig() const { return cache_mapping.size(); } + EllpackPageImpl const* At(std::int32_t k); }; // Pimpl to hide CUDA calls from the host compiler. class EllpackHostCacheStreamImpl; -// A view onto the actual cache implemented by `EllpackHostCache`. +/** + * @brief A view of the actual cache implemented by `EllpackHostCache`. + */ class EllpackHostCacheStream { std::unique_ptr p_impl_; public: explicit EllpackHostCacheStream(std::shared_ptr cache); ~EllpackHostCacheStream(); - - std::shared_ptr Share(); - + /** + * @brief Get a shared handler to the cache. + */ + std::shared_ptr Share() const; + /** + * @brief Stream seek. + * + * @param offset_bytes This must align to the actual cached page size. + */ void Seek(bst_idx_t offset_bytes); - + /** + * @brief Read a page from the cache. + * + * The read page might be concatenated during page write. + * + * @param page[out] The returned page. + * @param prefetch_copy[in] Does the stream need to copy the page? + */ void Read(EllpackPage* page, bool prefetch_copy) const; - void Write(EllpackPage const& page); + /** + * @brief Add a new page to the host cache. + * + * This method might append the input page to a previously stored page to increase + * individual page size. + * + * @return Whether a new cache page is create. False if the new page is appended to the + * previous one. + */ + [[nodiscard]] bool Write(EllpackPage const& page); }; template @@ -63,6 +108,9 @@ class EllpackFormatPolicy { DeviceOrd device_; bool has_hmm_{curt::SupportsPageableMem()}; + EllpackCacheInfo cache_info_; + static_assert(std::is_same_v); + public: using FormatT = EllpackPageRawFormat; @@ -82,7 +130,7 @@ class EllpackFormatPolicy { } std::int32_t major{0}, minor{0}; curt::DrVersion(&major, &minor); - if (!(major >= 12 && minor >= 7) && curt::SupportsAts()) { + if ((major < 12 || (major <= 12 && minor < 7)) && curt::SupportsAts()) { // Use ATS, but with an old kernel driver. LOG(WARNING) << "Using an old kernel driver with supported CTK<12.7." << "The latest version of CTK supported by the current driver: " << major << "." @@ -97,18 +145,19 @@ class EllpackFormatPolicy { std::unique_ptr fmt{new EllpackPageRawFormat{cuts_, device_, param, has_hmm_}}; return fmt; } - - void SetCuts(std::shared_ptr cuts, DeviceOrd device) { - std::swap(cuts_, cuts); - device_ = device; + void SetCuts(std::shared_ptr cuts, DeviceOrd device, + EllpackCacheInfo cinfo) { + std::swap(this->cuts_, cuts); + this->device_ = device; CHECK(this->device_.IsCUDA()); + this->cache_info_ = std::move(cinfo); } - [[nodiscard]] auto GetCuts() { + [[nodiscard]] auto GetCuts() const { CHECK(cuts_); return cuts_; } - - [[nodiscard]] auto Device() const { return device_; } + [[nodiscard]] auto Device() const { return this->device_; } + [[nodiscard]] auto const& CacheInfo() { return this->cache_info_; } }; template typename F> @@ -120,7 +169,7 @@ class EllpackCacheStreamPolicy : public F { using ReaderT = EllpackHostCacheStream; public: - EllpackCacheStreamPolicy(); + EllpackCacheStreamPolicy() = default; [[nodiscard]] std::unique_ptr CreateWriter(StringView name, std::uint32_t iter); [[nodiscard]] std::unique_ptr CreateReader(StringView name, bst_idx_t offset, @@ -156,6 +205,10 @@ class EllpackMmapStreamPolicy : public F { bst_idx_t length) const; }; +void CalcCacheMapping(Context const* ctx, bool is_dense, + std::shared_ptr cuts, double min_page_bytes, + ExternalDataInfo const& ext_info, EllpackCacheInfo* cinfo); + /** * @brief Ellpack source with sparse pages as the underlying source. */ @@ -168,19 +221,19 @@ class EllpackPageSourceImpl : public PageSourceIncMixIn { common::Span feature_types_; public: - EllpackPageSourceImpl(float missing, std::int32_t nthreads, bst_feature_t n_features, - std::size_t n_batches, std::shared_ptr cache, BatchParam param, - std::shared_ptr cuts, bool is_dense, - bst_idx_t row_stride, common::Span feature_types, - std::shared_ptr source, DeviceOrd device) - : Super{missing, nthreads, n_features, n_batches, cache, false}, + EllpackPageSourceImpl(Context const* ctx, bst_feature_t n_features, std::size_t n_batches, + std::shared_ptr cache, std::shared_ptr cuts, + bool is_dense, bst_idx_t row_stride, + common::Span feature_types, + std::shared_ptr source, EllpackCacheInfo const& cinfo) + : Super{cinfo.missing, ctx->Threads(), n_features, n_batches, cache, false}, is_dense_{is_dense}, row_stride_{row_stride}, - param_{std::move(param)}, + param_{std::move(cinfo.param)}, feature_types_{feature_types} { this->source_ = source; - cuts->SetDevice(device); - this->SetCuts(std::move(cuts), device); + cuts->SetDevice(ctx->Device()); + this->SetCuts(std::move(cuts), ctx->Device(), cinfo); this->Fetch(); } @@ -210,18 +263,19 @@ class ExtEllpackPageSourceImpl : public ExtQantileSourceMixin cache, BatchParam param, std::shared_ptr cuts, + Context const* ctx, MetaInfo* info, ExternalDataInfo ext_info, std::shared_ptr cache, + std::shared_ptr cuts, std::shared_ptr> source, - DMatrixProxy* proxy) - : Super{missing, ctx->Threads(), static_cast(info->num_col_), source, cache}, + DMatrixProxy* proxy, EllpackCacheInfo const& cinfo) + : Super{cinfo.missing, ctx->Threads(), static_cast(info->num_col_), source, + cache}, ctx_{ctx}, - p_{std::move(param)}, + p_{cinfo.param}, proxy_{proxy}, info_{info}, ext_info_{std::move(ext_info)} { cuts->SetDevice(ctx->Device()); - this->SetCuts(std::move(cuts), ctx->Device()); + this->SetCuts(std::move(cuts), ctx->Device(), cinfo); CHECK(!this->cache_info_->written); this->source_->Reset(); CHECK(this->source_->Next()); @@ -229,6 +283,17 @@ class ExtEllpackPageSourceImpl : public ExtQantileSourceMixincache_info_->written) { + CHECK_EQ(this->Iter(), this->cache_info_->Size()); + } else { + CHECK_LE(this->cache_info_->Size(), this->ext_info_.n_batches); + } + this->cache_info_->Commit(); + CHECK_GE(this->count_, 1); + this->count_ = 0; + } }; // Cache to host diff --git a/src/data/extmem_quantile_dmatrix.cc b/src/data/extmem_quantile_dmatrix.cc index d1321a511eee..9026411eec85 100644 --- a/src/data/extmem_quantile_dmatrix.cc +++ b/src/data/extmem_quantile_dmatrix.cc @@ -3,10 +3,9 @@ */ #include "extmem_quantile_dmatrix.h" -#include // for shared_ptr -#include // for string -#include // for move -#include // for vector +#include // for shared_ptr +#include // for string +#include // for vector #include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter. #include "batch_utils.h" // for CheckParam, RegenGHist @@ -23,10 +22,9 @@ namespace xgboost::data { ExtMemQuantileDMatrix::ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy, std::shared_ptr ref, DataIterResetCallback *reset, - XGDMatrixCallbackNext *next, float missing, - std::int32_t n_threads, std::string cache, - bst_bin_t max_bin, bool on_host) - : cache_prefix_{std::move(cache)}, on_host_{on_host} { + XGDMatrixCallbackNext *next, bst_bin_t max_bin, + ExtMemConfig const &config) + : cache_prefix_{config.cache}, on_host_{config.on_host} { cache_prefix_ = MakeCachePrefix(cache_prefix_); auto iter = std::make_shared>( iter_handle, reset, next); @@ -37,13 +35,13 @@ ExtMemQuantileDMatrix::ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrix auto pctx = MakeProxy(proxy)->Ctx(); Context ctx; - ctx.Init(Args{{"nthread", std::to_string(n_threads)}, {"device", pctx->DeviceName()}}); + ctx.Init(Args{{"nthread", std::to_string(config.n_threads)}, {"device", pctx->DeviceName()}}); BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()}; if (ctx.IsCPU()) { - this->InitFromCPU(&ctx, iter, proxy, p, missing, ref); + this->InitFromCPU(&ctx, iter, proxy, p, config.missing, ref); } else { - this->InitFromCUDA(&ctx, iter, proxy, p, missing, ref); + this->InitFromCUDA(&ctx, iter, proxy, p, ref, config); } this->batch_ = p; this->fmat_ctx_ = ctx; @@ -92,7 +90,7 @@ void ExtMemQuantileDMatrix::InitFromCPU( */ auto id = MakeCache(this, ".gradient_index.page", false, cache_prefix_, &cache_info_); this->ghist_index_source_ = std::make_unique( - ctx, missing, &this->Info(), cache_info_.at(id), p, cuts, iter, proxy, ext_info.base_rows); + ctx, missing, &this->Info(), cache_info_.at(id), p, cuts, iter, proxy, ext_info.base_rowids); /** * Force initialize the cache and do some sanity checks along the way @@ -101,7 +99,7 @@ void ExtMemQuantileDMatrix::InitFromCPU( bst_idx_t n_total_samples = 0; for (auto const &page : this->GetGradientIndexImpl()) { n_total_samples += page.Size(); - CHECK_EQ(page.base_rowid, ext_info.base_rows[k]); + CHECK_EQ(page.base_rowid, ext_info.base_rowids[k]); CHECK_EQ(page.Features(), this->Info().num_col_); ++k, ++batch_cnt; } diff --git a/src/data/extmem_quantile_dmatrix.cu b/src/data/extmem_quantile_dmatrix.cu index 2508a7e90a7e..051f2873c5be 100644 --- a/src/data/extmem_quantile_dmatrix.cu +++ b/src/data/extmem_quantile_dmatrix.cu @@ -2,7 +2,7 @@ * Copyright 2024, XGBoost Contributors */ #include // for shared_ptr -#include // for visit +#include // for visit, get_if #include "../common/cuda_rt_utils.h" // for xgboost_NVTX_FN_RANGE #include "batch_utils.h" // for CheckParam, RegenGHist @@ -11,13 +11,13 @@ #include "proxy_dmatrix.h" // for DataIterProxy #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for BatchParam -#include "../common/cuda_rt_utils.h" namespace xgboost::data { void ExtMemQuantileDMatrix::InitFromCUDA( Context const *ctx, std::shared_ptr> iter, - DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr ref) { + DMatrixHandle proxy_handle, BatchParam const &p, std::shared_ptr ref, + ExtMemConfig const &config) { xgboost_NVTX_FN_RANGE(); // A handle passed to external iterator. @@ -29,9 +29,20 @@ void ExtMemQuantileDMatrix::InitFromCUDA( */ auto cuts = std::make_shared(); ExternalDataInfo ext_info; - cuda_impl::MakeSketches(ctx, iter.get(), proxy, ref, p, missing, cuts, this->Info(), &ext_info); + cuda_impl::MakeSketches(ctx, iter.get(), proxy, ref, p, config.missing, cuts, this->Info(), + &ext_info); ext_info.SetInfo(ctx, &this->info_); + /** + * Calculate cache info + */ + auto cinfo = EllpackCacheInfo{p, config.missing}; + CalcCacheMapping(ctx, this->Info().IsDense(), cuts, config.min_cache_page_bytes, ext_info, + &cinfo); + CHECK_EQ(cinfo.cache_mapping.size(), ext_info.n_batches); + auto n_batches = cinfo.buffer_rows.size(); + LOG(INFO) << "Number of batches after concatenation:" << n_batches; + /** * Generate gradient index */ @@ -43,8 +54,8 @@ void ExtMemQuantileDMatrix::InitFromCUDA( std::visit( [&](auto &&ptr) { using SourceT = typename std::remove_reference_t::element_type; - ptr = std::make_shared(ctx, missing, &this->Info(), ext_info, cache_info_.at(id), - p, cuts, iter, proxy); + ptr = std::make_shared(ctx, &this->Info(), ext_info, cache_info_.at(id), cuts, + iter, proxy, cinfo); }, ellpack_page_source_); @@ -55,13 +66,17 @@ void ExtMemQuantileDMatrix::InitFromCUDA( bst_idx_t n_total_samples = 0; for (auto const &page : this->GetEllpackPageImpl()) { n_total_samples += page.Size(); - CHECK_EQ(page.Impl()->base_rowid, ext_info.base_rows[k]); + CHECK_EQ(page.Impl()->base_rowid, ext_info.base_rowids[k]); CHECK_EQ(page.Impl()->info.row_stride, ext_info.row_stride); ++k, ++batch_cnt; } CHECK_EQ(batch_cnt, ext_info.n_batches); CHECK_EQ(n_total_samples, ext_info.accumulated_rows); - this->n_batches_ = ext_info.n_batches; + + if (this->on_host_) { + CHECK_EQ(this->cache_info_.at(id)->Size(), n_batches); + } + this->n_batches_ = this->cache_info_.at(id)->Size(); } [[nodiscard]] BatchSet ExtMemQuantileDMatrix::GetEllpackPageImpl() { diff --git a/src/data/extmem_quantile_dmatrix.h b/src/data/extmem_quantile_dmatrix.h index 842bfd2d49d3..1c9a9013b6ba 100644 --- a/src/data/extmem_quantile_dmatrix.h +++ b/src/data/extmem_quantile_dmatrix.h @@ -29,8 +29,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix { public: ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy, std::shared_ptr ref, DataIterResetCallback *reset, - XGDMatrixCallbackNext *next, float missing, std::int32_t n_threads, - std::string cache, bst_bin_t max_bin, bool on_host); + XGDMatrixCallbackNext *next, bst_bin_t max_bin, ExtMemConfig const &config); ~ExtMemQuantileDMatrix() override; [[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; } @@ -43,7 +42,8 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix { void InitFromCUDA( Context const *ctx, std::shared_ptr> iter, - DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr ref); + DMatrixHandle proxy_handle, BatchParam const &p, std::shared_ptr ref, + ExtMemConfig const &config); [[nodiscard]] BatchSet GetGradientIndexImpl(); BatchSet GetGradientIndex(Context const *ctx, BatchParam const ¶m) override; @@ -63,7 +63,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix { std::map> cache_info_; std::string cache_prefix_; - bool on_host_; + bool const on_host_; BatchParam batch_; bst_idx_t n_batches_{0}; diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index ea4b104f5d3d..ff92eed74e76 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -150,7 +150,7 @@ struct ExternalDataInfo { bst_idx_t nnz = 0; // The number of non-missing values std::vector column_sizes; // The nnz for each column std::vector batch_nnz; // nnz for each batch - std::vector base_rows{0}; // base_rowid + std::vector base_rowids{0}; // base_rowid bst_idx_t row_stride{0}; // Used by ellpack void Validate() const { @@ -159,6 +159,7 @@ struct ExternalDataInfo { })) << "Something went wrong during iteration."; CHECK_GE(this->n_features, 1) << "Data must has at least 1 column."; + CHECK_EQ(this->base_rowids.size(), this->n_batches + 1); } void SetInfo(Context const* ctx, MetaInfo* p_info) { diff --git a/src/data/quantile_dmatrix.cc b/src/data/quantile_dmatrix.cc index 082e1ac2a909..27bce54050f0 100644 --- a/src/data/quantile_dmatrix.cc +++ b/src/data/quantile_dmatrix.cc @@ -131,14 +131,14 @@ void GetDataShape(Context const* ctx, DMatrixProxy* proxy, } bst_idx_t batch_size = BatchSamples(proxy); info.batch_nnz.push_back(nnz_cnt()); - info.base_rows.push_back(batch_size); + info.base_rowids.push_back(batch_size); info.nnz += info.batch_nnz.back(); info.accumulated_rows += batch_size; info.n_batches++; } while (iter.Next()); iter.Reset(); - std::partial_sum(info.base_rows.cbegin(), info.base_rows.cend(), info.base_rows.begin()); + std::partial_sum(info.base_rowids.cbegin(), info.base_rowids.cend(), info.base_rowids.begin()); } void MakeSketches(Context const* ctx, diff --git a/src/data/quantile_dmatrix.cu b/src/data/quantile_dmatrix.cu index b41ab046d7c8..47d81c8c31ed 100644 --- a/src/data/quantile_dmatrix.cu +++ b/src/data/quantile_dmatrix.cu @@ -48,6 +48,13 @@ void MakeSketches(Context const* ctx, data::BatchSamples(proxy), dh::GetDevice(ctx)), 0); }; + auto total_capacity = [&] { + bst_idx_t n_bytes = 0; + for (auto const& sk : sketches) { + n_bytes += sk.first->MemCapacityBytes(); + } + return n_bytes; + }; // Workaround empty input with CPU ctx. Context new_ctx; @@ -102,6 +109,7 @@ void MakeSketches(Context const* ctx, sketches.back().first.get()); sketches.back().second++; }); + LOG(DEBUG) << "Total capacity:" << common::HumanMemUnit(total_capacity()); } /** @@ -115,13 +123,13 @@ void MakeSketches(Context const* ctx, })); ext_info.nnz += thrust::reduce(ctx->CUDACtx()->CTP(), row_counts.begin(), row_counts.end()); ext_info.n_batches++; - ext_info.base_rows.push_back(batch_rows); + ext_info.base_rowids.push_back(batch_rows); } while (iter->Next()); iter->Reset(); CHECK_GE(ext_info.n_features, 1) << "Data must has at least 1 column."; - std::partial_sum(ext_info.base_rows.cbegin(), ext_info.base_rows.cend(), - ext_info.base_rows.begin()); + std::partial_sum(ext_info.base_rowids.cbegin(), ext_info.base_rowids.cend(), + ext_info.base_rowids.begin()); // Get reference curt::SetDevice(dh::GetDevice(ctx).ordinal); @@ -151,6 +159,8 @@ void MakeSketches(Context const* ctx, } else { GetCutsFromRef(ctx, ref, ext_info.n_features, p, cuts.get()); } + + ctx->CUDACtx()->Stream().Sync(); } } // namespace cuda_impl } // namespace xgboost::data diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index e2313806666c..c85102dcd526 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -24,40 +24,41 @@ const MetaInfo &SparsePageDMatrix::Info() const { return info_; } SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy_handle, DataIterResetCallback *reset, XGDMatrixCallbackNext *next, - float missing, int32_t nthreads, std::string cache_prefix, - bool on_host) + ExtMemConfig const &config) : proxy_{proxy_handle}, iter_{iter_handle}, reset_{reset}, next_{next}, - missing_{missing}, - cache_prefix_{std::move(cache_prefix)}, - on_host_{on_host} { + missing_{config.missing}, + cache_prefix_{config.cache}, + on_host_{config.on_host}, + min_cache_page_bytes_{config.min_cache_page_bytes} { Context ctx; - ctx.Init(Args{{"nthread", std::to_string(nthreads)}}); + ctx.Init(Args{{"nthread", std::to_string(config.n_threads)}}); cache_prefix_ = MakeCachePrefix(cache_prefix_); DMatrixProxy *proxy = MakeProxy(proxy_); auto iter = DataIterProxy{ iter_, reset_, next_}; - ExternalDataInfo ext_info; - // The proxy is iterated together with the sparse page source so we can obtain all // information in 1 pass. for (auto const &page : this->GetRowBatchesImpl(&ctx)) { this->info_.Extend(std::move(proxy->Info()), false, false); - ext_info.n_features = - std::max(static_cast(ext_info.n_features), BatchColumns(proxy)); - ext_info.accumulated_rows += BatchSamples(proxy); - ext_info.nnz += page.data.Size(); - ext_info.n_batches++; + ext_info_.n_features = + std::max(static_cast(ext_info_.n_features), BatchColumns(proxy)); + ext_info_.accumulated_rows += BatchSamples(proxy); + ext_info_.nnz += page.data.Size(); + ext_info_.n_batches++; + ext_info_.base_rowids.push_back(page.Size()); + ext_info_.batch_nnz.push_back(page.data.Size()); } + std::partial_sum(ext_info_.base_rowids.cbegin(), ext_info_.base_rowids.cend(), + ext_info_.base_rowids.begin()); iter.Reset(); - this->n_batches_ = ext_info.n_batches; - ext_info.SetInfo(&ctx, &this->info_); + ext_info_.SetInfo(&ctx, &this->info_); fmat_ctx_ = ctx; } @@ -86,11 +87,11 @@ void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) { auto iter = DataIterProxy{iter_, reset_, next_}; DMatrixProxy *proxy = MakeProxy(proxy_); sparse_page_source_.reset(); // clear before creating new one to prevent conflicts. - // During initialization, the n_batches_ is 0. - CHECK_EQ(this->n_batches_, static_castn_batches_)>(0)); - sparse_page_source_ = std::make_shared(iter, proxy, this->missing_, - ctx->Threads(), this->info_.num_col_, - this->n_batches_, cache_info_.at(id)); + // During initialization, the n_batches is 0. + CHECK_EQ(this->ext_info_.n_batches, static_castext_info_.n_batches)>(0)); + sparse_page_source_ = std::make_shared( + iter, proxy, this->missing_, ctx->Threads(), this->info_.num_col_, this->ext_info_.n_batches, + cache_info_.at(id)); } BatchSet SparsePageDMatrix::GetRowBatchesImpl(Context const *ctx) { @@ -108,9 +109,9 @@ BatchSet SparsePageDMatrix::GetColumnBatches(Context const *ctx) { CHECK_NE(this->Info().num_col_, 0); this->InitializeSparsePage(ctx); if (!column_source_) { - column_source_ = - std::make_shared(this->missing_, ctx->Threads(), this->Info().num_col_, - this->n_batches_, cache_info_.at(id), sparse_page_source_); + column_source_ = std::make_shared(this->missing_, ctx->Threads(), + this->Info().num_col_, this->NumBatches(), + cache_info_.at(id), sparse_page_source_); } else { column_source_->Reset({}); } @@ -123,8 +124,8 @@ BatchSet SparsePageDMatrix::GetSortedColumnBatches(Context const this->InitializeSparsePage(ctx); if (!sorted_column_source_) { sorted_column_source_ = std::make_shared( - this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id), - sparse_page_source_); + this->missing_, ctx->Threads(), this->Info().num_col_, this->NumBatches(), + cache_info_.at(id), sparse_page_source_); } else { sorted_column_source_->Reset({}); } @@ -153,8 +154,8 @@ BatchSet SparsePageDMatrix::GetGradientIndex(Context const *ct CHECK_NE(cuts.Values().size(), 0); auto ft = this->info_.feature_types.ConstHostSpan(); ghist_index_source_.reset(new GradientIndexPageSource( - this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id), - param, std::move(cuts), this->IsDense(), ft, sparse_page_source_)); + this->missing_, ctx->Threads(), this->Info().num_col_, this->NumBatches(), + cache_info_.at(id), param, std::move(cuts), this->IsDense(), ft, sparse_page_source_)); } else { CHECK(ghist_index_source_); ghist_index_source_->Reset(param); diff --git a/src/data/sparse_page_dmatrix.cu b/src/data/sparse_page_dmatrix.cu index 89806953383f..f22e5136d558 100644 --- a/src/data/sparse_page_dmatrix.cu +++ b/src/data/sparse_page_dmatrix.cu @@ -4,10 +4,11 @@ #include // for shared_ptr #include // for move #include // for visit +#include // for vector #include "../common/hist_util.cuh" #include "../common/hist_util.h" // for HistogramCuts -#include "batch_utils.h" // for CheckEmpty, RegenGHist +#include "batch_utils.h" // for CheckEmpty, RegenGHist, CachePageRatio #include "ellpack_page.cuh" #include "sparse_page_dmatrix.h" #include "xgboost/context.h" // for Context @@ -23,7 +24,6 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(Context const* ctx, detail::CheckEmpty(batch_param_, param); auto id = MakeCache(this, ".ellpack.page", on_host_, cache_prefix_, &cache_info_); - bst_idx_t row_stride = 0; if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) { this->InitializeSparsePage(ctx); // reinitialize the cache @@ -40,23 +40,31 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(Context const* ctx, } this->InitializeSparsePage(ctx); // reset after use. - row_stride = GetRowStride(this); + std::vector base_rowids, nnz; + if (this->ext_info_.row_stride == 0) { + this->ext_info_.row_stride = GetRowStride(this); + } + this->InitializeSparsePage(ctx); // reset after use. - CHECK_NE(row_stride, 0); batch_param_ = param; auto ft = this->Info().feature_types.ConstDeviceSpan(); if (on_host_ && std::get_if(&ellpack_page_source_) == nullptr) { ellpack_page_source_.emplace(nullptr); } + + auto cinfo = EllpackCacheInfo{param, this->missing_}; + CalcCacheMapping(ctx, this->IsDense(), cuts, min_cache_page_bytes_, this->ext_info_, &cinfo); + CHECK_EQ(cinfo.cache_mapping.size(), this->ext_info_.n_batches) + << "Page concatenation is only supported by the `ExtMemQuantileDMatrix`."; std::visit( [&](auto&& ptr) { ptr.reset(); // make sure resource is released before making new ones. using SourceT = typename std::remove_reference_t::element_type; - ptr = std::make_shared(this->missing_, ctx->Threads(), this->Info().num_col_, - this->n_batches_, cache_info_.at(id), param, - std::move(cuts), this->IsDense(), row_stride, ft, - this->sparse_page_source_, ctx->Device()); + ptr = std::make_shared(ctx, this->Info().num_col_, this->ext_info_.n_batches, + cache_info_.at(id), std::move(cuts), this->IsDense(), + this->ext_info_.row_stride, ft, this->sparse_page_source_, + cinfo); }, ellpack_page_source_); } else { diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 9f2eed9187ff..eea303cb8a85 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -74,7 +74,9 @@ class SparsePageDMatrix : public DMatrix { Context fmat_ctx_; std::string cache_prefix_; bool const on_host_; - std::uint32_t n_batches_{0}; + bst_idx_t const min_cache_page_bytes_; + ExternalDataInfo ext_info_; + // sparse page is the source to other page types, we make a special member function. void InitializeSparsePage(Context const *ctx); // Non-virtual version that can be used in constructor @@ -82,15 +84,14 @@ class SparsePageDMatrix : public DMatrix { public: explicit SparsePageDMatrix(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset, - XGDMatrixCallbackNext *next, float missing, int32_t nthreads, - std::string cache_prefix, bool on_host); + XGDMatrixCallbackNext *next, ExtMemConfig const &config); ~SparsePageDMatrix() override; [[nodiscard]] MetaInfo &Info() override; [[nodiscard]] const MetaInfo &Info() const override; [[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; } - [[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; } + [[nodiscard]] std::int32_t NumBatches() const override { return ext_info_.n_batches; } DMatrix *Slice(common::Span) override { LOG(FATAL) << "Slicing DMatrix is not supported for external memory."; return nullptr; diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 64acf5b4885b..a81920f12a5a 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -9,6 +9,7 @@ #include // for atomic #include // for uint64_t #include // for future +#include // for numeric_limits #include // for map #include // for unique_ptr #include // for mutex @@ -35,6 +36,8 @@ void TryDeleteCacheFile(const std::string& file); std::string MakeCachePrefix(std::string cache_prefix); +auto constexpr InvalidPageSize() { return std::numeric_limits::max(); } + /** * @brief Information about the cache including path and page offsets. */ @@ -339,10 +342,12 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol auto bytes = fmt->Write(*page_, fo.get()); timer.Stop(); - // Not entirely accurate, the kernels doesn't have to flush the data. - LOG(INFO) << common::HumanMemUnit(bytes) << " written in " << timer.ElapsedSeconds() - << " seconds."; - cache_info_->Push(bytes); + if (bytes != InvalidPageSize()) { + // Not entirely accurate, the kernels doesn't have to flush the data. + LOG(INFO) << common::HumanMemUnit(bytes) << " written in " << timer.ElapsedSeconds() + << " seconds."; + cache_info_->Push(bytes); + } } virtual void Fetch() = 0; diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 0117cc8f2218..9c5d601e8a42 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -8,17 +8,18 @@ #include #include -#include // for array -#include // std::size_t -#include // std::filesystem -#include // std::numeric_limits -#include // std::string +#include // for array +#include // std::size_t +#include // std::filesystem +#include // std::numeric_limits +#include // std::string #include #include "../../../src/c_api/c_api_error.h" #include "../../../src/common/io.h" #include "../../../src/data/adapter.h" // for ArrayAdapter #include "../../../src/data/array_interface.h" // for ArrayInterface +#include "../../../src/data/batch_utils.h" // for MatchingPageBytes #include "../../../src/data/gradient_index.h" // for GHistIndexMatrix #include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix #include "../../../src/data/sparse_page_dmatrix.h" // for SparsePageDMatrix @@ -495,8 +496,12 @@ auto MakeExtMemForTest(bst_idx_t n_samples, bst_feature_t n_features, Json dconf 0); NumpyArrayIterForTest iter_1{0.0f, n_samples, n_features, n_batches}; - auto Xy = std::make_shared( - &iter_1, iter_1.Proxy(), Reset, Next, std::numeric_limits::quiet_NaN(), 0, "", false); + auto config = ExtMemConfig{.cache = "", + .on_host = false, + .min_cache_page_bytes = cuda_impl::MatchingPageBytes(), + .missing = std::numeric_limits::quiet_NaN(), + .n_threads = 0}; + auto Xy = std::make_shared(&iter_1, iter_1.Proxy(), Reset, Next, config); MakeLabelForTest(Xy, p_fmat); return std::pair{p_fmat, Xy}; } diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 0ccb79def2ec..3a8d0b8d0b17 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -7,7 +7,8 @@ #include "../../../src/common/categorical.h" #include "../../../src/common/hist_util.h" -#include "../../../src/data/device_adapter.cuh" // for CupyAdapter +#include "../../../src/common/ref_resource_view.cuh" // for MakeCudaGrowOnly +#include "../../../src/data/device_adapter.cuh" // for CupyAdapter #include "../../../src/data/ellpack_page.cuh" #include "../../../src/data/ellpack_page.h" #include "../../../src/data/gradient_index.h" // for GHistIndexMatrix @@ -433,5 +434,5 @@ TEST_P(SparseEllpack, FromGHistIndex) { this->TestFromGHistIndex(GetParam()); } TEST_P(SparseEllpack, NumNonMissing) { this->TestNumNonMissing(this->GetParam()); } -INSTANTIATE_TEST_SUITE_P(EllpackPage, SparseEllpack, testing::Values(.0f, .2f, .4f, .8f)); +INSTANTIATE_TEST_SUITE_P(EllpackPage, SparseEllpack, ::testing::Values(.0f, .2f, .4f, .8f)); } // namespace xgboost diff --git a/tests/cpp/data/test_ellpack_page_raw_format.cu b/tests/cpp/data/test_ellpack_page_raw_format.cu index 87fd6db5fa05..f8415208c277 100644 --- a/tests/cpp/data/test_ellpack_page_raw_format.cu +++ b/tests/cpp/data/test_ellpack_page_raw_format.cu @@ -4,7 +4,7 @@ #include #include -#include "../../../src/data/ellpack_page.cuh" // for EllpackPage +#include "../../../src/data/ellpack_page.cuh" // for EllpackPage, GetRowStride #include "../../../src/data/ellpack_page_raw_format.h" // for EllpackPageRawFormat #include "../../../src/data/ellpack_page_source.h" // for EllpackFormatStreamPolicy #include "../../../src/tree/param.h" // for TrainParam @@ -13,6 +13,20 @@ namespace xgboost::data { namespace { +[[nodiscard]] EllpackCacheInfo CInfoForTest(Context const *ctx, DMatrix *Xy, bst_idx_t row_stride, + BatchParam param, + std::shared_ptr cuts) { + EllpackCacheInfo cinfo{param, std::numeric_limits::quiet_NaN()}; + ExternalDataInfo ext_info; + ext_info.n_batches = 1; + ext_info.row_stride = row_stride; + ext_info.base_rowids.push_back(Xy->Info().num_row_); + + CalcCacheMapping(ctx, Xy->IsDense(), cuts, 0, ext_info, &cinfo); + CHECK_EQ(ext_info.n_batches, cinfo.cache_mapping.size()); + return cinfo; +} + class TestEllpackPageRawFormat : public ::testing::TestWithParam { public: template @@ -33,7 +47,10 @@ class TestEllpackPageRawFormat : public ::testing::TestWithParam { ASSERT_EQ(cuts->cut_values_.Device(), ctx.Device()); ASSERT_TRUE(cuts->cut_values_.DeviceCanRead()); - policy.SetCuts(cuts, ctx.Device()); + + auto row_stride = GetRowStride(m.get()); + EllpackCacheInfo cinfo = CInfoForTest(&ctx, m.get(), row_stride, param, cuts); + policy.SetCuts(cuts, ctx.Device(), cinfo); std::unique_ptr format{policy.CreatePageFormat(param)}; @@ -98,7 +115,13 @@ TEST_P(TestEllpackPageRawFormat, HostIO) { auto p_fmat = RandomDataGenerator{100, 14, 0.5}.Seed(i).GenerateDMatrix(); for (auto const &page : p_fmat->GetBatches(&ctx, param)) { if (!format) { - policy.SetCuts(page.Impl()->CutsShared(), ctx.Device()); + EllpackCacheInfo cinfo{param, std::numeric_limits::quiet_NaN()}; + for (std::size_t i = 0; i < 3; ++i) { + cinfo.cache_mapping.push_back(i); + cinfo.buffer_bytes.push_back(page.Impl()->MemCostBytes()); + cinfo.buffer_rows.push_back(page.Impl()->n_rows); + } + policy.SetCuts(page.Impl()->CutsShared(), ctx.Device(), cinfo); format = policy.CreatePageFormat(param); } auto writer = policy.CreateWriter({}, i); diff --git a/tests/cpp/data/test_extmem_quantile_dmatrix.cu b/tests/cpp/data/test_extmem_quantile_dmatrix.cu index 6d91ac7f60b1..94419f9574e7 100644 --- a/tests/cpp/data/test_extmem_quantile_dmatrix.cu +++ b/tests/cpp/data/test_extmem_quantile_dmatrix.cu @@ -47,4 +47,75 @@ TEST_P(ExtMemQuantileDMatrixGpu, Basic) { INSTANTIATE_TEST_SUITE_P(ExtMemQuantileDMatrix, ExtMemQuantileDMatrixGpu, ::testing::Combine(::testing::Values(0.0f, 0.2f, 0.4f, 0.8f), ::testing::Bool())); + +class EllpackHostCacheTest : public ::testing::TestWithParam> { + public: + static constexpr bst_idx_t NumSamples() { return 8192; } + static constexpr bst_idx_t NumFeatures() { return 4; } + static constexpr bst_bin_t NumBins() { return 256; } + // Assumes dense + static constexpr bst_idx_t NumBytes() { return NumFeatures() * NumSamples(); } + + void Run(float sparsity, bool is_concat) { + auto ctx = MakeCUDACtx(0); + auto param = BatchParam{NumBins(), tree::TrainParam::DftSparseThreshold()}; + auto n_batches = 4; + auto p_fmat = RandomDataGenerator{NumSamples(), NumFeatures(), sparsity} + .Device(ctx.Device()) + .GenerateDMatrix(); + bst_idx_t min_page_cache_bytes = 0; + if (is_concat) { + min_page_cache_bytes = + p_fmat->GetBatches(&ctx, param).begin().Page()->Impl()->MemCostBytes() / 3; + } + + auto p_ext_fmat = RandomDataGenerator{NumSamples(), NumFeatures(), sparsity} + .Batches(n_batches) + .Bins(param.max_bin) + .Device(ctx.Device()) + .OnHost(true) + .MinPageCacheBytes(min_page_cache_bytes) + .GenerateExtMemQuantileDMatrix("temp", true); + if (!is_concat) { + ASSERT_EQ(p_ext_fmat->NumBatches(), n_batches); + } else { + ASSERT_EQ(p_ext_fmat->NumBatches(), n_batches / 2); + } + ASSERT_EQ(p_fmat->Info().num_row_, p_ext_fmat->Info().num_row_); + for (auto const& page_s : p_fmat->GetBatches(&ctx, param)) { + auto impl_s = page_s.Impl(); + auto cuts = impl_s->CutsShared(); + auto new_impl = std::make_unique(&ctx, cuts, sparsity == 0.0, + impl_s->info.row_stride, impl_s->n_rows); + new_impl->CopyInfo(impl_s); + bst_idx_t offset = 0; + for (auto const& page_m : p_ext_fmat->GetBatches(&ctx, param)) { + auto impl_m = page_m.Impl(); + auto cuts_m = page_m.Impl()->CutsShared(); + ASSERT_EQ(cuts->min_vals_.ConstHostVector(), cuts_m->min_vals_.ConstHostVector()); + ASSERT_EQ(cuts->cut_values_.ConstHostVector(), cuts_m->cut_values_.ConstHostVector()); + ASSERT_EQ(cuts->cut_ptrs_.ConstHostVector(), cuts_m->cut_ptrs_.ConstHostVector()); + offset += new_impl->Copy(&ctx, impl_m, offset); + } + std::vector buffer_s; + auto acc_s = impl_s->GetHostAccessor(&ctx, &buffer_s, {}); + std::vector buffer_m; + auto acc_m = new_impl->GetHostAccessor(&ctx, &buffer_m, {}); + ASSERT_EQ(acc_m.row_stride * acc_m.n_rows, acc_s.row_stride * acc_s.n_rows); + for (std::size_t i = 0; i < acc_m.row_stride * acc_m.n_rows; ++i) { + ASSERT_EQ(acc_s.gidx_iter[i], acc_m.gidx_iter[i]); + } + } + } +}; + +TEST_P(EllpackHostCacheTest, Basic) { + auto ctx = MakeCUDACtx(0); + auto [sparsity, min_page_cache_bytes] = this->GetParam(); + this->Run(sparsity, min_page_cache_bytes); +} + +INSTANTIATE_TEST_SUITE_P(ExtMemQuantileDMatrix, EllpackHostCacheTest, + ::testing::Combine(::testing::Values(0.0f, 0.2f, 0.4f, 0.8f), + ::testing::Bool())); } // namespace xgboost::data diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index a7c1bb3afb90..4b71bb989f66 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -11,6 +11,7 @@ #include "../../../src/data/adapter.h" #include "../../../src/data/file_iterator.h" #include "../../../src/data/simple_dmatrix.h" +#include "../../../src/data/batch_utils.h" // for MatchingPageBytes #include "../../../src/data/sparse_page_dmatrix.h" #include "../../../src/tree/param.h" // for TrainParam #include "../filesystem.h" // dmlc::TemporaryDirectory @@ -31,14 +32,13 @@ void TestSparseDMatrixLoadFile(Context const* ctx) { opath += "?indexing_mode=1&format=libsvm"; data::FileIterator iter{opath, 0, 1}; auto n_threads = 0; - data::SparsePageDMatrix m{&iter, - iter.Proxy(), - data::fileiter::Reset, - data::fileiter::Next, - std::numeric_limits::quiet_NaN(), - n_threads, - tmpdir.path + "cache", - false}; + auto config = ExtMemConfig{.cache = tmpdir.path + "cache", + .on_host = false, + .min_cache_page_bytes = cuda_impl::MatchingPageBytes(), + .missing = std::numeric_limits::quiet_NaN(), + .n_threads = n_threads}; + data::SparsePageDMatrix m{&iter, iter.Proxy(), data::fileiter::Reset, data::fileiter::Next, + config}; ASSERT_EQ(AllThreadsForTest(), m.Ctx()->Threads()); ASSERT_EQ(m.Info().num_col_, 5); ASSERT_EQ(m.Info().num_row_, 64); @@ -365,9 +365,13 @@ auto TestSparsePageDMatrixDeterminism(int32_t threads) { CreateBigTestData(filename, 1 << 16); data::FileIterator iter(filename + "?format=libsvm", 0, 1); + auto config = ExtMemConfig{.cache = filename, + .on_host = false, + .min_cache_page_bytes = cuda_impl::MatchingPageBytes(), + .missing = std::numeric_limits::quiet_NaN(), + .n_threads = threads}; std::unique_ptr sparse{new data::SparsePageDMatrix{ - &iter, iter.Proxy(), data::fileiter::Reset, data::fileiter::Next, - std::numeric_limits::quiet_NaN(), threads, filename, false}}; + &iter, iter.Proxy(), data::fileiter::Reset, data::fileiter::Next, config}}; CHECK(sparse->Ctx()->Threads() == threads || sparse->Ctx()->Threads() == AllThreadsForTest()); DMatrixToCSR(sparse.get(), &sparse_data, &sparse_rptr, &sparse_cids); diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index dd26eca609fc..110d0f898758 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -196,6 +196,7 @@ class TestEllpackPageExt : public ::testing::TestWithParambase_rowid, 0); ASSERT_EQ(impl->n_rows, kRows); ASSERT_EQ(impl->IsDense(), is_dense); - ASSERT_EQ(impl->info.row_stride, 2); - ASSERT_EQ(impl->Cuts().TotalBins(), 4); + ASSERT_EQ(impl->info.row_stride, kCols); + ASSERT_EQ(impl->Cuts().TotalBins(), param.max_bin * kCols); std::unique_ptr impl_ext; size_t offset = 0; diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 78a6b3b03994..05d8fb0c8221 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -448,9 +448,15 @@ void MakeLabels(DeviceOrd device, bst_idx_t n_samples, bst_target_t n_classes, #endif // defined(XGBOOST_USE_CUDA) } - std::shared_ptr p_fmat{DMatrix::Create( - static_cast(iter.get()), iter->Proxy(), Reset, Next, - std::numeric_limits::quiet_NaN(), Context{}.Threads(), prefix, on_host_)}; + ExtMemConfig config{ + .cache = prefix, + .on_host = this->on_host_, + .min_cache_page_bytes = this->min_cache_page_bytes_, + .missing = std::numeric_limits::quiet_NaN(), + .n_threads = Context{}.Threads(), + }; + std::shared_ptr p_fmat{ + DMatrix::Create(static_cast(iter.get()), iter->Proxy(), Reset, Next, config)}; auto row_page_path = data::MakeId(prefix, dynamic_cast(p_fmat.get())) + ".row.page"; @@ -491,9 +497,16 @@ void MakeLabels(DeviceOrd device, bst_idx_t n_samples, bst_target_t n_classes, } CHECK(iter); - std::shared_ptr p_fmat{DMatrix::Create( - static_cast(iter.get()), iter->Proxy(), nullptr, Reset, Next, - std::numeric_limits::quiet_NaN(), 0, this->bins_, prefix, this->on_host_)}; + ExtMemConfig config{ + .cache = prefix, + .on_host = this->on_host_, + .min_cache_page_bytes = this->min_cache_page_bytes_, + .missing = std::numeric_limits::quiet_NaN(), + .n_threads = Context{}.Threads(), + }; + std::shared_ptr p_fmat{DMatrix::Create(static_cast(iter.get()), + iter->Proxy(), nullptr, Reset, Next, this->bins_, + config)}; auto page_path = data::MakeId(prefix, p_fmat.get()); page_path += device_.IsCPU() ? ".gradient_index.page" : ".ellpack.page"; diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 7137d0d51fda..0d0ad2bbb41c 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -54,7 +54,7 @@ class GradientBooster; template Float RelError(Float l, Float r) { - static_assert(std::is_floating_point::value); + static_assert(std::is_floating_point_v); return std::abs(1.0f - l / r); } @@ -166,8 +166,7 @@ class SimpleRealUniformDistribution { /*! \brief Over-simplified version of std::generate_canonical. */ template ResultT GenerateCanonical(GeneratorT* rng) const { - static_assert(std::is_floating_point::value, - "Result type must be floating point."); + static_assert(std::is_floating_point_v, "Result type must be floating point."); long double const r = (static_cast(rng->Max()) - static_cast(rng->Min())) + 1.0L; auto const log2r = static_cast(std::log(r) / std::log(2.0L)); @@ -240,6 +239,7 @@ class RandomDataGenerator { std::vector ft_; bst_cat_t max_cat_{32}; bool on_host_{false}; + bst_idx_t min_cache_page_bytes_{0}; Json ArrayInterfaceImpl(HostDeviceVector* storage, size_t rows, size_t cols) const; @@ -269,6 +269,10 @@ class RandomDataGenerator { on_host_ = on_host; return *this; } + RandomDataGenerator& MinPageCacheBytes(bst_idx_t min_cache_page_bytes) { + this->min_cache_page_bytes_ = min_cache_page_bytes; + return *this; + } RandomDataGenerator& Seed(uint64_t s) { seed_ = s; lcg_.Seed(seed_); diff --git a/tests/python-gpu/test_gpu_data_iterator.py b/tests/python-gpu/test_gpu_data_iterator.py index 63333579f140..bed4880bb8ff 100644 --- a/tests/python-gpu/test_gpu_data_iterator.py +++ b/tests/python-gpu/test_gpu_data_iterator.py @@ -1,5 +1,6 @@ import sys +import numpy as np import pytest from hypothesis import given, settings, strategies @@ -12,6 +13,9 @@ from test_data_iterator import run_data_iterator from test_data_iterator import test_single_batch as cpu_single_batch +# There are lots of warnings if XGBoost is not running on ATS-enabled systems. +pytestmark = pytest.mark.filterwarnings("ignore") + def test_gpu_single_batch() -> None: cpu_single_batch("hist", "cuda") @@ -69,7 +73,6 @@ def test_cpu_data_iterator() -> None: strategies.booleans(), ) @settings(deadline=None, max_examples=10, print_blob=True) -@pytest.mark.filterwarnings("ignore") def test_extmem_qdm( n_samples_per_batch: int, n_features: int, @@ -87,7 +90,6 @@ def test_extmem_qdm( ) -@pytest.mark.filterwarnings("ignore") def test_invalid_device_extmem_qdm() -> None: it = tm.IteratorForTest( *tm.make_batches(16, 4, 2, use_cupy=False), cache="cache", on_host=True @@ -104,7 +106,7 @@ def test_invalid_device_extmem_qdm() -> None: xgb.train({"device": "cpu"}, Xy) -def test_concat_pages() -> None: +def test_concat_pages_invalid() -> None: it = tm.IteratorForTest(*tm.make_batches(64, 16, 4, use_cupy=True), cache=None) Xy = xgb.ExtMemQuantileDMatrix(it) with pytest.raises(ValueError, match="can not be used with concatenated pages"): @@ -120,6 +122,28 @@ def test_concat_pages() -> None: ) +def test_concat_pages() -> None: + boosters = [] + for min_cache_page_bytes in [0, 256, 386, np.iinfo(np.int64).max]: + it = tm.IteratorForTest( + *tm.make_batches(64, 16, 4, use_cupy=True), + cache=None, + min_cache_page_bytes=min_cache_page_bytes, + on_host=True, + ) + Xy = xgb.ExtMemQuantileDMatrix(it) + booster = xgb.train( + { + "device": "cuda", + "objective": "reg:absoluteerror", + }, + Xy, + ) + boosters.append(booster.save_raw(raw_format="json")) + + for model in boosters[1:]: + assert str(model) == str(boosters[0]) + @given( strategies.integers(1, 64), strategies.integers(1, 8),