From 62cdfe1684e6265a329378ab26f817712a51367d Mon Sep 17 00:00:00 2001 From: Zhicharevich Date: Thu, 7 Nov 2019 09:31:24 +0200 Subject: [PATCH 01/10] fixed get_keras_embedding, now accepts word mapping --- docs/notebooks/keras_wrapper.ipynb | 582 ++++------------------------- gensim/models/keyedvectors.py | 90 +++-- gensim/test/test_keyedvectors.py | 30 +- 3 files changed, 155 insertions(+), 547 deletions(-) diff --git a/docs/notebooks/keras_wrapper.ipynb b/docs/notebooks/keras_wrapper.ipynb index 28f2a3e00a..a99ed262b5 100644 --- a/docs/notebooks/keras_wrapper.ipynb +++ b/docs/notebooks/keras_wrapper.ipynb @@ -29,189 +29,6 @@ "### Word2Vec" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To use Word2Vec, we import the corresponding module." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "from gensim.models import word2vec" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next we create a dummy set of sentences to train our Word2Vec model." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "sentences = [\n", - " ['human', 'interface', 'computer'],\n", - " ['survey', 'user', 'computer', 'system', 'response', 'time'],\n", - " ['eps', 'user', 'interface', 'system'],\n", - " ['system', 'human', 'system', 'eps'],\n", - " ['user', 'response', 'time'],\n", - " ['trees'],\n", - " ['graph', 'trees'],\n", - " ['graph', 'minors', 'trees'],\n", - " ['graph', 'minors', 'survey']\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then, we create the Word2Vec model by passing appropriate parameters." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "model = word2vec.Word2Vec(sentences, size=100, min_count=1, hs=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": true - }, - "source": [ - "#### Integration with Keras : Cosine Similarity Task" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As an example of integration of Gensim's Word2Vec model with Keras, we consider a word similarity task where we compute the cosine distance as a measure of similarity between the two words." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using TensorFlow backend.\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "from keras.engine import Input\n", - "from keras.models import Model\n", - "from keras.layers.merge import dot" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We would use the layer returned by the function `get_keras_embedding` in the Keras model." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "wv = model.wv\n", - "embedding_layer = wv.get_keras_embedding()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we construct the Keras model. " - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:From /home/misha/envs/gensim/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", - "Instructions for updating:\n", - "Colocations handled automatically by placer.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/misha/envs/gensim/lib/python3.7/site-packages/ipykernel_launcher.py:7: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=[ 0:\n", - " file_data = file_data[i:]\n", - " try:\n", - " curr_str = str(file_data)\n", - " sentence_list = curr_str.split('\\n')\n", - " for sentence in sentence_list:\n", - " sentence = (sentence.strip()).lower()\n", - " texts.append(sentence)\n", - " texts_w2v.append(sentence.split(' '))\n", - " labels.append(label_id)\n", - " except:\n", - " None" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then, we format our text samples and labels into tensors that can be fed into a neural network. To do this, we rely on Keras utilities `keras.preprocessing.text.Tokenizer` and `keras.preprocessing.sequence.pad_sequences`." + "We first load the training data.\n", + "Then, we format our text samples and labels into tensors that can be fed into a neural network. To do this, we rely on Keras utilities `keras.preprocessing.text.Tokenizer`, `keras.preprocessing.sequence.pad_sequences` and `from keras.utils.np_utils import to_categorical`.\n" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 164, "metadata": {}, "outputs": [], "source": [ + "dataset = fetch_20newsgroups(subset='train', categories=['alt.atheism', 'comp.graphics', 'sci.space'])\n", + "\n", "MAX_SEQUENCE_LENGTH = 1000\n", "\n", "# Vectorize the text samples into a 2D integer tensor\n", "tokenizer = Tokenizer()\n", - "tokenizer.fit_on_texts(texts)\n", - "sequences = tokenizer.texts_to_sequences(texts)\n", - "\n", - "# word_index = tokenizer.word_index\n", - "data = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH)\n", - "labels = to_categorical(np.asarray(labels))\n", + "tokenizer.fit_on_texts(dataset.data)\n", + "sequences = tokenizer.texts_to_sequences(dataset.data)\n", "\n", - "x_train = data\n", - "y_train = labels" + "x_train = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH)\n", + "y_train = to_categorical(np.asarray(dataset.target))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "As the next step, we prepare the embedding layer to be used in our actual Keras model." + "Now we train a Word2Vec model from the documents we have.\n", + "From the word2vec model we construct the embedding layer to be used in our actual Keras model.\n", + "\n", + "The Keras tokenizer object maintains an internal vocabulary (a token to index mapping), which might be different from the vocabulary gensim builds when training the word2vec model. To align the vocabularies we pass the Keras tokenizer vocabulary to the `get_keras_embedding` function" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 165, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/misha/envs/gensim/lib/python3.7/site-packages/ipykernel_launcher.py:3: DeprecationWarning: Call to deprecated `iter` (Attribute will be removed in 4.0.0, use self.epochs instead).\n", - " This is separate from the ipykernel package so we can avoid doing imports until\n" - ] - } - ], + "outputs": [], "source": [ - "Keras_w2v = word2vec.Word2Vec(min_count=1)\n", - "Keras_w2v.build_vocab(texts_w2v)\n", - "Keras_w2v.train(texts, total_examples=Keras_w2v.corpus_count, epochs=Keras_w2v.iter)\n", - "Keras_w2v_wv = Keras_w2v.wv\n", - "embedding_layer = Keras_w2v_wv.get_keras_embedding()" + "keras_w2v = word2vec.Word2Vec([text_to_word_sequence(doc) for doc in dataset.data],min_count=0)\n", + "embedding_layer = keras_w2v.wv.get_keras_embedding(word_index = tokenizer.word_index,train_embeddings=True)" ] }, { @@ -366,35 +126,29 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 166, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "WARNING:tensorflow:From /home/misha/envs/gensim/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", - "Instructions for updating:\n", - "Use tf.cast instead.\n", - "Epoch 1/5\n", - "96/96 [==============================] - 1s 9ms/step - loss: 1.0589 - acc: 0.4896\n", - "Epoch 2/5\n", - "96/96 [==============================] - 1s 7ms/step - loss: 0.8601 - acc: 0.6354\n", - "Epoch 3/5\n", - "96/96 [==============================] - 1s 7ms/step - loss: 0.9060 - acc: 0.6354\n", - "Epoch 4/5\n", - "96/96 [==============================] - 1s 6ms/step - loss: 0.8576 - acc: 0.6354\n", - "Epoch 5/5\n", - "96/96 [==============================] - 1s 6ms/step - loss: 0.8527 - acc: 0.6250\n" + "Train on 1491 samples, validate on 166 samples\n", + "Epoch 1/3\n", + "1491/1491 [==============================] - 16s 11ms/step - loss: 1.0239 - acc: 0.5017 - val_loss: 0.9306 - val_acc: 0.5663\n", + "Epoch 2/3\n", + "1491/1491 [==============================] - 15s 10ms/step - loss: 0.6941 - acc: 0.7015 - val_loss: 0.6612 - val_acc: 0.7048\n", + "Epoch 3/3\n", + "1491/1491 [==============================] - 15s 10ms/step - loss: 0.4270 - acc: 0.8404 - val_loss: 0.5119 - val_acc: 0.7892\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 12, + "execution_count": 166, "metadata": {}, "output_type": "execute_result" } @@ -413,275 +167,85 @@ "preds = Dense(y_train.shape[1], activation='softmax')(x)\n", "\n", "model = Model(sequence_input, preds)\n", - "model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc'])\n", + "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])\n", "\n", - "model.fit(x_train, y_train, epochs=5)" + "model.fit(x_train, y_train, epochs=3, validation_split= 0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "As can be seen from the results above, the accuracy obtained is not that high. This is because of the small size of training data used and we could expect to obtain better accuracy for training data of larger size." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": true - }, - "source": [ - "#### Integration with Keras : Another classification task" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this task, we train our model to predict the category of the input text. We start by importing the relevant modules and libraries : " - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "from keras.models import Sequential\n", - "from keras.layers import Dropout\n", - "from keras.regularizers import l2\n", - "from keras.models import Model\n", - "from keras.engine import Input\n", - "from keras.preprocessing.sequence import pad_sequences\n", - "from keras.preprocessing.text import Tokenizer\n", - "from gensim.models import keyedvectors\n", - "from collections import defaultdict\n", + "We see that the model learns to reaches a reasonable accuracy, considering the small dataset.\n", "\n", - "import pandas as pd" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We now define some global variables and utility functions which would be used in the code further : " + "Alternatively, we can use embeddings pretrained on a different larger corpus (Glove), to see if performance impoves" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 167, "metadata": {}, "outputs": [], "source": [ - "# global variables\n", - "\n", - "nb_filters = 1200 # number of filters\n", - "n_gram = 2 # n-gram, or window size of CNN/ConvNet\n", - "maxlen = 15 # maximum number of words in a sentence\n", - "vecsize = 300 # length of the embedded vectors in the model \n", - "cnn_dropout = 0.0 # dropout rate for CNN/ConvNet\n", - "final_activation = 'softmax' # activation function. Options: softplus, softsign, relu, tanh, sigmoid, hard_sigmoid, linear.\n", - "dense_wl2reg = 0.0 # dense_wl2reg: L2 regularization coefficient\n", - "dense_bl2reg = 0.0 # dense_bl2reg: L2 regularization coefficient for bias\n", - "optimizer = 'adam' # optimizer for gradient descent. Options: sgd, rmsprop, adagrad, adadelta, adam, adamax, nadam\n", - "\n", - "# utility functions\n", - "\n", - "def retrieve_csvdata_as_dict(filepath):\n", - " \"\"\"\n", - " Retrieve the training data in a CSV file, with the first column being the\n", - " class labels, and second column the text data. It returns a dictionary with\n", - " the class labels as keys, and a list of short texts as the value for each key.\n", - " \"\"\"\n", - " df = pd.read_csv(filepath)\n", - " category_col, descp_col = df.columns.values.tolist()\n", - " shorttextdict = dict()\n", - " for category, descp in zip(df[category_col], df[descp_col]):\n", - " if type(descp) == str:\n", - " shorttextdict.setdefault(category, []).append(descp)\n", - " return shorttextdict\n", - "\n", - "def subjectkeywords():\n", - " \"\"\"\n", - " Return an example data set, with three subjects and corresponding keywords.\n", - " This is in the format of the training input.\n", - " \"\"\"\n", - " data_path = os.path.join(os.getcwd(), 'datasets/keras_classifier_training_data.csv')\n", - " return retrieve_csvdata_as_dict(data_path)\n", - "\n", - "def convert_trainingdata(classdict):\n", - " \"\"\"\n", - " Convert the training data into format put into the neural networks.\n", - " \"\"\"\n", - " classlabels = classdict.keys()\n", - " lblidx_dict = dict(zip(classlabels, range(len(classlabels))))\n", - "\n", - " # tokenize the words, and determine the word length\n", - " phrases = []\n", - " indices = []\n", - " for label in classlabels:\n", - " for shorttext in classdict[label]:\n", - " shorttext = shorttext if type(shorttext) == str else ''\n", - " category_bucket = [0]*len(classlabels)\n", - " category_bucket[lblidx_dict[label]] = 1\n", - " indices.append(category_bucket)\n", - " phrases.append(shorttext)\n", - "\n", - " return classlabels, phrases, indices\n", + "import gensim.downloader as api\n", "\n", - "def process_text(text):\n", - " \"\"\" \n", - " Process the input text by tokenizing and padding it.\n", - " \"\"\"\n", - " tokenizer = Tokenizer()\n", - " tokenizer.fit_on_texts(text)\n", - " x_train = tokenizer.texts_to_sequences(text)\n", - "\n", - " x_train = pad_sequences(x_train, maxlen=maxlen)\n", - " return x_train" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We create our word2vec model first. We could either train our model or user pre-trained vectors." + "glove_embeddings = api.load(\"glove-wiki-gigaword-100\")" ] }, { "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "# we are training our Word2Vec model here\n", - "w2v_training_data_path = os.path.join(os.getcwd(), 'datasets/word_vectors_training_data.txt')\n", - "input_data = word2vec.LineSentence(w2v_training_data_path)\n", - "w2v_model = word2vec.Word2Vec(input_data, size=300)\n", - "w2v_model_wv = w2v_model.wv\n", - "\n", - "# Alternatively we could have imported pre-trained word-vectors like : \n", - "# w2v_model_wv = keyedvectors.KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin.gz', binary=True)\n", - "# The dataset 'GoogleNews-vectors-negative300.bin.gz' can be downloaded from https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We load the training data for the Keras model." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "trainclassdict = subjectkeywords()\n", - "\n", - "nb_labels = len(trainclassdict) # number of class labels" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we create out Keras model." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "# get embedding layer corresponding to our trained Word2Vec model\n", - "embedding_layer = w2v_model_wv.get_keras_embedding()\n", - "\n", - "# create a convnet to solve our classification task\n", - "sequence_input = Input(shape=(maxlen,), dtype='int32')\n", - "embedded_sequences = embedding_layer(sequence_input)\n", - "x = Conv1D(filters=nb_filters, kernel_size=n_gram, padding='valid', activation='relu', input_shape=(maxlen, vecsize))(embedded_sequences)\n", - "x = MaxPooling1D(pool_size=maxlen - n_gram + 1)(x)\n", - "x = Flatten()(x)\n", - "preds = Dense(nb_labels, activation=final_activation, kernel_regularizer=l2(dense_wl2reg), bias_regularizer=l2(dense_bl2reg))(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we train the classifier." - ] - }, - { - "cell_type": "code", - "execution_count": 18, + "execution_count": 168, "metadata": {}, "outputs": [ { - "ename": "ValueError", - "evalue": "Error when checking model target: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 1 array(s), but instead got the following list of 45 arrays: [array([[1],\n [0],\n [0]]), array([[1],\n [0],\n [0]]), array([[1],\n [0],\n [0]]), array([[1],\n [0],\n [0]]), array([[1],\n [0],\n [0]]), array([[1...", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msequence_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'categorical_crossentropy'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'rmsprop'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'acc'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mfit_ret_val\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/envs/gensim/lib/python3.7/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)\u001b[0m\n\u001b[1;32m 950\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 951\u001b[0m \u001b[0mclass_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclass_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 952\u001b[0;31m batch_size=batch_size)\n\u001b[0m\u001b[1;32m 953\u001b[0m \u001b[0;31m# Prepare validation data.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 954\u001b[0m \u001b[0mdo_validation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/envs/gensim/lib/python3.7/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36m_standardize_user_data\u001b[0;34m(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)\u001b[0m\n\u001b[1;32m 787\u001b[0m \u001b[0mfeed_output_shapes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 788\u001b[0m \u001b[0mcheck_batch_axis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# Don't enforce the batch size.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 789\u001b[0;31m exception_prefix='target')\n\u001b[0m\u001b[1;32m 790\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 791\u001b[0m \u001b[0;31m# Generate sample-wise weight values given the `sample_weight` and\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/envs/gensim/lib/python3.7/site-packages/keras/engine/training_utils.py\u001b[0m in \u001b[0;36mstandardize_input_data\u001b[0;34m(data, names, shapes, check_batch_axis, exception_prefix)\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0;34m'Expected to see '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' array(s), '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;34m'but instead got the following list of '\u001b[0m \u001b[0;34m+\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 102\u001b[0;31m str(len(data)) + ' arrays: ' + str(data)[:200] + '...')\n\u001b[0m\u001b[1;32m 103\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m raise ValueError(\n", - "\u001b[0;31mValueError\u001b[0m: Error when checking model target: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 1 array(s), but instead got the following list of 45 arrays: [array([[1],\n [0],\n [0]]), array([[1],\n [0],\n [0]]), array([[1],\n [0],\n [0]]), array([[1],\n [0],\n [0]]), array([[1],\n [0],\n [0]]), array([[1..." + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 1491 samples, validate on 166 samples\n", + "Epoch 1/3\n", + "1491/1491 [==============================] - 17s 11ms/step - loss: 1.0564 - acc: 0.4514 - val_loss: 0.9083 - val_acc: 0.4578\n", + "Epoch 2/3\n", + "1491/1491 [==============================] - 16s 11ms/step - loss: 0.5122 - acc: 0.7901 - val_loss: 0.3278 - val_acc: 0.8855\n", + "Epoch 3/3\n", + "1491/1491 [==============================] - 16s 10ms/step - loss: 0.0902 - acc: 0.9718 - val_loss: 0.2187 - val_acc: 0.9398\n" ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 168, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "classlabels, x_train, y_train = convert_trainingdata(trainclassdict)\n", - "\n", - "tokenizer = Tokenizer()\n", - "tokenizer.fit_on_texts(x_train)\n", - "x_train = tokenizer.texts_to_sequences(x_train)\n", + "glove_embedding_layer = glove_embeddings.get_keras_embedding(word_index = tokenizer.word_index,train_embeddings=True)\n", "\n", - "x_train = pad_sequences(x_train, maxlen=maxlen)\n", + "embedded_sequences = glove_embedding_layer(sequence_input)\n", + "x = Conv1D(128, 5, activation='relu')(embedded_sequences)\n", + "x = MaxPooling1D(5)(x)\n", + "x = Conv1D(128, 5, activation='relu')(x)\n", + "x = MaxPooling1D(5)(x)\n", + "x = Conv1D(128, 5, activation='relu')(x)\n", + "x = MaxPooling1D(35)(x) # global max pooling\n", + "x = Flatten()(x)\n", + "x = Dense(128, activation='relu')(x)\n", + "preds = Dense(y_train.shape[1], activation='softmax')(x)\n", "\n", "model = Model(sequence_input, preds)\n", - "model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc'])\n", - "fit_ret_val = model.fit(x_train, y_train, epochs=10)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Our classifier is now ready to predict classes for input data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "input_text = 'artificial intelligence'\n", - "\n", - "matrix = process_text(input_text)\n", - "\n", - "predictions = model.predict(matrix)\n", - "\n", - "# get the actual categories from output\n", - "scoredict = {}\n", - "for idx, classlabel in zip(range(len(classlabels)), classlabels):\n", - " scoredict[classlabel] = predictions[0][idx]\n", + "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])\n", "\n", - "print scoredict" + "model.fit(x_train, y_train, epochs=3, validation_split= 0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The result above clearly suggests (~ 98% probability!) that the input `artificial intelligence` should belong to the category `mathematics`, which conforms very well with the expected output in this case.\n", - "In general, the output could depend on several factors including the number of filters for the conv-net, the training data for the word-vectors, the training data for the classifier etc." + "We see that pretrained embeddings result in a faster convergence" ] } ], @@ -701,7 +265,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.1" + "version": "3.6.4" } }, "nbformat": 4, diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index 467b740ed7..5937f1d6a0 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -1383,6 +1383,59 @@ def relative_cosine_similarity(self, wa, wb, topn=10): return rcs + def get_keras_embedding(self, train_embeddings=False, word_index = None): + """Get a Keras 'Embedding' layer with weights set as the Word2Vec model's learned word embeddings. + + Parameters + ---------- + train_embeddings : bool + If False, the weights are frozen and stopped from being updated. + If True, the weights can/will be further trained/updated. + + word_index : {str : int} + A mapping from tokens to their indices the way they will be provided in the input to the embedding layer. + The embedding of each token will be placed in the corresponding index in the embedding matrix. + Tokens not in the index are ignored and not placed in the returned layer embedding matrix. + This is useful when the token indices are produced by a process that is not coupled with the embedding + model, e.x. an Keras Tokenizer object. + If None, the embedding matrix in the embedding layer will be indexed according to self.vocab + + Returns + ------- + `keras.layers.Embedding` + Embedding layer. + + Raises + ------ + ImportError + If `Keras `_ not installed. + + Warnings + -------- + Current method work only if `Keras `_ installed. + + """ + try: + from keras.layers import Embedding + except ImportError: + raise ImportError("Please install Keras to use this function") + if word_index is None: + weights = self.vectors + else: + max_index = max(word_index.values()) + weights = np.random.normal(size=(max_index + 1, self.vectors.shape[1])) + for word, index in word_index.items(): + if word in self.vocab: + weights[index] = self.get_vector(word) + + # set `trainable` as `False` to use the pretrained word embedding + layer = Embedding( + input_dim=weights.shape[0], output_dim=weights.shape[1], + weights=[weights], trainable=train_embeddings + ) + return layer + + class WordEmbeddingSimilarityIndex(TermSimilarityIndex): """ @@ -1497,43 +1550,6 @@ def load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8', cls, fname, fvocab=fvocab, binary=binary, encoding=encoding, unicode_errors=unicode_errors, limit=limit, datatype=datatype) - def get_keras_embedding(self, train_embeddings=False): - """Get a Keras 'Embedding' layer with weights set as the Word2Vec model's learned word embeddings. - - Parameters - ---------- - train_embeddings : bool - If False, the weights are frozen and stopped from being updated. - If True, the weights can/will be further trained/updated. - - Returns - ------- - `keras.layers.Embedding` - Embedding layer. - - Raises - ------ - ImportError - If `Keras `_ not installed. - - Warnings - -------- - Current method work only if `Keras `_ installed. - - """ - try: - from keras.layers import Embedding - except ImportError: - raise ImportError("Please install Keras to use this function") - weights = self.vectors - - # set `trainable` as `False` to use the pretrained word embedding - # No extra mem usage here as `Embedding` layer doesn't create any new matrix for weights - layer = Embedding( - input_dim=weights.shape[0], output_dim=weights.shape[1], - weights=[weights], trainable=train_embeddings - ) - return layer @classmethod def load(cls, fname_or_handle, **kwargs): diff --git a/gensim/test/test_keyedvectors.py b/gensim/test/test_keyedvectors.py index b002050cc6..b751cd616d 100644 --- a/gensim/test/test_keyedvectors.py +++ b/gensim/test/test_keyedvectors.py @@ -11,6 +11,7 @@ import logging import unittest +from mock import patch import numpy as np @@ -21,7 +22,6 @@ import gensim.models.keyedvectors - logger = logging.getLogger(__name__) @@ -400,6 +400,34 @@ def test_load_model_and_vocab_file_ignore(self): model.get_vector(u'どういたしまして'), np.array([.1, .2, .3], dtype=np.float32))) +class WordEmbeddingsKeyedVectorsTest(unittest.TestCase): + def setUp(self): + self.vectors = EuclideanKeyedVectors.load_word2vec_format( + datapath('euclidean_vectors.bin'), binary=True, datatype=np.float64) + + def test_get_keras_embedding_word_index_none(self): + embedding_layer = self.vectors.get_keras_embedding() + self.assertEqual(self.vectors.vectors.shape, embedding_layer._initial_weights[0].shape) + self.assertTrue(np.array_equal( + self.vectors['is'], embedding_layer._initial_weights[0][self.vectors.vocab['is'].index, :])) + + def test_get_keras_embedding_word_index_passed(self): + word_index = {'is': 1, 'to': 2} + embedding_layer = self.vectors.get_keras_embedding(word_index=word_index) + self.assertEqual( embedding_layer._initial_weights[0].shape, (3, self.vectors.vectors.shape[1])) + self.assertTrue(np.array_equal( + self.vectors['is'], embedding_layer._initial_weights[0][1, :])) + + @patch('numpy.random.normal') + def test_get_keras_embedding_word_index_passed_with_oov_word(self, normal_func): + normal_func.return_value = np.zeros((3, self.vectors.vectors.shape[1])) + word_index = {'is': 1, 'not_a_real_word': 2} + embedding_layer = self.vectors.get_keras_embedding(word_index=word_index) + self.assertEqual( embedding_layer._initial_weights[0].shape, (3, self.vectors.vectors.shape[1])) + self.assertTrue(np.array_equal(embedding_layer._initial_weights[0][2, :], + np.zeros(self.vectors.vectors.shape[1]))) + + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) unittest.main() From 59e1c50de5ea82d63d619e7b281cca8e70005a2f Mon Sep 17 00:00:00 2001 From: Zhicharevich Date: Wed, 11 Dec 2019 14:57:55 +0200 Subject: [PATCH 02/10] skip tests if keras not installed --- gensim/models/keyedvectors.py | 4 +--- gensim/test/test_keyedvectors.py | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index 5937f1d6a0..9abaffc29f 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -1383,7 +1383,7 @@ def relative_cosine_similarity(self, wa, wb, topn=10): return rcs - def get_keras_embedding(self, train_embeddings=False, word_index = None): + def get_keras_embedding(self, train_embeddings=False, word_index=None): """Get a Keras 'Embedding' layer with weights set as the Word2Vec model's learned word embeddings. Parameters @@ -1436,7 +1436,6 @@ def get_keras_embedding(self, train_embeddings=False, word_index = None): return layer - class WordEmbeddingSimilarityIndex(TermSimilarityIndex): """ Computes cosine similarities between word embeddings and retrieves the closest word embeddings @@ -1550,7 +1549,6 @@ def load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8', cls, fname, fvocab=fvocab, binary=binary, encoding=encoding, unicode_errors=unicode_errors, limit=limit, datatype=datatype) - @classmethod def load(cls, fname_or_handle, **kwargs): model = super(WordEmbeddingsKeyedVectors, cls).load(fname_or_handle, **kwargs) diff --git a/gensim/test/test_keyedvectors.py b/gensim/test/test_keyedvectors.py index b751cd616d..9d5acca752 100644 --- a/gensim/test/test_keyedvectors.py +++ b/gensim/test/test_keyedvectors.py @@ -67,7 +67,7 @@ def test_most_similar(self): first_similarities = np.array([similarity for term, similarity in index.most_similar(u"holiday", topn=10)]) index = WordEmbeddingSimilarityIndex(self.vectors, exponent=2.0) second_similarities = np.array([similarity for term, similarity in index.most_similar(u"holiday", topn=10)]) - self.assertTrue(np.allclose(first_similarities**2.0, second_similarities)) + self.assertTrue(np.allclose(first_similarities ** 2.0, second_similarities)) class TestEuclideanKeyedVectors(unittest.TestCase): @@ -120,7 +120,7 @@ def test_relative_cosine_similarity(self): 'skillful', 'skilful', 'dear', 'near', 'dependable', 'safe', 'secure', 'right', 'ripe', 'well', 'effective', 'in_effect', 'in_force', 'serious', 'sound', 'salutary', 'honest', 'undecomposed', 'unspoiled', 'unspoilt', 'thoroughly', 'soundly' - ] # synonyms for "good" as per wordnet + ] # synonyms for "good" as per wordnet cos_sim = [] for i in range(len(wordnet_syn)): if wordnet_syn[i] in self.vectors.vocab: @@ -400,32 +400,42 @@ def test_load_model_and_vocab_file_ignore(self): model.get_vector(u'どういたしまして'), np.array([.1, .2, .3], dtype=np.float32))) +try: + import keras # noqa:F401 + keras_installed = True +except ImportError: + keras_installed = False + + class WordEmbeddingsKeyedVectorsTest(unittest.TestCase): def setUp(self): self.vectors = EuclideanKeyedVectors.load_word2vec_format( datapath('euclidean_vectors.bin'), binary=True, datatype=np.float64) + @unittest.skipIf(not keras_installed, 'keras needs to be installed for this test') def test_get_keras_embedding_word_index_none(self): embedding_layer = self.vectors.get_keras_embedding() self.assertEqual(self.vectors.vectors.shape, embedding_layer._initial_weights[0].shape) self.assertTrue(np.array_equal( self.vectors['is'], embedding_layer._initial_weights[0][self.vectors.vocab['is'].index, :])) + @unittest.skipIf(not keras_installed, 'keras needs to be installed for this test') def test_get_keras_embedding_word_index_passed(self): word_index = {'is': 1, 'to': 2} embedding_layer = self.vectors.get_keras_embedding(word_index=word_index) - self.assertEqual( embedding_layer._initial_weights[0].shape, (3, self.vectors.vectors.shape[1])) + self.assertEqual(embedding_layer._initial_weights[0].shape, (3, self.vectors.vectors.shape[1])) self.assertTrue(np.array_equal( self.vectors['is'], embedding_layer._initial_weights[0][1, :])) + @unittest.skipIf(not keras_installed, 'keras needs to be installed for this test') @patch('numpy.random.normal') def test_get_keras_embedding_word_index_passed_with_oov_word(self, normal_func): normal_func.return_value = np.zeros((3, self.vectors.vectors.shape[1])) word_index = {'is': 1, 'not_a_real_word': 2} embedding_layer = self.vectors.get_keras_embedding(word_index=word_index) - self.assertEqual( embedding_layer._initial_weights[0].shape, (3, self.vectors.vectors.shape[1])) + self.assertEqual(embedding_layer._initial_weights[0].shape, (3, self.vectors.vectors.shape[1])) self.assertTrue(np.array_equal(embedding_layer._initial_weights[0][2, :], - np.zeros(self.vectors.vectors.shape[1]))) + np.zeros(self.vectors.vectors.shape[1]))) if __name__ == '__main__': From 74c932f532a991baed23b4ba057cacee144b1106 Mon Sep 17 00:00:00 2001 From: Zhicharevich Date: Wed, 11 Dec 2019 15:06:55 +0200 Subject: [PATCH 03/10] removed unnessecary comment from test_keyed_vectors --- gensim/test/test_keyedvectors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gensim/test/test_keyedvectors.py b/gensim/test/test_keyedvectors.py index 9d5acca752..a498f62491 100644 --- a/gensim/test/test_keyedvectors.py +++ b/gensim/test/test_keyedvectors.py @@ -401,7 +401,7 @@ def test_load_model_and_vocab_file_ignore(self): try: - import keras # noqa:F401 + import keras keras_installed = True except ImportError: keras_installed = False From a0d7ce7a5e6e58ed266337b14db3fb80c50d977e Mon Sep 17 00:00:00 2001 From: Zhicharevich Date: Tue, 14 Jan 2020 17:27:11 +0200 Subject: [PATCH 04/10] fixed indentation --- gensim/test/test_keyedvectors.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gensim/test/test_keyedvectors.py b/gensim/test/test_keyedvectors.py index a498f62491..6a61dce562 100644 --- a/gensim/test/test_keyedvectors.py +++ b/gensim/test/test_keyedvectors.py @@ -401,7 +401,8 @@ def test_load_model_and_vocab_file_ignore(self): try: - import keras + import keras + keras_installed = True except ImportError: keras_installed = False @@ -434,8 +435,8 @@ def test_get_keras_embedding_word_index_passed_with_oov_word(self, normal_func): word_index = {'is': 1, 'not_a_real_word': 2} embedding_layer = self.vectors.get_keras_embedding(word_index=word_index) self.assertEqual(embedding_layer._initial_weights[0].shape, (3, self.vectors.vectors.shape[1])) - self.assertTrue(np.array_equal(embedding_layer._initial_weights[0][2, :], - np.zeros(self.vectors.vectors.shape[1]))) + self.assertTrue( + np.array_equal(embedding_layer._initial_weights[0][2, :], np.zeros(self.vectors.vectors.shape[1]))) if __name__ == '__main__': From 9416e552fdcc7f79b57e7f56e889c07f4964ca67 Mon Sep 17 00:00:00 2001 From: Zhicharevich Date: Tue, 14 Jan 2020 23:26:29 +0200 Subject: [PATCH 05/10] fixed flake import error --- gensim/test/test_keyedvectors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gensim/test/test_keyedvectors.py b/gensim/test/test_keyedvectors.py index 6a61dce562..395b16c09e 100644 --- a/gensim/test/test_keyedvectors.py +++ b/gensim/test/test_keyedvectors.py @@ -401,7 +401,7 @@ def test_load_model_and_vocab_file_ignore(self): try: - import keras + import keras # noqa: F401 keras_installed = True except ImportError: From 8f6fd101a6acfc4362f0527e52a84780cc3a1e29 Mon Sep 17 00:00:00 2001 From: Zhicharevich Date: Thu, 23 Jan 2020 10:33:14 +0200 Subject: [PATCH 06/10] moved skip test decorator to class --- gensim/test/test_keyedvectors.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/gensim/test/test_keyedvectors.py b/gensim/test/test_keyedvectors.py index 395b16c09e..46d29c89fa 100644 --- a/gensim/test/test_keyedvectors.py +++ b/gensim/test/test_keyedvectors.py @@ -408,19 +408,18 @@ def test_load_model_and_vocab_file_ignore(self): keras_installed = False +@unittest.skipUnless(keras_installed, 'keras needs to be installed for this test') class WordEmbeddingsKeyedVectorsTest(unittest.TestCase): def setUp(self): self.vectors = EuclideanKeyedVectors.load_word2vec_format( datapath('euclidean_vectors.bin'), binary=True, datatype=np.float64) - @unittest.skipIf(not keras_installed, 'keras needs to be installed for this test') def test_get_keras_embedding_word_index_none(self): embedding_layer = self.vectors.get_keras_embedding() self.assertEqual(self.vectors.vectors.shape, embedding_layer._initial_weights[0].shape) self.assertTrue(np.array_equal( self.vectors['is'], embedding_layer._initial_weights[0][self.vectors.vocab['is'].index, :])) - @unittest.skipIf(not keras_installed, 'keras needs to be installed for this test') def test_get_keras_embedding_word_index_passed(self): word_index = {'is': 1, 'to': 2} embedding_layer = self.vectors.get_keras_embedding(word_index=word_index) @@ -428,7 +427,6 @@ def test_get_keras_embedding_word_index_passed(self): self.assertTrue(np.array_equal( self.vectors['is'], embedding_layer._initial_weights[0][1, :])) - @unittest.skipIf(not keras_installed, 'keras needs to be installed for this test') @patch('numpy.random.normal') def test_get_keras_embedding_word_index_passed_with_oov_word(self, normal_func): normal_func.return_value = np.zeros((3, self.vectors.vectors.shape[1])) From 6853bae65600e3507dd68e75dc50f6fb3dc6f842 Mon Sep 17 00:00:00 2001 From: Hamekoded Date: Thu, 30 Jan 2020 12:05:49 +0200 Subject: [PATCH 07/10] Update gensim/models/keyedvectors.py Co-Authored-By: Michael Penkov --- gensim/models/keyedvectors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index 9abaffc29f..23e83d1bd7 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -1394,7 +1394,7 @@ def get_keras_embedding(self, train_embeddings=False, word_index=None): word_index : {str : int} A mapping from tokens to their indices the way they will be provided in the input to the embedding layer. - The embedding of each token will be placed in the corresponding index in the embedding matrix. + The embedding of each token will be placed at the corresponding index in the returned matrix. Tokens not in the index are ignored and not placed in the returned layer embedding matrix. This is useful when the token indices are produced by a process that is not coupled with the embedding model, e.x. an Keras Tokenizer object. From bf5ebc041896d1c0f44358dabfd829314f00161e Mon Sep 17 00:00:00 2001 From: Hamekoded Date: Thu, 30 Jan 2020 12:06:39 +0200 Subject: [PATCH 08/10] Update gensim/models/keyedvectors.py Co-Authored-By: Michael Penkov --- gensim/models/keyedvectors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index 23e83d1bd7..fe32ee8ff1 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -1395,7 +1395,7 @@ def get_keras_embedding(self, train_embeddings=False, word_index=None): word_index : {str : int} A mapping from tokens to their indices the way they will be provided in the input to the embedding layer. The embedding of each token will be placed at the corresponding index in the returned matrix. - Tokens not in the index are ignored and not placed in the returned layer embedding matrix. + Tokens not in the index are ignored. This is useful when the token indices are produced by a process that is not coupled with the embedding model, e.x. an Keras Tokenizer object. If None, the embedding matrix in the embedding layer will be indexed according to self.vocab From 0be76ab628df0f2bd289a802f0e63621cbeda0b0 Mon Sep 17 00:00:00 2001 From: Hamekoded Date: Thu, 30 Jan 2020 12:06:57 +0200 Subject: [PATCH 09/10] Update gensim/models/keyedvectors.py Co-Authored-By: Michael Penkov --- gensim/models/keyedvectors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index fe32ee8ff1..4c1fc0c0f5 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -1412,7 +1412,7 @@ def get_keras_embedding(self, train_embeddings=False, word_index=None): Warnings -------- - Current method work only if `Keras `_ installed. + Current method works only if `Keras `_ installed. """ try: From 44e13c1fc2348aec56110f6c1249129ba616067b Mon Sep 17 00:00:00 2001 From: Zhicharevich Date: Sun, 2 Feb 2020 10:11:10 +0200 Subject: [PATCH 10/10] renamed keras_installed flag to upper case, removed unneeded comment --- gensim/models/keyedvectors.py | 1 - gensim/test/test_keyedvectors.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index 4c1fc0c0f5..ae370dc1d7 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -1428,7 +1428,6 @@ def get_keras_embedding(self, train_embeddings=False, word_index=None): if word in self.vocab: weights[index] = self.get_vector(word) - # set `trainable` as `False` to use the pretrained word embedding layer = Embedding( input_dim=weights.shape[0], output_dim=weights.shape[1], weights=[weights], trainable=train_embeddings diff --git a/gensim/test/test_keyedvectors.py b/gensim/test/test_keyedvectors.py index 46d29c89fa..13256e7003 100644 --- a/gensim/test/test_keyedvectors.py +++ b/gensim/test/test_keyedvectors.py @@ -403,12 +403,12 @@ def test_load_model_and_vocab_file_ignore(self): try: import keras # noqa: F401 - keras_installed = True + KERAS_INSTALLED = True except ImportError: - keras_installed = False + KERAS_INSTALLED = False -@unittest.skipUnless(keras_installed, 'keras needs to be installed for this test') +@unittest.skipUnless(KERAS_INSTALLED, 'keras needs to be installed for this test') class WordEmbeddingsKeyedVectorsTest(unittest.TestCase): def setUp(self): self.vectors = EuclideanKeyedVectors.load_word2vec_format(