Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make LERF compatible with DINOv2 #9

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lerf/data/lerf_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def __init__(

cache_dir = f"outputs/{self.config.dataparser.data.name}"
clip_cache_path = Path(osp.join(cache_dir, f"clip_{self.image_encoder.name}"))
dino_cache_path = Path(osp.join(cache_dir, "dino.npy"))
dino_name = DinoDataloader.dino_model_type
dino_cache_path = Path(osp.join(cache_dir, f"dino_{dino_name}.npy"))
# NOTE: cache config is sensitive to list vs. tuple, because it checks for dict equality
self.dino_dataloader = DinoDataloader(
image_list=images,
Expand Down
15 changes: 15 additions & 0 deletions lerf/data/utils/dino_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
from lerf.data.utils.dino_extractor import ViTExtractor
from lerf.data.utils.feature_dataloader import FeatureDataloader
from tqdm import tqdm
import torchvision.transforms as transforms


class DinoDataloader(FeatureDataloader):
dino_model_type = "dino_vits8"
dino_stride = 8
# # For dinov2, use this:
# dino_model_type = "dinov2_vitb14"
# dino_stride = 14
dino_load_size = 500
dino_layer = 11
dino_facet = "key"
Expand All @@ -23,13 +27,24 @@ def __init__(
):
assert "image_shape" in cfg
super().__init__(cfg, device, image_list, cache_path)
# Do distillation-preprocessing as noted in N3F:
# The features are then L2-normalized and reduced with PCA to 64 dimensions before distillation.
data_shape = self.data.shape
self.data = self.data / self.data.norm(dim=-1, keepdim=True)
self.data = torch.pca_lowrank(self.data.reshape(-1, data_shape[-1]), q=64)[0].reshape((*data_shape[:-1], 64))

def create(self, image_list):
extractor = ViTExtractor(self.dino_model_type, self.dino_stride)
preproc_image_lst = extractor.preprocess(image_list, self.dino_load_size)[0].to(self.device)

dino_embeds = []
for image in tqdm(preproc_image_lst, desc="dino", total=len(image_list), leave=False):
# image nees to be resized s.t. H, W are divisible by dino_stride
if "dinov2" in self.dino_model_type:
image = transforms.Resize((
(image.shape[1]//self.dino_stride)*self.dino_stride,
(image.shape[2]//self.dino_stride)*self.dino_stride,
))(image)
with torch.no_grad():
descriptors = extractor.extract_descriptors(
image.unsqueeze(0),
Expand Down
11 changes: 10 additions & 1 deletion lerf/data/utils/dino_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Mo
self.model.eval()
self.model.to(self.device)
self.p = self.model.patch_embed.patch_size
if type(self.p) is tuple:
assert len(self.p) == 2 and self.p[0] == self.p[1], "only square patch sizes are supported"
self.p = self.p[0]

self.stride = self.model.patch_embed.proj.stride

self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5)
Expand All @@ -64,7 +68,9 @@ def create_model(model_type: str) -> nn.Module:
vit_base_patch16_224]
:return: the model
"""
if 'dino' in model_type:
if 'dinov2' in model_type:
model = torch.hub.load('facebookresearch/dinov2:main', model_type)
elif 'dino' in model_type:
model = torch.hub.load('facebookresearch/dino:main', model_type)
else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images).
temp_model = timm.create_model(model_type, pretrained=True)
Expand Down Expand Up @@ -128,6 +134,9 @@ def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module:
patch_size = model.patch_embed.patch_size
if stride == patch_size: # nothing to do
return model
if type(patch_size) == tuple and len(patch_size) == 2:
assert patch_size[0] == patch_size[1], f'patch_size {patch_size} should be square'
patch_size = patch_size[0]

stride = nn_utils._pair(stride)
assert all([(patch_size // s_) * s_ == patch_size for s_ in
Expand Down
15 changes: 15 additions & 0 deletions lerf/lerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def populate_modules(self):
clip_n_dims=self.image_encoder.embedding_dim,
)

self.n_clusters = ViewerSlider("n_clusters", 6, 2, 20, 1)

# populate some viewer logic
# TODO use the values from this code to select the scale
# def scale_cb(element):
Expand Down Expand Up @@ -159,6 +161,19 @@ def gather_fn(tens):
)
outputs["raw_relevancy"] = max_across # N x B x 1
outputs["best_scales"] = best_scales.to(self.device) # N

from fast_pytorch_kmeans import KMeans
n_clusters = self.n_clusters.value
dino_feats = self.renderer_mean(embeds=lerf_field_outputs[LERFFieldHeadNames.DINO], weights=lerf_weights.detach())
kmeans = KMeans(n_clusters=n_clusters, verbose=0)
dino_labels = kmeans.fit_predict(dino_feats)
dino_labels = dino_labels / dino_labels.max()
outputs["dino_kmeans"] = dino_labels[...,None]

dino_feats_pca = torch.pca_lowrank(dino_feats, q=3)[0]
dino_feats_pca = dino_feats_pca - dino_feats_pca.min(dim=0, keepdim=True).values
dino_feats_pca = dino_feats_pca / dino_feats_pca.max(dim=0, keepdim=True).values
outputs["dino_pca"] = dino_feats_pca

return outputs

Expand Down
4 changes: 3 additions & 1 deletion lerf/lerf_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def __init__(

self.dino_net = tcnn.Network(
n_input_dims=tot_out_dims,
n_output_dims=384,
n_output_dims=64,
# n_output_dims=384,
# n_output_dims=768,
network_config={
"otype": "CutlassMLP",
"activation": "ReLU",
Expand Down