from torch.utils.data import Dataset from PIL import Image import h5py import numpy as np import cv2 import random from skimage import exposure, img_as_float def load_data(img_path, ratio, aug, index, kernel_path='maps_adaptive_kernel'): gt_path = img_path.replace('.jpg', '.h5').replace('images', kernel_path) img = Image.open(img_path).convert('RGB') gt_file = h5py.File(gt_path) target = np.asarray(gt_file['density']) if aug: crop_size = (img.size[0]//2, img.size[1]//2) if random.random() <=0.44: # 4 non-overlapping patches dx = int(random.randint(0,1) * crop_size[0]) dy = int(random.randint(0,1) * crop_size[1]) else: # 5 random patches # set seed to ensure for each image the random patches are certain # if not set, the crop will be online which means the patches change every time loading, leading to a dynamic training set. patch_id = random.randint(0, 4) random.seed(index + patch_id * 0.1) dx = int(random.random() * crop_size[0]) random.seed(index + 0.5 + patch_id * 0.1) dy = int(random.random() * crop_size[1]) # crop img = img.crop((dx, dy, crop_size[0]+dx, crop_size[1]+dy)) target = target[dy:crop_size[1]+dy, dx:crop_size[0]+dx] # flip if random.random() > 0.5: target = np.fliplr(target) img = img.transpose(Image.FLIP_LEFT_RIGHT) # gamma transform if random.random() > 0.7: image = img_as_float(img) # gamma_img: np.array(dtype=float64) ranging [0,1] if random.random() > 0.5: gamma_img = exposure.adjust_gamma(image, 1.5) else: gamma_img = exposure.adjust_gamma(image, 0.5) gamma_img = gamma_img * 255 gamma_img = np.uint8(gamma_img) img = Image.fromarray(gamma_img) # grayscale if random.random() > 0.9: img = img.convert('L').convert('RGB') # convert to grayscale on 3 channels count = target.sum() if ratio>1: ratio_rounded = int(target.shape[0]) * int(target.shape[1]) / (int(target.shape[1] / ratio) * int(target.shape[0] / ratio)) # INTER_AREA interpolation leads to the smallest error after resizing the target target = cv2.resize(target, (int(target.shape[1]/ratio),int(target.shape[0]/ratio)), interpolation=cv2.INTER_AREA) * ratio_rounded return img, target, count class RawDataset(Dataset): def __init__(self, root, transform, ratio=8, aug=False, kernel_path='maps_adaptive_kernel'): self.nsamples = len(root) self.aug = aug self.root = root self.ratio = ratio self.transform = transform self.kernel_path = kernel_path def __getitem__(self, index): img, target, count = load_data(self.root[index], self.ratio, self.aug, index, self.kernel_path) if self.transform: img = self.transform(img) return img, target, count def __len__(self): return self.nsamples