-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataloader.py
40 lines (32 loc) · 954 Bytes
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# %%
import os
import sys
import re
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms.functional as TF
import torchvision.utils as vutils
import random
import nibabel as nib
from glob import glob
import pandas as pd
class RandomRotation:
def __init__(self, angles=[0, 90, 180, 270]):
self.angles = angles
def __call__(self, x):
angle = random.choice(self.angles)
return TF.rotate(x, angle)
class SingleDataset(Dataset):
def __init__(self, image_path, transform=None):
self.image_path = image_path
self.transform = transform
def __len__(self):
return len(self.image_path)
def __getitem__(self, idx):
image = Image.open(self.image_path[idx]).convert('L')
if self.transform:
image = self.transform(image)
return image