From 14d7499431e4e90efdebc832250f316c81dd019b Mon Sep 17 00:00:00 2001 From: Aston Zhang <22279212+astonzhang@users.noreply.github.com> Date: Mon, 26 Mar 2018 23:03:14 -0700 Subject: [PATCH] Add gluon.text vocab/embedding demo (#18) * Add word embedding example * clean * Add text descriptions --- example/gluon/word_embedding.ipynb | 1049 ++++++++++++++++++++++++++ python/mxnet/gluon/text/embedding.py | 6 + 2 files changed, 1055 insertions(+) create mode 100644 example/gluon/word_embedding.ipynb diff --git a/example/gluon/word_embedding.ipynb b/example/gluon/word_embedding.ipynb new file mode 100644 index 000000000000..f3c321793ab6 --- /dev/null +++ b/example/gluon/word_embedding.ipynb @@ -0,0 +1,1049 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using Pre-trained Word Embeddings" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we introduce how to use pre-trained word embeddings via `mxnet.gluon.text`. \n", + "\n", + "The used GloVe and fastText word embeddings in this tutorial are from the following sources:\n", + "\n", + "* GloVe project website:https://nlp.stanford.edu/projects/glove/\n", + "* fastText project website:https://fasttext.cc/\n", + "\n", + "Let us first import the following packages." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:34.447895Z", + "start_time": "2018-03-27T00:03:33.503038Z" + } + }, + "outputs": [], + "source": [ + "from mxnet import gluon\n", + "from mxnet import nd\n", + "from mxnet.gluon import text\n", + "from collections import Counter" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating Vocabulary with Word Embeddings\n", + "\n", + "As a common use case, let us index words, attach pre-trained word embeddings for them, and use such embeddings in `gluon` in just a few lines of code.\n", + "\n", + "### Creating Vocabulary from Data Sets\n", + "\n", + "To begin with, suppose that we have a simple text data set in the string format. We can count word frequency in the data set." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:34.453636Z", + "start_time": "2018-03-27T00:03:34.449760Z" + } + }, + "outputs": [], + "source": [ + "data = \" hello world \\n hello nice world \\n hi world \\n\"\n", + "counter = text.utils.count_tokens_from_str(data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The obtained `counter` has key-value pairs whose keys are words and values are word frequencies. This allows us to filter out infrequent words via `Vocabulary` arguments such as `max_size` and `min_freq`. Suppose that we want to build indices for all the keys in counter. We need a `Vocabulary` instance with counter as its argument." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:34.459747Z", + "start_time": "2018-03-27T00:03:34.456473Z" + } + }, + "outputs": [], + "source": [ + "vocab = text.vocab.Vocabulary(counter)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To attach word embedding to indexed words in `vocab`, let us go on to create a fastText word embedding instance by specifying the embedding name `fasttext` and the pre-trained file name `wiki.simple.vec`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:53.199585Z", + "start_time": "2018-03-27T00:03:34.462702Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/astonz/WorkDocs/Programs/git_repo/mxnet/python/mxnet/gluon/text/embedding.py:264: UserWarning: At line 1 of the pre-trained token embedding file: token 111051 with 1-dimensional vector [300.0] is likely a header and is skipped.\n", + " 'skipped.' % (line_num, token, elems))\n" + ] + } + ], + "source": [ + "fasttext_simple = text.embedding.create('fasttext', file_name='wiki.simple.vec')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So we can attach word embedding `fasttext_simple` to indexed words in `vocab`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:53.214582Z", + "start_time": "2018-03-27T00:03:53.201953Z" + } + }, + "outputs": [], + "source": [ + "vocab.set_embedding(fasttext_simple)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To see other pre-trained file names under the fastText word embedding, we can use `text.embedding.get_file_names`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:53.240556Z", + "start_time": "2018-03-27T00:03:53.217839Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['crawl-300d-2M.vec',\n", + " 'wiki.aa.vec',\n", + " 'wiki.ab.vec',\n", + " 'wiki.ace.vec',\n", + " 'wiki.ady.vec']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text.embedding.get_file_names('fasttext')[:5]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The created vocabulary `vocab` includes four different words and a special unknown token. Let us check the size of `vocab`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:53.250542Z", + "start_time": "2018-03-27T00:03:53.243313Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "5" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(vocab)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default, the vector of any token that is unknown to `vocab` is a zero vector. Its length is equal to the vector dimension of the fastText word embeddings: 300." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:53.262146Z", + "start_time": "2018-03-27T00:03:53.253051Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(300,)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vocab.embedding['beautiful'].shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The first five elements of the vector of any unknown token are zeros." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:53.273198Z", + "start_time": "2018-03-27T00:03:53.264987Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\n", + "[ 0. 0. 0. 0. 0.]\n", + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vocab.embedding['beautiful'][:5]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us check the shape of the vectors of words 'hello' and 'world' from `vocab`." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:53.283862Z", + "start_time": "2018-03-27T00:03:53.276282Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(2, 300)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vocab.embedding['hello', 'world'].shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-26T23:29:07.340108Z", + "start_time": "2018-03-26T23:29:07.334790Z" + } + }, + "source": [ + "We can access the first five elements of the vectors of 'hello' and 'world'." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:53.296482Z", + "start_time": "2018-03-27T00:03:53.287022Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\n", + "[[ 0.39567 0.21454 -0.035389 -0.24299 -0.095645 ]\n", + " [ 0.10444 -0.10858 0.27212 0.13299 -0.33164999]]\n", + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vocab.embedding['hello', 'world'][:, :5]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using Pre-trained Word Embeddings in `gluon.nn.Embedding`\n", + "\n", + "To demonstrate how to use pre-trained word embeddings in the `gluon` package, let us first obtain indices of the words 'hello' and 'world'." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:53.306574Z", + "start_time": "2018-03-27T00:03:53.300400Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[2, 1]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vocab['hello', 'world']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can obtain the vectors for the words 'hello' and 'world' by specifying their indices (2 and 1) and the weight matrix `vocab.embedding.idx_to_vec` in `gluon.nn.Embedding`." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:53.327785Z", + "start_time": "2018-03-27T00:03:53.309979Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\n", + "[[ 0.39567 0.21454 -0.035389 -0.24299 -0.095645 ]\n", + " [ 0.10444 -0.10858 0.27212 0.13299 -0.33164999]]\n", + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_dim, output_dim = vocab.embedding.idx_to_vec.shape\n", + "layer = gluon.nn.Embedding(input_dim, output_dim)\n", + "layer.initialize()\n", + "layer.weight.set_data(vocab.embedding.idx_to_vec)\n", + "layer(nd.array([2, 1]))[:, :5]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Creating Vocabulary from Pre-trained Word Embeddings\n", + "\n", + "We can also create vocabulary by using vocabulary of pre-trained word embeddings, such as GloVe. Below are a few pre-trained file names under the GloVe word embedding." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:03:53.338638Z", + "start_time": "2018-03-27T00:03:53.330822Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['glove.42B.300d.txt',\n", + " 'glove.6B.50d.txt',\n", + " 'glove.6B.100d.txt',\n", + " 'glove.6B.200d.txt',\n", + " 'glove.6B.300d.txt']" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text.embedding.get_file_names('glove')[:5]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For simplicity of demonstration, we use a smaller word embedding file, such as the 50-dimensional one. " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:04.229138Z", + "start_time": "2018-03-27T00:03:53.341827Z" + } + }, + "outputs": [], + "source": [ + "glove_6b50d = text.embedding.create('glove', file_name='glove.6B.50d.txt')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we create vocabulary by using all the tokens from `glove_6b50d`." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:06.032364Z", + "start_time": "2018-03-27T00:04:04.231212Z" + } + }, + "outputs": [], + "source": [ + "vocab = text.vocab.Vocabulary(Counter(glove_6b50d.idx_to_token))\n", + "vocab.set_embedding(glove_6b50d)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Below shows the size of `vocab` including a special unknown token." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:06.042843Z", + "start_time": "2018-03-27T00:04:06.034933Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "400001" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(vocab.idx_to_token)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can access attributes of `vocab`." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:06.056449Z", + "start_time": "2018-03-27T00:04:06.046106Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "71421\n", + "beautiful\n" + ] + } + ], + "source": [ + "print(vocab['beautiful'])\n", + "print(vocab.idx_to_token[71421])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Applications of Word Embeddings\n", + "\n", + "To apply word embeddings, we need to define cosine similarity. It can compare similarity of two vectors." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:06.067188Z", + "start_time": "2018-03-27T00:04:06.059379Z" + } + }, + "outputs": [], + "source": [ + "from mxnet import nd\n", + "def cos_sim(x, y):\n", + " return nd.dot(x, y) / (nd.norm(x) * nd.norm(y))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The range of cosine similarity between two vectors is between -1 and 1. The larger the value, the similarity between two vectors." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:06.272263Z", + "start_time": "2018-03-27T00:04:06.070098Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "[ 1.]\n", + "\n", + "\n", + "[-1.]\n", + "\n" + ] + } + ], + "source": [ + "x = nd.array([1, 2])\n", + "y = nd.array([10, 20])\n", + "z = nd.array([-1, -2])\n", + "\n", + "print(cos_sim(x, y))\n", + "print(cos_sim(x, z))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Word Similarity\n", + "\n", + "Given an input word, we can find the nearest $k$ words from the vocabulary (400,000 words excluding the unknown token) by similarity. The similarity between any pair of words can be represented by the cosine similarity of their vectors." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:06.292283Z", + "start_time": "2018-03-27T00:04:06.274721Z" + } + }, + "outputs": [], + "source": [ + "def norm_vecs_by_row(x):\n", + " return x / nd.sqrt(nd.sum(x * x, axis=1)).reshape((-1,1))\n", + "\n", + "def get_knn(vocab, k, word):\n", + " word_vec = vocab.embedding[word].reshape((-1, 1))\n", + " vocab_vecs = norm_vecs_by_row(vocab.embedding.idx_to_vec)\n", + " dot_prod = nd.dot(vocab_vecs, word_vec)\n", + " indices = nd.topk(dot_prod.reshape((len(vocab), )), k=k+2, ret_typ='indices')\n", + " indices = [int(i.asscalar()) for i in indices]\n", + " # Remove unknown and input tokens.\n", + " return vocab.to_tokens(indices[2:])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us find the 5 most similar words of 'baby' from the vocabulary (size: 400,000 words)." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:06.687950Z", + "start_time": "2018-03-27T00:04:06.295771Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['babies', 'boy', 'girl', 'newborn', 'pregnant']" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_knn(vocab, 5, 'baby')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can verify the cosine similarity of vectors of 'baby' and 'babies'." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:06.698920Z", + "start_time": "2018-03-27T00:04:06.691103Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\n", + "[ 0.83871299]\n", + "" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cos_sim(vocab.embedding['baby'], vocab.embedding['babies'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us find the 5 most similar words of 'computers' from the vocabulary." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:07.084357Z", + "start_time": "2018-03-27T00:04:06.702292Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['computer', 'phones', 'pcs', 'machines', 'devices']" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_knn(vocab, 5, 'computers')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us find the 5 most similar words of 'run' from the vocabulary." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:07.504323Z", + "start_time": "2018-03-27T00:04:07.087221Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['running', 'runs', 'went', 'start', 'ran']" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_knn(vocab, 5, 'run')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us find the 5 most similar words of 'beautiful' from the vocabulary." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:07.967072Z", + "start_time": "2018-03-27T00:04:07.507039Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['lovely', 'gorgeous', 'wonderful', 'charming', 'beauty']" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_knn(vocab, 5, 'beautiful')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Word Analogy\n", + "\n", + "We can also apply pre-trained word embeddings to the word analogy problem. For instance, \"man : woman :: son : daughter\" is an analogy. The word analogy completion problem is defined as: for analogy 'a : b :: c : d', given teh first three words 'a', 'b', 'c', find 'd'. The idea is to find the most similar word vector for vec('c') + (vec('b')-vec('a')).\n", + "\n", + "In this example, we will find words by analogy from the 400,000 indexed words in `vocab`." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:08.040101Z", + "start_time": "2018-03-27T00:04:07.973776Z" + } + }, + "outputs": [], + "source": [ + "def get_top_k_by_analogy(vocab, k, word1, word2, word3):\n", + " word_vecs = vocab.embedding[word1, word2, word3]\n", + " word_diff = (word_vecs[1] - word_vecs[0] + word_vecs[2]).reshape((-1, 1))\n", + " vocab_vecs = norm_vecs_by_row(vocab.embedding.idx_to_vec)\n", + " dot_prod = nd.dot(vocab_vecs, word_diff)\n", + " indices = nd.topk(dot_prod.reshape((len(vocab), )), k=k+1, ret_typ='indices')\n", + " indices = [int(i.asscalar()) for i in indices]\n", + "\n", + " # Filter out unknown tokens.\n", + " if vocab.to_tokens(indices[0]) == vocab.unknown_token:\n", + " return vocab.to_tokens(indices[1:])\n", + " else:\n", + " return vocab.to_tokens(indices[:-1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Complete word analogy 'man : woman :: son :'." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:08.519697Z", + "start_time": "2018-03-27T00:04:08.051060Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['daughter']" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_top_k_by_analogy(vocab, 1, 'man', 'woman', 'son')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us verify the cosine similarity between vec('son')+vec('woman')-vec('man') and vec('daughter')" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:08.535690Z", + "start_time": "2018-03-27T00:04:08.522548Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\n", + "[ 0.9658342]\n", + "" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def cos_sim_word_analogy(vocab, word1, word2, word3, word4):\n", + " words = [word1, word2, word3, word4]\n", + " vecs = vocab.embedding[words]\n", + " return cos_sim(vecs[1] - vecs[0] + vecs[2], vecs[3])\n", + "\n", + "cos_sim_word_analogy(vocab, 'man', 'woman', 'son', 'daughter')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Complete word analogy 'beijing : china :: tokyo : '." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:08.939664Z", + "start_time": "2018-03-27T00:04:08.538918Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['japan']" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_top_k_by_analogy(vocab, 1, 'beijing', 'china', 'tokyo')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Complete word analogy 'bad : worst :: big : '." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:09.319291Z", + "start_time": "2018-03-27T00:04:08.942078Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['biggest']" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_top_k_by_analogy(vocab, 1, 'bad', 'worst', 'big')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Complete word analogy 'do : did :: go :'." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "ExecuteTime": { + "end_time": "2018-03-27T00:04:09.735225Z", + "start_time": "2018-03-27T00:04:09.323663Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['went']" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_top_k_by_analogy(vocab, 1, 'do', 'did', 'go')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/mxnet/gluon/text/embedding.py b/python/mxnet/gluon/text/embedding.py index 1839212ee825..fcbc6dfb88a6 100644 --- a/python/mxnet/gluon/text/embedding.py +++ b/python/mxnet/gluon/text/embedding.py @@ -155,6 +155,8 @@ class TokenEmbedding(object): Properties ---------- + idx_to_token : list of strs + A list of indexed tokens where the list indices and the token indices are aligned. idx_to_vec : mxnet.ndarray.NDArray For all the indexed tokens in this embedding, this NDArray maps each token's index to an embedding vector. @@ -284,6 +286,10 @@ def _load_embedding(self, pretrained_file_path, elem_delim, init_unknown_vec, en else: self._idx_to_vec[C.UNKNOWN_IDX] = nd.array(loaded_unknown_vec) + @property + def idx_to_token(self): + return self._idx_to_token + @property def idx_to_vec(self): return self._idx_to_vec