Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

请问有支持更高版本的 transformers 和 pydantic 2.x 版本的计划吗? #90

Open
NagatoYuki0943 opened this issue Sep 21, 2024 · 1 comment

Comments

@NagatoYuki0943
Copy link

No description provided.

@A0nameless0man
Copy link

创建以下文件为 BCERerank.py ,并以from BCERerank import BCERerank替代 from BCEmbedding.tools.langchain import BCERerank 即可暂时解决问题:

from __future__ import annotations

from typing import Dict, Optional, Sequence, Any

from langchain_core.documents import Document

from pydantic import model_validator
from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor

from pydantic import PrivateAttr


class BCERerank(BaseDocumentCompressor):
    """Document compressor that uses `BCEmbedding RerankerModel API`."""

    client: str = "BCEmbedding"
    top_n: int = 3
    """Number of documents to return."""
    model: str = "maidalun1020/bce-reranker-base_v1"
    """Model to use for reranking."""
    _model: Any = PrivateAttr()

    class Config:
        """Configuration for this pydantic object."""

        extra = "forbid"
        arbitrary_types_allowed = True

    def __init__(
        self,
        top_n: int = 3,
        model: str = "maidalun1020/bce-reranker-base_v1",
        device: Optional[str] = None,
        **kwargs,
    ):
        super().__init__(top_n=top_n, model=model)
        try:
            from BCEmbedding.models import RerankerModel
        except ImportError:
            raise ImportError(
                "Cannot import `BCEmbedding` package,",
                "please `pip install BCEmbedding>=0.1.2`",
            )
        self._model = RerankerModel(model_name_or_path=model, device=device, **kwargs)

    # @model_validator(mode="before")
    # def validate_environment(cls, values: Dict) -> Dict:
    #     """Validate that api key and python package exists in environment."""
    #     values["client"] = "BCEmbedding.models.RerankerModel"
    #     return values

    def compress_documents(
        self,
        documents: Sequence[Document],
        query: str,
        callbacks: Optional[Callbacks] = None,
    ) -> Sequence[Document]:
        """
        Compress documents using `BCEmbedding RerankerModel API`.

        Args:
            documents: A sequence of documents to compress.
            query: The query to use for compressing the documents.
            callbacks: Callbacks to run during the compression process.

        Returns:
            A sequence of compressed documents.
        """
        if len(documents) == 0:  # to avoid empty api call
            return []
        doc_list = list(documents)

        passages = []
        valid_doc_list = []
        invalid_doc_list = []
        for d in doc_list:
            passage = d.page_content
            if isinstance(passage, str) and len(passage) > 0:
                passages.append(passage.replace("\n", " "))
                valid_doc_list.append(d)
            else:
                invalid_doc_list.append(d)

        rerank_result = self._model.rerank(query, passages)
        final_results = []
        for score, doc_id in zip(
            rerank_result["rerank_scores"], rerank_result["rerank_ids"]
        ):
            doc = valid_doc_list[doc_id]
            doc.metadata["relevance_score"] = score
            final_results.append(doc)
        for doc in invalid_doc_list:
            doc.metadata["relevance_score"] = 0
            final_results.append(doc)

        final_results = final_results[: self.top_n]
        return final_results

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants