-
Notifications
You must be signed in to change notification settings - Fork 54
/
dataloaders.py
60 lines (51 loc) · 2.23 KB
/
dataloaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
def get_mnist_dataloaders(batch_size=128):
"""MNIST dataloader with (32, 32) sized images."""
# Resize images so they are a power of 2
all_transforms = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor()
])
# Get train and test data
train_data = datasets.MNIST('../data', train=True, download=True,
transform=all_transforms)
test_data = datasets.MNIST('../data', train=False,
transform=all_transforms)
# Create dataloaders
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
return train_loader, test_loader
def get_fashion_mnist_dataloaders(batch_size=128):
"""Fashion MNIST dataloader with (32, 32) sized images."""
# Resize images so they are a power of 2
all_transforms = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor()
])
# Get train and test data
train_data = datasets.FashionMNIST('../fashion_data', train=True, download=True,
transform=all_transforms)
test_data = datasets.FashionMNIST('../fashion_data', train=False,
transform=all_transforms)
# Create dataloaders
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
return train_loader, test_loader
def get_lsun_dataloader(path_to_data='../lsun', dataset='bedroom_train',
batch_size=64):
"""LSUN dataloader with (128, 128) sized images.
path_to_data : str
One of 'bedroom_val' or 'bedroom_train'
"""
# Compose transforms
transform = transforms.Compose([
transforms.Resize(128),
transforms.CenterCrop(128),
transforms.ToTensor()
])
# Get dataset
lsun_dset = datasets.LSUN(db_path=path_to_data, classes=[dataset],
transform=transform)
# Create dataloader
return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True)