Installation • Usage • Datasets • Reproducibility • Credits
This repository contains the codes accompanying the paper "CoxKAN: Kolmogorov-Arnold Networks for Interpretable, High-Performance Survival Analysis".
- Paper: ArXiv.
- Installation:
pip install coxkan
- Documentation: Read-the-Docs.
- Quick-start:
tutorials/intro.ipynb
Repo Structure:
├── checkpoints/ # Results / checkpoints from paper
├── configs/ # Model configuration files
├── coxkan/ # CoxKAN package
├── data/ # Data
├── docs/ # Documentation
├── media/ # Figures used in paper
├── reprod/ # Reproducability instructions/code
├── tutorials/ # Tutorials for CoxKAN
|
# standard stuff:
├── .gitignore
├── LICENSE
├── README.md
└── setup.py
CoxKAN can be installed via:
pip install coxkan
Alternatively, may desire the full codebase and environment that was used to produce all results in the associated paper:
git clone https://github.com/knottwill/CoxKAN.git
cd CoxKAN
pip install -r reprod/requirements.txt
Please refer to reproducibility instructions in reprod/README.md
.
Find tutorials in tutorials/
or Read-the-Docs
from coxkan import CoxKAN
from coxkan.datasets import metabric
df_train, df_test = metabric.load(split=True)
ckan = CoxKAN(width=[len(metabric.covariates), 1])
_ = ckan.train(
df_train,
df_test,
duration_col='duration',
event_col='event',
steps=100)
# evaluate model
ckan.cindex(df_test)
>>> 0.6441975461899737
The coxkan/
package has 4 main components:
coxkan/
├── datasets/ # datasets subpackage
├── CoxKAN.py # CoxKAN model
├── utils.py # utility functions
└── hyperparam_search.py # hyperparameter searching
coxkan.datasets.create_dataset
makes it easy to generate synthetic survival data assuming a proportional-hazards, time-independant hazard function:
In the example below, we use a log-partial hazard of
from coxkan.datasets import create_dataset
log_partial_hazard = lambda x1, x2: 5*np.exp(-2*(x1**2 + x2**2))
df = create_dataset(log_partial_hazard, baseline_hazard=0.01, n_samples=10000)
5 clinical datasets are available with the coxkan.datasets
subpackage (inspired by pycox). For example:
from coxkan.datasets import gbsg
df_train, df_test = gbsg.load(split=True)
You can decide where to store them using the COXKAN_DATA_DIR
environment variable.
Dataset | Description | Source |
---|---|---|
GBSG | The Rotterdam & German Breast Cancer Study Group. | DeepSurv |
METABRIC | The Molecular Taxonomy of Breast Cancer International Consortium. | DeepSurv |
SUPPORT | Study to Understand Prognoses Preferences Outcomes and Risks of Treatment. | DeepSurv |
NWTCO | National Wilm's Tumor Study. | Rdatasets |
FLCHAIN | Assay of Serum Free Light Chain. | Rdatasets |
Unfortunately, DeepSurv did not retain the column names. We manually restored the names by obtaining the datasets elsewhere and comparing the columns (then we can use the same train/test split):
- GBSG: https://www.kaggle.com/datasets/utkarshx27/breast-cancer-dataset-used-royston-and-altman
- SUPPORT: https://hbiostat.org/data/repo/support2csv.zip
- METABRIC: https://www.kaggle.com/datasets/raghadalharbi/breast-cancer-gene-expression-profiles-metabric
We curated 4 genomics datasets from The Cancer Genome Atlas Program (TCGA). The raw or pre-processed data is available by request - please email me at knottenbeltwill@gmail.com.
Two of the datasets (GBMLGG, KIRC) were the unaltered datasets used in Pathomic Fusion
Dataset | Description | Source |
---|---|---|
STAD | Stomach Adenocarcinoma. | TCGA |
BRCA | Breast Invasive Carcinoma. | TCGA |
GBM/LGG | Merged dataset from two types of brain cancer: Glioblastoma Multiforme and Lower Grade Glioma. | Chen et al. |
KIRC | Kidney Renal Clear Cell Carcinoma. | Chen et al. |
All results in the associated paper can be reproduced using the codes in reprod/
. Please refer to the instructions in reprod/README.md
.
Special thanks to:
- All authors of Kolmogorov-Arnold Networks and the incredible pykan package.
- Håvard Kvamme for pycox and torchtuples.