-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathutils.py
28 lines (24 loc) · 927 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from torchvision import datasets, transforms
data_dir = "../data"
def prep_data(loader):
dataiter = iter(loader)
images, _ = dataiter.next()
print("prep_data: images.type:",type(images))
print("prep_data: images.shape:",images.shape)
return images
def get_data(is_train=True, batch_size=128):
kwargs = {}
loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=is_train, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True, **kwargs)
return loader
def display_predictions(data):
import pandas as pd
from tabulate import tabulate
df = pd.DataFrame(data).head(10)
print(tabulate(df, headers="keys", tablefmt="psql", showindex=False))