Skip to content

Commit

Permalink
Merge pull request #76 from rbgirshick/datasets-refactor
Browse files Browse the repository at this point in the history
Refactor datasets package (empty __init__.py)
  • Loading branch information
rbgirshick committed Feb 9, 2016
2 parents 95918a5 + 8d448ba commit 4680720
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 55 deletions.
36 changes: 0 additions & 36 deletions lib/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,3 @@
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

from .imdb import imdb
from .pascal_voc import pascal_voc
from . import factory

import os.path as osp
ROOT_DIR = osp.join(osp.dirname(__file__), '..', '..')

# We assume your matlab binary is in your path and called `matlab'.
# If either is not true, just add it to your path and alias it as matlab, or
# you could change this file.
MATLAB = 'matlab'

# http://stackoverflow.com/questions/377017/test-if-executable-exists-in-python
def _which(program):
import os
def is_exe(fpath):
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)

fpath, fname = os.path.split(program)
if fpath:
if is_exe(program):
return program
else:
for path in os.environ["PATH"].split(os.pathsep):
path = path.strip('"')
exe_file = os.path.join(path, program)
if is_exe(exe_file):
return exe_file

return None

if _which(MATLAB) is None:
msg = ("MATLAB command '{}' not found. "
"Please add '{}' to your PATH.").format(MATLAB, MATLAB)
raise EnvironmentError(msg)
6 changes: 3 additions & 3 deletions lib/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

__sets = {}

import datasets.pascal_voc
from datasets.pascal_voc import pascal_voc
import numpy as np

def _selective_search_IJCV_top_k(split, year, top_k):
"""Return an imdb that uses the top k proposals from the selective search
IJCV code.
"""
imdb = datasets.pascal_voc(split, year)
imdb = pascal_voc(split, year)
imdb.roidb_handler = imdb.selective_search_IJCV_roidb
imdb.config['top_k'] = top_k
return imdb
Expand All @@ -26,7 +26,7 @@ def _selective_search_IJCV_top_k(split, year, top_k):
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year:
datasets.pascal_voc(split, year))
pascal_voc(split, year))

# Set up voc_<year>_<split>_top_<k> using selective search "quality" mode
# but only returning the first k boxes
Expand Down
4 changes: 2 additions & 2 deletions lib/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from utils.cython_bbox import bbox_overlaps
import numpy as np
import scipy.sparse
import datasets
from fast_rcnn.config import cfg

class imdb(object):
"""Image database."""
Expand Down Expand Up @@ -69,7 +69,7 @@ def roidb(self):

@property
def cache_path(self):
cache_path = osp.abspath(osp.join(datasets.ROOT_DIR, 'data', 'cache'))
cache_path = osp.abspath(osp.join(cfg.DATA_DIR, 'cache'))
if not os.path.exists(cache_path):
os.makedirs(cache_path)
return cache_path
Expand Down
24 changes: 12 additions & 12 deletions lib/datasets/pascal_voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,20 @@
# Written by Ross Girshick
# --------------------------------------------------------

import datasets
import datasets.pascal_voc
import os
import datasets.imdb
from datasets.imdb import imdb
import xml.dom.minidom as minidom
import numpy as np
import scipy.sparse
import scipy.io as sio
import utils.cython_bbox
import cPickle
import subprocess
from fast_rcnn.config import cfg

class pascal_voc(datasets.imdb):
class pascal_voc(imdb):
def __init__(self, image_set, year, devkit_path=None):
datasets.imdb.__init__(self, 'voc_' + year + '_' + image_set)
imdb.__init__(self, 'voc_' + year + '_' + image_set)
self._year = year
self._image_set = image_set
self._devkit_path = self._get_default_path() if devkit_path is None \
Expand Down Expand Up @@ -83,7 +82,7 @@ def _get_default_path(self):
"""
Return the default path where PASCAL VOC is expected to be installed.
"""
return os.path.join(datasets.ROOT_DIR, 'data', 'VOCdevkit' + self._year)
return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)

def gt_roidb(self):
"""
Expand Down Expand Up @@ -125,7 +124,7 @@ def selective_search_roidb(self):
if int(self._year) == 2007 or self._image_set != 'test':
gt_roidb = self.gt_roidb()
ss_roidb = self._load_selective_search_roidb(gt_roidb)
roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb)
roidb = imdb.merge_roidbs(gt_roidb, ss_roidb)
else:
roidb = self._load_selective_search_roidb(None)
with open(cache_file, 'wb') as fid:
Expand All @@ -138,7 +137,7 @@ def rpn_roidb(self):
if int(self._year) == 2007 or self._image_set != 'test':
gt_roidb = self.gt_roidb()
rpn_roidb = self._load_rpn_roidb(gt_roidb)
roidb = datasets.imdb.merge_roidbs(gt_roidb, rpn_roidb)
roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
else:
roidb = self._load_rpn_roidb(None)

Expand All @@ -154,7 +153,7 @@ def _load_rpn_roidb(self, gt_roidb):
return self.create_roidb_from_box_list(box_list, gt_roidb)

def _load_selective_search_roidb(self, gt_roidb):
filename = os.path.abspath(os.path.join(self.cache_path, '..',
filename = os.path.abspath(os.path.join(cfg.DATA_DIR,
'selective_search_data',
self.name + '.mat'))
assert os.path.exists(filename), \
Expand Down Expand Up @@ -245,10 +244,10 @@ def _write_voc_results_file(self, all_boxes):
def _do_matlab_eval(self, comp_id, output_dir='output'):
rm_results = self.config['cleanup']

path = os.path.join(os.path.dirname(__file__),
path = os.path.join(cfg.ROOT_DIR, 'lib', 'datasets',
'VOCdevkit-matlab-wrapper')
cmd = 'cd {} && '.format(path)
cmd += '{:s} -nodisplay -nodesktop '.format(datasets.MATLAB)
cmd += '{:s} -nodisplay -nodesktop '.format(cfg.MATLAB)
cmd += '-r "dbstop if error; '
cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\',{:d}); quit;"' \
.format(self._devkit_path, comp_id,
Expand All @@ -269,6 +268,7 @@ def competition_mode(self, on):
self.config['cleanup'] = True

if __name__ == '__main__':
d = datasets.pascal_voc('trainval', '2007')
from datasets.pascal_voc import pascal_voc
d = pascal_voc('trainval', '2007')
res = d.roidb
from IPython import embed; embed()
6 changes: 6 additions & 0 deletions lib/fast_rcnn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@
# Root directory of project
__C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..'))

# Data directory
__C.DATA_DIR = osp.abspath(osp.join(__C.ROOT_DIR, 'data'))

# Name (or path to) the matlab executable
__C.MATLAB = 'matlab'

# Place outputs under an experiments directory
__C.EXP_DIR = 'default'

Expand Down
1 change: 0 additions & 1 deletion tools/train_faster_rcnn_alt_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from datasets.factory import get_imdb
from rpn.generate import imdb_proposals
import datasets.imdb
import argparse
import pprint
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_roidb(imdb_name):
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
imdb = datasets.imdb(imdb_names)
imdb = datasets.imdb.imdb(imdb_names)
else:
imdb = get_imdb(imdb_names)
return imdb, roidb
Expand Down

0 comments on commit 4680720

Please sign in to comment.