You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The SequenceTagger class has a _all_scores_for_token() function that takes as input a batch of sentences, softmax scores from the tagger and the length of each sentence in the batch. The function calculates the probability distribution over all class labels for each token of each sentence in the batch and returns it. The calculation of probability distribution seems incorrect in this function.
Below, I have explained this for a sample of English OntoNotes Corpus.
# import necessary modules# 1. get the corpus
corpus = flair.datasets.ONTONOTES()
# 2. what label do we want to predict?
label_type = 'ner'# 3. make the label dictionary from the corpus
label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=False)
# 4. initialize fine-tuneable transformer embeddings WITH document context
embeddings = TransformerWordEmbeddings(model='xlm-roberta-large',
layers="-1",
subtoken_pooling="first",
fine_tune=True,
use_context=True,
)
# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection)
tagger = SequenceTagger(hidden_size=256,
embeddings=embeddings,
tag_dictionary=label_dict,
tag_type='ner',
use_crf=False,
use_rnn=False,
reproject_embeddings=False,
)
label_name = 'ner'
sentences = corpus.train
mini_batch_size=10
verbose = True
with torch.no_grad():
Sentence.set_context_for_sentences(cast(List[Sentence], sentences))
# filter empty sentences
sentences = [sentence forsentencein sentences if len(sentence) > 0]
# reverse sort all sequences by their length
reordered_sentences = sorted(sentences, key=len, reverse=True)
dataloader = DataLoader(
dataset=FlairDatapointDataset(reordered_sentences),
batch_size=mini_batch_size,
)
# progress bar for verbosityif verbose:
dataloader = tqdm(dataloader, desc="Batch inference")
overall_loss = torch.zeros(1, device=flair.device)
label_count = 0
forbatchin dataloader:
# stop if all sentences are emptyif not batch:
continue# get features from forward propagation
sentence_tensor, lengths = tagger._prepare_tensors(batch)
features = tagger.forward(sentence_tensor, lengths)
# remove previously predicted labels of this typeforsentencein batch:
sentence.remove_labels(label_name)
break
We calculate the probability distribution of this batch via the _all_scores_for_token() function. The _all_scores_for_token() function is written as a separate function below for better understanding of the output.
def _all_scores_for_token(sentences: List[Sentence], scores: torch.Tensor, lengths: List[int]):
"""Returns all scores for each tag in tag dictionary."""
scores = scores.numpy()
tokens = [token forsentencein sentences fortokenin sentence]
print('Number of tokens in batch:',len(tokens))
prob_all_tags = [
[
Label(token, tagger.label_dictionary.get_item_for_index(score_id), score)
forscore_id, scorein enumerate(score_dist)
]
forscore_dist, tokenin zip(scores, tokens)
]
print('Length of prob_all_tags:',len(prob_all_tags))
prob_tags_per_sentence = []
previous = 0
fori,lengthin enumerate(lengths):
print(f'Length range of Sentence {i}: {previous} to {previous + length}')
prob_tags_per_sentence.append(prob_all_tags[previous : previous + length])
previous = length
return prob_tags_per_sentence
softmax_batch = F.softmax(features, dim=1).cpu()
lengths = [len(sentence) forsentencein batch]
all_tags = _all_scores_for_token(batch, softmax_batch, lengths)
Output:
Number of tokens in batch: 1761
Length of prob_all_tags: 1761
Length range of Sentence 0: 0 to 210
Length range of Sentence 1: 210 to 415
Length range of Sentence 2: 205 to 394
Length range of Sentence 3: 189 to 377
Length range of Sentence 4: 188 to 361
Length range of Sentence 5: 173 to 341
Length range of Sentence 6: 168 to 335
Length range of Sentence 7: 167 to 324
Length range of Sentence 8: 157 to 313
Length range of Sentence 9: 156 to 304
Here the total of number tokens in the batch is 1761. prob_all_tag variable contains the probability distribution for each token. But when they are spliited for each sentence in prob_tags_per_sentence variable, the sentence length range calculation is incorrect which can be seen from the above output. The corrected length range calculation should be -
fori,lengthin enumerate(lengths):
print(f'Length range of Sentence {i}: {previous} to {previous + length}')
prob_tags_per_sentence.append(prob_all_tags[previous : previous + length])
# previous = length should be previous+=length
previous += length
To Reproduce
importflairimportflair.datasetsfromflair.embeddingsimportTransformerWordEmbeddingsfromflair.modelsimportSequenceTaggerfromflair.dataimportSentence, Labelfromflair.datasetsimportDataLoader, FlairDatapointDatasetfromtqdmimporttqdmfromtypingimportList, castimporttorchimporttorch.nn.functionalasFdef_all_scores_for_token(sentences: List[Sentence], scores: torch.Tensor, lengths: List[int]):
"""Returns all scores for each tag in tag dictionary."""scores=scores.numpy()
tokens= [tokenforsentenceinsentencesfortokeninsentence]
print('Number of tokens in batch:',len(tokens))
prob_all_tags= [
[
Label(token, tagger.label_dictionary.get_item_for_index(score_id), score)
forscore_id, scoreinenumerate(score_dist)
]
forscore_dist, tokeninzip(scores, tokens)
]
print('Length of prob_all_tags:',len(prob_all_tags))
prob_tags_per_sentence= []
previous=0fori,lengthinenumerate(lengths):
print(f'Length range of Sentence {i}: {previous} to {previous+length}')
prob_tags_per_sentence.append(prob_all_tags[previous : previous+length])
previous=lengthreturnprob_tags_per_sentence# 1. get the corpuscorpus=flair.datasets.ONTONOTES()
# 2. what label do we want to predict?label_type='ner'# 3. make the label dictionary from the corpuslabel_dict=corpus.make_label_dictionary(label_type=label_type, add_unk=False)
# 4. initialize fine-tuneable transformer embeddings WITH document contextembeddings=TransformerWordEmbeddings(model='xlm-roberta-large',
layers="-1",
subtoken_pooling="first",
fine_tune=True,
use_context=True,
)
# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection)tagger=SequenceTagger(hidden_size=256,
embeddings=embeddings,
tag_dictionary=label_dict,
tag_type='ner',
use_crf=False,
use_rnn=False,
reproject_embeddings=False,
)
label_name='ner'sentences=corpus.trainmini_batch_size=10verbose=Truewithtorch.no_grad():
Sentence.set_context_for_sentences(cast(List[Sentence], sentences))
# filter empty sentencessentences= [sentenceforsentenceinsentencesiflen(sentence) >0]
# reverse sort all sequences by their lengthreordered_sentences=sorted(sentences, key=len, reverse=True)
dataloader=DataLoader(
dataset=FlairDatapointDataset(reordered_sentences),
batch_size=mini_batch_size,
)
# progress bar for verbosityifverbose:
dataloader=tqdm(dataloader, desc="Batch inference")
overall_loss=torch.zeros(1, device=flair.device)
label_count=0forbatchindataloader:
# stop if all sentences are emptyifnotbatch:
continue# get features from forward propagationsentence_tensor, lengths=tagger._prepare_tensors(batch)
features=tagger.forward(sentence_tensor, lengths)
# remove previously predicted labels of this typeforsentenceinbatch:
sentence.remove_labels(label_name)
breaksoftmax_batch=F.softmax(features, dim=1).cpu()
lengths= [len(sentence) forsentenceinbatch]
all_tags=_all_scores_for_token(batch, softmax_batch, lengths)
Expected behavior
If length is measured correctly in the _all_scores_for_token() function, the output should be -
Number of tokens in batch: 1761
Length of prob_all_tags: 1761
Length range of Sentence 0: 0 to 210
Length range of Sentence 1: 210 to 415
Length range of Sentence 2: 415 to 604
Length range of Sentence 3: 604 to 792
Length range of Sentence 4: 792 to 965
Length range of Sentence 5: 965 to 1133
Length range of Sentence 6: 1133 to 1300
Length range of Sentence 7: 1300 to 1457
Length range of Sentence 8: 1457 to 1613
Length range of Sentence 9: 1613 to 1761
Logs and Stack traces
No response
Screenshots
No response
Additional Context
No response
Environment
Versions:
Flair
0.13.1
Pytorch
2.2.1+cu121
Transformers
4.40.1
GPU
True
The text was updated successfully, but these errors were encountered:
Describe the bug
The SequenceTagger class has a _all_scores_for_token() function that takes as input a batch of sentences, softmax scores from the tagger and the length of each sentence in the batch. The function calculates the probability distribution over all class labels for each token of each sentence in the batch and returns it. The calculation of probability distribution seems incorrect in this function.
Below, I have explained this for a sample of English OntoNotes Corpus.
We calculate the probability distribution of this batch via the _all_scores_for_token() function. The _all_scores_for_token() function is written as a separate function below for better understanding of the output.
Here the total of number tokens in the batch is 1761.
prob_all_tag
variable contains the probability distribution for each token. But when they are spliited for each sentence inprob_tags_per_sentence
variable, the sentence length range calculation is incorrect which can be seen from the above output. The corrected length range calculation should be -To Reproduce
Expected behavior
If length is measured correctly in the _all_scores_for_token() function, the output should be -
Number of tokens in batch: 1761 Length of prob_all_tags: 1761 Length range of Sentence 0: 0 to 210 Length range of Sentence 1: 210 to 415 Length range of Sentence 2: 415 to 604 Length range of Sentence 3: 604 to 792 Length range of Sentence 4: 792 to 965 Length range of Sentence 5: 965 to 1133 Length range of Sentence 6: 1133 to 1300 Length range of Sentence 7: 1300 to 1457 Length range of Sentence 8: 1457 to 1613 Length range of Sentence 9: 1613 to 1761
Logs and Stack traces
No response
Screenshots
No response
Additional Context
No response
Environment
Versions:
Flair
0.13.1
Pytorch
2.2.1+cu121
Transformers
4.40.1
GPU
True
The text was updated successfully, but these errors were encountered: