-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmy_dataset.py
135 lines (113 loc) · 5 KB
/
my_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset
import cv2
import os
import numpy as np
IMAGE_FOLDER = "/public_bme/data/jianght/datas/Pathology"
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from glob import glob
from einops.layers.torch import Rearrange
class MultiDataSet(Dataset):
def __init__(self, data, transforms=None, head_idx=None, age=False, img_batch=25,tasks=['fungus','label'],need_patch=False,patch_size=256):
if isinstance(data,str) and os.path.isfile(data):
data = pd.read_csv(data)
self.data = data
self.head_idx = head_idx
self.age = age
self.img_batch = img_batch
self.tasks = tasks
self.patch_size = patch_size
self.need_patch = need_patch
if isinstance(self.tasks,str):
self.tasks = [tasks]
# print(self.data.columns.array)
for i in self.tasks:
assert i in self.data.columns.array, f'task names wrong get {i} ---- '
def __getitem__(self, index):
# Initialize transform and normalize
# Read images
folder_path = get_folder_path(self.data.iloc[index, 0])
image_filenames = sorted(glob(f'{folder_path}/*.jpg'), key=lambda x: os.path.getsize(x), reverse=True)[:self.img_batch]
# print('sss',folder_path)
images = []
for img_name in image_filenames:
image_path = img_name
image = Image.open(image_path)
if self.need_patch:
transform = transforms.Compose([
#transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
Rearrange('c (h p1) (w p2) -> (h w) c p1 p2 ', p1=self.patch_size, p2=self.patch_size),
])
image = transform(image)
images.extend(image)
else:
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
#Rearrange('c (h p1) (w p2) -> (h w) c p1 p2 ', p1=self.patch_size, p2=self.patch_size),
])
image = transform(image)
images.append(image)
# Stack images
images_tensor = torch.stack(images)
# images_tensor = 1
label_dict={}
for i in range(len(self.tasks)):
column_index = self.data.columns.get_loc(self.tasks[i])
label_dict[f'label_{i}'] = self.data.iloc[index,column_index]
if 'code' in self.data.columns:
idx = self.data.columns.get_loc('code')
label_dict['code'] = self.data.iloc[index,idx]
else:
label_dict['code'] = self.data.iloc[index, 0].split('/')[-1]
if 'highlabel' in self.data.columns:
multilabel_index = self.data.columns.get_loc('highlabel')
else:
multilabel_index = self.data.columns.get_loc('multilabel')
label_dict['multilabel'] = self.data.iloc[index, multilabel_index]
labels = label_dict
if self.head_idx is not None:
# print(labels[f'label_{self.head_idx}'],type(labels[f'label_{self.head_idx}']))
return images_tensor, labels[f'label_{self.head_idx}']
else:
return images_tensor, labels
def __len__(self):
return len(self.data)
@staticmethod
def collate_fn(batch):
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
def get_folder_path(name):
folder_path = ''
if os.path.isdir(os.path.join(IMAGE_FOLDER, name, 'torch')):
folder_path = os.path.join(IMAGE_FOLDER, name, 'torch')
elif os.path.isdir(os.path.join(IMAGE_FOLDER,'yangxing','Torch',name,'torch')):
folder_path = os.path.join(IMAGE_FOLDER,'yangxing','Torch',name,'torch')
elif os.path.isdir(os.path.join(IMAGE_FOLDER,'yinxing','Torch',name,'torch')):
folder_path = os.path.join(IMAGE_FOLDER,'yinxing','Torch',name,'torch')
elif os.path.isdir(os.path.join(IMAGE_FOLDER, name)):
folder_path = os.path.join(IMAGE_FOLDER, name)
assert os.path.isdir(folder_path),f'get wrong name {name}'
return folder_path
if __name__=='__main__':
# data1=Gongjing('D:\\Datas\\bingli')
# data2=Fungus('E:\\fungus')
data3 = MultiDataSet(data='/public_bme/data/jianght/datas/Pathology/class2/test_supplement3.csv')
# data=data1+data2
data = data3
loader=torch.utils.data.DataLoader(data,shuffle=True,batch_size=4)
for i,l in loader:
print(i.size(),l)