This repository has been archived by the owner on Jul 11, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 15
/
dataloader.py
47 lines (42 loc) · 1.63 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import torch
import torchvision
def get_loader(batch_size, num_workers, use_gpu):
mean = np.array([0.4914, 0.4822, 0.4465])
std = np.array([0.2470, 0.2435, 0.2616])
train_transform = torchvision.transforms.Compose([
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
])
test_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
])
dataset_dir = '~/.torchvision/datasets/CIFAR10'
train_dataset = torchvision.datasets.CIFAR10(dataset_dir,
train=True,
transform=train_transform,
download=True)
test_dataset = torchvision.datasets.CIFAR10(dataset_dir,
train=False,
transform=test_transform,
download=True)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=use_gpu,
drop_last=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False,
pin_memory=use_gpu,
drop_last=False,
)
return train_loader, test_loader