R package ifedtree
and replication codes for paper "A Tree-based Model Averaging Approach for Personalized Treatment Effect Estimation from Heterogeneous Data Sources" [paper], which has recently been accepted by ICML 2022 for publication.
An earlier version received the Student Research Award at the 35th New England Statistics Symposium and Honorable Mention Award in the ASA Student Paper Award competition (SLDS section) at JSM 2021.
To install this package in R, run the following commands:
install.packages("devtools")
devtools::install_github("ellenxtan/ifedtree")
library(ifedtree)
data(SimDataLst)
K <- length(SimDataLst)
covars <- grep("^X", names(SimDataLst[[1]]), value=TRUE)
# coordinating site
coord_id <- 1
coord_test <- GenSimData(coord_id)
coord_df <- SimDataLst[[coord_id]]
# local models (causal forest from `grf` package as an example)
fit_lst <- list()
for (k in 1:K) {
df <- SimDataLst[[k]]
fit_lst[[k]] <- grf::causal_forest(X=as.matrix(df[, covars, with=FALSE]), Y=df$Y, W=df$Z)
}
# augmented coordinating site data
aug_df <- GenAugData(coord_id, coord_df, fit_lst, covars)
# ensemble tree
et_fit <- EnsemTree(coord_id, aug_df, "site", covars)$myfit
PlotTree(et_fit)
BestLinearProj(et_fit, coord_df, coord_id, "site", "Z", "Y", covars)
# ensemble forest
ef_fit <- EnsemForest(coord_id, aug_df, "site", covars, importance="impurity")$myfit
PlotForestImp(ef_fit)
PlotForestPred(aug_df, coord_df, coord_id, ef_fit, "site", covars, "site", "X1")
BestLinearProj(ef_fit, coord_df, coord_id, "site", "Z", "Y", covars)
- Simulation codes are under folder code_for_paper
- Run the following bash script.
sbatch run_hetero.sh
- Run the following bash script.
- Real data access and preparation
If you find this repository useful, please cite:
@inproceedings{tan2022tree,
title={A tree-based model averaging approach for personalized treatment effect estimation from heterogeneous data sources},
author={Tan, Xiaoqing and Chang, Chung-Chou H and Zhou, Ling and Tang, Lu},
booktitle={International Conference on Machine Learning},
pages={21013--21036},
year={2022},
organization={PMLR}
}