Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Choose class-candidates during inference #22

Open
Ulipenitz opened this issue Jul 26, 2023 · 0 comments
Open

Choose class-candidates during inference #22

Ulipenitz opened this issue Jul 26, 2023 · 0 comments

Comments

@Ulipenitz
Copy link

Ulipenitz commented Jul 26, 2023

Really great library @tomaarsen, thank you for this great contribution!!

It would be really convenient to have the possibility to give a list of class-candidates to the predict method during inference. This would be useful for Retriever-Reader systems, where the Retriever (e.g. Setfit-Model) returns text sequences where it is already known what set of classes are available for the Reader (e.g. SpanMarker) for extraction and you do not want to extract other classes.

E.g. a system like this from here https://lilianweng.github.io/posts/2020-10-29-odqa/:
Capture

I was thinking about modifications to the predict method like these:

def predict(self, ... , class_candidates: Optional[List[str]] = None):
    
    ...

    if class_candidates is not None:
        # convert class names to class ids
        label2id = self.config.label2id
        class_candidate_ids = [label2id[c] for c in class_candidates if c in label2id]

    for batch_start_idx in trange(0, len(dataset), batch_size, leave=True, disable=not show_progress_bar):
        
        ...
        # Computing probabilities based on the logits
        probs = output.logits.softmax(-1)

        # Mask everything except class-candidate probabilities
        if class_candidates is not None:
            mask = torch.zeros_like(probs)
            mask[:, :, class_candidate_ids] = 1
            probs = probs * mask

        # Get the labels and the correponding probability scores
        scores, labels = probs.max(-1)

        ...

    return all_entities

I did not find time to have a deep dive, implement & test it, but I think this could be a useful feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant