Skip to content

Commit ff89072

Browse files
author
redone17
committed
move to py3
1 parent 1c934d1 commit ff89072

File tree

2 files changed

+49
-42
lines changed

2 files changed

+49
-42
lines changed

iter_utils.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def train(model, train_loader, criterion, optimizer, init_lr=0.001, decay_epoch=
4545
if phase == 'train':
4646
loss.backward()
4747
optimizer.step()
48-
running_loss += loss.data[0]
48+
running_loss += loss.data
4949
_, predicted = torch.max(outputs.data, 1)
50-
running_corrects += torch.sum(predicted==targets.data)
51-
loss_curve.append(loss.data[0])
50+
running_corrects += torch.sum(predicted==targets.data).item()
51+
loss_curve.append(loss.data)
5252

5353
epoch_loss = running_loss / len(train_loader.dataset)
5454
epoch_accuracy = running_corrects / len(train_loader.dataset)
@@ -68,10 +68,10 @@ def test(model, test_loader):
6868
corrects = 0
6969
model = model.cpu()
7070
for batch_idx, (inputs, targets) in enumerate(test_loader):
71-
inputs, targets = Variable(inputs, volatile=True), Variable(targets)
72-
outputs = model(inputs)
73-
_, predicted = torch.max(outputs.data, 1)
74-
corrects += torch.sum(predicted==targets.data)
71+
with torch.no_grad():
72+
inputs, targets = Variable(inputs), Variable(targets)
73+
outputs = model(inputs)
74+
_, predicted = torch.max(outputs.data, 1)
75+
corrects += torch.sum(predicted==targets.data).item()
7576
accuracy = corrects / len(test_loader.dataset)
7677
return accuracy
77-

run.py

+41-34
Original file line numberDiff line numberDiff line change
@@ -20,38 +20,45 @@
2020
import torch.utils.data as data_utils
2121
from models import *
2222

23-
## load data
24-
data_arr_01 = data_loader.load_data('data/pgb/SF01/vib_data_1.txt')
25-
# data_arr_03 = data_loader.load_data('data/pgb/SF03/vib_data_1.txt')
26-
# data_arr_01 = data_loader.resample_arr(data_arr_01, num=240) # add for Ince's model
27-
# data_arr_03 = data_loader.resample_arr(data_arr_03, num=240) # add for Ince's model
28-
# data_arr_01, _ = data_loader.fft_arr(data_arr_01) # add for fft wdcnn
29-
# data_arr_03, _ = data_loader.fft_arr(data_arr_03) # add for fft wdcnn
30-
# data_arr_01 = data_loader.stft_arr(data_arr_01) # add for stft-LeNet
31-
# data_arr_03 = data_loader.stft_arr(data_arr_03)
32-
label_vec = data_loader.load_label('data/pgb/SF01/label_vec.txt')
33-
34-
trainset_01, testset_01 = data_loader.split_set(data_arr_01, label_vec)
35-
# trainset_03, testset_03 = data_loader.split_set(data_arr_03, label_vec)
36-
train_loader = data_utils.DataLoader(dataset = trainset_01, batch_size =512 , shuffle = True, num_workers = 2)
37-
test_loader = data_utils.DataLoader(dataset = testset_01, batch_size = 512, shuffle = True, num_workers = 2)
38-
print('Number of training samples: {}'.format(len(train_loader.dataset)))
39-
print('Number of testing samples: {}'.format(len(test_loader.dataset)))
40-
print( )
41-
42-
## make models
43-
model = dcnn.Net('DCNN08', 1, 5)
44-
45-
## train
46-
criterion = nn.CrossEntropyLoss()
47-
optimizer = optim.Adam(model.parameters(), weight_decay=0.0001)
48-
best_model, loss_curve = iter_utils.train(model, train_loader, criterion, optimizer,
49-
init_lr=0.0001, decay_epoch=5, n_epoch=10)
50-
51-
# test
52-
test_accuracy = iter_utils.test(best_model, test_loader)
53-
print('Test accuracy: {:.4f}%'.format(100*test_accuracy))
54-
55-
## visualization
56-
# TODO
23+
def main():
24+
## load data
25+
# data_arr_01 = data_loader.load_data(r'toydata/data.txt')
26+
# label_vec = data_loader.load_label(r'toydata/label.txt')
5727

28+
data_arr_01 = data_loader.load_data(r'data/uestc_pgb/SF01/vib_data_1.txt')
29+
# data_arr_03 = data_loader.load_data('data/pgb/SF03/vib_data_1.txt')
30+
# data_arr_01 = data_loader.resample_arr(data_arr_01, num=240) # add for Ince's model
31+
# data_arr_03 = data_loader.resample_arr(data_arr_03, num=240) # add for Ince's model
32+
# data_arr_01, _ = data_loader.fft_arr(data_arr_01) # add for fft wdcnn
33+
# data_arr_03, _ = data_loader.fft_arr(data_arr_03) # add for fft wdcnn
34+
# data_arr_01 = data_loader.stft_arr(data_arr_01) # add for stft-LeNet
35+
# data_arr_03 = data_loader.stft_arr(data_arr_03)
36+
label_vec = data_loader.load_label(r'data/uestc_pgb/SF01/label_vec.txt')
37+
38+
trainset_01, testset_01 = data_loader.split_set(data_arr_01, label_vec)
39+
# trainset_03, testset_03 = data_loader.split_set(data_arr_03, label_vec)
40+
train_loader = data_utils.DataLoader(dataset = trainset_01, batch_size =512 , shuffle = True, num_workers = 2)
41+
test_loader = data_utils.DataLoader(dataset = testset_01, batch_size = 512, shuffle = True, num_workers = 2)
42+
print('Number of training samples: {}'.format(len(train_loader.dataset)))
43+
print('Number of testing samples: {}'.format(len(test_loader.dataset)))
44+
print( )
45+
46+
## make models
47+
model = wdcnn.Net(1, 5)
48+
49+
## train
50+
criterion = nn.CrossEntropyLoss()
51+
optimizer = optim.Adam(model.parameters(), weight_decay=0.0001)
52+
best_model, loss_curve = iter_utils.train(model, train_loader, criterion, optimizer,
53+
init_lr=0.0001, decay_epoch=5, n_epoch=10, use_cuda=False)
54+
55+
# test
56+
test_accuracy = iter_utils.test(best_model, test_loader)
57+
print('Test accuracy: {:.4f}%'.format(100*test_accuracy))
58+
59+
60+
## visualization
61+
# TODO
62+
63+
if __name__ == '__main__':
64+
main()

0 commit comments

Comments
 (0)