Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/xgboost/collective/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,5 +159,6 @@ template <typename Fn>
return fn();
}

void SafeColl(Result const& rc);
void SafeColl(Result const& rc, char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE());
} // namespace xgboost::collective
21 changes: 2 additions & 19 deletions python-package/xgboost/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
import logging
from collections import defaultdict
from contextlib import contextmanager
from functools import cache, partial, update_wrapper
from functools import partial, update_wrapper
from threading import Thread
from typing import (
Any,
Expand Down Expand Up @@ -85,8 +85,6 @@
from dask import dataframe as dd
from dask.delayed import Delayed
from distributed import Future
from packaging.version import Version
from packaging.version import parse as parse_version

from .. import collective, config
from .._data_utils import Categories
Expand Down Expand Up @@ -124,7 +122,7 @@
from ..tracker import RabitTracker
from ..training import train as worker_train
from .data import _get_dmatrices, no_group_split
from .utils import get_address_from_user, get_n_threads
from .utils import _DASK_2024_12_1, _DASK_2025_3_0, get_address_from_user, get_n_threads

_DaskCollection: TypeAlias = Union[da.Array, dd.DataFrame, dd.Series]
_DataT: TypeAlias = Union[da.Array, dd.DataFrame] # do not use series as predictor
Expand Down Expand Up @@ -174,21 +172,6 @@
LOGGER = logging.getLogger("[xgboost.dask]")


@cache
def _DASK_VERSION() -> Version:
return parse_version(dask.__version__)


@cache
def _DASK_2024_12_1() -> bool:
return _DASK_VERSION() >= parse_version("2024.12.1")


@cache
def _DASK_2025_3_0() -> bool:
return _DASK_VERSION() >= parse_version("2025.3.0")


def _try_start_tracker(
n_workers: int,
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
Expand Down
20 changes: 20 additions & 0 deletions python-package/xgboost/dask/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# pylint: disable=invalid-name
"""Utilities for the XGBoost Dask interface."""

import logging
import warnings
from functools import cache as fcache
from typing import Any, Dict, Optional, Tuple

import dask
import distributed
from packaging.version import Version
from packaging.version import parse as parse_version

from ..collective import Config

Expand Down Expand Up @@ -97,3 +102,18 @@ def get_address_from_user(
port = coll_cfg.tracker_port

return host_ip, port


@fcache
def _DASK_VERSION() -> Version:
return parse_version(dask.__version__)


@fcache
def _DASK_2024_12_1() -> bool:
return _DASK_VERSION() >= parse_version("2024.12.1")


@fcache
def _DASK_2025_3_0() -> bool:
return _DASK_VERSION() >= parse_version("2025.3.0")
126 changes: 72 additions & 54 deletions python-package/xgboost/testing/dask.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# pylint: disable=invalid-name
"""Tests for dask shared by different test modules."""

from typing import Any, List, Literal, Tuple, cast
from typing import Any, List, Literal, Tuple, Type, cast

import numpy as np
import pandas as pd
from dask import array as da
from dask import dataframe as dd
from distributed import Client, get_worker, wait
from distributed import Client, get_worker
from packaging.version import parse as parse_version
from sklearn.datasets import make_classification

Expand All @@ -17,7 +18,8 @@

from .. import dask as dxgb
from .._typing import EvalsLog
from ..dask import _DASK_VERSION, _get_rabit_args
from ..dask import _get_rabit_args
from ..dask.utils import _DASK_VERSION
from .data import make_batches
from .data import make_categorical as make_cat_local
from .ordinal import make_recoded
Expand Down Expand Up @@ -325,61 +327,77 @@ def pack(**kwargs: Any) -> dd.DataFrame:
# pylint: disable=too-many-locals
def run_recode(client: Client, device: Device) -> None:
"""Run re-coding test with the Dask interface."""
enc, reenc, y, _, _ = make_recoded(device, n_features=96)
workers = get_client_workers(client)
denc, dreenc, dy = (
dd.from_pandas(enc, npartitions=8).persist(workers=workers),
dd.from_pandas(reenc, npartitions=8).persist(workers=workers),
da.from_array(y, chunks=(y.shape[0] // 8,)).persist(workers=workers),
)

wait([denc, dreenc, dy])
def create_dmatrix(
DMatrixT: Type[dxgb.DaskDMatrix], *args: Any, **kwargs: Any
) -> dxgb.DaskDMatrix:
if DMatrixT is dxgb.DaskQuantileDMatrix:
ref = kwargs.pop("ref", None)
return DMatrixT(*args, ref=ref, **kwargs)

if device == "cuda":
denc = denc.to_backend("cudf")
dreenc = dreenc.to_backend("cudf")
dy = dy.to_backend("cupy")
kwargs.pop("ref", None)
return DMatrixT(*args, **kwargs)

Xy = dxgb.DaskQuantileDMatrix(client, denc, dy, enable_categorical=True)
Xy_valid = dxgb.DaskQuantileDMatrix(
client, dreenc, dy, enable_categorical=True, ref=Xy
)
# Base model
results = dxgb.train(client, {"device": device}, Xy, evals=[(Xy_valid, "Valid")])
def run(DMatrixT: Type[dxgb.DaskDMatrix]) -> None:
enc, reenc, y, _, _ = make_recoded(device, n_features=96)
to = get_client_workers(client)

# Training continuation
Xy = dxgb.DaskQuantileDMatrix(client, denc, dy, enable_categorical=True)
Xy_valid = dxgb.DaskQuantileDMatrix(
client, dreenc, dy, enable_categorical=True, ref=Xy
)
results_1 = dxgb.train(
client,
{"device": device},
Xy,
evals=[(Xy_valid, "Valid")],
xgb_model=results["booster"],
)
denc, dreenc, dy = (
dd.from_pandas(enc, npartitions=8).persist(workers=to),
dd.from_pandas(reenc, npartitions=8).persist(workers=to),
da.from_array(y, chunks=(y.shape[0] // 8,)).persist(workers=to),
)

# Reversed training continuation
Xy = dxgb.DaskQuantileDMatrix(client, dreenc, dy, enable_categorical=True)
Xy_valid = dxgb.DaskQuantileDMatrix(
client, denc, dy, enable_categorical=True, ref=Xy
)
results_2 = dxgb.train(
client,
{"device": device},
Xy,
evals=[(Xy_valid, "Valid")],
xgb_model=results["booster"],
)
np.testing.assert_allclose(
results_1["history"]["Valid"]["rmse"], results_2["history"]["Valid"]["rmse"]
)
if device == "cuda":
denc = denc.to_backend("cudf")
dreenc = dreenc.to_backend("cudf")
dy = dy.to_backend("cupy")

Xy = create_dmatrix(DMatrixT, client, denc, dy, enable_categorical=True)
Xy_valid = create_dmatrix(
DMatrixT, client, dreenc, dy, enable_categorical=True, ref=Xy
)
# Base model
results = dxgb.train(
client, {"device": device}, Xy, evals=[(Xy_valid, "Valid")]
)

# Training continuation
Xy = create_dmatrix(DMatrixT, client, denc, dy, enable_categorical=True)
Xy_valid = create_dmatrix(
DMatrixT, client, dreenc, dy, enable_categorical=True, ref=Xy
)
results_1 = dxgb.train(
client,
{"device": device},
Xy,
evals=[(Xy_valid, "Valid")],
xgb_model=results["booster"],
)

# Reversed training continuation
Xy = create_dmatrix(DMatrixT, client, dreenc, dy, enable_categorical=True)
Xy_valid = create_dmatrix(
DMatrixT, client, denc, dy, enable_categorical=True, ref=Xy
)
results_2 = dxgb.train(
client,
{"device": device},
Xy,
evals=[(Xy_valid, "Valid")],
xgb_model=results["booster"],
)
np.testing.assert_allclose(
results_1["history"]["Valid"]["rmse"], results_2["history"]["Valid"]["rmse"]
)

predt_0 = dxgb.inplace_predict(client, results, denc).compute()
predt_1 = dxgb.inplace_predict(client, results, dreenc).compute()
assert_allclose(device, predt_0, predt_1)

predt_0 = dxgb.inplace_predict(client, results, denc).compute()
predt_1 = dxgb.inplace_predict(client, results, dreenc).compute()
assert_allclose(device, predt_0, predt_1)
predt_0 = dxgb.predict(client, results, Xy).compute()
predt_1 = dxgb.predict(client, results, Xy_valid).compute()
assert_allclose(device, predt_0, predt_1)

predt_0 = dxgb.predict(client, results, Xy).compute()
predt_1 = dxgb.predict(client, results, Xy_valid).compute()
assert_allclose(device, predt_0, predt_1)
for DMatrixT in [dxgb.DaskDMatrix, dxgb.DaskQuantileDMatrix]:
run(DMatrixT)
20 changes: 15 additions & 5 deletions src/collective/result.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2024, XGBoost Contributors
* Copyright 2024-2025, XGBoost Contributors
*/
#include "xgboost/collective/result.h"

Expand Down Expand Up @@ -65,17 +65,27 @@ void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
dmlc::DateLogger logger;
if (file && line != -1) {
auto name = std::filesystem::path{ file }.filename();
auto name = std::filesystem::path{file}.filename();
return "[" + name.string() + ":" + std::to_string(line) + "|" + logger.HumanDate() +
"]: " + std::forward<std::string>(msg);
}
return std::string{"["} + logger.HumanDate() + "]" + std::forward<std::string>(msg); // NOLINT
}
} // namespace detail

void SafeColl(Result const& rc) {
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
void SafeColl(Result const& rc, char const* file, std::int32_t line) {
if (rc.OK()) {
return;
}
if (file && line != -1) {
dmlc::DateLogger logger;
auto name = std::filesystem::path{file}.filename();
LOG(FATAL) << ("[" + name.string() + ":" + std::to_string(line) + "|" + logger.HumanDate() +
"]:\n")
<< rc.Report();
// Return just in case if this function is deep in ctypes callbacks.
return;
}
LOG(FATAL) << rc.Report();
}
} // namespace xgboost::collective
26 changes: 23 additions & 3 deletions src/data/cat_container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
#include <utility> // for move
#include <vector> // for vector

#include "../common/error_msg.h" // for NoFloatCat
#include "../encoder/types.h" // for Overloaded
#include "xgboost/json.h" // for Json
#include "../collective/allreduce.h" // for Allreduce
#include "../collective/communicator-inl.h" // for GetRank, GetWorldSize
#include "../common/error_msg.h" // for NoFloatCat
#include "../encoder/types.h" // for Overloaded
#include "xgboost/json.h" // for Json

namespace xgboost {
CatContainer::CatContainer(enc::HostColumnsView const& df, bool is_ref) : CatContainer{} {
Expand Down Expand Up @@ -293,4 +295,22 @@ void CatContainer::Sort(Context const* ctx) {
enc::SortNames(enc::Policy<EncErrorPolicy>{}, view, this->sorted_idx_.HostSpan());
}
#endif // !defined(XGBOOST_USE_CUDA)

void SyncCategories(Context const* ctx, CatContainer* cats, bool is_empty) {
CHECK(cats);
if (!collective::IsDistributed()) {
return;
}

auto rank = collective::GetRank();
std::vector<std::int32_t> workers(collective::GetWorldSize(), 0);
workers[rank] = is_empty;
collective::SafeColl(collective::Allreduce(ctx, &workers, collective::Op::kSum));
if (cats->HasCategorical() &&
std::any_of(workers.cbegin(), workers.cend(), [](auto v) { return v == 1; })) {
LOG(FATAL)
<< "A worker cannot have empty input when a dataframe with categorical features is used. "
"XGBoost cannot infer the categories if the input is empty.";
}
}
} // namespace xgboost
4 changes: 3 additions & 1 deletion src/data/cat_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class CatContainer {
* this method returns True.
*/
[[nodiscard]] bool Empty() const;
[[nodiscard]] bool NeedRecode() const { return !this->Empty() && !this->is_ref_; }
[[nodiscard]] bool NeedRecode() const { return this->HasCategorical() && !this->is_ref_; }

[[nodiscard]] std::size_t NumFeatures() const;
/**
Expand Down Expand Up @@ -263,6 +263,8 @@ struct NoOpAccessor {
[[nodiscard]] XGBOOST_DEVICE float operator()(Entry const& e) const { return e.fvalue; }
};

void SyncCategories(Context const* ctx, CatContainer* cats, bool is_empty);

namespace cpu_impl {
inline auto MakeCatAccessor(Context const* ctx, enc::HostColumnsView const& new_enc,
CatContainer const* orig_cats) {
Expand Down
2 changes: 2 additions & 0 deletions src/data/extmem_quantile_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ ExtMemQuantileDMatrix::ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrix
}
this->batch_ = p;
this->fmat_ctx_ = ctx;

SyncCategories(&ctx, info_.Cats(), info_.num_row_ == 0);
}

ExtMemQuantileDMatrix::~ExtMemQuantileDMatrix() {
Expand Down
5 changes: 4 additions & 1 deletion src/data/iterative_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
#include <utility> // for move
#include <vector> // for vector

#include "../common/categorical.h" // common::IsCat
#include "../common/categorical.h" // for IsCat
#include "../common/hist_util.h" // for HistogramCuts
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "batch_utils.h" // for RegenGHist
#include "cat_container.h" // for SyncCategories
#include "gradient_index.h" // for GHistIndexMatrix
#include "proxy_dmatrix.h" // for DataIterProxy, DispatchAny
#include "quantile_dmatrix.h" // for GetCutsFromRef
Expand Down Expand Up @@ -50,6 +51,8 @@ IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle pro
this->fmat_ctx_ = ctx;
this->batch_ = p;

SyncCategories(&ctx, info_.Cats(), info_.num_row_ == 0);

LOG(INFO) << "Finished constructing the `IterativeDMatrix`: (" << this->Info().num_row_ << ", "
<< this->Info().num_col_ << ", " << this->info_.num_nonzero_ << ").";
}
Expand Down
2 changes: 2 additions & 0 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
}
info_.num_nonzero_ = data_vec.size();

SyncCategories(&ctx, info_.Cats(), info_.num_row_ == 0);

// Sort the index for row partitioners used by variuos tree methods.
if (!sparse_page_->IsIndicesSorted(ctx.Threads())) {
sparse_page_->SortIndices(ctx.Threads());
Expand Down
3 changes: 2 additions & 1 deletion src/data/sparse_page_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
iter.Reset();

ext_info_.SetInfo(&ctx, true, &this->info_);

fmat_ctx_ = ctx;

SyncCategories(&ctx, info_.Cats(), info_.num_row_ == 0);
}

SparsePageDMatrix::~SparsePageDMatrix() {
Expand Down
Loading
Loading