Skip to content

Commit

Permalink
Merge pull request #1467 from rwightman/clip_laion2b
Browse files Browse the repository at this point in the history
Adding support for fine-tune CLIP LAION-2B image tower weights for B/32, L/14, H/14, and g/14.
  • Loading branch information
rwightman authored Sep 23, 2022
2 parents a520da9 + 33e30f8 commit d199f66
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 33 deletions.
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
torch>=1.4.0
torchvision>=0.5.0
torch>=1.7
torchvision
pyyaml
huggingface_hub
9 changes: 6 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
'Development Status :: 3 - Alpha',
'Development Status :: 4 - Beta',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
Expand All @@ -40,9 +42,10 @@
],

# Note that this is a string of words separated by whitespace, not a list.
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet',
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
packages=find_packages(exclude=['convert', 'tests', 'results']),
include_package_data=True,
install_requires=['torch >= 1.4', 'torchvision'],
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'],
python_requires='>=3.6',
)

2 changes: 2 additions & 0 deletions timm/data/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
8 changes: 7 additions & 1 deletion timm/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def _resolve_pretrained_source(pretrained_cfg):
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
if load_from == 'hf-hub' and 'hf_hub_filename' in pretrained_cfg:
# if a filename override is set, return tuple for location w/ (hub_id, filename)
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
return load_from, pretrained_loc


Expand Down Expand Up @@ -246,7 +249,10 @@ def load_pretrained(
pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
state_dict = load_state_dict_from_hf(pretrained_loc)
if isinstance(pretrained_loc, (list, tuple)):
state_dict = load_state_dict_from_hf(*pretrained_loc)
else:
state_dict = load_state_dict_from_hf(pretrained_loc)
else:
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
return
Expand Down
9 changes: 5 additions & 4 deletions timm/models/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.hub import _get_torch_home as get_dir

from timm import __version__

try:
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
Expand Down Expand Up @@ -55,7 +56,7 @@ def download_cached_file(url, check_hash=True, progress=False):

def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed and it is necessary to continue, raise error
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return _has_hf_hub
Expand All @@ -78,7 +79,7 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]):

def _download_from_hf(model_id: str, filename: str):
hf_model_id, hf_revision = hf_split(model_id)
return hf_hub_download(hf_model_id, filename, revision=hf_revision, cache_dir=get_cache_dir('hf'))
return hf_hub_download(hf_model_id, filename, revision=hf_revision)


def load_model_config_from_hf(model_id: str):
Expand All @@ -91,9 +92,9 @@ def load_model_config_from_hf(model_id: str):
return pretrained_cfg, model_name


def load_state_dict_from_hf(model_id: str):
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
cached_file = _download_from_hf(model_id, filename)
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict

Expand Down
13 changes: 11 additions & 2 deletions timm/models/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
Expand All @@ -25,7 +34,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

def forward(self, x):
Expand Down
Loading

0 comments on commit d199f66

Please sign in to comment.