Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 060073e

Browse files
committedJan 24, 2025·
tests: make more resampling tests work
1 parent 74a4882 commit 060073e

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed
 

‎R/ResamplingFcstCV.R

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,26 +106,29 @@ ResamplingFcstCV = R6Class("ResamplingFcstCV",
106106
setnames(tab, c("row_id", "order"))
107107
setorderv(tab, "order")
108108
train_end = tab[.N - horizon, row_id]
109-
train_end = rev(seq.int(
109+
train_end = seq.int(
110110
from = train_end,
111111
by = -pars$step_size,
112112
length.out = pars$folds
113-
))
113+
)
114114
if (!pars$fixed_window) {
115115
train_ids = map(train_end, function(x) ids[1L]:x)
116116
} else {
117117
train_ids = map(train_end, function(x) (x - window_size + 1L):x)
118118
}
119-
test_ids = map(train_ids, function(x) (x[length(x)] + 1L):(x[length(x)] + horizon))
119+
test_ids = map(train_ids, function(x) {
120+
n = length(x)
121+
(x[n] + 1L):(x[n] + horizon)
122+
})
120123
} else {
121124
setnames(tab, "..row_id", "row_id")
122125
setorderv(tab, c(key_cols, order_cols))
123126
ids = tab[, {
124-
train_end = rev(seq.int(
127+
train_end = seq.int(
125128
from = .N - horizon,
126129
by = -pars$step_size,
127130
length.out = pars$folds
128-
))
131+
)
129132
if (pars$fixed_window) {
130133
train_ids = map(train_end, function(x) .SD[(x - window_size + 1L):x, row_id])
131134
} else {
@@ -137,9 +140,10 @@ ResamplingFcstCV = R6Class("ResamplingFcstCV",
137140
})
138141
list(train_ids = train_ids, test_ids = test_ids)
139142
}, by = key_cols][, .(train_ids, test_ids)]
140-
143+
train_ids = ids$train_ids
144+
test_ids = ids$test_ids
141145
}
142-
list(train = ids$train_ids, test = ids$test_ids)
146+
list(train = train_ids, test = test_ids)
143147
},
144148

145149
.sample_ids = function(ids, ...) {
@@ -149,11 +153,11 @@ ResamplingFcstCV = R6Class("ResamplingFcstCV",
149153

150154
ids = sort(ids)
151155
train_end = ids[ids <= (max(ids) - horizon) & ids >= window_size]
152-
train_end = rev(seq.int(
156+
train_end = seq.int(
153157
from = train_end[length(train_end)],
154158
by = -pars$step_size,
155159
length.out = pars$folds
156-
))
160+
)
157161
if (pars$fixed_window) {
158162
train_ids = map(train_end, function(x) (x - window_size + 1L):x)
159163
} else {

‎tests/testthat/test_ResamplingFcstCV.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
test_that("forecast_cv basic properties", {
2-
task = tsk("penguins")
2+
task = tsk("airpassengers")
33
resampling = rsmp("forecast_cv",
44
folds = 10L, horizon = 3L, window_size = 5L, fixed_window = FALSE
55
)
6-
expect_resampling(resampling, task)
6+
expect_resampling(resampling, task, strata = FALSE)
77
resampling$instantiate(task)
8-
expect_resampling(resampling, task)
8+
expect_resampling(resampling, task, strata = FALSE)
99
expect_identical(resampling$iters, 10L)
1010
expect_equal(intersect(resampling$test_set(1L), resampling$train_set(1L)), integer())
1111
expect_false(resampling$duplicated_ids)
@@ -33,7 +33,7 @@ test_that("forecast_cv works", {
3333
})
3434

3535
test_that("forecast_cv fixed vs. expanding window", {
36-
task = tsk("penguins")
36+
task = tsk("airpassengers")
3737
task$filter(1:30)
3838

3939
# fixed window
@@ -56,7 +56,7 @@ test_that("forecast_cv fixed vs. expanding window", {
5656
})
5757

5858
test_that("forecast_cv with various parameter combinations", {
59-
task = tsk("penguins")
59+
task = tsk("airpassengers")
6060
task$filter(1:30)
6161

6262
# small window, large step size

0 commit comments

Comments
 (0)
Please sign in to comment.