-
Notifications
You must be signed in to change notification settings - Fork 0
/
predictions.py
82 lines (67 loc) · 3.01 KB
/
predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Utility functions to make predictions.
Main reference for code creation: https://www.learnpytorch.io/06_pytorch_transfer_learning/#6-make-predictions-on-images-from-the-test-set
"""
import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from typing import List, Tuple
from PIL import Image
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Predict on a target image with a target model
# Function created in: https://www.learnpytorch.io/06_pytorch_transfer_learning/#6-make-predictions-on-images-from-the-test-set
def pred_and_plot_image(
model: torch.nn.Module,
class_names: List[str],
image_path: str,
image_size: Tuple[int, int] = (224, 224),
transform: torchvision.transforms = None,
device: torch.device = device,
):
"""Predicts on a target image with a target model.
Args:
model (torch.nn.Module): A trained (or untrained) PyTorch model to predict on an image.
class_names (List[str]): A list of target classes to map predictions to.
image_path (str): Filepath to target image to predict on.
image_size (Tuple[int, int], optional): Size to transform target image to. Defaults to (224, 224).
transform (torchvision.transforms, optional): Transform to perform on image. Defaults to None which uses ImageNet normalization.
device (torch.device, optional): Target device to perform prediction on. Defaults to device.
"""
# Open image
img = Image.open(image_path)
# Create transformation for image (if one doesn't exist)
if transform is not None:
image_transform = transform
else:
image_transform = transforms.Compose(
[
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
### Predict on image ###
# Make sure the model is on the target device
model.to(device)
# Turn on model evaluation mode and inference mode
model.eval()
with torch.inference_mode():
# Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
transformed_image = image_transform(img).unsqueeze(dim=0)
# Make a prediction on image with an extra dimension and send it to the target device
target_image_pred = model(transformed_image.to(device))
# Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
# Convert prediction probabilities -> prediction labels
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
# Plot image with predicted label and probability
plt.figure()
plt.imshow(img)
plt.title(
f"Pred: {class_names[target_image_pred_label]} | Prob: {target_image_pred_probs.max():.3f}"
)
plt.axis(False)