Skip to content

Commit

Permalink
allow arbitrary cross validation fold indices (#3353)
Browse files Browse the repository at this point in the history
* allow arbitrary cross validation fold indices

 - use training indices passed to `folds` parameter in `training.cv`
 - update doc string

* add tests for arbitrary fold indices
  • Loading branch information
owlas authored and hcho3 committed Jun 30, 2018
1 parent 594bcea commit 18813a2
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 9 deletions.
40 changes: 31 additions & 9 deletions python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,22 +231,39 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
np.random.seed(seed)

if stratified is False and folds is None:
# Do standard k-fold cross validation
if shuffle is True:
idx = np.random.permutation(dall.num_row())
else:
idx = np.arange(dall.num_row())
idset = np.array_split(idx, nfold)
elif folds is not None and isinstance(folds, list):
idset = [x[1] for x in folds]
nfold = len(idset)
out_idset = np.array_split(idx, nfold)
in_idset = [
np.concatenate([out_idset[i] for i in range(nfold) if k != i])
for k in range(nfold)
]
elif folds is not None:
# Use user specified custom split using indices
try:
in_idset = [x[0] for x in folds]
out_idset = [x[1] for x in folds]
except TypeError:
# Custom stratification using Sklearn KFoldSplit object
splits = list(folds.split(X=dall.get_label(), y=dall.get_label()))
in_idset = [x[0] for x in splits]
out_idset = [x[1] for x in splits]
nfold = len(out_idset)
else:
# Do standard stratefied shuffle k-fold split
sfk = XGBStratifiedKFold(n_splits=nfold, shuffle=True, random_state=seed)
idset = [x[1] for x in sfk.split(X=dall.get_label(), y=dall.get_label())]
splits = list(sfk.split(X=dall.get_label(), y=dall.get_label()))
in_idset = [x[0] for x in splits]
out_idset = [x[1] for x in splits]
nfold = len(out_idset)

ret = []
for k in range(nfold):
dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i]))
dtest = dall.slice(idset[k])
dtrain = dall.slice(in_idset[k])
dtest = dall.slice(out_idset[k])
# run preprocessing on the data set if needed
if fpreproc is not None:
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
Expand Down Expand Up @@ -308,8 +325,13 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
Number of folds in CV.
stratified : bool
Perform stratified sampling.
folds : a KFold or StratifiedKFold instance
Sklearn KFolds or StratifiedKFolds.
folds : a KFold or StratifiedKFold instance or list of fold indices
Sklearn KFolds or StratifiedKFolds object.
Alternatively may explicitly pass sample indices for each fold.
For `n` folds, `folds` should be a length `n` list of tuples.
Each tuple is `(in,out)` where `in` is a list of indices to be used
as the training samples for the `n`th fold and `out` is a list of
indices to be used as the testing samples for the `n`th fold.
metrics : string or list of strings
Evaluation metrics to be watched in CV.
obj : function
Expand Down
61 changes: 61 additions & 0 deletions tests/python/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# -*- coding: utf-8 -*-
import sys
from contextlib import contextmanager
try:
# python 2
from StringIO import StringIO
except ImportError:
# python 3
from io import StringIO
import numpy as np
import xgboost as xgb
import unittest
Expand All @@ -8,6 +16,21 @@
rng = np.random.RandomState(1994)


@contextmanager
def captured_output():
"""
Reassign stdout temporarily in order to test printed statements
Taken from: https://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python
"""
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield sys.stdout, sys.stderr
finally:
sys.stdout, sys.stderr = old_out, old_err


class TestBasic(unittest.TestCase):

def test_basic(self):
Expand Down Expand Up @@ -238,3 +261,41 @@ def test_cv_no_shuffle(self):
cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False)
assert isinstance(cv, dict)
assert len(cv) == (4)

def test_cv_explicit_fold_indices(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
folds = [
# Train Test
([1, 3], [5, 8]),
([7, 9], [23, 43]),
]

# return np.ndarray
cv = xgb.cv(params, dm, num_boost_round=10, folds=folds, as_pandas=False)
assert isinstance(cv, dict)
assert len(cv) == (4)

def test_cv_explicit_fold_indices_labels(self):
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'reg:linear'}
N = 100
F = 3
dm = xgb.DMatrix(data=np.random.randn(N, F), label=np.arange(N))
folds = [
# Train Test
([1, 3], [5, 8]),
([7, 9], [23, 43, 11]),
]

# Use callback to log the test labels in each fold
def cb(cbackenv):
print([fold.dtest.get_label() for fold in cbackenv.cvfolds])

# Run cross validation and capture standard out to test callback result
with captured_output() as (out, err):
xgb.cv(
params, dm, num_boost_round=1, folds=folds, callbacks=[cb],
as_pandas=False
)
output = out.getvalue().strip()
assert output == '[array([5., 8.], dtype=float32), array([23., 43., 11.], dtype=float32)]'

0 comments on commit 18813a2

Please sign in to comment.