Skip to content

Commit 025cc59

Browse files
hfricksimonpcouch
andauthored
Enable fit_xy() for censored regression (#829)
* enable `fit_xy()` for censored regression * update news * Update NEWS.md Co-authored-by: Simon P. Couch <simonpatrickcouch@gmail.com> Co-authored-by: Simon P. Couch <simonpatrickcouch@gmail.com>
1 parent ed34674 commit 025cc59

File tree

6 files changed

+10
-25
lines changed

6 files changed

+10
-25
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.0.2.9000
3+
Version: 1.0.2.9001
44
Authors@R: c(
55
person("Max", "Kuhn", , "max@rstudio.com", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "davis@rstudio.com", role = "aut"),

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* The matrix interface for fitting `fit_xy()` now works for the `"censored regression"` mode (#829).
4+
35
# parsnip 1.0.2
46

57
* A bagged neural network model was added (`bag_mlp()`). Engine implementations will live in the baguette package.

R/fit.R

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,6 @@ fit_xy.model_spec <-
228228
rlang::abort("Please set the mode in the model specification.")
229229
}
230230

231-
if (object$mode == "censored regression") {
232-
rlang::abort("Models for censored regression must use the formula interface.")
233-
}
234-
235231
if (inherits(object, "surv_reg")) {
236232
rlang::abort("Survival models must use the formula interface.")
237233
}
@@ -408,9 +404,10 @@ check_xy_interface <- function(x, y, cl, model) {
408404
inher(x, c("data.frame", "matrix"), cl)
409405
}
410406

411-
# `y` can be a vector (which is not a class), or a factor (which is not a vector)
407+
# `y` can be a vector (which is not a class), or a factor or
408+
# Surv object (which are not vectors)
412409
if (!is.null(y) && !is.vector(y))
413-
inher(y, c("data.frame", "matrix", "factor"), cl)
410+
inher(y, c("data.frame", "matrix", "factor", "Surv"), cl)
414411

415412
# rule out spark data sets that don't use the formula interface
416413
if (inherits(x, "tbl_spark") | inherits(y, "tbl_spark"))

R/misc.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,10 @@ check_outcome <- function(y, spec) {
330330
if (!all(map_lgl(y, is.factor))) {
331331
rlang::abort("For a classification model, the outcome should be a factor.")
332332
}
333+
} else if (spec$mode == "censored regression") {
334+
if (!inherits(y, "Surv")) {
335+
rlang::abort("For a censored regression model, the outcome should be a `Surv` object.")
336+
}
333337
}
334338
invisible(NULL)
335339
}

tests/testthat/test_proportional_hazards.R

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,3 @@ test_that("updating", {
99
test_that("bad input", {
1010
expect_error(proportional_hazards(mode = ", classification"))
1111
})
12-
13-
test_that("wrong fit interface", {
14-
expect_error(
15-
expect_message(
16-
proportional_hazards() %>% fit_xy()
17-
),
18-
"must use the formula interface"
19-
)
20-
})

tests/testthat/test_survival_reg.R

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,3 @@ test_that("updating", {
3030
test_that("bad input", {
3131
expect_error(survival_reg(mode = ", classification"))
3232
})
33-
34-
test_that("wrong fit interface", {
35-
expect_error(
36-
expect_message(
37-
survival_reg() %>% fit_xy()
38-
),
39-
"must use the formula interface"
40-
)
41-
})

0 commit comments

Comments
 (0)