From ccda3044a71e87eb13767279e20ace446b770ebe Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Tue, 13 Sep 2022 17:54:08 +0200 Subject: [PATCH] Cleanup --- hloc/extract_features.py | 11 ++++++ hloc/extractors/disk.py | 76 +++++++++++++++++----------------------- 2 files changed, 43 insertions(+), 44 deletions(-) diff --git a/hloc/extract_features.py b/hloc/extract_features.py index 7720f936..6ac014c0 100644 --- a/hloc/extract_features.py +++ b/hloc/extract_features.py @@ -108,6 +108,17 @@ 'resize_max': 1600, }, }, + 'disk': { + 'output': 'feats-disk', + 'model': { + 'name': 'disk', + 'max_keypoints': 5000, + }, + 'preprocessing': { + 'grayscale': False, + 'resize_max': 1600, + }, + }, # Global descriptors 'dir': { 'output': 'global-feats-dir', diff --git a/hloc/extractors/disk.py b/hloc/extractors/disk.py index 4545eb0e..d668d30c 100644 --- a/hloc/extractors/disk.py +++ b/hloc/extractors/disk.py @@ -8,7 +8,7 @@ disk_path = Path(__file__).parent / "../../third_party/disk" sys.path.append(str(disk_path)) -from disk import DISK as _DISK +from disk import DISK as _DISK # noqa E402 class DISK(BaseModel): @@ -22,17 +22,18 @@ class DISK(BaseModel): required_inputs = ['image'] def _init(self, conf): - state_dict = torch.load(disk_path/conf['model_name'], - map_location='cpu') + self.model = _DISK(window=8, desc_dim=conf['desc_dim']) + state_dict = torch.load( + disk_path / conf['model_name'], map_location='cpu') if 'extractor' in state_dict: weights = state_dict['extractor'] elif 'disk' in state_dict: weights = state_dict['disk'] else: raise KeyError('Incompatible weight file!') - self.model = _DISK(window=8, desc_dim=conf['desc_dim']) self.model.load_state_dict(weights) + if conf['mode'] == 'nms': self.extract = partial( self.model.features, @@ -44,50 +45,37 @@ def _init(self, conf): elif conf['mode'] == 'rng': self.extract = partial(self.model.features, kind='rng') else: - raise KeyError('mode must be either nms or rng!') + raise KeyError( + f'mode must be `nms` or `rng`, got `{conf["mode"]}`') def _forward(self, data): - img = data['image'][0] - assert len(img.shape) == 3 and img.shape[0] == 3 - # pad img so that its height and width be the multiple of 16 - # as required by the original dis repo - orig_h, orig_w = img.shape[1:] - new_h = ((orig_h-1)//16 + 1) * 16 - new_w = ((orig_w-1)//16 + 1) * 16 - y_pad = new_h - orig_h - x_pad = new_w - orig_w + image = data['image'] + # make sure that the dimensions of the image are multiple of 16 + orig_h, orig_w = image.shape[-2:] + new_h = round(orig_h / 16) * 16 + new_w = round(orig_w / 16) * 16 + image = F.pad(image, (0, new_w - orig_w, 0, new_h - orig_h)) - img = F.pad(img, (0, x_pad, 0, y_pad)) - assert img.shape[1] == new_h and img.shape[2] == new_w, "Wrong Padding" - - batched_features = self.extract(img[None]) # add batch dimension + batched_features = self.extract(image) assert(len(batched_features) == 1) features = batched_features[0] - for features in batched_features.flat: - features = features.to(torch.device('cpu')) - - kps_crop_space = features.kp.t() - - kps_img_space = kps_crop_space # (2, N) - x = kps_crop_space[0, :] - y = kps_crop_space[1, :] - mask = (0 <= x) & (x <= orig_w-1) & (0 <= y) & (y <= orig_h-1) - - keypoints = kps_img_space.t()[mask] - descriptors = features.desc[mask] - scores = features.kp_logp[mask] - - order = torch.argsort(-scores) - - keypoints = keypoints[order] - descriptors = descriptors[order] - scores = scores[order] - - assert descriptors.shape[1] == self.conf['desc_dim'] - assert keypoints.shape[1] == 2 - pred = {'keypoints': keypoints[None], - 'descriptors': descriptors.t()[None], - 'scores': scores[None]} - return pred + # filter points detected in the padded areas + kpts = features.kp + valid = torch.all(kpts <= kpts.new_tensor([orig_w, orig_h]) - 1, 1) + kpts = kpts[valid] + descriptors = features.desc[valid] + scores = features.kp_logp[valid] + + # order the keypoints + indices = torch.argsort(scores, descending=True) + kpts = kpts[indices] + descriptors = descriptors[indices] + scores = scores[indices] + + return { + 'keypoints': kpts[None], + 'descriptors': descriptors.t()[None], + 'scores': scores[None], + }