Skip to content

Commit

Permalink
feat: add concept of native forecast learners
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Jan 6, 2025
1 parent 45bb4cc commit 1fb9125
Show file tree
Hide file tree
Showing 11 changed files with 348 additions and 156 deletions.
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ RoxygenNote: 7.3.2
Collate:
'ForecastLearner.R'
'zzz.R'
'LearnerARIMA.R'
'LearnerFcst.R'
'LearnerRegrARIMA.R'
'LearnerRegrAutoARIMA.R'
'MeasureDirectional.R'
'ResamplingForecastCV.R'
'ResamplingForecastHoldout.R'
'TaskFcstAirpassengers.R'
'as_task_fcst.R'
'bibentries.R'
'utils.R'
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ S3method(as_task_fcst,TaskFcst)
S3method(as_task_fcst,data.frame)
export(ForecastLearner)
export(LearnerFcstARIMA)
export(LearnerFcstAutoARIMA)
export(ResamplingForecastCV)
export(ResamplingForecastHoldout)
export(as_task_fcst)
Expand Down
23 changes: 0 additions & 23 deletions R/LearnerFcst.R

This file was deleted.

46 changes: 27 additions & 19 deletions R/LearnerARIMA.R → R/LearnerRegrARIMA.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@ LearnerFcstARIMA = R6Class("LearnerFcstARIMA",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {

ps = ps(
order = p_uty(default = c(0, 0, 0), tags = "train"),
seasonal = p_uty(default = c(0, 0, 0), tags = "train"),
param_set = ps(
order = p_uty(
default = c(0L, 0L, 0L),
tags = "train",
custom_check = crate(function(x) check_integerish(x, lower = 0L, len = 3L))
),
seasonal = p_uty(
default = c(0L, 0L, 0L),
tags = "train",
custom_check = crate(function(x) check_integerish(x, lower = 0L, len = 3L))
),
include.mean = p_lgl(default = TRUE, tags = "train"),
include.drift = p_lgl(default = FALSE, tags = "train"),
biasadj = p_lgl(default = FALSE, tags = "train"),
Expand All @@ -31,11 +38,11 @@ LearnerFcstARIMA = R6Class("LearnerFcstARIMA",

super$initialize(
id = "fcst.arima",
param_set = ps,
param_set = param_set,
feature_types = c("logical", "integer", "numeric"),
packages = c("mlr3learners", "forecast"),
packages = c("mlr3forecast", "forecast"),
label = "ARIMA",
man = "mlr3learners::mlr_learners_arima.arima"
man = "mlr3forecast::mlr_learners_fcst.arima"
)
}
),
Expand All @@ -44,41 +51,42 @@ LearnerFcstARIMA = R6Class("LearnerFcstARIMA",
.max_index = NULL,

.train = function(task) {
if (length(task$col_roles$order) == 0L) {
if ("ordered" %nin% task$properties) {
stopf("%s learner requires an ordered task.", self$id)
}
private$.max_index = max(task$data(cols = task$col_roles$order)[[1L]])
pv = self$param_set$get_values(tags = "train")
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}
if (length(task$feature_names) > 0) {
xreg = as.matrix(task$data(cols = task$feature_names))

if (is_task_featureless(task)) {
invoke(forecast::Arima,
y = task$data(rows = task$row_ids, cols = task$target_names),
xreg = xreg,
y = as.ts(task$data(cols = task$target_names)),
.args = pv
)
} else {
xreg = as.matrix(task$data(cols = task$feature_names))
invoke(forecast::Arima,
y = task$data(rows = task$row_ids, cols = task$target_names),
.args = pv)
y = as.ts(task$data(cols = task$target_names)),
xreg = xreg,
.args = pv
)
}
},

.predict = function(task) {
pv = self$param_set$get_values(tags = "predict")
if (private$.is_newdata(task)) {
if (length(task$feature_names) > 0) {
if (is_task_featureless(task)) {
prediction = invoke(forecast::forecast, self$model, h = length(task$row_ids))
} else {
newdata = as.matrix(task$data(cols = task$feature_names))
prediction = invoke(forecast::forecast, self$model, xreg = newdata)
} else {
prediction = invoke(forecast::forecast, self$model, h = length(task$row_ids))
browser()
}
list(response = prediction$mean)
} else {
prediction = stats::fitted(self$model[task$row_ids])
prediction = stats::fitted(self$model)[task$row_ids]
list(response = prediction)
}
},
Expand Down
105 changes: 105 additions & 0 deletions R/LearnerRegrAutoARIMA.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#' @title ARIMA
#'
#' @name mlr_learners_fcst.auto_arima
#'
#' @description
#' ...
#'
#' @templateVar id fcst.auto_arima
#' @template learner
#'
#' @references
#' ...
#'
#' @export
#' @template seealso_learner
LearnerFcstAutoARIMA = R6Class("LearnerFcstAutoARIMA",
inherit = LearnerRegr,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(
d = p_int(0L, default = NA, tags = "train", special_vals = list(NA)),
D = p_int(0L, default = NA, tags = "train", special_vals = list(NA)),
max.q = p_int(0L, default = 5, tags = "train"),
max.p = p_int(0L, default = 5, tags = "train"),
max.P = p_int(0L, default = 2, tags = "train"),
max.Q = p_int(0L, default = 2, tags = "train"),
max.order = p_int(0L, default = 5, tags = "train"),
max.d = p_int(0L, default = 2, tags = "train"),
max.D = p_int(0L, default = 1, tags = "train"),
start.p = p_int(0L, default = 2, tags = "train"),
start.q = p_int(0L, default = 2, tags = "train"),
start.P = p_int(0L, default = 2, tags = "train"),
start.Q = p_int(0L, default = 2, tags = "train"),
stepwise = p_lgl(default = FALSE, tags = "train"),
allowdrift = p_lgl(default = TRUE, tags = "train"),
seasonal = p_lgl(default = FALSE, tags = "train")
)

super$initialize(
id = "fcst.auto_arima",
param_set = param_set,
feature_types = c("logical", "integer", "numeric"),
packages = c("mlr3forecast", "forecast"),
label = "Auto ARIMA",
man = "mlr3forecast::mlr_learners_fcst.arima"
)
}
),

private = list(
.max_index = NULL,

.train = function(task) {
if ("ordered" %nin% task$properties) {
stopf("%s learner requires an ordered task.", self$id)
}
private$.max_index = max(task$data(cols = task$col_roles$order)[[1L]])
pv = self$param_set$get_values(tags = "train")
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

if (is_task_featureless(task)) {
invoke(forecast::auto.arima,
y = as.ts(task$data(cols = task$target_names)),
.args = pv
)
} else {
xreg = as.matrix(task$data(cols = task$feature_names))
invoke(forecast::auto.arima,
y = as.ts(task$data(cols = task$target_names)),
xreg = xreg,
.args = pv
)
}
},

.predict = function(task) {
pv = self$param_set$get_values(tags = "predict")
if (private$.is_newdata(task)) {
if (is_task_featureless(task)) {
prediction = invoke(forecast::forecast, self$model, h = length(task$row_ids))
} else {
newdata = as.matrix(task$data(cols = task$feature_names))
prediction = invoke(forecast::forecast, self$model, xreg = newdata)
}
list(response = prediction$mean)
} else {
prediction = stats::fitted(self$model)[task$row_ids]
list(response = prediction)
}
},

.is_newdata = function(task) {
order_cols = task$col_roles$order
idx = task$backend$data(rows = task$row_ids, cols = order_cols)[[1L]]
!any(private$.max_index %in% idx)
}
)
)

#' @include zzz.R
register_learner("fcst.auto_arima", LearnerFcstAutoARIMA)
4 changes: 4 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
is_task_featureless = function(task) {
length(task$feature_names) == 0L
}

18 changes: 12 additions & 6 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,19 @@ newdata = generate_newdata(task, 12L, "month")
newdata
```

### Example: WIP
### Example: Native Forecasting Learners

```{r, eval = FALSE}
task = tsk("airpassengers")
task$select(setdiff(task$feature_names, "date"))
learner = LearnerFcstARIMA$new()$train(task)
```{r}
task = tsk("airpassengers")$select(setdiff(task$feature_names, "date"))
learner = lrn("fcst.arima", order = c(2L, 1L, 2L))$train(task)
prediction = learner$predict(task, 140:144)
prediction$score(msr("regr.rmse"))
newdata = generate_newdata(task, 12L, "month")
learner$predict_newdata(newdata, task)
learner = lrn("fcst.auto_arima")$train(task)
prediction = learner$predict(task, 140:144)
prediction$score(msr("regr.rmse"))
newdata = generate_newdata(task, 12L, "month")
learner$predict(task, 140:144)
learner$predict_newdata(newdata, task)
```
Loading

0 comments on commit 1fb9125

Please sign in to comment.