-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accelerating image processing for CNN #668
Changes from 5 commits
9d72cab
8c9a967
4d99782
fe073d1
ae06deb
9d2f49c
84d47ac
978d6e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
import os, sys | ||
import numpy as np | ||
from PIL import Image | ||
from cStringIO import StringIO | ||
import multiprocessing | ||
from functools import partial | ||
|
||
from paddle.utils.image_util import * | ||
from paddle.trainer.config_parser import logger | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加入 |
||
try: | ||
import cv2 | ||
except ImportError: | ||
logger.warning("OpenCV2 is not installed, using PIL to prcoess") | ||
cv2 = None | ||
|
||
|
||
class CvTransfomer(ImageTransformer): | ||
""" | ||
CvTransfomer used python-opencv to process image. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
min_size=None, | ||
crop_size=None, | ||
transpose=(2, 0, 1), # transpose to C * H * W | ||
channel_swap=None, | ||
mean=None, | ||
is_train=True, | ||
is_color=True): | ||
ImageTransformer.__init__(self, transpose, channel_swap, mean, is_color) | ||
self.min_size = min_size | ||
self.crop_size = crop_size | ||
self.is_train = is_train | ||
|
||
def resize(self, im, min_size): | ||
row, col = im.shape[:2] | ||
new_row, new_col = min_size, min_size | ||
if row > col: | ||
new_row = min_size * row / col | ||
else: | ||
new_col = min_size * col / row | ||
im = cv2.resize(im, (new_row, new_col), interpolation=cv2.INTER_CUBIC) | ||
return im | ||
|
||
def crop_and_flip(self, im): | ||
""" | ||
Return cropped image. | ||
The size of the cropped image is inner_size * inner_size. | ||
im: (H x W x K) ndarrays | ||
""" | ||
row, col = im.shape[:2] | ||
start_h, start_w = 0, 0 | ||
if self.is_train: | ||
start_h = np.random.randint(0, row - self.crop_size + 1) | ||
start_w = np.random.randint(0, col - self.crop_size + 1) | ||
else: | ||
start_h = (row - self.crop_size) / 2 | ||
start_w = (col - self.crop_size) / 2 | ||
end_h, end_w = start_h + self.crop_size, start_w + self.crop_size | ||
if self.is_color: | ||
im = im[start_h:end_h, start_w:end_w, :] | ||
else: | ||
im = im[start_h:end_h, start_w:end_w] | ||
if (self.is_train) and (np.random.randint(2) == 0): | ||
if self.is_color: | ||
im = im[:, ::-1, :] | ||
else: | ||
im = im[:, ::-1] | ||
return im | ||
|
||
def transform(self, im): | ||
im = self.resize(im, self.min_size) | ||
im = self.crop_and_flip(im) | ||
# transpose, swap channel, sub mean | ||
im = im.astype('float32') | ||
ImageTransformer.transformer(self, im) | ||
return im | ||
|
||
def load_image_from_string(self, data): | ||
flag = cv2.CV_LOAD_IMAGE_COLOR if self.is_color else cv2.CV_LOAD_IMAGE_GRAYSCALE | ||
im = cv2.imdecode(np.fromstring(data, np.uint8), flag) | ||
return im | ||
|
||
def transform_from_string(self, data): | ||
im = self.load_image_from_string(data) | ||
return self.transform(im) | ||
|
||
def load_image_from_file(self, file): | ||
flag = cv2.CV_LOAD_IMAGE_COLOR if self.is_color else cv2.CV_LOAD_IMAGE_GRAYSCALE | ||
im = cv2.imread(file, flag) | ||
return im | ||
|
||
def transform_from_file(self, file): | ||
im = self.load_image_from_file(file) | ||
return self.transform(im) | ||
|
||
|
||
class PILTransfomer(ImageTransformer): | ||
""" | ||
PILTransfomer used PIL to process image. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
min_size=None, | ||
crop_size=None, | ||
transpose=(2, 0, 1), # transpose to C * H * W | ||
channel_swap=None, | ||
mean=None, | ||
is_train=True, | ||
is_color=True): | ||
ImageTransformer.__init__(self, transpose, channel_swap, mean, is_color) | ||
self.min_size = min_size | ||
self.crop_size = crop_size | ||
self.is_train = is_train | ||
|
||
def resize(self, im, min_size): | ||
row, col = im.size[:2] | ||
new_row, new_col = min_size, min_size | ||
if row > col: | ||
new_row = min_size * row / col | ||
else: | ||
new_col = min_size * col / row | ||
im = im.resize((new_row, new_col), Image.ANTIALIAS) | ||
return im | ||
|
||
def crop_and_flip(self, im): | ||
""" | ||
Return cropped image. | ||
The size of the cropped image is inner_size * inner_size. | ||
""" | ||
row, col = im.size[:2] | ||
start_h, start_w = 0, 0 | ||
if self.is_train: | ||
start_h = np.random.randint(0, row - self.crop_size + 1) | ||
start_w = np.random.randint(0, col - self.crop_size + 1) | ||
else: | ||
start_h = (row - self.crop_size) / 2 | ||
start_w = (col - self.crop_size) / 2 | ||
end_h, end_w = start_h + self.crop_size, start_w + self.crop_size | ||
im = im.crop((start_h, start_w, end_h, end_w)) | ||
if (self.is_train) and (np.random.randint(2) == 0): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里不需要加括号 |
||
im = im.transpose(Image.FLIP_LEFT_RIGHT) | ||
return im | ||
|
||
def transform(self, im): | ||
im = self.resize(im, self.min_size) | ||
im = self.crop_and_flip(im) | ||
im = np.array(im, dtype=np.float32) # convert to numpy.array | ||
# transpose, swap channel, sub mean | ||
ImageTransformer.transformer(self, im) | ||
return im | ||
|
||
def load_image_from_string(self, data): | ||
im = Image.open(StringIO(data)) | ||
return im | ||
|
||
def transform_from_string(self, data): | ||
im = self.load_image_from_string(data) | ||
return self.transform(im) | ||
|
||
def load_image_from_file(self, file): | ||
im = Image.open(file) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 最好不要使用file作为参数名,file在python里面是个函数 |
||
return im | ||
|
||
def transform_from_file(self, file): | ||
im = self.load_image_from_file(file) | ||
return self.transform(im) | ||
|
||
|
||
def warpper(cls, (dat, label)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删了吧,没用 |
||
return cls.job(dat, label) | ||
|
||
|
||
class MultiProcessImageTransformer(object): | ||
def __init__(self, | ||
procnum=10, | ||
resize_size=None, | ||
crop_size=None, | ||
transpose=(2, 0, 1), | ||
channel_swap=None, | ||
mean=None, | ||
is_train=True, | ||
is_color=True, | ||
is_img_string=True): | ||
""" | ||
Processing image with multi-process. If it is used in PyDataProvider, | ||
the simple usage for CNN is as follows: | ||
|
||
.. code-block:: python | ||
|
||
def hool(settings, is_train, **kwargs): | ||
settings.is_train = is_train | ||
settings.mean_value = np.array([103.939,116.779,123.68], dtype=np.float32) | ||
settings.input_types = [ | ||
dense_vector(3 * 224 * 224), | ||
integer_value(1)] | ||
settings.transformer = MultiProcessImageTransformer( | ||
procnum=10, | ||
resize_size=256, | ||
crop_size=224, | ||
transpose=(2, 0, 1), | ||
mean=settings.mean_values, | ||
is_train=settings.is_train) | ||
|
||
|
||
@provider(init_hook=hook, pool_size=20480) | ||
def process(settings, file_list): | ||
with open(file_list, 'r') as fdata: | ||
for line in fdata: | ||
data_dic = np.load(line.strip()) # load the data batch pickled by Pickle. | ||
data = data_dic['data'] | ||
labels = data_dic['label'] | ||
labels = np.array(labels, dtype=np.float32) | ||
for im, lab in settings.dp.run(data, labels): | ||
yield [im.astype('float32'), int(lab)] | ||
|
||
:param procnum: processor number. | ||
:type procnum: int | ||
:param resize_size: the shorter edge size of image after resizing. | ||
:type resize_size: int | ||
:param crop_size: the croping size. | ||
:type crop_size: int | ||
:param transpose: the transpose order, Paddle only allow C * H * W order. | ||
:type transpose: tuple or list | ||
:param channel_swap: the channel swap order, RGB or BRG. | ||
:type channel_swap: tuple or list | ||
:param mean: the mean values of image, per-channel mean or element-wise mean. | ||
:type mean: array, The dimension is 1 for per-channel mean. | ||
The dimension is 3 for element-wise mean. | ||
:param is_train: training peroid or testing peroid. | ||
:type is_train: bool. | ||
:param is_color: the image is color or gray. | ||
:type is_color: bool. | ||
:param is_img_string: The input can be the file name of image or image string. | ||
:type is_img_string: bool. | ||
""" | ||
|
||
self.pool = multiprocessing.Pool(procnum) | ||
self.is_img_string = is_img_string | ||
if cv2 is not None: | ||
self.transformer = CvTransfomer(resize_size, crop_size, transpose, | ||
channel_swap, mean, is_train, | ||
is_color) | ||
else: | ||
self.transformer = PILTransfomer(resize_size, crop_size, transpose, | ||
channel_swap, mean, is_train, | ||
is_color) | ||
|
||
def run(self, data, label): | ||
try: | ||
fun = partial(warpper, self) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 另外,最好不要 |
||
return self.pool.imap_unordered(fun, zip(data, label), chunksize=5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有几个问题:
|
||
except KeyboardInterrupt: | ||
self.pool.terminate() | ||
except Exception, e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里不是阻塞的,所以这里try...except一点用应该都没有 |
||
self.pool.terminate() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @staticmethod
def __job__(is_img_string, transformer, data, label):
if is_img_string:
transformer. transform_from_string(data), label
else:
transformer. transform_from_file(data), label |
||
def job(self, data, label): | ||
if self.is_img_string: | ||
return self.transformer.transform_from_string(data), label | ||
else: | ||
return self.transformer.transform_from_file(data), label | ||
|
||
def __getstate__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不要加这个函数了 |
||
self_dict = self.__dict__.copy() | ||
del self_dict['pool'] | ||
return self_dict | ||
|
||
def __setstate__(self, state): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不要加这个函数了 |
||
self.__dict__.update(state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加上__all__字段,把需要Export的东西,Export出来。