diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index ededbe43e..8cc27f1df 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,10 +1,12 @@ import torch import torch.nn as nn from functools import partial -import clip +import open_clip as clip from einops import rearrange, repeat from transformers import CLIPTokenizer, CLIPTextModel import kornia +from ldm.modules.rope_utils import build_rope_cache, apply_rope + from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test @@ -140,10 +142,17 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_l super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) + # === Inject RoPE into attention layers === + for name, module in self.transformer.named_modules(): + if isinstance(module, torch.nn.MultiheadAttention): + setattr(self.transformer, name, RoPEAttentionWrapper(module)) + print(f"[RoPE] Wrapped attention module: {name}") + self.device = device self.max_length = max_length self.freeze() + def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): @@ -227,6 +236,41 @@ def forward(self, x): # x is assumed to be in range [-1,1] return self.model.encode_image(self.preprocess(x)) +class RoPEAttentionWrapper(nn.Module): + def __init__(self, attn_layer): + super().__init__() + self.attn = attn_layer + self.rope_cache = None + + def forward(self, x, *args, **kwargs): + B, S, C = x.shape # batch, seq_len, channels + device = x.device + num_heads = self.attn.num_heads + head_dim = C // num_heads + + # Linear projection to get QKV + qkv = F.linear(x, self.attn.in_proj_weight, self.attn.in_proj_bias) + qkv = qkv.view(B, S, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + # Build rope cache if not existing + if self.rope_cache is None or self.rope_cache[0].shape[2] != S: + self.rope_cache = build_rope_cache(S, head_dim, device) + + # Apply RoPE + q = apply_rope(q, self.rope_cache) + k = apply_rope(k, self.rope_cache) + + # Attention calculation + attn_weights = torch.matmul(q, k.transpose(-2, -1)) * (head_dim ** -0.5) + attn_weights = attn_weights.softmax(dim=-1) + attn_output = torch.matmul(attn_weights, v) + + attn_output = attn_output.transpose(1, 2).reshape(B, S, C) + output = self.attn.out_proj(attn_output) + + return output + if __name__ == "__main__": from ldm.util import count_params diff --git a/ldm/modules/rope_utils.py b/ldm/modules/rope_utils.py new file mode 100644 index 000000000..d15263cf1 --- /dev/null +++ b/ldm/modules/rope_utils.py @@ -0,0 +1,20 @@ +# ldm/modules/rope_utils.py + +import torch + +def build_rope_cache(seq_len, head_dim, device): + inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim)) + t = torch.arange(seq_len, device=device).type_as(inv_freq) + freqs = torch.einsum('i,j->ij', t, inv_freq) # (seq_len, head_dim/2) + emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, head_dim) + sin_emb = emb.sin()[None, None, :, :] # (1, 1, seq_len, head_dim) + cos_emb = emb.cos()[None, None, :, :] + return sin_emb, cos_emb + +def apply_rope(x, rope_cache): + sin_emb, cos_emb = rope_cache + x1 = x[..., ::2] + x2 = x[..., 1::2] + x_out = torch.cat([x1 * cos_emb - x2 * sin_emb, + x1 * sin_emb + x2 * cos_emb], dim=-1) + return x_out diff --git a/scripts/finetune_encoder.py b/scripts/finetune_encoder.py new file mode 100644 index 000000000..fe2fec130 --- /dev/null +++ b/scripts/finetune_encoder.py @@ -0,0 +1,109 @@ +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from datasets import load_dataset +from sklearn.metrics import precision_recall_fscore_support +import torch.nn.functional as F + +from ldm.modules.encoders.modules import FrozenCLIPEmbedder + +# === Config === +device = "cuda" if torch.cuda.is_available() else "cpu" +batch_size = 32 +epochs = 3 +lr = 1e-5 +max_length = 77 +save_dir = "./checkpoints" +os.makedirs(save_dir, exist_ok=True) +save_every_n_steps = 1000 # Save every 1000 batches + +# === Dataset === +class CocoCountingDataset(torch.utils.data.Dataset): + def __init__(self, split="train", tokenizer=None, max_length=77): + self.dataset = load_dataset("conceptual_captions", split=split) + self.tokenizer = tokenizer + self.max_length = max_length + self.number_words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten'] + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + caption = self.dataset[idx]['caption'].lower() + label = int(any(word in caption for word in self.number_words)) # label 1 if counting word exists + + if label == 0: + caption = "one object." + + encoding = self.tokenizer(caption, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt") + input_ids = encoding["input_ids"].squeeze(0) + attention_mask = encoding["attention_mask"].squeeze(0) + return input_ids, attention_mask, label + +# === Model === +model = FrozenCLIPEmbedder(version="openai/clip-vit-large-patch14", device=device, max_length=max_length) + +for param in model.transformer.parameters(): + param.requires_grad = True + +model = model.to(device) +optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.transformer.parameters()), lr=lr) + +# === Dataloader === +dataset = CocoCountingDataset(split="train", tokenizer=model.tokenizer, max_length=max_length) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) + +# === Training === +model.train() +global_step = 0 +for epoch in range(epochs): + total_loss = 0 + preds, targets = [], [] + + for batch_idx, (input_ids, attention_mask, labels) in enumerate(tqdm(dataloader)): + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + labels = labels.to(device) + + outputs = model.transformer(input_ids=input_ids, attention_mask=attention_mask) + embeddings = outputs.last_hidden_state + + loss = torch.mean(torch.norm(embeddings, dim=-1)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # Mock "classification" for precision/recall: use embedding norm as pseudo-score + scores = torch.norm(embeddings[:, 0, :], dim=-1) # CLS token norm + pred_labels = (scores > scores.mean()).long() + + preds.extend(pred_labels.cpu().tolist()) + targets.extend(labels.cpu().tolist()) + + global_step += 1 + + # === Save checkpoint mid-epoch + if global_step % save_every_n_steps == 0: + checkpoint_path = os.path.join(save_dir, f"clip_rope_step{global_step}.pth") + torch.save(model.transformer.state_dict(), checkpoint_path) + print(f"[Checkpoint] Saved at step {global_step}") + + # === End of epoch logging === + precision, recall, f1, _ = precision_recall_fscore_support(targets, preds, average='binary') + print(f"Epoch {epoch+1}/{epochs}: Loss={total_loss/len(dataloader):.4f}") + print(f"Precision: {precision:.4f} Recall: {recall:.4f} F1: {f1:.4f}") + + # Save after each epoch + checkpoint_path = os.path.join(save_dir, f"clip_rope_epoch{epoch+1}.pth") + torch.save(model.transformer.state_dict(), checkpoint_path) + print(f"[Checkpoint] Saved model after epoch {epoch+1}") + +# === Final Save === +torch.save(model.transformer.state_dict(), "./clip_rope_finetuned_final.pth") +print("[Final Save] Fine-tuned text encoder saved!") diff --git a/scripts/train_clip_rope.py b/scripts/train_clip_rope.py new file mode 100644 index 000000000..481ec2558 --- /dev/null +++ b/scripts/train_clip_rope.py @@ -0,0 +1,81 @@ +import sys +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from datasets import load_dataset + +from ldm.modules.encoders.modules import FrozenCLIPEmbedder + +# === Config === +device = "cuda" if torch.cuda.is_available() else "cpu" +batch_size = 32 +epochs = 3 +lr = 1e-5 +max_length = 77 +save_path = "./clip_rope_finetuned.pth" + +# === Dataset === +class CocoCountingDataset(torch.utils.data.Dataset): + def __init__(self, split="train", tokenizer=None, max_length=77): + self.dataset = load_dataset("conceptual_captions", split=split) + self.tokenizer = tokenizer + self.max_length = max_length + self.number_words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten'] + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + caption = self.dataset[idx]['caption'].lower() + + if not any(word in caption for word in self.number_words): + caption = "one object." # fallback dummy caption + + encoding = self.tokenizer(caption, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt") + input_ids = encoding["input_ids"].squeeze(0) + attention_mask = encoding["attention_mask"].squeeze(0) + return input_ids, attention_mask + +# === Model === +model = FrozenCLIPEmbedder(version="openai/clip-vit-large-patch14", device=device, max_length=max_length) + +# ❗ Unfreeze only transformer parameters +for param in model.transformer.parameters(): + param.requires_grad = True + +model = model.to(device) + +# ❗ Optimizer on transformer parameters +optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.transformer.parameters()), lr=lr) + +# === Dataloader === +dataset = CocoCountingDataset(split="train", tokenizer=model.tokenizer, max_length=max_length) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) + +# === Training === +model.train() +for epoch in range(epochs): + total_loss = 0 + for input_ids, attention_mask in tqdm(dataloader): + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + + outputs = model.transformer(input_ids=input_ids, attention_mask=attention_mask) + embeddings = outputs.last_hidden_state + + # Simple L2 loss + loss = torch.mean(torch.norm(embeddings, dim=-1)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + + print(f"Epoch {epoch+1}/{epochs}: Loss={total_loss/len(dataloader):.4f}") + +# === Save the fine-tuned transformer +torch.save(model.transformer.state_dict(), save_path) +print(f"Fine-tuned text encoder saved to {save_path}")