Skip to content

Commit

Permalink
rm torhvision
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhang committed Aug 24, 2017
1 parent 3101fe4 commit 0ff21f4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import net
import utils
from option import Options
from data import datasets
import data

def train(args):
np.random.seed(args.seed)
Expand All @@ -26,7 +26,7 @@ def train(args):
utils.CenterCrop(args.image_size),
utils.ToTensor(ctx),
])
train_dataset = datasets.ImageFolder(args.dataset, transform)
train_dataset = data.ImageFolder(args.dataset, transform)
train_loader = gluon.data.DataLoader(train_dataset, batch_size=args.batch_size, last_batch='discard')
style_loader = utils.StyleLoader(args.style_folder, args.style_size, ctx=ctx)
print('len(style_loader):',style_loader.size())
Expand Down
41 changes: 41 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import numbers
from PIL import Image

import numpy as np
Expand Down Expand Up @@ -116,6 +117,46 @@ def __call__(self, img):
return img


class Scale(object):
"""Rescale the input PIL.Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(w, h), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""

def __init__(self, size, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation

def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be scaled.
Returns:
PIL.Image: Rescaled image.
"""
if isinstance(self.size, int):
w, h = img.size
if (w <= h and w == self.size) or (h <= w and h == self.size):
return img
if w < h:
ow = self.size
oh = int(self.size * h / w)
return img.resize((ow, oh), self.interpolation)
else:
oh = self.size
ow = int(self.size * w / h)
return img.resize((ow, oh), self.interpolation)
else:
return img.resize(self.size, self.interpolation)


class CenterCrop(object):
"""Crops the given PIL.Image at the center.
Args:
Expand Down

0 comments on commit 0ff21f4

Please sign in to comment.