From 9a4cbb1ca5ee2626624633306b7994897de6bc86 Mon Sep 17 00:00:00 2001 From: stephantul Date: Wed, 26 Feb 2025 13:01:16 +0100 Subject: [PATCH] feat: styling, fixing some bugs --- model2vec/knowledge_distill/featurize.py | 21 ++--- .../knowledge_distillation.py | 83 ++++++------------- model2vec/knowledge_distill/utils.py | 22 ++--- 3 files changed, 39 insertions(+), 87 deletions(-) diff --git a/model2vec/knowledge_distill/featurize.py b/model2vec/knowledge_distill/featurize.py index b7321de..6067d16 100644 --- a/model2vec/knowledge_distill/featurize.py +++ b/model2vec/knowledge_distill/featurize.py @@ -57,35 +57,25 @@ def featurize( seen = set() total_means = 0 + base_filename = None for index, batch in enumerate(tqdm(batched(texts, batch_size))): i = index // _SAVE_INTERVAL base_filename = f"featurized_{i}" - list_batch = [x["text"].strip() for x in batch if x.get("text")] - if not list_batch: - continue # Skip empty batches - + list_batch = list(batch) # Encode the batch to get token embeddings - token_embeddings = model.encode( - list_batch, - output_value="token_embeddings", - convert_to_tensor=True, - ) + token_embeddings = model.encode(list_batch, output_value="token_embeddings", convert_to_numpy=True) # Tokenize the batch to get input IDs tokenized_ids = model.tokenize(list_batch)["input_ids"] for tokenized_id, token_embedding in zip(tokenized_ids, token_embeddings): - # Convert token IDs to tokens (excluding special tokens) - token_ids = tokenized_id[1:-1] # Decode tokens to text - text = model.tokenizer.decode(tokenized_id, skip_special_tokens=True) + text = model.tokenizer.decode(tokenized_ids, skip_special_tokens=True) if text in seen: continue seen.add(text) # Get the corresponding token embeddings (excluding special tokens) token_embeds = token_embedding[1:-1] - # Convert embeddings to NumPy arrays - token_embeds = token_embeds.detach().cpu().numpy() # Compute the mean of the token embeddings mean = np.mean(token_embeds, axis=0) txts.append(text) @@ -102,6 +92,9 @@ def featurize( means = [] seen = set() else: + # This happens if there are fewer than _SAVE_INTERVAL texts. + if base_filename is None: + base_filename = "featurized_0" if txts and means: save_data(means, txts, str(out_path / base_filename)) diff --git a/model2vec/knowledge_distill/knowledge_distillation.py b/model2vec/knowledge_distill/knowledge_distillation.py index b4204c8..10c3d18 100644 --- a/model2vec/knowledge_distill/knowledge_distillation.py +++ b/model2vec/knowledge_distill/knowledge_distillation.py @@ -1,5 +1,6 @@ import logging from pathlib import Path +from tempfile import TemporaryDirectory from typing import Any, Optional, Type import lightning as pl @@ -9,7 +10,6 @@ from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from sklearn.decomposition import PCA from tokenizers import Tokenizer -from torch import nn from model2vec import StaticModel from model2vec.knowledge_distill.utils import calculate_token_probabilities, collect_means_and_texts @@ -17,14 +17,8 @@ logger = logging.getLogger(__name__) - -class KnowledgeDistillationDataset(TextDataset): - """Dataset class for Knowledge Distillation training.""" - - def __init__(self, texts: list[str], targets: torch.Tensor, tokenizer: Tokenizer) -> None: - """Initialize a Knowledge Distillation dataset.""" - tokenized_texts = [encoding.ids for encoding in tokenizer.encode_batch_fast(texts, add_special_tokens=False)] - super().__init__(tokenized_texts, targets) +# By default, train on 512 token chunks. +_MAX_LENGTH = 512 class KnowledgeDistillationModel(FinetunableStaticModel, pl.LightningModule): @@ -63,32 +57,6 @@ def __init__( self.mse_weight = mse_weight self.w = self.construct_weights() - def construct_weights(self) -> nn.Parameter: - """Construct the weights for the model.""" - weights = torch.ones(len(self.vectors)) # Change from zeros to ones - weights[self.pad_id] = 0 # Make sure padding gets ignored - return nn.Parameter(weights) - - def sub_forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass through the mean pooling.""" - w = self.w[x] - zeros = (x != self.pad_id).float() - length = zeros.sum(1) - embedded = self.embeddings(x) - - # Zero out the padding - embedded = embedded * zeros[:, :, None] - embedded = (embedded * w[:, :, None]).sum(1) / (w.sum(1)[:, None]) - - embedded = embedded / length[:, None] - - return embedded - - def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Forward pass through the mean pooling, and a classifier layer after.""" - encoded = self.sub_forward(input_ids) - return self.head(encoded), encoded - def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: """The training step for the model.""" input_ids, target_vectors = batch @@ -151,24 +119,24 @@ def fit( if patience is not None: callbacks.append(EarlyStopping(monitor="train_loss", mode="min", patience=patience)) - checkpoint_callback = ModelCheckpoint( - monitor="train_loss", - mode="min", - save_top_k=1, - dirpath="checkpoints/", - filename="best_model", - ) - callbacks.append(checkpoint_callback) - - train_loader = dataset.to_dataloader(batch_size=batch_size, shuffle=True) - trainer = pl.Trainer(max_epochs=max_epochs, accelerator=device, callbacks=callbacks) - trainer.fit(self, train_loader) - - # Load the best checkpoint after training - best_model_path = checkpoint_callback.best_model_path - if best_model_path: - logger.info(f"Loading best model from {best_model_path}") - self.load_state_dict(torch.load(best_model_path)["state_dict"]) + with TemporaryDirectory() as tempdir: + checkpoint_callback = ModelCheckpoint( + monitor="train_loss", + mode="min", + save_top_k=1, + dirpath=tempdir, + ) + callbacks.append(checkpoint_callback) + + train_loader = dataset.to_dataloader(batch_size=batch_size, shuffle=True) + trainer = pl.Trainer(max_epochs=max_epochs, accelerator=device, callbacks=callbacks) + trainer.fit(self, train_loader) + + # Load the best checkpoint after training + best_model_path = checkpoint_callback.best_model_path + if best_model_path: + logger.info(f"Loading best model from {best_model_path}") + self.load_state_dict(torch.load(best_model_path, weights_only=True)["state_dict"]) def apply_weighting(self, texts: list[str], alpha: float = 1e-3, pca_dims: int = 256) -> StaticModel: """ @@ -206,7 +174,7 @@ def save_pretrained(self, save_directory: str) -> None: """Convert the trained model to a StaticModel and save it.""" final_static = self.to_static_model() final_static.save_pretrained(save_directory) - logger.info(f"Saved TokenlearnModel as a static model to '{save_directory}'") + logger.info(f"Saved Tokenlearn model as a static model to '{save_directory}'") def main() -> None: @@ -217,7 +185,7 @@ def main() -> None: # Collect paths for training data paths = sorted(Path("../tokenlearn/data/c4_features_bgebase_test").glob("*.json")) - X, y = collect_means_and_texts(paths) + texts, y = collect_means_and_texts(paths) # Detect device device = "cuda" if torch.cuda.is_available() else "cpu" @@ -226,7 +194,8 @@ def main() -> None: y_tensor = torch.tensor(y, dtype=torch.float32, device=device) # Convert to TokenlearnDataset - dataset = KnowledgeDistillationDataset(X, y_tensor, tokenizer=model.tokenizer) + tokenized = model.tokenize(texts) + dataset = TextDataset(tokenized, y_tensor) # Create a TokenlearnModel from the StaticModel tokenlearn_model = KnowledgeDistillationModel.from_static_model(model, out_dim=y_tensor.shape[1]) @@ -236,7 +205,7 @@ def main() -> None: tokenlearn_model.fit(dataset, batch_size=256, max_epochs=50, device=device) # Apply SIF weighting + PCA to the embeddings - tokenlearn_model.apply_weighting(X, alpha=1e-3, pca_dims=256) + tokenlearn_model.apply_weighting(texts, alpha=1e-3, pca_dims=256) # Save the final static model tokenlearn_model.save_pretrained("models/potion-base-8M-reproduce-v1") diff --git a/model2vec/knowledge_distill/utils.py b/model2vec/knowledge_distill/utils.py index e3a5893..da247d3 100644 --- a/model2vec/knowledge_distill/utils.py +++ b/model2vec/knowledge_distill/utils.py @@ -2,10 +2,8 @@ import logging from collections import Counter from pathlib import Path -from typing import Any import numpy as np -import regex from more_itertools import batched from tokenizers import Tokenizer from tqdm import tqdm @@ -13,26 +11,26 @@ logger = logging.getLogger(__name__) -def create_vocab(texts: list[str], vocab_size: int = 56_000) -> list[str]: +def create_vocab(texts: list[str], tokenizer: Tokenizer, vocab_size: int = 56_000) -> tuple[str, ...]: """ Create a vocabulary from a list of texts. :param texts: The list of texts to create the vocabulary from. + :param tokenizer: The tokenizer to use. :param vocab_size: The size of the vocabulary. Defaults to 56,000, which is the vocab_size used for our 32M models. :return: The vocabulary. """ - tokenizer_regex = regex.compile(r"\w+|[^\w\s]+") - # Tokenize all texts - tokens = [] + tokens: list[str] = [] for text in tqdm(texts, desc="Tokenizing texts"): - tokens.extend(tokenizer_regex.findall(text.lower())) + _, toks = zip(*tokenizer.pre_tokenizer.pre_tokenize_str(text)) + tokens.extend(toks) # Count the tokens token_counts = Counter(tokens) # Get the most common tokens as the vocabulary - vocab = [word for word, _ in token_counts.most_common(vocab_size)] + vocab, _ = zip(*token_counts.most_common(vocab_size)) return vocab @@ -91,11 +89,3 @@ def calculate_token_probabilities(tokenizer: Tokenizer, txt: list[str]) -> np.nd x /= sum_id return x - - -# def load_dataset(path: str): -# """Load a dataset from a file.""" -# # Collect paths for training data -# paths = sorted(Path(path).glob("*.json")) -# X, y = collect_means_and_texts(paths) -# return X, y