Skip to content

Commit 9a51e4e

Browse files
committed
Add FlexiViT models and weights, refactoring, push more weights
* push all vision_transformer*.py weights to HF hub * finalize more pretrained tags for pushed weights * refactor pos_embed files and module locations, move some pos embed modules to layers * tweak hf hub helpers to aid bulk uploading and updating
1 parent 656e177 commit 9a51e4e

18 files changed

+1190
-714
lines changed

.gitignore

+10
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,16 @@ output/
106106
*.tar
107107
*.pth
108108
*.pt
109+
*.torch
109110
*.gz
110111
Untitled.ipynb
111112
Testing notebook.ipynb
113+
114+
# Root dir exclusions
115+
/*.csv
116+
/*.yaml
117+
/*.json
118+
/*.jpg
119+
/*.png
120+
/*.zip
121+
/*.tar.*

timm/layers/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .activations import *
22
from .adaptive_avgmax_pool import \
33
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
4+
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
45
from .blur_pool import BlurPool2d
56
from .classifier import ClassifierHead, create_classifier
67
from .cond_conv2d import CondConv2d, get_condconv_initializer
@@ -30,8 +31,12 @@
3031
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
3132
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
3233
from .padding import get_padding, get_same_padding, pad_same
33-
from .patch_embed import PatchEmbed
34+
from .patch_embed import PatchEmbed, resample_patch_embed
3435
from .pool2d_same import AvgPool2dSame, create_pool2d
36+
from .pos_embed import resample_abs_pos_embed
37+
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
38+
from .pos_embed_sincos import build_sincos2d_pos_embed, build_fourier_pos_embed, build_rotary_pos_embed, \
39+
FourierEmbed, RotaryEmbedding
3540
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
3641
from .selective_kernel import SelectiveKernel
3742
from .separable_conv import SeparableConv2d, SeparableConvNormAct

timm/layers/attention_pool2d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch.nn as nn
1414

1515
from .helpers import to_2tuple
16-
from .pos_embed import apply_rot_embed, RotaryEmbedding
16+
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
1717
from .weight_init import trunc_normal_
1818

1919

timm/layers/helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
def _ntuple(n):
1111
def parse(x):
1212
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
13-
return x
13+
return tuple(x)
1414
return tuple(repeat(x, n))
1515
return parse
1616

timm/layers/patch_embed.py

+129-1
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,24 @@
22
33
A convolution based approach to patchifying a 2D image w/ embedding projection.
44
5-
Based on the impl in https://github.com/google-research/vision_transformer
5+
Based on code in:
6+
* https://github.com/google-research/vision_transformer
7+
* https://github.com/google-research/big_vision/tree/main/big_vision
68
79
Hacked together by / Copyright 2020 Ross Wightman
810
"""
11+
import logging
12+
from typing import List
13+
14+
import torch
915
from torch import nn as nn
16+
import torch.nn.functional as F
1017

1118
from .helpers import to_2tuple
1219
from .trace_utils import _assert
1320

21+
_logger = logging.getLogger(__name__)
22+
1423

1524
class PatchEmbed(nn.Module):
1625
""" 2D Image to Patch Embedding
@@ -46,3 +55,122 @@ def forward(self, x):
4655
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
4756
x = self.norm(x)
4857
return x
58+
59+
60+
def resample_patch_embed(
61+
patch_embed,
62+
new_size: List[int],
63+
interpolation: str = 'bicubic',
64+
antialias: bool = True,
65+
verbose: bool = False,
66+
):
67+
"""Resample the weights of the patch embedding kernel to target resolution.
68+
We resample the patch embedding kernel by approximately inverting the effect
69+
of patch resizing.
70+
71+
Code based on:
72+
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
73+
74+
With this resizing, we can for example load a B/8 filter into a B/16 model
75+
and, on 2x larger input image, the result will match.
76+
77+
Args:
78+
patch_embed: original parameter to be resized.
79+
new_size (tuple(int, int): target shape (height, width)-only.
80+
interpolation (str): interpolation for resize
81+
antialias (bool): use anti-aliasing filter in resize
82+
verbose (bool): log operation
83+
Returns:
84+
Resized patch embedding kernel.
85+
"""
86+
import numpy as np
87+
88+
assert len(patch_embed.shape) == 4, "Four dimensions expected"
89+
assert len(new_size) == 2, "New shape should only be hw"
90+
old_size = patch_embed.shape[-2:]
91+
if tuple(old_size) == tuple(new_size):
92+
return patch_embed
93+
94+
if verbose:
95+
_logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.")
96+
97+
def resize(x_np, _new_size):
98+
x_tf = torch.Tensor(x_np)[None, None, ...]
99+
x_upsampled = F.interpolate(
100+
x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy()
101+
return x_upsampled
102+
103+
def get_resize_mat(_old_size, _new_size):
104+
mat = []
105+
for i in range(np.prod(_old_size)):
106+
basis_vec = np.zeros(_old_size)
107+
basis_vec[np.unravel_index(i, _old_size)] = 1.
108+
mat.append(resize(basis_vec, _new_size).reshape(-1))
109+
return np.stack(mat).T
110+
111+
resize_mat = get_resize_mat(old_size, new_size)
112+
resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T))
113+
114+
def resample_kernel(kernel):
115+
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
116+
return resampled_kernel.reshape(new_size)
117+
118+
v_resample_kernel = torch.vmap(torch.vmap(resample_kernel, 0, 0), 1, 1)
119+
return v_resample_kernel(patch_embed)
120+
121+
122+
# def divs(n, m=None):
123+
# m = m or n // 2
124+
# if m == 1:
125+
# return [1]
126+
# if n % m == 0:
127+
# return [m] + divs(n, m - 1)
128+
# return divs(n, m - 1)
129+
#
130+
#
131+
# class FlexiPatchEmbed(nn.Module):
132+
# """ 2D Image to Patch Embedding w/ Flexible Patch sizes (FlexiViT)
133+
# FIXME WIP
134+
# """
135+
# def __init__(
136+
# self,
137+
# img_size=240,
138+
# patch_size=16,
139+
# in_chans=3,
140+
# embed_dim=768,
141+
# base_img_size=240,
142+
# base_patch_size=32,
143+
# norm_layer=None,
144+
# flatten=True,
145+
# bias=True,
146+
# ):
147+
# super().__init__()
148+
# self.img_size = to_2tuple(img_size)
149+
# self.patch_size = to_2tuple(patch_size)
150+
# self.num_patches = 0
151+
#
152+
# # full range for 240 = (5, 6, 8, 10, 12, 14, 15, 16, 20, 24, 30, 40, 48)
153+
# self.seqhw = (6, 8, 10, 12, 14, 15, 16, 20, 24, 30)
154+
#
155+
# self.base_img_size = to_2tuple(base_img_size)
156+
# self.base_patch_size = to_2tuple(base_patch_size)
157+
# self.base_grid_size = tuple([i // p for i, p in zip(self.base_img_size, self.base_patch_size)])
158+
# self.base_num_patches = self.base_grid_size[0] * self.base_grid_size[1]
159+
#
160+
# self.flatten = flatten
161+
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=bias)
162+
# self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
163+
#
164+
# def forward(self, x):
165+
# B, C, H, W = x.shape
166+
#
167+
# if self.patch_size == self.base_patch_size:
168+
# weight = self.proj.weight
169+
# else:
170+
# weight = resample_patch_embed(self.proj.weight, self.patch_size)
171+
# patch_size = self.patch_size
172+
# x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
173+
# if self.flatten:
174+
# x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
175+
# x = self.norm(x)
176+
# return x

0 commit comments

Comments
 (0)