Skip to content

More ViTamin changes #2193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*'
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*'
]
NUM_NON_STD = len(NON_STD_FILTERS)

Expand Down
1 change: 1 addition & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from .vision_transformer_hybrid import *
from .vision_transformer_relpos import *
from .vision_transformer_sam import *
from .vitamin import *
from .volo import *
from .vovnet import *
from .xception import *
Expand Down
22 changes: 18 additions & 4 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def __init__(
qk_norm: bool = False,
init_values: Optional[float] = None,
class_token: bool = True,
pos_embed: str = 'learn',
no_embed_class: bool = False,
reg_tokens: int = 0,
pre_norm: bool = False,
Expand Down Expand Up @@ -460,6 +461,7 @@ def __init__(
super().__init__()
assert global_pool in ('', 'avg', 'token', 'map')
assert class_token or global_pool != 'token'
assert pos_embed in ('', 'none', 'learn')
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
act_layer = get_act_layer(act_layer) or nn.GELU
Expand Down Expand Up @@ -494,7 +496,10 @@ def __init__(
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
if not pos_embed or pos_embed == 'none':
self.pos_embed = None
else:
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
self.patch_drop = PatchDropout(
Expand Down Expand Up @@ -556,7 +561,8 @@ def rescale(param, _layer_id):
def init_weights(self, mode: str = '') -> None:
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02)
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(get_init_weights_vit(mode, head_bias), self)
Expand All @@ -583,6 +589,8 @@ def group_matcher(self, coarse: bool = False) -> Dict:
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
self.grad_checkpointing = enable
if hasattr(self.patch_embed, 'set_grad_checkpointing'):
self.patch_embed.set_grad_checkpointing(enable)

@torch.jit.ignore
def get_classifier(self) -> nn.Module:
Expand All @@ -600,6 +608,9 @@ def reset_classifier(self, num_classes: int, global_pool = None) -> None:
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
if self.pos_embed is None:
return x

if self.dynamic_img_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
Expand Down Expand Up @@ -1066,10 +1077,13 @@ def checkpoint_filter_fn(
# IJEPA, vit in an 'encoder' submodule
state_dict = state_dict['encoder']
prefix = 'module.'
elif 'visual.trunk.pos_embed' in state_dict:
elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict:
# OpenCLIP model with timm vision encoder
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
prefix = 'visual.trunk.'
if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear):
# remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
out_dict['head.weight'] = state_dict['visual.head.proj.weight']
out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])

if prefix:
# filter on & remove prefix string from keys
Expand Down
44 changes: 35 additions & 9 deletions timm/models/vision_transformer_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ class HybridEmbed(nn.Module):

def __init__(
self,
backbone,
img_size=224,
patch_size=1,
feature_size=None,
feature_ratio=None,
in_chans=3,
embed_dim=768,
bias=True,
backbone: nn.Module,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 1,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
in_chans: int = 3,
embed_dim: int = 768,
bias: bool = True,
proj: bool = True,
flatten: bool = True,
output_fmt: Optional[str] = None,
strict_img_size: bool = True,
Expand Down Expand Up @@ -95,7 +96,18 @@ def __init__(
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad

self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
if proj:
self.proj = nn.Conv2d(
feature_dim,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
)
else:
assert feature_dim == embed_dim,\
f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
self.proj = nn.Identity()

def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
total_reduction = (
Expand All @@ -116,6 +128,13 @@ def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
else:
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]

@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
if hasattr(self.backbone, 'set_grad_checkpointing'):
self.backbone.set_grad_checkpointing(enable=enable)
elif hasattr(self.backbone, 'grad_checkpointing'):
self.backbone.grad_checkpointing = enable

def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
Expand Down Expand Up @@ -157,6 +176,13 @@ def __init__(
bias=bias,
)

@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
if hasattr(self.backbone, 'set_grad_checkpointing'):
self.backbone.set_grad_checkpointing(enable=enable)
elif hasattr(self.backbone, 'grad_checkpointing'):
self.backbone.grad_checkpointing = enable

def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
x = self.backbone(x)
if isinstance(x, (list, tuple)):
Expand Down
Loading
Loading