-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #14 from dmarx/more_loaders
Add CLOOB and KELIP
- Loading branch information
Showing
8 changed files
with
210 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# mmc | ||
|
||
# installation | ||
|
||
``` | ||
pip install poetry | ||
poetry build | ||
pip install dist/mmc*.whl | ||
poe napm_installs | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
|
||
""" | ||
Loaders for pretrained CLOOB model by crowsonkb | ||
https://github.com/crowsonkb/cloob-training | ||
""" | ||
|
||
# importing this first is necessary for cloob to be available | ||
import napm | ||
|
||
from cloob.cloob_training import pretrained # this should probably be isolated somehow | ||
from loguru import logger | ||
import torch | ||
|
||
from .basemmcloader import BaseMmcLoader | ||
from ..modalities import TEXT, IMAGE | ||
from ..multimodalcomparator import MultiModalComparator | ||
from ..registry import REGISTRY, register_model | ||
|
||
from torchvision.transforms import ToTensor | ||
|
||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
import PIL | ||
|
||
class KatCloobLoader(BaseMmcLoader): | ||
""" | ||
CLOOB models by crowsonkb, initially trained on LAION datasets | ||
https://github.com/crowsonkb/cloob-training | ||
""" | ||
def __init__( | ||
self, | ||
id='cloob_laion_400m_vit_b_16_32_epochs', | ||
): | ||
self.architecture = 'cloob' # should this be a type too? | ||
self.publisher = 'crowsonkb' | ||
self.id = id | ||
self.modalities = (TEXT, IMAGE) | ||
def load(self, device=DEVICE): | ||
""" | ||
Returns the MMC associated with this loader. | ||
""" | ||
from cloob.cloob_training import model_pt, pretrained | ||
|
||
config = pretrained.get_config(self.id) | ||
model = model_pt.get_pt_model(config) | ||
checkpoint = pretrained.download_checkpoint(config) | ||
model.load_state_dict(model_pt.get_pt_params(config, checkpoint)) | ||
model.eval().requires_grad_(False).to(device) | ||
d_im = config['image_encoder']['image_size'] | ||
|
||
def _preprocess_closure(img: "PIL.Image.Image") -> torch.Tensor: | ||
img = img.resize((d_im, d_im)).convert('RGB') | ||
t_img = ToTensor()(img) | ||
if t_img.ndim == 3: | ||
t_img = t_img.unsqueeze(0) | ||
t_img = t_img.to(device) | ||
return model.normalize(t_img) | ||
|
||
mmc = MultiModalComparator(name=str(self), device=device) | ||
mmc.register_modality(modality=TEXT, projector=model.text_encoder, preprocessor=model.tokenize) | ||
mmc.register_modality(modality=IMAGE, projector=model.image_encoder, preprocessor=_preprocess_closure) | ||
mmc._model = model | ||
return mmc | ||
|
||
for model_name in pretrained.list_configs(): | ||
register_model( | ||
KatCloobLoader(id=model_name) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
|
||
""" | ||
Loaders for pretrained Korean CLIP (KELIP) published by navervision | ||
https://github.com/navervision/KELIP | ||
""" | ||
|
||
#import clip # this should probably be isolated somehow | ||
from loguru import logger | ||
import torch | ||
|
||
from .basemmcloader import BaseMmcLoader | ||
from ..modalities import TEXT, IMAGE | ||
from ..multimodalcomparator import MultiModalComparator | ||
from ..registry import REGISTRY, register_model | ||
|
||
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
|
||
|
||
#class ClipFaLoader(BaseMmcLoader): | ||
class ClipKelipLoader(BaseMmcLoader): | ||
""" | ||
CLIP model trained for Korean and English languages | ||
https://github.com/navervision/KELIP | ||
""" | ||
def __init__( | ||
self, | ||
id='kelip_ViT-B/32', | ||
): | ||
self.architecture = 'clip' # should this be a type too? | ||
self.publisher = 'navervision' | ||
self.id = id | ||
self.modalities = (TEXT, IMAGE) | ||
def load(self, device=DEVICE): | ||
""" | ||
Returns the MMC associated with this loader. | ||
""" | ||
import kelip | ||
_id = self.id.replace('kelip_','') | ||
model, preprocess_img, tokenizer = kelip.build_model(_id) | ||
|
||
mmc = MultiModalComparator(name=str(self), device=device) | ||
mmc.register_modality(modality=TEXT, projector=model.encode_text, preprocessor=tokenizer) | ||
mmc.register_modality(modality=IMAGE, projector=model.encode_image, preprocessor=preprocess_img) | ||
mmc._model = model | ||
return mmc | ||
|
||
|
||
register_model( | ||
#They don't have a systematic way for listing their weights it for now and only support ViT-B/32 | ||
ClipKelipLoader(id='kelip_ViT-B/32') | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import napm | ||
from loguru import logger | ||
|
||
|
||
def napm_pi_katcloob(): | ||
""" | ||
Usage: | ||
import cloob | ||
from cloob.cloob_training import model_pt, pretrained | ||
config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs') | ||
model = model_pt.get_pt_model(config) | ||
checkpoint = pretrained.download_checkpoint(config) | ||
model.load_state_dict(model_pt.get_pt_params(config, checkpoint), ) | ||
model.eval().requires_grad_(False).to('cuda') | ||
""" | ||
logger.debug('using napm to "install" katCLOOB') | ||
url = "https://github.com/crowsonkb/cloob-training" | ||
napm.pseudoinstall_git_repo(url, package_name='cloob') | ||
|
||
|
||
def all(): | ||
napm_pi_katcloob() | ||
|
||
|
||
if __name__ == '__main__': | ||
all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pytest | ||
from mmc.loaders import KatCloobLoader as loader | ||
import PIL | ||
from loguru import logger | ||
import torch | ||
|
||
|
||
#loader_args = {'id':'RN50--cc12m'} | ||
loader_args = {} | ||
|
||
def test_project_text(): | ||
ldr = loader(**loader_args) | ||
perceptor = ldr.load() | ||
projection = perceptor.project_text("foo bar baz") | ||
assert isinstance(projection, torch.Tensor) | ||
|
||
def test_project_img(): | ||
ldr = loader(**loader_args) | ||
perceptor = ldr.load() | ||
img = PIL.Image.open("./tests/assets/marley_birthday.jpg").resize((250,200)) | ||
projection = perceptor.project_image(img) | ||
assert isinstance(projection, torch.Tensor) | ||
|
||
def test_supported_modalities(): | ||
ldr = loader(**loader_args) | ||
perceptor = ldr.load() | ||
assert perceptor.supports_text | ||
assert perceptor.supports_image | ||
assert not perceptor.supports_audio |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters