-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdatautils.py
90 lines (71 loc) · 2.18 KB
/
datautils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
'''
This file has the custom dataset loading classes defined
'''
import os
import sys
import numpy as np
import torch
import random
import pdb
from torch.utils.data import Dataset
from torchvision import transforms
class Single_Image_Dataset(Dataset):
def __init__(self, image, erase=True):
self.image = image
self.erase = erase # to turn on/off image mising patch
if erase:
self.erase_transform = transforms.RandomErasing(p=1.0, scale=(0.002, 0.002), ratio=(1, 1), value='random', inplace=False)
def __getitem__(self, idx):
missing_image = self.image[idx]
if self.erase:
missing_image = self.erase_transform(missing_image)
return missing_image, self.image[idx]
def __len__(self):
return len(self.image)
class INR_Single_Image_Dataset(Dataset):
def __init__(self, grid, image):
self.grid = grid
self.image = image
def __getitem__(self, idx):
return self.grid[idx], self.image[idx]
def __len__(self):
return len(self.image)
class INR_Multi_Image_Dataset(Dataset):
def __init__(self, grid, image):
self.grid = grid
self.image = image
grid_min = np.min(grid[0])
grid_max = np.max(grid[0])
image_min = np.min(image[0])
image_max = np.max(image[0])
for i in range(len(grid)):
cur_grid_min = np.min(grid[i])
cur_grid_max = np.max(grid[i])
cur_image_min = np.min(image[i])
cur_image_max = np.max(image[i])
if cur_grid_min < grid_min:
grid_min = cur_grid_min
if cur_grid_max > grid_max:
grid_max = cur_grid_max
if cur_image_min < image_min:
image_min = cur_image_min
if cur_image_max > image_max:
image_max = cur_image_max
print("Grid min max:", grid_min, grid_max)
print('Image min max:', image_min, image_max)
def __getitem__(self, idx):
image = transforms.ToTensor()(np.array(self.image[idx]))
grid = transforms.ToTensor()(np.array(self.grid[idx]))*2 - 1
return grid, image
def __len__(self):
return len(self.image)
class INR_Multi_Image_Dataset_v3(Dataset):
def __init__(self, grid, image):
self.grid = grid
self.image = image
def __getitem__(self, idx):
image = transforms.ToTensor()(np.array(self.image[idx]))
grid = self.grid[idx]
return grid, image
def __len__(self):
return len(self.image)