-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
101 lines (84 loc) · 2.89 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import logging
import random
from typing import cast
import torch
import torch.utils.data
from PIL.Image import Image
from torch import nn
from torchvision import models, transforms
from torchvision.models import EfficientNet_V2_M_Weights
from torchvision.transforms import _functional_pil as F_pil
from torchvision.transforms import functional as F
TASK_CLASSIFICATION = "classification"
TASK_ORDINAL_REGRESSION = "ordinal-regression"
DatasetFolderItem = tuple[Image, int]
def create_model(
device: torch.device,
num_classes: int,
task_type: str,
freeze_pretrained_layers: bool = True,
) -> nn.Module:
model = models.efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.DEFAULT)
if freeze_pretrained_layers:
model.requires_grad_(False)
lastconv_output_channels = cast(nn.Linear, model.classifier[1]).in_features
final_output_channels = (
num_classes - 1 if task_type == TASK_ORDINAL_REGRESSION else num_classes
)
model.classifier = nn.Sequential(
nn.Dropout(0.3, inplace=True),
nn.Linear(lastconv_output_channels, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, final_output_channels),
).requires_grad_(True)
return model.to(device)
def get_device() -> torch.device:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device {device}.")
return device
def get_train_transform(resize_to: int) -> transforms.Compose:
return transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.Lambda(random_crop_long_side),
transforms.Lambda(pad_to_square),
transforms.Resize(resize_to),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def get_val_transform(resize_to: int) -> transforms.Compose:
return transforms.Compose(
[
transforms.Lambda(pad_to_square),
transforms.Resize((resize_to, resize_to)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def pad_to_square(img: Image) -> Image:
w, h = img.size
if w == h:
return img
elif w < h:
padding = (h - w) // 2
return F_pil.pad(img, [padding, 0, padding, 0], fill=0)
else:
padding = (w - h) // 2
return F_pil.pad(img, [0, padding, 0, padding], fill=0)
def random_crop_long_side(img: Image) -> Image:
w, h = img.size
if w < h:
new_w = w
new_h = int(h * random.triangular(0.9, 1, 1))
else:
new_w = int(w * random.triangular(0.9, 1, 1))
new_h = h
return F.center_crop(img, (new_h, new_w)) # type: ignore