Skip to content

Commit

Permalink
fix: filter order column in fcst learner features
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Jan 7, 2025
1 parent 2484add commit 539c07d
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 34 deletions.
6 changes: 3 additions & 3 deletions R/LearnerRegrARIMA.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ LearnerFcstARIMA = R6Class("LearnerFcstARIMA",
super$initialize(
id = "fcst.arima",
param_set = param_set,
feature_types = c("logical", "integer", "numeric"),
feature_types = c("Date", "logical", "integer", "numeric"),
packages = c("mlr3forecast", "forecast"),
label = "ARIMA",
man = "mlr3forecast::mlr_learners_fcst.arima"
Expand All @@ -64,7 +64,7 @@ LearnerFcstARIMA = R6Class("LearnerFcstARIMA",
.args = pv
)
} else {
xreg = as.matrix(task$data(cols = task$feature_names))
xreg = as.matrix(task$data(cols = fcst_feature_names(task)))
invoke(forecast::Arima,
y = stats::ts(task$data(cols = task$target_names)[[1L]]),
xreg = xreg,
Expand All @@ -79,7 +79,7 @@ LearnerFcstARIMA = R6Class("LearnerFcstARIMA",
if (is_task_featureless(task)) {
prediction = invoke(forecast::forecast, self$model, h = length(task$row_ids))
} else {
newdata = as.matrix(task$data(cols = task$feature_names))
newdata = as.matrix(task$data(cols = fcst_feature_names(task)))
prediction = invoke(forecast::forecast, self$model, xreg = newdata)
}
list(response = prediction$mean)
Expand Down
6 changes: 3 additions & 3 deletions R/LearnerRegrAutoARIMA.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ LearnerFcstAutoARIMA = R6Class("LearnerFcstAutoARIMA",
super$initialize(
id = "fcst.auto_arima",
param_set = param_set,
feature_types = c("logical", "integer", "numeric"),
feature_types = c("Date", "logical", "integer", "numeric"),
packages = c("mlr3forecast", "forecast"),
label = "Auto ARIMA",
man = "mlr3forecast::mlr_learners_fcst.arima"
Expand All @@ -69,7 +69,7 @@ LearnerFcstAutoARIMA = R6Class("LearnerFcstAutoARIMA",
.args = pv
)
} else {
xreg = as.matrix(task$data(cols = task$feature_names))
xreg = as.matrix(task$data(cols = fcst_feature_names(task)))
invoke(forecast::auto.arima,
y = stats::ts(task$data(cols = task$target_names)[[1L]]),
xreg = xreg,
Expand All @@ -84,7 +84,7 @@ LearnerFcstAutoARIMA = R6Class("LearnerFcstAutoARIMA",
if (is_task_featureless(task)) {
prediction = invoke(forecast::forecast, self$model, h = length(task$row_ids))
} else {
newdata = as.matrix(task$data(cols = task$feature_names))
newdata = as.matrix(task$data(cols = fcst_feature_names(task)))
prediction = invoke(forecast::forecast, self$model, xreg = newdata)
}
list(response = prediction$mean)
Expand Down
9 changes: 7 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
is_task_featureless = function(task) {
length(task$feature_names) == 0L
fcst_feature_names = function(task) {
nms = task$feature_names
nms[nms != task$col_roles$order]
}

is_task_featureless = function(task) {
nms = fcst_feature_names(task)
length(nms) == 0L
}
2 changes: 1 addition & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ register_mlr3 = function() {
mlr_reflections$task_types = mlr_reflections$task_types[!"fcst"]
mlr_reflections$task_types = setkeyv(rbind(mlr_reflections$task_types, rowwise_table(
~type, ~package, ~task, ~learner, ~prediction, ~prediction_data, ~measure,
"fcst", "mlr3forecast", "TaskFcst", "LearnerFcst", "PredictionFcst", "PredictionDataFcst", "MeasureFcst" # nolint
"fcst", "mlr3forecast", "TaskRegr", "LearnerRegr", "PredictionFcst", "PredictionDataFcst", "MeasureFcst" # nolint
), fill = TRUE), "type")
mlr_reflections$learner_predict_types$fcst = mlr_reflections$learner_predict_types$regr
mlr_reflections$learner_properties$fcst = mlr_reflections$learner_properties$regr
Expand Down
2 changes: 1 addition & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ newdata
### Example: Native Forecasting Learners

```{r}
task = tsk("airpassengers")$select(setdiff(task$feature_names, "date"))
task = tsk("airpassengers")
learner = lrn("fcst.arima", order = c(2L, 1L, 2L))$train(task)
prediction = learner$predict(task, 140:144)
prediction$score(msr("regr.rmse"))
Expand Down
48 changes: 24 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,32 +44,32 @@ prediction = flrn$predict_newdata(newdata, task)
prediction
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 NA 433.9804
#> 2 NA 436.2909
#> 3 NA 456.6843
#> 1 NA 435.9582
#> 2 NA 435.6452
#> 3 NA 455.2647
prediction = flrn$predict(task, 142:144)
prediction
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 461 458.1193
#> 2 390 411.2993
#> 3 432 432.5985
#> 1 461 456.7750
#> 2 390 412.2228
#> 3 432 431.6033
prediction$score(msr("regr.rmse"))
#> regr.rmse
#> 12.41393
#> 13.06217

flrn = ForecastLearner$new(lrn("regr.ranger"), 1:12)
resampling = rsmp("forecast_holdout", ratio = 0.9)
rr = resample(task, flrn, resampling)
rr$aggregate(msr("regr.rmse"))
#> regr.rmse
#> 48.07301
#> 48.02502

resampling = rsmp("forecast_cv")
rr = resample(task, flrn, resampling)
rr$aggregate(msr("regr.rmse"))
#> regr.rmse
#> 26.31197
#> 25.43955
```

### Multivariate
Expand All @@ -89,34 +89,34 @@ flrn = ForecastLearner$new(lrn("regr.ranger"), 1:12)$train(new_task)
prediction = flrn$predict(new_task, 142:144)
prediction$score(msr("regr.rmse"))
#> regr.rmse
#> 13.75817
#> 14.16161

row_ids = new_task$nrow - 0:2
flrn$predict_newdata(new_task$data(rows = row_ids), new_task)
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 432 435.1631
#> 2 390 432.3781
#> 3 461 454.9814
#> 1 432 433.9538
#> 2 390 431.6852
#> 3 461 457.1615
newdata = new_task$data(rows = row_ids, cols = new_task$feature_names)
flrn$predict_newdata(newdata, new_task)
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 NA 435.1631
#> 2 NA 432.3781
#> 3 NA 454.9814
#> 1 NA 433.9538
#> 2 NA 431.6852
#> 3 NA 457.1615

resampling = rsmp("forecast_holdout", ratio = 0.9)
rr = resample(new_task, flrn, resampling)
rr$aggregate(msr("regr.rmse"))
#> regr.rmse
#> 52.16247
#> 46.99669

resampling = rsmp("forecast_cv")
rr = resample(new_task, flrn, resampling)
rr$aggregate(msr("regr.rmse"))
#> regr.rmse
#> 26.3378
#> 26.67119
```

### mlr3pipelines integration
Expand All @@ -131,7 +131,7 @@ glrn = as_learner(graph %>>% flrn)$train(task)
prediction = glrn$predict(task, 142:144)
prediction$score(msr("regr.rmse"))
#> regr.rmse
#> 12.75032
#> 13.63586
```

### Example: Forecasting electricity demand
Expand Down Expand Up @@ -205,14 +205,14 @@ flrn = ForecastLearner$new(lrn("regr.ranger"), 1:3)$train(task)
prediction = flrn$predict(task, 4460:4464)
prediction$score(msr("regr.rmse"))
#> regr.rmse
#> 21184.4
#> 22628.38

flrn = ForecastLearner$new(lrn("regr.ranger"), 1:3)
resampling = rsmp("forecast_holdout", ratio = 0.9)
rr = resample(task, flrn, resampling)
rr$aggregate(msr("regr.rmse"))
#> regr.rmse
#> 90376.96
#> 93334.76
```

### Example: Global vs Local Forecasting
Expand Down Expand Up @@ -247,7 +247,7 @@ row_ids = tab[year >= 2015, row_id]
prediction = flrn$predict(task, row_ids)
prediction$score(msr("regr.rmse"))
#> regr.rmse
#> 33028.77
#> 32862.98

# global forecasting
task = tsibbledata::aus_livestock |>
Expand All @@ -268,7 +268,7 @@ row_ids = tab[year >= 2015 & state == "Western Australia", row_id]
prediction = flrn$predict(task, row_ids)
prediction$score(msr("regr.rmse"))
#> regr.rmse
#> 31142.26
#> 30804.78
```

### Example: generate new data
Expand Down Expand Up @@ -328,7 +328,7 @@ newdata
### Example: Native Forecasting Learners

``` r
task = tsk("airpassengers")$select(setdiff(task$feature_names, "date"))
task = tsk("airpassengers")
learner = lrn("fcst.arima", order = c(2L, 1L, 2L))$train(task)
#> Registered S3 method overwritten by 'quantmod':
#> method from
Expand Down

0 comments on commit 539c07d

Please sign in to comment.