Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using Reranker in a multithreaded process issues Already borrowed Runtime Exeption #42

Closed
sam-bercovici opened this issue Oct 28, 2024 · 4 comments

Comments

@sam-bercovici
Copy link
Contributor

I am using colbert.

see: huggingface/tokenizers#537

I suggest you allow to pass tokenizer_kwargs and model_kewargs to the Reranker factory class which will pass it forward.

follows an example on how to modify the ColBERTRanker ini

I marked the modification with ## change

    def __init__(
        self,
        model_name: str,
        batch_size: int = 32,
        dtype: Optional[Union[str, torch.dtype]] = None,
        device: Optional[Union[str, torch.device]] = None,
        verbose: int = 1,
        query_token: str = "[unused0]",
        document_token: str = "[unused1]",
        **kwargs, ## change
    ):
        self.verbose = verbose
        self.device = get_device(device, self.verbose)
        self.dtype = get_dtype(dtype, self.device, self.verbose)
        self.batch_size = batch_size
        vprint(
            f"Loading model {model_name}, this might take a while...",
            self.verbose,
        )
        tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) ## change
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) ## change
        model_kwargs = kwargs.get("model_kwargs", {}) ## change
        self.model = (
            ColBERTModel.from_pretrained(model_name, **model_kwargs) ## change
            .to(self.device)
            .to(self.dtype)
        )
        self.model.eval()
        self.query_max_length = 32  # Lower bound
        self.doc_max_length = (
            self.model.config.max_position_embeddings - 2
        )  # Upper bound
        self.query_token_id: int = self.tokenizer.convert_tokens_to_ids(query_token)  # type: ignore
        self.document_token_id: int = self.tokenizer.convert_tokens_to_ids(
            document_token
        )  # type: ignore
        self.normalize = True
@bclavie
Copy link
Collaborator

bclavie commented Nov 4, 2024

Thanks for flagging! Would you be willing to submit your proposed changes as a PR? I'm happy with this logic being added to handle various kwargs situations!

@sam-bercovici
Copy link
Contributor Author

Thanks for flagging! Would you be willing to submit your proposed changes as a PR? I'm happy with this logic being added to handle various kwargs situations!

Sure.
I will try to find a couple of hours to do so in the next week or so.

@sam-bercovici
Copy link
Contributor Author

see #44

@bclavie
Copy link
Collaborator

bclavie commented Nov 12, 2024

Merged, thank you! Will ship with 0.0.6 in ~30mn

@bclavie bclavie closed this as completed Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants