Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stratified Cox regression #1489

Closed
wds15 opened this issue Apr 27, 2023 · 3 comments
Closed

Stratified Cox regression #1489

wds15 opened this issue Apr 27, 2023 · 3 comments
Labels
Milestone

Comments

@wds15
Copy link
Contributor

wds15 commented Apr 27, 2023

It would be great to allow for Cox regression models which have a differing baseline hazard function for different strata of the data. Here is an example:

library(simsurv)
library(survival)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(brms)
#> Loading required package: Rcpp
#> Loading 'brms' package (version 2.19.0). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
#> 
#> Attaching package: 'brms'
#> The following object is masked from 'package:survival':
#> 
#>     kidney
#> The following object is masked from 'package:stats':
#> 
#>     ar
# instruct brms to use cmdstanr as backend and cache all Stan binaries
options(brms.backend="cmdstanr", cmdstanr_write_stan_file_dir=here::here("brms-cache"))
# create cache directory if not yet available
dir.create(here::here("brms-cache"), FALSE)

set.seed(456456)

# simulated data from the rstanarm::stan_surv example
covs <- data.frame(id  = 1:200, trt = stats::rbinom(200, 1L, 0.5), l=rnorm(200)) %>%
    mutate(stratum=cut(l, c(-Inf, -0.5, 0.5, Inf)))
d1 <- simsurv(lambdas = 0.1,
              gammas  = 1.5,
              betas   = c(trt = -0.5, l=-0.2),
              x       = covs,
              maxt    = 5)
d1 <- merge(d1, covs)


## now fit the cox model with a different baseline hazard for each
## stratum:
fit_coxph <- coxph(Surv(eventtime, status) ~ trt + strata(stratum), data = d1)
summary(fit_coxph)
#> Call:
#> coxph(formula = Surv(eventtime, status) ~ trt + strata(stratum), 
#>     data = d1)
#> 
#>   n= 200, number of events= 120 
#> 
#>        coef exp(coef) se(coef)      z Pr(>|z|)  
#> trt -0.4335    0.6482   0.1897 -2.285   0.0223 *
#> ---
#> Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#> 
#>     exp(coef) exp(-coef) lower .95 upper .95
#> trt    0.6482      1.543    0.4469    0.9402
#> 
#> Concordance= 0.558  (se = 0.025 )
#> Likelihood ratio test= 5.33  on 1 df,   p=0.02
#> Wald test            = 5.22  on 1 df,   p=0.02
#> Score (logrank) test = 5.3  on 1 df,   p=0.02

## the equivalent in brms would be great to have:
fit_brm <- brm(eventtime | cens(1 - status) ~ 1 + trt,
            data = d1, family = brmsfamily("cox"), refresh=0, cores=4)
#> Start sampling
#> Running MCMC with 4 parallel chains...
#> 
#> Chain 3 finished in 0.4 seconds.
#> Chain 1 finished in 0.5 seconds.
#> Chain 2 finished in 0.5 seconds.
#> Chain 4 finished in 0.5 seconds.
#> 
#> All 4 chains finished successfully.
#> Mean chain execution time: 0.5 seconds.
#> Total execution time: 0.6 seconds.
summary(fit_brm)
#>  Family: cox 
#>   Links: mu = log 
#> Formula: eventtime | cens(1 - status) ~ 1 + trt 
#>    Data: d1 (Number of observations: 200) 
#>   Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
#>          total post-warmup draws = 4000
#> 
#> Population-Level Effects: 
#>           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> Intercept     0.17      0.13    -0.09     0.42 1.00     3920     2592
#> trt          -0.51      0.19    -0.88    -0.14 1.00     3365     2644
#> 
#> Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
#> and Tail_ESS are effective sample size measures, and Rhat is the potential
#> scale reduction factor on split chains (at convergence, Rhat = 1).

Created on 2023-04-27 with reprex v2.0.2

@wds15
Copy link
Contributor Author

wds15 commented Apr 3, 2024

Maybe this is simpler to implement than it looks at first. Here is a toy example, which hacks the generated brms code such that we fit at the end one baseline hazard M-spline per stratum. This is what I am testing right now here. Things are still hard-coded, but I think it can be generalised easily:

## This script needs brms 2.21.0. The code below runs on DaVinci with R 4.3.1
library(simsurv)
library(survival)
library(dplyr)
library(brms)
options(brms.backend="cmdstanr", cmdstanr_write_stan_file_dir=file.path("~/brms-cache"))
# create cache directory if not yet available
dir.create(file.path("~/brms-cache"), FALSE)

set.seed(456456)

# simulated data from the rstanarm::stan_surv example
covs <- data.frame(id  = 1:200, trt = stats::rbinom(200, 1L, 0.5), l=rnorm(200)) %>%
    mutate(stratum=cut(l, c(-Inf, -0.5, 0.5, Inf)), stratum_index=as.integer(stratum))
d1 <- simsurv(lambdas = 0.1,
              gammas  = 1.5,
              betas   = c(trt = -0.5, l=-0.2),
              x       = covs,
              maxt    = 5)
d1 <- merge(d1, covs)


## now fit the cox model with a different baseline hazards for each
## stratum:
fit_coxph <- coxph(Surv(eventtime, status) ~ trt + strata(stratum), data = d1)
summary(fit_coxph)

## the equivalent in brms would be great to have:
fit_brm <- brm(eventtime | cens(1 - status) ~ 1 + trt,
            data = d1, family = brmsfamily("cox"), refresh=0, cores=4)

summary(fit_brm)


## with a bit of extra magic code this should work ok:

## define more baselines
sv <- stanvar(d1$stratum_index, name="stratum") +
    stanvar(name="baseline_strata_simplex",
            scode="array[3] simplex[Kbhaz] sbhaz_stratum;",
            block="parameters") +
    stanvar(scode="lprior += dirichlet_lpdf(sbhaz_stratum[1] | con_sbhaz);", block="tparameters") +
    stanvar(scode="lprior += dirichlet_lpdf(sbhaz_stratum[2] | con_sbhaz);", block="tparameters") +
    stanvar(scode="lprior += dirichlet_lpdf(sbhaz_stratum[3] | con_sbhaz);", block="tparameters") +
    stanvar(scode="
// make baseline function stratum specific
for(n in 1:N) {
   bhaz[n] = Zbhaz[n] * sbhaz_stratum[stratum[n]];
   cbhaz[n] = Zcbhaz[n] * sbhaz_stratum[stratum[n]];
}
",
block="likelihood")

fit_brm_strata <- brm(eventtime | cens(1 - status) ~ 1 + trt,
                      stanvars=sv,
                      data = d1, family = brmsfamily("cox"), refresh=0, cores=4)

summary(fit_brm_strata)

## here one can check that we did indeed estimate per stratum a
## baseline hazard function on its own:
stancode(fit_brm_strata)

The relevant Stan code then becomes:

...
parameters {
  vector[Kc] b;  // regression coefficients
  real Intercept;  // temporary intercept for centered predictors
  simplex[Kbhaz] sbhaz;  // baseline coefficients
  array[3] simplex[Kbhaz] sbhaz_stratum;
}
transformed parameters {
  real lprior = 0;  // prior contributions to the log posterior
  lprior += dirichlet_lpdf(sbhaz_stratum[1] | con_sbhaz);
  lprior += dirichlet_lpdf(sbhaz_stratum[2] | con_sbhaz);
  lprior += dirichlet_lpdf(sbhaz_stratum[3] | con_sbhaz);
  lprior += student_t_lpdf(Intercept | 3, 1.5, 2.5);
  lprior += dirichlet_lpdf(sbhaz | con_sbhaz);
}
model {
  // likelihood including constants
  if (!prior_only) {
    // compute values of baseline function
    vector[N] bhaz = Zbhaz * sbhaz;
    // compute values of cumulative baseline function
    vector[N] cbhaz = Zcbhaz * sbhaz;
    // initialize linear predictor term
    vector[N] mu = rep_vector(0.0, N);
    
  // make baseline function stratum specific
  for(n in 1:N) {
     bhaz[n] = Zbhaz[n] * sbhaz_stratum[stratum[n]];
     cbhaz[n] = Zcbhaz[n] * sbhaz_stratum[stratum[n]];
  }

    mu += Intercept + Xc * b;
    for (n in 1:N) {
    // special treatment of censored data
      if (cens[n] == 0) {
        target += cox_log_lpdf(Y[n] | mu[n], bhaz[n], cbhaz[n]);
      } else if (cens[n] == 1) {
        target += cox_log_lccdf(Y[n] | mu[n], bhaz[n], cbhaz[n]);
      } else if (cens[n] == -1) {
        target += cox_log_lcdf(Y[n] | mu[n], bhaz[n], cbhaz[n]);
      }
    }
  }
  // priors including constants
  target += lprior;
}
...

BTW, the statement stanvar(d1$stratum_index, name="stratum") creates a data entry with the old array syntax:

  int stratum[200];

which looks like a bug to me as it should be the new array syntax here.

@paul-buerkner
Copy link
Owner

Just a quick note that we may want to add a new addition term bhaz or something for this purpose. This would replace the bhaz argument in the cox() family.

@paul-buerkner paul-buerkner added this to the brms 2.22.0 milestone Apr 15, 2024
paul-buerkner added a commit that referenced this issue Sep 16, 2024
@paul-buerkner
Copy link
Owner

This feature is now implemented. Here is an example:

fit <- brm(
  time | cens(censored) + bhaz(gr = sex) ~ age * sex + disease + (1|patient),
  data = kidney, 
  family = cox()
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants