-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata_loaders.py
40 lines (30 loc) · 1.6 KB
/
data_loaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torchvision
import torchvision.transforms
def get_image_data_train_val_loaders(image_dim, batch_size):
train_dir = 'output/train'
val_dir = 'output/val'
transform = torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.Pad(padding=4, padding_mode="constant"),
torchvision.transforms.Resize((image_dim,image_dim)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, ), (0.5, ))
])
training_dataset = torchvision.datasets.ImageFolder(train_dir, transform=transform)
val_dataset = torchvision.datasets.ImageFolder(val_dir, transform=transform)
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
return train_loader, val_loader
def get_image_data_test_loader(image_dim, batch_size):
test_dir ='output/test'
transform = torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.Pad(padding=4, padding_mode="constant"),
torchvision.transforms.Resize((image_dim,image_dim)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, ), (0.5, ))
])
test_dataset = torchvision.datasets.ImageFolder(test_dir, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
return test_loader