Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple DinoV2 for semantic segmentation #4136

Merged
merged 7 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 158 additions & 9 deletions src/otx/algo/classification/backbones/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Copy from mmpretrain/models/backbones/vision_transformer.py."""
from __future__ import annotations

import math
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal

Expand Down Expand Up @@ -46,6 +47,7 @@
"vit-huge",
"dinov2-s",
"dinov2-small",
"dinov2-small-seg",
"dinov2-b",
"dinov2-base",
"dinov2-l",
Expand Down Expand Up @@ -87,6 +89,7 @@
norm_layer: Normalization layer.
act_layer: MLP activation layer.
block_fn: Transformer block layer.
interpolate_offset: work-around offset to apply when interpolating positional embeddings
lora: Enable LoRA training.
"""

Expand Down Expand Up @@ -147,6 +150,17 @@
"num_heads": 6,
"reg_tokens": 4,
"no_embed_class": True,
},
),
**dict.fromkeys(
["dinov2-small-seg"], # segmentation
{
"patch_size": 14,
"embed_dim": 384,
"depth": 12,
"num_heads": 6,
"reg_tokens": 0,
"no_embed_class": False,
"init_values": 1e-5,
},
),
Expand Down Expand Up @@ -193,9 +207,9 @@

def __init__( # noqa: PLR0913
self,
arch: VIT_ARCH_TYPE = "vit-base",
arch: VIT_ARCH_TYPE | str = "vit-base",
img_size: int | tuple[int, int] = 224,
patch_size: int | tuple[int, int] | None = None,
patch_size: int | None = None,
in_chans: int = 3,
num_classes: int = 1000,
embed_dim: int | None = None,
Expand All @@ -221,6 +235,7 @@
mlp_layer: nn.Module | None = None,
act_layer: LayerType | None = None,
norm_layer: LayerType | None = None,
interpolate_offset: float = 0.1,
lora: bool = False,
) -> None:
super().__init__()
Expand All @@ -231,7 +246,7 @@
arch_settings: dict[str, Any] = self.arch_zoo[arch]

self.img_size: int | tuple[int, int] = img_size
self.patch_size: int | tuple[int, int] = patch_size or arch_settings.get("patch_size", 16)
self.patch_size: int = patch_size or arch_settings.get("patch_size", 16)
self.embed_dim = embed_dim or arch_settings.get("embed_dim", 768)
depth = depth or arch_settings.get("depth", 12)
num_heads = num_heads or arch_settings.get("num_heads", 12)
Expand All @@ -251,6 +266,7 @@
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False
self.interpolate_offset = interpolate_offset

embed_args = {}
if dynamic_img_size:
Expand Down Expand Up @@ -353,15 +369,17 @@
# convert dinov2 pretrained weights
state_dict = torch.load(checkpoint_path)
state_dict.pop("mask_token", None)
state_dict["reg_token"] = state_dict.pop("register_tokens")
if "reg_token" in state_dict:
state_dict["reg_token"] = state_dict.pop("register_tokens")

Check warning on line 373 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L372-L373

Added lines #L372 - L373 were not covered by tests
state_dict["cls_token"] = state_dict.pop("cls_token") + state_dict["pos_embed"][:, 0]

img_size = (self.img_size, self.img_size) if isinstance(self.img_size, int) else self.img_size
patch_size = (self.patch_size, self.patch_size) if isinstance(self.patch_size, int) else self.patch_size
state_dict["pos_embed"] = resize_positional_embeddings(
state_dict.pop("pos_embed")[:, 1:],
(img_size[0] // patch_size[0], img_size[1] // patch_size[1]),
)
patch_size = (self.patch_size, self.patch_size)
if state_dict["pos_embed"].shape != self.pos_embed.shape:
state_dict["pos_embed"] = resize_positional_embeddings(

Check warning on line 379 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L377-L379

Added lines #L377 - L379 were not covered by tests
state_dict.pop("pos_embed")[:, 1:],
(img_size[0] // patch_size[0], img_size[1] // patch_size[1]),
)
self.load_state_dict(state_dict, strict=False)
else:
msg = f"Unsupported `checkpoint_extension` {checkpoint_ext}, please choose from 'npz' or 'pth'."
Expand Down Expand Up @@ -401,6 +419,137 @@

return self.pos_drop(x)

def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor:
"""Interpolates the positional encoding to match the input dimensions.

Args:
x (torch.Tensor): Input tensor.
w (int): Width of the input image.
h (int): Height of the input image.

Returns:
torch.Tensor: Tensor with interpolated positional encoding.
"""
previous_dtype = x.dtype
npatch = x.shape[1]
n = self.pos_embed.shape[1]
if npatch == n and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
m = int(math.sqrt(n)) # Recover the number of patches in each dimension
if m * m != n:
msg = f"Expected m * m to equal n, but got m={m}, n={n}"
raise ValueError(msg)
kwargs = {}
if self.interpolate_offset:

Check warning on line 449 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L433-L449

Added lines #L433 - L449 were not covered by tests
# fix float error by introducing small offset
sx = float(w0 + self.interpolate_offset) / m
sy = float(h0 + self.interpolate_offset) / m
kwargs["scale_factor"] = (sx, sy)

Check warning on line 453 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L451-L453

Added lines #L451 - L453 were not covered by tests
else:
# Simply specify an output size instead of a scale factor
kwargs["size"] = (w0, h0)
patch_pos_embed = nn.functional.interpolate(

Check warning on line 457 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L456-L457

Added lines #L456 - L457 were not covered by tests
patch_pos_embed.reshape(1, m, m, dim).permute(0, 3, 1, 2),
mode="bicubic",
**kwargs,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)

Check warning on line 463 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L462-L463

Added lines #L462 - L463 were not covered by tests

def prepare_tokens_with_masks(self, x: torch.Tensor, masks: torch.Tensor | None = None) -> torch.Tensor:
"""Prepare tokens with optional masks.

Args:
x (torch.Tensor): Input tensor.
masks (torch.Tensor | None): Optional masks tensor.

Returns:
torch.Tensor: Tensor with prepared tokens.
"""
_, _, w, h = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)

Check warning on line 478 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L475-L478

Added lines #L475 - L478 were not covered by tests

x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)

Check warning on line 481 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L480-L481

Added lines #L480 - L481 were not covered by tests

if self.reg_token is not None:
x = torch.cat(

Check warning on line 484 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L483-L484

Added lines #L483 - L484 were not covered by tests
(
x[:, :1],
self.reg_token.expand(x.shape[0], -1, -1),
x[:, 1:],
),
dim=1,
)

return x

Check warning on line 493 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L493

Added line #L493 was not covered by tests

def _get_intermediate_layers_not_chunked(self, x: torch.Tensor, n: int = 1) -> list[torch.Tensor]:
"""Get intermediate layers without chunking.

Args:
x (torch.Tensor): Input tensor.
n (int): Number of last blocks to take. If it's a list, take the specified blocks.

Returns:
list[torch.Tensor]: List of intermediate layer outputs.
"""
x = self.prepare_tokens_with_masks(x)

Check warning on line 505 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L505

Added line #L505 was not covered by tests
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in blocks_to_take:
output.append(x)
if len(output) != len(blocks_to_take):
msg = f"only {len(output)} / {len(blocks_to_take)} blocks found"
raise RuntimeError(msg)
return output

Check warning on line 516 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L507-L516

Added lines #L507 - L516 were not covered by tests

def get_intermediate_layers(
self,
x: torch.Tensor,
n: int = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
norm: bool = True,
) -> tuple:
"""Get intermediate layers of the VisionTransformer.

Args:
x (torch.Tensor): Input tensor.
n (int): Number of last blocks to take. If it's a list, take the specified blocks.
reshape (bool): Whether to reshape the output feature maps.
return_class_token (bool): Whether to return the class token.
norm (bool): Whether to apply normalization to the outputs.

Returns:
tuple: A tuple containing the intermediate layer outputs.
"""
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1 + self.num_reg_tokens :] for out in outputs]
if reshape:
b, _, w, h = x.shape
outputs = [

Check warning on line 545 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L538-L545

Added lines #L538 - L545 were not covered by tests
out.reshape(b, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)

Check warning on line 551 in src/otx/algo/classification/backbones/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/classification/backbones/vision_transformer.py#L549-L551

Added lines #L549 - L551 were not covered by tests

def forward(
self,
x: torch.Tensor,
Expand Down
3 changes: 1 addition & 2 deletions src/otx/algo/segmentation/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
#
"""Backbone modules for OTX segmentation model."""

from .dinov2 import DinoVisionTransformer
from .litehrnet import LiteHRNetBackbone
from .mscan import MSCAN

__all__ = ["LiteHRNetBackbone", "DinoVisionTransformer", "MSCAN"]
__all__ = ["LiteHRNetBackbone", "MSCAN"]
98 changes: 0 additions & 98 deletions src/otx/algo/segmentation/backbones/dinov2.py

This file was deleted.

Loading
Loading