Skip to content

Commit

Permalink
Merge pull request #367 from jemus42/fix-collate-order
Browse files Browse the repository at this point in the history
Fix Collate: order by splitting zzz.R
  • Loading branch information
bblodfon authored Feb 19, 2024
2 parents a6badfe + 1e8df14 commit 4bc2513
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 115 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.3.1
Collate:
'aaa.R'
'LearnerDens.R'
'zzz.R'
'LearnerDensHistogram.R'
'LearnerDensKDE.R'
'LearnerSurv.R'
Expand Down Expand Up @@ -162,3 +162,4 @@ Collate:
'scoring_rule_erv.R'
'surv_measures.R'
'surv_return.R'
'zzz.R'
115 changes: 115 additions & 0 deletions R/aaa.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# These elements need to be at the top of the Collate: order!

mlr3proba_learners = new.env()
mlr3proba_tasks = new.env()
mlr3proba_measures = new.env()
mlr3proba_task_gens = new.env()
mlr3proba_pipeops = new.env()
mlr3proba_graphs = new.env()

register_learner = function(name, constructor) {
assert_class(constructor, "R6ClassGenerator")
if (name %in% names(mlr3proba_learners)) stopf("learner %s registered twice", name)
mlr3proba_learners[[name]] = constructor
}

register_task = function(name, constructor) {
if (name %in% names(mlr3proba_tasks)) stopf("task %s registered twice", name)
mlr3proba_tasks[[name]] = constructor
}

register_measure = function(name, constructor) {
if (name %in% names(mlr3proba_measures)) stopf("measure %s registered twice", name)
mlr3proba_measures[[name]] = constructor
}

register_task_generator = function(name, constructor) {
if (name %in% names(mlr3proba_task_gens)) stopf("task generator %s registered twice", name)
mlr3proba_task_gens[[name]] = constructor
}

register_pipeop = function(name, constructor) {
if (name %in% names(mlr3proba_pipeops)) stopf("pipeop %s registered twice", name)
mlr3proba_pipeops[[name]] = constructor
}

register_graph = function(name, constructor) {
if (name %in% names(mlr3proba_graphs)) stopf("graph %s registered twice", name)
mlr3proba_graphs[[name]] = constructor
}

register_reflections = function() {
x = utils::getFromNamespace("mlr_reflections", ns = "mlr3")

# task
x$task_types = x$task_types[!c("surv", "dens")]
x$task_types = setkeyv(rbind(x$task_types, rowwise_table(
~type, ~package, ~task, ~learner, ~prediction, ~prediction_data, ~measure,
"surv", "mlr3proba", "TaskSurv", "LearnerSurv", "PredictionSurv", "PredictionDataSurv", "MeasureSurv",
"dens", "mlr3proba", "TaskDens", "LearnerDens", "PredictionDens", "PredictionDataDens", "MeasureDens"
)), "type")

x$task_col_roles$surv = x$task_col_roles$regr
x$task_col_roles$dens = c("feature", "target", "label", "order", "group", "weight", "stratum")
x$task_properties$surv = x$task_properties$regr
x$task_properties$dens = x$task_properties$regr

# learner
x$learner_properties$surv = x$learner_properties$regr
x$learner_properties$dens = x$learner_properties$regr
x$learner_predict_types$surv = list(
crank = c("crank", "lp", "distr", "response"),
distr = c("crank", "lp", "distr", "response"),
lp = c("crank", "lp", "distr", "response"),
response = c("crank", "lp", "distr", "response")
)
x$learner_predict_types$dens = list(
pdf = c("pdf", "cdf", "distr"),
cdf = c("pdf", "cdf", "distr"),
distr = c("pdf", "cdf", "distr")
)

# measure
x$measure_properties$surv = x$measure_properties$regr
x$measure_properties$dens = x$measure_properties$regr
x$default_measures$surv = "surv.cindex"
x$default_measures$dens = "dens.logloss"
}

register_mlr3 = function() {
# reflections
register_reflections()

# tasks
mlr_tasks = utils::getFromNamespace("mlr_tasks", ns = "mlr3")
iwalk(as.list(mlr3proba_tasks), function(obj, name) mlr_tasks$add(name, obj)) # nolint

# task generators
mlr_task_gens = utils::getFromNamespace("mlr_task_generators", ns = "mlr3")
iwalk(as.list(mlr3proba_task_gens), function(obj, name) mlr_task_gens$add(name, obj)) # nolint

# learners
mlr_learners = utils::getFromNamespace("mlr_learners", ns = "mlr3")
iwalk(as.list(mlr3proba_learners), function(obj, name) mlr_learners$add(name, obj)) # nolint

# measures
mlr_measures = utils::getFromNamespace("mlr_measures", ns = "mlr3")
iwalk(as.list(mlr3proba_measures), function(obj, name) mlr_measures$add(name, obj)) # nolint
}

register_mlr3pipelines = function() {
mlr3pipelines::add_class_hierarchy_cache(c("PredictionSurv", "Prediction"))

# pipeops
mlr_pipeops = utils::getFromNamespace("mlr_pipeops", ns = "mlr3pipelines")
iwalk(as.list(mlr3proba_pipeops), function(obj, name) mlr_pipeops$add(name, obj)) # nolint

# Breslow needs another argument so we do it manually
mlr_pipeops$add("breslowcompose", PipeOpBreslow, list(R6Class("Learner",
public = list(id = "breslowcompose", task_type = "surv", predict_types = "lp",
packages = c("mlr3", "mlr3proba"), param_set = ps()))$new()))

# graphs
mlr_graphs = utils::getFromNamespace("mlr_graphs", ns = "mlr3pipelines")
iwalk(as.list(mlr3proba_graphs), function(obj, name) mlr_graphs$add(name, obj)) # nolint
}
115 changes: 1 addition & 114 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ NULL
#' @importFrom stats reformulate model.matrix model.frame sd predict complete.cases density
#' @importFrom survival Surv
#' @importFrom mlr3viz fortify
#' @importFrom utils getFromNamespace
"_PACKAGE"
# nolint end

Expand All @@ -39,120 +40,6 @@ utils::globalVariables(c(
"ShortName", "ClassName", "missing", "task", "value", "variable", "y"
))

mlr3proba_learners = new.env()
mlr3proba_tasks = new.env()
mlr3proba_measures = new.env()
mlr3proba_task_gens = new.env()
mlr3proba_pipeops = new.env()
mlr3proba_graphs = new.env()

register_learner = function(name, constructor) {
assert_class(constructor, "R6ClassGenerator")
if (name %in% names(mlr3proba_learners)) stopf("learner %s registered twice", name)
mlr3proba_learners[[name]] = constructor
}

register_task = function(name, constructor) {
if (name %in% names(mlr3proba_tasks)) stopf("task %s registered twice", name)
mlr3proba_tasks[[name]] = constructor
}

register_measure = function(name, constructor) {
if (name %in% names(mlr3proba_measures)) stopf("measure %s registered twice", name)
mlr3proba_measures[[name]] = constructor
}

register_task_generator = function(name, constructor) {
if (name %in% names(mlr3proba_task_gens)) stopf("task generator %s registered twice", name)
mlr3proba_task_gens[[name]] = constructor
}

register_pipeop = function(name, constructor) {
if (name %in% names(mlr3proba_pipeops)) stopf("pipeop %s registered twice", name)
mlr3proba_pipeops[[name]] = constructor
}

register_graph = function(name, constructor) {
if (name %in% names(mlr3proba_graphs)) stopf("graph %s registered twice", name)
mlr3proba_graphs[[name]] = constructor
}

register_reflections = function() {
x = utils::getFromNamespace("mlr_reflections", ns = "mlr3")

# task
x$task_types = x$task_types[!c("surv", "dens")]
x$task_types = setkeyv(rbind(x$task_types, rowwise_table(
~type, ~package, ~task, ~learner, ~prediction, ~prediction_data, ~measure,
"surv", "mlr3proba", "TaskSurv", "LearnerSurv", "PredictionSurv", "PredictionDataSurv", "MeasureSurv",
"dens", "mlr3proba", "TaskDens", "LearnerDens", "PredictionDens", "PredictionDataDens", "MeasureDens"
)), "type")

x$task_col_roles$surv = x$task_col_roles$regr
x$task_col_roles$dens = c("feature", "target", "label", "order", "group", "weight", "stratum")
x$task_properties$surv = x$task_properties$regr
x$task_properties$dens = x$task_properties$regr

# learner
x$learner_properties$surv = x$learner_properties$regr
x$learner_properties$dens = x$learner_properties$regr
x$learner_predict_types$surv = list(
crank = c("crank", "lp", "distr", "response"),
distr = c("crank", "lp", "distr", "response"),
lp = c("crank", "lp", "distr", "response"),
response = c("crank", "lp", "distr", "response")
)
x$learner_predict_types$dens = list(
pdf = c("pdf", "cdf", "distr"),
cdf = c("pdf", "cdf", "distr"),
distr = c("pdf", "cdf", "distr")
)

# measure
x$measure_properties$surv = x$measure_properties$regr
x$measure_properties$dens = x$measure_properties$regr
x$default_measures$surv = "surv.cindex"
x$default_measures$dens = "dens.logloss"
}

register_mlr3 = function() {
# reflections
register_reflections()

# tasks
mlr_tasks = utils::getFromNamespace("mlr_tasks", ns = "mlr3")
iwalk(as.list(mlr3proba_tasks), function(obj, name) mlr_tasks$add(name, obj)) # nolint

# task generators
mlr_task_gens = utils::getFromNamespace("mlr_task_generators", ns = "mlr3")
iwalk(as.list(mlr3proba_task_gens), function(obj, name) mlr_task_gens$add(name, obj)) # nolint

# learners
mlr_learners = utils::getFromNamespace("mlr_learners", ns = "mlr3")
iwalk(as.list(mlr3proba_learners), function(obj, name) mlr_learners$add(name, obj)) # nolint

# measures
mlr_measures = utils::getFromNamespace("mlr_measures", ns = "mlr3")
iwalk(as.list(mlr3proba_measures), function(obj, name) mlr_measures$add(name, obj)) # nolint
}

register_mlr3pipelines = function() {
mlr3pipelines::add_class_hierarchy_cache(c("PredictionSurv", "Prediction"))

# pipeops
mlr_pipeops = utils::getFromNamespace("mlr_pipeops", ns = "mlr3pipelines")
iwalk(as.list(mlr3proba_pipeops), function(obj, name) mlr_pipeops$add(name, obj)) # nolint

# Breslow needs another argument so we do it manually
mlr_pipeops$add("breslowcompose", PipeOpBreslow, list(R6Class("Learner",
public = list(id = "breslowcompose", task_type = "surv", predict_types = "lp",
packages = c("mlr3", "mlr3proba"), param_set = ps()))$new()))

# graphs
mlr_graphs = utils::getFromNamespace("mlr_graphs", ns = "mlr3pipelines")
iwalk(as.list(mlr3proba_graphs), function(obj, name) mlr_graphs$add(name, obj)) # nolint
}

.onLoad = function(libname, pkgname) {
register_mlr3()
if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
Expand Down

0 comments on commit 4bc2513

Please sign in to comment.