Skip to content

Commit

Permalink
[EM] Improve external iterator. (#10876)
Browse files Browse the repository at this point in the history
- Use bool instead of int in the Python interface.
- Unify the internal advance operator for the ext qdm between CPU and GPU.
- Avoid the use of `n_batches` calculated from the external iterator and use the cache size instead. We will add support for internal page concatenation for ellpack to handle irregular batch sizes from distributed frameworks.
  • Loading branch information
trivialfis authored Oct 8, 2024
1 parent d9bb4fb commit 56d155c
Show file tree
Hide file tree
Showing 21 changed files with 137 additions and 113 deletions.
10 changes: 5 additions & 5 deletions demo/guide-python/external_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,21 @@ def load_file(self) -> Tuple[np.ndarray, np.ndarray]:
assert X.shape[0] == y.shape[0]
return X, y

def next(self, input_data: Callable) -> int:
def next(self, input_data: Callable) -> bool:
"""Advance the iterator by 1 step and pass the data to XGBoost. This function
is called by XGBoost during the construction of ``DMatrix``
"""
if self._it == len(self._file_paths):
# return 0 to let XGBoost know this is the end of iteration
return 0
# return False to let XGBoost know this is the end of iteration
return False

# input_data is a function passed in by XGBoost and has the similar signature to
# the ``DMatrix`` constructor.
X, y = self.load_file()
input_data(data=X, label=y)
self._it += 1
return 1
return True

def reset(self) -> None:
"""Reset the iterator to its beginning"""
Expand Down Expand Up @@ -153,7 +153,7 @@ def main(tmpdir: str, args: argparse.Namespace) -> None:

# It's important to use RMM for GPU-based external memory to improve performance.
# If XGBoost is not built with RMM support, a warning will be raised.
mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource())
mr = rmm.mr.CudaAsyncMemoryResource()
rmm.mr.set_current_device_resource(mr)
# Set the allocator for cupy as well.
cp.cuda.set_allocator(rmm_cupy_allocator)
Expand Down
2 changes: 1 addition & 1 deletion doc/tutorials/external_memory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ the GPU. This is a current limitation we aim to address in the future.
# It's important to use RMM for GPU-based external memory to improve performance.
# If XGBoost is not built with RMM support, a warning will be raised.
mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource())
mr = rmm.mr.CudaAsyncMemoryResource()
rmm.mr.set_current_device_resource(mr)
# Set the allocator for cupy as well.
cp.cuda.set_allocator(rmm_cupy_allocator)
Expand Down
6 changes: 3 additions & 3 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,15 +665,15 @@ def input_data(
if self._release:
self._temporary_data = None
# pylint: disable=not-callable
return self._handle_exception(lambda: self.next(input_data), 0)
return self._handle_exception(lambda: int(self.next(input_data)), 0)

@abstractmethod
def reset(self) -> None:
"""Reset the data iterator. Prototype for user defined function."""
raise NotImplementedError()

@abstractmethod
def next(self, input_data: Callable) -> int:
def next(self, input_data: Callable) -> bool:
"""Set the next batch of data.
Parameters
Expand All @@ -685,7 +685,7 @@ def next(self, input_data: Callable) -> int:
Returns
-------
0 if there's no more batch, otherwise 1.
False if there's no more batch, otherwise True.
"""
raise NotImplementedError()
Expand Down
8 changes: 4 additions & 4 deletions python-package/xgboost/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,11 +636,11 @@ def reset(self) -> None:
"""Reset the iterator"""
self._iter = 0

def next(self, input_data: Callable) -> int:
def next(self, input_data: Callable) -> bool:
"""Yield next batch of data"""
if self._iter == len(self._data):
# Return 0 when there's no more batch.
return 0
# Return False when there's no more batch.
return False

input_data(
data=self.data(),
Expand All @@ -656,7 +656,7 @@ def next(self, input_data: Callable) -> int:
feature_weights=self._feature_weights,
)
self._iter += 1
return 1
return True


class DaskQuantileDMatrix(DaskDMatrix):
Expand Down
6 changes: 3 additions & 3 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,12 +1472,12 @@ def __init__(self, **kwargs: Any) -> None:
# might use memory.
super().__init__(release_data=False)

def next(self, input_data: Callable) -> int:
def next(self, input_data: Callable) -> bool:
if self.it == 1:
return 0
return False
self.it += 1
input_data(**self.kwargs)
return 1
return True

def reset(self) -> None:
self.it = 0
Expand Down
6 changes: 3 additions & 3 deletions python-package/xgboost/spark/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def _fetch(self, data: Optional[Sequence[pd.DataFrame]]) -> Optional[pd.DataFram

return data[self._iter]

def next(self, input_data: Callable) -> int:
def next(self, input_data: Callable) -> bool:
if self._iter == len(self._data[alias.data]):
return 0
return False
input_data(
data=self._fetch(self._data[alias.data]),
label=self._fetch(self._data.get(alias.label, None)),
Expand All @@ -106,7 +106,7 @@ def next(self, input_data: Callable) -> int:
**self._kwargs,
)
self._iter += 1
return 1
return True

def reset(self) -> None:
self._iter = 0
Expand Down
6 changes: 3 additions & 3 deletions python-package/xgboost/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,9 @@ def __init__( # pylint: disable=too-many-arguments
self.it = 0
super().__init__(cache_prefix=cache, on_host=on_host)

def next(self, input_data: Callable) -> int:
def next(self, input_data: Callable) -> bool:
if self.it == len(self.X):
return 0
return False

with pytest.raises(TypeError, match="Keyword argument"):
input_data(self.X[self.it], self.y[self.it], None)
Expand All @@ -250,7 +250,7 @@ def next(self, input_data: Callable) -> int:
)
gc.collect() # clear up the copy, see if XGBoost access freed memory.
self.it += 1
return 1
return True

def reset(self) -> None:
self.it = 0
Expand Down
7 changes: 7 additions & 0 deletions src/data/device_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
dh::safe_cuda(cudaSetDevice(device_.ordinal));
for (auto& json_col : json_columns) {
auto column = ArrayInterface<1>(get<Object const>(json_col));
n_bytes_ += column.ElementSize() * column.Shape(0);
columns.push_back(column);
num_rows_ = std::max(num_rows_, column.Shape(0));
CHECK_EQ(device_.ordinal, dh::CudaGetPointerDevice(column.data))
Expand All @@ -145,11 +146,13 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
[[nodiscard]] std::size_t NumRows() const { return num_rows_; }
[[nodiscard]] std::size_t NumColumns() const { return columns_.size(); }
[[nodiscard]] DeviceOrd Device() const { return device_; }
[[nodiscard]] bst_idx_t SizeBytes() const { return this->n_bytes_; }

private:
CudfAdapterBatch batch_;
dh::device_vector<ArrayInterface<1>> columns_;
size_t num_rows_{0};
bst_idx_t n_bytes_{0};
DeviceOrd device_{DeviceOrd::CPU()};
};

Expand Down Expand Up @@ -189,6 +192,8 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
return;
}
device_ = DeviceOrd::CUDA(dh::CudaGetPointerDevice(array_interface_.data));
this->n_bytes_ =
array_interface_.Shape(0) * array_interface_.Shape(1) * array_interface_.ElementSize();
CHECK(device_.IsCUDA());
}
explicit CupyAdapter(std::string cuda_interface_str)
Expand All @@ -198,10 +203,12 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
[[nodiscard]] std::size_t NumRows() const { return array_interface_.Shape(0); }
[[nodiscard]] std::size_t NumColumns() const { return array_interface_.Shape(1); }
[[nodiscard]] DeviceOrd Device() const { return device_; }
[[nodiscard]] bst_idx_t SizeBytes() const { return this->n_bytes_; }

private:
ArrayInterface<2> array_interface_;
CupyAdapterBatch batch_;
bst_idx_t n_bytes_{0};
DeviceOrd device_{DeviceOrd::CPU()};
};

Expand Down
28 changes: 13 additions & 15 deletions src/data/ellpack_page_source.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
/**
* Copyright 2019-2024, XGBoost contributors
*/
#include <thrust/host_vector.h> // for host_vector

#include <cstddef> // for size_t
#include <cstdint> // for int8_t, uint64_t, uint32_t
#include <memory> // for shared_ptr, make_unique, make_shared
Expand Down Expand Up @@ -46,7 +44,7 @@ EllpackPageImpl const* EllpackHostCache::Get(std::int32_t k) {
*/
class EllpackHostCacheStreamImpl {
std::shared_ptr<EllpackHostCache> cache_;
std::int32_t ptr_;
std::int32_t ptr_{0};

public:
explicit EllpackHostCacheStreamImpl(std::shared_ptr<EllpackHostCache> cache)
Expand Down Expand Up @@ -204,21 +202,22 @@ template <typename F>
void EllpackPageSourceImpl<F>::Fetch() {
curt::SetDevice(this->Device().ordinal);
if (!this->ReadCache()) {
if (this->count_ != 0 && !this->sync_) {
if (this->Iter() != 0 && !this->sync_) {
// source is initialized to be the 0th page during construction, so when count_ is 0
// there's no need to increment the source.
++(*this->source_);
}
// This is not read from cache so we still need it to be synced with sparse page source.
CHECK_EQ(this->count_, this->source_->Iter());
CHECK_EQ(this->Iter(), this->source_->Iter());
auto const& csr = this->source_->Page();
this->page_.reset(new EllpackPage{});
auto* impl = this->page_->Impl();
Context ctx = Context{}.MakeCUDA(this->Device().ordinal);
*impl = EllpackPageImpl{&ctx, this->GetCuts(), *csr, is_dense_, row_stride_, feature_types_};
this->page_->SetBaseRowId(csr->base_rowid);
LOG(INFO) << "Generated an Ellpack page with size: " << impl->MemCostBytes()
<< " from a SparsePage with size:" << csr->MemCostBytes();
LOG(INFO) << "Generated an Ellpack page with size: "
<< common::HumanMemUnit(impl->MemCostBytes())
<< " from a SparsePage with size:" << common::HumanMemUnit(csr->MemCostBytes());
this->WriteCache();
}
}
Expand All @@ -239,9 +238,7 @@ void ExtEllpackPageSourceImpl<F>::Fetch() {
curt::SetDevice(this->Device().ordinal);
if (!this->ReadCache()) {
auto iter = this->source_->Iter();
CHECK_EQ(this->count_, iter);
++(*this->source_);
CHECK_GE(this->source_->Iter(), 1);
CHECK_EQ(this->Iter(), iter);
cuda_impl::Dispatch(proxy_, [this](auto const& value) {
CHECK(this->proxy_->Ctx()->IsCUDA()) << "All batches must use the same device type.";
proxy_->Info().feature_types.SetDevice(dh::GetDevice(this->ctx_));
Expand All @@ -250,11 +247,7 @@ void ExtEllpackPageSourceImpl<F>::Fetch() {

dh::device_vector<size_t> row_counts(n_samples + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
cuda_impl::Dispatch(proxy_, [=](auto const& value) {
return GetRowCounts(this->ctx_, value, row_counts_span, dh::GetDevice(this->ctx_),
this->missing_);
});

GetRowCounts(this->ctx_, value, row_counts_span, dh::GetDevice(this->ctx_), this->missing_);
this->page_.reset(new EllpackPage{});
*this->page_->Impl() = EllpackPageImpl{this->ctx_,
value,
Expand All @@ -267,6 +260,11 @@ void ExtEllpackPageSourceImpl<F>::Fetch() {
this->GetCuts()};
this->info_->Extend(proxy_->Info(), false, true);
});
// The size of ellpack is logged in write cache.
LOG(INFO) << "Estimated batch size:"
<< cuda_impl::Dispatch<false>(proxy_, [](auto const& adapter) {
return common::HumanMemUnit(adapter->SizeBytes());
});
this->page_->SetBaseRowId(this->ext_info_.base_rows.at(iter));
this->WriteCache();
}
Expand Down
21 changes: 9 additions & 12 deletions src/data/ellpack_page_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ class EllpackFormatPolicy {
curt::DrVersion(&major, &minor);
if (!(major >= 12 && minor >= 7) && curt::SupportsAts()) {
// Use ATS, but with an old kernel driver.
LOG(WARNING) << "Using an old kernel driver with supported CTK<12.7." << msg;
LOG(WARNING) << "Using an old kernel driver with supported CTK<12.7."
<< "The latest version of CTK supported by the current driver: " << major << "."
<< minor << "." << msg;
}
}
// For testing with the HMM flag.
Expand Down Expand Up @@ -206,28 +208,23 @@ class ExtEllpackPageSourceImpl : public ExtQantileSourceMixin<EllpackPage, Forma
MetaInfo* info_;
ExternalDataInfo ext_info_;

std::vector<bst_idx_t> base_rows_;

public:
ExtEllpackPageSourceImpl(
Context const* ctx, float missing, MetaInfo* info, ExternalDataInfo ext_info,
std::shared_ptr<Cache> cache, BatchParam param, std::shared_ptr<common::HistogramCuts> cuts,
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> source,
DMatrixProxy* proxy, std::vector<bst_idx_t> base_rows)
: Super{missing,
ctx->Threads(),
static_cast<bst_feature_t>(info->num_col_),
ext_info.n_batches,
source,
cache},
DMatrixProxy* proxy)
: Super{missing, ctx->Threads(), static_cast<bst_feature_t>(info->num_col_), source, cache},
ctx_{ctx},
p_{std::move(param)},
proxy_{proxy},
info_{info},
ext_info_{std::move(ext_info)},
base_rows_{std::move(base_rows)} {
ext_info_{std::move(ext_info)} {
cuts->SetDevice(ctx->Device());
this->SetCuts(std::move(cuts), ctx->Device());
CHECK(!this->cache_info_->written);
this->source_->Reset();
CHECK(this->source_->Next());
this->Fetch();
}

Expand Down
3 changes: 1 addition & 2 deletions src/data/extmem_quantile_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ void ExtMemQuantileDMatrix::InitFromCPU(
*/
auto id = MakeCache(this, ".gradient_index.page", false, cache_prefix_, &cache_info_);
this->ghist_index_source_ = std::make_unique<ExtGradientIndexPageSource>(
ctx, missing, &this->Info(), ext_info.n_batches, cache_info_.at(id), p, cuts, iter, proxy,
ext_info.base_rows);
ctx, missing, &this->Info(), cache_info_.at(id), p, cuts, iter, proxy, ext_info.base_rows);

/**
* Force initialize the cache and do some sanity checks along the way
Expand Down
2 changes: 1 addition & 1 deletion src/data/extmem_quantile_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void ExtMemQuantileDMatrix::InitFromCUDA(
[&](auto &&ptr) {
using SourceT = typename std::remove_reference_t<decltype(ptr)>::element_type;
ptr = std::make_shared<SourceT>(ctx, missing, &this->Info(), ext_info, cache_info_.at(id),
p, cuts, iter, proxy, ext_info.base_rows);
p, cuts, iter, proxy);
},
ellpack_page_source_);

Expand Down
8 changes: 3 additions & 5 deletions src/data/gradient_index_page_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ void GradientIndexPageSource::Fetch() {
void ExtGradientIndexPageSource::Fetch() {
if (!this->ReadCache()) {
CHECK_EQ(count_, source_->Iter());
++(*source_);
CHECK_GE(source_->Iter(), 1);
CHECK_NE(cuts_.Values().size(), 0);
HostAdapterDispatch(proxy_, [this](auto const& value) {
CHECK(this->proxy_->Ctx()->IsCPU()) << "All batches must use the same device type.";
Expand All @@ -49,9 +47,9 @@ void ExtGradientIndexPageSource::Fetch() {
// FIXME(jiamingy): For now, we use the `info->IsDense()` to represent all batches
// similar to the sparse DMatrix source. We should use per-batch property with proxy
// DMatrix info instead. This requires more fine-grained tests.
this->page_ = std::make_shared<GHistIndexMatrix>(
value.NumRows(), this->base_rows_.at(source_->Iter() - 1), std::move(cuts),
this->p_.max_bin, info_->IsDense());
this->page_ =
std::make_shared<GHistIndexMatrix>(value.NumRows(), this->base_rows_.at(source_->Iter()),
std::move(cuts), this->p_.max_bin, info_->IsDense());
bst_idx_t prev_sum = 0;
bst_idx_t rbegin = 0;
// Use `value.NumRows()` for the size of a single batch. Unlike the
Expand Down
13 changes: 8 additions & 5 deletions src/data/gradient_index_page_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class GradientIndexPageSource

public:
GradientIndexPageSource(float missing, std::int32_t nthreads, bst_feature_t n_features,
size_t n_batches, std::shared_ptr<Cache> cache, BatchParam param,
bst_idx_t n_batches, std::shared_ptr<Cache> cache, BatchParam param,
common::HistogramCuts cuts, bool is_dense,
common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source)
Expand Down Expand Up @@ -81,18 +81,21 @@ class ExtGradientIndexPageSource

public:
ExtGradientIndexPageSource(
Context const* ctx, float missing, MetaInfo* info, bst_idx_t n_batches,
std::shared_ptr<Cache> cache, BatchParam param, common::HistogramCuts cuts,
Context const* ctx, float missing, MetaInfo* info, std::shared_ptr<Cache> cache,
BatchParam param, common::HistogramCuts cuts,
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> source,
DMatrixProxy* proxy, std::vector<bst_idx_t> base_rows)
: ExtQantileSourceMixin{missing, ctx->Threads(), static_cast<bst_feature_t>(info->num_col_),
n_batches, source, cache},
: ExtQantileSourceMixin{missing, ctx->Threads(), static_cast<bst_feature_t>(info->num_col_),
source, cache},
p_{std::move(param)},
ctx_{ctx},
proxy_{proxy},
info_{info},
feature_types_{info_->feature_types.ConstHostSpan()},
base_rows_{std::move(base_rows)} {
CHECK(!this->cache_info_->written);
this->source_->Reset();
CHECK(this->source_->Next());
this->SetCuts(std::move(cuts));
this->Fetch();
}
Expand Down
Loading

0 comments on commit 56d155c

Please sign in to comment.