Skip to content

Commit 2426f76

Browse files
authored
Merge pull request #1 from Rganeshk/Rope-implementation
implemented rope
2 parents 21f890f + 0afc84a commit 2426f76

File tree

4 files changed

+255
-1
lines changed

4 files changed

+255
-1
lines changed

ldm/modules/encoders/modules.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import torch
22
import torch.nn as nn
33
from functools import partial
4-
import clip
4+
import open_clip as clip
55
from einops import rearrange, repeat
66
from transformers import CLIPTokenizer, CLIPTextModel
77
import kornia
8+
from ldm.modules.rope_utils import build_rope_cache, apply_rope
9+
810

911
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
1012

@@ -140,10 +142,17 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_l
140142
super().__init__()
141143
self.tokenizer = CLIPTokenizer.from_pretrained(version)
142144
self.transformer = CLIPTextModel.from_pretrained(version)
145+
# === Inject RoPE into attention layers ===
146+
for name, module in self.transformer.named_modules():
147+
if isinstance(module, torch.nn.MultiheadAttention):
148+
setattr(self.transformer, name, RoPEAttentionWrapper(module))
149+
print(f"[RoPE] Wrapped attention module: {name}")
150+
143151
self.device = device
144152
self.max_length = max_length
145153
self.freeze()
146154

155+
147156
def freeze(self):
148157
self.transformer = self.transformer.eval()
149158
for param in self.parameters():
@@ -227,6 +236,41 @@ def forward(self, x):
227236
# x is assumed to be in range [-1,1]
228237
return self.model.encode_image(self.preprocess(x))
229238

239+
class RoPEAttentionWrapper(nn.Module):
240+
def __init__(self, attn_layer):
241+
super().__init__()
242+
self.attn = attn_layer
243+
self.rope_cache = None
244+
245+
def forward(self, x, *args, **kwargs):
246+
B, S, C = x.shape # batch, seq_len, channels
247+
device = x.device
248+
num_heads = self.attn.num_heads
249+
head_dim = C // num_heads
250+
251+
# Linear projection to get QKV
252+
qkv = F.linear(x, self.attn.in_proj_weight, self.attn.in_proj_bias)
253+
qkv = qkv.view(B, S, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4)
254+
q, k, v = qkv[0], qkv[1], qkv[2]
255+
256+
# Build rope cache if not existing
257+
if self.rope_cache is None or self.rope_cache[0].shape[2] != S:
258+
self.rope_cache = build_rope_cache(S, head_dim, device)
259+
260+
# Apply RoPE
261+
q = apply_rope(q, self.rope_cache)
262+
k = apply_rope(k, self.rope_cache)
263+
264+
# Attention calculation
265+
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * (head_dim ** -0.5)
266+
attn_weights = attn_weights.softmax(dim=-1)
267+
attn_output = torch.matmul(attn_weights, v)
268+
269+
attn_output = attn_output.transpose(1, 2).reshape(B, S, C)
270+
output = self.attn.out_proj(attn_output)
271+
272+
return output
273+
230274

231275
if __name__ == "__main__":
232276
from ldm.util import count_params

ldm/modules/rope_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# ldm/modules/rope_utils.py
2+
3+
import torch
4+
5+
def build_rope_cache(seq_len, head_dim, device):
6+
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
7+
t = torch.arange(seq_len, device=device).type_as(inv_freq)
8+
freqs = torch.einsum('i,j->ij', t, inv_freq) # (seq_len, head_dim/2)
9+
emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, head_dim)
10+
sin_emb = emb.sin()[None, None, :, :] # (1, 1, seq_len, head_dim)
11+
cos_emb = emb.cos()[None, None, :, :]
12+
return sin_emb, cos_emb
13+
14+
def apply_rope(x, rope_cache):
15+
sin_emb, cos_emb = rope_cache
16+
x1 = x[..., ::2]
17+
x2 = x[..., 1::2]
18+
x_out = torch.cat([x1 * cos_emb - x2 * sin_emb,
19+
x1 * sin_emb + x2 * cos_emb], dim=-1)
20+
return x_out

scripts/finetune_encoder.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import sys
2+
import os
3+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
4+
import torch
5+
from torch.utils.data import DataLoader
6+
from tqdm import tqdm
7+
from datasets import load_dataset
8+
from sklearn.metrics import precision_recall_fscore_support
9+
import torch.nn.functional as F
10+
11+
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
12+
13+
# === Config ===
14+
device = "cuda" if torch.cuda.is_available() else "cpu"
15+
batch_size = 32
16+
epochs = 3
17+
lr = 1e-5
18+
max_length = 77
19+
save_dir = "./checkpoints"
20+
os.makedirs(save_dir, exist_ok=True)
21+
save_every_n_steps = 1000 # Save every 1000 batches
22+
23+
# === Dataset ===
24+
class CocoCountingDataset(torch.utils.data.Dataset):
25+
def __init__(self, split="train", tokenizer=None, max_length=77):
26+
self.dataset = load_dataset("conceptual_captions", split=split)
27+
self.tokenizer = tokenizer
28+
self.max_length = max_length
29+
self.number_words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten']
30+
31+
def __len__(self):
32+
return len(self.dataset)
33+
34+
def __getitem__(self, idx):
35+
caption = self.dataset[idx]['caption'].lower()
36+
label = int(any(word in caption for word in self.number_words)) # label 1 if counting word exists
37+
38+
if label == 0:
39+
caption = "one object."
40+
41+
encoding = self.tokenizer(caption, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
42+
input_ids = encoding["input_ids"].squeeze(0)
43+
attention_mask = encoding["attention_mask"].squeeze(0)
44+
return input_ids, attention_mask, label
45+
46+
# === Model ===
47+
model = FrozenCLIPEmbedder(version="openai/clip-vit-large-patch14", device=device, max_length=max_length)
48+
49+
for param in model.transformer.parameters():
50+
param.requires_grad = True
51+
52+
model = model.to(device)
53+
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.transformer.parameters()), lr=lr)
54+
55+
# === Dataloader ===
56+
dataset = CocoCountingDataset(split="train", tokenizer=model.tokenizer, max_length=max_length)
57+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
58+
59+
# === Training ===
60+
model.train()
61+
global_step = 0
62+
for epoch in range(epochs):
63+
total_loss = 0
64+
preds, targets = [], []
65+
66+
for batch_idx, (input_ids, attention_mask, labels) in enumerate(tqdm(dataloader)):
67+
input_ids = input_ids.to(device)
68+
attention_mask = attention_mask.to(device)
69+
labels = labels.to(device)
70+
71+
outputs = model.transformer(input_ids=input_ids, attention_mask=attention_mask)
72+
embeddings = outputs.last_hidden_state
73+
74+
loss = torch.mean(torch.norm(embeddings, dim=-1))
75+
76+
optimizer.zero_grad()
77+
loss.backward()
78+
optimizer.step()
79+
80+
total_loss += loss.item()
81+
82+
# Mock "classification" for precision/recall: use embedding norm as pseudo-score
83+
scores = torch.norm(embeddings[:, 0, :], dim=-1) # CLS token norm
84+
pred_labels = (scores > scores.mean()).long()
85+
86+
preds.extend(pred_labels.cpu().tolist())
87+
targets.extend(labels.cpu().tolist())
88+
89+
global_step += 1
90+
91+
# === Save checkpoint mid-epoch
92+
if global_step % save_every_n_steps == 0:
93+
checkpoint_path = os.path.join(save_dir, f"clip_rope_step{global_step}.pth")
94+
torch.save(model.transformer.state_dict(), checkpoint_path)
95+
print(f"[Checkpoint] Saved at step {global_step}")
96+
97+
# === End of epoch logging ===
98+
precision, recall, f1, _ = precision_recall_fscore_support(targets, preds, average='binary')
99+
print(f"Epoch {epoch+1}/{epochs}: Loss={total_loss/len(dataloader):.4f}")
100+
print(f"Precision: {precision:.4f} Recall: {recall:.4f} F1: {f1:.4f}")
101+
102+
# Save after each epoch
103+
checkpoint_path = os.path.join(save_dir, f"clip_rope_epoch{epoch+1}.pth")
104+
torch.save(model.transformer.state_dict(), checkpoint_path)
105+
print(f"[Checkpoint] Saved model after epoch {epoch+1}")
106+
107+
# === Final Save ===
108+
torch.save(model.transformer.state_dict(), "./clip_rope_finetuned_final.pth")
109+
print("[Final Save] Fine-tuned text encoder saved!")

scripts/train_clip_rope.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import sys
2+
import os
3+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
4+
import torch
5+
from torch.utils.data import DataLoader
6+
from tqdm import tqdm
7+
from datasets import load_dataset
8+
9+
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
10+
11+
# === Config ===
12+
device = "cuda" if torch.cuda.is_available() else "cpu"
13+
batch_size = 32
14+
epochs = 3
15+
lr = 1e-5
16+
max_length = 77
17+
save_path = "./clip_rope_finetuned.pth"
18+
19+
# === Dataset ===
20+
class CocoCountingDataset(torch.utils.data.Dataset):
21+
def __init__(self, split="train", tokenizer=None, max_length=77):
22+
self.dataset = load_dataset("conceptual_captions", split=split)
23+
self.tokenizer = tokenizer
24+
self.max_length = max_length
25+
self.number_words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten']
26+
27+
def __len__(self):
28+
return len(self.dataset)
29+
30+
def __getitem__(self, idx):
31+
caption = self.dataset[idx]['caption'].lower()
32+
33+
if not any(word in caption for word in self.number_words):
34+
caption = "one object." # fallback dummy caption
35+
36+
encoding = self.tokenizer(caption, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
37+
input_ids = encoding["input_ids"].squeeze(0)
38+
attention_mask = encoding["attention_mask"].squeeze(0)
39+
return input_ids, attention_mask
40+
41+
# === Model ===
42+
model = FrozenCLIPEmbedder(version="openai/clip-vit-large-patch14", device=device, max_length=max_length)
43+
44+
# ❗ Unfreeze only transformer parameters
45+
for param in model.transformer.parameters():
46+
param.requires_grad = True
47+
48+
model = model.to(device)
49+
50+
# ❗ Optimizer on transformer parameters
51+
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.transformer.parameters()), lr=lr)
52+
53+
# === Dataloader ===
54+
dataset = CocoCountingDataset(split="train", tokenizer=model.tokenizer, max_length=max_length)
55+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
56+
57+
# === Training ===
58+
model.train()
59+
for epoch in range(epochs):
60+
total_loss = 0
61+
for input_ids, attention_mask in tqdm(dataloader):
62+
input_ids = input_ids.to(device)
63+
attention_mask = attention_mask.to(device)
64+
65+
outputs = model.transformer(input_ids=input_ids, attention_mask=attention_mask)
66+
embeddings = outputs.last_hidden_state
67+
68+
# Simple L2 loss
69+
loss = torch.mean(torch.norm(embeddings, dim=-1))
70+
71+
optimizer.zero_grad()
72+
loss.backward()
73+
optimizer.step()
74+
75+
total_loss += loss.item()
76+
77+
print(f"Epoch {epoch+1}/{epochs}: Loss={total_loss/len(dataloader):.4f}")
78+
79+
# === Save the fine-tuned transformer
80+
torch.save(model.transformer.state_dict(), save_path)
81+
print(f"Fine-tuned text encoder saved to {save_path}")

0 commit comments

Comments
 (0)