-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
125 lines (100 loc) · 4.85 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from torch.utils.data import Dataset
import os
from PIL import Image
from torchvision import transforms
import numpy as np
import torch
# class RSDataset(Dataset):
# def __init__(self, class_name, mode=None, img_transform=img_transform, mask_transform=mask_transform, sync_transforms=None):
# # 数据相关
# self.class_names = class_name
# self.mode = mode
# self.img_transform = img_transform
# self.mask_transform = mask_transform
# self.sync_transform = sync_transforms
# self.sync_img_mask = []
# key_word = 'patches'
# if mode == "src": # for gid
# img_dir = os.path.join(root, 'rgb')
# mask_dir = os.path.join(root, 'label')
# else: # for whu-opt-sar
# img_dir = os.path.join(root, 'rgb')
# mask_dir = os.path.join(root, 'label')
# for img_filename in os.listdir(img_dir):
# img_mask_pair = (os.path.join(img_dir, img_filename),
# os.path.join(mask_dir,
# img_filename.replace("MSS1.jpg", "MSS1_label.png").replace("MSS2.jpg", "MSS2_label.png")))
# self.sync_img_mask.append(img_mask_pair)
# print(self.sync_img_mask)
# if (len(self.sync_img_mask)) == 0:
# print("Found 0 data, please check your dataset!")
# def __getitem__(self, index):
# img_path, mask_path = self.sync_img_mask[index]
# img = Image.open(img_path).convert('RGB')
# mask = Image.open(mask_path).convert('L')
# # transform
# if self.sync_transform is not None:
# img, mask = self.sync_transform(img, mask)
# if self.img_transform is not None:
# img = self.img_transform(img)
# if self.mask_transform is not None:
# mask = self.mask_transform(mask)
# return img, mask
# def __len__(self):
# return len(self.sync_img_mask)
# def classes(self):
# return self.class_names
class MaskToTensor(object):
def __call__(self, img):
return torch.from_numpy(np.array(img, dtype=np.int32)).long() #将图片转换为Tensor,并转换为long类型,不归一化
img_opt_transform = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
])
img_transform = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
transforms.Normalize([.485, .456, .406], [.229, .224, .225])]) # 标准化至[-1,1],经常用于图像预处理
img_sar_transform = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
])
mask_transform = MaskToTensor()
class WHUOPTSARDataset(Dataset):
def __init__(self, class_name, root, mode=None, img_sar_transform=img_sar_transform, img_opt_transform=img_opt_transform, mask_transform=mask_transform, sync_transforms=None):
# 数据相关
self.class_names = class_name
self.mode = mode
self.img_sar_transform = img_sar_transform
self.img_opt_transform = img_opt_transform
self.mask_transform = mask_transform
self.sync_transform = sync_transforms
self.sync_img_mask = []
img_sar_dir = os.path.join(root, 'sar')
img_opt_dir = os.path.join(root, 'opt')
mask_dir = os.path.join(root, 'lbl')
for img_filename in os.listdir(img_sar_dir):
img_mask_pair = (os.path.join(img_sar_dir, img_filename),
os.path.join(img_opt_dir, img_filename),
os.path.join(mask_dir, img_filename)) #相对应的三张图片组成一个元组
self.sync_img_mask.append(img_mask_pair) #列表
# print(self.sync_img_mask)
if (len(self.sync_img_mask)) == 0:
print("Found 0 data, please check your dataset!")
def __getitem__(self, index) -> tuple:
img_sar_path, img_opt_path, mask_path = self.sync_img_mask[index]
img_sar = Image.open(img_sar_path)
img_opt = Image.open(img_opt_path)
mask = Image.open(mask_path).convert('L') #原图就是L模式
# transform
if self.sync_transform is not None: #同步变换
img_sar, img_opt, mask = self.sync_transform(img_sar, img_opt, mask)
if self.img_sar_transform is not None: #变为tensor【0,1】
img_sar = self.img_sar_transform(img_sar)
img_opt = self.img_opt_transform(img_opt)
if self.mask_transform is not None: #变为tensor
mask = self.mask_transform(mask)
return img_sar, img_opt, mask
def __len__(self):
return len(self.sync_img_mask)
def classes(self):
return self.class_names
if __name__ == "__main__":
pass