forked from TransEmbedBA/TREMBA
-
Notifications
You must be signed in to change notification settings - Fork 1
/
DataLoader.py
60 lines (44 loc) · 1.64 KB
/
DataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
import torchvision.datasets as dset
import torch.utils.data
import torchvision.transforms as transforms
import os
import json
import numpy as np
from torch.utils.data import Dataset, DataLoader
def imagenet(state):
if 'defense' in state and state['defense']:
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.5, 0.5, 0.5])
else:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
train_loader = torch.utils.data.DataLoader(
dset.ImageFolder(state['train_path'], transform=transform),
batch_size=state['batch_size'], shuffle=False,
num_workers=8, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
dset.ImageFolder(state['data_path'], transform=transform),
batch_size=state['batch_size'], shuffle=False,
num_workers=8, pin_memory=True)
nlabels = 1000
return train_loader, test_loader, nlabels, mean, std
def gvision(state):
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
test_loader = torch.utils.data.DataLoader(
dset.ImageFolder(state['data_path'], transform=transform),
batch_size=1, shuffle=False,
num_workers=1, pin_memory=True)
nlabels = 1
labels = ["Cat"]
return test_loader, nlabels, labels, mean, std