diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3db6600b1c..227a3f5b4f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,12 +14,10 @@ jobs: with: python-version: 3.6 - name: Install Flair dependencies - run: | - pip install -e . + run: pip install -e . - name: Install unittest dependencies run: pip install -r requirements-dev.txt + - name: Show installed dependencies + run: pip freeze - name: Run tests - run: | - cd tests - pip freeze - pytest --runintegration -vv + run: pytest --runintegration -vv diff --git a/examples/ner/run_ner.py b/examples/ner/run_ner.py index c1790c10c4..cc7d277337 100644 --- a/examples/ner/run_ner.py +++ b/examples/ner/run_ner.py @@ -80,7 +80,7 @@ def get_flair_corpus(data_args): if data_args.dataset_arguments: dataset_args = json.loads(data_args.dataset_arguments) - if not dataset_name in ner_task_mapping: + if dataset_name not in ner_task_mapping: raise ValueError(f"Dataset name {dataset_name} is not a valid Flair datasets name!") return ner_task_mapping[dataset_name](**dataset_args) diff --git a/flair/data.py b/flair/data.py index d025cbd88a..a64098a811 100644 --- a/flair/data.py +++ b/flair/data.py @@ -489,10 +489,12 @@ def __init__(self, tokens: List[Token]): @property def start_pos(self) -> int: + assert self.tokens[0].start_position is not None return self.tokens[0].start_position @property def end_pos(self) -> int: + assert self.tokens[-1].end_position is not None return self.tokens[-1].end_position @property @@ -574,6 +576,7 @@ def clear_embeddings(self, embedding_names: List[str] = None): pass def add_tag(self, tag_type: str, tag_value: str, confidence=1.0): + assert self.tokens[0].sentence is not None self.tokens[0].sentence.add_complex_label(tag_type, SpanLabel(self, value=tag_value, score=confidence)) @@ -670,9 +673,6 @@ def __init__( self._next_sentence: Optional[Sentence] = None self._position_in_dataset: Optional[typing.Tuple[Dataset, int]] = None - def get_span(self, from_id: int, to_id: int) -> Span: - return self.tokens[from_id : to_id + 1] - def get_token(self, token_id: int) -> Optional[Token]: for token in self.tokens: if token.idx == token_id: @@ -924,7 +924,15 @@ def to_dict(self, tag_type: str = None): return {"text": self.to_original_text(), "all labels": labels} - def __getitem__(self, subscript: int) -> Union[Token, Span]: + @typing.overload + def __getitem__(self, idx: int) -> Token: + ... + + @typing.overload + def __getitem__(self, s: slice) -> Span: + ... + + def __getitem__(self, subscript): if isinstance(subscript, slice): return Span(self.tokens[subscript]) else: diff --git a/flair/datasets/entity_linking.py b/flair/datasets/entity_linking.py index 18dcb1453d..8138a95a26 100644 --- a/flair/datasets/entity_linking.py +++ b/flair/datasets/entity_linking.py @@ -1375,7 +1375,7 @@ def __init__( super(WSD_RAGANATO_ALL, self).__init__( data_folder=data_folder, - columns=columns, + column_format=columns, train_file=train_file, in_memory=in_memory, document_separator_token="-DOCSTART-", @@ -1453,7 +1453,7 @@ def __init__( super(WSD_SEMCOR, self).__init__( data_folder=data_folder, - columns=columns, + column_format=columns, train_file=train_file, test_file=test_file, in_memory=in_memory, @@ -1609,7 +1609,7 @@ def __init__( super(WSD_MASC, self).__init__( data_folder=data_folder, - columns=columns, + column_format=columns, train_file=train_file, test_file=test_file, in_memory=in_memory, @@ -1690,7 +1690,7 @@ def __init__( super(WSD_OMSTI, self).__init__( data_folder=data_folder, - columns=columns, + column_format=columns, train_file=train_file, test_file=test_file, in_memory=in_memory, @@ -1768,7 +1768,7 @@ def __init__( super(WSD_TRAINOMATIC, self).__init__( data_folder=data_folder, - columns=columns, + column_format=columns, train_file=train_file, test_file=test_file, in_memory=in_memory, diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 35e226d849..12b6732274 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -388,7 +388,7 @@ def _convert_lines_to_sentence( ): sentence: Sentence = Sentence() - token = None + token: Optional[Token] = None filtered_lines = [] comments = [] for line in lines: @@ -400,7 +400,7 @@ def _convert_lines_to_sentence( filtered_lines.append(line) # otherwise, this line is a token. parse and add to sentence - token: Token = self._parse_token(line, word_level_tag_columns, token) + token = self._parse_token(line, word_level_tag_columns, token) sentence.add_token(token) # check if this sentence is a document boundary @@ -415,14 +415,14 @@ def _convert_lines_to_sentence( re.split(self.column_delimiter, line.rstrip())[span_column] for line in filtered_lines ] predicted_spans = get_spans_from_bio(bioes_tags) - for predicted_span in predicted_spans: - span = Span(sentence[predicted_span[0][0] : predicted_span[0][-1] + 1]) - value = self._remap_label(predicted_span[2]) + for span_indices, score, label in predicted_spans: + span = sentence[span_indices[0] : span_indices[-1] + 1] + value = self._remap_label(label) sentence.add_complex_label( typename=span_level_tag_columns[span_column], - label=SpanLabel(span=span, value=value, score=predicted_span[1]), + label=SpanLabel(span=span, value=value, score=score), ) - except: + except Exception: pass # log.warning(f"--\nUnparseable sentence: {''.join(lines)}--\n") @@ -452,7 +452,7 @@ def _parse_token(self, line: str, column_name_map: Dict[int, str], last_token: O fields: List[str] = re.split(self.column_delimiter, line.rstrip()) # get head_id if exists (only in dependency parses) - head_id = fields[self.head_id_column] if self.head_id_column else None + head_id = int(fields[self.head_id_column]) if self.head_id_column else None # initialize token token = Token(fields[self.text_column], head_id=head_id, whitespace_after=self.default_whitespace_after) @@ -494,6 +494,7 @@ def _parse_token(self, line: str, column_name_map: Dict[int, str], last_token: O if last_token is None: start = 0 else: + assert last_token.end_pos is not None start = last_token.end_pos if last_token.whitespace_after: start += 1 diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 25175c231a..b92adfcba0 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -18,10 +18,8 @@ AutoTokenizer, PretrainedConfig, PreTrainedTokenizer, - TransfoXLModel, - XLNetModel, ) -from transformers.tokenization_utils_base import LARGE_INTEGER, TruncationStrategy +from transformers.tokenization_utils_base import LARGE_INTEGER import flair from flair.data import DT, Sentence, Token diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index 288f99148d..d41549e984 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional, Union +from typing import List, Union import torch from sklearn.feature_extraction.text import TfidfVectorizer @@ -273,7 +273,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]): + def _add_embeddings_internal(self, sentences: List[Sentence]): """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update only if embeddings are non-static.""" @@ -283,9 +283,6 @@ def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]): if not hasattr(self, "word_dropout"): self.word_dropout = None - if type(sentences) is Sentence: - sentences = [sentences] - self.rnn.zero_grad() # embed words in the sentence @@ -596,7 +593,7 @@ def __init__( def embedding_length(self) -> int: return self.__embedding_length - def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]): + def _add_embeddings_internal(self, sentences: List[Sentence]): """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update only if embeddings are non-static.""" @@ -606,9 +603,6 @@ def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]): if not hasattr(self, "word_dropout"): self.word_dropout = None - if type(sentences) is Sentence: - sentences = [sentences] - self.zero_grad() # is it necessary? # embed words in the sentence diff --git a/flair/models/__init__.py b/flair/models/__init__.py index 6341c59455..4dbd759a61 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -3,6 +3,7 @@ from .language_model import LanguageModel from .lemmatizer_model import Lemmatizer from .pairwise_classification_model import TextPairClassifier +from .regexp_tagger import RegexpTagger from .relation_extractor_model import RelationExtractor from .sequence_tagger_model import MultiTagger, SequenceTagger from .tars_model import FewshotClassifier, TARSClassifier, TARSTagger @@ -16,6 +17,7 @@ "Lemmatizer", "TextPairClassifier", "RelationExtractor", + "RegexpTagger", "MultiTagger", "SequenceTagger", "WordTagger", diff --git a/flair/models/dependency_parser_model.py b/flair/models/dependency_parser_model.py index 8ca0ef86ab..4754e28790 100644 --- a/flair/models/dependency_parser_model.py +++ b/flair/models/dependency_parser_model.py @@ -79,7 +79,7 @@ def __init__( mlp_input_dim = self.lstm_hidden_size * 2 if use_rnn == "Variational": - self.lstm = BiLSTM( + self.lstm: torch.nn.Module = BiLSTM( input_size=self.lstm_input_dim, hidden_size=self.lstm_hidden_size, num_layers=self.lstm_layers, @@ -142,10 +142,10 @@ def forward(self, sentences: List[Sentence]): sentence_tensor = self.word_dropout(sentence_tensor) if self.use_rnn: - sentence_tensor = pack_padded_sequence(sentence_tensor, lengths, True, False) + sentence_sequence = pack_padded_sequence(sentence_tensor, torch.IntTensor(lengths), True, False) - sentence_tensor, _ = self.lstm(sentence_tensor) - sentence_tensor, _ = pad_packed_sequence(sentence_tensor, True, total_length=seq_len) + sentence_sequence, _ = self.lstm(sentence_sequence) + sentence_tensor, _ = pad_packed_sequence(sentence_sequence, True, total_length=seq_len) # apply MLPs for arc and relations to the BiLSTM output states arc_h = self.mlp_arc_h(sentence_tensor) diff --git a/flair/models/regexp_tagger.py b/flair/models/regexp_tagger.py index 569987a99a..95c26df95a 100644 --- a/flair/models/regexp_tagger.py +++ b/flair/models/regexp_tagger.py @@ -1,44 +1,46 @@ import re +import typing from dataclasses import dataclass, field from typing import Dict, List, Tuple, Union from flair.data import Sentence, Span, SpanLabel, Token -class RegexpTagger: - @dataclass - class TokenCollection: - """ - A utility class for RegexpTagger to hold all tokens for a given Sentence and define some functionality - :param sentence: A Sentence object - """ +@dataclass +class TokenCollection: + """ + A utility class for RegexpTagger to hold all tokens for a given Sentence and define some functionality + :param sentence: A Sentence object + """ + + sentence: Sentence + __tokens_start_pos: List[int] = field(init=False, default_factory=list) + __tokens_end_pos: List[int] = field(init=False, default_factory=list) - sentence: Sentence - __tokens_start_pos: List[int] = field(init=False, default_factory=list) - __tokens_end_pos: List[int] = field(init=False, default_factory=list) + def __post_init__(self): + for token in self.tokens: + self.__tokens_start_pos.append(token.start_pos) + self.__tokens_end_pos.append(token.end_pos) - def __post_init__(self): - for token in self.tokens: - self.__tokens_start_pos.append(token.start_pos) - self.__tokens_end_pos.append(token.end_pos) + @property + def tokens(self) -> List[Token]: + return list(self.sentence) - @property - def tokens(self) -> List[Token]: - return list(self.sentence) + def get_token_span(self, span: Tuple[int, int]) -> Span: + """ + Given an interval specified with start and end pos as tuple, this function returns a Span object + spanning the tokens included in the interval. If the interval is overlapping with a token span, a + ValueError is raised - def get_token_span(self, span: Tuple[int, int]) -> Span: - """ - Given an interval specified with start and end pos as tuple, this function returns a Span object - spanning the tokens included in the interval. If the interval is overlapping with a token span, a - ValueError is raised + :param span: Start and end pos of the requested span as tuple + :return: A span object spanning the requested token interval + """ + span_start: int = self.__tokens_start_pos.index(span[0]) + span_end: int = self.__tokens_end_pos.index(span[1]) + return Span(self.tokens[span_start : span_end + 1]) - :param span: Start and end pos of the requested span as tuple - :return: A span object spanning the requested token interval - """ - span_start: int = self.__tokens_start_pos.index(span[0]) - span_end: int = self.__tokens_end_pos.index(span[1]) - return Span(self.tokens[span_start : span_end + 1]) +class RegexpTagger: def __init__(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]): """ This tagger is capable of tagging sentence objects with given regexp -> label mappings. @@ -52,7 +54,7 @@ def __init__(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]): :param mapping: A list of tuples or a single tuple representing a mapping as regexp -> label """ - self._regexp_mapping: Dict[str, re.Pattern] = {} + self._regexp_mapping: Dict[str, typing.Pattern] = {} self.register_labels(mapping=mapping) @property @@ -112,7 +114,7 @@ def _label(self, sentence: Sentence): This will add a complex_label to the given sentence for every match.span() for every registered_mapping. If a match span overlaps with a token span an exception is raised. """ - collection = RegexpTagger.TokenCollection(sentence) + collection = TokenCollection(sentence) for label, pattern in self._regexp_mapping.items(): for match in pattern.finditer(sentence.to_original_text()): diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 83fda66208..0e9297487a 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -11,7 +11,7 @@ from tqdm import tqdm import flair.nn -from flair.data import DataPoint, Dictionary, Label, Sentence, Span, SpanLabel +from flair.data import Dictionary, Label, Sentence, Span, SpanLabel from flair.datasets import DataLoader, FlairDatapointDataset from flair.embeddings import StackedEmbeddings, TokenEmbeddings from flair.file_utils import cached_path, unzip_file @@ -31,7 +31,7 @@ def __init__( tag_dictionary: Dictionary, tag_type: str, use_rnn: bool = True, - rnn: Optional[torch.nn.Module] = None, + rnn: Optional[torch.nn.RNN] = None, rnn_type: str = "LSTM", tag_format: str = "BIOES", hidden_size: int = 256, @@ -153,7 +153,7 @@ def __init__( # ----- RNN layer ----- if use_rnn: # If shared RNN provided, else create one for model - self.rnn = ( + self.rnn: torch.nn.RNN = ( rnn if rnn else self.RNN( @@ -200,7 +200,7 @@ def __init__( def label_type(self): return self.tag_type - def _init_loss_weights(self, loss_weights: dict) -> torch.tensor: + def _init_loss_weights(self, loss_weights: Dict[str, float]) -> torch.Tensor: """ Intializes the loss weights based on given dictionary: :param loss_weights: dictionary - contains loss weights @@ -220,11 +220,11 @@ def _init_initial_hidden_state(self, num_directions: int): """ hs_initializer = torch.nn.init.xavier_normal_ lstm_init_h = torch.nn.Parameter( - torch.randn(self.nlayers * num_directions, self.hidden_size), + torch.randn(self.rnn.num_layers * num_directions, self.hidden_size), requires_grad=True, ) lstm_init_c = torch.nn.Parameter( - torch.randn(self.nlayers * num_directions, self.hidden_size), + torch.randn(self.rnn.num_layers * num_directions, self.hidden_size), requires_grad=True, ) @@ -237,7 +237,7 @@ def RNN( hidden_size: int, bidirectional: bool, rnn_input_dim: int, - ) -> torch.nn.Module: + ) -> torch.nn.RNN: """ Static wrapper function returning an RNN instance from PyTorch :param rnn_type: Type of RNN from torch.nn @@ -277,15 +277,17 @@ def forward(self, sentences: Union[List[Sentence], Sentence]): Forward propagation through network. Returns gold labels of batch in addition. :param sentences: Batch of current sentences """ + if not isinstance(sentences, list): + sentences = [sentences] self.embeddings.embed(sentences) # make a zero-padded tensor for the whole sentence lengths, sentence_tensor = self._make_padded_tensor_for_batch(sentences) # sort tensor in decreasing order based on lengths of sentences in batch - lengths = lengths.sort(dim=0, descending=True) - sentences = [sentences[i] for i in lengths.indices] - sentence_tensor = sentence_tensor[lengths.indices] + sorted_lengths, length_indices = lengths.sort(dim=0, descending=True) + sentences = [sentences[i] for i in length_indices] + sentence_tensor = sentence_tensor[length_indices] # ----- Forward Propagation ----- if self.use_dropout: @@ -299,7 +301,7 @@ def forward(self, sentences: Union[List[Sentence], Sentence]): sentence_tensor = self.embedding2nn(sentence_tensor) if self.use_rnn: - packed = pack_padded_sequence(sentence_tensor, lengths.values, batch_first=True, enforce_sorted=False) + packed = pack_padded_sequence(sentence_tensor, sorted_lengths, batch_first=True, enforce_sorted=False) rnn_output, hidden = self.rnn(packed) sentence_tensor, output_lengths = pad_packed_sequence(rnn_output, batch_first=True) @@ -316,9 +318,9 @@ def forward(self, sentences: Union[List[Sentence], Sentence]): # -- A tensor of shape (aggregated sequence length for all sentences in batch, tagset size) for linear layer if self.use_crf: features = self.crf(features) - scores = (features, lengths, self.crf.transitions) + scores = (features, sorted_lengths, self.crf.transitions) else: - scores = self._get_scores_from_features(features, lengths) + scores = self._get_scores_from_features(features, sorted_lengths) # get the gold labels gold_labels = self._get_gold_labels(sentences) @@ -368,11 +370,10 @@ def _make_padded_tensor_for_batch(self, sentences: List[Sentence]) -> Tuple[torc self.embeddings.embedding_length, ] ) - lengths: torch.Tensor = torch.tensor(lengths, dtype=torch.long) - return lengths, sentence_tensor + return torch.tensor(lengths, dtype=torch.long), sentence_tensor @staticmethod - def _get_scores_from_features(features: torch.tensor, lengths: torch.tensor): + def _get_scores_from_features(features: torch.Tensor, lengths: torch.Tensor): """ Trims current batch tensor in shape (batch size, sequence length, tagset size) in such a way that all pads are going to be removed. @@ -380,13 +381,13 @@ def _get_scores_from_features(features: torch.tensor, lengths: torch.tensor): :param lengths: length from each sentence in batch in order to trim padding tokens """ features_formatted = [] - for feat, length in zip(features, lengths.values): + for feat, length in zip(features, lengths): features_formatted.append(feat[:length]) scores = torch.cat(features_formatted) return scores - def _get_gold_labels(self, sentences: Union[List[DataPoint], DataPoint]): + def _get_gold_labels(self, sentences: Union[List[Sentence], Sentence]): """ Extracts gold labels from each sentence. :param sentences: List of sentences in batch @@ -459,18 +460,15 @@ def predict( ) # progress bar for verbosity if verbose: - dataloader = tqdm(dataloader) + dataloader = tqdm(dataloader, desc="Batch inference") - overall_loss = 0 + overall_loss = torch.zeros(1, device=flair.device) batch_no = 0 label_count = 0 for batch in dataloader: batch_no += 1 - if verbose: - dataloader.set_description(f"Inferencing on batch {batch_no}") - # stop if all sentences are empty if not batch: continue @@ -490,8 +488,8 @@ def predict( # Sort batch in same way as forward propagation lengths = torch.LongTensor([len(sentence) for sentence in batch]) - lengths = lengths.sort(dim=0, descending=True) - batch = [batch[i] for i in lengths.indices] + _, sort_indices = lengths.sort(dim=0, descending=True) + batch = [batch[i] for i in sort_indices] # make predictions if self.use_crf: @@ -530,7 +528,7 @@ def predict( if return_loss: return overall_loss, label_count - def _standard_inference(self, features: torch.tensor, batch: list, probabilities_for_all_classes: bool): + def _standard_inference(self, features: torch.Tensor, batch: List[Sentence], probabilities_for_all_classes: bool): """ Softmax over emission scores from forward propagation. :param features: sentence tensor from forward propagation @@ -560,7 +558,7 @@ def _standard_inference(self, features: torch.tensor, batch: list, probabilities return predictions, all_tags - def _all_scores_for_token(self, scores: torch.tensor, lengths: list): + def _all_scores_for_token(self, scores: torch.Tensor, lengths: List[int]): """ Returns all scores for each tag in tag dictionary. :param scores: Scores for current sentence. @@ -922,15 +920,15 @@ def _fetch_model(model_name) -> str: library_version=flair.__version__, cache_dir=flair.cache_root / "models" / model_folder, ) - except HTTPError as e: + except HTTPError: # output information log.error("-" * 80) log.error( f"ACHTUNG: The key '{model_name}' was neither found on the ModelHub nor is this a valid path to a file on your system!" ) # log.error(f" - Error message: {e}") - log.error(f" -> Please check https://huggingface.co/models?filter=flair for all available models.") - log.error(f" -> Alternatively, point to a model file on your local drive.") + log.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.") + log.error(" -> Alternatively, point to a model file on your local drive.") log.error("-" * 80) Path(flair.cache_root / "models" / model_folder).rmdir() # remove folder again if not valid @@ -1050,7 +1048,7 @@ def load(cls, model_names: Union[List[str], str]): model_names = [model_names] taggers = {} - models = [] + models: List[SequenceTagger] = [] # load each model for model_name in model_names: diff --git a/flair/models/sequence_tagger_utils/crf.py b/flair/models/sequence_tagger_utils/crf.py index 9b65e16ffa..3b37d1aeec 100644 --- a/flair/models/sequence_tagger_utils/crf.py +++ b/flair/models/sequence_tagger_utils/crf.py @@ -32,15 +32,14 @@ def __init__(self, tag_dictionary, tagset_size: int, init_from_state_dict: bool) self.transitions.detach()[:, tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000 self.to(flair.device) - def forward(self, features: torch.tensor) -> torch.tensor: + def forward(self, features: torch.Tensor) -> torch.Tensor: """ Forward propagation of Conditional Random Field. :param features: output from RNN / Linear layer in shape (batch size, seq len, hidden size) :return: CRF scores (emission scores for each token + transitions prob from previous state) in shape (batch_size, seq len, tagset size, tagset size) """ - batch_size = features.size(0) - seq_len = features.size(1) + batch_size, seq_len = features.size()[:2] emission_scores = features emission_scores = emission_scores.unsqueeze(-1).expand(batch_size, seq_len, self.tagset_size, self.tagset_size) diff --git a/flair/models/sequence_tagger_utils/viterbi.py b/flair/models/sequence_tagger_utils/viterbi.py index 9618b97ac4..793f528d0e 100644 --- a/flair/models/sequence_tagger_utils/viterbi.py +++ b/flair/models/sequence_tagger_utils/viterbi.py @@ -1,3 +1,5 @@ +from typing import Tuple + import numpy as np import torch import torch.nn @@ -26,7 +28,7 @@ def __init__(self, tag_dictionary: Dictionary): self.start_tag = tag_dictionary.get_idx_for_item(START_TAG) self.stop_tag = tag_dictionary.get_idx_for_item(STOP_TAG) - def forward(self, features_tuple: tuple, targets: torch.tensor) -> torch.tensor: + def forward(self, features_tuple: tuple, targets: torch.Tensor) -> torch.Tensor: """ Forward propagation of Viterbi Loss @@ -46,18 +48,18 @@ def forward(self, features_tuple: tuple, targets: torch.tensor) -> torch.tensor: # scores_at_targets[range(features.shape[0]), lengths.values -1] # Squeeze crf scores matrices in 1-dim shape and gather scores at targets by matrix indices scores_at_targets = torch.gather(features.view(batch_size, seq_len, -1), 2, targets_matrix_indices) - scores_at_targets = pack_padded_sequence(scores_at_targets, lengths.values, batch_first=True)[0] + scores_at_targets = pack_padded_sequence(scores_at_targets, lengths, batch_first=True)[0] transitions_to_stop = transitions[ np.repeat(self.stop_tag, features.shape[0]), - [target[length - 1] for target, length in zip(targets, lengths.values)], + [target[length - 1] for target, length in zip(targets, lengths)], ] gold_score = scores_at_targets.sum() + transitions_to_stop.sum() scores_upto_t = torch.zeros(batch_size, self.tagset_size, device=flair.device) - for t in range(max(lengths.values)): + for t in range(max(lengths)): batch_size_t = sum( - [l > t for l in lengths.values] + [length > t for length in lengths] ) # since batch is ordered, we can save computation time by reducing our effective batch_size if t == 0: @@ -91,7 +93,7 @@ def _log_sum_exp(tensor, dim): m_expanded = m.unsqueeze(dim).expand_as(tensor) return m + torch.log(torch.sum(torch.exp(tensor - m_expanded), dim)) - def _format_targets(self, targets: torch.tensor, lengths: torch.tensor): + def _format_targets(self, targets: torch.Tensor, lengths: torch.IntTensor): """ Formats targets into matrix indices. CRF scores contain per sentence, per token a (tagset_size x tagset_size) matrix, containing emission score for @@ -106,12 +108,12 @@ def _format_targets(self, targets: torch.tensor, lengths: torch.tensor): targets_per_sentence = [] targets_list = targets.tolist() - for cut in lengths.values: + for cut in lengths: targets_per_sentence.append(targets_list[:cut]) targets_list = targets_list[cut:] for t in targets_per_sentence: - t += [self.tag_dictionary.get_idx_for_item(STOP_TAG)] * (max(lengths.values) - len(t)) + t += [self.tag_dictionary.get_idx_for_item(STOP_TAG)] * (int(lengths.max().item()) - len(t)) matrix_indices = list( map( @@ -138,7 +140,7 @@ def __init__(self, tag_dictionary: Dictionary): self.start_tag = tag_dictionary.get_idx_for_item(START_TAG) self.stop_tag = tag_dictionary.get_idx_for_item(STOP_TAG) - def decode(self, features_tuple: tuple, probabilities_for_all_classes: bool) -> (List, List): + def decode(self, features_tuple: tuple, probabilities_for_all_classes: bool) -> Tuple[List, List]: """ Decoding function returning the most likely sequence of tags. :param features_tuple: CRF scores from forward method in shape (batch size, seq len, tagset size, tagset size), @@ -163,8 +165,8 @@ def decode(self, features_tuple: tuple, probabilities_for_all_classes: bool) -> ) for t in range(seq_len): - batch_size_t = sum([l > t for l in lengths.values]) # effective batch size (sans pads) at this timestep - terminates = [i for i, l in enumerate(lengths.values) if l == t + 1] + batch_size_t = sum([length > t for length in lengths]) # effective batch size (sans pads) at this timestep + terminates = [i for i, length in enumerate(lengths) if length == t + 1] if t == 0: scores_upto_t[:batch_size_t, t] = features[:batch_size_t, t, :, self.start_tag] @@ -206,7 +208,7 @@ def decode(self, features_tuple: tuple, probabilities_for_all_classes: bool) -> confidences = torch.max(scores, dim=2) tags = [] - for tag_seq, tag_seq_conf, length_seq in zip(decoded, confidences.values, lengths.values): + for tag_seq, tag_seq_conf, length_seq in zip(decoded, confidences.values, lengths): tags.append( [ Label(self.tag_dictionary.get_item_for_index(tag), conf.item()) @@ -219,14 +221,14 @@ def decode(self, features_tuple: tuple, probabilities_for_all_classes: bool) -> return tags, all_tags - def _all_scores_for_token(self, scores: torch.tensor, lengths: torch.tensor): + def _all_scores_for_token(self, scores: torch.Tensor, lengths: torch.IntTensor): """ Returns all scores for each tag in tag dictionary. :param scores: Scores for current sentence. """ scores = scores.numpy() prob_tags_per_sentence = [] - for scores_sentence, length in zip(scores, lengths.values): + for scores_sentence, length in zip(scores, lengths): scores_sentence = scores_sentence[:length] prob_tags_per_sentence.append( [ diff --git a/flair/models/tars_model.py b/flair/models/tars_model.py index 326286dee8..790bdc3ea3 100644 --- a/flair/models/tars_model.py +++ b/flair/models/tars_model.py @@ -541,7 +541,7 @@ def predict( if most_probable_first: import operator - already_set_indices = [] + already_set_indices: List[int] = [] sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1)) sorted_x.reverse() diff --git a/flair/nn/model.py b/flair/nn/model.py index b61d15dbaa..d78a9315dc 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -14,7 +14,7 @@ import flair from flair import file_utils -from flair.data import DT, Dictionary, Label, Sentence, _iter_dataset +from flair.data import DT, Dictionary, Label, Sentence from flair.datasets import DataLoader, FlairDatapointDataset from flair.file_utils import Tqdm from flair.training_utils import Result, store_embeddings @@ -642,7 +642,7 @@ def predict( return sentences if len(sentences) > mini_batch_size: - batches = DataLoader( + batches: Union[DataLoader, List[List[DT]]] = DataLoader( dataset=FlairDatapointDataset(reordered_sentences), batch_size=mini_batch_size, ) @@ -654,15 +654,15 @@ def predict( else: batches = [reordered_sentences] - overall_loss = 0 + overall_loss = torch.zeros(1, device=flair.device) label_count = 0 for batch in batches: # stop if all sentences are empty if not batch: continue - scores, gold_labels, data_points, label_candidates = self.forward_pass( - batch, return_label_candidates=True # type: ignore + scores, gold_labels, data_points, label_candidates = self.forward_pass( # type: ignore + batch, return_label_candidates=True ) # remove previously predicted labels of this type for sentence in data_points: diff --git a/flair/training_utils.py b/flair/training_utils.py index bc28435922..caf43472a4 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -13,7 +13,7 @@ from torch.utils.data import Dataset import flair -from flair.data import DataPoint, Dictionary, Sentence, _iter_dataset +from flair.data import DT, DataPoint, Dictionary, Sentence, _iter_dataset class Result(object): @@ -366,7 +366,7 @@ def add_file_handler(log, output_file): def store_embeddings( - data_points: Union[List[DataPoint], Dataset], storage_mode: str, dynamic_embeddings: Optional[List[str]] = None + data_points: Union[List[DT], Dataset], storage_mode: str, dynamic_embeddings: Optional[List[str]] = None ): if isinstance(data_points, Dataset): diff --git a/requirements-dev.txt b/requirements-dev.txt index 0d45fd004b..100eaddd71 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,5 +5,6 @@ pytest-flake8 flake8-black flake8<4.0.0 types-Deprecated +types-dataclasses types-tabulate -types-requests \ No newline at end of file +types-requests