Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions python-package/xgboost/_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def array_interface_dict(data: np.ndarray) -> ArrayInf:
return cast(ArrayInf, ainf)


def pd_cats_inf( # pylint: disable=too-many-locals
def pd_cat_inf( # pylint: disable=too-many-locals
cats: DfCatAccessor, codes: "pd.Series"
) -> Tuple[Union[StringArray, ArrayInf], ArrayInf, Tuple]:
"""Get the array interface representation of pandas category accessor."""
Expand Down Expand Up @@ -665,12 +665,18 @@ def to_arrow(self) -> ArrowCatList:
)
return self._arrow_arrays

def empty(self) -> bool:
"""Returns True if there's no category."""
return self._handle.value is None

def get_handle(self) -> int:
"""Internal method for retrieving the handle."""
assert self._handle.value
return self._handle.value

def __del__(self) -> None:
if self._handle.value is None:
return
self._free()


Expand Down Expand Up @@ -718,7 +724,7 @@ class TransformedDf(ABC):

def __init__(self, ref_categories: Optional[Categories], aitfs: AifType) -> None:
self.ref_categories = ref_categories
if ref_categories is not None:
if ref_categories is not None and ref_categories.get_handle() is not None:
aif = ref_categories.get_handle()
self.ref_aif: Optional[int] = aif
else:
Expand Down
115 changes: 101 additions & 14 deletions python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import logging
import sys
import types
from typing import Any, Sequence, cast
from typing import TYPE_CHECKING, Any, Sequence, TypeGuard, cast

import numpy as np

from ._typing import _T
from ._typing import _T, DataType

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa

assert sys.version_info[0] == 3, "Python 2 is no longer supported."

Expand All @@ -31,17 +35,6 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
return is_same_module and has_same_name


# pandas
try:
from pandas import DataFrame, Series

PANDAS_INSTALLED = True
except ImportError:
DataFrame = object
Series = object
PANDAS_INSTALLED = False


# sklearn
try:
from sklearn import __version__ as _sklearn_version
Expand Down Expand Up @@ -139,6 +132,14 @@ def import_pyarrow() -> types.ModuleType:
return pa


@functools.cache
def import_pandas() -> types.ModuleType:
"""Import pandas with memory cache."""
import pandas as pd

return pd


@functools.cache
def import_polars() -> types.ModuleType:
"""Import polars with memory cache."""
Expand All @@ -147,6 +148,14 @@ def import_polars() -> types.ModuleType:
return pl


@functools.cache
def is_pandas_available() -> bool:
"""Check the pandas package is available or not."""
if importlib.util.find_spec("pandas") is None:
return False
return True


try:
import scipy.sparse as scipy_sparse
from scipy.sparse import csr_matrix as scipy_csr
Expand All @@ -155,6 +164,84 @@ def import_polars() -> types.ModuleType:
scipy_csr = object


def _is_polars_lazyframe(data: DataType) -> bool:
return lazy_isinstance(data, "polars.lazyframe.frame", "LazyFrame")


def _is_polars_series(data: DataType) -> bool:
return lazy_isinstance(data, "polars.series.series", "Series")


def _is_polars(data: DataType) -> bool:
lf = _is_polars_lazyframe(data)
df = lazy_isinstance(data, "polars.dataframe.frame", "DataFrame")
return lf or df


def _is_arrow(data: DataType) -> TypeGuard["pa.Table"]:
return lazy_isinstance(data, "pyarrow.lib", "Table")


def _is_cudf_df(data: DataType) -> bool:
return lazy_isinstance(data, "cudf.core.dataframe", "DataFrame")


def _is_cudf_ser(data: DataType) -> bool:
return lazy_isinstance(data, "cudf.core.series", "Series")


def _is_cudf_pandas(data: DataType) -> bool:
"""Must go before both pandas and cudf checks."""
return (_is_pandas_df(data) or _is_pandas_series(data)) and lazy_isinstance(
type(data), "cudf.pandas.fast_slow_proxy", "_FastSlowProxyMeta"
)


def _is_pandas_df(data: DataType) -> TypeGuard["pd.DataFrame"]:
return lazy_isinstance(data, "pandas.core.frame", "DataFrame")


def _is_pandas_series(data: DataType) -> TypeGuard["pd.Series"]:
return lazy_isinstance(data, "pandas.core.series", "Series")


def _is_modin_df(data: DataType) -> bool:
return lazy_isinstance(data, "modin.pandas.dataframe", "DataFrame")


def _is_modin_series(data: DataType) -> bool:
return lazy_isinstance(data, "modin.pandas.series", "Series")


def is_dataframe(data: DataType) -> bool:
"""Whether the input is a dataframe. Currently supported dataframes:

- pandas
- cudf
- cudf.pandas
- polars
- pyarrow
- modin


"""
return any(
p(data)
for p in (
_is_polars,
_is_polars_series,
_is_arrow,
_is_cudf_df,
_is_cudf_ser,
_is_cudf_pandas,
_is_pandas_df,
_is_pandas_series,
_is_modin_df,
_is_modin_series,
)
)


def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statements
"""Concatenate row-wise."""
if isinstance(value[0], np.ndarray):
Expand All @@ -167,7 +254,7 @@ def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statem
if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix):
# other sparse format will be converted to CSR.
return scipy_sparse.vstack(value, format="csr")
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
if _is_pandas_df(value[0]) or _is_pandas_series(value[0]):
from pandas import concat as pd_concat

return pd_concat(value, axis=0)
Expand Down
30 changes: 20 additions & 10 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-lines, too-many-locals
"""Core XGBoost Library."""

import copy
import ctypes
import json
Expand Down Expand Up @@ -69,15 +70,17 @@
c_bst_ulong,
)
from .compat import (
PANDAS_INSTALLED,
DataFrame,
import_polars,
import_pyarrow,
is_pandas_available,
is_pyarrow_available,
py_str,
)
from .libpath import find_lib_path, is_sphinx_build

if TYPE_CHECKING:
from pandas import DataFrame as PdDataFrame


class XGBoostError(ValueError):
"""Error thrown by xgboost trainer."""
Expand Down Expand Up @@ -782,7 +785,7 @@ def _get_categories(
cfn: Callable[[ctypes.c_char_p], int],
feature_names: FeatureNames,
n_features: int,
) -> Optional[ArrowCatList]:
) -> ArrowCatList:
if not is_pyarrow_available():
raise ImportError(
"`pyarrow` is required for exporting categories to arrow arrays."
Expand All @@ -797,7 +800,9 @@ def _get_categories(

ret = ctypes.c_char_p()
_check_call(cfn(ret))
assert ret.value is not None
if ret.value is None:
results = [(feature_names[i], None) for i in range(n_features)]
return results

retstr = ret.value.decode() # pylint: disable=no-member
jcats = json.loads(retstr)
Expand Down Expand Up @@ -3201,7 +3206,8 @@ def get_score(
"""Get feature importance of each feature.
For tree model Importance type can be defined as:

* 'weight': the number of times a feature is used to split the data across all trees.
* 'weight': the number of times a feature is used to split the data across all
trees.
* 'gain': the average gain across all splits the feature is used in.
* 'cover': the average coverage across all splits the feature is used in.
* 'total_gain': the total gain across all splits the feature is used in.
Expand Down Expand Up @@ -3261,7 +3267,7 @@ def get_score(
return results

# pylint: disable=too-many-statements
def trees_to_dataframe(self, fmap: PathLike = "") -> DataFrame:
def trees_to_dataframe(self, fmap: PathLike = "") -> "PdDataFrame":
"""Parse a boosted tree model text dump into a pandas DataFrame structure.

This feature is only defined when the decision tree model is chosen as base
Expand All @@ -3274,8 +3280,10 @@ def trees_to_dataframe(self, fmap: PathLike = "") -> DataFrame:
The name of feature map file.
"""
# pylint: disable=too-many-locals
from pandas import DataFrame

fmap = os.fspath(os.path.expanduser(fmap))
if not PANDAS_INSTALLED:
if not is_pandas_available():
raise ImportError(
(
"pandas must be available to use this method."
Expand Down Expand Up @@ -3426,7 +3434,7 @@ def get_split_value_histogram(
fmap: PathLike = "",
bins: Optional[int] = None,
as_pandas: bool = True,
) -> Union[np.ndarray, DataFrame]:
) -> Union[np.ndarray, "PdDataFrame"]:
"""Get split value histogram of a feature

Parameters
Expand Down Expand Up @@ -3482,9 +3490,11 @@ def get_split_value_histogram(
"Split value historgam doesn't support categorical split."
)

if as_pandas and PANDAS_INSTALLED:
if as_pandas and is_pandas_available():
from pandas import DataFrame

return DataFrame(nph_stacked, columns=["SplitValue", "Count"])
if as_pandas and not PANDAS_INSTALLED:
if as_pandas and not is_pandas_available():
warnings.warn(
"Returning histogram as ndarray"
" (as_pandas == True, but pandas is not installed).",
Expand Down
12 changes: 8 additions & 4 deletions python-package/xgboost/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
from ..collective import Config as CollConfig
from ..collective import _Args as CollArgs
from ..collective import _ArgVals as CollArgsVals
from ..compat import DataFrame, lazy_isinstance
from ..compat import _is_cudf_df
from ..core import (
Booster,
DMatrix,
Expand Down Expand Up @@ -942,7 +942,7 @@ def _maybe_dataframe(
# In older versions of dask, the partition is actually a numpy array when input
# is dataframe.
index = getattr(data, "index", None)
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
if _is_cudf_df(data):
import cudf

if prediction.size == 0:
Expand All @@ -952,10 +952,14 @@ def _maybe_dataframe(
prediction, columns=columns, dtype=numpy.float32, index=index
)
else:
import pandas as pd

if prediction.size == 0:
return DataFrame({}, columns=columns, dtype=numpy.float32, index=index)
return pd.DataFrame(
{}, columns=columns, dtype=numpy.float32, index=index
)

prediction = DataFrame(
prediction = pd.DataFrame(
prediction, columns=columns, dtype=numpy.float32, index=index
)
return prediction
Expand Down
Loading
Loading