Skip to content

Commit

Permalink
fix batch size > 1, test=dygraph (PaddlePaddle#2141)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerrywgz authored Jan 30, 2021
1 parent 32edc34 commit 3cd9e85
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
4 changes: 2 additions & 2 deletions dygraph/ppdet/data/transform/batch_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def __call__(self, samples, context=None):
gt_num_max = max(gt_num)

for i, data in enumerate(samples):
gt_box_data = np.zeros([gt_num_max, 4], dtype=np.float32)
gt_class_data = np.zeros([gt_num_max], dtype=np.int32)
gt_box_data = -np.ones([gt_num_max, 4], dtype=np.float32)
gt_class_data = -np.ones([gt_num_max], dtype=np.int32)
is_crowd_data = np.ones([gt_num_max], dtype=np.int32)

if pad_mask:
Expand Down
11 changes: 5 additions & 6 deletions dygraph/ppdet/modeling/proposal_generator/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,15 @@ def label_box(anchors, gt_boxes, positive_overlap, negative_overlap,
return default_matches, default_match_labels
matched_vals, matches = paddle.topk(iou, k=1, axis=0)
match_labels = paddle.full(matches.shape, -1, dtype='int32')

match_labels = paddle.where(matched_vals < negative_overlap,
paddle.zeros_like(match_labels), match_labels)
match_labels = paddle.where(matched_vals >= positive_overlap,
paddle.ones_like(match_labels), match_labels)
if allow_low_quality:
highest_quality_foreach_gt = iou.max(axis=1, keepdim=True)
pred_inds_with_highest_quality = (
iou == highest_quality_foreach_gt).cast('int32').sum(0,
keepdim=True)
pred_inds_with_highest_quality = paddle.logical_and(
iou > 0, iou == highest_quality_foreach_gt).cast('int32').sum(
0, keepdim=True)
match_labels = paddle.where(pred_inds_with_highest_quality > 0,
paddle.ones_like(match_labels),
match_labels)
Expand Down Expand Up @@ -151,7 +150,7 @@ def generate_proposal_target(rpn_rois,
for i, rpn_roi in enumerate(rpn_rois):
max_overlap = max_overlaps[i] if is_cascade_rcnn else None
gt_bbox = gt_boxes[i]
gt_classes = gt_classes[i]
gt_class = gt_classes[i]
if is_cascade_rcnn:
rpn_roi = filter_roi(rpn_roi, max_overlap)
bbox = paddle.concat([rpn_roi, gt_bbox])
Expand All @@ -161,7 +160,7 @@ def generate_proposal_target(rpn_rois,
bbox, gt_bbox, fg_thresh, bg_thresh, False)
# Step2: sample bbox
sampled_inds, sampled_gt_classes = sample_bbox(
matches, match_labels, gt_classes, batch_size_per_im, fg_fraction,
matches, match_labels, gt_class, batch_size_per_im, fg_fraction,
num_classes, use_random)

# Step3: make output
Expand Down

0 comments on commit 3cd9e85

Please sign in to comment.