Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLIP Text Encoder #1969

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
55 changes: 55 additions & 0 deletions tests/torchtune/models/clip/test_clip_text_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from torchtune.models.clip._component_builders import clip_text_encoder
from torchtune.training.seed import set_seed

VOCAB_SIZE = 512
MAX_SEQ_LEN = 77
BSZ = 2
EMBED_DIM = 4


@pytest.fixture(autouse=True)
def random():
set_seed(0)


class TestClipTextEncoder:
@pytest.fixture
def model(self):
model = clip_text_encoder(
vocab_size=VOCAB_SIZE,
max_seq_len=MAX_SEQ_LEN,
embed_dim=EMBED_DIM,
num_heads=2,
num_layers=2,
)

for param in model.parameters():
param.data.uniform_(0, 1)

return model

@pytest.fixture
def inputs(self):
return torch.randint(0, VOCAB_SIZE, (BSZ, MAX_SEQ_LEN))

def test_forward(self, model, inputs):
actual = model(inputs)
expected = torch.tensor(
[[0.2195, 1.3941, 0.6295, -0.1026], [0.2418, 1.4928, 0.6177, -0.0863]]
)
assert actual.shape == (BSZ, EMBED_DIM)
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

def test_backward(self, model, inputs):
y = model(inputs)
loss = y.mean()
loss.backward()
61 changes: 61 additions & 0 deletions tests/torchtune/models/clip/test_clip_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
calvinpelletier marked this conversation as resolved.
Show resolved Hide resolved
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import pytest

from tests.common import ASSETS
from torchtune.models.clip._model_builders import clip_tokenizer


class TestCLIPTokenizer:
@pytest.fixture
def tokenizer(self):
return clip_tokenizer(ASSETS / "tiny_bpe_merges.txt")

def test_tokenization(self, tokenizer):
texts = [
"a cow jumping over the moon",
"a helpful AI assistant",
]
correct_tokens = [
_pad(
[
2416,
320,
66,
78,
342,
73,
669,
79,
515,
326,
1190,
337,
673,
324,
76,
819,
333,
2417,
]
),
_pad(
[2416, 320, 516, 75, 79, 69, 84, 331, 64, 328, 813, 667, 540, 339, 2417]
),
]
tokens_tensor = tokenizer(texts)
assert tokens_tensor.tolist() == correct_tokens

def test_decoding(self, tokenizer):
text = "this is torchtune"
decoded_text = "<|startoftext|>this is torchtune <|endoftext|>"
assert decoded_text == tokenizer.decode(tokenizer.encode(text))


def _pad(tokens, max_seq_len=77, pad_token=2417):
while len(tokens) < max_seq_len:
tokens.append(pad_token)
return tokens
64 changes: 59 additions & 5 deletions torchtune/models/clip/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,23 @@

import torch
from torch import nn

from torchtune.models.clip._position_embeddings import (
TiledTokenPositionalEmbedding,
TilePositionalEmbedding,
TokenPositionalEmbedding,
)

from torchtune.models.clip._text_encoder import CLIPTextEncoder
from torchtune.modules import (
FeedForward,
Fp32LayerNorm,
FrozenNF4Linear,
MultiHeadAttention,
TransformerSelfAttentionLayer,
)

from torchtune.modules.activations import QuickGELU
from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook

from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear

from torchtune.modules.peft import LORA_ATTN_MODULES, DoRALinear, LoRALinear
from torchtune.modules.vision_transformer import CLSProjection, VisionTransformer


Expand Down Expand Up @@ -157,6 +156,61 @@ def clip_vision_encoder(
)


def clip_text_encoder(
vocab_size: int = 49408,
max_seq_len: int = 77,
embed_dim: int = 768,
num_heads: int = 12,
num_layers: int = 12,
norm_eps: float = 1e-5,
):
"""
Text encoder for CLIP.

Args:
vocab_size (int): size of the vocabulary, default 49408
max_seq_len (int): context size, default 77
embed_dim (int): embedding/model dimension size, default 768
num_heads (int): number of attention heads, default 12
num_layers (int): number of transformer layers, default 12
norm_eps (float): small value added to denominator for numerical stability, default 1e-5

Returns:
CLIPTextEncoder
"""
attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_heads,
head_dim=embed_dim // num_heads,
q_proj=nn.Linear(embed_dim, embed_dim),
k_proj=nn.Linear(embed_dim, embed_dim),
v_proj=nn.Linear(embed_dim, embed_dim),
output_proj=nn.Linear(embed_dim, embed_dim),
)
mlp = clip_mlp(
in_dim=embed_dim,
out_dim=embed_dim,
hidden_dim=embed_dim * 4,
activation=QuickGELU(),
)
encoder_layer = TransformerSelfAttentionLayer(
attn=attn,
mlp=mlp,
sa_norm=nn.LayerNorm(embed_dim, eps=norm_eps),
mlp_norm=nn.LayerNorm(embed_dim, eps=norm_eps),
)
final_norm = nn.LayerNorm(embed_dim, eps=norm_eps)
return CLIPTextEncoder(
layers=encoder_layer,
final_norm=final_norm,
vocab_size=vocab_size,
max_seq_len=max_seq_len,
embed_dim=embed_dim,
num_layers=num_layers,
)


def clip_mlp(
in_dim: int,
out_dim: int,
Expand Down
48 changes: 48 additions & 0 deletions torchtune/models/clip/_convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtune.models.convert_weights import get_mapped_key

# state dict key mappings from HF's format to torchtune's format
_FROM_HF = {
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.embeddings.position_embedding.weight": "position_embedding",
"text_model.encoder.layers.{}.layer_norm1.weight": "layers.{}.sa_norm.weight",
"text_model.encoder.layers.{}.layer_norm1.bias": "layers.{}.sa_norm.bias",
"text_model.encoder.layers.{}.layer_norm2.weight": "layers.{}.mlp_norm.weight",
"text_model.encoder.layers.{}.layer_norm2.bias": "layers.{}.mlp_norm.bias",
"text_model.encoder.layers.{}.mlp.fc1.weight": "layers.{}.mlp.w1.weight",
"text_model.encoder.layers.{}.mlp.fc1.bias": "layers.{}.mlp.w1.bias",
"text_model.encoder.layers.{}.mlp.fc2.weight": "layers.{}.mlp.w2.weight",
"text_model.encoder.layers.{}.mlp.fc2.bias": "layers.{}.mlp.w2.bias",
"text_model.encoder.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight",
"text_model.encoder.layers.{}.self_attn.q_proj.bias": "layers.{}.attn.q_proj.bias",
"text_model.encoder.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight",
"text_model.encoder.layers.{}.self_attn.k_proj.bias": "layers.{}.attn.k_proj.bias",
"text_model.encoder.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight",
"text_model.encoder.layers.{}.self_attn.v_proj.bias": "layers.{}.attn.v_proj.bias",
"text_model.encoder.layers.{}.self_attn.out_proj.bias": "layers.{}.attn.output_proj.bias",
"text_model.encoder.layers.{}.self_attn.out_proj.weight": "layers.{}.attn.output_proj.weight",
"text_model.final_layer_norm.weight": "final_norm.weight",
"text_model.final_layer_norm.bias": "final_norm.bias",
}

_IGNORE = {
"logit_scale",
"text_model.embeddings.position_ids",
"text_projection.weight",
"visual_projection.weight",
}


def clip_text_hf_to_tune(state_dict):
converted_state_dict = {}
for key, value in state_dict.items():
if key.startswith("vision_model.") or key in _IGNORE:
continue
new_key = get_mapped_key(key, _FROM_HF)
converted_state_dict[new_key] = value
return converted_state_dict
51 changes: 50 additions & 1 deletion torchtune/models/clip/_model_builders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,53 @@
from torchtune.models.clip._transforms import CLIPImageTransform
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from os import PathLike

from torchtune.models.clip._component_builders import clip_text_encoder
from torchtune.models.clip._text_encoder import CLIPTextEncoder
from torchtune.models.clip._tokenizer import CLIPTokenizer
from torchtune.models.clip._transform import CLIPImageTransform


def clip_tokenizer(
merges_path: PathLike,
calvinpelletier marked this conversation as resolved.
Show resolved Hide resolved
max_seq_len: int = 77,
truncate: bool = True,
) -> CLIPTokenizer:
"""
Builder for the CLIP text tokenizer.

Args:
merges_path (PathLike): Path to the CLIP merges file
max_seq_len (bool): Context length
Default: 77
truncate (bool): Truncate the token sequence if it exceeds max_seq_len (otherwise raises AssertionError)
Default: True

Returns:
CLIPTokenizer: Instantiation of the CLIP text tokenizer
"""
return CLIPTokenizer(merges_path, max_seq_len=max_seq_len, truncate=truncate)


def clip_text_vit_large_patch14() -> CLIPTextEncoder:
"""
Builder for the CLIP text encoder for CLIP-ViT-L/14.

Returns:
CLIPTextEncoder: Instantiation of the CLIP text encoder
"""
return clip_text_encoder(
vocab_size=49408,
max_seq_len=77,
embed_dim=768,
num_heads=12,
num_layers=12,
norm_eps=1e-5,
)


def clip_vit_224_transform():
image_transform = CLIPImageTransform(
Expand Down
Loading
Loading