Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove sentence transformers dependency #35

Merged
merged 4 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from huggingface_hub import model_info
from sklearn.decomposition import PCA
from tokenizers import Tokenizer
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast

from model2vec.distill.inference import (
create_output_embeddings_from_model_name,
Expand Down Expand Up @@ -49,12 +49,15 @@ def distill(
)

# Load original tokenizer. We need to keep this to tokenize any tokens in the vocabulary.
original_tokenizer: Tokenizer = Tokenizer.from_pretrained(model_name)
original_tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(model_name)
original_model: PreTrainedModel = AutoModel.from_pretrained(model_name)
# Make a base list of tokens.
tokens: list[str] = []
if use_subword:
# Create the subword embeddings.
tokens, embeddings = create_output_embeddings_from_model_name(model_name, device=device)
tokens, embeddings = create_output_embeddings_from_model_name(
model=original_model, tokenizer=original_tokenizer, device=device
)

# Remove any unused tokens from the tokenizer and embeddings.
wrong_tokens = [x for x in tokens if x.startswith("[unused")]
Expand All @@ -68,15 +71,17 @@ def distill(
logger.info(f"Removed {len(wrong_tokens)} unused tokens from the tokenizer and embeddings.")
else:
# We need to keep the unk token in the tokenizer.
unk_token = original_tokenizer.model.unk_token
unk_token = original_tokenizer.backend_tokenizer.model.unk_token
# Remove all tokens except the UNK token.
new_tokenizer = remove_tokens(original_tokenizer, list(set(original_tokenizer.get_vocab()) - {unk_token}))
new_tokenizer = remove_tokens(
original_tokenizer.backend_tokenizer, list(set(original_tokenizer.get_vocab()) - {unk_token})
)
# We need to set embeddings to None because we don't know the dimensions of the embeddings yet.
embeddings = None

if vocabulary is not None:
# Preprocess the vocabulary with the original tokenizer.
preprocessed_vocabulary = preprocess_vocabulary(original_tokenizer, vocabulary)
preprocessed_vocabulary = preprocess_vocabulary(original_tokenizer.backend_tokenizer, vocabulary)
n_tokens_before = len(preprocessed_vocabulary)
# Clean the vocabulary by removing duplicate tokens and tokens that are in the subword vocabulary.
cleaned_vocabulary = _clean_vocabulary(preprocessed_vocabulary, tokens)
Expand All @@ -88,11 +93,10 @@ def distill(
if cleaned_vocabulary:
# Create the embeddings.
_, token_embeddings = create_output_embeddings_from_model_name_and_tokens(
model_name=model_name,
model=original_model,
tokenizer=original_tokenizer,
tokens=cleaned_vocabulary,
device=device,
output_value="token_embeddings",
include_eos_bos=False,
)

# If we don't have subword tokens, we still need to create
Expand Down
112 changes: 59 additions & 53 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# -*- coding: utf-8 -*-
import logging
from pathlib import Path
from typing import Literal, Protocol, cast
from typing import Protocol

import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions

logger = logging.getLogger(__name__)
Expand All @@ -17,93 +16,100 @@

_DEFAULT_BATCH_SIZE = 1024

OutputValue = Literal["sentence_embedding", "token_embeddings"]


class ModulewithWeights(Protocol):
weight: torch.nn.Parameter


def create_output_embeddings_from_model_name_and_tokens(
model_name: PathLike,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
tokens: list[str],
device: str,
output_value: Literal["sentence_embedding", "token_embeddings"],
include_eos_bos: bool,
) -> tuple[list[str], np.ndarray]:
"""
Create output embeddings for a bunch of tokens from a model name.

It does a forward pass for all tokens passed in tokens.

:param model_name: The model name to use.
:param model: The model name to use.
:param tokenizer: The tokenizer to use.
:param tokens: The tokens to use.
:param device: The torch device to use.
:param output_value: The output value to pass to sentence transformers. If this is 'sentence_embedding', get pooled output, if this is 'token_embedding', get token means.
:param include_eos_bos: Whether to include the eos and bos tokens in the mean. Only applied if output_value == "token_embeddings".
:return: The tokens and output emnbeddings.
:return: The tokens and output embeddings.
"""
embedder = SentenceTransformer(str(model_name), device=device)

embedder_output_dim = _get_embedder_output_dim(output_value, embedder)
model = model.to(device)

out_weights: np.ndarray
if output_value == "token_embeddings":
intermediate_weights: list[np.ndarray] = []
# NOTE: because tokens might be really long, and we want to take the mean anyway, we need to batch.
# otherwise we could go OOM.
for batch_idx in tqdm(range(0, len(tokens), _DEFAULT_BATCH_SIZE)):
batch = tokens[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]
out: list[torch.Tensor] = cast(
list[torch.Tensor], embedder.encode(batch, show_progress_bar=False, output_value=output_value)
)
for idx, token_vectors in enumerate(out):
if not include_eos_bos:
# NOTE: remove BOS/EOS
token_vectors = token_vectors[1:-1]
if len(token_vectors) == 0:
str_repr = batch[idx]
bytes_repr = str_repr.encode("utf-8")
logger.warning(f"Got empty token vectors for word `{str_repr}` with bytes `{bytes_repr!r}`")
mean_vector = np.zeros(embedder_output_dim)
else:
mean_vector = cast(np.ndarray, token_vectors.cpu().numpy()).mean(0)
intermediate_weights.append(mean_vector)
out_weights = np.stack(intermediate_weights)
else:
out_weights = cast(
np.ndarray,
embedder.encode(tokens, show_progress_bar=True, output_value=output_value, batch_size=_DEFAULT_BATCH_SIZE),
)
intermediate_weights: list[np.ndarray] = []

for batch_idx in tqdm(range(0, len(tokens), _DEFAULT_BATCH_SIZE)):
batch = tokens[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]
out = _encode_mean_using_model(model, tokenizer, batch)
intermediate_weights.append(out.numpy())
out_weights = np.concatenate(intermediate_weights)

return tokens, out_weights


def _get_embedder_output_dim(output_value: OutputValue, embedder: SentenceTransformer) -> int:
"""Get the embeddings dimension of a sentence transformer, given an output value."""
embedder_output_dim = embedder.encode(["a"], show_progress_bar=False, output_value=output_value)[0].shape[1]
@torch.no_grad()
def _encode_mean_using_model(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, tokens: list[str]) -> torch.Tensor:
"""
Encode a batch of tokens using a model.

Note that if a token in the input batch does not have any embeddings, it will be output as a vector of zeros.
So detection of these is necessary.

:param model: The model to use.
:param tokenizer: The tokenizer to use.
:param tokens: The tokens to encode.
:return: The mean of the output for each token.
"""
encodings = tokenizer(tokens, return_tensors="pt", padding=True, truncation=True).to(model.device)
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings)
out = encoded.last_hidden_state.cpu()

mask = encodings["attention_mask"].cpu()
# NOTE: evil hack. For any batch, there will be a mask vector
# which has all 1s, because we pad to max_length. argmin returns 0
# in this case, which is wrong. But because we end up subtracting 1
# from it, we use -1, which is correct.
last_nonzero_index = mask.argmin(1) - 1
# NOTE: do not change the order of these calls. If you do, the hack
# above will no longer be evil (it will turn good), and will no longer work.
mask[torch.arange(mask.shape[0]), last_nonzero_index] = 0
mask[:, 0] = 0

# We take the mean of embeddings by first summing
result = torch.bmm(mask[:, None, :].float(), out).squeeze(1)

# Divide by the number of non-padding tokens, non-cls, etc. tokens.
divisor = mask.sum(1)
# Account for the case where divisor is 0.
divisor[divisor == 0] = 1

return embedder_output_dim
return result / divisor[:, None]


def create_output_embeddings_from_model_name(
model_name: PathLike,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
device: str,
) -> tuple[list[str], np.ndarray]:
"""
Create output embeddings for a bunch of tokens from a model name.

It does a forward pass for all ids in the tokenizer.

:param model_name: The model name to use.
:param model: The model name to use.
:param tokenizer: The tokenizer to use.
:param device: The torch device to use.
:return: The tokens and output emnbeddings.
:return: The tokens and output embeddings.
"""
model: PreTrainedModel = AutoModel.from_pretrained(model_name).to(device)
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_name)
model = model.to(device)
ids = torch.arange(tokenizer.vocab_size)

# Work-around
# Work-around to get the eos and bos token ids without having to go into tokenizer internals.
dummy_encoding = tokenizer.encode("A")
eos_token_id, bos_token_id = dummy_encoding[0], dummy_encoding[-1]

Expand All @@ -114,7 +120,7 @@ def create_output_embeddings_from_model_name(

intermediate_weights: list[np.ndarray] = []
for batch_idx in tqdm(range(0, len(stacked), _DEFAULT_BATCH_SIZE)):
batch = stacked[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]
batch = stacked[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE].to(model.device)
with torch.no_grad():
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(input_ids=batch.to(device))
out: torch.Tensor = encoded.last_hidden_state
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ dependencies = [
"typer",
"transformers",
"torch",
"sentence_transformers",
"tokenizers",
"scikit-learn",
"setuptools",
Expand Down
Loading