Skip to content

Commit 12c64e2

Browse files
committed
first commit
0 parents  commit 12c64e2

18 files changed

+420
-0
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# pytorch-image-classification

eval.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import numpy as np
2+
import torch
3+
import torchvision
4+
from torchvision import datasets, models, transforms
5+
import torch.utils.data as data
6+
import multiprocessing
7+
from sklearn.metrics import confusion_matrix
8+
9+
# Paths for image directory and model
10+
EVAL_DIR=sys.argv[1])
11+
EVAL_MODEL='models/mobilenetv2.pth'
12+
13+
# Load the model for evaluation
14+
model = torch.load(EVAL_MODEL)
15+
model.eval()
16+
17+
# Configure batch size and nuber of cpu's
18+
num_cpu = multiprocessing.cpu_count()
19+
bs = 8
20+
21+
# Prepare the eval data loader
22+
eval_transform=transforms.Compose([
23+
transforms.Resize(size=256),
24+
transforms.CenterCrop(size=224),
25+
transforms.ToTensor(),
26+
transforms.Normalize([0.485, 0.456, 0.406],
27+
[0.229, 0.224, 0.225])])
28+
29+
eval_dataset=datasets.ImageFolder(root=EVAL_DIR, transform=eval_transform)
30+
eval_loader=data.DataLoader(eval_dataset, batch_size=bs, shuffle=True,
31+
num_workers=num_cpu, pin_memory=True)
32+
33+
# Enable gpu mode, if cuda available
34+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35+
36+
# Number of classes and dataset-size
37+
num_classes=len(eval_dataset.classes)
38+
dsize=len(eval_dataset)
39+
40+
# Initialize the prediction and label lists
41+
predlist=torch.zeros(0,dtype=torch.long, device='cpu')
42+
lbllist=torch.zeros(0,dtype=torch.long, device='cpu')
43+
44+
# Evaluate the model accuracy on the dataset
45+
correct = 0
46+
total = 0
47+
with torch.no_grad():
48+
for images, labels in eval_loader:
49+
images, labels = images.to(device), labels.to(device)
50+
outputs = model(images)
51+
_, predicted = torch.max(outputs.data, 1)
52+
53+
total += labels.size(0)
54+
correct += (predicted == labels).sum().item()
55+
56+
predlist=torch.cat([predlist,predicted.view(-1).cpu()])
57+
lbllist=torch.cat([lbllist,labels.view(-1).cpu()])
58+
59+
# Overall accuracy
60+
overall_accuracy=100 * correct / total
61+
print('Accuracy of the network on the {:d} test images: {:.2f}%'.format(dsize,
62+
overall_accuracy))
63+
64+
# Confusion matrix
65+
conf_mat=confusion_matrix(lbllist.numpy(), predlist.numpy())
66+
print('Confusion Matrix')
67+
print('-'*16)
68+
print(conf_mat,'\n')
69+
70+
# Per-class accuracy
71+
class_accuracy=100*conf_mat.diagonal()/conf_mat.sum(1)
72+
print('Per class accuracy')
73+
print('-'*18)
74+
for label,accuracy in zip(eval_dataset.classes, class_accuracy):
75+
print('Accuracy of %3s : %0.2f %%'%(label, accuracy)
76+
77+
78+
'''
79+
Sample run: python eval.py data/eval
80+
'''

models/mobilenetv2.pth

8.74 MB
Binary file not shown.

models/resnet18.pth

42.7 MB
Binary file not shown.

nets.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
def vgg_block_single(in_ch, out_ch, kernel_size=3, padding=1):
5+
return nn.Sequential(
6+
nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding),
7+
nn.ReLU(),
8+
nn.MaxPool2d(kernel_size=2, stride=2)
9+
)
10+
11+
def vgg_block_double(in_ch, out_ch, kernel_size=3, padding=1):
12+
return nn.Sequential(
13+
nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding),
14+
nn.ReLU(),
15+
nn.Conv2d(out_ch, out_ch, kernel_size=kernel_size, padding=padding),
16+
nn.ReLU(),
17+
nn.MaxPool2d(kernel_size=2, stride=2)
18+
)
19+
20+
21+
class MyVGG11(nn.Module):
22+
def __init__(self, in_ch, num_classes):
23+
super().__init__()
24+
25+
self.conv_block1 =vgg_block_single(in_ch,64)
26+
self.conv_block2 =vgg_block_single(64,128)
27+
28+
self.conv_block3 =vgg_block_double(128,256)
29+
self.conv_block4 =vgg_block_double(256,512)
30+
self.conv_block5 =vgg_block_double(512,512)
31+
32+
self.fc_layers = nn.Sequential(
33+
nn.Linear(512 * 7 * 7, 4096), nn.ReLU(inplace=True), nn.Dropout(),
34+
nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Dropout(),
35+
nn.Linear(4096, num_classes)
36+
)
37+
38+
def forward(self, x):
39+
40+
x=self.conv_block1(x)
41+
x=self.conv_block2(x)
42+
43+
x=self.conv_block3(x)
44+
x=self.conv_block4(x)
45+
x=self.conv_block5(x)
46+
47+
x=x.view(x.size(0), -1)
48+
49+
x=self.fc_layers(x)
50+
51+
return x

results/mobilenetv2.png

75.1 KB
Loading

results/resnet18.png

86 KB
Loading

test.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
import sys, random
3+
import torch
4+
from torchvision import models, transforms
5+
from PIL import Image
6+
from pathlib import Path
7+
import matplotlib.pyplot as plt
8+
9+
# Paths for image directory and model
10+
IMDIR=sys.argv[1])
11+
MODEL='models/resnet18.pth'
12+
13+
# Load the model for testing
14+
model = torch.load(MODEL)
15+
model.eval()
16+
17+
# Class labels for prediction
18+
class_names=['apple','atm card','cat','banana','bangle','battery','bottle','broom','bulb','calender','camera']
19+
20+
# Retreive 9 random images from directory
21+
files=Path(IMDIR).resolve().glob('*.*')
22+
images=random.sample(list(files), 9)
23+
24+
# Configure plots
25+
fig = plt.figure(figsize=(9,9))
26+
rows,cols = 3,3
27+
28+
# Preprocessing transformations
29+
preprocess=transforms.Compose([
30+
transforms.Resize(size=256),
31+
transforms.CenterCrop(size=224),
32+
transforms.ToTensor(),
33+
transforms.Normalize([0.485, 0.456, 0.406],
34+
[0.229, 0.224, 0.225])
35+
])
36+
37+
# Enable gpu mode, if cuda available
38+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
39+
40+
# Perform prediction and plot results
41+
with torch.no_grad():
42+
for num,img in enumerate(images):
43+
img=Image.open(img).convert('RGB')
44+
inputs=preprocess(img).unsqueeze(0).to(device)
45+
outputs = model(inputs)
46+
_, preds = torch.max(outputs, 1)
47+
label=class_names[preds]
48+
plt.subplot(rows,cols,num+1)
49+
plt.title("Pred: "+label)
50+
plt.axis('off')
51+
plt.imshow(img)
52+
'''
53+
Sample run: python test.py test
54+
'''

test/apple.jpeg

5.47 KB
Loading

test/atm_card.jpeg

11 KB
Loading

test/banana.jpeg

5.72 KB
Loading

test/bangle.jpeg

5.19 KB
Loading

test/bottle.jpeg

3.5 KB
Loading

test/bulb.jpeg

4.6 KB
Loading

test/calender.jpeg

13.2 KB
Loading

test/camera.jpeg

6.71 KB
Loading

test/cat.jpeg

5.16 KB
Loading

0 commit comments

Comments
 (0)