Skip to content

Commit

Permalink
Adding support for fine-tune CLIP LAION-2B image tower weights for B/…
Browse files Browse the repository at this point in the history
…32, L/14, H/14 and g/14. Still WIP
  • Loading branch information
rwightman committed Sep 16, 2022
1 parent a520da9 commit 9709dba
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 28 deletions.
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_g*', 'swin*huge*',
'swin*giant*']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*']
else:
Expand Down
8 changes: 7 additions & 1 deletion timm/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def _resolve_pretrained_source(pretrained_cfg):
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
if load_from == 'hf-hub' and 'hf_hub_filename' in pretrained_cfg:
# if a filename override is set, return tuple for location w/ (hub_id, filename)
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
return load_from, pretrained_loc


Expand Down Expand Up @@ -246,7 +249,10 @@ def load_pretrained(
pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
state_dict = load_state_dict_from_hf(pretrained_loc)
if isinstance(pretrained_loc, (list, tuple)):
state_dict = load_state_dict_from_hf(*pretrained_loc)
else:
state_dict = load_state_dict_from_hf(pretrained_loc)
else:
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
return
Expand Down
8 changes: 4 additions & 4 deletions timm/models/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def download_cached_file(url, check_hash=True, progress=False):

def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed and it is necessary to continue, raise error
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return _has_hf_hub
Expand All @@ -78,7 +78,7 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]):

def _download_from_hf(model_id: str, filename: str):
hf_model_id, hf_revision = hf_split(model_id)
return hf_hub_download(hf_model_id, filename, revision=hf_revision, cache_dir=get_cache_dir('hf'))
return hf_hub_download(hf_model_id, filename, revision=hf_revision)


def load_model_config_from_hf(model_id: str):
Expand All @@ -91,9 +91,9 @@ def load_model_config_from_hf(model_id: str):
return pretrained_cfg, model_name


def load_state_dict_from_hf(model_id: str):
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
cached_file = _download_from_hf(model_id, filename)
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict

Expand Down
13 changes: 11 additions & 2 deletions timm/models/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
Expand All @@ -25,7 +34,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

def forward(self, x):
Expand Down
198 changes: 180 additions & 18 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from collections import OrderedDict
from typing import Optional

import huggingface_hub.file_download
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -106,7 +107,7 @@ def _cfg(url='', **kwargs):
'vit_large_patch14_224': _cfg(url=''),
'vit_huge_patch14_224': _cfg(url=''),
'vit_giant_patch14_224': _cfg(url=''),
'vit_gigantic_patch14_224': _cfg(url=''),
'vit_gee_patch14_224': _cfg(url=''),


# patch models, imagenet21k (weights from official Google JAX impl)
Expand Down Expand Up @@ -177,6 +178,20 @@ def _cfg(url='', **kwargs):
'vit_small_patch16_36x1_224': _cfg(url=''),
'vit_small_patch16_18x2_224': _cfg(url=''),
'vit_base_patch16_18x2_224': _cfg(url=''),

'vit_base_patch32_224_clip_laion2b': _cfg(
hf_hub_id='',
num_classes=512),
'vit_large_patch14_224_clip_laion2b': _cfg(
hf_hub_id='',
num_classes=768),
'vit_huge_patch14_224_clip_laion2b': _cfg(
hf_hub_id='',
num_classes=1024),
'vit_giant_patch14_224_clip_laion2b': _cfg(
hf_hub_id='',
num_classes=1024),

}


Expand Down Expand Up @@ -221,8 +236,18 @@ def forward(self, x):
class Block(nn.Module):

def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
Expand All @@ -244,8 +269,18 @@ def forward(self, x):
class ResPostBlock(nn.Module):

def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__()
self.init_values = init_values

Expand Down Expand Up @@ -274,8 +309,19 @@ def forward(self, x):
class ParallelBlock(nn.Module):

def __init__(
self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
self,
dim,
num_heads,
num_parallel=2,
mlp_ratio=4.,
qkv_bias=False,
init_values=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__()
self.num_parallel = num_parallel
self.attns = nn.ModuleList()
Expand Down Expand Up @@ -320,10 +366,31 @@ class VisionTransformer(nn.Module):
"""

def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=None,
class_token=True,
no_embed_class=False,
pre_norm=False,
fc_norm=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=Block,
):
"""
Args:
img_size (int, tuple): input image size
Expand Down Expand Up @@ -362,19 +429,34 @@ def __init__(
self.grad_checkpointing = False

self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate)
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
block_fn(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer
)
for i in range(depth)])
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()

Expand Down Expand Up @@ -445,6 +527,7 @@ def _pos_embed(self, x):
def forward_features(self, x):
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.norm_pre(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
Expand Down Expand Up @@ -623,6 +706,40 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
return posemb


def _convert_openai_clip(state_dict, model):
out_dict = {}
swaps = [
('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'),
('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'),
('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'),
]
for k, v in state_dict.items():
if not k.startswith('visual.'):
continue
for sp in swaps:
k = k.replace(sp[0], sp[1])

if k == 'proj':
k = 'head.weight'
v = v.transpose(0, 1)
out_dict['head.bias'] = torch.zeros(v.shape[0])
elif k == 'class_embedding':
k = 'cls_token'
v = v.unsqueeze(0).unsqueeze(1)
elif k == 'pos_embed':
v = v.unsqueeze(0)
if v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(
v,
model.pos_embed,
0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
)
out_dict[k] = v
return out_dict


def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
import re
Expand All @@ -631,6 +748,9 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
# For deit models
state_dict = state_dict['model']

if 'visual.class_embedding' in state_dict:
return _convert_openai_clip(state_dict, model)

for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
# For old models that I trained prior to conv based patchification
Expand Down Expand Up @@ -833,19 +953,20 @@ def vit_huge_patch14_224(pretrained=False, **kwargs):

@register_model
def vit_giant_patch14_224(pretrained=False, **kwargs):
""" ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
"""
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
return model


@register_model
def vit_gigantic_patch14_224(pretrained=False, **kwargs):
""" ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
def vit_gee_patch14_224(pretrained=False, **kwargs):
""" ViT-GEE (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
As per https://twitter.com/wightmanr/status/1570549064667889666
"""
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
model = _create_vision_transformer('vit_gee_patch14_224', pretrained=pretrained, **model_kwargs)
return model


Expand Down Expand Up @@ -1085,3 +1206,44 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
return model


@register_model
def vit_base_patch32_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-B/32
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model


@register_model
def vit_large_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14)
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, **kwargs)
model = _create_vision_transformer('vit_large_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model


@register_model
def vit_huge_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model


@register_model
def vit_giant_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model
Loading

0 comments on commit 9709dba

Please sign in to comment.