Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename some funcs to make the code clearer #47

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 0 additions & 87 deletions lib/fpn/proposal_assignments/proposal_assignments_gtbox.py

This file was deleted.

90 changes: 84 additions & 6 deletions lib/fpn/proposal_assignments/rel_assignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@

import numpy as np
import numpy.random as npr
from config import BG_THRESH_HI, BG_THRESH_LO, REL_FG_FRACTION, RELS_PER_IMG_REFINE
from lib.fpn.box_utils import bbox_overlaps
from lib.pytorch_misc import to_variable, nonintersecting_2d_inds
from collections import defaultdict
from lib.pytorch_misc import enumerate_by_image, gather_nd, random_choose
from lib.fpn.box_utils import bbox_preds, center_size, bbox_overlaps
import torch

from lib.pytorch_misc import diagonal_inds, to_variable
from config import RELS_PER_IMG, REL_FG_FRACTION
@to_variable
def rel_assignments(im_inds, rpn_rois, roi_gtlabels, gt_boxes, gt_classes, gt_rels, image_offset,
def rel_assignments_sgdet(im_inds, rpn_rois, roi_gtlabels, gt_boxes, gt_classes, gt_rels, image_offset,
fg_thresh=0.5, num_sample_per_gt=4, filter_non_overlap=True):
"""
Assign object detection proposals to ground-truth targets. Produces proposal
Expand Down Expand Up @@ -143,3 +142,82 @@ def rel_assignments(im_inds, rpn_rois, roi_gtlabels, gt_boxes, gt_classes, gt_re
rel_labels = torch.LongTensor(np.concatenate(rel_labels, 0)).cuda(rpn_rois.get_device(),
async=True)
return rel_labels

@to_variable
def rel_assignments_sgcls(rois, gt_boxes, gt_classes, gt_rels, image_offset):
"""
sample_rels to balance proportion of positive and negative samples
:param rois: [img_ind, x1, y1, x2, y2]
:param gt_boxes: [num_boxes, 4] array of x0, y0, x1, y1]. Not needed it seems
:param gt_classes: [num_boxes, 2] array of [img_ind, class]
Note, the img_inds here start at image_offset
:param gt_rels [num_boxes, 4] array of [img_ind, box_0, box_1, rel type].
Note, the img_inds here start at image_offset
:param Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH)
:return:
rois: [num_rois, 5]
labels: [num_rois] array of labels
rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type)
"""
im_inds = rois[:,0].long()

num_im = im_inds[-1] + 1

# Offset the image indices in fg_rels to refer to absolute indices (not just within img i)
fg_rels = gt_rels.clone()
fg_rels[:,0] -= image_offset
offset = {}
for i, s, e in enumerate_by_image(im_inds):
offset[i] = s
for i, s, e in enumerate_by_image(fg_rels[:, 0]):
fg_rels[s:e, 1:3] += offset[i]

# Try ALL things, not just intersections.
is_cand = (im_inds[:, None] == im_inds[None])
is_cand.view(-1)[diagonal_inds(is_cand)] = 0

# # Compute salience
# gt_inds = fg_rels[:, 1:3].contiguous().view(-1)
# labels_arange = labels.data.new(labels.size(0))
# torch.arange(0, labels.size(0), out=labels_arange)
# salience_labels = ((gt_inds[:, None] == labels_arange[None]).long().sum(0) > 0).long()
# labels = torch.stack((labels, salience_labels), 1)

# Add in some BG labels

# NOW WE HAVE TO EXCLUDE THE FGs.
# TODO: check if this causes an error if many duplicate GTs havent been filtered out

is_cand.view(-1)[fg_rels[:,1]*im_inds.size(0) + fg_rels[:,2]] = 0
is_bgcand = is_cand.nonzero()
# TODO: make this sample on a per image case
# If too many then sample
num_fg = min(fg_rels.size(0), int(RELS_PER_IMG * REL_FG_FRACTION * num_im))
if num_fg < fg_rels.size(0):
fg_rels = random_choose(fg_rels, num_fg)

# If too many then sample
num_bg = min(is_bgcand.size(0) if is_bgcand.dim() > 0 else 0,
int(RELS_PER_IMG * num_im) - num_fg)
if num_bg > 0:
bg_rels = torch.cat((
im_inds[is_bgcand[:, 0]][:, None],
is_bgcand,
(is_bgcand[:, 0, None] < -10).long(),
), 1)

if num_bg < is_bgcand.size(0):
bg_rels = random_choose(bg_rels, num_bg)
rel_labels = torch.cat((fg_rels, bg_rels), 0)
else:
rel_labels = fg_rels


# last sort by rel.
_, perm = torch.sort(rel_labels[:, 0]*(gt_boxes.size(0)**2) +
rel_labels[:,1]*gt_boxes.size(0) + rel_labels[:,2])

rel_labels = rel_labels[perm].contiguous()

labels = gt_classes[:,1].contiguous()
return rois, labels, rel_labels
4 changes: 2 additions & 2 deletions lib/object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from lib.fpn.generate_anchors import generate_anchors
from lib.fpn.box_utils import bbox_preds, center_size, bbox_overlaps
from lib.fpn.nms.functions.nms import apply_nms
from lib.fpn.proposal_assignments.proposal_assignments_gtbox import proposal_assignments_gtbox
from lib.fpn.proposal_assignments.rel_assignments import rel_assignments_sgcls
from lib.fpn.proposal_assignments.proposal_assignments_det import proposal_assignments_det

from lib.fpn.roi_align.functions.roi_align import RoIAlignFunction
Expand Down Expand Up @@ -214,7 +214,7 @@ def gt_boxes(self, fmap, im_sizes, image_offset, gt_boxes=None, gt_classes=None,
im_inds = gt_classes[:, 0] - image_offset
rois = torch.cat((im_inds.float()[:, None], gt_boxes), 1)
if gt_rels is not None and self.training:
rois, labels, rel_labels = proposal_assignments_gtbox(
rois, labels, rel_labels = rel_assignments_sgcls(
rois.data, gt_boxes.data, gt_classes.data, gt_rels.data, image_offset,
fg_thresh=0.5)
else:
Expand Down
8 changes: 4 additions & 4 deletions lib/rel_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lib.lstm.highway_lstm_cuda.alternating_highway_lstm import AlternatingHighwayLSTM
from lib.fpn.box_utils import bbox_overlaps, center_size
from lib.get_union_boxes import UnionBoxesAndFeats
from lib.fpn.proposal_assignments.rel_assignments import rel_assignments
from lib.fpn.proposal_assignments.rel_assignments import rel_assignments_sgdet
from lib.object_detector import ObjectDetector, gather_res, load_vgg
from lib.pytorch_misc import transpose_packed_sequence_inds, to_onehot, arange, enumerate_by_image, diagonal_inds, Flattener
from lib.sparse_targets import FrequencyBias
Expand Down Expand Up @@ -464,10 +464,10 @@ def forward(self, x, im_sizes, image_offset,
be used to compute the training loss. Each (img_ind, fpn_idx)
:return: If train:
scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels

if test:
prob dists, boxes, img inds, maxscores, classes

"""
result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
train_anchor_inds, return_fmap=True)
Expand All @@ -479,7 +479,7 @@ def forward(self, x, im_sizes, image_offset,

if self.training and result.rel_labels is None:
assert self.mode == 'sgdet'
result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
result.rel_labels = rel_assignments_sgdet(im_inds.data, boxes.data, result.rm_obj_labels.data,
gt_boxes.data, gt_classes.data, gt_rels.data,
image_offset, filter_non_overlap=True,
num_sample_per_gt=1)
Expand Down
10 changes: 5 additions & 5 deletions lib/rel_model_stanford.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.autograd import Variable
from torch.nn import functional as F
from lib.surgery import filter_dets
from lib.fpn.proposal_assignments.rel_assignments import rel_assignments
from lib.fpn.proposal_assignments.rel_assignments import rel_assignments_sgdet
from lib.pytorch_misc import arange
from lib.object_detector import filter_det
from lib.rel_model import RelModel
Expand Down Expand Up @@ -63,7 +63,7 @@ def message_pass(self, rel_rep, obj_rep, rel_inds):
:param rel_rep: [num_rel, fc]
:param obj_rep: [num_obj, fc]
:param rel_inds: [num_rel, 2] of the valid relationships
:return: object prediction [num_obj, 151], bbox_prediction [num_obj, 151*4]
:return: object prediction [num_obj, 151], bbox_prediction [num_obj, 151*4]
and rel prediction [num_rel, 51]
"""
# [num_obj, num_rel] with binary!
Expand Down Expand Up @@ -125,10 +125,10 @@ def forward(self, x, im_sizes, image_offset,
be used to compute the training loss. Each (img_ind, fpn_idx)
:return: If train:
scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels

if test:
prob dists, boxes, img inds, maxscores, classes

"""
result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
train_anchor_inds, return_fmap=True)
Expand All @@ -141,7 +141,7 @@ def forward(self, x, im_sizes, image_offset,

if self.training and result.rel_labels is None:
assert self.mode == 'sgdet'
result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
result.rel_labels = rel_assignments_sgdet(im_inds.data, boxes.data, result.rm_obj_labels.data,
gt_boxes.data, gt_classes.data, gt_rels.data,
image_offset, filter_non_overlap=True, num_sample_per_gt=1)
rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)
Expand Down