Skip to content

Commit

Permalink
Add Salad retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe committed Nov 3, 2024
1 parent b21ff20 commit 79325bc
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
5 changes: 5 additions & 0 deletions hloc/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@
"model": {"name": "eigenplaces"},
"preprocessing": {"resize_max": 1024},
},
"salad": {
"output": "global-feats-salad",
"model": {"name": "salad"},
"preprocessing": {"resize_max": 640},
},
}


Expand Down
40 changes: 40 additions & 0 deletions hloc/extractors/salad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import math

import torch
import torchvision.transforms as tvf

from ..utils.base_model import BaseModel


class Salad(BaseModel):
default_conf = {
"backbone": "dinov2_vitb14",
"pretrained": True,
"patch_size": 14,
}
required_inputs = ["image"]

def _init(self, conf):
self.net = torch.hub.load(
"sarlinpe/salad",
"dinov2_salad",
backbone=conf["backbone"],
pretrained=conf["pretrained"],
).eval()

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
self.norm_rgb = tvf.Normalize(mean=mean, std=std)

def _forward(self, data):
image = self.norm_rgb(data["image"])
_, _, h, w = image.shape
patch_size = self.conf["patch_size"]
if h % patch_size or w % patch_size:
h_inp = math.ceil(h / patch_size) * patch_size
w_inp = math.ceil(w / patch_size) * patch_size
image = torch.nn.functional.pad(image, [0, w_inp - w, 0, h_inp - h])
desc = self.net(image)
return {
"global_descriptor": desc,
}

0 comments on commit 79325bc

Please sign in to comment.