Skip to content

Commit

Permalink
Merge pull request #2611 from helpmefindaname/bf/cicd_also_check_all_…
Browse files Browse the repository at this point in the history
…formattings

collect formatting tests for flair package
  • Loading branch information
alanakbik authored Feb 2, 2022
2 parents c285829 + 7437dca commit 9697dc2
Show file tree
Hide file tree
Showing 17 changed files with 130 additions and 127 deletions.
10 changes: 4 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ jobs:
with:
python-version: 3.6
- name: Install Flair dependencies
run: |
pip install -e .
run: pip install -e .
- name: Install unittest dependencies
run: pip install -r requirements-dev.txt
- name: Show installed dependencies
run: pip freeze
- name: Run tests
run: |
cd tests
pip freeze
pytest --runintegration -vv
run: pytest --runintegration -vv
2 changes: 1 addition & 1 deletion examples/ner/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_flair_corpus(data_args):
if data_args.dataset_arguments:
dataset_args = json.loads(data_args.dataset_arguments)

if not dataset_name in ner_task_mapping:
if dataset_name not in ner_task_mapping:
raise ValueError(f"Dataset name {dataset_name} is not a valid Flair datasets name!")

return ner_task_mapping[dataset_name](**dataset_args)
Expand Down
16 changes: 12 additions & 4 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,12 @@ def __init__(self, tokens: List[Token]):

@property
def start_pos(self) -> int:
assert self.tokens[0].start_position is not None
return self.tokens[0].start_position

@property
def end_pos(self) -> int:
assert self.tokens[-1].end_position is not None
return self.tokens[-1].end_position

@property
Expand Down Expand Up @@ -574,6 +576,7 @@ def clear_embeddings(self, embedding_names: List[str] = None):
pass

def add_tag(self, tag_type: str, tag_value: str, confidence=1.0):
assert self.tokens[0].sentence is not None
self.tokens[0].sentence.add_complex_label(tag_type, SpanLabel(self, value=tag_value, score=confidence))


Expand Down Expand Up @@ -670,9 +673,6 @@ def __init__(
self._next_sentence: Optional[Sentence] = None
self._position_in_dataset: Optional[typing.Tuple[Dataset, int]] = None

def get_span(self, from_id: int, to_id: int) -> Span:
return self.tokens[from_id : to_id + 1]

def get_token(self, token_id: int) -> Optional[Token]:
for token in self.tokens:
if token.idx == token_id:
Expand Down Expand Up @@ -924,7 +924,15 @@ def to_dict(self, tag_type: str = None):

return {"text": self.to_original_text(), "all labels": labels}

def __getitem__(self, subscript: int) -> Union[Token, Span]:
@typing.overload
def __getitem__(self, idx: int) -> Token:
...

@typing.overload
def __getitem__(self, s: slice) -> Span:
...

def __getitem__(self, subscript):
if isinstance(subscript, slice):
return Span(self.tokens[subscript])
else:
Expand Down
10 changes: 5 additions & 5 deletions flair/datasets/entity_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ def __init__(

super(WSD_RAGANATO_ALL, self).__init__(
data_folder=data_folder,
columns=columns,
column_format=columns,
train_file=train_file,
in_memory=in_memory,
document_separator_token="-DOCSTART-",
Expand Down Expand Up @@ -1453,7 +1453,7 @@ def __init__(

super(WSD_SEMCOR, self).__init__(
data_folder=data_folder,
columns=columns,
column_format=columns,
train_file=train_file,
test_file=test_file,
in_memory=in_memory,
Expand Down Expand Up @@ -1609,7 +1609,7 @@ def __init__(

super(WSD_MASC, self).__init__(
data_folder=data_folder,
columns=columns,
column_format=columns,
train_file=train_file,
test_file=test_file,
in_memory=in_memory,
Expand Down Expand Up @@ -1690,7 +1690,7 @@ def __init__(

super(WSD_OMSTI, self).__init__(
data_folder=data_folder,
columns=columns,
column_format=columns,
train_file=train_file,
test_file=test_file,
in_memory=in_memory,
Expand Down Expand Up @@ -1768,7 +1768,7 @@ def __init__(

super(WSD_TRAINOMATIC, self).__init__(
data_folder=data_folder,
columns=columns,
column_format=columns,
train_file=train_file,
test_file=test_file,
in_memory=in_memory,
Expand Down
17 changes: 9 additions & 8 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def _convert_lines_to_sentence(
):

sentence: Sentence = Sentence()
token = None
token: Optional[Token] = None
filtered_lines = []
comments = []
for line in lines:
Expand All @@ -400,7 +400,7 @@ def _convert_lines_to_sentence(
filtered_lines.append(line)

# otherwise, this line is a token. parse and add to sentence
token: Token = self._parse_token(line, word_level_tag_columns, token)
token = self._parse_token(line, word_level_tag_columns, token)
sentence.add_token(token)

# check if this sentence is a document boundary
Expand All @@ -415,14 +415,14 @@ def _convert_lines_to_sentence(
re.split(self.column_delimiter, line.rstrip())[span_column] for line in filtered_lines
]
predicted_spans = get_spans_from_bio(bioes_tags)
for predicted_span in predicted_spans:
span = Span(sentence[predicted_span[0][0] : predicted_span[0][-1] + 1])
value = self._remap_label(predicted_span[2])
for span_indices, score, label in predicted_spans:
span = sentence[span_indices[0] : span_indices[-1] + 1]
value = self._remap_label(label)
sentence.add_complex_label(
typename=span_level_tag_columns[span_column],
label=SpanLabel(span=span, value=value, score=predicted_span[1]),
label=SpanLabel(span=span, value=value, score=score),
)
except:
except Exception:
pass
# log.warning(f"--\nUnparseable sentence: {''.join(lines)}--\n")

Expand Down Expand Up @@ -452,7 +452,7 @@ def _parse_token(self, line: str, column_name_map: Dict[int, str], last_token: O
fields: List[str] = re.split(self.column_delimiter, line.rstrip())

# get head_id if exists (only in dependency parses)
head_id = fields[self.head_id_column] if self.head_id_column else None
head_id = int(fields[self.head_id_column]) if self.head_id_column else None

# initialize token
token = Token(fields[self.text_column], head_id=head_id, whitespace_after=self.default_whitespace_after)
Expand Down Expand Up @@ -494,6 +494,7 @@ def _parse_token(self, line: str, column_name_map: Dict[int, str], last_token: O
if last_token is None:
start = 0
else:
assert last_token.end_pos is not None
start = last_token.end_pos
if last_token.whitespace_after:
start += 1
Expand Down
4 changes: 1 addition & 3 deletions flair/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
AutoTokenizer,
PretrainedConfig,
PreTrainedTokenizer,
TransfoXLModel,
XLNetModel,
)
from transformers.tokenization_utils_base import LARGE_INTEGER, TruncationStrategy
from transformers.tokenization_utils_base import LARGE_INTEGER

import flair
from flair.data import DT, Sentence, Token
Expand Down
12 changes: 3 additions & 9 deletions flair/embeddings/document.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List, Optional, Union
from typing import List, Union

import torch
from sklearn.feature_extraction.text import TfidfVectorizer
Expand Down Expand Up @@ -273,7 +273,7 @@ def __init__(
def embedding_length(self) -> int:
return self.__embedding_length

def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):
def _add_embeddings_internal(self, sentences: List[Sentence]):
"""Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update
only if embeddings are non-static."""

Expand All @@ -283,9 +283,6 @@ def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):
if not hasattr(self, "word_dropout"):
self.word_dropout = None

if type(sentences) is Sentence:
sentences = [sentences]

self.rnn.zero_grad()

# embed words in the sentence
Expand Down Expand Up @@ -596,7 +593,7 @@ def __init__(
def embedding_length(self) -> int:
return self.__embedding_length

def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):
def _add_embeddings_internal(self, sentences: List[Sentence]):
"""Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update
only if embeddings are non-static."""

Expand All @@ -606,9 +603,6 @@ def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):
if not hasattr(self, "word_dropout"):
self.word_dropout = None

if type(sentences) is Sentence:
sentences = [sentences]

self.zero_grad() # is it necessary?

# embed words in the sentence
Expand Down
2 changes: 2 additions & 0 deletions flair/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .language_model import LanguageModel
from .lemmatizer_model import Lemmatizer
from .pairwise_classification_model import TextPairClassifier
from .regexp_tagger import RegexpTagger
from .relation_extractor_model import RelationExtractor
from .sequence_tagger_model import MultiTagger, SequenceTagger
from .tars_model import FewshotClassifier, TARSClassifier, TARSTagger
Expand All @@ -16,6 +17,7 @@
"Lemmatizer",
"TextPairClassifier",
"RelationExtractor",
"RegexpTagger",
"MultiTagger",
"SequenceTagger",
"WordTagger",
Expand Down
8 changes: 4 additions & 4 deletions flair/models/dependency_parser_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
mlp_input_dim = self.lstm_hidden_size * 2

if use_rnn == "Variational":
self.lstm = BiLSTM(
self.lstm: torch.nn.Module = BiLSTM(
input_size=self.lstm_input_dim,
hidden_size=self.lstm_hidden_size,
num_layers=self.lstm_layers,
Expand Down Expand Up @@ -142,10 +142,10 @@ def forward(self, sentences: List[Sentence]):
sentence_tensor = self.word_dropout(sentence_tensor)

if self.use_rnn:
sentence_tensor = pack_padded_sequence(sentence_tensor, lengths, True, False)
sentence_sequence = pack_padded_sequence(sentence_tensor, torch.IntTensor(lengths), True, False)

sentence_tensor, _ = self.lstm(sentence_tensor)
sentence_tensor, _ = pad_packed_sequence(sentence_tensor, True, total_length=seq_len)
sentence_sequence, _ = self.lstm(sentence_sequence)
sentence_tensor, _ = pad_packed_sequence(sentence_sequence, True, total_length=seq_len)

# apply MLPs for arc and relations to the BiLSTM output states
arc_h = self.mlp_arc_h(sentence_tensor)
Expand Down
62 changes: 32 additions & 30 deletions flair/models/regexp_tagger.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,46 @@
import re
import typing
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union

from flair.data import Sentence, Span, SpanLabel, Token


class RegexpTagger:
@dataclass
class TokenCollection:
"""
A utility class for RegexpTagger to hold all tokens for a given Sentence and define some functionality
:param sentence: A Sentence object
"""
@dataclass
class TokenCollection:
"""
A utility class for RegexpTagger to hold all tokens for a given Sentence and define some functionality
:param sentence: A Sentence object
"""

sentence: Sentence
__tokens_start_pos: List[int] = field(init=False, default_factory=list)
__tokens_end_pos: List[int] = field(init=False, default_factory=list)

sentence: Sentence
__tokens_start_pos: List[int] = field(init=False, default_factory=list)
__tokens_end_pos: List[int] = field(init=False, default_factory=list)
def __post_init__(self):
for token in self.tokens:
self.__tokens_start_pos.append(token.start_pos)
self.__tokens_end_pos.append(token.end_pos)

def __post_init__(self):
for token in self.tokens:
self.__tokens_start_pos.append(token.start_pos)
self.__tokens_end_pos.append(token.end_pos)
@property
def tokens(self) -> List[Token]:
return list(self.sentence)

@property
def tokens(self) -> List[Token]:
return list(self.sentence)
def get_token_span(self, span: Tuple[int, int]) -> Span:
"""
Given an interval specified with start and end pos as tuple, this function returns a Span object
spanning the tokens included in the interval. If the interval is overlapping with a token span, a
ValueError is raised
def get_token_span(self, span: Tuple[int, int]) -> Span:
"""
Given an interval specified with start and end pos as tuple, this function returns a Span object
spanning the tokens included in the interval. If the interval is overlapping with a token span, a
ValueError is raised
:param span: Start and end pos of the requested span as tuple
:return: A span object spanning the requested token interval
"""
span_start: int = self.__tokens_start_pos.index(span[0])
span_end: int = self.__tokens_end_pos.index(span[1])
return Span(self.tokens[span_start : span_end + 1])

:param span: Start and end pos of the requested span as tuple
:return: A span object spanning the requested token interval
"""
span_start: int = self.__tokens_start_pos.index(span[0])
span_end: int = self.__tokens_end_pos.index(span[1])
return Span(self.tokens[span_start : span_end + 1])

class RegexpTagger:
def __init__(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]):
"""
This tagger is capable of tagging sentence objects with given regexp -> label mappings.
Expand All @@ -52,7 +54,7 @@ def __init__(self, mapping: Union[List[Tuple[str, str]], Tuple[str, str]]):
:param mapping: A list of tuples or a single tuple representing a mapping as regexp -> label
"""
self._regexp_mapping: Dict[str, re.Pattern] = {}
self._regexp_mapping: Dict[str, typing.Pattern] = {}
self.register_labels(mapping=mapping)

@property
Expand Down Expand Up @@ -112,7 +114,7 @@ def _label(self, sentence: Sentence):
This will add a complex_label to the given sentence for every match.span() for every registered_mapping.
If a match span overlaps with a token span an exception is raised.
"""
collection = RegexpTagger.TokenCollection(sentence)
collection = TokenCollection(sentence)

for label, pattern in self._regexp_mapping.items():
for match in pattern.finditer(sentence.to_original_text()):
Expand Down
Loading

0 comments on commit 9697dc2

Please sign in to comment.