Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
longcw committed May 16, 2017
1 parent 84d05e3 commit edda7d4
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 31 deletions.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,19 @@ YOLO9000: Better, Faster, Stronger by Joseph Redmon and Ali Farhadi.

I used a Cython extension for postprocessing and
`multiprocessing.Pool` for image preprocessing.
Testing an image in VOC2007 costs about 13~20ms.
Testing an image in VOC2007 costs about 13~20ms.

**NOTE:**
This is still an experimental project.
VOC07 test mAP is about 0.71 (trained on VOC07+12 trainval,
reported by [@cory8249](https://github.com/longcw/yolo2-pytorch/issues/23)).
See https://github.com/longcw/yolo2-pytorch/issues/1 and https://github.com/longcw/yolo2-pytorch/issues/23
for more details about training.

BTW, I recommend to write your own dataloader using [torch.utils.data.Dataset](http://pytorch.org/docs/data.html)
since `multiprocessing.Pool.imap` won't stop even there is no enough memory space.



### Installation and demo
1. Clone this repository
Expand Down
48 changes: 18 additions & 30 deletions darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _process_batch(data):
inp_size = cfg.inp_size
out_size = cfg.out_size

bbox_pred_np, gt_boxes, gt_classes, dontcares = data
bbox_pred_np, gt_boxes, gt_classes, dontcares, iou_pred_np = data

# net output
hw, num_anchors, _ = bbox_pred_np.shape
Expand All @@ -61,21 +61,22 @@ def _process_batch(data):
np.ascontiguousarray(bbox_pred_np, dtype=np.float),
anchors,
H, W)
bbox_np = bbox_np[0]
bbox_np[:, :, 0::2] *= float(inp_size[0])
bbox_np[:, :, 1::2] *= float(inp_size[1])
bbox_np = bbox_np[0] # bbox_np = (hw, num_anchors, (x1, y1, x2, y2)) range: 0 ~ 1
bbox_np[:, :, 0::2] *= float(inp_size[0]) # rescale x
bbox_np[:, :, 1::2] *= float(inp_size[1]) # rescale y

# gt_boxes_b = np.asarray(gt_boxes[b], dtype=np.float)
gt_boxes_b = np.asarray(gt_boxes, dtype=np.float)

# for each cell
# for each cell, compare predicted_bbox and gt_bbox
bbox_np_b = np.reshape(bbox_np, [-1, 4])
ious = bbox_ious(
np.ascontiguousarray(bbox_np_b, dtype=np.float),
np.ascontiguousarray(gt_boxes_b, dtype=np.float)
)
best_ious = np.max(ious, axis=1).reshape(_iou_mask.shape)
_iou_mask[best_ious <= cfg.iou_thresh] = cfg.noobject_scale
iou_penalty = 0 - iou_pred_np[best_ious < cfg.iou_thresh]
_iou_mask[best_ious <= cfg.iou_thresh] = cfg.noobject_scale * iou_penalty

# locate the cell of each gt_boxe
cell_w = float(inp_size[0]) / W
Expand Down Expand Up @@ -108,7 +109,8 @@ def _process_batch(data):
continue
a = anchor_inds[i]

_iou_mask[cell_ind, a, :] = cfg.object_scale
iou_pred_cell_anchor = iou_pred_np[cell_ind, a, :] # 0 ~ 1, should be close to 1
_iou_mask[cell_ind, a, :] = cfg.object_scale * (1 - iou_pred_cell_anchor)
# _ious[cell_ind, a, :] = anchor_ious[a, i]
_ious[cell_ind, a, :] = ious_reshaped[cell_ind, a, i]

Expand All @@ -119,8 +121,8 @@ def _process_batch(data):
_class_mask[cell_ind, a, :] = cfg.class_scale
_classes[cell_ind, a, gt_classes[i]] = 1.

_boxes[:, :, 2:4] = np.maximum(_boxes[:, :, 2:4], 0.001)
_boxes[:, :, 2:4] = np.log(_boxes[:, :, 2:4])
# _boxes[:, :, 2:4] = np.maximum(_boxes[:, :, 2:4], 0.001)
# _boxes[:, :, 2:4] = np.log(_boxes[:, :, 2:4])

return _boxes, _ious, _classes, _box_mask, _iou_mask, _class_mask

Expand Down Expand Up @@ -172,14 +174,10 @@ def loss(self):

def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):
conv1s = self.conv1s(im_data)

conv2 = self.conv2(conv1s)

conv3 = self.conv3(conv2)

conv1s_reorg = self.reorg(conv1s)
cat_1_3 = torch.cat([conv1s_reorg, conv3], 1)

conv4 = self.conv4(cat_1_3)
conv5 = self.conv5(conv4) # batch_size, out_channels, h, w

Expand All @@ -191,11 +189,8 @@ def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):

# tx, ty, tw, th, to -> sig(tx), sig(ty), exp(tw), exp(th), sig(to)
xy_pred = F.sigmoid(conv5_reshaped[:, :, :, 0:2])

wh_pred = conv5_reshaped[:, :, :, 2:4]
wh_pred_exp = torch.exp(wh_pred)
bbox_pred = torch.cat([xy_pred, wh_pred_exp], 3)

wh_pred = torch.exp(conv5_reshaped[:, :, :, 2:4])
bbox_pred = torch.cat([xy_pred, wh_pred], 3)
iou_pred = F.sigmoid(conv5_reshaped[:, :, :, 4:5])

score_pred = conv5_reshaped[:, :, :, 5:].contiguous()
Expand All @@ -204,8 +199,9 @@ def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):
# for training
if self.training:
bbox_pred_np = bbox_pred.data.cpu().numpy()
iou_pred_np = iou_pred.data.cpu().numpy()
_boxes, _ious, _classes, _box_mask, _iou_mask, _class_mask = self._build_target(
bbox_pred_np, gt_boxes, gt_classes, dontcare)
bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np)

_boxes = net_utils.np_to_variable(_boxes)
_ious = net_utils.np_to_variable(_ious)
Expand All @@ -218,30 +214,23 @@ def forward(self, im_data, gt_boxes=None, gt_classes=None, dontcare=None):

# _boxes[:, :, :, 2:4] = torch.log(_boxes[:, :, :, 2:4])
box_mask = box_mask.expand_as(_boxes)
# self.bbox_loss = torch.sum(torch.pow(_boxes - bbox_pred, 2) * box_mask) / num_boxes
bbox_pred_log = torch.cat([xy_pred, wh_pred], 3)
self.bbox_loss = nn.MSELoss(size_average=False)(bbox_pred_log * box_mask, _boxes * box_mask) / num_boxes

# self.iou_loss = torch.sum(torch.pow(_ious - iou_pred, 2) * iou_mask) / num_boxes
self.bbox_loss = nn.MSELoss(size_average=False)(bbox_pred * box_mask, _boxes * box_mask) / num_boxes
self.iou_loss = nn.MSELoss(size_average=False)(iou_pred * iou_mask, _ious * iou_mask) / num_boxes

class_mask = class_mask.expand_as(prob_pred)
# self.cls_loss = torch.sum(torch.pow(_classes - prob_pred, 2) * class_mask) / num_boxes
self.cls_loss = nn.MSELoss(size_average=False)(prob_pred * class_mask, _classes * class_mask) / num_boxes

# wh_pred = torch.exp(conv5_reshaped[:, :, :, 2:4])
# bbox_pred = torch.cat([xy_pred, wh_pred], 3)

return bbox_pred, iou_pred, prob_pred

def _build_target(self, bbox_pred_np, gt_boxes, gt_classes, dontcare):
def _build_target(self, bbox_pred_np, gt_boxes, gt_classes, dontcare, iou_pred_np):
"""
:param bbox_pred: shape: (bsize, h x w, num_anchors, 4) : (sig(tx), sig(ty), exp(tw), exp(th))
"""

bsize = bbox_pred_np.shape[0]

targets = self.pool.map(_process_batch, ((bbox_pred_np[b], gt_boxes[b], gt_classes[b], dontcare[b]) for b in range(bsize)))
targets = self.pool.map(_process_batch, ((bbox_pred_np[b], gt_boxes[b], gt_classes[b], dontcare[b], iou_pred_np[b]) for b in range(bsize)))

_boxes = np.stack(tuple((row[0] for row in targets)))
_ious = np.stack(tuple((row[1] for row in targets)))
Expand All @@ -250,7 +239,6 @@ def _build_target(self, bbox_pred_np, gt_boxes, gt_classes, dontcare):
_iou_mask = np.stack(tuple((row[4] for row in targets)))
_class_mask = np.stack(tuple((row[5] for row in targets)))


return _boxes, _ious, _classes, _box_mask, _iou_mask, _class_mask

def load_from_npz(self, fname, num_conv=None):
Expand Down

0 comments on commit edda7d4

Please sign in to comment.