diff --git a/CHANGES.md b/CHANGES.md
index 4ffd2b82d..4a997035b 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added the `event_name` argument for `LRScheduler` for optional recording of LR changes inside `net.history`. NOTE: Supported only in Pytorch>=1.4
- Make it easier to add custom modules or optimizers to a neural net class by automatically registering them where necessary and by making them available to set_params
- Added the `step_every` argument for `LRScheduler` to set whether the scheduler step should be taken on every epoch or on every batch.
+- Added a notebook that shows how to use a pretrained BERT in skorch with the help of torchtext and huggingface transformers
### Changed
diff --git a/notebooks/MNIST-torchvision.ipynb b/notebooks/MNIST-torchvision.ipynb
index c62e00bd3..f454ae221 100644
--- a/notebooks/MNIST-torchvision.ipynb
+++ b/notebooks/MNIST-torchvision.ipynb
@@ -12,7 +12,7 @@
"\n",
"
Run in Google Colab \n",
"
\n",
- " View source on GitHub | "
+ "
View source on GitHub"
]
},
{
diff --git a/notebooks/README.md b/notebooks/README.md
index a9217007c..a8ef9aa41 100644
--- a/notebooks/README.md
+++ b/notebooks/README.md
@@ -5,3 +5,4 @@
* [MNIST](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/MNIST.ipynb)
* [MNIST using torchvision](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/MNIST-torchvision.ipynb)
* [Transfer Learning](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Transfer_Learning.ipynb)
+* [torchtext and bert]((https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/torchtext_bert.ipynb)
\ No newline at end of file
diff --git a/notebooks/torchtext_bert.ipynb b/notebooks/torchtext_bert.ipynb
new file mode 100644
index 000000000..9f993ec69
--- /dev/null
+++ b/notebooks/torchtext_bert.ipynb
@@ -0,0 +1,838 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "YXZDKnk92yhl"
+ },
+ "source": [
+ "# Train a sentiment classifier using torchtext and BERT using skorch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "7YsnM31l2jzf"
+ },
+ "source": [
+ "This notebook here is based on [another notebook](https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/6%20-%20Transformers%20for%20Sentiment%20Analysis.ipynb). Please check there for more details."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Note**: If you are running this in [a colab notebook](https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/torchtext_bert.ipynb), we recommend you enable a free GPU by going:\n",
+ "\n",
+ "> **Runtime** → **Change runtime type** → **Hardware Accelerator: GPU**\n",
+ "\n",
+ "If you are running in colab, you should install the dependencies and download the dataset by running the following cell:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "fNoXphO66yb1"
+ },
+ "source": [
+ "## Install packages"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "! [ ! -z \"$COLAB_GPU\" ] && pip install torch torchtext transformers skorch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "QKgAeAtn67a-"
+ },
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "OPqksHrsqntn"
+ },
+ "outputs": [],
+ "source": [
+ "import random"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "deqPnHqqqFY9"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torchtext\n",
+ "from torch import nn\n",
+ "from torchtext.data import Field, LabelField\n",
+ "from torchtext.data import BucketIterator\n",
+ "from torchtext.datasets import IMDB\n",
+ "from transformers import BertTokenizer\n",
+ "from transformers import BertModel\n",
+ "from skorch import NeuralNetClassifier\n",
+ "from skorch.callbacks import Freezer\n",
+ "from skorch.callbacks import ProgressBar"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "cV1wda1E6-Pe"
+ },
+ "source": [
+ "## Constants"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "LEhrsBHVqoTN"
+ },
+ "outputs": [],
+ "source": [
+ "SEED = 0\n",
+ "MAX_SEQ_LEN = 512 # discard everything after this many tokens, for speed\n",
+ "\n",
+ "torch.manual_seed(SEED)\n",
+ "torch.cuda.manual_seed(SEED)\n",
+ "torch.backends.cudnn.deterministic = True"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "EaRlTexx7DRC"
+ },
+ "source": [
+ "## Load data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "When running this notebook for the first time, loading data and the pretrained model will take a couple of minutes."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "sECMSc_aqoQe"
+ },
+ "outputs": [],
+ "source": [
+ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "BKKV0YiGqoOx"
+ },
+ "outputs": [],
+ "source": [
+ "def tokenize_and_cut(sentence):\n",
+ " tokens = tokenizer.tokenize(sentence) \n",
+ " tokens = tokens[:MAX_SEQ_LEN - 2]\n",
+ " return tokens"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "54lgG62wqoLR"
+ },
+ "outputs": [],
+ "source": [
+ "TEXT = Field(\n",
+ " batch_first=True,\n",
+ " use_vocab=False,\n",
+ " tokenize=tokenize_and_cut,\n",
+ " preprocessing=tokenizer.convert_tokens_to_ids,\n",
+ " init_token=tokenizer.cls_token_id,\n",
+ " eos_token=tokenizer.sep_token_id,\n",
+ " pad_token=tokenizer.pad_token_id,\n",
+ " unk_token=tokenizer.unk_token_id,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "kSJfnzOnqoJz"
+ },
+ "outputs": [],
+ "source": [
+ "LABEL = LabelField(dtype=torch.int64)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 54
+ },
+ "colab_type": "code",
+ "id": "Vf5JqfIbqoG1",
+ "outputId": "62ff6e11-7061-43f1-b520-7e809229a7d9"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 3min 19s, sys: 1.34 s, total: 3min 20s\n",
+ "Wall time: 3min 20s\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "# make splits for data\n",
+ "ds_train, ds_test = IMDB.splits(TEXT, LABEL)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "0vofdyXkqoEc"
+ },
+ "outputs": [],
+ "source": [
+ "LABEL.build_vocab(ds_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "5NPNcyLGqoAt"
+ },
+ "outputs": [],
+ "source": [
+ "bert = BertModel.from_pretrained('bert-base-uncased')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "KIz8txtS7Qbe"
+ },
+ "source": [
+ "## Model definition"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "8F6xB4ehqn-e"
+ },
+ "outputs": [],
+ "source": [
+ "class BERTGRUSentiment(nn.Module):\n",
+ " def __init__(\n",
+ " self,\n",
+ " bert,\n",
+ " hidden_dim,\n",
+ " output_dim,\n",
+ " n_layers,\n",
+ " bidirectional,\n",
+ " dropout\n",
+ " ):\n",
+ " super().__init__()\n",
+ "\n",
+ " self.bert = bert\n",
+ " embedding_dim = bert.config.to_dict()['hidden_size']\n",
+ " self.rnn = nn.GRU(\n",
+ " embedding_dim,\n",
+ " hidden_dim,\n",
+ " num_layers=n_layers,\n",
+ " bidirectional=bidirectional,\n",
+ " batch_first=True,\n",
+ " dropout=0 if n_layers < 2 else dropout,\n",
+ " )\n",
+ "\n",
+ " self.dropout = nn.Dropout(dropout)\n",
+ " self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)\n",
+ " self.sm = nn.Softmax(dim=-1)\n",
+ "\n",
+ " def forward(self, text):\n",
+ " # text = [batch size, sent len]\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " embedded = self.bert(text)[0]\n",
+ " # embedded = [batch size, sent len, emb dim]\n",
+ "\n",
+ " _, hidden = self.rnn(embedded)\n",
+ " # hidden = [n layers * n directions, batch size, emb dim]\n",
+ "\n",
+ " if self.rnn.bidirectional:\n",
+ " hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))\n",
+ " else:\n",
+ " hidden = self.dropout(hidden[-1, :, :])\n",
+ " # hidden = [batch size, hid dim]\n",
+ "\n",
+ " output = self.out(hidden)\n",
+ " # output = [batch size, out dim]\n",
+ "\n",
+ " return self.sm(output)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "_oxEgqGlvwaX"
+ },
+ "outputs": [],
+ "source": [
+ "# model hyper-parameters\n",
+ "HIDDEN_DIM = 256\n",
+ "OUTPUT_DIM = 2\n",
+ "N_LAYERS = 2\n",
+ "BIDIRECTIONAL = True\n",
+ "DROPOUT = 0.25"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "vQ88u81C7XzV"
+ },
+ "source": [
+ "## Custom code"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "x9Q_ZmHLqn25"
+ },
+ "outputs": [],
+ "source": [
+ "class SkorchBucketIterator(BucketIterator):\n",
+ " def __iter__(self):\n",
+ " for batch in super().__iter__():\n",
+ " # We make a small modification: Instead of just returning batch\n",
+ " # we return batch.text and batch.label, corresponding to X and y\n",
+ " yield batch.text, batch.label.long()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "104GBSjNqnzG"
+ },
+ "outputs": [],
+ "source": [
+ "def my_split(dataset, y, seed=SEED):\n",
+ " # use 70% of the training data for skorch-interval validation\n",
+ " return dataset.split(random_state=random.seed(seed))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "eb8hVteK7dSH"
+ },
+ "source": [
+ "## Define and train neural net"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "MOlHycsspDdD"
+ },
+ "outputs": [],
+ "source": [
+ "net = NeuralNetClassifier(\n",
+ " module=BERTGRUSentiment,\n",
+ " module__bert=bert,\n",
+ " module__hidden_dim=HIDDEN_DIM,\n",
+ " module__output_dim=OUTPUT_DIM,\n",
+ " module__n_layers=N_LAYERS,\n",
+ " module__bidirectional=BIDIRECTIONAL,\n",
+ " module__dropout=DROPOUT,\n",
+ "\n",
+ " optimizer=torch.optim.Adam,\n",
+ "\n",
+ " iterator_train=SkorchBucketIterator,\n",
+ " iterator_valid=SkorchBucketIterator,\n",
+ " train_split=my_split,\n",
+ "\n",
+ " callbacks=[\n",
+ " # don't update the pretrained bert model parameters\n",
+ " Freezer(['bert*']),\n",
+ " # each epoch takes many minutes on colab, uncomment the\n",
+ " # next line to see a progress bar\n",
+ " # ProgressBar(batches_per_epoch=len(ds_train) // 128 + 1),\n",
+ " ],\n",
+ "\n",
+ " device='cuda',\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "colab_type": "code",
+ "id": "lfYbWgA_rlOV",
+ "outputId": "ef36d66c-9a88-43fb-e54d-bc7fdcc3144b"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " epoch train_loss valid_acc valid_loss dur\n",
+ "------- ------------ ----------- ------------ ---------\n",
+ " 1 \u001b[36m0.7713\u001b[0m \u001b[32m0.8348\u001b[0m \u001b[35m0.3604\u001b[0m 1145.7862\n",
+ " 2 \u001b[36m0.3791\u001b[0m \u001b[32m0.8864\u001b[0m \u001b[35m0.2884\u001b[0m 1150.6875\n",
+ " 3 \u001b[36m0.3496\u001b[0m \u001b[32m0.8881\u001b[0m \u001b[35m0.2822\u001b[0m 1148.2985\n",
+ " 4 0.3603 \u001b[32m0.8900\u001b[0m \u001b[35m0.2790\u001b[0m 1147.4423\n",
+ " 5 \u001b[36m0.3465\u001b[0m \u001b[32m0.8977\u001b[0m \u001b[35m0.2656\u001b[0m 1146.9673\n",
+ " 6 0.3468 0.8861 0.2969 1149.5763\n",
+ " 7 \u001b[36m0.3455\u001b[0m 0.8787 0.2821 1148.0665\n",
+ " 8 \u001b[36m0.3348\u001b[0m 0.8929 0.2704 1147.7167\n",
+ " 9 0.3455 \u001b[32m0.8981\u001b[0m 0.2776 1148.3576\n",
+ " 10 \u001b[36m0.3329\u001b[0m \u001b[32m0.8988\u001b[0m 0.2670 1150.3590\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[initialized](\n",
+ " module_=BERTGRUSentiment(\n",
+ " (bert): BertModel(\n",
+ " (embeddings): BertEmbeddings(\n",
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
+ " (position_embeddings): Embedding(512, 768)\n",
+ " (token_type_embeddings): Embedding(2, 768)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): BertEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (1): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (2): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (3): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (4): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (5): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (6): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (7): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (8): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (9): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (10): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (11): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (pooler): BertPooler(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (activation): Tanh()\n",
+ " )\n",
+ " )\n",
+ " (rnn): GRU(768, 256, num_layers=2, batch_first=True, dropout=0.25, bidirectional=True)\n",
+ " (dropout): Dropout(p=0.25, inplace=False)\n",
+ " (out): Linear(in_features=512, out_features=2, bias=True)\n",
+ " (sm): Softmax(dim=-1)\n",
+ " ),\n",
+ ")"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# we can set y=None because the labels are contained inside the dataset\n",
+ "net.fit(ds_train, y=None)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "ES5zTENXrm3J"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "skorch-torchtext-bert.ipynb",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}