-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
64 lines (57 loc) · 1.7 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import h5py
from torch import nn
from scipy.misc import imresize
import torch
import cv2
from PIL import Image
import os
import logging
from multiprocessing import Pool
import numpy as np
import time
import random
import time
from matplotlib import pyplot as plt
import matplotlib.image as mpimg
class Dataset(torch.utils.data.Dataset):
def __init__(self, file, phase):
self.iter = 4
self.file = get_image_list(file, phase)
pool = Pool()
self.datas = pool.map(self.getitem, range(len(self.file)))
pool.close()
pool.join()
self.phase = phase
def __getitem__(self, id):
#print(self.datas[id].shape)
return self.datas[int(id / 4)][id % 4, 0, :, :, :]
# real_id = int(id / 256)
# logging.debug(self.file[real_id])
# temp = self.datas[real_id]
# x = int((id - real_id * 256) / 16)
# y = id - real_id * 256 - x * 16
# return temp[:, x : x + 32, y : y + 32]
def __len__(self):
return len(self.file) * self.iter
def getitem(self, id):
print(self.file[id])
content = np.load(self.file[id])
codes = np.unpackbits(content['codes'])
codes = np.reshape(codes, content['shape']).astype(np.float32)
return codes
def get_image_list(train_dir, phase):
image_list = []
index = 0
for dir in os.listdir(train_dir):
if(phase == 'train'):
index += 1
image_list.append(os.path.abspath(train_dir + dir))
if(index > 1000):
break
elif(phase == 'val'):
index += 1
image_list.append(os.path.abspath(train_dir + dir))
if(index > 10):
break
return image_list