-
-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stratified Cox regression #1489
Comments
Maybe this is simpler to implement than it looks at first. Here is a toy example, which hacks the generated brms code such that we fit at the end one baseline hazard M-spline per stratum. This is what I am testing right now here. Things are still hard-coded, but I think it can be generalised easily: ## This script needs brms 2.21.0. The code below runs on DaVinci with R 4.3.1
library(simsurv)
library(survival)
library(dplyr)
library(brms)
options(brms.backend="cmdstanr", cmdstanr_write_stan_file_dir=file.path("~/brms-cache"))
# create cache directory if not yet available
dir.create(file.path("~/brms-cache"), FALSE)
set.seed(456456)
# simulated data from the rstanarm::stan_surv example
covs <- data.frame(id = 1:200, trt = stats::rbinom(200, 1L, 0.5), l=rnorm(200)) %>%
mutate(stratum=cut(l, c(-Inf, -0.5, 0.5, Inf)), stratum_index=as.integer(stratum))
d1 <- simsurv(lambdas = 0.1,
gammas = 1.5,
betas = c(trt = -0.5, l=-0.2),
x = covs,
maxt = 5)
d1 <- merge(d1, covs)
## now fit the cox model with a different baseline hazards for each
## stratum:
fit_coxph <- coxph(Surv(eventtime, status) ~ trt + strata(stratum), data = d1)
summary(fit_coxph)
## the equivalent in brms would be great to have:
fit_brm <- brm(eventtime | cens(1 - status) ~ 1 + trt,
data = d1, family = brmsfamily("cox"), refresh=0, cores=4)
summary(fit_brm)
## with a bit of extra magic code this should work ok:
## define more baselines
sv <- stanvar(d1$stratum_index, name="stratum") +
stanvar(name="baseline_strata_simplex",
scode="array[3] simplex[Kbhaz] sbhaz_stratum;",
block="parameters") +
stanvar(scode="lprior += dirichlet_lpdf(sbhaz_stratum[1] | con_sbhaz);", block="tparameters") +
stanvar(scode="lprior += dirichlet_lpdf(sbhaz_stratum[2] | con_sbhaz);", block="tparameters") +
stanvar(scode="lprior += dirichlet_lpdf(sbhaz_stratum[3] | con_sbhaz);", block="tparameters") +
stanvar(scode="
// make baseline function stratum specific
for(n in 1:N) {
bhaz[n] = Zbhaz[n] * sbhaz_stratum[stratum[n]];
cbhaz[n] = Zcbhaz[n] * sbhaz_stratum[stratum[n]];
}
",
block="likelihood")
fit_brm_strata <- brm(eventtime | cens(1 - status) ~ 1 + trt,
stanvars=sv,
data = d1, family = brmsfamily("cox"), refresh=0, cores=4)
summary(fit_brm_strata)
## here one can check that we did indeed estimate per stratum a
## baseline hazard function on its own:
stancode(fit_brm_strata) The relevant Stan code then becomes: ...
parameters {
vector[Kc] b; // regression coefficients
real Intercept; // temporary intercept for centered predictors
simplex[Kbhaz] sbhaz; // baseline coefficients
array[3] simplex[Kbhaz] sbhaz_stratum;
}
transformed parameters {
real lprior = 0; // prior contributions to the log posterior
lprior += dirichlet_lpdf(sbhaz_stratum[1] | con_sbhaz);
lprior += dirichlet_lpdf(sbhaz_stratum[2] | con_sbhaz);
lprior += dirichlet_lpdf(sbhaz_stratum[3] | con_sbhaz);
lprior += student_t_lpdf(Intercept | 3, 1.5, 2.5);
lprior += dirichlet_lpdf(sbhaz | con_sbhaz);
}
model {
// likelihood including constants
if (!prior_only) {
// compute values of baseline function
vector[N] bhaz = Zbhaz * sbhaz;
// compute values of cumulative baseline function
vector[N] cbhaz = Zcbhaz * sbhaz;
// initialize linear predictor term
vector[N] mu = rep_vector(0.0, N);
// make baseline function stratum specific
for(n in 1:N) {
bhaz[n] = Zbhaz[n] * sbhaz_stratum[stratum[n]];
cbhaz[n] = Zcbhaz[n] * sbhaz_stratum[stratum[n]];
}
mu += Intercept + Xc * b;
for (n in 1:N) {
// special treatment of censored data
if (cens[n] == 0) {
target += cox_log_lpdf(Y[n] | mu[n], bhaz[n], cbhaz[n]);
} else if (cens[n] == 1) {
target += cox_log_lccdf(Y[n] | mu[n], bhaz[n], cbhaz[n]);
} else if (cens[n] == -1) {
target += cox_log_lcdf(Y[n] | mu[n], bhaz[n], cbhaz[n]);
}
}
}
// priors including constants
target += lprior;
}
... BTW, the statement int stratum[200]; which looks like a bug to me as it should be the new array syntax here. |
Just a quick note that we may want to add a new addition term |
This feature is now implemented. Here is an example: fit <- brm(
time | cens(censored) + bhaz(gr = sex) ~ age * sex + disease + (1|patient),
data = kidney,
family = cox()
) |
It would be great to allow for Cox regression models which have a differing baseline hazard function for different strata of the data. Here is an example:
Created on 2023-04-27 with reprex v2.0.2
The text was updated successfully, but these errors were encountered: