Skip to content

Commit

Permalink
Merge pull request #14 from dmarx/more_loaders
Browse files Browse the repository at this point in the history
Add CLOOB and KELIP
  • Loading branch information
dmarx authored May 1, 2022
2 parents 71acb15 + ac0110a commit 633b6af
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 3 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# mmc

# installation

```
pip install poetry
poetry build
pip install dist/mmc*.whl
poe napm_installs
```
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ Pillow = "^9.1.0"
kornia = "^0.6.4"
open-clip-torch = "^0.2.1"
declip = {git = "https://github.com/pytti-tools/DeCLIP", branch = "installable"}
#cloob....
kelip = {git = "https://github.com/navervision/KELIP.git", branch = "master"}
sentence-transformers = "^2.2.0"
napm="0.1.1"

[tool.poetry.dev-dependencies]
black = "^22.3.0"
Expand All @@ -27,3 +28,8 @@ poethepoet = "^0.13.1"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"


# https://github.com/nat-n/poethepoet
[tool.poe.tasks]
napm_installs = { "script" = "mmc.napm_installs:all" }
4 changes: 3 additions & 1 deletion src/mmc/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
from .openaicliploader import *
from .mlfcliploader import *
from .sbertclibloader import *
from .clipfaloader import *
from .clipfaloader import *
from .cloobloader import *
from .keliploader import *
71 changes: 71 additions & 0 deletions src/mmc/loaders/cloobloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@

"""
Loaders for pretrained CLOOB model by crowsonkb
https://github.com/crowsonkb/cloob-training
"""

# importing this first is necessary for cloob to be available
import napm

from cloob.cloob_training import pretrained # this should probably be isolated somehow
from loguru import logger
import torch

from .basemmcloader import BaseMmcLoader
from ..modalities import TEXT, IMAGE
from ..multimodalcomparator import MultiModalComparator
from ..registry import REGISTRY, register_model

from torchvision.transforms import ToTensor

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import PIL

class KatCloobLoader(BaseMmcLoader):
"""
CLOOB models by crowsonkb, initially trained on LAION datasets
https://github.com/crowsonkb/cloob-training
"""
def __init__(
self,
id='cloob_laion_400m_vit_b_16_32_epochs',
):
self.architecture = 'cloob' # should this be a type too?
self.publisher = 'crowsonkb'
self.id = id
self.modalities = (TEXT, IMAGE)
def load(self, device=DEVICE):
"""
Returns the MMC associated with this loader.
"""
from cloob.cloob_training import model_pt, pretrained

config = pretrained.get_config(self.id)
model = model_pt.get_pt_model(config)
checkpoint = pretrained.download_checkpoint(config)
model.load_state_dict(model_pt.get_pt_params(config, checkpoint))
model.eval().requires_grad_(False).to(device)
d_im = config['image_encoder']['image_size']

def _preprocess_closure(img: "PIL.Image.Image") -> torch.Tensor:
img = img.resize((d_im, d_im)).convert('RGB')
t_img = ToTensor()(img)
if t_img.ndim == 3:
t_img = t_img.unsqueeze(0)
t_img = t_img.to(device)
return model.normalize(t_img)

mmc = MultiModalComparator(name=str(self), device=device)
mmc.register_modality(modality=TEXT, projector=model.text_encoder, preprocessor=model.tokenize)
mmc.register_modality(modality=IMAGE, projector=model.image_encoder, preprocessor=_preprocess_closure)
mmc._model = model
return mmc

for model_name in pretrained.list_configs():
register_model(
KatCloobLoader(id=model_name)
)
51 changes: 51 additions & 0 deletions src/mmc/loaders/keliploader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@

"""
Loaders for pretrained Korean CLIP (KELIP) published by navervision
https://github.com/navervision/KELIP
"""

#import clip # this should probably be isolated somehow
from loguru import logger
import torch

from .basemmcloader import BaseMmcLoader
from ..modalities import TEXT, IMAGE
from ..multimodalcomparator import MultiModalComparator
from ..registry import REGISTRY, register_model

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


#class ClipFaLoader(BaseMmcLoader):
class ClipKelipLoader(BaseMmcLoader):
"""
CLIP model trained for Korean and English languages
https://github.com/navervision/KELIP
"""
def __init__(
self,
id='kelip_ViT-B/32',
):
self.architecture = 'clip' # should this be a type too?
self.publisher = 'navervision'
self.id = id
self.modalities = (TEXT, IMAGE)
def load(self, device=DEVICE):
"""
Returns the MMC associated with this loader.
"""
import kelip
_id = self.id.replace('kelip_','')
model, preprocess_img, tokenizer = kelip.build_model(_id)

mmc = MultiModalComparator(name=str(self), device=device)
mmc.register_modality(modality=TEXT, projector=model.encode_text, preprocessor=tokenizer)
mmc.register_modality(modality=IMAGE, projector=model.encode_image, preprocessor=preprocess_img)
mmc._model = model
return mmc


register_model(
#They don't have a systematic way for listing their weights it for now and only support ViT-B/32
ClipKelipLoader(id='kelip_ViT-B/32')
)
28 changes: 28 additions & 0 deletions src/mmc/napm_installs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import napm
from loguru import logger


def napm_pi_katcloob():
"""
Usage:
import cloob
from cloob.cloob_training import model_pt, pretrained
config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
model = model_pt.get_pt_model(config)
checkpoint = pretrained.download_checkpoint(config)
model.load_state_dict(model_pt.get_pt_params(config, checkpoint), )
model.eval().requires_grad_(False).to('cuda')
"""
logger.debug('using napm to "install" katCLOOB')
url = "https://github.com/crowsonkb/cloob-training"
napm.pseudoinstall_git_repo(url, package_name='cloob')


def all():
napm_pi_katcloob()


if __name__ == '__main__':
all()
29 changes: 29 additions & 0 deletions tests/test_mmc_katcloob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
from mmc.loaders import KatCloobLoader as loader
import PIL
from loguru import logger
import torch


#loader_args = {'id':'RN50--cc12m'}
loader_args = {}

def test_project_text():
ldr = loader(**loader_args)
perceptor = ldr.load()
projection = perceptor.project_text("foo bar baz")
assert isinstance(projection, torch.Tensor)

def test_project_img():
ldr = loader(**loader_args)
perceptor = ldr.load()
img = PIL.Image.open("./tests/assets/marley_birthday.jpg").resize((250,200))
projection = perceptor.project_image(img)
assert isinstance(projection, torch.Tensor)

def test_supported_modalities():
ldr = loader(**loader_args)
perceptor = ldr.load()
assert perceptor.supports_text
assert perceptor.supports_image
assert not perceptor.supports_audio
12 changes: 11 additions & 1 deletion tests/test_mmc_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,14 @@ def test_load_sbert_mclip():
def test_load_clipfa():
from mmc.loaders import ClipFaLoader
ldr = ClipFaLoader()
farsi_clip = ldr.load()
farsi_clip = ldr.load()

def test_load_katcloob():
from mmc.loaders import KatCloobLoader
ldr = KatCloobLoader()
cloob = ldr.load()

def test_load_kelip():
from mmc.loaders import ClipKelipLoader
ldr = ClipKelipLoader()
kelip = ldr.load()

0 comments on commit 633b6af

Please sign in to comment.