Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Feb 8, 2025
1 parent c2046f6 commit 619b36e
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions model2vec/train/tokenlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from lightning.pytorch.callbacks import EarlyStopping
from sklearn.decomposition import PCA
from tokenizers import Tokenizer
from torch import nn

from model2vec import StaticModel
from model2vec.train.base import FinetunableStaticModel, ModelType, TextDataset
Expand Down Expand Up @@ -60,6 +61,33 @@ def __init__(
self.lr_linear = lr_linear
self.cosine_weight = cosine_weight
self.mse_weight = mse_weight
self.w = self.construct_weights()

def construct_weights(self) -> nn.Parameter:
"""Construct the weights for the model."""
weights = torch.ones(len(self.vectors)) # Change from zeros to ones
weights[self.pad_id] = 0 # Make sure padding gets ignored
return nn.Parameter(weights)

def sub_forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the mean pooling."""
w = self.w[x]
zeros = (x != self.pad_id).float()
length = zeros.sum(1)
embedded = self.embeddings(x)

# Zero out the padding
embedded = embedded * zeros[:, :, None]
embedded = (embedded * w[:, :, None]).sum(1) / (w.sum(1)[:, None] + 1e-16)

embedded = embedded / length[:, None]

return embedded

def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the mean pooling, and a classifier layer after."""
encoded = self.sub_forward(input_ids)
return self.head(encoded), encoded

def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
"""The training step for the model."""
Expand Down Expand Up @@ -169,6 +197,7 @@ def main() -> None:
"""Run the tokenlearn training script."""
# Initialize a StaticModel
model = StaticModel.from_pretrained("minishlab/M2V_base_output")
model.normalize = True

# Collect paths for training data
paths = sorted(Path("../tokenlearn/data/c4_features_bgebase_test").glob("*.json"))
Expand Down

0 comments on commit 619b36e

Please sign in to comment.