diff --git a/examples/python-guide/dask/ranking.py b/examples/python-guide/dask/ranking.py index 570e74462cfd..0e80cfb9f5a9 100644 --- a/examples/python-guide/dask/ranking.py +++ b/examples/python-guide/dask/ranking.py @@ -30,7 +30,7 @@ # make this array dense because we're splitting across # a sparse boundary to partition the data - X = X.todense() + X = X.toarray() dX = da.from_array( x=X, diff --git a/include/LightGBM/utils/random.h b/include/LightGBM/utils/random.h index 98b1efc24df8..6f89f935b310 100644 --- a/include/LightGBM/utils/random.h +++ b/include/LightGBM/utils/random.h @@ -82,13 +82,10 @@ class Random { ret.push_back(i); } } - } else if (K == 1) { - int v = NextInt(0, N); - ret.push_back(v); } else { std::set sample_set; for (int r = N - K; r < N; ++r) { - int v = NextInt(0, r); + int v = NextInt(0, r + 1); if (!sample_set.insert(v).second) { sample_set.insert(r); } diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py index dc2143674be2..8af6253b46a2 100644 --- a/python-package/lightgbm/__init__.py +++ b/python-package/lightgbm/__init__.py @@ -3,7 +3,7 @@ Contributors: https://github.com/microsoft/LightGBM/graphs/contributors. """ -import os +from pathlib import Path from .basic import Booster, Dataset, Sequence, register_logger from .callback import early_stopping, print_evaluation, record_evaluation, reset_parameter @@ -23,11 +23,9 @@ pass -dir_path = os.path.dirname(os.path.realpath(__file__)) - -if os.path.isfile(os.path.join(dir_path, 'VERSION.txt')): - with open(os.path.join(dir_path, 'VERSION.txt')) as version_file: - __version__ = version_file.read().strip() +_version_path = Path(__file__).absolute().parent / 'VERSION.txt' +if _version_path.is_file(): + __version__ = _version_path.read_text(encoding='utf-8').strip() __all__ = ['Dataset', 'Booster', 'CVBooster', 'Sequence', 'register_logger', diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 1ad3c7da0a3b..c1a6ec02f16d 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -3,12 +3,14 @@ import abc import ctypes import json -import os import warnings from collections import OrderedDict from copy import deepcopy from functools import wraps from logging import Logger +from os import SEEK_END +from os.path import getsize +from pathlib import Path from tempfile import NamedTemporaryFile from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union @@ -243,7 +245,7 @@ def to_string(x): else: return str(x) pairs.append(f"{key}={','.join(map(to_string, val))}") - elif isinstance(val, (str, NUMERIC_TYPES)) or is_numeric(val): + elif isinstance(val, (str, Path, NUMERIC_TYPES)) or is_numeric(val): pairs.append(f"{key}={val}") elif val is not None: raise TypeError(f'Unknown type of parameter:{key}, got:{type(val).__name__}') @@ -251,23 +253,17 @@ def to_string(x): class _TempFile: + """Proxy class to workaround errors on Windows.""" + def __enter__(self): with NamedTemporaryFile(prefix="lightgbm_tmp_", delete=True) as f: self.name = f.name + self.path = Path(self.name) return self def __exit__(self, exc_type, exc_val, exc_tb): - if os.path.isfile(self.name): - os.remove(self.name) - - def readlines(self): - with open(self.name, "r+") as f: - ret = f.readlines() - return ret - - def writelines(self, lines): - with open(self.name, "w+") as f: - f.writelines(lines) + if self.path.is_file(): + self.path.unlink() class LightGBMError(Exception): @@ -584,12 +580,12 @@ def _load_pandas_categorical(file_name=None, model_str=None): pandas_key = 'pandas_categorical:' offset = -len(pandas_key) if file_name is not None: - max_offset = -os.path.getsize(file_name) + max_offset = -getsize(file_name) with open(file_name, 'rb') as f: while True: if offset < max_offset: offset = max_offset - f.seek(offset, os.SEEK_END) + f.seek(offset, SEEK_END) lines = f.readlines() if len(lines) >= 2: break @@ -685,7 +681,7 @@ def __init__(self, model_file=None, booster_handle=None, pred_parameter=None): Parameters ---------- - model_file : string or None, optional (default=None) + model_file : string, pathlib.Path or None, optional (default=None) Path to the model file. booster_handle : object or None, optional (default=None) Handle of Booster. @@ -698,7 +694,7 @@ def __init__(self, model_file=None, booster_handle=None, pred_parameter=None): """Prediction task""" out_num_iterations = ctypes.c_int(0) _safe_call(_LIB.LGBM_BoosterCreateFromModelfile( - c_str(model_file), + c_str(str(model_file)), ctypes.byref(out_num_iterations), ctypes.byref(self.handle))) out_num_class = ctypes.c_int(0) @@ -743,9 +739,9 @@ def predict(self, data, start_iteration=0, num_iteration=-1, Parameters ---------- - data : string, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse + data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse Data source for prediction. - When data type is string, it represents the path of txt file. + When data type is string or pathlib.Path, it represents the path of txt file. start_iteration : int, optional (default=0) Start index of the iteration to predict. num_iteration : int, optional (default=-1) @@ -780,21 +776,19 @@ def predict(self, data, start_iteration=0, num_iteration=-1, predict_type = C_API_PREDICT_CONTRIB int_data_has_header = 1 if data_has_header else 0 - if isinstance(data, str): + if isinstance(data, (str, Path)): with _TempFile() as f: _safe_call(_LIB.LGBM_BoosterPredictForFile( self.handle, - c_str(data), + c_str(str(data)), ctypes.c_int(int_data_has_header), ctypes.c_int(predict_type), ctypes.c_int(start_iteration), ctypes.c_int(num_iteration), c_str(self.pred_parameter), c_str(f.name))) - lines = f.readlines() - nrow = len(lines) - preds = [float(token) for line in lines for token in line.split('\t')] - preds = np.array(preds, dtype=np.float64, copy=False) + preds = np.loadtxt(f.name, dtype=np.float64) + nrow = preds.shape[0] elif isinstance(data, scipy.sparse.csr_matrix): preds, nrow = self.__pred_for_csr(data, start_iteration, num_iteration, predict_type) elif isinstance(data, scipy.sparse.csc_matrix): @@ -829,9 +823,9 @@ def predict(self, data, start_iteration=0, num_iteration=-1, def __get_num_preds(self, start_iteration, num_iteration, nrow, predict_type): """Get size of prediction result.""" if nrow > MAX_INT32: - raise LightGBMError('LightGBM cannot perform prediction for data' + raise LightGBMError('LightGBM cannot perform prediction for data ' f'with number of rows greater than MAX_INT32 ({MAX_INT32}).\n' - 'You can split your data into chunks' + 'You can split your data into chunks ' 'and then concatenate predictions for them') n_preds = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterCalcNumPredict( @@ -1133,9 +1127,9 @@ def __init__(self, data, label=None, reference=None, Parameters ---------- - data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays + data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays Data source of Dataset. - If string, it represents the path to txt file. + If string or pathlib.Path, it represents the path to txt file. label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) Label of the data. reference : Dataset or None, optional (default=None) @@ -1384,7 +1378,7 @@ def _free_handle(self): def _set_init_score_by_predictor(self, predictor, data, used_indices=None): data_has_header = False - if isinstance(data, str): + if isinstance(data, (str, Path)): # check data has header or not data_has_header = any(self.params.get(alias, False) for alias in _ConfigAliases.get("header")) num_data = self.num_data() @@ -1395,7 +1389,7 @@ def _set_init_score_by_predictor(self, predictor, data, used_indices=None): is_reshape=False) if used_indices is not None: assert not self.need_slice - if isinstance(data, str): + if isinstance(data, (str, Path)): sub_init_score = np.empty(num_data * predictor.num_class, dtype=np.float32) assert num_data == len(used_indices) for i in range(len(used_indices)): @@ -1472,10 +1466,10 @@ def _lazy_init(self, data, label=None, reference=None, elif reference is not None: raise TypeError('Reference dataset should be None or dataset instance') # start construct data - if isinstance(data, str): + if isinstance(data, (str, Path)): self.handle = ctypes.c_void_p() _safe_call(_LIB.LGBM_DatasetCreateFromFile( - c_str(data), + c_str(str(data)), c_str(params_str), ref_dataset, ctypes.byref(self.handle))) @@ -1775,9 +1769,9 @@ def create_valid(self, data, label=None, weight=None, group=None, Parameters ---------- - data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays + data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays Data source of Dataset. - If string, it represents the path to txt file. + If string or pathlib.Path, it represents the path to txt file. label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) Label of the data. weight : list, numpy 1-D array, pandas Series or None, optional (default=None) @@ -1842,7 +1836,7 @@ def save_binary(self, filename): Parameters ---------- - filename : string + filename : string or pathlib.Path Name of the output file. Returns @@ -1852,7 +1846,7 @@ def save_binary(self, filename): """ _safe_call(_LIB.LGBM_DatasetSaveBinary( self.construct().handle, - c_str(filename))) + c_str(str(filename)))) return self def _update_params(self, params): @@ -2242,7 +2236,7 @@ def get_data(self): Returns ------- - data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, list of numpy arrays or None + data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, list of numpy arrays or None Raw data used in the Dataset construction. """ if self.handle is None: @@ -2442,7 +2436,7 @@ def _dump_text(self, filename): Parameters ---------- - filename : string + filename : string or pathlib.Path Name of the output file. Returns @@ -2452,7 +2446,7 @@ def _dump_text(self, filename): """ _safe_call(_LIB.LGBM_DatasetDumpText( self.construct().handle, - c_str(filename))) + c_str(str(filename)))) return self @@ -2468,7 +2462,7 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None, Parameters for Booster. train_set : Dataset or None, optional (default=None) Training dataset. - model_file : string or None, optional (default=None) + model_file : string, pathlib.Path or None, optional (default=None) Path to the model file. model_str : string or None, optional (default=None) Model will be loaded from this string. @@ -2561,7 +2555,7 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None, out_num_iterations = ctypes.c_int(0) self.handle = ctypes.c_void_p() _safe_call(_LIB.LGBM_BoosterCreateFromModelfile( - c_str(model_file), + c_str(str(model_file)), ctypes.byref(out_num_iterations), ctypes.byref(self.handle))) out_num_class = ctypes.c_int(0) @@ -3200,7 +3194,7 @@ def save_model(self, filename, num_iteration=None, start_iteration=0, importance Parameters ---------- - filename : string + filename : string or pathlib.Path Filename to save Booster. num_iteration : int or None, optional (default=None) Index of the iteration that should be saved. @@ -3226,7 +3220,7 @@ def save_model(self, filename, num_iteration=None, start_iteration=0, importance ctypes.c_int(start_iteration), ctypes.c_int(num_iteration), ctypes.c_int(importance_type_int), - c_str(filename))) + c_str(str(filename)))) _dump_pandas_categorical(self.pandas_categorical, filename) return self @@ -3400,9 +3394,9 @@ def predict(self, data, start_iteration=0, num_iteration=None, Parameters ---------- - data : string, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse + data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse Data source for prediction. - If string, it represents the path to txt file. + If string or pathlib.Path, it represents the path to txt file. start_iteration : int, optional (default=0) Start index of the iteration to predict. If <= 0, starts from the first iteration. @@ -3455,9 +3449,9 @@ def refit(self, data, label, decay_rate=0.9, **kwargs): Parameters ---------- - data : string, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse + data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse Data source for refit. - If string, it represents the path to txt file. + If string or pathlib.Path, it represents the path to txt file. label : list, numpy 1-D array or pandas Series / one-column DataFrame Label for refit. decay_rate : float, optional (default=0.9) diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 068a667b2603..52726622f076 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -125,6 +125,8 @@ class _LGBMRegressorBase: # type: ignore try: from dask import delayed from dask.array import Array as dask_Array + from dask.array import from_delayed as dask_array_from_delayed + from dask.bag import from_delayed as dask_bag_from_delayed from dask.dataframe import DataFrame as dask_DataFrame from dask.dataframe import Series as dask_Series from dask.distributed import Client, default_client, wait @@ -132,6 +134,8 @@ class _LGBMRegressorBase: # type: ignore except ImportError: DASK_INSTALLED = False + dask_array_from_delayed = None + dask_bag_from_delayed = None delayed = None default_client = None wait = None diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 39919d96ad58..107cf218d861 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -10,6 +10,7 @@ from collections import defaultdict from copy import deepcopy from enum import Enum, auto +from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union from urllib.parse import urlparse @@ -18,7 +19,8 @@ from .basic import _LIB, LightGBMError, _choose_param_value, _ConfigAliases, _log_info, _log_warning, _safe_call from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat, - dask_Array, dask_DataFrame, dask_Series, default_client, delayed, pd_DataFrame, pd_Series, wait) + dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series, + default_client, delayed, pd_DataFrame, pd_Series, wait) from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit, _lgbmmodel_doc_predict) @@ -842,7 +844,7 @@ def _predict( pred_contrib: bool = False, dtype: _PredictionDtype = np.float32, **kwargs: Any -) -> dask_Array: +) -> Union[dask_Array, List[dask_Array]]: """Inner predict routine. Parameters @@ -870,7 +872,7 @@ def _predict( The predicted values. X_leaves : Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes] If ``pred_leaf=True``, the predicted leaf of every tree for each sample. - X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] + X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1] If ``pred_contrib=True``, the feature contributions for each sample. """ if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): @@ -886,6 +888,74 @@ def _predict( **kwargs ).values elif isinstance(data, dask_Array): + # for multi-class classification with sparse matrices, pred_contrib predictions + # are returned as a list of sparse matrices (one per class) + num_classes = model._n_classes or -1 + + if ( + num_classes > 2 + and pred_contrib + and isinstance(data._meta, ss.spmatrix) + ): + + predict_function = partial( + _predict_part, + model=model, + raw_score=False, + pred_proba=pred_proba, + pred_leaf=False, + pred_contrib=True, + **kwargs + ) + + delayed_chunks = data.to_delayed() + bag = dask_bag_from_delayed(delayed_chunks[:, 0]) + + @delayed + def _extract(items: List[Any], i: int) -> Any: + return items[i] + + preds = bag.map_partitions(predict_function) + + # pred_contrib output will have one column per feature, + # plus one more for the base value + num_cols = model.n_features_ + 1 + + nrows_per_chunk = data.chunks[0] + out = [[] for _ in range(num_classes)] + + # need to tell Dask the expected type and shape of individual preds + pred_meta = data._meta + + for j, partition in enumerate(preds.to_delayed()): + for i in range(num_classes): + part = dask_array_from_delayed( + value=_extract(partition, i), + shape=(nrows_per_chunk[j], num_cols), + meta=pred_meta + ) + out[i].append(part) + + # by default, dask.array.concatenate() concatenates sparse arrays into a COO matrix + # the code below is used instead to ensure that the sparse type is preserved during concatentation + if isinstance(pred_meta, ss.csr_matrix): + concat_fn = partial(ss.vstack, format='csr') + elif isinstance(pred_meta, ss.csc_matrix): + concat_fn = partial(ss.vstack, format='csc') + else: + concat_fn = ss.vstack + + # At this point, `out` is a list of lists of delayeds (each of which points to a matrix). + # Concatenate them to return a list of Dask Arrays. + for i in range(num_classes): + out[i] = dask_array_from_delayed( + value=delayed(concat_fn)(out[i]), + shape=(data.shape[0], num_cols), + meta=pred_meta + ) + + return out + return data.map_blocks( _predict_part, model=model, @@ -1140,7 +1210,7 @@ def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: output_name="predicted_result", predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]", X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", - X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]" + X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]" ) def predict_proba(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: @@ -1158,7 +1228,7 @@ def predict_proba(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: output_name="predicted_probability", predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]", X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", - X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]" + X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]" ) def to_local(self) -> LGBMClassifier: diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index fade2d925c2f..fc77ff7a9d4a 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -3,6 +3,7 @@ import collections import copy from operator import attrgetter +from pathlib import Path import numpy as np @@ -76,7 +77,7 @@ def train(params, train_set, num_boost_round=100, If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i]. To ignore the default metric corresponding to the used objective, set the ``metric`` parameter to the string ``"None"`` in ``params``. - init_model : string, Booster or None, optional (default=None) + init_model : string, pathlib.Path, Booster or None, optional (default=None) Filename of LightGBM model or Booster instance used for continue training. feature_name : list of strings or 'auto', optional (default="auto") Feature names. @@ -161,7 +162,7 @@ def train(params, train_set, num_boost_round=100, if num_boost_round <= 0: raise ValueError("num_boost_round should be greater than zero.") - if isinstance(init_model, str): + if isinstance(init_model, (str, Path)): predictor = _InnerPredictor(model_file=init_model, pred_parameter=params) elif isinstance(init_model, Booster): predictor = init_model._to_predictor(dict(init_model.params, **params)) @@ -470,7 +471,7 @@ def cv(params, train_set, num_boost_round=100, If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i]. To ignore the default metric corresponding to the used objective, set ``metrics`` to the string ``"None"``. - init_model : string, Booster or None, optional (default=None) + init_model : string, pathlib.Path, Booster or None, optional (default=None) Filename of LightGBM model or Booster instance used for continue training. feature_name : list of strings or 'auto', optional (default="auto") Feature names. @@ -545,7 +546,7 @@ def cv(params, train_set, num_boost_round=100, if num_boost_round <= 0: raise ValueError("num_boost_round should be greater than zero.") - if isinstance(init_model, str): + if isinstance(init_model, (str, Path)): predictor = _InnerPredictor(model_file=init_model, pred_parameter=params) elif isinstance(init_model, Booster): predictor = init_model._to_predictor(dict(init_model.params, **params)) diff --git a/python-package/lightgbm/libpath.py b/python-package/lightgbm/libpath.py index 653379288706..7ad1c65e1c59 100644 --- a/python-package/lightgbm/libpath.py +++ b/python-package/lightgbm/libpath.py @@ -1,6 +1,7 @@ # coding: utf-8 """Find the path to LightGBM dynamic library files.""" -import os +from os import environ +from pathlib import Path from platform import system from typing import List @@ -13,27 +14,26 @@ def find_lib_path() -> List[str]: lib_path: list of strings List of all found library paths to LightGBM. """ - if os.environ.get('LIGHTGBM_BUILD_DOC', False): + if environ.get('LIGHTGBM_BUILD_DOC', False): # we don't need lib_lightgbm while building docs return [] - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + curr_path = Path(__file__).absolute().parent dll_path = [curr_path, - os.path.join(curr_path, '../../'), - os.path.join(curr_path, 'compile'), - os.path.join(curr_path, '../compile'), - os.path.join(curr_path, '../../lib/')] + curr_path.parents[1], + curr_path / 'compile', + curr_path.parent / 'compile', + curr_path.parents[1] / 'lib'] if system() in ('Windows', 'Microsoft'): - dll_path.append(os.path.join(curr_path, '../compile/Release/')) - dll_path.append(os.path.join(curr_path, '../compile/windows/x64/DLL/')) - dll_path.append(os.path.join(curr_path, '../../Release/')) - dll_path.append(os.path.join(curr_path, '../../windows/x64/DLL/')) - dll_path = [os.path.join(p, 'lib_lightgbm.dll') for p in dll_path] + dll_path.append(curr_path.parent / 'compile' / 'Release') + dll_path.append(curr_path.parent / 'compile' / 'windows' / 'x64' / 'DLL') + dll_path.append(curr_path.parents[1] / 'Release') + dll_path.append(curr_path.parents[1] / 'windows' / 'x64' / 'DLL') + dll_path = [p / 'lib_lightgbm.dll' for p in dll_path] else: - dll_path = [os.path.join(p, 'lib_lightgbm.so') for p in dll_path] - lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] + dll_path = [p / 'lib_lightgbm.so' for p in dll_path] + lib_path = [str(p) for p in dll_path if p.is_file()] if not lib_path: - dll_path = [os.path.realpath(p) for p in dll_path] - new_line = "\n" - raise Exception(f'Cannot find lightgbm library file in following paths:{new_line}{new_line.join(dll_path)}') + dll_path_joined = '\n'.join(map(str, dll_path)) + raise Exception(f'Cannot find lightgbm library file in following paths:\n{dll_path_joined}') return lib_path diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 9e4aee82af4b..565ed8c10c9d 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -256,7 +256,7 @@ def __call__(self, preds, dataset): callbacks : list of callback functions or None, optional (default=None) List of callback functions that are applied at each iteration. See Callbacks in Python API for more information. - init_model : string, Booster, LGBMModel or None, optional (default=None) + init_model : string, pathlib.Path, Booster, LGBMModel or None, optional (default=None) Filename of LightGBM model, Booster instance or LGBMModel instance used for continue training. Returns diff --git a/python-package/setup.py b/python-package/setup.py index b82d4a9e0a94..45bc340c8540 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -4,6 +4,7 @@ import struct import subprocess import sys +from os import chdir from pathlib import Path from platform import system from shutil import copyfile, copytree, rmtree @@ -79,17 +80,13 @@ def copy_files_helper(folder_name: Union[str, Path]) -> None: CURRENT_DIR / "compile" / "cmake" / "IntegratedOpenCL.cmake") -def clear_path(path: str) -> None: - import os - path = str(path) - if os.path.isdir(path): - contents = os.listdir(path) - for file_name in contents: - file_path = os.path.join(path, file_name) - if os.path.isfile(file_path): - os.remove(file_path) +def clear_path(path: Path) -> None: + if path.is_dir(): + for file_name in path.iterdir(): + if file_name.is_dir(): + rmtree(file_name) else: - rmtree(file_path) + file_name.unlink() def silent_call(cmd: List[str], raise_error: bool = False, error_msg: str = '') -> int: @@ -122,6 +119,8 @@ def compile_cpp( build_dir = CURRENT_DIR / "build_cpp" rmtree(build_dir, ignore_errors=True) build_dir.mkdir(parents=True) + original_dir = Path.cwd() + chdir(build_dir) logger.info("Starting to compile the library.") @@ -199,6 +198,7 @@ def compile_cpp( silent_call(cmake_cmd, raise_error=True, error_msg='Please install CMake and all required dependencies first') silent_call(["make", "_lightgbm", f"-I{build_dir}", "-j4"], raise_error=True, error_msg='An error has occurred while building lightgbm library file') + chdir(original_dir) class CustomInstallLib(install_lib): diff --git a/tests/distributed/_test_distributed.py b/tests/distributed/_test_distributed.py index 64ffa2b22399..9e1dd8e4f5a4 100644 --- a/tests/distributed/_test_distributed.py +++ b/tests/distributed/_test_distributed.py @@ -1,9 +1,9 @@ import copy import io -import os import socket import subprocess from concurrent.futures import ThreadPoolExecutor +from pathlib import Path from typing import Any, Dict, Generator, List import numpy as np @@ -11,7 +11,7 @@ from sklearn.datasets import make_blobs, make_regression from sklearn.metrics import accuracy_score -TESTS_DIR = os.path.abspath(os.path.dirname(__file__)) +TESTS_DIR = Path(__file__).absolute().parent @pytest.fixture(scope='module') @@ -57,7 +57,7 @@ class DistributedMockup: default_train_config = { 'task': 'train', 'pre_partition': True, - 'machine_list_file': os.path.join(TESTS_DIR, 'mlist.txt'), + 'machine_list_file': TESTS_DIR / 'mlist.txt', 'tree_learner': 'data', 'force_row_wise': True, 'verbose': 0, @@ -68,9 +68,9 @@ class DistributedMockup: default_predict_config = { 'task': 'predict', - 'data': os.path.join(TESTS_DIR, 'train.txt'), - 'input_model': os.path.join(TESTS_DIR, 'model0.txt'), - 'output_result': os.path.join(TESTS_DIR, 'predictions.txt'), + 'data': TESTS_DIR / 'train.txt', + 'input_model': TESTS_DIR / 'model0.txt', + 'output_result': TESTS_DIR / 'predictions.txt', } def __init__(self, executable: str): @@ -78,7 +78,7 @@ def __init__(self, executable: str): def worker_train(self, i: int) -> subprocess.CompletedProcess: """Start the training process on the `i`-th worker.""" - config_path = os.path.join(TESTS_DIR, f'train{i}.conf') + config_path = TESTS_DIR / f'train{i}.conf' cmd = [self.executable, f'config={config_path}'] return subprocess.run(cmd) @@ -95,16 +95,16 @@ def _set_ports(self) -> None: if i == max_tries: raise RuntimeError('Unable to find non-colliding ports.') self.listen_ports = list(ports) - with open(os.path.join(TESTS_DIR, 'mlist.txt'), 'wt') as f: + with open(TESTS_DIR / 'mlist.txt', 'wt') as f: for port in self.listen_ports: f.write(f'127.0.0.1 {port}\n') def _write_data(self, partitions: List[np.ndarray]) -> None: """Write all training data as train.txt and each training partition as train{i}.txt.""" all_data = np.vstack(partitions) - np.savetxt(os.path.join(TESTS_DIR, 'train.txt'), all_data, delimiter=',') + np.savetxt(str(TESTS_DIR / 'train.txt'), all_data, delimiter=',') for i, partition in enumerate(partitions): - np.savetxt(os.path.join(TESTS_DIR, f'train{i}.txt'), partition, delimiter=',') + np.savetxt(str(TESTS_DIR / f'train{i}.txt'), partition, delimiter=',') def fit(self, partitions: List[np.ndarray], train_config: Dict = {}) -> None: """Run the distributed training process on a single machine. @@ -142,14 +142,14 @@ def predict(self, predict_config: Dict[str, Any] = {}) -> np.ndarray: """ self.predict_config = copy.deepcopy(self.default_predict_config) self.predict_config.update(predict_config) - config_path = os.path.join(TESTS_DIR, 'predict.conf') + config_path = TESTS_DIR / 'predict.conf' with open(config_path, 'wt') as file: _write_dict(self.predict_config, file) cmd = [self.executable, f'config={config_path}'] result = subprocess.run(cmd) if result.returncode != 0: - raise RuntimeError - y_pred = np.loadtxt(os.path.join(TESTS_DIR, 'predictions.txt')) + raise RuntimeError('Error in prediction') + y_pred = np.loadtxt(str(TESTS_DIR / 'predictions.txt')) return y_pred def write_train_config(self, i: int) -> None: @@ -158,9 +158,9 @@ def write_train_config(self, i: int) -> None: Each worker gets a different port and piece of the data, the rest are the model parameters contained in `self.config`. """ - with open(os.path.join(TESTS_DIR, f'train{i}.conf'), 'wt') as file: - output_model = os.path.join(TESTS_DIR, f'model{i}.txt') - data = os.path.join(TESTS_DIR, f'train{i}.txt') + with open(TESTS_DIR / f'train{i}.conf', 'wt') as file: + output_model = TESTS_DIR / f'model{i}.txt' + data = TESTS_DIR / f'train{i}.txt' file.write(f'output_model = {output_model}\n') file.write(f'local_listen_port = {self.listen_ports[i]}\n') file.write(f'data = {data}\n') diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py index d5db71c69513..9df13e8202d6 100644 --- a/tests/distributed/conftest.py +++ b/tests/distributed/conftest.py @@ -1,8 +1,7 @@ -import os +from pathlib import Path -TESTS_DIR = os.path.dirname(__file__) -default_exec_file = os.path.abspath(os.path.join(TESTS_DIR, '..', '..', 'lightgbm')) +default_exec_file = Path(__file__).absolute().parents[2] / 'lightgbm' def pytest_addoption(parser): - parser.addoption('--execfile', action='store', default=default_exec_file) + parser.addoption('--execfile', action='store', default=str(default_exec_file)) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index c32b2ad091d4..445205528cad 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -49,8 +49,8 @@ def test_basic(tmp_path): assert bst.lower_bound() == pytest.approx(-2.9040190126976606) assert bst.upper_bound() == pytest.approx(3.3182142872462883) - tname = str(tmp_path / "svm_light.dat") - model_file = str(tmp_path / "model.txt") + tname = tmp_path / "svm_light.dat" + model_file = tmp_path / "model.txt" bst.save_model(model_file) pred_from_matr = bst.predict(X_test) @@ -153,8 +153,8 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq): X = data[:, :-1] Y = data[:, -1] - npy_bin_fname = str(tmpdir / 'data_from_npy.bin') - seq_bin_fname = str(tmpdir / 'data_from_seq.bin') + npy_bin_fname = tmpdir / 'data_from_npy.bin' + seq_bin_fname = tmpdir / 'data_from_seq.bin' # Create dataset from numpy array directly. ds = lgb.Dataset(X, label=Y, params=params) @@ -175,9 +175,9 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq): valid_X = valid_data[:, :-1] valid_Y = valid_data[:, -1] - valid_npy_bin_fname = str(tmpdir / 'valid_data_from_npy.bin') - valid_seq_bin_fname = str(tmpdir / 'valid_data_from_seq.bin') - valid_seq2_bin_fname = str(tmpdir / 'valid_data_from_seq2.bin') + valid_npy_bin_fname = tmpdir / 'valid_data_from_npy.bin' + valid_seq_bin_fname = tmpdir / 'valid_data_from_seq.bin' + valid_seq2_bin_fname = tmpdir / 'valid_data_from_seq2.bin' valid_ds = lgb.Dataset(valid_X, label=valid_Y, params=params, reference=ds) valid_ds.save_binary(valid_npy_bin_fname) @@ -268,10 +268,10 @@ def test_add_features_equal_data_on_alternating_used_unused(tmp_path): d1 = lgb.Dataset(X[:, :j], feature_name=names[:j]).construct() d2 = lgb.Dataset(X[:, j:], feature_name=names[j:]).construct() d1.add_features_from(d2) - d1name = str(tmp_path / "d1.txt") + d1name = tmp_path / "d1.txt" d1._dump_text(d1name) d = lgb.Dataset(X, feature_name=names).construct() - dname = str(tmp_path / "d.txt") + dname = tmp_path / "d.txt" d._dump_text(dname) with open(d1name, 'rt') as d1f: d1txt = d1f.read() @@ -297,8 +297,8 @@ def test_add_features_same_booster_behaviour(tmp_path): for k in range(10): b.update() b1.update() - dname = str(tmp_path / "d.txt") - d1name = str(tmp_path / "d1.txt") + dname = tmp_path / "d.txt" + d1name = tmp_path / "d1.txt" b1.save_model(d1name) b.save_model(dname) with open(dname, 'rt') as df: @@ -352,7 +352,7 @@ def test_cegb_affects_behavior(tmp_path): base = lgb.Booster(train_set=ds) for k in range(10): base.update() - basename = str(tmp_path / "basename.txt") + basename = tmp_path / "basename.txt" base.save_model(basename) with open(basename, 'rt') as f: basetxt = f.read() @@ -364,7 +364,7 @@ def test_cegb_affects_behavior(tmp_path): booster = lgb.Booster(train_set=ds, params=case) for k in range(10): booster.update() - casename = str(tmp_path / "casename.txt") + casename = tmp_path / "casename.txt" booster.save_model(casename) with open(casename, 'rt') as f: casetxt = f.read() @@ -391,13 +391,13 @@ def test_cegb_scaling_equalities(tmp_path): for k in range(10): booster1.update() booster2.update() - p1name = str(tmp_path / "p1.txt") + p1name = tmp_path / "p1.txt" # Reset booster1's parameters to p2, so the parameter section of the file matches. booster1.reset_parameter(p2) booster1.save_model(p1name) with open(p1name, 'rt') as f: p1txt = f.read() - p2name = str(tmp_path / "p2.txt") + p2name = tmp_path / "p2.txt" booster2.save_model(p2name) with open(p2name, 'rt') as f: p2txt = f.read() diff --git a/tests/python_package_test/test_consistency.py b/tests/python_package_test/test_consistency.py index 4f9bc89e7fc7..a9cb8436a847 100644 --- a/tests/python_package_test/test_consistency.py +++ b/tests/python_package_test/test_consistency.py @@ -24,7 +24,7 @@ def __init__(self, directory, prefix, config_file='train.conf'): self.params[key] = value if key != 'num_trees' else int(value) def load_dataset(self, suffix, is_sparse=False): - filename = self.path(suffix) + filename = str(self.path(suffix)) if is_sparse: X, Y = load_svmlight_file(filename, dtype=np.float64, zero_based=True) return X, Y, filename @@ -62,7 +62,7 @@ def file_load_check(self, lgb_train, name): assert a == b, f def path(self, suffix): - return str(self.directory / f'{self.prefix}{suffix}') + return self.directory / f'{self.prefix}{suffix}' def test_binary(): diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 98f738ddcb30..2c0d4089c990 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -28,7 +28,7 @@ from dask.array.utils import assert_eq from dask.distributed import Client, LocalCluster, default_client, wait from pkg_resources import parse_version -from scipy.sparse import csr_matrix +from scipy.sparse import csc_matrix, csr_matrix from scipy.stats import spearmanr from sklearn import __version__ as sk_version from sklearn.datasets import make_blobs, make_regression @@ -198,6 +198,12 @@ def _create_data(objective, n_samples=1_000, output='array', chunk_size=500, **k dX = da.from_array(X, chunks=(chunk_size, X.shape[1])).map_blocks(csr_matrix) dy = da.from_array(y, chunks=chunk_size) dw = da.from_array(weights, chunk_size) + X = csr_matrix(X) + elif output == 'scipy_csc_matrix': + dX = da.from_array(X, chunks=(chunk_size, X.shape[1])).map_blocks(csc_matrix) + dy = da.from_array(y, chunks=chunk_size) + dw = da.from_array(weights, chunk_size) + X = csc_matrix(X) else: raise ValueError(f"Unknown output type '{output}'") @@ -344,7 +350,7 @@ def test_classifier(output, task, boosting_type, tree_learner, cluster): assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '==' -@pytest.mark.parametrize('output', data_output) +@pytest.mark.parametrize('output', data_output + ['scipy_csc_matrix']) @pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification']) def test_classifier_pred_contrib(output, task, cluster): with Client(cluster) as client: @@ -365,14 +371,52 @@ def test_classifier_pred_contrib(output, task, cluster): **params ) dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw) - preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True).compute() + preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True) local_classifier = lgb.LGBMClassifier(**params) local_classifier.fit(X, y, sample_weight=w) local_preds_with_contrib = local_classifier.predict(X, pred_contrib=True) - if output == 'scipy_csr_matrix': - preds_with_contrib = np.array(preds_with_contrib.todense()) + # shape depends on whether it is binary or multiclass classification + num_features = dask_classifier.n_features_ + num_classes = dask_classifier.n_classes_ + if num_classes == 2: + expected_num_cols = num_features + 1 + else: + expected_num_cols = (num_features + 1) * num_classes + + # in the special case of multi-class classification using scipy sparse matrices, + # the output of `.predict(..., pred_contrib=True)` is a list of sparse matrices (one per class) + # + # since that case is so different than all other cases, check the relevant things here + # and then return early + if output.startswith('scipy') and task == 'multiclass-classification': + if output == 'scipy_csr_matrix': + expected_type = csr_matrix + elif output == 'scipy_csc_matrix': + expected_type = csc_matrix + else: + raise ValueError(f"Unrecognized output type: {output}") + assert isinstance(preds_with_contrib, list) + assert all(isinstance(arr, da.Array) for arr in preds_with_contrib) + assert all(isinstance(arr._meta, expected_type) for arr in preds_with_contrib) + assert len(preds_with_contrib) == num_classes + assert len(preds_with_contrib) == len(local_preds_with_contrib) + for i in range(num_classes): + computed_preds = preds_with_contrib[i].compute() + assert isinstance(computed_preds, expected_type) + assert computed_preds.shape[1] == num_classes + assert computed_preds.shape == local_preds_with_contrib[i].shape + assert len(np.unique(computed_preds[:, -1])) == 1 + # raw scores will probably be different, but at least check that all predicted classes are the same + pred_classes = np.argmax(computed_preds.toarray(), axis=1) + local_pred_classes = np.argmax(local_preds_with_contrib[i].toarray(), axis=1) + np.testing.assert_array_equal(pred_classes, local_pred_classes) + return + + preds_with_contrib = preds_with_contrib.compute() + if output.startswith('scipy'): + preds_with_contrib = preds_with_contrib.toarray() # be sure LightGBM actually used at least one categorical column, # and that it was correctly treated as a categorical feature @@ -386,14 +430,6 @@ def test_classifier_pred_contrib(output, task, cluster): assert node_uses_cat_col.sum() > 0 assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '==' - # shape depends on whether it is binary or multiclass classification - num_features = dask_classifier.n_features_ - num_classes = dask_classifier.n_classes_ - if num_classes == 2: - expected_num_cols = num_features + 1 - else: - expected_num_cols = (num_features + 1) * num_classes - # * shape depends on whether it is binary or multiclass classification # * matrix for binary classification is of the form [feature_contrib, base_value], # for multi-class it's [feat_contrib_class1, base_value_class1, feat_contrib_class2, base_value_class2, etc.] @@ -403,7 +439,7 @@ def test_classifier_pred_contrib(output, task, cluster): assert preds_with_contrib.shape == local_preds_with_contrib.shape if num_classes == 2: - assert len(np.unique(preds_with_contrib[:, num_features]) == 1) + assert len(np.unique(preds_with_contrib[:, num_features])) == 1 else: for i in range(num_classes): base_value_col = num_features * (i + 1) + i @@ -585,7 +621,7 @@ def test_regressor_pred_contrib(output, cluster): local_preds_with_contrib = local_regressor.predict(X, pred_contrib=True) if output == "scipy_csr_matrix": - preds_with_contrib = np.array(preds_with_contrib.todense()) + preds_with_contrib = preds_with_contrib.toarray() # contrib outputs for distributed training are different than from local training, so we can just test # that the output has the right shape and base values are in the right position diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index ed45fa1554f4..20c4436149b7 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -1087,7 +1087,7 @@ def test_contribs_sparse_multiclass(): # convert data to dense and get back same contribs contribs_dense = gbm.predict(X_test.toarray(), pred_contrib=True) # validate the values are the same - contribs_csr_array = np.swapaxes(np.array([sparse_array.todense() for sparse_array in contribs_csr]), 0, 1) + contribs_csr_array = np.swapaxes(np.array([sparse_array.toarray() for sparse_array in contribs_csr]), 0, 1) contribs_csr_arr_re = contribs_csr_array.reshape((contribs_csr_array.shape[0], contribs_csr_array.shape[1] * contribs_csr_array.shape[2])) if platform.machine() == 'aarch64': @@ -1103,7 +1103,7 @@ def test_contribs_sparse_multiclass(): for perclass_contribs_csc in contribs_csc: assert isspmatrix_csc(perclass_contribs_csc) # validate the values are the same - contribs_csc_array = np.swapaxes(np.array([sparse_array.todense() for sparse_array in contribs_csc]), 0, 1) + contribs_csc_array = np.swapaxes(np.array([sparse_array.toarray() for sparse_array in contribs_csc]), 0, 1) contribs_csc_array = contribs_csc_array.reshape((contribs_csc_array.shape[0], contribs_csc_array.shape[1] * contribs_csc_array.shape[2])) if platform.machine() == 'aarch64': @@ -2261,7 +2261,7 @@ def test_forced_bins(): x[:, 0] = np.arange(0, 1, 0.01) x[:, 1] = -np.arange(0, 1, 0.01) y = np.arange(0, 1, 0.01) - forcedbins_filename = str( + forcedbins_filename = ( Path(__file__).absolute().parents[2] / 'examples' / 'regression' / 'forced_bins.json' ) params = {'objective': 'regression_l1', @@ -2285,7 +2285,7 @@ def test_forced_bins(): est = lgb.train(params, lgb_x, num_boost_round=20) predicted = est.predict(new_x) assert len(np.unique(predicted)) == 3 - params['forcedbins_filename'] = str( + params['forcedbins_filename'] = ( Path(__file__).absolute().parents[2] / 'examples' / 'regression' / 'forced_bins2.json' ) params['max_bin'] = 11