Skip to content

Commit

Permalink
Speed up startup, massively decrease ram usage
Browse files Browse the repository at this point in the history
By only instanciating the `counterfitted_GLOVE_embedding` when
necessary, the startup time gets cut by two thirds, while the ram usage
decreases by at least two gigabytes.
  • Loading branch information
duesenfranz committed Feb 9, 2022
1 parent 33c9873 commit 613362e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ class ThoughtVector(SentenceEncoder):
"""

def __init__(
self, embedding=WordEmbedding.counterfitted_GLOVE_embedding(), **kwargs
self, embedding=None, **kwargs
):
if embedding is None:
embedding = WordEmbedding.counterfitted_GLOVE_embedding()
if not isinstance(embedding, AbstractWordEmbedding):
raise ValueError(
"`embedding` object must be of type `textattack.shared.AbstractWordEmbedding`."
Expand Down
4 changes: 3 additions & 1 deletion textattack/constraints/semantics/word_embedding_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@ class WordEmbeddingDistance(Constraint):

def __init__(
self,
embedding=WordEmbedding.counterfitted_GLOVE_embedding(),
embedding=None,
include_unknown_words=True,
min_cos_sim=None,
max_mse_dist=None,
cased=False,
compare_against_original=True,
):
super().__init__(compare_against_original)
if embedding is None:
embedding = WordEmbedding.counterfitted_GLOVE_embedding()
self.include_unknown_words = include_unknown_words
self.cased = cased

Expand Down
4 changes: 3 additions & 1 deletion textattack/transformations/word_swaps/word_swap_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ class WordSwapEmbedding(WordSwap):
def __init__(
self,
max_candidates=15,
embedding=WordEmbedding.counterfitted_GLOVE_embedding(),
embedding=None,
**kwargs
):
super().__init__(**kwargs)
if embedding is None:
embedding = WordEmbedding.counterfitted_GLOVE_embedding()
self.max_candidates = max_candidates
if not isinstance(embedding, AbstractWordEmbedding):
raise ValueError(
Expand Down

0 comments on commit 613362e

Please sign in to comment.