Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add DocumentCNNEmbeddings #2141

Merged
merged 1 commit into from
Mar 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flair/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .document import DocumentTFIDFEmbeddings
from .document import DocumentRNNEmbeddings
from .document import DocumentLMEmbeddings
from .document import DocumentCNNEmbeddings
from .document import SentenceTransformerDocumentEmbeddings

# Expose image embedding classes
Expand Down
159 changes: 159 additions & 0 deletions flair/embeddings/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,3 +688,162 @@ def _add_embeddings_to_sentences(self, sentences: List[Sentence]):
def embedding_length(self) -> int:
"""Returns the length of the embedding vector."""
return self.model.get_sentence_embedding_dimension()


class DocumentCNNEmbeddings(DocumentEmbeddings):
def __init__(
self,
embeddings: List[TokenEmbeddings],
kernels=((100, 3), (100, 4), (100, 5)),
reproject_words: bool = True,
reproject_words_dimension: int = None,
dropout: float = 0.5,
word_dropout: float = 0.0,
locked_dropout: float = 0.0,
fine_tune: bool = True,
):
"""The constructor takes a list of embeddings to be combined.
:param embeddings: a list of token embeddings
:param kernels: list of (number of kernels, kernel size)
:param reproject_words: boolean value, indicating whether to reproject the token embeddings in a separate linear
layer before putting them into the rnn or not
:param reproject_words_dimension: output dimension of reprojecting token embeddings. If None the same output
dimension as before will be taken.
:param dropout: the dropout value to be used
:param word_dropout: the word dropout value to be used, if 0.0 word dropout is not used
:param locked_dropout: the locked dropout value to be used, if 0.0 locked dropout is not used
"""
super().__init__()

self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings)
self.length_of_all_token_embeddings: int = self.embeddings.embedding_length

self.kernels = kernels
self.reproject_words = reproject_words

self.static_embeddings = False if fine_tune else True

self.embeddings_dimension: int = self.length_of_all_token_embeddings
if self.reproject_words and reproject_words_dimension is not None:
self.embeddings_dimension = reproject_words_dimension

self.word_reprojection_map = torch.nn.Linear(
self.length_of_all_token_embeddings, self.embeddings_dimension
)

# CNN
self.__embedding_length: int = sum([kernel_num for kernel_num, kernel_size in self.kernels])
self.convs = torch.nn.ModuleList(
[
torch.nn.Conv1d(self.embeddings_dimension, kernel_num, kernel_size) for kernel_num, kernel_size in self.kernels
]
)
self.pool = torch.nn.AdaptiveMaxPool1d(1)

self.name = "document_cnn"

# dropouts
self.dropout = torch.nn.Dropout(dropout) if dropout > 0.0 else None
self.locked_dropout = (
LockedDropout(locked_dropout) if locked_dropout > 0.0 else None
)
self.word_dropout = WordDropout(word_dropout) if word_dropout > 0.0 else None

torch.nn.init.xavier_uniform_(self.word_reprojection_map.weight)

self.to(flair.device)

self.eval()

@property
def embedding_length(self) -> int:
return self.__embedding_length

def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):
"""Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update
only if embeddings are non-static."""

# TODO: remove in future versions
if not hasattr(self, "locked_dropout"):
self.locked_dropout = None
if not hasattr(self, "word_dropout"):
self.word_dropout = None

if type(sentences) is Sentence:
sentences = [sentences]

self.zero_grad() # is it necessary?

# embed words in the sentence
self.embeddings.embed(sentences)

lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
longest_token_sequence_in_batch: int = max(lengths)

pre_allocated_zero_tensor = torch.zeros(
self.embeddings.embedding_length * longest_token_sequence_in_batch,
dtype=torch.float,
device=flair.device,
)

all_embs: List[torch.Tensor] = list()
for sentence in sentences:
all_embs += [
emb for token in sentence for emb in token.get_each_embedding()
]
nb_padding_tokens = longest_token_sequence_in_batch - len(sentence)

if nb_padding_tokens > 0:
t = pre_allocated_zero_tensor[
: self.embeddings.embedding_length * nb_padding_tokens
]
all_embs.append(t)

sentence_tensor = torch.cat(all_embs).view(
[
len(sentences),
longest_token_sequence_in_batch,
self.embeddings.embedding_length,
]
)

# before-RNN dropout
if self.dropout:
sentence_tensor = self.dropout(sentence_tensor)
if self.locked_dropout:
sentence_tensor = self.locked_dropout(sentence_tensor)
if self.word_dropout:
sentence_tensor = self.word_dropout(sentence_tensor)

# reproject if set
if self.reproject_words:
sentence_tensor = self.word_reprojection_map(sentence_tensor)

# push CNN
x = sentence_tensor
x = x.permute(0, 2, 1)

rep = [self.pool(torch.nn.functional.relu(conv(x))) for conv in self.convs]
outputs = torch.cat(rep, 1)

outputs = outputs.reshape(outputs.size(0), -1)

# after-CNN dropout
if self.dropout:
outputs = self.dropout(outputs)
if self.locked_dropout:
outputs = self.locked_dropout(outputs)

# extract embeddings from CNN
for sentence_no, length in enumerate(lengths):
embedding = outputs[sentence_no]

if self.static_embeddings:
embedding = embedding.detach()

sentence = sentences[sentence_no]
sentence.set_embedding(self.name, embedding)

def _apply(self, fn):
for child_module in self.children():
child_module._apply(fn)
2 changes: 1 addition & 1 deletion flair/models/text_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(

super(TextClassifier, self).__init__()

self.document_embeddings: flair.embeddings.DocumentRNNEmbeddings = document_embeddings
self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings
self.label_dictionary: Dictionary = label_dictionary
self.label_type = label_type

Expand Down
18 changes: 18 additions & 0 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
FlairEmbeddings,
DocumentRNNEmbeddings,
DocumentLMEmbeddings, TransformerWordEmbeddings, TransformerDocumentEmbeddings,
DocumentCNNEmbeddings,
)

from flair.data import Sentence, Dictionary
Expand Down Expand Up @@ -287,4 +288,21 @@ def test_transformer_document_embeddings():

sentence.clear_embeddings()

del embeddings

def test_document_cnn_embeddings():
sentence: Sentence = Sentence("I love Berlin. Berlin is a great place to live.")

embeddings: DocumentCNNEmbeddings = DocumentCNNEmbeddings(
[glove, flair_embedding], kernels=((50, 2), (50, 3))
)

embeddings.embed(sentence)

assert len(sentence.get_embedding()) == 100
assert len(sentence.get_embedding()) == embeddings.embedding_length

sentence.clear_embeddings()

assert len(sentence.get_embedding()) == 0
del embeddings