1- # pylint: disable=invalid-name
1+ # pylint: disable=invalid-name, too-many-lines
22"""Utilities for data generation."""
33import multiprocessing
44import os
5+ import string
56import zipfile
67from concurrent .futures import ThreadPoolExecutor
78from dataclasses import dataclass
1415 List ,
1516 NamedTuple ,
1617 Optional ,
18+ Set ,
1719 Tuple ,
1820 Type ,
1921 Union ,
2628from numpy .random import Generator as RNG
2729from scipy import sparse
2830
29- import xgboost
30- from xgboost .data import pandas_pyarrow_mapper
31+ from ..core import DMatrix , QuantileDMatrix
32+ from ..data import is_pd_cat_dtype , pandas_pyarrow_mapper
33+ from ..sklearn import ArrayLike , XGBRanker
34+ from ..training import train as train_fn
3135
3236if TYPE_CHECKING :
3337 from ..compat import DataFrame as DataFrameT
@@ -42,7 +46,7 @@ def np_dtypes(
4246 n_samples : int , n_features : int
4347) -> Generator [Tuple [np .ndarray , np .ndarray ], None , None ]:
4448 """Enumerate all supported dtypes from numpy."""
45- import pandas as pd
49+ pd = pytest . importorskip ( "pandas" )
4650
4751 rng = np .random .RandomState (1994 )
4852 # Integer and float.
@@ -99,7 +103,7 @@ def np_dtypes(
99103
100104def pd_dtypes () -> Generator :
101105 """Enumerate all supported pandas extension types."""
102- import pandas as pd
106+ pd = pytest . importorskip ( "pandas" )
103107
104108 # Integer
105109 dtypes = [
@@ -162,8 +166,8 @@ def pd_dtypes() -> Generator:
162166
163167def pd_arrow_dtypes () -> Generator :
164168 """Pandas DataFrame with pyarrow backed type."""
165- import pandas as pd
166- import pyarrow as pa
169+ pd = pytest . importorskip ( "pandas" )
170+ pa = pytest . importorskip ( "pyarrow" )
167171
168172 # Integer
169173 dtypes = pandas_pyarrow_mapper
@@ -225,10 +229,10 @@ def check_inf(rng: RNG) -> None:
225229 X [5 , 2 ] = np .inf
226230
227231 with pytest .raises (ValueError , match = "Input data contains `inf`" ):
228- xgboost . QuantileDMatrix (X , y )
232+ QuantileDMatrix (X , y )
229233
230234 with pytest .raises (ValueError , match = "Input data contains `inf`" ):
231- xgboost . DMatrix (X , y )
235+ DMatrix (X , y )
232236
233237
234238@memory .cache
@@ -288,8 +292,10 @@ def get_ames_housing() -> Tuple[DataFrameT, np.ndarray]:
288292 Number of categorical features: 10
289293 Number of numerical features: 10
290294 """
291- pytest .importorskip ("pandas" )
292- import pandas as pd
295+ if TYPE_CHECKING :
296+ import pandas as pd
297+ else :
298+ pd = pytest .importorskip ("pandas" )
293299
294300 rng = np .random .default_rng (1994 )
295301 n_samples = 1460
@@ -664,7 +670,7 @@ def init_rank_score(
664670 y_train = y_train [sorted_idx ]
665671 qid_train = qid_train [sorted_idx ]
666672
667- ltr = xgboost . XGBRanker (objective = "rank:ndcg" , tree_method = "hist" )
673+ ltr = XGBRanker (objective = "rank:ndcg" , tree_method = "hist" )
668674 ltr .fit (X_train , y_train , qid = qid_train )
669675
670676 # Use the original order of the data.
@@ -799,9 +805,7 @@ def sort_ltr_samples(
799805 return data
800806
801807
802- def run_base_margin_info (
803- DType : Callable , DMatrixT : Type [xgboost .DMatrix ], device : str
804- ) -> None :
808+ def run_base_margin_info (DType : Callable , DMatrixT : Type [DMatrix ], device : str ) -> None :
805809 """Run tests for base margin."""
806810 rng = np .random .default_rng ()
807811 X = DType (rng .normal (0 , 1.0 , size = 100 ).astype (np .float32 ).reshape (50 , 2 ))
@@ -814,7 +818,7 @@ def run_base_margin_info(
814818 Xy = DMatrixT (X , y , base_margin = base_margin )
815819 # Error at train, caused by check in predictor.
816820 with pytest .raises (ValueError , match = r".*base_margin.*" ):
817- xgboost . train ({"tree_method" : "hist" , "device" : device }, Xy )
821+ train_fn ({"tree_method" : "hist" , "device" : device }, Xy )
818822
819823 if not hasattr (X , "iloc" ):
820824 # column major matrix
@@ -932,3 +936,102 @@ def random_csc(t_id: int) -> sparse.csc_matrix:
932936 return arr , y
933937
934938 return csr , y
939+
940+
941+ def unique_random_strings (n_strings : int , seed : int ) -> List [str ]:
942+ """Generate n unique strings."""
943+ name_len = 8 # hardcoded, should be more than enough
944+ unique_strings : Set [str ] = set ()
945+ rng = np .random .default_rng (seed )
946+
947+ while len (unique_strings ) < n_strings :
948+ random_str = "" .join (
949+ rng .choice (list (string .ascii_letters ), size = name_len , replace = True )
950+ )
951+ unique_strings .add (random_str )
952+
953+ return list (unique_strings )
954+
955+
956+ # pylint: disable=too-many-arguments,too-many-locals,too-many-branches
957+ def make_categorical (
958+ n_samples : int ,
959+ n_features : int ,
960+ n_categories : int ,
961+ * ,
962+ onehot : bool ,
963+ sparsity : float = 0.0 ,
964+ cat_ratio : float = 1.0 ,
965+ shuffle : bool = False ,
966+ random_state : int = 1994 ,
967+ cat_dtype : np .typing .DTypeLike = np .int64 ,
968+ ) -> Tuple [ArrayLike , np .ndarray ]:
969+ """Generate categorical features for test.
970+
971+ Parameters
972+ ----------
973+ n_categories:
974+ Number of categories for categorical features.
975+ onehot:
976+ Should we apply one-hot encoding to the data?
977+ sparsity:
978+ The ratio of the amount of missing values over the number of all entries.
979+ cat_ratio:
980+ The ratio of features that are categorical.
981+ shuffle:
982+ Whether we should shuffle the columns.
983+ cat_dtype :
984+ The dtype for categorical features, might be string or numeric.
985+
986+ Returns
987+ -------
988+ X, y
989+ """
990+ pd = pytest .importorskip ("pandas" )
991+
992+ rng = np .random .RandomState (random_state )
993+
994+ df = pd .DataFrame ()
995+ for i in range (n_features ):
996+ choice = rng .binomial (1 , cat_ratio , size = 1 )[0 ]
997+ if choice == 1 :
998+ if np .issubdtype (cat_dtype , np .str_ ):
999+ categories = np .array (unique_random_strings (n_categories , i ))
1000+ c = rng .choice (categories , size = n_samples , replace = True )
1001+ else :
1002+ categories = np .arange (0 , n_categories )
1003+ c = rng .randint (low = 0 , high = n_categories , size = n_samples )
1004+
1005+ df [str (i )] = pd .Series (c , dtype = "category" )
1006+ df [str (i )] = df [str (i )].cat .set_categories (categories )
1007+ else :
1008+ num = rng .randint (low = 0 , high = n_categories , size = n_samples )
1009+ df [str (i )] = pd .Series (num , dtype = num .dtype )
1010+
1011+ label = np .zeros (shape = (n_samples ,))
1012+ for col in df .columns :
1013+ if isinstance (df [col ].dtype , pd .CategoricalDtype ):
1014+ label += df [col ].cat .codes
1015+ else :
1016+ label += df [col ]
1017+ label += 1
1018+
1019+ if sparsity > 0.0 :
1020+ for i in range (n_features ):
1021+ index = rng .randint (
1022+ low = 0 , high = n_samples - 1 , size = int (n_samples * sparsity )
1023+ )
1024+ df .iloc [index , i ] = np .nan
1025+ if is_pd_cat_dtype (df .dtypes .iloc [i ]):
1026+ assert n_categories == np .unique (df .dtypes .iloc [i ].categories ).size
1027+
1028+ assert df .shape [1 ] == n_features
1029+ if onehot :
1030+ df = pd .get_dummies (df )
1031+
1032+ if shuffle :
1033+ columns = list (df .columns )
1034+ rng .shuffle (columns )
1035+ df = df [columns ]
1036+
1037+ return df , label
0 commit comments