Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RepLLaMA style models #1223

Merged
merged 6 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mteb/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
mxbai_models,
nomic_models,
openai_models,
promptriever_models,
repllama_models,
ru_sentence_models,
salesforce_models,
sentence_transformers_models,
Expand Down Expand Up @@ -133,6 +135,8 @@ def model_meta_from_sentence_transformers(model: SentenceTransformer) -> ModelMe
mxbai_models,
nomic_models,
openai_models,
promptriever_models,
repllama_models,
ru_sentence_models,
salesforce_models,
sentence_transformers_models,
Expand Down
121 changes: 121 additions & 0 deletions mteb/models/promptriever_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

import logging
from typing import Any, Callable, Literal

import numpy as np
import torch

from mteb.encoder_interface import Encoder
from mteb.model_meta import ModelMeta

from .repllama_models import RepLLaMAWrapper

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)

EncodeTypes = Literal["query", "passage"]


class PromptrieverWrapper(RepLLaMAWrapper):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def encode_queries(self, queries: list[str], **kwargs: Any) -> np.ndarray:
queries = [f"query: {query}" for query in queries]
if "instruction" in kwargs:
end_punct_list = [
"?" if query.strip()[-1] not in ["?", ".", "!"] else ""
for query in queries
]
queries = [
f"{query}{end_punct_list[i]} {kwargs['instruction']}"
for i, query in enumerate(queries)
]
return self.encode(queries, **kwargs)


def _loader(wrapper: type[PromptrieverWrapper], **kwargs) -> Callable[..., Encoder]:
_kwargs = kwargs

def loader_inner(**kwargs: Any) -> Encoder:
return wrapper(**_kwargs, **kwargs)

return loader_inner


promptriever_llama2 = ModelMeta(
loader=_loader(
RepLLaMAWrapper,
base_model_name_or_path="meta-llama/Llama-2-7b-hf",
peft_model_name_or_path="samaya-ai/promptriever-llama2-7b-v1",
device_map="auto",
torch_dtype=torch.bfloat16,
),
name="samaya-ai/promptriever-llama2-7b-v1",
languages=["eng_Latn"],
open_source=True,
revision=None, # TODO: Not sure what to put here as a model is made of two peft repos, each with a different revision
orionw marked this conversation as resolved.
Show resolved Hide resolved
release_date="2024-09-15",
)

promptriever_llama3 = ModelMeta(
loader=_loader(
RepLLaMAWrapper,
base_model_name_or_path="meta-llama/Meta-Llama-3.1-8B",
peft_model_name_or_path="samaya-ai/promptriever-llama3.1-8b-v1",
device_map="auto",
torch_dtype=torch.bfloat16,
),
name="samaya-ai/promptriever-llama3.1-8b-v1",
languages=["eng_Latn"],
open_source=True,
revision=None, # TODO: Not sure what to put here as a model is made of two peft repos, each with a different revision
release_date="2024-09-15",
)


promptriever_llama3_instruct = ModelMeta(
loader=_loader(
RepLLaMAWrapper,
base_model_name_or_path="meta-llama/Meta-Llama-3.1-8B-Instruct",
peft_model_name_or_path="samaya-ai/promptriever-llama3.1-8b-instruct-v1",
device_map="auto",
torch_dtype=torch.bfloat16,
),
name="samaya-ai/promptriever-llama3.1-8b-instruct-v1",
languages=["eng_Latn"],
open_source=True,
revision=None, # TODO: Not sure what to put here as a model is made of two peft repos, each with a different revision
release_date="2024-09-15",
)

promptriever_mistral_v1 = ModelMeta(
loader=_loader(
RepLLaMAWrapper,
base_model_name_or_path="mistralai/Mistral-7B-v0.1",
peft_model_name_or_path="samaya-ai/promptriever-mistral-v0.1-7b-v1",
device_map="auto",
torch_dtype=torch.bfloat16,
),
name="samaya-ai/promptriever-mistral-v0.1-7b-v1",
languages=["eng_Latn"],
open_source=True,
revision=None, # TODO: Not sure what to put here as a model is made of two peft repos, each with a different revision
release_date="2024-09-15",
)

promptriever_mistral_v3 = ModelMeta(
loader=_loader(
RepLLaMAWrapper,
base_model_name_or_path="mistralai/Mistral-7B-v0.3",
peft_model_name_or_path="samaya-ai/promptriever-mistral-v0.3-7b-v1",
device_map="auto",
torch_dtype=torch.bfloat16,
),
name="samaya-ai/promptriever-mistral-v0.3-7b-v1",
languages=["eng_Latn"],
open_source=True,
revision=None, # TODO: Not sure what to put here as a model is made of two peft repos, each with a different revision
release_date="2024-09-15",
)
174 changes: 174 additions & 0 deletions mteb/models/repllama_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from __future__ import annotations

import logging
from typing import Any, Callable, Literal

import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from transformers import AutoModel, AutoTokenizer

from mteb.encoder_interface import Encoder
from mteb.model_meta import ModelMeta
from mteb.models.text_formatting_utils import corpus_to_texts

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)

EncodeTypes = Literal["query", "passage"]


class RepLLaMAWrapper:
def __init__(self, *args, **kwargs):
try:
from peft import PeftModel
except ImportError:
raise ImportError(
"To use the RepLLaMA based models `peft` is required. Please install it with `pip install peft`."
)
orionw marked this conversation as resolved.
Show resolved Hide resolved

self.base_model = AutoModel.from_pretrained(
kwargs["base_model_name_or_path"],
torch_dtype=kwargs["torch_dtype"],
device_map=kwargs["device_map"],
)
self.model = PeftModel.from_pretrained(
self.base_model, kwargs["peft_model_name_or_path"]
)
self.model = self.model.merge_and_unload()

self.tokenizer = AutoTokenizer.from_pretrained(
kwargs["base_model_name_or_path"]
)
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "right"
# set the max_length for the evals as they did, although the model can handle longer
self.model.config.max_length = 512
self.tokenizer.model_max_length = 512

def create_batch_dict(self, tokenizer, input_texts):
max_length = self.model.config.max_length
batch_dict = tokenizer(
input_texts,
max_length=max_length - 1,
return_token_type_ids=False,
return_attention_mask=False,
padding=False,
truncation=True,
)
batch_dict["input_ids"] = [
input_ids + [tokenizer.eos_token_id]
for input_ids in batch_dict["input_ids"]
]
return tokenizer.pad(
batch_dict,
padding=True,
pad_to_multiple_of=8,
return_attention_mask=True,
return_tensors="pt",
)

def encode(
self,
sentences: list[str],
*,
prompt_name: str = None,
**kwargs: Any, # noqa
) -> np.ndarray:
batch_size = 16 if "batch_size" not in kwargs else kwargs.pop("batch_size")
all_embeddings = []
for i in tqdm.tqdm(range(0, len(sentences), batch_size)):
batch_texts = sentences[i : i + batch_size]

batch_dict = self.create_batch_dict(self.tokenizer, batch_texts)
batch_dict = {
key: value.to(self.model.device) for key, value in batch_dict.items()
}

with torch.cuda.amp.autocast():
with torch.no_grad():
outputs = self.model(**batch_dict)
last_hidden_state = outputs.last_hidden_state
sequence_lengths = batch_dict["attention_mask"].sum(dim=1) - 1
batch_size = last_hidden_state.shape[0]
reps = last_hidden_state[
torch.arange(batch_size, device=last_hidden_state.device),
sequence_lengths,
]
embeddings = F.normalize(reps, p=2, dim=-1)
all_embeddings.append(embeddings.cpu().numpy())

return np.concatenate(all_embeddings, axis=0)

def encode_corpus(
self,
corpus: list[dict[str, str]] | dict[str, list[str]] | list[str],
prompt_name: str = None,
**kwargs: Any,
) -> np.ndarray:
sentences = corpus_to_texts(corpus, sep=" ")
if "request_qid" in kwargs:
kwargs.pop("request_qid")
# NOTE: two spaces after the colon
sentences = [f"passage: {sentence}".strip() for sentence in sentences]
print(f"Encoding corpus of length {len(sentences)}")
print(f"First sentence: {sentences[0]}")
return self.encode(sentences, **kwargs)

def encode_queries(self, queries: list[str], **kwargs: Any) -> np.ndarray:
# NOTE: two spaces after the colon
queries = [f"query: {query.strip()}".strip() for query in queries]
print(f"Encoding queries of length {len(queries)}")
print(queries[0])
return self.encode(queries, **kwargs)


def _loader(wrapper: type[RepLLaMAWrapper], **kwargs) -> Callable[..., Encoder]:
_kwargs = kwargs

def loader_inner(**kwargs: Any) -> Encoder:
return wrapper(**_kwargs, **kwargs)

return loader_inner


repllama_llama2_original = ModelMeta(
loader=_loader(
RepLLaMAWrapper,
base_model_name_or_path="meta-llama/Llama-2-7b-hf",
peft_model_name_or_path="castorini/repllama-v1-7b-lora-passage",
device_map="auto",
torch_dtype=torch.bfloat16,
),
name="castorini/repllama-v1-7b-lora-passage",
languages=["eng_Latn"],
open_source=True,
revision=None, # TODO: Not sure what to put here as a model is made of two peft repos, each with a different revision
release_date="2023-10-11",
)


repllama_llama2_reproduced = ModelMeta(
loader=_loader(
RepLLaMAWrapper,
base_model_name_or_path="meta-llama/Llama-2-7b-hf",
peft_model_name_or_path="samaya-ai/RepLLaMA-reproduced",
device_map="auto",
torch_dtype=torch.bfloat16,
),
name="samaya-ai/RepLLaMA-reproduced",
languages=["eng_Latn"],
open_source=True,
revision=None, # TODO: Not sure what to put here as a model is made of two peft repos, each with a different revision
release_date="2024-09-15",
)


## Debug code
# import mteb
# model = mteb.get_model("samaya-ai/RepLLaMA-reproduced")
# tasks = mteb.get_tasks(tasks=["SciFact"], languages=["eng"])
# evaluation = mteb.MTEB(tasks=tasks)
# evaluation.run(model)
Loading