From 28ee0935b52daf3ed7bf94a89599e4a7db0616ce Mon Sep 17 00:00:00 2001 From: Chung Min Kim Date: Thu, 20 Apr 2023 22:56:33 +0000 Subject: [PATCH 1/3] Add initial dinov2 compatibility --- lerf/data/lerf_datamanager.py | 3 ++- lerf/data/utils/dino_dataloader.py | 10 ++++++++++ lerf/data/utils/dino_extractor.py | 11 ++++++++++- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/lerf/data/lerf_datamanager.py b/lerf/data/lerf_datamanager.py index cb57462..2c56b0d 100644 --- a/lerf/data/lerf_datamanager.py +++ b/lerf/data/lerf_datamanager.py @@ -81,7 +81,8 @@ def __init__( cache_dir = f"outputs/{self.config.dataparser.data.name}" clip_cache_path = Path(osp.join(cache_dir, f"clip_{self.image_encoder.name}")) - dino_cache_path = Path(osp.join(cache_dir, "dino.npy")) + dino_name = DinoDataloader.dino_model_type + dino_cache_path = Path(osp.join(cache_dir, f"dino_{dino_name}.npy")) # NOTE: cache config is sensitive to list vs. tuple, because it checks for dict equality self.dino_dataloader = DinoDataloader( image_list=images, diff --git a/lerf/data/utils/dino_dataloader.py b/lerf/data/utils/dino_dataloader.py index b5a6cfa..2242377 100644 --- a/lerf/data/utils/dino_dataloader.py +++ b/lerf/data/utils/dino_dataloader.py @@ -4,11 +4,15 @@ from lerf.data.utils.dino_extractor import ViTExtractor from lerf.data.utils.feature_dataloader import FeatureDataloader from tqdm import tqdm +import torchvision.transforms as transforms class DinoDataloader(FeatureDataloader): dino_model_type = "dino_vits8" dino_stride = 8 + # # For dinov2, use this: + # dino_model_type = "dinov2_vitb14" + # dino_stride = 14 dino_load_size = 500 dino_layer = 11 dino_facet = "key" @@ -30,6 +34,12 @@ def create(self, image_list): dino_embeds = [] for image in tqdm(preproc_image_lst, desc="dino", total=len(image_list), leave=False): + # image nees to be resized s.t. H, W are divisible by dino_stride + if "dinov2" in self.dino_model_type: + image = transforms.Resize(( + (image.shape[1]//self.dino_stride)*self.dino_stride, + (image.shape[2]//self.dino_stride)*self.dino_stride, + ))(image) with torch.no_grad(): descriptors = extractor.extract_descriptors( image.unsqueeze(0), diff --git a/lerf/data/utils/dino_extractor.py b/lerf/data/utils/dino_extractor.py index 8b5a9cc..33c3a47 100644 --- a/lerf/data/utils/dino_extractor.py +++ b/lerf/data/utils/dino_extractor.py @@ -46,6 +46,10 @@ def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Mo self.model.eval() self.model.to(self.device) self.p = self.model.patch_embed.patch_size + if type(self.p) is tuple: + assert len(self.p) == 2 and self.p[0] == self.p[1], "only square patch sizes are supported" + self.p = self.p[0] + self.stride = self.model.patch_embed.proj.stride self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5) @@ -64,7 +68,9 @@ def create_model(model_type: str) -> nn.Module: vit_base_patch16_224] :return: the model """ - if 'dino' in model_type: + if 'dinov2' in model_type: + model = torch.hub.load('facebookresearch/dinov2:main', model_type) + elif 'dino' in model_type: model = torch.hub.load('facebookresearch/dino:main', model_type) else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images). temp_model = timm.create_model(model_type, pretrained=True) @@ -128,6 +134,9 @@ def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module: patch_size = model.patch_embed.patch_size if stride == patch_size: # nothing to do return model + if type(patch_size) == tuple and len(patch_size) == 2: + assert patch_size[0] == patch_size[1], f'patch_size {patch_size} should be square' + patch_size = patch_size[0] stride = nn_utils._pair(stride) assert all([(patch_size // s_) * s_ == patch_size for s_ in From 89935da213399b77e261dff50643f30749e9d8da Mon Sep 17 00:00:00 2001 From: Chung Min Kim Date: Thu, 20 Apr 2023 22:57:02 +0000 Subject: [PATCH 2/3] Add distillation-preprocessing (PCA, N3F-style) --- lerf/data/utils/dino_dataloader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lerf/data/utils/dino_dataloader.py b/lerf/data/utils/dino_dataloader.py index 2242377..4a61dca 100644 --- a/lerf/data/utils/dino_dataloader.py +++ b/lerf/data/utils/dino_dataloader.py @@ -27,6 +27,11 @@ def __init__( ): assert "image_shape" in cfg super().__init__(cfg, device, image_list, cache_path) + # Do distillation-preprocessing as noted in N3F: + # The features are then L2-normalized and reduced with PCA to 64 dimensions before distillation. + data_shape = self.data.shape + self.data = self.data / self.data.norm(dim=-1, keepdim=True) + self.data = torch.pca_lowrank(self.data.reshape(-1, data_shape[-1]), q=64)[0].reshape((*data_shape[:-1], 64)) def create(self, image_list): extractor = ViTExtractor(self.dino_model_type, self.dino_stride) From dfc582aeee196afd7f95f412e4abc94f7a82660a Mon Sep 17 00:00:00 2001 From: Chung Min Kim Date: Thu, 20 Apr 2023 23:13:41 +0000 Subject: [PATCH 3/3] Add kmeans+pca visualizations for dino --- lerf/lerf.py | 15 +++++++++++++++ lerf/lerf_field.py | 4 +++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/lerf/lerf.py b/lerf/lerf.py index ba628e8..835581a 100644 --- a/lerf/lerf.py +++ b/lerf/lerf.py @@ -51,6 +51,8 @@ def populate_modules(self): clip_n_dims=self.image_encoder.embedding_dim, ) + self.n_clusters = ViewerSlider("n_clusters", 6, 2, 20, 1) + # populate some viewer logic # TODO use the values from this code to select the scale # def scale_cb(element): @@ -159,6 +161,19 @@ def gather_fn(tens): ) outputs["raw_relevancy"] = max_across # N x B x 1 outputs["best_scales"] = best_scales.to(self.device) # N + + from fast_pytorch_kmeans import KMeans + n_clusters = self.n_clusters.value + dino_feats = self.renderer_mean(embeds=lerf_field_outputs[LERFFieldHeadNames.DINO], weights=lerf_weights.detach()) + kmeans = KMeans(n_clusters=n_clusters, verbose=0) + dino_labels = kmeans.fit_predict(dino_feats) + dino_labels = dino_labels / dino_labels.max() + outputs["dino_kmeans"] = dino_labels[...,None] + + dino_feats_pca = torch.pca_lowrank(dino_feats, q=3)[0] + dino_feats_pca = dino_feats_pca - dino_feats_pca.min(dim=0, keepdim=True).values + dino_feats_pca = dino_feats_pca / dino_feats_pca.max(dim=0, keepdim=True).values + outputs["dino_pca"] = dino_feats_pca return outputs diff --git a/lerf/lerf_field.py b/lerf/lerf_field.py index 68178bc..e0b11df 100644 --- a/lerf/lerf_field.py +++ b/lerf/lerf_field.py @@ -58,7 +58,9 @@ def __init__( self.dino_net = tcnn.Network( n_input_dims=tot_out_dims, - n_output_dims=384, + n_output_dims=64, + # n_output_dims=384, + # n_output_dims=768, network_config={ "otype": "CutlassMLP", "activation": "ReLU",