-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_plot.py
119 lines (103 loc) · 5.36 KB
/
train_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from matplotlib import pyplot as plt
from synthetic_dgp import linear_dgp
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
import os
import numpy as np
# sys.path.append(os.path.abspath('../'))
from dirac_phi import DiracPhi
from survival import DCSurvival
from survival import sample
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_num_threads(24)
device = torch.device("cuda:0")
batch_size = 20000
num_epochs = 10000
copula_form = 'Frank'
sample_size = 30000
val_size = 10000
seed = 142857
rng = np.random.default_rng(seed)
def main():
for theta_true in [2]:
X, observed_time, event_indicator, _, _, _ = linear_dgp( copula_name=copula_form, covariate_dim=10, theta=theta_true, sample_size=sample_size, rng=rng)
times_tensor = torch.tensor(observed_time, dtype=torch.float64).to(device)
event_indicator_tensor = torch.tensor(event_indicator, dtype=torch.float64).to(device)
covariate_tensor = torch.tensor(X, dtype=torch.float64).to(device)
train_data = TensorDataset(covariate_tensor[0:sample_size-val_size], times_tensor[0:sample_size-val_size], event_indicator_tensor[0:sample_size-val_size])
val_data = TensorDataset(covariate_tensor[sample_size-val_size:], times_tensor[sample_size-val_size:], event_indicator_tensor[sample_size-val_size:])
train_loader = DataLoader(train_data, batch_size= batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size= batch_size, shuffle=True)
# Early stopping
best_val_loglikelihood = float('-inf')
epochs_no_improve = 0
early_stop_epochs = 100
# Parameters for ACNet
depth = 2
widths = [100, 100]
lc_w_range = (0, 1.0)
shift_w_range = (0., 2.0)
phi = DiracPhi(depth, widths, lc_w_range, shift_w_range, device, tol = 1e-10).to(device)
model = DCSurvival(phi, device = device, num_features=10, tol=1e-10).to(device)
# separately optimize copula and survival parameters is sometimes helpful, but not necessary
optimizer_survival = optim.Adam([{"params": model.sumo_e.parameters(), "lr": 0.001},
{"params": model.sumo_c.parameters(), "lr": 0.001},
])
optimizer_copula = optim.SGD([{"params": model.phi.parameters(), "lr": 0.0005}])
train_loss_per_epoch = []
print("Start training!")
for epoch in range(num_epochs):
loss_per_minibatch = []
for i, (x , t, c) in enumerate(train_loader, 0):
optimizer_copula.zero_grad()
optimizer_survival.zero_grad()
p = model(x, t, c, max_iter = 1000)
logloss = -p
logloss.backward()
scalar_loss = (logloss/p.numel()).detach().cpu().numpy().item()
optimizer_survival.step()
if epoch > 200:
optimizer_copula.step()
# optimizer_censoring.step()
# optimizer.step()
loss_per_minibatch.append(scalar_loss/batch_size)
train_loss_per_epoch.append(np.mean(loss_per_minibatch))
if epoch % 1 == 0:
print('Training likilihood at epoch %s: %.5f' %
(epoch, -train_loss_per_epoch[-1]))
# Check if validation loglikelihood has improved
for i, (x_val, t_val, c_val) in enumerate(val_loader, 0):
val_loglikelihood = model(x_val, t_val, c_val, max_iter = 10000)/val_size
print('Validation log-likelihood at epoch %s: %s' % (epoch, val_loglikelihood.cpu().detach().numpy().item()))
if val_loglikelihood > best_val_loglikelihood:
best_val_loglikelihood = val_loglikelihood
epochs_no_improve = 0
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'loss': best_val_loglikelihood,
}, '/home/DCSurvival/checkpoints/checkpoint_experiment_'+copula_form + '_' +str(theta_true)+'.pth')
else:
epochs_no_improve += 1
# Early stopping condition
if epochs_no_improve == early_stop_epochs:
print('Early stopping triggered at epoch: %s' % epoch)
break
# Plot Samples from the learned copula
if epoch % 200 == 0:
print('Scatter sampling')
samples = sample(model, 2, sample_size, device = device)
plt.scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), s=15)
plt.savefig('/home/DCSurvival/sample_figs/'+copula_form+'/'+str(theta_true)+'/epoch%s.png' %
(epoch))
plt.clf()
checkpoint = torch.load('/home/DCSurvival/checkpoints/checkpoint_experiment_'+copula_form + '_' +str(theta_true)+'.pth')
model.load_state_dict(checkpoint['model_state_dict'])
samples = sample(model, 2, sample_size, device = device)
plt.scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), s = 15)
plt.savefig('/home/DCSurvival/sample_figs/'+copula_form+'/'+str(theta_true)+'/best_epoch.png')
plt.clf()
if __name__ == '__main__':
main()