Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
remove changes to contrib/text/embedding.py
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Jul 31, 2019
1 parent c104695 commit 91d64d2
Showing 1 changed file with 7 additions and 22 deletions.
29 changes: 7 additions & 22 deletions python/mxnet/contrib/text/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@
from ... import ndarray as nd
from ... import registry
from ... import base
from ...util import is_np_array
from ... import numpy as _mx_np
from ... import numpy_extension as _mx_npx


def register(embedding_cls):
Expand Down Expand Up @@ -298,15 +295,12 @@ def _load_embedding(self, pretrained_file_path, elem_delim, init_unknown_vec, en
tokens.add(token)

self._vec_len = vec_len
array_fn = _mx_np.array if is_np_array() else nd.array
self._idx_to_vec = array_fn(all_elems).reshape((-1, self.vec_len))
self._idx_to_vec = nd.array(all_elems).reshape((-1, self.vec_len))

if loaded_unknown_vec is None:
init_val = init_unknown_vec(shape=self.vec_len)
self._idx_to_vec[C.UNKNOWN_IDX] =\
init_val.as_np_ndarray() if is_np_array() else init_val
self._idx_to_vec[C.UNKNOWN_IDX] = init_unknown_vec(shape=self.vec_len)
else:
self._idx_to_vec[C.UNKNOWN_IDX] = array_fn(loaded_unknown_vec)
self._idx_to_vec[C.UNKNOWN_IDX] = nd.array(loaded_unknown_vec)

def _index_tokens_from_vocabulary(self, vocabulary):
self._token_to_idx = vocabulary.token_to_idx.copy() \
Expand Down Expand Up @@ -334,8 +328,7 @@ def _set_idx_to_vec_by_embeddings(self, token_embeddings, vocab_len, vocab_idx_t
"""

new_vec_len = sum(embed.vec_len for embed in token_embeddings)
zeros_fn = _mx_np.zeros if is_np_array() else nd.zeros
new_idx_to_vec = zeros_fn(shape=(vocab_len, new_vec_len))
new_idx_to_vec = nd.zeros(shape=(vocab_len, new_vec_len))

col_start = 0
# Concatenate all the embedding vectors in token_embeddings.
Expand Down Expand Up @@ -404,13 +397,7 @@ def get_vecs_by_tokens(self, tokens, lower_case_backup=False):
else self.token_to_idx.get(token.lower(), C.UNKNOWN_IDX)
for token in tokens]

if is_np_array():
embedding_fn = _mx_npx.embedding
array_fn = _mx_np.array
else:
embedding_fn = nd.Embedding
array_fn = nd.array
vecs = embedding_fn(array_fn(indices), self.idx_to_vec, self.idx_to_vec.shape[0],
vecs = nd.Embedding(nd.array(indices), self.idx_to_vec, self.idx_to_vec.shape[0],
self.idx_to_vec.shape[1])

return vecs[0] if to_reduce else vecs
Expand Down Expand Up @@ -438,8 +425,7 @@ def update_token_vectors(self, tokens, new_vectors):
if not isinstance(tokens, list):
tokens = [tokens]
if len(new_vectors.shape) == 1:
expand_dims_fn = _mx_np.expand_dims if is_np_array() else nd.expand_dims
new_vectors = expand_dims_fn(new_vectors, axis=0)
new_vectors = new_vectors.expand_dims(0)

else:
assert isinstance(new_vectors, nd.NDArray) and len(new_vectors.shape) == 2, \
Expand All @@ -458,8 +444,7 @@ def update_token_vectors(self, tokens, new_vectors):
'`unknown_token` %s in `tokens`. This is to avoid unintended '
'updates.' % (token, self.idx_to_token[C.UNKNOWN_IDX]))

array_fn = _mx_np.array if is_np_array() else nd.array
self._idx_to_vec[array_fn(indices)] = new_vectors
self._idx_to_vec[nd.array(indices)] = new_vectors

@classmethod
def _check_pretrained_file_names(cls, pretrained_file_name):
Expand Down

0 comments on commit 91d64d2

Please sign in to comment.