Skip to content

Commit

Permalink
refactor / add eps param to IPCW pipeop
Browse files Browse the repository at this point in the history
  • Loading branch information
studener committed Aug 10, 2024
1 parent 43e3222 commit dde3f72
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
32 changes: 22 additions & 10 deletions R/PipeOpTaskSurvClassifIPCW.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ PipeOpTaskSurvClassifIPCW = R6Class(
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "trafotask_survclassif_IPCW") {
param_set = ps(
cutoff_time = p_dbl(0, default = NULL, special_vals = list(NULL))
cutoff_time = p_dbl(lower = 0, special_vals = list()),
eps = p_dbl(lower = 0, default = 1e-6)
)
super$initialize(
id = id,
Expand Down Expand Up @@ -65,22 +66,33 @@ PipeOpTaskSurvClassifIPCW = R6Class(
},

.train = function(input) {
data_trafo = input[[1]]$data()
data = input[[1]]$data()
time_var = input[[1]]$target_names[1]
status_var = input[[1]]$target_names[2]

cutoff_time = self$param_set$values$cutoff_time
eps = self$param_set$values$eps

if (cutoff_time >= max(data[[time_var]])) {
stop("Cutoff time must be smaller than the maximum event time.")
}

# transform data and calculate weights
data_trafo$time[data_trafo$time > cutoff_time] = cutoff_time
data_trafo$status[data_trafo$time == cutoff_time] = 1
data_trafo$status = (data_trafo$status != 1) * 1
times = data[[time_var]]
times[times > cutoff_time] = cutoff_time

task_new = TaskSurv$new(id = "ipcw", time = "time", event = "status", backend = data_trafo)
pred = lrn("surv.kaplan")$train(task_new)$predict(task_new)
weights = 1 / pred$data$distr[1,]
status = data[[status_var]]
status[times == cutoff_time] = 0

cens = survival::survfit(Surv(times, 1 - status) ~ 1)
cens$surv[length(cens$surv)] = cens$surv[length(cens$surv)-1]
cens$surv[cens$surv == 0] = eps

weights = rep(1/cens$surv, table(times))

# add weights to original data
time = status = NULL
data = input[[1]]$data()
data[["ipc_weights"]] = weights[as.character(data_trafo$time)]
data[["ipc_weights"]] = weights
data[status == 0 & time < cutoff_time, "ipc_weights" := 0]
data$status = factor(data$status, levels = c("0", "1"))

Expand Down
5 changes: 3 additions & 2 deletions tests/testthat/test_pipelines.R
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,13 @@ skip_if_not_installed("mlr3extralearners")
test_that("survtoclassif_IPCW", {
requireNamespace("mlr3extralearners")

pipe = mlr3pipelines::ppl("survtoclassif_IPCW", learner = lrn("classif.gam"))
pipe = mlr3pipelines::ppl("survtoclassif_IPCW", learner = lrn("classif.gam"),
cutoff_time = 50)
expect_class(pipe, "Graph")

## This needs fixing
grlrn = mlr3pipelines::ppl("survtoclassif_IPCW", learner = lrn("classif.gam"),
graph_learner = TRUE)
cutoff_time = 50, graph_learner = TRUE)
expect_class(grlrn, "GraphLearner")
grlrn$train(task)
p = grlrn$predict(task)
Expand Down

0 comments on commit dde3f72

Please sign in to comment.