Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
strengejacke committed Mar 26, 2021
1 parent 1ba93e2 commit 314ab79
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 4 deletions.
20 changes: 19 additions & 1 deletion R/get_predictions_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ get_predictions_survival <- function(model, fitfram, ci.lvl, type, terms, ...) {
clean_terms <- .clean_terms(terms)
ff <- fitfram[clean_terms]

do.call(rbind, lapply(seq_len(nrow(ff)), function(i) {
out <- do.call(rbind, lapply(seq_len(nrow(ff)), function(i) {
dat <- data.frame(
time = prdat$time,
predicted = pr[, i],
Expand All @@ -60,4 +60,22 @@ get_predictions_survival <- function(model, fitfram, ci.lvl, type, terms, ...) {

cbind(dat[, 1, drop = FALSE], dat2, dat[, 2:4])
}))

if (min(out$time, na.rm = TRUE) > 1) {
time <- 1
predicted <- ifelse(type == "surv", 1, 0)
conf.low <- ifelse(type == "surv", 1, 0)
conf.high <- ifelse(type == "surv", 1, 0)

dat <- expand.grid(lapply(out[clean_terms], unique))
names(dat) <- clean_terms

out <- rbind(
out,
cbind(time = 1, dat, predicted = predicted,
conf.low = conf.low, conf.high = conf.high)
)
}

out
}
2 changes: 1 addition & 1 deletion R/ggpredict.R
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ ggpredict <- function(model,
"zero_inflated_random" = "re.zi",
"zi_prob" = "zi.prob",
"survival" = "surv",
"cumulative_hazard" = "cumhaz" ,
"cumulative_hazard" = "cumhaz",
type
)

Expand Down
3 changes: 2 additions & 1 deletion R/post_processing_labels.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
n.trials = attr(data_grid, "n.trials", exact = TRUE),
prediction.interval = prediction.interval,
condition = condition,
ci.lvl = ci.lvl
ci.lvl = ci.lvl,
type = type
)
}
3 changes: 2 additions & 1 deletion R/utils_set_attr.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#' @importFrom insight link_inverse link_function
.set_attributes_and_class <- function(data, model, t.title, x.title, y.title, l.title, legend.labels, x.axis.labels, model_info, constant.values = NULL, terms = NULL, original_terms = NULL, at_list = NULL, n.trials = NULL, prediction.interval = NULL, condition = NULL, ci.lvl = .95) {
.set_attributes_and_class <- function(data, model, t.title, x.title, y.title, l.title, legend.labels, x.axis.labels, model_info, constant.values = NULL, terms = NULL, original_terms = NULL, at_list = NULL, n.trials = NULL, prediction.interval = NULL, condition = NULL, ci.lvl = .95, type = NULL) {
# check correct labels
if (!is.null(x.axis.labels) && length(x.axis.labels) != length(stats::na.omit(unique(data$x))))
x.axis.labels <- as.vector(sort(stats::na.omit(unique(data$x))))
Expand All @@ -23,6 +23,7 @@
attr(data, "prediction.interval") <- prediction.interval
attr(data, "condition") <- condition
attr(data, "ci.lvl") <- ci.lvl
attr(data, "type") <- type
attr(data, "response.name") <- insight::find_response(model)

# remember fit family
Expand Down

0 comments on commit 314ab79

Please sign in to comment.