-
Notifications
You must be signed in to change notification settings - Fork 17
/
mvtecad_dataloader.py
85 lines (77 loc) · 3.2 KB
/
mvtecad_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import json
import cv2
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
mean_train = [0.485, 0.456, 0.406]
std_train = [0.229, 0.224, 0.225]
def data_transforms(size):
datatrans = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
transforms.CenterCrop(size),
#transforms.CenterCrop(args.input_size),
transforms.Normalize(mean=mean_train,
std=std_train)])
return datatrans
def gt_transforms(size):
gttrans = transforms.Compose([
transforms.Resize((size, size)),
transforms.CenterCrop(size),
transforms.ToTensor()])
return gttrans
class MVTecDataset(Dataset):
def __init__(self,type, root):
self.data = []
if type == 'train':
with open('./training/MVTec-AD/train.json', 'rt') as f:
for line in f:
self.data.append(json.loads(line))
else:
with open('./training/MVTec-AD/test.json', 'rt') as f:
for line in f:
self.data.append(json.loads(line))
self.label_to_idx = {'bottle': '0', 'cable': '1', 'capsule': '2', 'carpet': '3', 'grid': '4', 'hazelnut': '5',
'leather': '6', 'metal_nut': '7', 'pill': '8', 'screw': '9', 'tile': '10',
'toothbrush': '11', 'transistor': '12', 'wood': '13', 'zipper': '14'}
self.image_size = (256, 256)
self.root = root
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
source_filename = item['filename']
target_filename = item['filename']
label = item["label"]
if item.get("maskname", None):
mask = cv2.imread( self.root + item['maskname'], cv2.IMREAD_GRAYSCALE)
else:
if label == 0: # good
mask = np.zeros(self.image_size).astype(np.uint8)
elif label == 1: # defective
mask = (np.ones(self.image_size)).astype(np.uint8)
else:
raise ValueError("Labels must be [None, 0, 1]!")
prompt = ""
source = cv2.imread(self.root + source_filename)
target = cv2.imread(self.root + target_filename)
source = cv2.cvtColor(source, 4)
target = cv2.cvtColor(target, 4)
source = Image.fromarray(source, "RGB")
target = Image.fromarray(target, "RGB")
mask = Image.fromarray(mask, "L")
# transform_fn = transforms.Resize(256, Image.BILINEAR)
transform_fn = transforms.Resize(self.image_size)
source = transform_fn(source)
target = transform_fn(target)
mask = transform_fn(mask)
source = transforms.ToTensor()(source)
target = transforms.ToTensor()(target)
mask = transforms.ToTensor()(mask)
normalize_fn = transforms.Normalize(mean=mean_train, std=std_train)
source = normalize_fn(source)
target = normalize_fn(target)
clsname = item["clsname"]
image_idx = self.label_to_idx[clsname]
return dict(jpg=target, txt=prompt, hint=source, mask=mask, filename=source_filename, clsname=clsname, label=int(image_idx))