Skip to content

Commit

Permalink
[python] allow to pass some params as pathlib.Path objects (#4440)
Browse files Browse the repository at this point in the history
* allow to pass some params as pathlib.Path objects

* fix lint

* improve indentation
  • Loading branch information
StrikerRUS authored Jul 7, 2021
1 parent b09da43 commit 90342e9
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 95 deletions.
10 changes: 4 additions & 6 deletions python-package/lightgbm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Contributors: https://github.com/microsoft/LightGBM/graphs/contributors.
"""
import os
from pathlib import Path

from .basic import Booster, Dataset, Sequence, register_logger
from .callback import early_stopping, print_evaluation, record_evaluation, reset_parameter
Expand All @@ -23,11 +23,9 @@
pass


dir_path = os.path.dirname(os.path.realpath(__file__))

if os.path.isfile(os.path.join(dir_path, 'VERSION.txt')):
with open(os.path.join(dir_path, 'VERSION.txt')) as version_file:
__version__ = version_file.read().strip()
_version_path = Path(__file__).parent.absolute() / 'VERSION.txt'
if _version_path.is_file():
__version__ = _version_path.read_text(encoding='utf-8').strip()

__all__ = ['Dataset', 'Booster', 'CVBooster', 'Sequence',
'register_logger',
Expand Down
90 changes: 42 additions & 48 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import abc
import ctypes
import json
import os
import warnings
from collections import OrderedDict
from copy import deepcopy
from functools import wraps
from logging import Logger
from os import SEEK_END
from os.path import getsize
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union

Expand Down Expand Up @@ -243,31 +245,25 @@ def to_string(x):
else:
return str(x)
pairs.append(f"{key}={','.join(map(to_string, val))}")
elif isinstance(val, (str, NUMERIC_TYPES)) or is_numeric(val):
elif isinstance(val, (str, Path, NUMERIC_TYPES)) or is_numeric(val):
pairs.append(f"{key}={val}")
elif val is not None:
raise TypeError(f'Unknown type of parameter:{key}, got:{type(val).__name__}')
return ' '.join(pairs)


class _TempFile:
"""Proxy class to workaround errors on Windows."""

def __enter__(self):
with NamedTemporaryFile(prefix="lightgbm_tmp_", delete=True) as f:
self.name = f.name
self.path = Path(self.name)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if os.path.isfile(self.name):
os.remove(self.name)

def readlines(self):
with open(self.name, "r+") as f:
ret = f.readlines()
return ret

def writelines(self, lines):
with open(self.name, "w+") as f:
f.writelines(lines)
if self.path.is_file():
self.path.unlink()


class LightGBMError(Exception):
Expand Down Expand Up @@ -584,12 +580,12 @@ def _load_pandas_categorical(file_name=None, model_str=None):
pandas_key = 'pandas_categorical:'
offset = -len(pandas_key)
if file_name is not None:
max_offset = -os.path.getsize(file_name)
max_offset = -getsize(file_name)
with open(file_name, 'rb') as f:
while True:
if offset < max_offset:
offset = max_offset
f.seek(offset, os.SEEK_END)
f.seek(offset, SEEK_END)
lines = f.readlines()
if len(lines) >= 2:
break
Expand Down Expand Up @@ -685,7 +681,7 @@ def __init__(self, model_file=None, booster_handle=None, pred_parameter=None):
Parameters
----------
model_file : string or None, optional (default=None)
model_file : string, pathlib.Path or None, optional (default=None)
Path to the model file.
booster_handle : object or None, optional (default=None)
Handle of Booster.
Expand All @@ -698,7 +694,7 @@ def __init__(self, model_file=None, booster_handle=None, pred_parameter=None):
"""Prediction task"""
out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
c_str(model_file),
c_str(str(model_file)),
ctypes.byref(out_num_iterations),
ctypes.byref(self.handle)))
out_num_class = ctypes.c_int(0)
Expand Down Expand Up @@ -743,9 +739,9 @@ def predict(self, data, start_iteration=0, num_iteration=-1,
Parameters
----------
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for prediction.
When data type is string, it represents the path of txt file.
When data type is string or pathlib.Path, it represents the path of txt file.
start_iteration : int, optional (default=0)
Start index of the iteration to predict.
num_iteration : int, optional (default=-1)
Expand Down Expand Up @@ -780,21 +776,19 @@ def predict(self, data, start_iteration=0, num_iteration=-1,
predict_type = C_API_PREDICT_CONTRIB
int_data_has_header = 1 if data_has_header else 0

if isinstance(data, str):
if isinstance(data, (str, Path)):
with _TempFile() as f:
_safe_call(_LIB.LGBM_BoosterPredictForFile(
self.handle,
c_str(data),
c_str(str(data)),
ctypes.c_int(int_data_has_header),
ctypes.c_int(predict_type),
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
c_str(self.pred_parameter),
c_str(f.name)))
lines = f.readlines()
nrow = len(lines)
preds = [float(token) for line in lines for token in line.split('\t')]
preds = np.array(preds, dtype=np.float64, copy=False)
preds = np.loadtxt(f.name, dtype=np.float64)
nrow = preds.shape[0]
elif isinstance(data, scipy.sparse.csr_matrix):
preds, nrow = self.__pred_for_csr(data, start_iteration, num_iteration, predict_type)
elif isinstance(data, scipy.sparse.csc_matrix):
Expand Down Expand Up @@ -829,9 +823,9 @@ def predict(self, data, start_iteration=0, num_iteration=-1,
def __get_num_preds(self, start_iteration, num_iteration, nrow, predict_type):
"""Get size of prediction result."""
if nrow > MAX_INT32:
raise LightGBMError('LightGBM cannot perform prediction for data'
raise LightGBMError('LightGBM cannot perform prediction for data '
f'with number of rows greater than MAX_INT32 ({MAX_INT32}).\n'
'You can split your data into chunks'
'You can split your data into chunks '
'and then concatenate predictions for them')
n_preds = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterCalcNumPredict(
Expand Down Expand Up @@ -1133,9 +1127,9 @@ def __init__(self, data, label=None, reference=None,
Parameters
----------
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays
Data source of Dataset.
If string, it represents the path to txt file.
If string or pathlib.Path, it represents the path to txt file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Label of the data.
reference : Dataset or None, optional (default=None)
Expand Down Expand Up @@ -1384,7 +1378,7 @@ def _free_handle(self):

def _set_init_score_by_predictor(self, predictor, data, used_indices=None):
data_has_header = False
if isinstance(data, str):
if isinstance(data, (str, Path)):
# check data has header or not
data_has_header = any(self.params.get(alias, False) for alias in _ConfigAliases.get("header"))
num_data = self.num_data()
Expand All @@ -1395,7 +1389,7 @@ def _set_init_score_by_predictor(self, predictor, data, used_indices=None):
is_reshape=False)
if used_indices is not None:
assert not self.need_slice
if isinstance(data, str):
if isinstance(data, (str, Path)):
sub_init_score = np.empty(num_data * predictor.num_class, dtype=np.float32)
assert num_data == len(used_indices)
for i in range(len(used_indices)):
Expand Down Expand Up @@ -1472,10 +1466,10 @@ def _lazy_init(self, data, label=None, reference=None,
elif reference is not None:
raise TypeError('Reference dataset should be None or dataset instance')
# start construct data
if isinstance(data, str):
if isinstance(data, (str, Path)):
self.handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_DatasetCreateFromFile(
c_str(data),
c_str(str(data)),
c_str(params_str),
ref_dataset,
ctypes.byref(self.handle)))
Expand Down Expand Up @@ -1775,9 +1769,9 @@ def create_valid(self, data, label=None, weight=None, group=None,
Parameters
----------
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays
Data source of Dataset.
If string, it represents the path to txt file.
If string or pathlib.Path, it represents the path to txt file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Label of the data.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
Expand Down Expand Up @@ -1842,7 +1836,7 @@ def save_binary(self, filename):
Parameters
----------
filename : string
filename : string or pathlib.Path
Name of the output file.
Returns
Expand All @@ -1852,7 +1846,7 @@ def save_binary(self, filename):
"""
_safe_call(_LIB.LGBM_DatasetSaveBinary(
self.construct().handle,
c_str(filename)))
c_str(str(filename))))
return self

def _update_params(self, params):
Expand Down Expand Up @@ -2242,7 +2236,7 @@ def get_data(self):
Returns
-------
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, list of numpy arrays or None
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, list of numpy arrays or None
Raw data used in the Dataset construction.
"""
if self.handle is None:
Expand Down Expand Up @@ -2442,7 +2436,7 @@ def _dump_text(self, filename):
Parameters
----------
filename : string
filename : string or pathlib.Path
Name of the output file.
Returns
Expand All @@ -2452,7 +2446,7 @@ def _dump_text(self, filename):
"""
_safe_call(_LIB.LGBM_DatasetDumpText(
self.construct().handle,
c_str(filename)))
c_str(str(filename))))
return self


Expand All @@ -2468,7 +2462,7 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None,
Parameters for Booster.
train_set : Dataset or None, optional (default=None)
Training dataset.
model_file : string or None, optional (default=None)
model_file : string, pathlib.Path or None, optional (default=None)
Path to the model file.
model_str : string or None, optional (default=None)
Model will be loaded from this string.
Expand Down Expand Up @@ -2561,7 +2555,7 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None,
out_num_iterations = ctypes.c_int(0)
self.handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
c_str(model_file),
c_str(str(model_file)),
ctypes.byref(out_num_iterations),
ctypes.byref(self.handle)))
out_num_class = ctypes.c_int(0)
Expand Down Expand Up @@ -3200,7 +3194,7 @@ def save_model(self, filename, num_iteration=None, start_iteration=0, importance
Parameters
----------
filename : string
filename : string or pathlib.Path
Filename to save Booster.
num_iteration : int or None, optional (default=None)
Index of the iteration that should be saved.
Expand All @@ -3226,7 +3220,7 @@ def save_model(self, filename, num_iteration=None, start_iteration=0, importance
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
c_str(filename)))
c_str(str(filename))))
_dump_pandas_categorical(self.pandas_categorical, filename)
return self

Expand Down Expand Up @@ -3400,9 +3394,9 @@ def predict(self, data, start_iteration=0, num_iteration=None,
Parameters
----------
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for prediction.
If string, it represents the path to txt file.
If string or pathlib.Path, it represents the path to txt file.
start_iteration : int, optional (default=0)
Start index of the iteration to predict.
If <= 0, starts from the first iteration.
Expand Down Expand Up @@ -3455,9 +3449,9 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
Parameters
----------
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for refit.
If string, it represents the path to txt file.
If string or pathlib.Path, it represents the path to txt file.
label : list, numpy 1-D array or pandas Series / one-column DataFrame
Label for refit.
decay_rate : float, optional (default=0.9)
Expand Down
9 changes: 5 additions & 4 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections
import copy
from operator import attrgetter
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -76,7 +77,7 @@ def train(params, train_set, num_boost_round=100,
If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i].
To ignore the default metric corresponding to the used objective,
set the ``metric`` parameter to the string ``"None"`` in ``params``.
init_model : string, Booster or None, optional (default=None)
init_model : string, pathlib.Path, Booster or None, optional (default=None)
Filename of LightGBM model or Booster instance used for continue training.
feature_name : list of strings or 'auto', optional (default="auto")
Feature names.
Expand Down Expand Up @@ -161,7 +162,7 @@ def train(params, train_set, num_boost_round=100,

if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.")
if isinstance(init_model, str):
if isinstance(init_model, (str, Path)):
predictor = _InnerPredictor(model_file=init_model, pred_parameter=params)
elif isinstance(init_model, Booster):
predictor = init_model._to_predictor(dict(init_model.params, **params))
Expand Down Expand Up @@ -470,7 +471,7 @@ def cv(params, train_set, num_boost_round=100,
If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i].
To ignore the default metric corresponding to the used objective,
set ``metrics`` to the string ``"None"``.
init_model : string, Booster or None, optional (default=None)
init_model : string, pathlib.Path, Booster or None, optional (default=None)
Filename of LightGBM model or Booster instance used for continue training.
feature_name : list of strings or 'auto', optional (default="auto")
Feature names.
Expand Down Expand Up @@ -545,7 +546,7 @@ def cv(params, train_set, num_boost_round=100,

if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.")
if isinstance(init_model, str):
if isinstance(init_model, (str, Path)):
predictor = _InnerPredictor(model_file=init_model, pred_parameter=params)
elif isinstance(init_model, Booster):
predictor = init_model._to_predictor(dict(init_model.params, **params))
Expand Down
Loading

0 comments on commit 90342e9

Please sign in to comment.