From feb43d64fdbd664b548380f10ac570776702b3d6 Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Wed, 8 Sep 2021 09:24:14 +0100 Subject: [PATCH 01/22] GPU version --- dsm/dsm_api.py | 25 ++++++++++++++++++------- dsm/utilities.py | 5 +++-- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index 57a3623..25d240c 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -52,17 +52,18 @@ class DSMBase(): """Base Class for all DSM models""" def __init__(self, k=3, layers=None, distribution="Weibull", - temp=1000., discount=1.0): + temp=1000., discount=1.0, cuda=False): self.k = k self.layers = layers self.dist = distribution self.temp = temp self.discount = discount self.fitted = False + self.cuda = cuda def _gen_torch_model(self, inputdim, optimizer, risks): """Helper function to return a torch model.""" - return DeepSurvivalMachinesTorch(inputdim, + model = DeepSurvivalMachinesTorch(inputdim, k=self.k, layers=self.layers, dist=self.dist, @@ -70,6 +71,9 @@ def _gen_torch_model(self, inputdim, optimizer, risks): discount=self.discount, optimizer=optimizer, risks=risks) + if self.cuda: + model = model.cuda() + return model def fit(self, x, t, e, vsize=0.15, val_data=None, iters=1, learning_rate=1e-3, batch_size=100, @@ -167,11 +171,14 @@ def compute_nll(self, x, t, e): for r in range(self.torch_model.risks): loss += float(losses.conditional_loss(self.torch_model, x_val, t_val, e_val, elbo=False, - risk=str(r+1)).detach().numpy()) + risk=str(r+1)).item()) return loss def _prepocess_test_data(self, x): - return torch.from_numpy(x) + data = torch.from_numpy(x) + if self.cuda: + data = data.cuda() + return data def _prepocess_training_data(self, x, t, e, vsize, val_data, random_state): @@ -201,6 +208,10 @@ def _prepocess_training_data(self, x, t, e, vsize, val_data, random_state): t_val = torch.from_numpy(t_val).double() e_val = torch.from_numpy(e_val).double() + if self.cuda: + x_train, t_train, e_train = x_train.cuda(), t_train.cuda(), e_train.cuda() + x_val, t_val, e_val = x_val.cuda(), t_val.cuda(), e_val.cuda() + return (x_train, t_train, e_train, x_val, t_val, e_val) @@ -219,7 +230,7 @@ def predict_mean(self, x, risk=1): if self.fitted: x = self._prepocess_test_data(x) - scores = losses.predict_mean(self.torch_model, x, risk=str(risk)) + scores = losses.predict_mean(self.torch_model, x, risk=str(risk)).detach().cpu().numpy() return scores else: raise Exception("The model has not been fitted yet. Please fit the " + @@ -268,7 +279,7 @@ def predict_survival(self, x, t, risk=1): if not isinstance(t, list): t = [t] if self.fitted: - scores = losses.predict_cdf(self.torch_model, x, t, risk=str(risk)) + scores = losses.predict_cdf(self.torch_model, x, t, risk=str(risk)).detach().cpu().numpy() return np.exp(np.array(scores)).T else: raise Exception("The model has not been fitted yet. Please fit the " + @@ -294,7 +305,7 @@ def predict_pdf(self, x, t, risk=1): if not isinstance(t, list): t = [t] if self.fitted: - scores = losses.predict_pdf(self.torch_model, x, t, risk=str(risk)) + scores = losses.predict_pdf(self.torch_model, x, t, risk=str(risk)).detach().cpu().numpy() return np.exp(np.array(scores)).T else: raise Exception("The model has not been fitted yet. Please fit the " + diff --git a/dsm/utilities.py b/dsm/utilities.py index 2c5396a..c15f85a 100644 --- a/dsm/utilities.py +++ b/dsm/utilities.py @@ -55,9 +55,10 @@ def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, premodel = DeepSurvivalMachinesTorch(1, 1, dist=model.dist, risks=model.risks, - optimizer=model.optimizer) - premodel.double() + optimizer=model.optimizer).double() + if model.is_cuda: + premodel.cuda() optimizer = get_optimizer(premodel, lr) oldcost = float('inf') From 04125d99cb13f8f591eba7e6fc520e745566da19 Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Wed, 8 Sep 2021 09:58:18 +0100 Subject: [PATCH 02/22] Update all models for GPU --- dsm/dsm_api.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index 25d240c..31cd72e 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -63,7 +63,7 @@ def __init__(self, k=3, layers=None, distribution="Weibull", def _gen_torch_model(self, inputdim, optimizer, risks): """Helper function to return a torch model.""" - model = DeepSurvivalMachinesTorch(inputdim, + return DeepSurvivalMachinesTorch(inputdim, k=self.k, layers=self.layers, dist=self.dist, @@ -71,9 +71,6 @@ def _gen_torch_model(self, inputdim, optimizer, risks): discount=self.discount, optimizer=optimizer, risks=risks) - if self.cuda: - model = model.cuda() - return model def fit(self, x, t, e, vsize=0.15, val_data=None, iters=1, learning_rate=1e-3, batch_size=100, @@ -126,6 +123,10 @@ def fit(self, x, t, e, vsize=0.15, val_data=None, maxrisk = int(np.nanmax(e_train.cpu().numpy())) model = self._gen_torch_model(inputdim, optimizer, risks=maxrisk) + + if self.cuda: + model = model.cuda() + model, _ = train_dsm(model, x_train, t_train, e_train, x_val, t_val, e_val, @@ -408,7 +409,10 @@ def _gen_torch_model(self, inputdim, optimizer, risks): risks=risks) def _prepocess_test_data(self, x): - return torch.from_numpy(_get_padded_features(x)) + data = torch.from_numpy(_get_padded_features(x)) + if self.cuda: + data = data.cuda() + return data def _prepocess_training_data(self, x, t, e, vsize, val_data, random_state): """RNNs require different preprocessing for variable length sequences""" @@ -449,6 +453,10 @@ def _prepocess_training_data(self, x, t, e, vsize, val_data, random_state): t_val = torch.from_numpy(t_val).double() e_val = torch.from_numpy(e_val).double() + if self.cuda: + x_train, t_train, e_train = x_train.cuda(), t_train.cuda(), e_train.cuda() + x_val, t_val, e_val = x_val.cuda(), t_val.cuda(), e_val.cuda() + return (x_train, t_train, e_train, x_val, t_val, e_val) From f7378520d93fb81d843b8afe39aa7100c3de0a76 Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Wed, 8 Sep 2021 10:18:09 +0100 Subject: [PATCH 03/22] Two modes for CUDA --- dsm/dsm_api.py | 8 ++++---- dsm/utilities.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index 31cd72e..6a5e243 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -59,7 +59,7 @@ def __init__(self, k=3, layers=None, distribution="Weibull", self.temp = temp self.discount = discount self.fitted = False - self.cuda = cuda + self.cuda = cuda # Two levels: 1 full GPU, 2 batch GPU (prefer 1 if fit on memory) def _gen_torch_model(self, inputdim, optimizer, risks): """Helper function to return a torch model.""" @@ -133,7 +133,7 @@ def fit(self, x, t, e, vsize=0.15, val_data=None, n_iter=iters, lr=learning_rate, elbo=elbo, - bs=batch_size) + bs=batch_size, cuda=self.cuda==2) self.torch_model = model.eval() self.fitted = True @@ -209,7 +209,7 @@ def _prepocess_training_data(self, x, t, e, vsize, val_data, random_state): t_val = torch.from_numpy(t_val).double() e_val = torch.from_numpy(e_val).double() - if self.cuda: + if self.cuda == 1: x_train, t_train, e_train = x_train.cuda(), t_train.cuda(), e_train.cuda() x_val, t_val, e_val = x_val.cuda(), t_val.cuda(), e_val.cuda() @@ -453,7 +453,7 @@ def _prepocess_training_data(self, x, t, e, vsize, val_data, random_state): t_val = torch.from_numpy(t_val).double() e_val = torch.from_numpy(e_val).double() - if self.cuda: + if self.cuda == 1: x_train, t_train, e_train = x_train.cuda(), t_train.cuda(), e_train.cuda() x_val, t_val, e_val = x_val.cuda(), t_val.cuda(), e_val.cuda() diff --git a/dsm/utilities.py b/dsm/utilities.py index c15f85a..e160fd8 100644 --- a/dsm/utilities.py +++ b/dsm/utilities.py @@ -59,6 +59,9 @@ def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, if model.is_cuda: premodel.cuda() + t_train, e_train = t_train.cuda(), e_train.cuda() + t_valid, e_valid = t_valid.cuda(), e_valid.cuda() + optimizer = get_optimizer(premodel, lr) oldcost = float('inf') @@ -114,7 +117,7 @@ def train_dsm(model, x_train, t_train, e_train, x_valid, t_valid, e_valid, n_iter=10000, lr=1e-3, elbo=True, - bs=100): + bs=100, cuda=False): """Function to train the torch instance of the model.""" logging.info('Pretraining the Underlying Distributions...') @@ -159,6 +162,9 @@ def train_dsm(model, if xb.shape[0] == 0: continue + if cuda: + xb, tb, eb = xb.cuda(), tb.cuda(), eb.cuda() + optimizer.zero_grad() loss = 0 for r in range(model.risks): @@ -174,6 +180,9 @@ def train_dsm(model, valid_loss = 0 for r in range(model.risks): + if cuda: + x_valid, t_valid_, e_valid_ = x_valid.cuda(), t_valid_.cuda(), e_valid_.cuda() + valid_loss += conditional_loss(model, x_valid, t_valid_, From f436ee2b87c95d60a646b5a0bfa09ec57f823c7b Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Wed, 8 Sep 2021 10:22:10 +0100 Subject: [PATCH 04/22] Fix loss computation --- dsm/dsm_api.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index 6a5e243..da1a3ab 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -168,6 +168,11 @@ def compute_nll(self, x, t, e): x_val, t_val, e_val = x_val,\ _reshape_tensor_with_nans(t_val),\ _reshape_tensor_with_nans(e_val) + + if self.cuda == 2: + # Data need to be on GPU if loss computed + x_val, t_val, e_val = x_val.cuda(), t_val.cuda(), e_val.cuda() + loss = 0 for r in range(self.torch_model.risks): loss += float(losses.conditional_loss(self.torch_model, From 31bc7de8c5e35b429c2b5d16957b371aa649f6fc Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Wed, 8 Sep 2021 11:49:53 +0100 Subject: [PATCH 05/22] Pretrain fix --- dsm/utilities.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dsm/utilities.py b/dsm/utilities.py index e160fd8..e700b65 100644 --- a/dsm/utilities.py +++ b/dsm/utilities.py @@ -50,14 +50,14 @@ def get_optimizer(model, lr): ' is not implemented') def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, - n_iter=10000, lr=1e-2, thres=1e-4): + n_iter=10000, lr=1e-2, thres=1e-4, cuda = False): premodel = DeepSurvivalMachinesTorch(1, 1, dist=model.dist, risks=model.risks, optimizer=model.optimizer).double() - if model.is_cuda: + if cuda: premodel.cuda() t_train, e_train = t_train.cuda(), e_train.cuda() t_valid, e_valid = t_valid.cuda(), e_valid.cuda() @@ -135,7 +135,7 @@ def train_dsm(model, e_valid_, n_iter=10000, lr=1e-2, - thres=1e-4) + thres=1e-4, cuda = cuda) for r in range(model.risks): model.shape[str(r+1)].data.fill_(float(premodel.shape[str(r+1)])) @@ -182,7 +182,7 @@ def train_dsm(model, for r in range(model.risks): if cuda: x_valid, t_valid_, e_valid_ = x_valid.cuda(), t_valid_.cuda(), e_valid_.cuda() - + valid_loss += conditional_loss(model, x_valid, t_valid_, From 1d1ab1891743d44b81c099975e5970bf7574b1cb Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Wed, 8 Sep 2021 12:51:18 +0100 Subject: [PATCH 06/22] Update GPU for list pdf --- dsm/dsm_api.py | 6 +++--- dsm/losses.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index da1a3ab..c8a0a62 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -236,7 +236,7 @@ def predict_mean(self, x, risk=1): if self.fitted: x = self._prepocess_test_data(x) - scores = losses.predict_mean(self.torch_model, x, risk=str(risk)).detach().cpu().numpy() + scores = losses.predict_mean(self.torch_model, x, risk=str(risk)) return scores else: raise Exception("The model has not been fitted yet. Please fit the " + @@ -285,7 +285,7 @@ def predict_survival(self, x, t, risk=1): if not isinstance(t, list): t = [t] if self.fitted: - scores = losses.predict_cdf(self.torch_model, x, t, risk=str(risk)).detach().cpu().numpy() + scores = losses.predict_cdf(self.torch_model, x, t, risk=str(risk)) return np.exp(np.array(scores)).T else: raise Exception("The model has not been fitted yet. Please fit the " + @@ -311,7 +311,7 @@ def predict_pdf(self, x, t, risk=1): if not isinstance(t, list): t = [t] if self.fitted: - scores = losses.predict_pdf(self.torch_model, x, t, risk=str(risk)).detach().cpu().numpy() + scores = losses.predict_pdf(self.torch_model, x, t, risk=str(risk)) return np.exp(np.array(scores)).T else: raise Exception("The model has not been fitted yet. Please fit the " + diff --git a/dsm/losses.py b/dsm/losses.py index c74d744..a8f1f23 100644 --- a/dsm/losses.py +++ b/dsm/losses.py @@ -324,7 +324,7 @@ def _weibull_pdf(model, x, t_horizon, risk='1'): lpdfs = torch.stack(lpdfs, dim=1) lpdfs = lpdfs+logits lpdfs = torch.logsumexp(lpdfs, dim=1) - pdfs.append(lpdfs.detach().numpy()) + pdfs.append(lpdfs.detach().cpu().numpy()) return pdfs @@ -357,7 +357,7 @@ def _weibull_cdf(model, x, t_horizon, risk='1'): lcdfs = torch.stack(lcdfs, dim=1) lcdfs = lcdfs+logits lcdfs = torch.logsumexp(lcdfs, dim=1) - cdfs.append(lcdfs.detach().numpy()) + cdfs.append(lcdfs.detach().cpu().numpy()) return cdfs @@ -386,7 +386,7 @@ def _weibull_mean(model, x, risk='1'): lmeans = lmeans+logits lmeans = torch.logsumexp(lmeans, dim=1) - return torch.exp(lmeans).detach().numpy() + return torch.exp(lmeans).detach().cpu().numpy() @@ -424,7 +424,7 @@ def _lognormal_cdf(model, x, t_horizon, risk='1'): lcdfs = torch.stack(lcdfs, dim=1) lcdfs = lcdfs+logits lcdfs = torch.logsumexp(lcdfs, dim=1) - cdfs.append(lcdfs.detach().numpy()) + cdfs.append(lcdfs.detach().cpu().numpy()) return cdfs @@ -461,7 +461,7 @@ def _normal_cdf(model, x, t_horizon, risk='1'): lcdfs = torch.stack(lcdfs, dim=1) lcdfs = lcdfs+logits lcdfs = torch.logsumexp(lcdfs, dim=1) - cdfs.append(lcdfs.detach().numpy()) + cdfs.append(lcdfs.detach().cpu().numpy()) return cdfs @@ -485,7 +485,7 @@ def _normal_mean(model, x, risk='1'): lmeans = lmeans*logits lmeans = torch.sum(lmeans, dim=1) - return lmeans.detach().numpy() + return lmeans.detach().cpu().numpy() def predict_mean(model, x, risk='1'): From 6eabc561d9339c01880a24f1da5e916a0ee20680 Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Wed, 8 Sep 2021 14:33:02 +0100 Subject: [PATCH 07/22] Push t to gpu --- dsm/losses.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dsm/losses.py b/dsm/losses.py index a8f1f23..c20d1b2 100644 --- a/dsm/losses.py +++ b/dsm/losses.py @@ -303,7 +303,7 @@ def _weibull_pdf(model, x, t_horizon, risk='1'): k_ = shape b_ = scale - t_horz = torch.tensor(t_horizon).double() + t_horz = torch.tensor(t_horizon).double().to(x.device) t_horz = t_horz.repeat(shape.shape[0], 1) pdfs = [] @@ -338,7 +338,7 @@ def _weibull_cdf(model, x, t_horizon, risk='1'): k_ = shape b_ = scale - t_horz = torch.tensor(t_horizon).double() + t_horz = torch.tensor(t_horizon).double().to(x.device) t_horz = t_horz.repeat(shape.shape[0], 1) cdfs = [] @@ -401,7 +401,7 @@ def _lognormal_cdf(model, x, t_horizon, risk='1'): k_ = shape b_ = scale - t_horz = torch.tensor(t_horizon).double() + t_horz = torch.tensor(t_horizon).double().to(x.device) t_horz = t_horz.repeat(shape.shape[0], 1) cdfs = [] From 9edb9cf92d6ffaf66bbeab30ea28a05fd9dad637 Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Wed, 8 Sep 2021 21:22:36 +0100 Subject: [PATCH 08/22] Fix pretraining --- dsm/utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsm/utilities.py b/dsm/utilities.py index e700b65..9a04b1d 100644 --- a/dsm/utilities.py +++ b/dsm/utilities.py @@ -135,7 +135,7 @@ def train_dsm(model, e_valid_, n_iter=10000, lr=1e-2, - thres=1e-4, cuda = cuda) + thres=1e-4, cuda = cuda or t_train.is_cuda) for r in range(model.risks): model.shape[str(r+1)].data.fill_(float(premodel.shape[str(r+1)])) From bcdd2f12dadb2c1504257aa3573f97574edc353d Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Tue, 2 Nov 2021 21:18:00 +0000 Subject: [PATCH 09/22] Typo fix --- dsm/dsm_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index d857614..ba3c290 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -180,7 +180,7 @@ def compute_nll(self, x, t, e): risk=str(r+1)).item()) return loss - def _prepocess_test_data(self, x): + def _preprocess_test_data(self, x): data = torch.from_numpy(x) if self.cuda: data = data.cuda() @@ -411,7 +411,7 @@ def _gen_torch_model(self, inputdim, optimizer, risks): typ=self.typ, risks=risks) - def _prepocess_test_data(self, x): + def _preprocess_test_data(self, x): data = torch.from_numpy(_get_padded_features(x)) if self.cuda: data = data.cuda() From ad3493b69ba9beb1c123a878a814358b4dc4437e Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Wed, 3 Nov 2021 09:08:53 +0000 Subject: [PATCH 10/22] Add nll computation --- dsm/contrib/dcm_api.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/dsm/contrib/dcm_api.py b/dsm/contrib/dcm_api.py index 6863bf7..6360f2a 100644 --- a/dsm/contrib/dcm_api.py +++ b/dsm/contrib/dcm_api.py @@ -3,7 +3,7 @@ import numpy as np from dsm.contrib.dcm_torch import DeepCoxMixturesTorch -from dsm.contrib.dcm_utilities import train_dcm, predict_survival +from dsm.contrib.dcm_utilities import * class DeepCoxMixtures(): @@ -178,4 +178,14 @@ def predict_survival(self, x, t): else: raise Exception("The model has not been fitted yet. Please fit the " + "model using the `fit` method on some training data " + - "before calling `predict_survival`.") \ No newline at end of file + "before calling `predict_survival`.") + + def compute_nll(self, x, t, e): + if not self.fitted: + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `_eval_nll`.") + processed_data = self._preprocess_training_data(x, t, e, 0, None, 0) + _, _, _, x_val, t_val, e_val = processed_data + with torch.no_grad(): + return - get_posteriors(repair_probs(get_likelihood(self.torch_model[0], self.torch_model[1], x_val, t_val, e_val))).sum().item() \ No newline at end of file From 064dc66af2db02be800ac7f8d4876e19d17a89e9 Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Thu, 4 Nov 2021 10:58:42 +0000 Subject: [PATCH 11/22] Update datasets loading to ensure columns and allow competing risks --- dsm/datasets.py | 35 ++++++++++++++++----------- examples/DSM on SUPPORT Dataset.ipynb | 2 +- examples/RDSM on PBC Dataset.ipynb | 2 +- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/dsm/datasets.py b/dsm/datasets.py index e3b4862..dc7753c 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -54,7 +54,7 @@ def increase_censoring(e, t, p): return e, t -def _load_framingham_dataset(sequential): +def _load_framingham_dataset(sequential, competing = False): """Helper function to load and preprocess the Framingham dataset. The Framingham Dataset is a subset of 4,434 participants of the well known, @@ -86,25 +86,31 @@ def _load_framingham_dataset(sequential): dat_num = data[['TOTCHOL', 'AGE', 'SYSBP', 'DIABP', 'CIGPDAY', 'BMI', 'HEARTRTE', 'GLUCOSE']] - x1 = pd.get_dummies(dat_cat).values - x2 = dat_num.values - x = np.hstack([x1, x2]) + x1 = pd.get_dummies(dat_cat) + x2 = dat_num + x = np.hstack([x1.values, x2.values]) time = (data['TIMEDTH'] - data['TIME']).values event = data['DEATH'].values + if competing: + time_cvd = (data['TIMECVD'] - data['TIME']).values + event_type = np.argmin(np.vstack([time, time_cvd]), 0) + event[event_type == 1] = 2 + time[event_type == 1] = time_cvd[event_type == 1] + x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) x_ = StandardScaler().fit_transform(x) if not sequential: - return x_, time, event + return x_, time, event, np.concatenate([x1.columns, x2.columns]) else: x, t, e = [], [], [] for id_ in sorted(list(set(data['RANDID']))): x.append(x_[data['RANDID'] == id_]) t.append(time[data['RANDID'] == id_]) e.append(event[data['RANDID'] == id_]) - return x, t, e + return x, t, e, np.concatenate([x1.columns, x2.columns]) def _load_pbc_dataset(sequential): """Helper function to load and preprocess the PBC dataset @@ -137,10 +143,10 @@ def _load_pbc_dataset(sequential): 'SGOT', 'platelets', 'prothrombin']] age = data['age'] + data['years'] - x1 = pd.get_dummies(dat_cat).values - x2 = dat_num.values - x3 = age.values.reshape(-1, 1) - x = np.hstack([x1, x2, x3]) + x1 = pd.get_dummies(dat_cat) + x2 = dat_num + x3 = age + x = np.hstack([x1.values, x2.values, x3.values.reshape(-1, 1)]) time = (data['years'] - data['year']).values event = data['status2'].values @@ -156,7 +162,7 @@ def _load_pbc_dataset(sequential): x.append(x_[data['id'] == id_]) t.append(time[data['id'] == id_]) e.append(event[data['id'] == id_]) - return x, t, e + return x, t, e, np.concatenate([x1.columns, x2.columns, x3.columns]) def _load_support_dataset(): """Helper function to load and preprocess the SUPPORT dataset. @@ -192,7 +198,7 @@ def _load_support_dataset(): x = StandardScaler().fit_transform(x) remove = ~np.isnan(t) - return x[remove], t[remove], e[remove] + return x[remove], t[remove], e[remove], np.concatenate([x1.columns, x2.columns]) def _load_mnist(): """Helper function to load and preprocess the MNIST dataset. @@ -222,7 +228,7 @@ def _load_mnist(): e, t = increase_censoring(np.ones(t.shape), t, p=.5) - return x, t, e + return x, t, e, train.data.columns def load_dataset(dataset='SUPPORT', **kwargs): """Helper function to load datasets to test Survival Analysis models. @@ -273,13 +279,14 @@ def load_dataset(dataset='SUPPORT', **kwargs): """ sequential = kwargs.get('sequential', False) + competing = kwargs.get('competing', False) if dataset == 'SUPPORT': return _load_support_dataset() if dataset == 'PBC': return _load_pbc_dataset(sequential) if dataset == 'FRAMINGHAM': - return _load_framingham_dataset(sequential) + return _load_framingham_dataset(sequential, competing) if dataset == 'MNIST': return _load_mnist() else: diff --git a/examples/DSM on SUPPORT Dataset.ipynb b/examples/DSM on SUPPORT Dataset.ipynb index 04455c5..2fb7c98 100644 --- a/examples/DSM on SUPPORT Dataset.ipynb +++ b/examples/DSM on SUPPORT Dataset.ipynb @@ -39,7 +39,7 @@ "outputs": [], "source": [ "from dsm import datasets\n", - "x, t, e = datasets.load_dataset('SUPPORT')" + "x, t, e, _ = datasets.load_dataset('SUPPORT')" ] }, { diff --git a/examples/RDSM on PBC Dataset.ipynb b/examples/RDSM on PBC Dataset.ipynb index 83b2cc7..7c5ed05 100644 --- a/examples/RDSM on PBC Dataset.ipynb +++ b/examples/RDSM on PBC Dataset.ipynb @@ -36,7 +36,7 @@ "outputs": [], "source": [ "from dsm import datasets\n", - "x, t, e = datasets.load_dataset('PBC', sequential = True)" + "x, t, e, _ = datasets.load_dataset('PBC', sequential = True)" ] }, { From c047098426276a683045e029d418ebf8688b5f28 Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Fri, 12 Nov 2021 22:33:42 +0000 Subject: [PATCH 12/22] Framingham evaluate at baseline --- dsm/datasets.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/dsm/datasets.py b/dsm/datasets.py index dc7753c..e366b10 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -80,6 +80,10 @@ def _load_framingham_dataset(sequential, competing = False): data = pkgutil.get_data(__name__, 'datasets/framingham.csv') data = pd.read_csv(io.BytesIO(data)) + if not sequential: + # Consider only first event + data = data.groupby('RANDID').first() + dat_cat = data[['SEX', 'CURSMOKE', 'DIABETES', 'BPMEDS', 'educ', 'PREVCHD', 'PREVAP', 'PREVMI', 'PREVSTRK', 'PREVHYP']] @@ -103,12 +107,13 @@ def _load_framingham_dataset(sequential, competing = False): x_ = StandardScaler().fit_transform(x) if not sequential: - return x_, time, event, np.concatenate([x1.columns, x2.columns]) + return x_, time + 1, event, np.concatenate([x1.columns, x2.columns]) else: + x_, data, time, event = x_[time > 0], data[time > 0], time[time > 0], event[time > 0] x, t, e = [], [], [] for id_ in sorted(list(set(data['RANDID']))): x.append(x_[data['RANDID'] == id_]) - t.append(time[data['RANDID'] == id_]) + t.append(time[data['RANDID'] == id_] + 1) e.append(event[data['RANDID'] == id_]) return x, t, e, np.concatenate([x1.columns, x2.columns]) @@ -155,12 +160,12 @@ def _load_pbc_dataset(sequential): x_ = StandardScaler().fit_transform(x) if not sequential: - return x_, time, event + return x_, time + 1, event else: x, t, e = [], [], [] for id_ in sorted(list(set(data['id']))): x.append(x_[data['id'] == id_]) - t.append(time[data['id'] == id_]) + t.append(time[data['id'] == id_] + 1) e.append(event[data['id'] == id_]) return x, t, e, np.concatenate([x1.columns, x2.columns, x3.columns]) @@ -198,7 +203,7 @@ def _load_support_dataset(): x = StandardScaler().fit_transform(x) remove = ~np.isnan(t) - return x[remove], t[remove], e[remove], np.concatenate([x1.columns, x2.columns]) + return x[remove], t[remove] + 1, e[remove], np.concatenate([x1.columns, x2.columns]) def _load_mnist(): """Helper function to load and preprocess the MNIST dataset. @@ -228,7 +233,7 @@ def _load_mnist(): e, t = increase_censoring(np.ones(t.shape), t, p=.5) - return x, t, e, train.data.columns + return x, t + 1, e, train.data.columns def load_dataset(dataset='SUPPORT', **kwargs): """Helper function to load datasets to test Survival Analysis models. From 5611a161a2424c84b3f1e666104b9bfcf781869c Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Wed, 15 Dec 2021 15:01:22 +0000 Subject: [PATCH 13/22] Update FRAMINGHAM competing risks --- dsm/datasets.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dsm/datasets.py b/dsm/datasets.py index e366b10..359d4a2 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -99,9 +99,9 @@ def _load_framingham_dataset(sequential, competing = False): if competing: time_cvd = (data['TIMECVD'] - data['TIME']).values - event_type = np.argmin(np.vstack([time, time_cvd]), 0) - event[event_type == 1] = 2 - time[event_type == 1] = time_cvd[event_type == 1] + event *= 2 + event[data['CVD'] == 1] = 1 + time[data['CVD'] == 1] = time_cvd[data['CVD'] == 1] x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) x_ = StandardScaler().fit_transform(x) From 877b2fad85ff32096821fae25ca25e90ae0cdc9b Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Tue, 21 Dec 2021 15:58:59 +0000 Subject: [PATCH 14/22] Add function for cluster assignement --- dsm/contrib/dcm_api.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/dsm/contrib/dcm_api.py b/dsm/contrib/dcm_api.py index 6360f2a..57941e9 100644 --- a/dsm/contrib/dcm_api.py +++ b/dsm/contrib/dcm_api.py @@ -188,4 +188,14 @@ def compute_nll(self, x, t, e): processed_data = self._preprocess_training_data(x, t, e, 0, None, 0) _, _, _, x_val, t_val, e_val = processed_data with torch.no_grad(): - return - get_posteriors(repair_probs(get_likelihood(self.torch_model[0], self.torch_model[1], x_val, t_val, e_val))).sum().item() \ No newline at end of file + return - get_posteriors(repair_probs(get_likelihood(self.torch_model[0], self.torch_model[1], x_val, t_val, e_val))).sum().item() + + def predict_alphas(self, x): + x = self._preprocess_test_data(x) + if self.fitted: + alphas, _ = self.torch_model[0](x) + return alphas.detach().exp().cpu().numpy() + else: + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `predict_alphas`.") \ No newline at end of file From 6d8cefb4cc469de0a6418c8b21add51857da7d00 Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Tue, 25 Jan 2022 18:04:13 +0000 Subject: [PATCH 15/22] Competing risks for PBC --- dsm/datasets.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dsm/datasets.py b/dsm/datasets.py index 359d4a2..81361ee 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -117,7 +117,7 @@ def _load_framingham_dataset(sequential, competing = False): e.append(event[data['RANDID'] == id_]) return x, t, e, np.concatenate([x1.columns, x2.columns]) -def _load_pbc_dataset(sequential): +def _load_pbc_dataset(sequential, competing = False): """Helper function to load and preprocess the PBC dataset The Primary biliary cirrhosis (PBC) Dataset [1] is well known @@ -155,19 +155,22 @@ def _load_pbc_dataset(sequential): time = (data['years'] - data['year']).values event = data['status2'].values + if competing: + event *= 2 # Death is 2 + event[data['status'] == 'transpanted'] = 1 x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) x_ = StandardScaler().fit_transform(x) if not sequential: - return x_, time + 1, event + return x_, time + 1, event, x1.columns.tolist() + x2.columns.tolist() + [x3.name] else: x, t, e = [], [], [] for id_ in sorted(list(set(data['id']))): x.append(x_[data['id'] == id_]) t.append(time[data['id'] == id_] + 1) e.append(event[data['id'] == id_]) - return x, t, e, np.concatenate([x1.columns, x2.columns, x3.columns]) + return x, t, e, np.concatenate([x1.columns, x2.columns, x3.name]) def _load_support_dataset(): """Helper function to load and preprocess the SUPPORT dataset. @@ -289,7 +292,7 @@ def load_dataset(dataset='SUPPORT', **kwargs): if dataset == 'SUPPORT': return _load_support_dataset() if dataset == 'PBC': - return _load_pbc_dataset(sequential) + return _load_pbc_dataset(sequential, competing) if dataset == 'FRAMINGHAM': return _load_framingham_dataset(sequential, competing) if dataset == 'MNIST': From 9fb5d30bf729685598a51e2988ba781f689bb76d Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Tue, 25 Jan 2022 18:23:02 +0000 Subject: [PATCH 16/22] Fix outcome --- dsm/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsm/datasets.py b/dsm/datasets.py index 81361ee..df05d71 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -157,7 +157,7 @@ def _load_pbc_dataset(sequential, competing = False): event = data['status2'].values if competing: event *= 2 # Death is 2 - event[data['status'] == 'transpanted'] = 1 + event[data['status'] == 'transplanted'] = 1 x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) x_ = StandardScaler().fit_transform(x) From 02d9f50dd6a5a11de754354535ccb7fa37a9ff13 Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Thu, 17 Feb 2022 09:20:54 +0000 Subject: [PATCH 17/22] Focus on baseline evaluation --- dsm/datasets.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dsm/datasets.py b/dsm/datasets.py index df05d71..8d3cbe3 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -141,6 +141,10 @@ def _load_pbc_dataset(sequential, competing = False): data = pkgutil.get_data(__name__, 'datasets/pbc2.csv') data = pd.read_csv(io.BytesIO(data)) + if not sequential: + # Consider only first event + data = data.groupby('id').first() + data['histologic'] = data['histologic'].astype(str) dat_cat = data[['drug', 'sex', 'ascites', 'hepatomegaly', 'spiders', 'edema', 'histologic']] From 3a243a2b437638438fa67c875e99f6a07779ed0b Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Tue, 5 Jul 2022 13:35:29 +0100 Subject: [PATCH 18/22] Ensure labels 1 same across experiments --- dsm/datasets.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/dsm/datasets.py b/dsm/datasets.py index 8d3cbe3..00db54b 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -99,8 +99,7 @@ def _load_framingham_dataset(sequential, competing = False): if competing: time_cvd = (data['TIMECVD'] - data['TIME']).values - event *= 2 - event[data['CVD'] == 1] = 1 + event[data['CVD'] == 1] = 2 time[data['CVD'] == 1] = time_cvd[data['CVD'] == 1] x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) @@ -158,10 +157,9 @@ def _load_pbc_dataset(sequential, competing = False): x = np.hstack([x1.values, x2.values, x3.values.reshape(-1, 1)]) time = (data['years'] - data['year']).values - event = data['status2'].values + event = (data['status'] == 'dead').values if competing: - event *= 2 # Death is 2 - event[data['status'] == 'transplanted'] = 1 + event[data['status'] == 'transplanted'] = 2 x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) x_ = StandardScaler().fit_transform(x) From 8fd04a88681ddb5aa549f9bc8278fc19b6013f95 Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Tue, 5 Jul 2022 16:14:59 +0100 Subject: [PATCH 19/22] Update PBC --- dsm/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsm/datasets.py b/dsm/datasets.py index 00db54b..7d39abf 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -157,7 +157,7 @@ def _load_pbc_dataset(sequential, competing = False): x = np.hstack([x1.values, x2.values, x3.values.reshape(-1, 1)]) time = (data['years'] - data['year']).values - event = (data['status'] == 'dead').values + event = (data['status'] == 'dead').values.astype(int) if competing: event[data['status'] == 'transplanted'] = 2 From e4b07b3f497f2266eaa71d0e182195e95663d367 Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Wed, 20 Jul 2022 10:00:09 +0100 Subject: [PATCH 20/22] Remove automatic normalization --- dsm/datasets.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/dsm/datasets.py b/dsm/datasets.py index 7d39abf..b463bc0 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -102,8 +102,7 @@ def _load_framingham_dataset(sequential, competing = False): event[data['CVD'] == 1] = 2 time[data['CVD'] == 1] = time_cvd[data['CVD'] == 1] - x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) - x_ = StandardScaler().fit_transform(x) + x_ = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) if not sequential: return x_, time + 1, event, np.concatenate([x1.columns, x2.columns]) @@ -161,8 +160,7 @@ def _load_pbc_dataset(sequential, competing = False): if competing: event[data['status'] == 'transplanted'] = 2 - x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) - x_ = StandardScaler().fit_transform(x) + x_ = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) if not sequential: return x_, time + 1, event, x1.columns.tolist() + x2.columns.tolist() + [x3.name] @@ -205,7 +203,6 @@ def _load_support_dataset(): e = data['death'].values x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) - x = StandardScaler().fit_transform(x) remove = ~np.isnan(t) return x[remove], t[remove] + 1, e[remove], np.concatenate([x1.columns, x2.columns]) @@ -240,7 +237,7 @@ def _load_mnist(): return x, t + 1, e, train.data.columns -def load_dataset(dataset='SUPPORT', **kwargs): +def load_dataset(dataset='SUPPORT', normalize = True, **kwargs): """Helper function to load datasets to test Survival Analysis models. Currently implemented datasets include: @@ -292,12 +289,13 @@ def load_dataset(dataset='SUPPORT', **kwargs): competing = kwargs.get('competing', False) if dataset == 'SUPPORT': - return _load_support_dataset() - if dataset == 'PBC': - return _load_pbc_dataset(sequential, competing) - if dataset == 'FRAMINGHAM': - return _load_framingham_dataset(sequential, competing) - if dataset == 'MNIST': - return _load_mnist() + x, t, e, covariates = _load_support_dataset() + elif dataset == 'PBC': + x, t, e, covariates = _load_pbc_dataset(sequential, competing) + elif dataset == 'FRAMINGHAM': + x, t, e, covariates = _load_framingham_dataset(sequential, competing) + elif dataset == 'MNIST': + x, t, e, covariates = _load_mnist() else: raise NotImplementedError('Dataset '+dataset+' not implemented.') + return StandardScaler().fit_transform(x) if normalize else x, t, e, covariates From cedfaf6970d94e0c8073408d23d711e88ded952e Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Mon, 19 Sep 2022 11:39:09 +0100 Subject: [PATCH 21/22] Update temporal normalization --- dsm/datasets.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/dsm/datasets.py b/dsm/datasets.py index b463bc0..5c84706 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -113,7 +113,7 @@ def _load_framingham_dataset(sequential, competing = False): x.append(x_[data['RANDID'] == id_]) t.append(time[data['RANDID'] == id_] + 1) e.append(event[data['RANDID'] == id_]) - return x, t, e, np.concatenate([x1.columns, x2.columns]) + return (x, x_), t, e, np.concatenate([x1.columns, x2.columns]) def _load_pbc_dataset(sequential, competing = False): """Helper function to load and preprocess the PBC dataset @@ -170,7 +170,7 @@ def _load_pbc_dataset(sequential, competing = False): x.append(x_[data['id'] == id_]) t.append(time[data['id'] == id_] + 1) e.append(event[data['id'] == id_]) - return x, t, e, np.concatenate([x1.columns, x2.columns, x3.name]) + return (x, x_), t, e, x1.columns.tolist() + x2.columns.tolist() + [x3.name] def _load_support_dataset(): """Helper function to load and preprocess the SUPPORT dataset. @@ -298,4 +298,12 @@ def load_dataset(dataset='SUPPORT', normalize = True, **kwargs): x, t, e, covariates = _load_mnist() else: raise NotImplementedError('Dataset '+dataset+' not implemented.') - return StandardScaler().fit_transform(x) if normalize else x, t, e, covariates + + if isinstance(x, tuple): + (x, x_all) = x + if normalize: + scaler = StandardScaler().fit(x_all) + x = [scaler.transform(x_) for x_ in x] + elif normalize: + x = StandardScaler().fit_transform(x) + return x, t, e, covariates From b29ab6dda51066e5c0b1b1f02e8bd7baca875094 Mon Sep 17 00:00:00 2001 From: Vincent Jeanselme Date: Thu, 4 Apr 2024 13:29:46 -0400 Subject: [PATCH 22/22] Add cluster assignment --- dsm/dsm_api.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index ba3c290..2afd15e 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -318,6 +318,15 @@ def predict_pdf(self, x, t, risk=1): "model using the `fit` method on some training data " + "before calling `predict_survival`.") + def predict_alphas(self, x): + x = self._preprocess_test_data(x) + if self.fitted: + _, _, alphas = self.torch_model(x) + return torch.softmax(alphas, dim = 1).detach().cpu().numpy() + else: + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `predict_alphas`.") class DeepSurvivalMachines(DSMBase): """A Deep Survival Machines model.