diff --git a/learning.py b/learning.py index e0d4cd26d..3bf3c7bc9 100644 --- a/learning.py +++ b/learning.py @@ -1078,8 +1078,9 @@ def cross_validation(learner, size, dataset, k=10, trials=1): fold_errV = 0 n = len(dataset.examples) examples = dataset.examples + random.shuffle(dataset.examples) for fold in range(k): - random.shuffle(dataset.examples) + train_data, val_data = train_test_split(dataset, fold * (n / k), (fold + 1) * (n / k)) dataset.examples = train_data