Skip to content

Commit

Permalink
update tests (combining different prediction types)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblodfon committed Oct 13, 2023
1 parent 0ae56d0 commit 707482e
Showing 1 changed file with 73 additions and 8 deletions.
81 changes: 73 additions & 8 deletions tests/testthat/test_PredictionSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,39 +33,55 @@ test_that("c", {
expect_true(length(times1) == length(times2))
expect_false(all(times1 == times2))

# so internal prediction data must be a distribution and not a survival matrix
pred = do.call(c, preds)
expect_prediction_surv(pred)
expect_class(pred$data$distr, "Distribution")
surv_mat = pred$data$distr
expect_class(surv_mat, "matrix")
# check that time points are properly combined
times = as.integer(colnames(surv_mat))
expect_true(all(times == sort(union(times1, times2), decreasing = F)))

# data.table conversion
dt = as.data.table(pred)
expect_data_table(dt, nrows = task$nrow, ncols = 5L, any.missing = FALSE)

# different number of time points
# add extra time point on the 2nd prediction object
preds2 = rr$predictions()
preds2[[2]]$data$distr = cbind(distr2,
matrix(data = rep(0.3, 10), ncol = 1, dimnames = list(NULL, 108)))
expect_false(ncol(preds2[[1]]$data$distr) == ncol(preds2[[2]]$data$distr))
distr1 = preds2[[1]]$data$distr
distr2 = preds2[[2]]$data$distr
times1 = as.integer(colnames(distr1))
times2 = as.integer(colnames(distr2))
expect_false(length(times1) == length(times2))

pred2 = do.call(c, preds2)
expect_prediction_surv(pred2)
expect_class(pred2$data$distr, "Distribution")
surv_mat2 = pred2$data$distr
expect_class(surv_mat2, "matrix")
# check that time points are properly combined
times = as.integer(colnames(surv_mat2))
expect_true(all(times == sort(union(times1, times2), decreasing = F)))

# combining survival arrays
arr_preds = mlr3misc::map(preds, reshape_distr_to_3d)
arr_preds = mlr3misc::map(preds2, reshape_distr_to_3d)
arr_pred = do.call(c, arr_preds)
expect_prediction_surv(arr_pred)
expect_class(arr_pred$data$distr, "array")
expect_class(arr_pred$distr, "Arrdist")
# check that time points are combined properly
# check that time points are properly combined
times1 = as.integer(colnames(arr_preds[[1]]$data$distr))
times2 = as.integer(colnames(arr_preds[[2]]$data$distr))
times = as.integer(colnames(arr_pred$data$distr))
expect_equal(as.integer(colnames(arr_pred$data$distr)),
unique(sort(c(times1, times2))))
sort(union(times1, times2), decreasing = F))

p1 = lrn("surv.kaplan")$train(task)$predict(task)
p2 = suppressWarnings(lrn("surv.coxph")$train(task))$predict(task)
expect_error(c(p1, p2), "Cannot combine")

# combining distr predictions with exactly the same time points
# combining predictions with exactly the same time points
p1 = lrn("surv.kaplan")$train(task)$predict(task)
p2 = p1$clone()
expect_equal(length(c(p1, p2, keep_duplicates = TRUE)$row_ids), 40)
Expand All @@ -81,6 +97,55 @@ test_that("c", {
arr_pred = do.call(c, arr_preds)
expect_class(arr_pred$data$distr, "array") # combination is an array
expect_equal(colnames(arr_pred$data$distr), colnames(arr_p1$data$distr)) # same time points

# combining distr6::Distribution objects of the same type
# Matdist
p1$data$distr = p1$distr
p2$data$distr = p2$distr
preds2 = list(p1, p2)
pred2 = do.call(c, preds2)
expect_class(pred2$data$distr, "matrix")
expect_true(all(pred2$data$distr == p12$data$distr))

# Arrdist
arr_p1$data$distr = arr_p1$distr
arr_p2$data$distr = arr_p2$distr
arr_preds2 = list(arr_p1, arr_p2)
arr_pred2 = do.call(c, arr_preds2)
expect_class(arr_pred2$data$distr, "array")
expect_true(all(arr_pred2$data$distr == arr_pred$data$distr))

# combining distr6::Distribution objects of different types
mix_preds = list(p1, arr_p2) # Matdist and Arrdist
expect_error(supressWarnings(do.call(c, mix_preds)))

# combine survival matrix and Matdist (matrix converts to a Matdist)
p1 = lrn("surv.kaplan")$train(task)$predict(task)
p2 = p1$clone()
p2$data$distr = p2$distr
preds = list(p1, p2)
expect_prediction_surv(do.call(c, preds))

# combine survival array and Matdist (array converts to an Arrdist)
preds = list(reshape_distr_to_3d(p1), p2)
expect_error(do.call(c, preds))

# combine survival matrix and Arrdist (matrix converts to a Matdist)
p2 = p1$clone()
p2 = reshape_distr_to_3d(p2)
p2$data$distr = p2$distr
preds = list(p1, p2)
expect_error(supressWarnings(do.call(c, preds)))

# combine survival array and Arrdist (array converts to an Arrdist)
preds = list(reshape_distr_to_3d(p1), p2)
expect_prediction_surv(do.call(c, preds))

# combine survival matrix and array
p2 = p1$clone()
p2 = reshape_distr_to_3d(p2)
preds = list(p1, p2)
expect_error(do.call(c, preds), "Cannot combine")
})

test_that("data.frame roundtrip", {
Expand Down

0 comments on commit 707482e

Please sign in to comment.