diff --git a/tests/testthat/test_PredictionSurv.R b/tests/testthat/test_PredictionSurv.R index d07d53207..ba7ba5171 100644 --- a/tests/testthat/test_PredictionSurv.R +++ b/tests/testthat/test_PredictionSurv.R @@ -33,39 +33,55 @@ test_that("c", { expect_true(length(times1) == length(times2)) expect_false(all(times1 == times2)) - # so internal prediction data must be a distribution and not a survival matrix pred = do.call(c, preds) expect_prediction_surv(pred) - expect_class(pred$data$distr, "Distribution") + surv_mat = pred$data$distr + expect_class(surv_mat, "matrix") + # check that time points are properly combined + times = as.integer(colnames(surv_mat)) + expect_true(all(times == sort(union(times1, times2), decreasing = F))) # data.table conversion dt = as.data.table(pred) expect_data_table(dt, nrows = task$nrow, ncols = 5L, any.missing = FALSE) # different number of time points + # add extra time point on the 2nd prediction object preds2 = rr$predictions() preds2[[2]]$data$distr = cbind(distr2, matrix(data = rep(0.3, 10), ncol = 1, dimnames = list(NULL, 108))) - expect_false(ncol(preds2[[1]]$data$distr) == ncol(preds2[[2]]$data$distr)) + distr1 = preds2[[1]]$data$distr + distr2 = preds2[[2]]$data$distr + times1 = as.integer(colnames(distr1)) + times2 = as.integer(colnames(distr2)) + expect_false(length(times1) == length(times2)) + pred2 = do.call(c, preds2) expect_prediction_surv(pred2) - expect_class(pred2$data$distr, "Distribution") + surv_mat2 = pred2$data$distr + expect_class(surv_mat2, "matrix") + # check that time points are properly combined + times = as.integer(colnames(surv_mat2)) + expect_true(all(times == sort(union(times1, times2), decreasing = F))) # combining survival arrays - arr_preds = mlr3misc::map(preds, reshape_distr_to_3d) + arr_preds = mlr3misc::map(preds2, reshape_distr_to_3d) arr_pred = do.call(c, arr_preds) expect_prediction_surv(arr_pred) expect_class(arr_pred$data$distr, "array") expect_class(arr_pred$distr, "Arrdist") - # check that time points are combined properly + # check that time points are properly combined + times1 = as.integer(colnames(arr_preds[[1]]$data$distr)) + times2 = as.integer(colnames(arr_preds[[2]]$data$distr)) + times = as.integer(colnames(arr_pred$data$distr)) expect_equal(as.integer(colnames(arr_pred$data$distr)), - unique(sort(c(times1, times2)))) + sort(union(times1, times2), decreasing = F)) p1 = lrn("surv.kaplan")$train(task)$predict(task) p2 = suppressWarnings(lrn("surv.coxph")$train(task))$predict(task) expect_error(c(p1, p2), "Cannot combine") - # combining distr predictions with exactly the same time points + # combining predictions with exactly the same time points p1 = lrn("surv.kaplan")$train(task)$predict(task) p2 = p1$clone() expect_equal(length(c(p1, p2, keep_duplicates = TRUE)$row_ids), 40) @@ -81,6 +97,55 @@ test_that("c", { arr_pred = do.call(c, arr_preds) expect_class(arr_pred$data$distr, "array") # combination is an array expect_equal(colnames(arr_pred$data$distr), colnames(arr_p1$data$distr)) # same time points + + # combining distr6::Distribution objects of the same type + # Matdist + p1$data$distr = p1$distr + p2$data$distr = p2$distr + preds2 = list(p1, p2) + pred2 = do.call(c, preds2) + expect_class(pred2$data$distr, "matrix") + expect_true(all(pred2$data$distr == p12$data$distr)) + + # Arrdist + arr_p1$data$distr = arr_p1$distr + arr_p2$data$distr = arr_p2$distr + arr_preds2 = list(arr_p1, arr_p2) + arr_pred2 = do.call(c, arr_preds2) + expect_class(arr_pred2$data$distr, "array") + expect_true(all(arr_pred2$data$distr == arr_pred$data$distr)) + + # combining distr6::Distribution objects of different types + mix_preds = list(p1, arr_p2) # Matdist and Arrdist + expect_error(supressWarnings(do.call(c, mix_preds))) + + # combine survival matrix and Matdist (matrix converts to a Matdist) + p1 = lrn("surv.kaplan")$train(task)$predict(task) + p2 = p1$clone() + p2$data$distr = p2$distr + preds = list(p1, p2) + expect_prediction_surv(do.call(c, preds)) + + # combine survival array and Matdist (array converts to an Arrdist) + preds = list(reshape_distr_to_3d(p1), p2) + expect_error(do.call(c, preds)) + + # combine survival matrix and Arrdist (matrix converts to a Matdist) + p2 = p1$clone() + p2 = reshape_distr_to_3d(p2) + p2$data$distr = p2$distr + preds = list(p1, p2) + expect_error(supressWarnings(do.call(c, preds))) + + # combine survival array and Arrdist (array converts to an Arrdist) + preds = list(reshape_distr_to_3d(p1), p2) + expect_prediction_surv(do.call(c, preds)) + + # combine survival matrix and array + p2 = p1$clone() + p2 = reshape_distr_to_3d(p2) + preds = list(p1, p2) + expect_error(do.call(c, preds), "Cannot combine") }) test_that("data.frame roundtrip", {