From 79325bc3485a912bd813a1d5df5ecaa420921f78 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin Date: Sun, 3 Nov 2024 22:28:16 +0100 Subject: [PATCH] Add Salad retrieval --- hloc/extract_features.py | 5 +++++ hloc/extractors/salad.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 hloc/extractors/salad.py diff --git a/hloc/extract_features.py b/hloc/extract_features.py index 53130ab5..83778ab9 100644 --- a/hloc/extract_features.py +++ b/hloc/extract_features.py @@ -135,6 +135,11 @@ "model": {"name": "eigenplaces"}, "preprocessing": {"resize_max": 1024}, }, + "salad": { + "output": "global-feats-salad", + "model": {"name": "salad"}, + "preprocessing": {"resize_max": 640}, + }, } diff --git a/hloc/extractors/salad.py b/hloc/extractors/salad.py new file mode 100644 index 00000000..7a5ef823 --- /dev/null +++ b/hloc/extractors/salad.py @@ -0,0 +1,40 @@ +import math + +import torch +import torchvision.transforms as tvf + +from ..utils.base_model import BaseModel + + +class Salad(BaseModel): + default_conf = { + "backbone": "dinov2_vitb14", + "pretrained": True, + "patch_size": 14, + } + required_inputs = ["image"] + + def _init(self, conf): + self.net = torch.hub.load( + "sarlinpe/salad", + "dinov2_salad", + backbone=conf["backbone"], + pretrained=conf["pretrained"], + ).eval() + + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + self.norm_rgb = tvf.Normalize(mean=mean, std=std) + + def _forward(self, data): + image = self.norm_rgb(data["image"]) + _, _, h, w = image.shape + patch_size = self.conf["patch_size"] + if h % patch_size or w % patch_size: + h_inp = math.ceil(h / patch_size) * patch_size + w_inp = math.ceil(w / patch_size) * patch_size + image = torch.nn.functional.pad(image, [0, w_inp - w, 0, h_inp - h]) + desc = self.net(image) + return { + "global_descriptor": desc, + }