Finetuning BERT to classify textual entailment on the Stanford NLI corpus.
This project is a PyTorch and HuggingFace-based toolkit for the Natural Language Inference (NLI) task. The Stanford NLI corpus is used, featuring 570k human-written English sentence pairs each labeled as 'entailment', 'contradiction', or 'neutral'.
The main model is a BERT transformer finetuned on the task. Other models are also available for experimentation:
- Pooled Logistic Regression
- Shallow Neural Network
- Deep Neural Network
Each model is fine-tuned to work with token embeddings and gives classification scores for the three NLI labels.
Run the main script with various command-line arguments to specify the model, number of epochs, and other settings.
python main.py --model=shallow --epochs=5 --device=cuda --batch_size=64 --embedding_dim=128
- PyTorch — model building and training
- scikit-learn — for additional machine learning utilities
- HuggingFace Transformers — for utilizing and finetuning BERT