-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot.py
89 lines (72 loc) · 2.71 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import matplotlib.pyplot as plt
import numpy as np
import json
def plot_predictions(y_train, pred_train, y_valid,
pred_valid, model_name, epochs):
'''
'''
plt.ion()
fig = plt.figure(figsize=(20, 5))
fig.suptitle(f'{model_name} - predictions')
ax = fig.add_subplot(121)
ax.plot(y_valid, '.k', label='y')
ax.plot(pred_valid, '.r', label='pred')
ax.set_ylim([9, 16])
ax.set_title('Validation set')
ax = fig.add_subplot(122)
ax.plot(y_train[:500], '.k', label='y')
ax.plot(pred_train[:500], '.r', label='pred')
ax.set_ylim([9, 16])
ax.set_title('Training set')
fig.savefig(f'fig/predictions_{model_name}_epochs{epochs}.png')
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(y_train, pred_train, marker='.', label='training set')
ax.scatter(y_valid, pred_valid, marker='.', label='validation set')
ax.legend()
ax.set_ylim([9, 16])
ax.set_xlim([9, 16])
ax.set_ylabel('Predicted affinity')
ax.set_xlabel('Measured affinity (ground truth)')
ax.set_title(f'{model_name} - predictions vs ground truth')
fig.savefig(f'fig/predictions_scatter_{model_name}_epochs{epochs}.png')
def plot_losses(history, model_name, ax=None):
'''
'''
if not ax:
fig = plt.figure(figsize=(7, 5))
ax = fig.add_subplot(111)
save = True
else:
save = False
epochs = range(1, len(history['train_loss'])+1)
ax.plot(epochs, history['train_loss'], label='train')
ax.plot(epochs, history['valid_loss'], label='valid')
ax.text(len(history['train_loss'])*.55, .7,
f"min validation loss: {np.min(history['valid_loss']):.4}")
ax.set_ylim([0, 1])
ax.set_ylabel('loss (MSE)')
ax.set_xlabel('epoch')
ax.set_title(f'{model_name} - training and validation loss')
if save:
fig.savefig(f'fig/training_history_{model_name}_'
f'epochs{max(epochs)}.png')
def plot_loss_comparison(fnames, names, title):
fig = plt.figure(figsize=(7, 5))
ax = fig.add_subplot(111)
for i, fname in enumerate(fnames):
history = json.load(open(fname))
epochs = range(1, len(history['train_loss'])+1)
ax.plot(epochs, history['train_loss'], ls=':',
label=f'{names[i]} (train)')
color = ax.get_lines()[-1].get_color()
ax.plot(epochs, history['valid_loss'], color=color, ls='-',
label=f'{names[i]} (valid)')
ax.plot([epochs[0], epochs[-1]], [0.71, 0.71], 'k',
label='Predicting the dataset average')
ax.set_ylim([0, 1])
ax.legend()
ax.set_title(title)
ax.set_ylabel('loss (MSE)')
ax.set_xlabel('Epoch')
fig.savefig(f'fig/history_comparision_{title}.png')