Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#31 from heavengate/fix_compile_prune
Browse files Browse the repository at this point in the history
extract input variable from feed
  • Loading branch information
heavengate authored Apr 9, 2020
2 parents acd23c7 + 35e267f commit 3c5f074
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 44 deletions.
19 changes: 18 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,27 @@ def _run(self, inputs, labels=None):
metric_list, metric_splits = flatten_list(endpoints['metric'])
fetch_list = endpoints['loss'] + metric_list
num_loss = len(endpoints['loss'])

# if fetch Variable is same as input Variable, do not fetch
# from program, get it from input directly
pruned_fetch_list = []
pruned_fetch_idx_name_map = [""] * len(fetch_list)
for i, fetch_var in enumerate(fetch_list):
if fetch_var.name in feed.keys():
pruned_fetch_idx_name_map[i] = fetch_var.name
else:
pruned_fetch_list.append(fetch_var)

rets = self._executor.run(compiled_prog,
feed=feed,
fetch_list=fetch_list,
fetch_list=pruned_fetch_list,
return_numpy=False)

# restore pruned fetch_list Variable from feeds
for i, name in enumerate(pruned_fetch_idx_name_map):
if len(name) > 0:
rets.insert(i, feed[name])

# LoDTensor cannot be fetch as numpy directly
rets = [np.array(v) for v in rets]
if self.mode == 'test':
Expand Down
6 changes: 2 additions & 4 deletions models/yolov3.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, num_classes=80, model_mode='train'):
act='leaky_relu'))
self.route_blocks.append(route)

def forward(self, img_info, inputs):
def forward(self, img_id, img_shape, inputs):
outputs = []
boxes = []
scores = []
Expand All @@ -163,8 +163,6 @@ def forward(self, img_info, inputs):
for m in anchor_mask:
mask_anchors.append(self.anchors[2 * m])
mask_anchors.append(self.anchors[2 * m + 1])
img_shape = fluid.layers.slice(img_info, axes=[1], starts=[1], ends=[3])
img_id = fluid.layers.slice(img_info, axes=[1], starts=[0], ends=[1])
b, s = fluid.layers.yolo_box(
x=block_out,
img_size=img_shape,
Expand All @@ -181,7 +179,7 @@ def forward(self, img_info, inputs):
if self.model_mode == 'train':
return outputs

preds = [img_id[0, :],
preds = [img_id,
fluid.layers.multiclass_nms(
bboxes=fluid.layers.concat(boxes, axis=1),
scores=fluid.layers.concat(scores, axis=2),
Expand Down
21 changes: 12 additions & 9 deletions yolov3/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,30 +186,31 @@ def _getitem_by_index(self, idx):
data = np.frombuffer(f.read(), dtype='uint8')
im = cv2.imdecode(data, 1)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im_info = np.array([roidb['im_id'][0], roidb['h'], roidb['w']], dtype='int32')
im_id = roidb['im_id']
im_shape = np.array([roidb['h'], roidb['w']], dtype='int32')
gt_bbox = roidb['gt_bbox']
gt_class = roidb['gt_class']
gt_score = roidb['gt_score']
return im_info, im, gt_bbox, gt_class, gt_score
return im_id, im_shape, im, gt_bbox, gt_class, gt_score

def __getitem__(self, idx):
im_info, im, gt_bbox, gt_class, gt_score = self._getitem_by_index(idx)
im_id, im_shape, im, gt_bbox, gt_class, gt_score = self._getitem_by_index(idx)

if self._mixup:
mixup_idx = idx + np.random.randint(1, self.__len__())
mixup_idx %= self.__len__()
_, mixup_im, mixup_bbox, mixup_class, _ = \
_, _, mixup_im, mixup_bbox, mixup_class, _ = \
self._getitem_by_index(mixup_idx)

im, gt_bbox, gt_class, gt_score = \
im_shape, im, gt_bbox, gt_class, gt_score = \
self._mixup_image(im, gt_bbox, gt_class, mixup_im,
mixup_bbox, mixup_class)

if self._transform:
im_info, im, gt_bbox, gt_class, gt_score = \
self._transform(im_info, im, gt_bbox, gt_class, gt_score)
im_id, im_shape, im, gt_bbox, gt_class, gt_score = \
self._transform(im_id, im_shape, im, gt_bbox, gt_class, gt_score)

return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]

def _mixup_image(self, img1, bbox1, class1, img2, bbox2, class2):
factor = np.random.beta(self._alpha, self._beta)
Expand All @@ -234,7 +235,9 @@ def _mixup_image(self, img1, bbox1, class1, img2, bbox2, class2):
score2 = np.ones_like(class2, dtype="float32") * (1.0 - factor)
gt_score = np.concatenate((score1, score2), axis=0)

return img, gt_bbox, gt_class, gt_score
im_shape = np.array([h, w], dtype='int32')

return im_shape, img, gt_bbox, gt_class, gt_score

@property
def mixup(self):
Expand Down
3 changes: 2 additions & 1 deletion yolov3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None

inputs = [Input([None, 3], 'int32', name='img_info'),
inputs = [Input([None, 1], 'int64', name='img_id'),
Input([None, 2], 'int32', name='img_shape'),
Input([None, 3, None, None], 'float32', name='image')]
labels = [Input([None, NUM_MAX_BOXES, 4], 'float32', name='gt_bbox'),
Input([None, NUM_MAX_BOXES], 'int32', name='gt_label'),
Expand Down
57 changes: 28 additions & 29 deletions yolov3/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ def apply_brightness(self, img):
img += delta
return img

def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
if self.random_apply:
distortions = np.random.permutation([
self.apply_brightness, self.apply_contrast,
self.apply_saturation, self.apply_hue
])
for func in distortions:
im = func(im)
return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]

im = self.apply_brightness(im)

Expand All @@ -165,7 +165,7 @@ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
im = self.apply_saturation(im)
im = self.apply_hue(im)
im = self.apply_contrast(im)
return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]


class RandomExpand(object):
Expand All @@ -183,16 +183,16 @@ def __init__(self, ratio=4., prob=0.5, fill_value=[123.675, 116.28, 103.53]):
self.prob = prob
self.fill_value = fill_value

def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
if np.random.uniform(0., 1.) < self.prob:
return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]

height, width, _ = im.shape
expand_ratio = np.random.uniform(1., self.ratio)
h = int(height * expand_ratio)
w = int(width * expand_ratio)
if not h > height or not w > width:
return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
y = np.random.randint(0, h - height)
x = np.random.randint(0, w - width)
canvas = np.ones((h, w, 3), dtype=np.uint8)
Expand All @@ -201,7 +201,7 @@ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):

gt_bbox += np.array([x, y, x, y], dtype=np.float32)

return [im_info, canvas, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, canvas, gt_bbox, gt_class, gt_score]


class RandomCrop():
Expand Down Expand Up @@ -232,9 +232,9 @@ def __init__(self,
self.allow_no_crop = allow_no_crop
self.cover_all_box = cover_all_box

def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
if len(gt_bbox) == 0:
return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]

# NOTE Original method attempts to generate one candidate for each
# threshold then randomly sample one from the resulting list.
Expand All @@ -251,7 +251,7 @@ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):

for thresh in thresholds:
if thresh == 'no_crop':
return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]

h, w, _ = im.shape
found = False
Expand Down Expand Up @@ -286,9 +286,9 @@ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
gt_bbox = np.take(cropped_box, valid_ids, axis=0)
gt_class = np.take(gt_class, valid_ids, axis=0)
gt_score = np.take(gt_score, valid_ids, axis=0)
return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]

return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]

def _iou_matrix(self, a, b):
tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
Expand Down Expand Up @@ -334,7 +334,7 @@ def __init__(self, prob=0.5, is_normalized=False):
isinstance(self.is_normalized, bool)):
raise TypeError("{}: input type is invalid.".format(self))

def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
"""Filp the image and bounding box.
Operators:
1. Flip the image numpy.
Expand Down Expand Up @@ -363,20 +363,20 @@ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
m = "{}: invalid box, x2 should be greater than x1".format(
self)
raise ValueError(m)
return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]


class NormalizeBox(object):
"""Transform the bounding box's coornidates to [0,1]."""

def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
height, width, _ = im.shape
for i in range(gt_bbox.shape[0]):
gt_bbox[i][0] = gt_bbox[i][0] / width
gt_bbox[i][1] = gt_bbox[i][1] / height
gt_bbox[i][2] = gt_bbox[i][2] / width
gt_bbox[i][3] = gt_bbox[i][3] / height
return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]


class PadBox(object):
Expand All @@ -388,7 +388,7 @@ def __init__(self, num_max_boxes=50):
"""
self.num_max_boxes = num_max_boxes

def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
gt_num = min(self.num_max_boxes, len(gt_bbox))
num_max = self.num_max_boxes

Expand All @@ -406,18 +406,18 @@ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
if gt_num > 0:
pad_score[:gt_num] = gt_score[:gt_num, 0]
gt_score = pad_score
return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]


class BboxXYXY2XYWH(object):
"""
Convert bbox XYXY format to XYWH format.
"""

def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
gt_bbox[:, 2:4] = gt_bbox[:, 2:4] - gt_bbox[:, :2]
gt_bbox[:, :2] = gt_bbox[:, :2] + gt_bbox[:, 2:4] / 2.
return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]


class RandomShape(object):
Expand Down Expand Up @@ -450,13 +450,13 @@ def __call__(self, samples):
method = np.random.choice(self.interps) if self.random_inter \
else cv2.INTER_NEAREST
for i in range(len(samples)):
im = samples[i][1]
im = samples[i][2]
h, w = im.shape[:2]
scale_x = float(shape) / w
scale_y = float(shape) / h
im = cv2.resize(
im, None, None, fx=scale_x, fy=scale_y, interpolation=method)
samples[i][1] = im
samples[i][2] = im
return samples


Expand Down Expand Up @@ -492,7 +492,7 @@ def __call__(self, samples):
3. (optional) permute channel
"""
for i in range(len(samples)):
im = samples[i][1]
im = samples[i][2]
im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
Expand All @@ -502,7 +502,7 @@ def __call__(self, samples):
im /= std
if self.channel_first:
im = im.transpose((2, 0, 1))
samples[i][1] = im
samples[i][2] = im
return samples


Expand Down Expand Up @@ -595,16 +595,15 @@ def __init__(self,
format(type(target_size)))
self.target_size = target_size

def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
""" Resize the image numpy.
"""
if not isinstance(im, np.ndarray):
raise TypeError("{}: image type is not numpy.".format(self))
if len(im.shape) != 3:
raise ImageError('{}: image is not 3-dimensional.'.format(self))
im_shape = im.shape
im_scale_x = float(self.target_size) / float(im_shape[1])
im_scale_y = float(self.target_size) / float(im_shape[0])
im_scale_x = float(self.target_size) / float(im.shape[1])
im_scale_y = float(self.target_size) / float(im.shape[0])
resize_w = self.target_size
resize_h = self.target_size

Expand All @@ -616,5 +615,5 @@ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
fy=im_scale_y,
interpolation=self.interp)

return [im_info, im, gt_bbox, gt_class, gt_score]
return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]

0 comments on commit 3c5f074

Please sign in to comment.