-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
请问有支持更高版本的 transformers 和 pydantic 2.x 版本的计划吗? #90
Comments
创建以下文件为 from __future__ import annotations
from typing import Dict, Optional, Sequence, Any
from langchain_core.documents import Document
from pydantic import model_validator
from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from pydantic import PrivateAttr
class BCERerank(BaseDocumentCompressor):
"""Document compressor that uses `BCEmbedding RerankerModel API`."""
client: str = "BCEmbedding"
top_n: int = 3
"""Number of documents to return."""
model: str = "maidalun1020/bce-reranker-base_v1"
"""Model to use for reranking."""
_model: Any = PrivateAttr()
class Config:
"""Configuration for this pydantic object."""
extra = "forbid"
arbitrary_types_allowed = True
def __init__(
self,
top_n: int = 3,
model: str = "maidalun1020/bce-reranker-base_v1",
device: Optional[str] = None,
**kwargs,
):
super().__init__(top_n=top_n, model=model)
try:
from BCEmbedding.models import RerankerModel
except ImportError:
raise ImportError(
"Cannot import `BCEmbedding` package,",
"please `pip install BCEmbedding>=0.1.2`",
)
self._model = RerankerModel(model_name_or_path=model, device=device, **kwargs)
# @model_validator(mode="before")
# def validate_environment(cls, values: Dict) -> Dict:
# """Validate that api key and python package exists in environment."""
# values["client"] = "BCEmbedding.models.RerankerModel"
# return values
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using `BCEmbedding RerankerModel API`.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
if len(documents) == 0: # to avoid empty api call
return []
doc_list = list(documents)
passages = []
valid_doc_list = []
invalid_doc_list = []
for d in doc_list:
passage = d.page_content
if isinstance(passage, str) and len(passage) > 0:
passages.append(passage.replace("\n", " "))
valid_doc_list.append(d)
else:
invalid_doc_list.append(d)
rerank_result = self._model.rerank(query, passages)
final_results = []
for score, doc_id in zip(
rerank_result["rerank_scores"], rerank_result["rerank_ids"]
):
doc = valid_doc_list[doc_id]
doc.metadata["relevance_score"] = score
final_results.append(doc)
for doc in invalid_doc_list:
doc.metadata["relevance_score"] = 0
final_results.append(doc)
final_results = final_results[: self.top_n]
return final_results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
No description provided.
The text was updated successfully, but these errors were encountered: