Skip to content

Commit

Permalink
Merge pull request #10 from pytorch/cifarfix
Browse files Browse the repository at this point in the history
making cifar data loader also return PIL Image
  • Loading branch information
soumith authored Nov 29, 2016
2 parents 98b9aa5 + 05bcb18 commit 2d55b9d
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 6 deletions.
173 changes: 168 additions & 5 deletions test/sanity_checks.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def __getitem__(self, index):
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]

# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(np.transpose(img, (1,2,0)))

if self.transform is not None:
img = self.transform(img)
Expand Down
12 changes: 11 additions & 1 deletion torchvision/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import torch
import math

def make_grid(tensor, nrow=8, padding=2):
"""
Given a 4D mini-batch Tensor of shape (B x C x H x W),
or a list of images all of the same size,
makes a grid of images
"""
import math
tensorlist = None
if isinstance(tensor, list):
tensorlist = tensor
numImages = len(tensorlist)
size = torch.Size(torch.Size([long(numImages)]) + tensorlist[0].size())
tensor = tensorlist[0].new(size)
for i in range(numImages):
tensor[i].copy_(tensorlist[i])
if tensor.dim() == 3: # single image
return tensor
# make the mini-batch of images into a grid
Expand Down

0 comments on commit 2d55b9d

Please sign in to comment.