From 07c4c9572eff110605629dd6bbfab95df9c3786e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Mon, 5 Feb 2024 10:13:18 +0100 Subject: [PATCH] Added instructions for all tasks in Mistral E5 --- src/seb/registered_models/e5_mistral.py | 71 ++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 7 deletions(-) diff --git a/src/seb/registered_models/e5_mistral.py b/src/seb/registered_models/e5_mistral.py index f907d945..183ffda2 100644 --- a/src/seb/registered_models/e5_mistral.py +++ b/src/seb/registered_models/e5_mistral.py @@ -24,6 +24,63 @@ def batched(iterable: Iterable[T], n: int) -> Iterable[tuple[T, ...]]: yield batch +def task_to_instruction(task: Task) -> str: + if task.task_type in ["STS"]: + return "Retrieve semantically similar text" + if task.task_type in ["Summarization"]: + return "Given a news summary, retrieve other semantically similar summaries" + if task.task_type in ["BitextMining"]: + task_name_to_instruct: dict[str, str] = { + "Bornholm Parallel": "Retrieve parallel sentences in Danish and Bornholmsk", + "Norwegian courts": "Retrieve parallel sentences in Norwegian Bokmål and Nynorsk", + } + default_instruction = "Retrieve parallel sentences." + return task_name_to_instruct.get(task.name, default_instruction) + if task.task_type in ["Classification"]: + task_name_to_instruct: dict[str, str] = { + "Angry Tweets": "Classify Danish tweets by sentiment. (positive, negative, neutral)", + "DKHate": "Classify Danish tweets based on offensiveness (offensive, not offensive)", + "Da Political Comments": "Classify Danish political comments for sentiment", + "DaLAJ": "Classify texts based on linguistic acceptability in Swedish", + "LCC": "Classify texts based on sentiment", + "Language Identification": "Classify texts based on language", + "Massive Intent": "Given a user utterance as query, find the user intents", + "Massive Scenario": "Given a user utterance as query, find the user scenarios", + "NoReC": "Classify Norwegian reviews by sentiment", + "SweReC": "Classify Swedish reviews by sentiment", + "Norwegian parliament": "Classify parliament speeches in Norwegian based on political affiliation", + "ScaLA": "Classify passages in Scandinavian Languages based on linguistic acceptability", + } + default_instruction = "Classify user passages" + return task_name_to_instruct.get(task.name, default_instruction) + if task.task_type in ["Clustering"]: + task_name_to_instruct: dict[str, str] = { + "ArxivClusteringP2P": "Identify the main and secondary category of Arxiv papers based on the titles and abstracts", + "VG Clustering": "Identify the categories (e.g. sports) of given articles in Norwegian", + "SNL Clustering": "Identify categories in a Norwegian lexicon", + "SwednClustering": "Identify news categories in Swedish passages", + } + default_instruction = "Identify categories in user passages" + return task_name_to_instruct.get(task.name, default_instruction) + if task.task_type in ["Reranking"]: + return "Retrieve semantically similar passages." + if task.task_type in ["Retrieval"]: + task_name_to_instruct: dict[str, str] = { + "Twitterhjerne": "Retrieve answers to questions asked in Danish tweets", + "SwednRetrieval": "Retrieve summaries of Swedish news articles", + "TV2Nord Retrieval": "Retrieve summaries of Danish news articles", + "DanFEVER": "Given a claim in Danish, retrieve documents that support or refute the claim", + "SNL Retrieval": "Given a lexicon article in Norwegian, retrieve its headline", + "NorQuad": "Given a question in Norwegian, retrieve the answer from Wikipedia articles", + "SweFAQ": "Retrieve answers given questions in Swedish", + "ArguAna": "Given a claim, find documents that refute the claim", + "ClimateFEVER": "Given a claim about climate change, retrieve documents that support or refute the claim", + } + default_instruction = "Retrieve text based on user query." + return task_name_to_instruct.get(task.name, default_instruction) + return "" + + class E5Mistral(Encoder): max_length = 4096 @@ -34,12 +91,8 @@ def load_model(self): self.tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-mistral-7b-instruct") self.model = AutoModel.from_pretrained("intfloat/e5-mistral-7b-instruct") - def preprocess(self, sentences: Sequence[str]) -> BatchEncoding: - # following the documentation we should also add "Instruction: " to the instruction, but for now I will just create this naive approach - # task = "" # Could e.g. be: "Given a web search query, retrieve relevant passages that answer the query" # noqa - # And then: - # f"Instruction: {task} Query: {sentence}" for sentence in sentences - sentences = ["Query: " + sentence for sentence in sentences] + def preprocess(self, sentences: Sequence[str], instruction: str) -> BatchEncoding: + sentences = [f"Instruction: {instruction} Query: {sentence}" for sentence in sentences] batch_dict = self.tokenizer( sentences, max_length=self.max_length - 1, @@ -80,8 +133,12 @@ def encode( **kwargs: Any, # noqa ) -> ArrayLike: batched_embeddings = [] + if task is not None: + instruction = task_to_instruction(task) + else: + instruction = "" for batch in batched(sentences, batch_size): - batch_dict = self.preprocess(batch) + batch_dict = self.preprocess(batch, instruction=instruction) outputs = self.model(**batch_dict) embeddings = self.last_token_pool(