Skip to content

Commit

Permalink
Faster matching with parallel reading and writing (#242)
Browse files Browse the repository at this point in the history
* Deprecate multi-reference matching
* Parallel match reading and writing
* Performance optimization in extract_features
* Remove unused map_tensor
  • Loading branch information
skydes authored Nov 2, 2022
1 parent 23b0c31 commit 7e6551d
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 68 deletions.
20 changes: 8 additions & 12 deletions hloc/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from . import extractors, logger
from .utils.base_model import dynamic_load
from .utils.tools import map_tensor
from .utils.parsers import parse_image_lists
from .utils.io import read_image, list_h5_names

Expand Down Expand Up @@ -211,7 +210,6 @@ def __getitem__(self, idx):
image = image / 255.

data = {
'name': name,
'image': image,
'original_size': np.array(size),
}
Expand All @@ -232,28 +230,26 @@ def main(conf: Dict,
logger.info('Extracting local features with configuration:'
f'\n{pprint.pformat(conf)}')

loader = ImageDataset(image_dir, conf['preprocessing'], image_list)
loader = torch.utils.data.DataLoader(loader, num_workers=1)

dataset = ImageDataset(image_dir, conf['preprocessing'], image_list)
if feature_path is None:
feature_path = Path(export_dir, conf['output']+'.h5')
feature_path.parent.mkdir(exist_ok=True, parents=True)
skip_names = set(list_h5_names(feature_path)
if feature_path.exists() and not overwrite else ())
if set(loader.dataset.names).issubset(set(skip_names)):
dataset.names = [n for n in dataset.names if n not in skip_names]
if len(dataset.names) == 0:
logger.info('Skipping the extraction.')
return feature_path

device = 'cuda' if torch.cuda.is_available() else 'cpu'
Model = dynamic_load(extractors, conf['model']['name'])
model = Model(conf['model']).eval().to(device)

for data in tqdm(loader):
name = data['name'][0] # remove batch dimension
if name in skip_names:
continue

pred = model(map_tensor(data, lambda x: x.to(device)))
loader = torch.utils.data.DataLoader(
dataset, num_workers=1, shuffle=False, pin_memory=True)
for idx, data in enumerate(tqdm(loader)):
name = dataset.names[idx]
pred = model({'image': data['image'].to(device, non_blocking=True)})
pred = {k: v[0].cpu().numpy() for k, v in pred.items()}

pred['image_size'] = original_size = data['original_size'][0].numpy()
Expand Down
120 changes: 81 additions & 39 deletions hloc/match_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
from typing import Union, Optional, Dict, List, Tuple
from pathlib import Path
import pprint
import collections.abc as collections
from queue import Queue
from threading import Thread
from functools import partial
from tqdm import tqdm
import h5py
import torch

from . import matchers, logger
from .utils.base_model import dynamic_load
from .utils.parsers import names_to_pair, names_to_pair_old, parse_retrieval
from .utils.io import list_h5_names


'''
Expand Down Expand Up @@ -68,6 +69,71 @@
}


class WorkQueue():
def __init__(self, work_fn, num_threads=1):
self.queue = Queue(num_threads)
self.threads = [
Thread(target=self.thread_fn, args=(work_fn,))
for _ in range(num_threads)
]
for thread in self.threads:
thread.start()

def join(self):
for thread in self.threads:
self.queue.put(None)
for thread in self.threads:
thread.join()

def thread_fn(self, work_fn):
item = self.queue.get()
while item is not None:
work_fn(item)
item = self.queue.get()

def put(self, data):
self.queue.put(data)


class FeaturePairsDataset(torch.utils.data.Dataset):
def __init__(self, pairs, feature_path_q, feature_path_r):
self.pairs = pairs
self.feature_path_q = feature_path_q
self.feature_path_r = feature_path_r

def __getitem__(self, idx):
name0, name1 = self.pairs[idx]
data = {}
with h5py.File(self.feature_path_q, 'r') as fd:
grp = fd[name0]
for k, v in grp.items():
data[k+'0'] = torch.from_numpy(v.__array__()).float()
# some matchers might expect an image but only use its size
data['image0'] = torch.empty((1,)+tuple(grp['image_size'])[::-1])
with h5py.File(self.feature_path_r, 'r') as fd:
grp = fd[name1]
for k, v in grp.items():
data[k+'1'] = torch.from_numpy(v.__array__()).float()
data['image1'] = torch.empty((1,)+tuple(grp['image_size'])[::-1])
return data

def __len__(self):
return len(self.pairs)


def writer_fn(inp, match_path):
pair, pred = inp
with h5py.File(str(match_path), 'a', libver='latest') as fd:
if pair in fd:
del fd[pair]
grp = fd.create_group(pair)
matches = pred['matches0'][0].cpu().short().numpy()
grp.create_dataset('matches0', data=matches)
if 'matching_scores0' in pred:
scores = pred['matching_scores0'][0].cpu().half().numpy()
grp.create_dataset('matching_scores0', data=scores)


def main(conf: Dict,
pairs: Path, features: Union[Path, str],
export_dir: Optional[Path] = None,
Expand All @@ -91,11 +157,6 @@ def main(conf: Dict,

if features_ref is None:
features_ref = features_q
if isinstance(features_ref, collections.Iterable):
features_ref = list(features_ref)
else:
features_ref = [features_ref]

match_from_paths(conf, pairs, matches, features_q, features_ref, overwrite)

return matches
Expand Down Expand Up @@ -127,18 +188,15 @@ def match_from_paths(conf: Dict,
pairs_path: Path,
match_path: Path,
feature_path_q: Path,
feature_paths_refs: Path,
feature_path_ref: Path,
overwrite: bool = False) -> Path:
logger.info('Matching local features with configuration:'
f'\n{pprint.pformat(conf)}')

if not feature_path_q.exists():
raise FileNotFoundError(f'Query feature file {feature_path_q}.')
for path in feature_paths_refs:
if not path.exists():
raise FileNotFoundError(f'Reference feature file {path}.')
name2ref = {n: i for i, p in enumerate(feature_paths_refs)
for n in list_h5_names(p)}
if not feature_path_ref.exists():
raise FileNotFoundError(f'Reference feature file {feature_path_ref}.')
match_path.parent.mkdir(exist_ok=True, parents=True)

assert pairs_path.exists(), pairs_path
Expand All @@ -153,34 +211,18 @@ def match_from_paths(conf: Dict,
Model = dynamic_load(matchers, conf['model']['name'])
model = Model(conf['model']).eval().to(device)

for (name0, name1) in tqdm(pairs, smoothing=.1):
data = {}
with h5py.File(str(feature_path_q), 'r', libver='latest') as fd:
grp = fd[name0]
for k, v in grp.items():
data[k+'0'] = torch.from_numpy(v.__array__()).float().to(device)
# some matchers might expect an image but only use its size
data['image0'] = torch.empty((1,)+tuple(grp['image_size'])[::-1])
with h5py.File(str(feature_paths_refs[name2ref[name1]]), 'r', libver='latest') as fd:
grp = fd[name1]
for k, v in grp.items():
data[k+'1'] = torch.from_numpy(v.__array__()).float().to(device)
data['image1'] = torch.empty((1,)+tuple(grp['image_size'])[::-1])
data = {k: v[None] for k, v in data.items()}
dataset = FeaturePairsDataset(pairs, feature_path_q, feature_path_ref)
loader = torch.utils.data.DataLoader(
dataset, num_workers=5, batch_size=1, shuffle=False, pin_memory=True)
writer_queue = WorkQueue(partial(writer_fn, match_path=match_path), 5)

for idx, data in enumerate(tqdm(loader, smoothing=.1)):
data = {k: v if k.startswith('image')
else v.to(device, non_blocking=True) for k, v in data.items()}
pred = model(data)
pair = names_to_pair(name0, name1)
with h5py.File(str(match_path), 'a', libver='latest') as fd:
if pair in fd:
del fd[pair]
grp = fd.create_group(pair)
matches = pred['matches0'][0].cpu().short().numpy()
grp.create_dataset('matches0', data=matches)

if 'matching_scores0' in pred:
scores = pred['matching_scores0'][0].cpu().half().numpy()
grp.create_dataset('matching_scores0', data=scores)

pair = names_to_pair(*pairs[idx])
writer_queue.put((pair, pred))
writer_queue.join()
logger.info('Finished exporting matches.')


Expand Down
17 changes: 0 additions & 17 deletions hloc/utils/tools.py

This file was deleted.

0 comments on commit 7e6551d

Please sign in to comment.