Skip to content

Commit

Permalink
refactor: try to use custom col role instead of new class
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Jan 6, 2025
1 parent dbbd151 commit 45bb4cc
Show file tree
Hide file tree
Showing 19 changed files with 77 additions and 386 deletions.
2 changes: 0 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ LazyData: true
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.3.2
Collate:
'DataBackendTimeSeries.R'
'ForecastLearner.R'
'zzz.R'
'LearnerARIMA.R'
'LearnerFcst.R'
'MeasureDirectional.R'
'ResamplingForecastCV.R'
'ResamplingForecastHoldout.R'
'TaskFcst.R'
'TaskFcstAirpassengers.R'
'as_task_fcst.R'
'bibentries.R'
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
S3method(as_task_fcst,DataBackend)
S3method(as_task_fcst,TaskFcst)
S3method(as_task_fcst,data.frame)
export(DataBackendTimeSeries)
export(ForecastLearner)
export(LearnerFcstARIMA)
export(ResamplingForecastCV)
export(ResamplingForecastHoldout)
export(TaskFcst)
export(as_task_fcst)
import(R6)
import(checkmate)
Expand Down
40 changes: 0 additions & 40 deletions R/DataBackendTimeSeries.R

This file was deleted.

18 changes: 5 additions & 13 deletions R/ForecastLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,14 @@ ForecastLearner = R6::R6Class("ForecastLearner",
#' The lag
lag = NULL,

#' @field trafo ([Graph])\cr
#' The task transformation
trafo = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' @param task ([Task])\cr
#' @param learner ([Learner])\cr
#' @param lag (`integer(1)`)\cr
#' @param trafo ([Graph])\cr
initialize = function(learner, lag, trafo = NULL) {
initialize = function(learner, lag) {
self$learner = assert_learner(as_learner(learner, clone = TRUE))
self$lag = assert_integerish(lag, lower = 1L, any.missing = FALSE, coerce = TRUE)
self$trafo = trafo
# self$trafo = as_graph(trafo, clone = TRUE)

super$initialize(
id = learner$id,
Expand Down Expand Up @@ -72,7 +65,6 @@ ForecastLearner = R6::R6Class("ForecastLearner",
preds = map(row_ids, function(i) {
new_x = private$.lag_transform(dt, target)[i]
pred = self$model$learner$predict_newdata(new_x)
# set is faster with DT
dt[i, (target) := pred$response]
pred
})
Expand All @@ -90,11 +82,11 @@ ForecastLearner = R6::R6Class("ForecastLearner",
lag = self$lag
nms = sprintf("%s_lag_%s", target, lag)
dt = copy(dt)
key = private$.task$key
if (is.null(key)) {
dt[, (nms) := shift(.SD, n = lag, type = "lag"), .SDcols = target]
key_coles = private$.task$col_roles$key
if (length(key_coles) > 0L) {
dt[, (nms) := shift(.SD, n = lag, type = "lag"), by = key_coles, .SDcols = target]
} else {
dt[, (nms) := shift(.SD, n = lag, type = "lag"), by = key, .SDcols = target]
dt[, (nms) := shift(.SD, n = lag, type = "lag"), .SDcols = target]
}
dt
},
Expand Down
Empty file removed R/PipeOp
Empty file.
9 changes: 7 additions & 2 deletions R/ResamplingForecastCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,12 @@ ResamplingForecastCV = R6Class("ResamplingForecastCV",
},

.sample_new = function(ids, task, ...) {
.NotYetImplemented()
if ("ordered" %nin% task$properties) {
stopf(
"Resampling '%s' requires an ordered task, but Task '%s' has no order.",
self$id, task$id
)
}

pars = self$param_set$get_values()
horizon = pars$horizon
Expand All @@ -113,7 +118,7 @@ ResamplingForecastCV = R6Class("ResamplingForecastCV",
fixed_window = pars$fixed_window

order_cols = task$col_roles$order
key_cols = task$key
key_cols = task$col_roles$key
has_key = length(key_cols) > 0L

tab = task$backend$data(
Expand Down
20 changes: 17 additions & 3 deletions R/ResamplingForecastHoldout.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",

private = list(
.sample_old = function(ids, ...) {
if ("ordered" %nin% task$properties) {
stopf(
"Resampling '%s' requires an ordered task, but Task '%s' has no order.",
self$id, task$id
)
}

pars = self$param_set$get_values()
ratio = pars$ratio
n = pars$n
Expand All @@ -91,6 +98,13 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
},

.sample = function(ids, task, ...) {
if ("ordered" %nin% task$properties) {
stopf(
"Resampling '%s' requires an ordered task, but Task '%s' has no order.",
self$id, task$id
)
}

pars = self$param_set$get_values()
ratio = pars$ratio
n = pars$n
Expand All @@ -109,10 +123,10 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
}

order_cols = task$col_roles$order
key_cols = task$key
has_key = !is.null(key_cols)
key_cols = task$col_roles$key
has_key_cols = length(key_cols) > 0L
tab = task$backend$data(rows = ids, cols = c(task$backend$primary_key, order_cols, key_cols))
if (has_key) {
if (has_key_cols) {
setnames(tab, c("row_id", "order", "key"))
setorderv(tab, c("key", "order"))
n_groups = length(unique(tab$key))
Expand Down
53 changes: 0 additions & 53 deletions R/TaskFcst.R

This file was deleted.

2 changes: 1 addition & 1 deletion R/TaskFcstAirpassengers.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ load_task_airpassengers = function(id = "airpassengers") {
setnames(dt, c("date", "passengers"))
b = as_data_backend(dt)

task = TaskFcst$new(
task = TaskRegr$new(
id = id,
backend = b,
target = "passengers",
Expand Down
15 changes: 10 additions & 5 deletions R/as_task_fcst.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ as_task_fcst.DataBackend = function(x, target = NULL, index = NULL, key = NULL,

assert_choice(target, x$colnames)
assert_choice(index, x$colnames)
assert_choice(key, x$colnames, null.ok = TRUE)

task = TaskFcst$new(
id = id, backend = x, target = target, target = target, key = key, label = label, ...
)
task = TaskRegr$new(id = id, backend = x, target = target, label = label, ...)
task$col_roles$order = index
if (!is.null(key)) {
task$col_roles$key = key
}
task
}

Expand All @@ -58,11 +60,14 @@ as_task_fcst.data.frame = function(x, target = NULL, index = NULL, key = NULL, i
assert_choice(key, names(x), null.ok = TRUE)

ii = which(map_lgl(keep(x, is.double), anyInfinite))
if (length(ii)) {
if (length(ii) > 0L) {
warningf("Detected columns with unsupported Inf values in data: %s", str_collapse(names(ii)))
}

task = TaskFcst$new(id = id, backend = x, target = target, key = key, label = label)
task = TaskRegr$new(id = id, backend = x, target = target, label = label, ...)
task$col_roles$order = index
if (!is.null(key)) {
task$col_roles$key = key
}
task
}
5 changes: 4 additions & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mlr3forecast_tasks = new.env()
mlr3forecast_learners = new.env()
mlr3forecast_measures = new.env()
mlr3forecast_feature_types = c(dte = "Date")
# mlr3forecast_col_roles = "key"
mlr3forecast_col_roles = "key"

named_union = function(x, y) set_names(union(x, y), union(names(x), names(y)))

Expand All @@ -38,6 +38,9 @@ register_mlr3 = function() {
mlr_reflections$learner_predict_types$fcst = mlr_reflections$learner_predict_types$regr
mlr_reflections$learner_properties$fcst = mlr_reflections$learner_properties$regr
mlr_reflections$task_col_roles$fcst = mlr_reflections$task_col_roles$regr
mlr_reflections$task_col_roles$regr = union(
mlr_reflections$task_col_roles$regr, mlr3forecast_col_roles
)
mlr_reflections$task_feature_types = named_union(
mlr_reflections$task_feature_types, mlr3forecast_feature_types
)
Expand Down
2 changes: 2 additions & 0 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ graph = ppl("convert_types", "Date", "POSIXct") %>>%
)
)
task = graph$train(task)[[1L]]
task$col_roles$key = "state"
flrn = ForecastLearner$new(lrn("regr.ranger"), 1:3)$train(task)
prediction = flrn$predict(task, 4460:4464)
Expand Down Expand Up @@ -232,6 +233,7 @@ task = tsibbledata::aus_livestock |>
setorder(state, month) |>
as_task_fcst(target = "count", index = "month", key = "state")
task = graph$train(task)[[1L]]
task$col_roles$key = "state"
flrn = ForecastLearner$new(lrn("regr.ranger"), 1L)$train(task)
tab = task$backend$data(
rows = task$row_ids, cols = c(task$backend$primary_key, "month.year", "state")
Expand Down
Loading

0 comments on commit 45bb4cc

Please sign in to comment.