Skip to content

Minor improvements #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 19, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def download(self):
import gzip

if self._check_exists():
print('Files already downloaded')
return

# download files
Expand All @@ -98,8 +97,8 @@ def download(self):
os.unlink(file_path)

# process and save as torch files
print('Processing')
print('Processing...')

training_set = (
read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
Expand Down
30 changes: 18 additions & 12 deletions torchvision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
import types

class Compose(object):
""" Composes several transforms together.
For example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""Composes several transforms together.

Args:
transforms (List[Transform]): list of transforms to compose.

Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
Expand All @@ -25,8 +29,9 @@ def __call__(self, img):


class ToTensor(object):
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
"""Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def __call__(self, pic):
if isinstance(pic, np.ndarray):
# handle numpy array
Expand All @@ -40,8 +45,9 @@ def __call__(self, pic):
img = img.transpose(0, 1).transpose(0, 2).contiguous()
return img.float().div(255)


class ToPILImage(object):
""" Converts a torch.*Tensor of range [0, 1] and shape C x H x W
"""Converts a torch.*Tensor of range [0, 1] and shape C x H x W
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
to a PIL.Image of range [0, 255]
"""
Expand All @@ -56,7 +62,7 @@ def __call__(self, pic):
return img

class Normalize(object):
""" Given mean: (R, G, B) and std: (R, G, B),
"""Given mean: (R, G, B) and std: (R, G, B),
will normalize each channel of the torch.*Tensor, i.e.
channel = (channel - mean) / std
"""
Expand All @@ -72,7 +78,7 @@ def __call__(self, tensor):


class Scale(object):
""" Rescales the input PIL.Image to the given 'size'.
"""Rescales the input PIL.Image to the given 'size'.
'size' will be the size of the smaller edge.
For example, if height > width, then image will be
rescaled to (size * height / width, size)
Expand Down Expand Up @@ -128,7 +134,7 @@ def __call__(self, img):
return ImageOps.expand(img, border=self.padding, fill=self.fill)

class Lambda(object):
"""Applies a lambda as a transform"""
"""Applies a lambda as a transform."""
def __init__(self, lambd):
assert type(lambd) is types.LambdaType
self.lambd = lambd
Expand Down