-
Notifications
You must be signed in to change notification settings - Fork 0
/
Dilated_CNN.py
187 lines (152 loc) · 6.54 KB
/
Dilated_CNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, Subset
from torch import nn
from torch import optim
import STRDataset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import sys
import STRDataset as dataset
import os
import time
#Design CNN model
# Use dialted model
class Dialated_CNN(nn.Module):
def __init__(self, size):
'''size is the total number of positions (around start + end) we have'''
super(Dialated_CNN, self).__init__()
self.conv1 = nn.Conv1d(5, 320, kernel_size = 6,
dilation=10) # (batchsize, 320, 445)
self.pool1 = nn.MaxPool1d(5) # (batchsize, 320, 89)
self.conv2 = nn.Conv1d(320,480, kernel_size = 5) # (batchsize, 480, 85)
self.pool2 = nn.MaxPool1d(5) # (batchsize, 480, 17)
self.fc1 = nn.Linear((480*17)+size, 256) #THE input will be larger based on how big metadata is (we input in forward func)
self.fc2 = nn.Linear(256,128)
self.fc3 = nn.Linear(128,1)
def forward(self,x,meta): #x should be (batchsize, channels(columns), length). meta should be (batchsize, 78)
out = nn.functional.relu(self.conv1(x))
out = self.pool1(out)
out = nn.functional.relu(self.conv2(out))
out = self.pool2(out)
out = torch.flatten(out, start_dim=1) #Should return (batchsize, 480*17)
# print(out.shape)
#Concatenate metadata here!
out = torch.cat((out,meta),1)
out = nn.functional.relu(self.fc1(out))
out = nn.functional.relu(self.fc2(out))
out = nn.functional.relu(self.fc3(out))
return(out)
if __name__ == "__main__":
#Create train, test datasets + dataloaders. Take subset of train/test based on chromosome as indicated in metadata file
#samplename_locus.npy
ohe_dir = "/nfs/turbo/dcmb-class/bioinf593/groups/group_05/STRonvoli/data/ohe/"
meta_file = "/nfs/turbo/dcmb-class/bioinf593/groups/group_05/STRonvoli/data/metadata.tsv"
dataset = dataset.STRDataset(ohe_dir, meta_file)
meta = pd.read_csv(meta_file, sep = '\t')
train_indices = meta.index[meta["split"] == 0]
test_indices = meta.index[meta["split"] == 1]
trainset = Subset(dataset, train_indices)
validset = Subset(dataset, test_indices)
num = 6
trainloader = DataLoader(trainset, batch_size = 64, shuffle=True, num_workers = num)
validloader = DataLoader(validset, batch_size = 64, shuffle=True, num_workers = num)
print("Data process ready")
#Test if the loaders work
'''
for (batch_idx, batch) in enumerate(trainloader):
#print(batch)
(X, labels, meta) = batch
print(X.shape)
print(labels.shape)
print(meta.shape)
'''
#Sweep through num workers to find ideal
'''
for num_workers in range(2, os.cpu_count(), 2):
sub = Subset(dataset, train_indices[0:1000])
train_loader = DataLoader(sub,shuffle=True,num_workers=num_workers,batch_size=64,pin_memory=True)
start = time.time()
for i, data in enumerate(train_loader):
pass
end = time.time()
print("Finish with:{} second, num_workers={}".format(end - start, num_workers))
'''
#Train
#Parameters (to change):
pos_count = 500
meta_count = 79
max_epochs = 10
learning = 0.001
net = Dialated_CNN(meta_count).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=learning)
print("Model set up ready")
losses = [] #training loss for each epoch (keys are epochs)
losses_t = [] #Test loss over epochs
for e in range(max_epochs):
loss_l = 0 #Track validation loss in each epoch
net.train() #Train model
# print(len(trainloader))
start_time = time.time()
total_loss = 0 #Track training loss in each epoch
for (batch_idx, batch) in enumerate(trainloader):
(X, labels, meta) = batch
X = X.to(device)
labels = labels.to(device)
meta = meta.to(device)
X = torch.transpose(X,1,2)
X = X.float()
meta = meta.float()
labels = labels.float()
labels = torch.reshape(labels, (labels.size(dim=0),1))
optimizer.zero_grad()
output = net(X, meta) #Will be (batch_size,1)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
# print(f'batch {batch_idx} done, loss is {loss}, time from start of batches is {time.time() - start_time}')
total_loss += (loss.cpu().detach().numpy()) #Remove grad requirement + convert to np array to be able to plot
#print(total_loss)
losses.append(total_loss/len(trainloader))
#print(losses[e])
print(f'Epoch {e} done')
#Save model
savedir = "/nfs/turbo/dcmb-class/bioinf593/groups/group_05/STRonvoli/models/dilation/"
path = os.path.join(savedir, f'epoch-{e}-model-dilated.pth')
torch.save({
'epoch': e,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': total_loss}, path) # saving model
net.eval()
for (Y,y,m) in validloader:
Y = Y.to(device)
y = y.to(device)
m = m.to(device)
Y = torch.transpose(Y,1,2)
Y = Y.float()
y = y.float()
y = torch.reshape(y, (y.size(dim=0),1))
m = m.float()
out = net(Y,m)
lossl = criterion(out, y)
loss_l += (lossl.data.item())
losses_t.append(loss_l/len(validloader))
print(f'Validation loss after epoch {e} is')
print(loss_l/len(validloader))
plt.figure(0)
plt.plot(losses_t)
plt.xlabel('Epochs')
plt.ylabel('Average MSE loss over batches in each epoch')
plt.title('Validation Loss over epochs for STR prediction using Dilated CNN')
plt.savefig('Average Validation loss (across all batches) over epochs - Dilated')
plt.figure(1)
plt.plot(losses)
plt.xlabel('Epochs')
plt.ylabel('Average MSE loss over batches in each epoch')
plt.title('Training Loss over epochs for STR prediction using Dilated CNN')
plt.savefig('Average Training loss (across all batches) over epochs - Dilated')