-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_custom_model.py
117 lines (100 loc) · 4.11 KB
/
train_custom_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torchvision
import argparse
import os
import torch
from torch import nn, optim
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, Precision, Recall
from tqdm import tqdm
from model import Model
from data_loader import load_data
import json
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
NUM_EPOCH = 5
BATCH_SIZE = 50
log_interval = 10
train_accuracy = []
val_accuracy = []
train_loss = []
val_loss = []
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Binary classifier')
parser.add_argument('--test', dest='test_only', action='store_true', help='test model', default=False)
parser.add_argument('--file', dest='test_file', help='test model file')
parser.add_argument('--epoch', dest='num_epochs', help='number of epochs', type=int, default=NUM_EPOCH)
parser.add_argument('--batch', dest='batch_size', type=int, help='batch size', default=BATCH_SIZE)
args = parser.parse_args()
if args.test_only and (args.test_file is None):
parser.error("--test requires --file")
model = Model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, val_loader, test_loader = load_data(batch_size=args.batch_size)
model.to(device)
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
evaluator = create_supervised_evaluator(model, metrics={'accuracy': Accuracy(),
'loss': Loss(criterion)}, device=device)
desc = "ITERATION - loss: {:.2f}"
pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=desc.format(0))
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
iter = (engine.state.iteration - 1) % len(train_loader) + 1
if iter % log_interval == 0:
pbar.desc = desc.format(engine.state.output)
pbar.update(log_interval)
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
pbar.refresh()
evaluator.run(train_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics['accuracy']
avg_loss = metrics['loss']
train_accuracy.append(avg_accuracy)
train_loss.append(avg_loss)
# precision = metrics['pre']
# recall = metrics['recall']
# F1 = (precision * recall * 2 / (precision + recall)).mean()
tqdm.write("Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, avg_accuracy, avg_loss))
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics['accuracy']
avg_loss = metrics['loss']
val_accuracy.append(avg_accuracy)
val_loss.append(avg_loss)
# precision = metrics['pre']
# recall = metrics['recall']
# F1 = (precision * recall * 2 / (precision + recall)).mean()
tqdm.write(
"Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, avg_accuracy, avg_loss)
)
pbar.n = pbar.last_print_n = 0
trainer.run(train_loader, max_epochs=args.num_epochs)
pbar.close()
tester = create_supervised_evaluator(model, metrics={'accuracy': Accuracy(),
'loss': Loss(criterion),
'pre': Precision(average=True),
'recall': Recall(average=False)
}, device=device)
tester.run(test_loader)
metrics = tester.state.metrics
test_accuracy = metrics['accuracy']
test_loss = metrics['loss']
print("Precision", metrics['pre'])
print("Recall", metrics['recall'])
print("Test Results - Avg accuracy: {:.2f} Avg loss: {:.2f}".format(test_accuracy, test_loss))
stats = {
'train_accuracy': train_accuracy,
'train_loss': train_loss,
'val_accuracy': val_accuracy,
'val_loss': val_loss,
'test_accuracy': test_accuracy,
'test_loss': test_loss
}
with open('training.json', 'w') as json_f:
json.dump(stats, json_f)