diff --git a/dsm/contrib/dcm_api.py b/dsm/contrib/dcm_api.py index 6863bf7..57941e9 100644 --- a/dsm/contrib/dcm_api.py +++ b/dsm/contrib/dcm_api.py @@ -3,7 +3,7 @@ import numpy as np from dsm.contrib.dcm_torch import DeepCoxMixturesTorch -from dsm.contrib.dcm_utilities import train_dcm, predict_survival +from dsm.contrib.dcm_utilities import * class DeepCoxMixtures(): @@ -178,4 +178,24 @@ def predict_survival(self, x, t): else: raise Exception("The model has not been fitted yet. Please fit the " + "model using the `fit` method on some training data " + - "before calling `predict_survival`.") \ No newline at end of file + "before calling `predict_survival`.") + + def compute_nll(self, x, t, e): + if not self.fitted: + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `_eval_nll`.") + processed_data = self._preprocess_training_data(x, t, e, 0, None, 0) + _, _, _, x_val, t_val, e_val = processed_data + with torch.no_grad(): + return - get_posteriors(repair_probs(get_likelihood(self.torch_model[0], self.torch_model[1], x_val, t_val, e_val))).sum().item() + + def predict_alphas(self, x): + x = self._preprocess_test_data(x) + if self.fitted: + alphas, _ = self.torch_model[0](x) + return alphas.detach().exp().cpu().numpy() + else: + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `predict_alphas`.") \ No newline at end of file diff --git a/dsm/datasets.py b/dsm/datasets.py index e3b4862..5c84706 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -54,7 +54,7 @@ def increase_censoring(e, t, p): return e, t -def _load_framingham_dataset(sequential): +def _load_framingham_dataset(sequential, competing = False): """Helper function to load and preprocess the Framingham dataset. The Framingham Dataset is a subset of 4,434 participants of the well known, @@ -80,33 +80,42 @@ def _load_framingham_dataset(sequential): data = pkgutil.get_data(__name__, 'datasets/framingham.csv') data = pd.read_csv(io.BytesIO(data)) + if not sequential: + # Consider only first event + data = data.groupby('RANDID').first() + dat_cat = data[['SEX', 'CURSMOKE', 'DIABETES', 'BPMEDS', 'educ', 'PREVCHD', 'PREVAP', 'PREVMI', 'PREVSTRK', 'PREVHYP']] dat_num = data[['TOTCHOL', 'AGE', 'SYSBP', 'DIABP', 'CIGPDAY', 'BMI', 'HEARTRTE', 'GLUCOSE']] - x1 = pd.get_dummies(dat_cat).values - x2 = dat_num.values - x = np.hstack([x1, x2]) + x1 = pd.get_dummies(dat_cat) + x2 = dat_num + x = np.hstack([x1.values, x2.values]) time = (data['TIMEDTH'] - data['TIME']).values event = data['DEATH'].values - x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) - x_ = StandardScaler().fit_transform(x) + if competing: + time_cvd = (data['TIMECVD'] - data['TIME']).values + event[data['CVD'] == 1] = 2 + time[data['CVD'] == 1] = time_cvd[data['CVD'] == 1] + + x_ = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) if not sequential: - return x_, time, event + return x_, time + 1, event, np.concatenate([x1.columns, x2.columns]) else: + x_, data, time, event = x_[time > 0], data[time > 0], time[time > 0], event[time > 0] x, t, e = [], [], [] for id_ in sorted(list(set(data['RANDID']))): x.append(x_[data['RANDID'] == id_]) - t.append(time[data['RANDID'] == id_]) + t.append(time[data['RANDID'] == id_] + 1) e.append(event[data['RANDID'] == id_]) - return x, t, e + return (x, x_), t, e, np.concatenate([x1.columns, x2.columns]) -def _load_pbc_dataset(sequential): +def _load_pbc_dataset(sequential, competing = False): """Helper function to load and preprocess the PBC dataset The Primary biliary cirrhosis (PBC) Dataset [1] is well known @@ -130,6 +139,10 @@ def _load_pbc_dataset(sequential): data = pkgutil.get_data(__name__, 'datasets/pbc2.csv') data = pd.read_csv(io.BytesIO(data)) + if not sequential: + # Consider only first event + data = data.groupby('id').first() + data['histologic'] = data['histologic'].astype(str) dat_cat = data[['drug', 'sex', 'ascites', 'hepatomegaly', 'spiders', 'edema', 'histologic']] @@ -137,26 +150,27 @@ def _load_pbc_dataset(sequential): 'SGOT', 'platelets', 'prothrombin']] age = data['age'] + data['years'] - x1 = pd.get_dummies(dat_cat).values - x2 = dat_num.values - x3 = age.values.reshape(-1, 1) - x = np.hstack([x1, x2, x3]) + x1 = pd.get_dummies(dat_cat) + x2 = dat_num + x3 = age + x = np.hstack([x1.values, x2.values, x3.values.reshape(-1, 1)]) time = (data['years'] - data['year']).values - event = data['status2'].values + event = (data['status'] == 'dead').values.astype(int) + if competing: + event[data['status'] == 'transplanted'] = 2 - x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) - x_ = StandardScaler().fit_transform(x) + x_ = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) if not sequential: - return x_, time, event + return x_, time + 1, event, x1.columns.tolist() + x2.columns.tolist() + [x3.name] else: x, t, e = [], [], [] for id_ in sorted(list(set(data['id']))): x.append(x_[data['id'] == id_]) - t.append(time[data['id'] == id_]) + t.append(time[data['id'] == id_] + 1) e.append(event[data['id'] == id_]) - return x, t, e + return (x, x_), t, e, x1.columns.tolist() + x2.columns.tolist() + [x3.name] def _load_support_dataset(): """Helper function to load and preprocess the SUPPORT dataset. @@ -189,10 +203,9 @@ def _load_support_dataset(): e = data['death'].values x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) - x = StandardScaler().fit_transform(x) remove = ~np.isnan(t) - return x[remove], t[remove], e[remove] + return x[remove], t[remove] + 1, e[remove], np.concatenate([x1.columns, x2.columns]) def _load_mnist(): """Helper function to load and preprocess the MNIST dataset. @@ -222,9 +235,9 @@ def _load_mnist(): e, t = increase_censoring(np.ones(t.shape), t, p=.5) - return x, t, e + return x, t + 1, e, train.data.columns -def load_dataset(dataset='SUPPORT', **kwargs): +def load_dataset(dataset='SUPPORT', normalize = True, **kwargs): """Helper function to load datasets to test Survival Analysis models. Currently implemented datasets include: @@ -273,14 +286,24 @@ def load_dataset(dataset='SUPPORT', **kwargs): """ sequential = kwargs.get('sequential', False) + competing = kwargs.get('competing', False) if dataset == 'SUPPORT': - return _load_support_dataset() - if dataset == 'PBC': - return _load_pbc_dataset(sequential) - if dataset == 'FRAMINGHAM': - return _load_framingham_dataset(sequential) - if dataset == 'MNIST': - return _load_mnist() + x, t, e, covariates = _load_support_dataset() + elif dataset == 'PBC': + x, t, e, covariates = _load_pbc_dataset(sequential, competing) + elif dataset == 'FRAMINGHAM': + x, t, e, covariates = _load_framingham_dataset(sequential, competing) + elif dataset == 'MNIST': + x, t, e, covariates = _load_mnist() else: raise NotImplementedError('Dataset '+dataset+' not implemented.') + + if isinstance(x, tuple): + (x, x_all) = x + if normalize: + scaler = StandardScaler().fit(x_all) + x = [scaler.transform(x_) for x_ in x] + elif normalize: + x = StandardScaler().fit_transform(x) + return x, t, e, covariates diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index 59da5a7..2afd15e 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -52,13 +52,14 @@ class DSMBase(): """Base Class for all DSM models""" def __init__(self, k=3, layers=None, distribution="Weibull", - temp=1000., discount=1.0): + temp=1000., discount=1.0, cuda=False): self.k = k self.layers = layers self.dist = distribution self.temp = temp self.discount = discount self.fitted = False + self.cuda = cuda # Two levels: 1 full GPU, 2 batch GPU (prefer 1 if fit on memory) def _gen_torch_model(self, inputdim, optimizer, risks): """Helper function to return a torch model.""" @@ -122,13 +123,17 @@ def fit(self, x, t, e, vsize=0.15, val_data=None, maxrisk = int(np.nanmax(e_train.cpu().numpy())) model = self._gen_torch_model(inputdim, optimizer, risks=maxrisk) + + if self.cuda: + model = model.cuda() + model, _ = train_dsm(model, x_train, t_train, e_train, x_val, t_val, e_val, n_iter=iters, lr=learning_rate, elbo=elbo, - bs=batch_size) + bs=batch_size, cuda=self.cuda==2) self.torch_model = model.eval() self.fitted = True @@ -163,15 +168,23 @@ def compute_nll(self, x, t, e): x_val, t_val, e_val = x_val,\ _reshape_tensor_with_nans(t_val),\ _reshape_tensor_with_nans(e_val) + + if self.cuda == 2: + # Data need to be on GPU if loss computed + x_val, t_val, e_val = x_val.cuda(), t_val.cuda(), e_val.cuda() + loss = 0 for r in range(self.torch_model.risks): loss += float(losses.conditional_loss(self.torch_model, x_val, t_val, e_val, elbo=False, - risk=str(r+1)).detach().numpy()) + risk=str(r+1)).item()) return loss def _preprocess_test_data(self, x): - return torch.from_numpy(x) + data = torch.from_numpy(x) + if self.cuda: + data = data.cuda() + return data def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state): @@ -201,7 +214,12 @@ def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state): t_val = torch.from_numpy(t_val).double() e_val = torch.from_numpy(e_val).double() - return (x_train, t_train, e_train, x_val, t_val, e_val) + if self.cuda == 1: + x_train, t_train, e_train = x_train.cuda(), t_train.cuda(), e_train.cuda() + x_val, t_val, e_val = x_val.cuda(), t_val.cuda(), e_val.cuda() + + return (x_train, t_train, e_train, + x_val, t_val, e_val) def predict_mean(self, x, risk=1): @@ -300,6 +318,15 @@ def predict_pdf(self, x, t, risk=1): "model using the `fit` method on some training data " + "before calling `predict_survival`.") + def predict_alphas(self, x): + x = self._preprocess_test_data(x) + if self.fitted: + _, _, alphas = self.torch_model(x) + return torch.softmax(alphas, dim = 1).detach().cpu().numpy() + else: + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `predict_alphas`.") class DeepSurvivalMachines(DSMBase): """A Deep Survival Machines model. @@ -394,7 +421,10 @@ def _gen_torch_model(self, inputdim, optimizer, risks): risks=risks) def _preprocess_test_data(self, x): - return torch.from_numpy(_get_padded_features(x)) + data = torch.from_numpy(_get_padded_features(x)) + if self.cuda: + data = data.cuda() + return data def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state): """RNNs require different preprocessing for variable length sequences""" @@ -435,7 +465,12 @@ def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state): t_val = torch.from_numpy(t_val).double() e_val = torch.from_numpy(e_val).double() - return (x_train, t_train, e_train, x_val, t_val, e_val) + if self.cuda == 1: + x_train, t_train, e_train = x_train.cuda(), t_train.cuda(), e_train.cuda() + x_val, t_val, e_val = x_val.cuda(), t_val.cuda(), e_val.cuda() + + return (x_train, t_train, e_train, + x_val, t_val, e_val) class DeepConvolutionalSurvivalMachines(DSMBase): diff --git a/dsm/losses.py b/dsm/losses.py index c74d744..c20d1b2 100644 --- a/dsm/losses.py +++ b/dsm/losses.py @@ -303,7 +303,7 @@ def _weibull_pdf(model, x, t_horizon, risk='1'): k_ = shape b_ = scale - t_horz = torch.tensor(t_horizon).double() + t_horz = torch.tensor(t_horizon).double().to(x.device) t_horz = t_horz.repeat(shape.shape[0], 1) pdfs = [] @@ -324,7 +324,7 @@ def _weibull_pdf(model, x, t_horizon, risk='1'): lpdfs = torch.stack(lpdfs, dim=1) lpdfs = lpdfs+logits lpdfs = torch.logsumexp(lpdfs, dim=1) - pdfs.append(lpdfs.detach().numpy()) + pdfs.append(lpdfs.detach().cpu().numpy()) return pdfs @@ -338,7 +338,7 @@ def _weibull_cdf(model, x, t_horizon, risk='1'): k_ = shape b_ = scale - t_horz = torch.tensor(t_horizon).double() + t_horz = torch.tensor(t_horizon).double().to(x.device) t_horz = t_horz.repeat(shape.shape[0], 1) cdfs = [] @@ -357,7 +357,7 @@ def _weibull_cdf(model, x, t_horizon, risk='1'): lcdfs = torch.stack(lcdfs, dim=1) lcdfs = lcdfs+logits lcdfs = torch.logsumexp(lcdfs, dim=1) - cdfs.append(lcdfs.detach().numpy()) + cdfs.append(lcdfs.detach().cpu().numpy()) return cdfs @@ -386,7 +386,7 @@ def _weibull_mean(model, x, risk='1'): lmeans = lmeans+logits lmeans = torch.logsumexp(lmeans, dim=1) - return torch.exp(lmeans).detach().numpy() + return torch.exp(lmeans).detach().cpu().numpy() @@ -401,7 +401,7 @@ def _lognormal_cdf(model, x, t_horizon, risk='1'): k_ = shape b_ = scale - t_horz = torch.tensor(t_horizon).double() + t_horz = torch.tensor(t_horizon).double().to(x.device) t_horz = t_horz.repeat(shape.shape[0], 1) cdfs = [] @@ -424,7 +424,7 @@ def _lognormal_cdf(model, x, t_horizon, risk='1'): lcdfs = torch.stack(lcdfs, dim=1) lcdfs = lcdfs+logits lcdfs = torch.logsumexp(lcdfs, dim=1) - cdfs.append(lcdfs.detach().numpy()) + cdfs.append(lcdfs.detach().cpu().numpy()) return cdfs @@ -461,7 +461,7 @@ def _normal_cdf(model, x, t_horizon, risk='1'): lcdfs = torch.stack(lcdfs, dim=1) lcdfs = lcdfs+logits lcdfs = torch.logsumexp(lcdfs, dim=1) - cdfs.append(lcdfs.detach().numpy()) + cdfs.append(lcdfs.detach().cpu().numpy()) return cdfs @@ -485,7 +485,7 @@ def _normal_mean(model, x, risk='1'): lmeans = lmeans*logits lmeans = torch.sum(lmeans, dim=1) - return lmeans.detach().numpy() + return lmeans.detach().cpu().numpy() def predict_mean(model, x, risk='1'): diff --git a/dsm/utilities.py b/dsm/utilities.py index 6a8d4cc..880ab62 100644 --- a/dsm/utilities.py +++ b/dsm/utilities.py @@ -50,13 +50,17 @@ def get_optimizer(model, lr): ' is not implemented') def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, - n_iter=10000, lr=1e-2, thres=1e-4): + n_iter=10000, lr=1e-2, thres=1e-4, cuda = False): premodel = DeepSurvivalMachinesTorch(1, 1, dist=model.dist, risks=model.risks, - optimizer=model.optimizer) - premodel.double() + optimizer=model.optimizer).double() + + if cuda: + premodel.cuda() + t_train, e_train = t_train.cuda(), e_train.cuda() + t_valid, e_valid = t_valid.cuda(), e_valid.cuda() optimizer = get_optimizer(premodel, lr) @@ -113,7 +117,7 @@ def train_dsm(model, x_train, t_train, e_train, x_valid, t_valid, e_valid, n_iter=10000, lr=1e-3, elbo=True, - bs=100): + bs=100, cuda=False): """Function to train the torch instance of the model.""" logging.info('Pretraining the Underlying Distributions...') @@ -131,7 +135,7 @@ def train_dsm(model, e_valid_, n_iter=10000, lr=1e-2, - thres=1e-4) + thres=1e-4, cuda = cuda or t_train.is_cuda) for r in range(model.risks): model.shape[str(r+1)].data.fill_(float(premodel.shape[str(r+1)])) @@ -158,6 +162,9 @@ def train_dsm(model, if xb.shape[0] == 0: continue + if cuda: + xb, tb, eb = xb.cuda(), tb.cuda(), eb.cuda() + optimizer.zero_grad() loss = 0 for r in range(model.risks): @@ -173,6 +180,9 @@ def train_dsm(model, valid_loss = 0 for r in range(model.risks): + if cuda: + x_valid, t_valid_, e_valid_ = x_valid.cuda(), t_valid_.cuda(), e_valid_.cuda() + valid_loss += conditional_loss(model, x_valid, t_valid_, diff --git a/examples/DSM on SUPPORT Dataset.ipynb b/examples/DSM on SUPPORT Dataset.ipynb index 04455c5..2fb7c98 100644 --- a/examples/DSM on SUPPORT Dataset.ipynb +++ b/examples/DSM on SUPPORT Dataset.ipynb @@ -39,7 +39,7 @@ "outputs": [], "source": [ "from dsm import datasets\n", - "x, t, e = datasets.load_dataset('SUPPORT')" + "x, t, e, _ = datasets.load_dataset('SUPPORT')" ] }, { diff --git a/examples/RDSM on PBC Dataset.ipynb b/examples/RDSM on PBC Dataset.ipynb index 83b2cc7..7c5ed05 100644 --- a/examples/RDSM on PBC Dataset.ipynb +++ b/examples/RDSM on PBC Dataset.ipynb @@ -36,7 +36,7 @@ "outputs": [], "source": [ "from dsm import datasets\n", - "x, t, e = datasets.load_dataset('PBC', sequential = True)" + "x, t, e, _ = datasets.load_dataset('PBC', sequential = True)" ] }, {