You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Really great library @tomaarsen, thank you for this great contribution!!
It would be really convenient to have the possibility to give a list of class-candidates to the predict method during inference. This would be useful for Retriever-Reader systems, where the Retriever (e.g. Setfit-Model) returns text sequences where it is already known what set of classes are available for the Reader (e.g. SpanMarker) for extraction and you do not want to extract other classes.
I was thinking about modifications to the predict method like these:
def predict(self, ... , class_candidates: Optional[List[str]] = None):
...
if class_candidates is not None:
# convert class names to class ids
label2id = self.config.label2id
class_candidate_ids = [label2id[c] for c in class_candidates if c in label2id]
for batch_start_idx in trange(0, len(dataset), batch_size, leave=True, disable=not show_progress_bar):
...
# Computing probabilities based on the logits
probs = output.logits.softmax(-1)
# Mask everything except class-candidate probabilities
if class_candidates is not None:
mask = torch.zeros_like(probs)
mask[:, :, class_candidate_ids] = 1
probs = probs * mask
# Get the labels and the correponding probability scores
scores, labels = probs.max(-1)
...
return all_entities
I did not find time to have a deep dive, implement & test it, but I think this could be a useful feature.
The text was updated successfully, but these errors were encountered:
Really great library @tomaarsen, thank you for this great contribution!!
It would be really convenient to have the possibility to give a list of class-candidates to the predict method during inference. This would be useful for Retriever-Reader systems, where the Retriever (e.g. Setfit-Model) returns text sequences where it is already known what set of classes are available for the Reader (e.g. SpanMarker) for extraction and you do not want to extract other classes.
E.g. a system like this from here https://lilianweng.github.io/posts/2020-10-29-odqa/:
I was thinking about modifications to the predict method like these:
I did not find time to have a deep dive, implement & test it, but I think this could be a useful feature.
The text was updated successfully, but these errors were encountered: