This code is a Python implementation of the Skip-Gram model for word embeddings. It utilizes PyTorch and Plotly libraries to train the Skip-Gram model on the PTB (Penn Treebank) dataset and visualize the word embeddings.
The code consists of several functions and classes:
read_ptb
: Reads the PTB dataset from a file and preprocesses it.subsample
: Subsamples frequent words from the dataset to improve training efficiency.batchify
: Converts data into batches for efficient training.get_centers_and_contexts
: Retrieves center words and corresponding contexts from the corpus.RandomGenerator
: Generates random negative samples for the Skip-Gram model.get_negatives
: Generates negative samples for all contexts in the dataset.load_data_ptb
: Loads and preprocesses the PTB dataset for training.skip_gram
: Implements the Skip-Gram model architecture.SigmoidBCELoss
: Custom loss function for the Skip-Gram model.train
: Trains the Skip-Gram model using the PTB dataset.evaluate
: Evaluates the Skip-Gram model on a validation dataset.get_similar_tokens
: Retrieves similar tokens to a given query token based on learned embeddings.reduce_dimensions
: Reduces the dimensionality of the word embeddings for visualization using t-SNE.plot_with_plotly
: Plots the reduced-dimensional word embeddings using Plotly.
To use this code, the PTB dataset should be provided in a text file format. The code loads the dataset, preprocesses it, trains the Skip-Gram model, and saves the trained model. It also provides functions to retrieve similar tokens and visualize the word embeddings.