Skip to content

Commit

Permalink
Merge pull request #503 from zalandoresearch/GH-502-pubmed-elmo
Browse files Browse the repository at this point in the history
GH-502: add support for pubmed ELMo model
  • Loading branch information
Alan Akbik authored Feb 15, 2019
2 parents 4b3562e + 270d6ef commit d853ca5
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ def __init__(self, model: str = 'original'):
if model == 'pt' or model == 'portuguese':
options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_options.json'
weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_weights.hdf5'
if model == 'pubmed':
options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_options.json'
weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5'

# put on Cuda if available
from flair import device
Expand Down Expand Up @@ -508,7 +511,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
longest_token_in_sentence = max(chars2_length)
tokens_mask = torch.zeros((len(tokens_sorted_by_length), longest_token_in_sentence),
dtype=torch.long, device=flair.device)

for i, c in enumerate(tokens_sorted_by_length):
tokens_mask[i, :chars2_length[i]] = torch.tensor(c, dtype=torch.long, device=flair.device)

Expand Down

0 comments on commit d853ca5

Please sign in to comment.