Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: styling, fixing some bugs #201

Merged
merged 1 commit into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions model2vec/knowledge_distill/featurize.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,35 +57,25 @@ def featurize(
seen = set()
total_means = 0

base_filename = None
for index, batch in enumerate(tqdm(batched(texts, batch_size))):
i = index // _SAVE_INTERVAL
base_filename = f"featurized_{i}"
list_batch = [x["text"].strip() for x in batch if x.get("text")]
if not list_batch:
continue # Skip empty batches

list_batch = list(batch)
# Encode the batch to get token embeddings
token_embeddings = model.encode(
list_batch,
output_value="token_embeddings",
convert_to_tensor=True,
)
token_embeddings = model.encode(list_batch, output_value="token_embeddings", convert_to_numpy=True)

# Tokenize the batch to get input IDs
tokenized_ids = model.tokenize(list_batch)["input_ids"]

for tokenized_id, token_embedding in zip(tokenized_ids, token_embeddings):
# Convert token IDs to tokens (excluding special tokens)
token_ids = tokenized_id[1:-1]
# Decode tokens to text
text = model.tokenizer.decode(tokenized_id, skip_special_tokens=True)
text = model.tokenizer.decode(tokenized_ids, skip_special_tokens=True)
if text in seen:
continue
seen.add(text)
# Get the corresponding token embeddings (excluding special tokens)
token_embeds = token_embedding[1:-1]
# Convert embeddings to NumPy arrays
token_embeds = token_embeds.detach().cpu().numpy()
# Compute the mean of the token embeddings
mean = np.mean(token_embeds, axis=0)
txts.append(text)
Expand All @@ -102,6 +92,9 @@ def featurize(
means = []
seen = set()
else:
# This happens if there are fewer than _SAVE_INTERVAL texts.
if base_filename is None:
base_filename = "featurized_0"
if txts and means:
save_data(means, txts, str(out_path / base_filename))

Expand Down
83 changes: 26 additions & 57 deletions model2vec/knowledge_distill/knowledge_distillation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Optional, Type

import lightning as pl
Expand All @@ -9,22 +10,15 @@
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.decomposition import PCA
from tokenizers import Tokenizer
from torch import nn

from model2vec import StaticModel
from model2vec.knowledge_distill.utils import calculate_token_probabilities, collect_means_and_texts
from model2vec.train.base import FinetunableStaticModel, ModelType, TextDataset

logger = logging.getLogger(__name__)


class KnowledgeDistillationDataset(TextDataset):
"""Dataset class for Knowledge Distillation training."""

def __init__(self, texts: list[str], targets: torch.Tensor, tokenizer: Tokenizer) -> None:
"""Initialize a Knowledge Distillation dataset."""
tokenized_texts = [encoding.ids for encoding in tokenizer.encode_batch_fast(texts, add_special_tokens=False)]
super().__init__(tokenized_texts, targets)
# By default, train on 512 token chunks.
_MAX_LENGTH = 512


class KnowledgeDistillationModel(FinetunableStaticModel, pl.LightningModule):
Expand Down Expand Up @@ -63,32 +57,6 @@ def __init__(
self.mse_weight = mse_weight
self.w = self.construct_weights()

def construct_weights(self) -> nn.Parameter:
"""Construct the weights for the model."""
weights = torch.ones(len(self.vectors)) # Change from zeros to ones
weights[self.pad_id] = 0 # Make sure padding gets ignored
return nn.Parameter(weights)

def sub_forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the mean pooling."""
w = self.w[x]
zeros = (x != self.pad_id).float()
length = zeros.sum(1)
embedded = self.embeddings(x)

# Zero out the padding
embedded = embedded * zeros[:, :, None]
embedded = (embedded * w[:, :, None]).sum(1) / (w.sum(1)[:, None])

embedded = embedded / length[:, None]

return embedded

def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the mean pooling, and a classifier layer after."""
encoded = self.sub_forward(input_ids)
return self.head(encoded), encoded

def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""The training step for the model."""
input_ids, target_vectors = batch
Expand Down Expand Up @@ -151,24 +119,24 @@ def fit(
if patience is not None:
callbacks.append(EarlyStopping(monitor="train_loss", mode="min", patience=patience))

checkpoint_callback = ModelCheckpoint(
monitor="train_loss",
mode="min",
save_top_k=1,
dirpath="checkpoints/",
filename="best_model",
)
callbacks.append(checkpoint_callback)

train_loader = dataset.to_dataloader(batch_size=batch_size, shuffle=True)
trainer = pl.Trainer(max_epochs=max_epochs, accelerator=device, callbacks=callbacks)
trainer.fit(self, train_loader)

# Load the best checkpoint after training
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
logger.info(f"Loading best model from {best_model_path}")
self.load_state_dict(torch.load(best_model_path)["state_dict"])
with TemporaryDirectory() as tempdir:
checkpoint_callback = ModelCheckpoint(
monitor="train_loss",
mode="min",
save_top_k=1,
dirpath=tempdir,
)
callbacks.append(checkpoint_callback)

train_loader = dataset.to_dataloader(batch_size=batch_size, shuffle=True)
trainer = pl.Trainer(max_epochs=max_epochs, accelerator=device, callbacks=callbacks)
trainer.fit(self, train_loader)

# Load the best checkpoint after training
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
logger.info(f"Loading best model from {best_model_path}")
self.load_state_dict(torch.load(best_model_path, weights_only=True)["state_dict"])

def apply_weighting(self, texts: list[str], alpha: float = 1e-3, pca_dims: int = 256) -> StaticModel:
"""
Expand Down Expand Up @@ -206,7 +174,7 @@ def save_pretrained(self, save_directory: str) -> None:
"""Convert the trained model to a StaticModel and save it."""
final_static = self.to_static_model()
final_static.save_pretrained(save_directory)
logger.info(f"Saved TokenlearnModel as a static model to '{save_directory}'")
logger.info(f"Saved Tokenlearn model as a static model to '{save_directory}'")


def main() -> None:
Expand All @@ -217,7 +185,7 @@ def main() -> None:

# Collect paths for training data
paths = sorted(Path("../tokenlearn/data/c4_features_bgebase_test").glob("*.json"))
X, y = collect_means_and_texts(paths)
texts, y = collect_means_and_texts(paths)

# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -226,7 +194,8 @@ def main() -> None:
y_tensor = torch.tensor(y, dtype=torch.float32, device=device)

# Convert to TokenlearnDataset
dataset = KnowledgeDistillationDataset(X, y_tensor, tokenizer=model.tokenizer)
tokenized = model.tokenize(texts)
dataset = TextDataset(tokenized, y_tensor)

# Create a TokenlearnModel from the StaticModel
tokenlearn_model = KnowledgeDistillationModel.from_static_model(model, out_dim=y_tensor.shape[1])
Expand All @@ -236,7 +205,7 @@ def main() -> None:
tokenlearn_model.fit(dataset, batch_size=256, max_epochs=50, device=device)

# Apply SIF weighting + PCA to the embeddings
tokenlearn_model.apply_weighting(X, alpha=1e-3, pca_dims=256)
tokenlearn_model.apply_weighting(texts, alpha=1e-3, pca_dims=256)

# Save the final static model
tokenlearn_model.save_pretrained("models/potion-base-8M-reproduce-v1")
Expand Down
22 changes: 6 additions & 16 deletions model2vec/knowledge_distill/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,35 @@
import logging
from collections import Counter
from pathlib import Path
from typing import Any

import numpy as np
import regex
from more_itertools import batched
from tokenizers import Tokenizer
from tqdm import tqdm

logger = logging.getLogger(__name__)


def create_vocab(texts: list[str], vocab_size: int = 56_000) -> list[str]:
def create_vocab(texts: list[str], tokenizer: Tokenizer, vocab_size: int = 56_000) -> tuple[str, ...]:
"""
Create a vocabulary from a list of texts.

:param texts: The list of texts to create the vocabulary from.
:param tokenizer: The tokenizer to use.
:param vocab_size: The size of the vocabulary. Defaults to 56,000, which is the vocab_size used for our 32M models.
:return: The vocabulary.
"""
tokenizer_regex = regex.compile(r"\w+|[^\w\s]+")

# Tokenize all texts
tokens = []
tokens: list[str] = []
for text in tqdm(texts, desc="Tokenizing texts"):
tokens.extend(tokenizer_regex.findall(text.lower()))
_, toks = zip(*tokenizer.pre_tokenizer.pre_tokenize_str(text))
tokens.extend(toks)

# Count the tokens
token_counts = Counter(tokens)

# Get the most common tokens as the vocabulary
vocab = [word for word, _ in token_counts.most_common(vocab_size)]
vocab, _ = zip(*token_counts.most_common(vocab_size))
return vocab


Expand Down Expand Up @@ -91,11 +89,3 @@ def calculate_token_probabilities(tokenizer: Tokenizer, txt: list[str]) -> np.nd
x /= sum_id

return x


# def load_dataset(path: str):
# """Load a dataset from a file."""
# # Collect paths for training data
# paths = sorted(Path(path).glob("*.json"))
# X, y = collect_means_and_texts(paths)
# return X, y