Skip to content

Commit

Permalink
Merge pull request #131 from KennethEnevoldsen/mistral-instructions
Browse files Browse the repository at this point in the history
Added instructions for all tasks in Mistral E5
  • Loading branch information
x-tabdeveloping authored Feb 6, 2024
2 parents 16a4a81 + 07c4c95 commit 006c253
Showing 1 changed file with 64 additions and 7 deletions.
71 changes: 64 additions & 7 deletions src/seb/registered_models/e5_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,63 @@ def batched(iterable: Iterable[T], n: int) -> Iterable[tuple[T, ...]]:
yield batch


def task_to_instruction(task: Task) -> str:
if task.task_type in ["STS"]:
return "Retrieve semantically similar text"
if task.task_type in ["Summarization"]:
return "Given a news summary, retrieve other semantically similar summaries"
if task.task_type in ["BitextMining"]:
task_name_to_instruct: dict[str, str] = {
"Bornholm Parallel": "Retrieve parallel sentences in Danish and Bornholmsk",
"Norwegian courts": "Retrieve parallel sentences in Norwegian Bokmål and Nynorsk",
}
default_instruction = "Retrieve parallel sentences."
return task_name_to_instruct.get(task.name, default_instruction)
if task.task_type in ["Classification"]:
task_name_to_instruct: dict[str, str] = {
"Angry Tweets": "Classify Danish tweets by sentiment. (positive, negative, neutral)",
"DKHate": "Classify Danish tweets based on offensiveness (offensive, not offensive)",
"Da Political Comments": "Classify Danish political comments for sentiment",
"DaLAJ": "Classify texts based on linguistic acceptability in Swedish",
"LCC": "Classify texts based on sentiment",
"Language Identification": "Classify texts based on language",
"Massive Intent": "Given a user utterance as query, find the user intents",
"Massive Scenario": "Given a user utterance as query, find the user scenarios",
"NoReC": "Classify Norwegian reviews by sentiment",
"SweReC": "Classify Swedish reviews by sentiment",
"Norwegian parliament": "Classify parliament speeches in Norwegian based on political affiliation",
"ScaLA": "Classify passages in Scandinavian Languages based on linguistic acceptability",
}
default_instruction = "Classify user passages"
return task_name_to_instruct.get(task.name, default_instruction)
if task.task_type in ["Clustering"]:
task_name_to_instruct: dict[str, str] = {
"ArxivClusteringP2P": "Identify the main and secondary category of Arxiv papers based on the titles and abstracts",
"VG Clustering": "Identify the categories (e.g. sports) of given articles in Norwegian",
"SNL Clustering": "Identify categories in a Norwegian lexicon",
"SwednClustering": "Identify news categories in Swedish passages",
}
default_instruction = "Identify categories in user passages"
return task_name_to_instruct.get(task.name, default_instruction)
if task.task_type in ["Reranking"]:
return "Retrieve semantically similar passages."
if task.task_type in ["Retrieval"]:
task_name_to_instruct: dict[str, str] = {
"Twitterhjerne": "Retrieve answers to questions asked in Danish tweets",
"SwednRetrieval": "Retrieve summaries of Swedish news articles",
"TV2Nord Retrieval": "Retrieve summaries of Danish news articles",
"DanFEVER": "Given a claim in Danish, retrieve documents that support or refute the claim",
"SNL Retrieval": "Given a lexicon article in Norwegian, retrieve its headline",
"NorQuad": "Given a question in Norwegian, retrieve the answer from Wikipedia articles",
"SweFAQ": "Retrieve answers given questions in Swedish",
"ArguAna": "Given a claim, find documents that refute the claim",
"ClimateFEVER": "Given a claim about climate change, retrieve documents that support or refute the claim",
}
default_instruction = "Retrieve text based on user query."
return task_name_to_instruct.get(task.name, default_instruction)
return ""


class E5Mistral(Encoder):
max_length = 4096

Expand All @@ -34,12 +91,8 @@ def load_model(self):
self.tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-mistral-7b-instruct")
self.model = AutoModel.from_pretrained("intfloat/e5-mistral-7b-instruct")

def preprocess(self, sentences: Sequence[str]) -> BatchEncoding:
# following the documentation we should also add "Instruction: " to the instruction, but for now I will just create this naive approach
# task = "" # Could e.g. be: "Given a web search query, retrieve relevant passages that answer the query" # noqa
# And then:
# f"Instruction: {task} Query: {sentence}" for sentence in sentences
sentences = ["Query: " + sentence for sentence in sentences]
def preprocess(self, sentences: Sequence[str], instruction: str) -> BatchEncoding:
sentences = [f"Instruction: {instruction} Query: {sentence}" for sentence in sentences]
batch_dict = self.tokenizer(
sentences,
max_length=self.max_length - 1,
Expand Down Expand Up @@ -80,8 +133,12 @@ def encode(
**kwargs: Any, # noqa
) -> ArrayLike:
batched_embeddings = []
if task is not None:
instruction = task_to_instruction(task)
else:
instruction = ""
for batch in batched(sentences, batch_size):
batch_dict = self.preprocess(batch)
batch_dict = self.preprocess(batch, instruction=instruction)

outputs = self.model(**batch_dict)
embeddings = self.last_token_pool(
Expand Down

0 comments on commit 006c253

Please sign in to comment.