Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/multimodal #45

Merged
merged 10 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
360 changes: 360 additions & 0 deletions examples/reranker_images.ipynb

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ packages = [
name = "rerankers"


version = "0.5.3"
version = "0.6.0"

description = "A unified API for various document re-ranking models."

Expand Down Expand Up @@ -53,22 +53,26 @@ dependencies = [

[project.optional-dependencies]
all = [
"transformers",
"transformers>=4.45.0",
"torch",
"litellm",
"requests",
"sentencepiece",
"protobuf",
"flashrank",
"flash-attn",
"pillow",
"accelerate>=0.26.0",
"peft>=0.13.0",
"nmslib-metabrainz; python_version >= '3.10'",
"rank-llm; python_version >= '3.10'"
]
transformers = ["transformers", "torch", "sentencepiece", "protobuf"]
transformers = ["transformers>=4.45.0", "torch", "sentencepiece", "protobuf"]
api = ["requests"]
gpt = ["litellm"]
flashrank = ["flashrank"]
llmlayerwise = ["transformers", "torch", "sentencepiece", "protobuf", "flash-attn"]
llmlayerwise = ["transformers>=4.45.0", "torch", "sentencepiece", "protobuf", "flash-attn"]
monovlm = ["transformers>=4.45.0", "torch", "sentencepiece", "protobuf", "flash-attn", "pillow", "accelerate>=0.26.0", "peft>=0.13.0"]
rankllm = [
"nmslib-metabrainz; python_version >= '3.10'",
"rank-llm; python_version >= '3.10'"
Expand Down
2 changes: 1 addition & 1 deletion rerankers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from rerankers.documents import Document

__all__ = ["Reranker", "Document"]
__version__ = "0.5.3"
__version__ = "0.6.0"
29 changes: 24 additions & 5 deletions rerankers/documents.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,35 @@
from typing import Optional, Union
from pydantic import BaseModel
from typing import Optional, Union, Literal
from pydantic import BaseModel, validator


class Document(BaseModel):
text: str
document_type: Literal["text", "image"] = "text"
text: Optional[str] = None
base64: Optional[str] = None
image_path: Optional[str] = None
doc_id: Optional[Union[str, int]] = None
metadata: Optional[dict] = None

@validator("text")
def validate_text(cls, v, values):
if values.get("document_type") == "text" and v is None:
raise ValueError("text field is required when document_type is 'text'")
return v

def __init__(
self,
text: str,
text: Optional[str] = None,
doc_id: Optional[Union[str, int]] = None,
metadata: Optional[dict] = None,
document_type: Literal["text", "image"] = "text",
image_path: Optional[str] = None,
base64: Optional[str] = None,
):
super().__init__(text=text, doc_id=doc_id, metadata=metadata)
super().__init__(
text=text,
doc_id=doc_id,
metadata=metadata,
document_type=document_type,
base64=base64,
image_path=image_path,
)
6 changes: 6 additions & 0 deletions rerankers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,9 @@
AVAILABLE_RANKERS["LLMLayerWiseRanker"] = LLMLayerWiseRanker
except ImportError:
pass

try:
from rerankers.models.monovlm_ranker import MonoVLMRanker
AVAILABLE_RANKERS["MonoVLMRanker"] = MonoVLMRanker
except ImportError:
pass
6 changes: 3 additions & 3 deletions rerankers/models/colbert_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _document_encode(self, documents: list[str]):
return self._encode(documents, self.document_token_id)

def _to_embs(self, encoding) -> torch.Tensor:
with torch.no_grad():
with torch.inference_mode():
# embs = self.model(**encoding).last_hidden_state.squeeze(1)
embs = self.model(**encoding)
if self.normalize:
Expand Down Expand Up @@ -271,7 +271,7 @@ def score(self, query: str, doc: str) -> float:
scores = self._colbert_rank(query, [doc])
return scores[0] if scores else 0.0

@torch.no_grad()
@torch.inference_mode()
def _colbert_rank(
self,
query: str,
Expand Down Expand Up @@ -377,7 +377,7 @@ def _encode(
return encoding

def _to_embs(self, encoding) -> torch.Tensor:
with torch.no_grad():
with torch.inference_mode():
batched_embs = []
for i in range(0, encoding["input_ids"].size(0), self.batch_size):
batch_encoding = {
Expand Down
4 changes: 2 additions & 2 deletions rerankers/models/llm_layerwise_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _get_inputs(self, pairs, max_sequence_length: int):
return_tensors="pt",
)

@torch.no_grad()
@torch.inference_mode()
def rank(
self,
query: str,
Expand Down Expand Up @@ -177,7 +177,7 @@ def rank(
]
return RankedResults(results=ranked_results, query=query, has_scores=True)

@torch.no_grad()
@torch.inference_mode()
def score(self, query: str, doc: str) -> float:
inputs = self._get_inputs(
[(query, doc)], max_sequence_length=self.max_sequence_length
Expand Down
164 changes: 164 additions & 0 deletions rerankers/models/monovlm_ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import torch
from PIL import Image
import base64
import io
# TODO: Support more than Qwen
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from rerankers.models.ranker import BaseRanker
from rerankers.documents import Document
from typing import Union, List, Optional
from rerankers.utils import vprint, get_device, get_dtype, prep_image_docs
from rerankers.results import RankedResults, Result

PREDICTION_TOKENS = {
"default": ["False", "True"],
"lightonai/MonoQwen2-VL-v0.1": ["False", "True"]
}

def _get_output_tokens(model_name_or_path, token_false: str, token_true: str):
if token_false == "auto":
if model_name_or_path in PREDICTION_TOKENS:
token_false = PREDICTION_TOKENS[model_name_or_path][0]
else:
token_false = PREDICTION_TOKENS["default"][0]
print(
f"WARNING: Model {model_name_or_path} does not have known True/False tokens. Defaulting token_false to `{token_false}`."
)
if token_true == "auto":
if model_name_or_path in PREDICTION_TOKENS:
token_true = PREDICTION_TOKENS[model_name_or_path][1]
else:
token_true = PREDICTION_TOKENS["default"][1]
print(
f"WARNING: Model {model_name_or_path} does not have known True/False tokens. Defaulting token_true to `{token_true}`."
)

return token_false, token_true

class MonoVLMRanker(BaseRanker):
def __init__(
self,
model_name_or_path: str,
processor_name: Optional[str] = None,
dtype: Optional[Union[str, torch.dtype]] = 'bf16',
device: Optional[Union[str, torch.device]] = None,
batch_size: int = 1,
verbose: int = 1,
token_false: str = "auto",
token_true: str = "auto",
return_logits: bool = False,
prompt_template: str = "Assert the relevance of the previous image document to the following query, answer True or False. The query is: {query}",
**kwargs
):
self.verbose = verbose
self.device = get_device(device, verbose=self.verbose)
if self.device == 'mps':
print("WARNING: MPS is not supported by MonoVLMRanker due to PyTorch limitations. Falling back to CPU.")
self.device = 'cpu'
print(dtype)
self.dtype = get_dtype(dtype, self.device, self.verbose)
self.batch_size = batch_size
self.return_logits = return_logits
self.prompt_template = prompt_template

vprint(f"Loading model {model_name_or_path}, this might take a while...", self.verbose)
vprint(f"Using device {self.device}.", self.verbose)
vprint(f"Using dtype {self.dtype}.", self.verbose)

processor_name = processor_name or "Qwen/Qwen2-VL-2B-Instruct"
processor_kwargs = kwargs.get("processor_kwargs", {})
model_kwargs = kwargs.get("model_kwargs", {})
attention_implementation = kwargs.get("attention_implementation", "flash_attention_2")
self.processor = AutoProcessor.from_pretrained(processor_name, **processor_kwargs)
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name_or_path,
device_map=self.device,
torch_dtype=self.dtype,
attn_implementation=attention_implementation,
**model_kwargs
)
self.model.eval()

token_false, token_true = _get_output_tokens(
model_name_or_path=model_name_or_path,
token_false=token_false,
token_true=token_true,
)
self.token_false_id = self.processor.tokenizer.convert_tokens_to_ids(token_false)
self.token_true_id = self.processor.tokenizer.convert_tokens_to_ids(token_true)

vprint(f"VLM true token set to {token_true}", self.verbose)
vprint(f"VLM false token set to {token_false}", self.verbose)

@torch.inference_mode()
def _get_scores(self, query: str, docs: List[Document]) -> List[float]:
scores = []
for doc in docs:
if doc.document_type != "image" or not doc.base64:
raise ValueError("MonoVLMRanker requires image documents with base64 data")

# Convert base64 to PIL Image
image_io = io.BytesIO(base64.b64decode(doc.base64))
image_io.seek(0) # Reset file pointer to start
image = Image.open(image_io).convert('RGB')

# Prepare prompt
prompt = self.prompt_template.format(query=query)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]

# Process inputs
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = self.processor(
text=text,
images=image,
return_tensors="pt"
).to(self.device).to(self.dtype)

# Get model outputs
outputs = self.model(**inputs)
logits = outputs.logits[:, -1, :]

# Calculate scores
relevant_logits = logits[:, [self.token_false_id, self.token_true_id]]
if self.return_logits:
score = relevant_logits[0, 1].cpu().item() # True logit
else:
probs = torch.softmax(relevant_logits, dim=-1)
score = probs[0, 1].cpu().item() # True probability

scores.append(score)

return scores

def rank(
self,
query: str,
docs: Union[str, List[str], Document, List[Document]],
doc_ids: Optional[Union[List[str], List[int]]] = None,
metadata: Optional[List[dict]] = None,
) -> RankedResults:
docs = prep_image_docs(docs, doc_ids, metadata)
scores = self._get_scores(query, docs)
ranked_results = [
Result(document=doc, score=score, rank=idx + 1)
for idx, (doc, score) in enumerate(
sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
)
]
return RankedResults(results=ranked_results, query=query, has_scores=True)

def score(self, query: str, doc: Union[str, Document]) -> float:
scores = self._get_scores(query, [doc])
return scores[0]
4 changes: 2 additions & 2 deletions rerankers/models/t5ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def score(self, query: str, doc: str) -> float:
scores = self._get_scores(query, [doc])
return scores[0] if scores else 0.0

@torch.no_grad()
@torch.inference_mode()
def _get_scores(
self,
query: str,
Expand Down Expand Up @@ -231,7 +231,7 @@ def _get_scores(
return logits
return scores

@torch.no_grad()
@torch.inference_mode()
def _greedy_decode(
self,
model,
Expand Down
4 changes: 2 additions & 2 deletions rerankers/models/transformer_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def tokenize(self, inputs: Union[str, List[str], List[Tuple[str, str]]]):
inputs, return_tensors="pt", padding=True, truncation=True
).to(self.device)

@torch.no_grad()
@torch.inference_mode()
def rank(
self,
query: str,
Expand Down Expand Up @@ -83,7 +83,7 @@ def rank(
]
return RankedResults(results=ranked_results, query=query, has_scores=True)

@torch.no_grad()
@torch.inference_mode()
def score(self, query: str, doc: str) -> float:
inputs = self.tokenize((query, doc))
outputs = self.model(**inputs)
Expand Down
10 changes: 9 additions & 1 deletion rerankers/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
"en": "BAAI/bge-reranker-v2.5-gemma2-lightweight",
"other": "BAAI/bge-reranker-v2.5-gemma2-lightweight",
},
"monovlm": {
"en": "lightonai/MonoQwen2-VL-v0.1",
"other": "lightonai/MonoQwen2-VL-v0.1"
}
}

DEPS_MAPPING = {
Expand All @@ -47,6 +51,7 @@
"FlashRankRanker": "flashrank",
"RankLLMRanker": "rankllm",
"LLMLayerWiseRanker": "transformers",
"MonoVLMRanker": "transformers"
}

PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai", "text-embeddings-inference"]
Expand Down Expand Up @@ -84,6 +89,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"flashrank": "FlashRankRanker",
"rankllm": "RankLLMRanker",
"llm-layerwise": "LLMLayerWiseRanker",
"monovlm": "MonoVLMRanker"
}
return model_mapping.get(explicit_model_type, explicit_model_type)
else:
Expand All @@ -105,6 +111,8 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"vicuna": "RankLLMRanker",
"zephyr": "RankLLMRanker",
"bge-reranker-v2.5-gemma2-lightweight": "LLMLayerWiseRanker",
"monovlm": "MonoVLMRanker",
"monoqwen2-vl": "MonoVLMRanker"
}
for key, value in model_mapping.items():
if key in model_name:
Expand Down Expand Up @@ -198,7 +206,7 @@ def Reranker(
model_type = _get_model_type(model_name, model_type)

try:
print(f"Loading {model_type} model {model_name}")
vprint(f"Loading {model_type} model {model_name} (this message can be suppressed by setting verbose=0)", verbose)
return AVAILABLE_RANKERS[model_type](model_name, verbose=verbose, **kwargs)
except KeyError:
print(
Expand Down
Loading