diff --git a/CHANGES.md b/CHANGES.md index 4ffd2b82d..4a997035b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added the `event_name` argument for `LRScheduler` for optional recording of LR changes inside `net.history`. NOTE: Supported only in Pytorch>=1.4 - Make it easier to add custom modules or optimizers to a neural net class by automatically registering them where necessary and by making them available to set_params - Added the `step_every` argument for `LRScheduler` to set whether the scheduler step should be taken on every epoch or on every batch. +- Added a notebook that shows how to use a pretrained BERT in skorch with the help of torchtext and huggingface transformers ### Changed diff --git a/notebooks/MNIST-torchvision.ipynb b/notebooks/MNIST-torchvision.ipynb index c62e00bd3..f454ae221 100644 --- a/notebooks/MNIST-torchvision.ipynb +++ b/notebooks/MNIST-torchvision.ipynb @@ -12,7 +12,7 @@ "\n", " Run in Google Colab \n", "\n", - "View source on GitHub" + "View source on GitHub" ] }, { diff --git a/notebooks/README.md b/notebooks/README.md index a9217007c..a8ef9aa41 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -5,3 +5,4 @@ * [MNIST](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/MNIST.ipynb) * [MNIST using torchvision](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/MNIST-torchvision.ipynb) * [Transfer Learning](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Transfer_Learning.ipynb) +* [torchtext and bert]((https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/torchtext_bert.ipynb) \ No newline at end of file diff --git a/notebooks/torchtext_bert.ipynb b/notebooks/torchtext_bert.ipynb new file mode 100644 index 000000000..9f993ec69 --- /dev/null +++ b/notebooks/torchtext_bert.ipynb @@ -0,0 +1,838 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "YXZDKnk92yhl" + }, + "source": [ + "# Train a sentiment classifier using torchtext and BERT using skorch" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7YsnM31l2jzf" + }, + "source": [ + "This notebook here is based on [another notebook](https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/6%20-%20Transformers%20for%20Sentiment%20Analysis.ipynb). Please check there for more details." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on GitHub
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note**: If you are running this in [a colab notebook](https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/torchtext_bert.ipynb), we recommend you enable a free GPU by going:\n", + "\n", + "> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n", + "\n", + "If you are running in colab, you should install the dependencies and download the dataset by running the following cell:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "fNoXphO66yb1" + }, + "source": [ + "## Install packages" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "! [ ! -z \"$COLAB_GPU\" ] && pip install torch torchtext transformers skorch" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QKgAeAtn67a-" + }, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "OPqksHrsqntn" + }, + "outputs": [], + "source": [ + "import random" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "deqPnHqqqFY9" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torchtext\n", + "from torch import nn\n", + "from torchtext.data import Field, LabelField\n", + "from torchtext.data import BucketIterator\n", + "from torchtext.datasets import IMDB\n", + "from transformers import BertTokenizer\n", + "from transformers import BertModel\n", + "from skorch import NeuralNetClassifier\n", + "from skorch.callbacks import Freezer\n", + "from skorch.callbacks import ProgressBar" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "cV1wda1E6-Pe" + }, + "source": [ + "## Constants" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "LEhrsBHVqoTN" + }, + "outputs": [], + "source": [ + "SEED = 0\n", + "MAX_SEQ_LEN = 512 # discard everything after this many tokens, for speed\n", + "\n", + "torch.manual_seed(SEED)\n", + "torch.cuda.manual_seed(SEED)\n", + "torch.backends.cudnn.deterministic = True" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EaRlTexx7DRC" + }, + "source": [ + "## Load data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When running this notebook for the first time, loading data and the pretrained model will take a couple of minutes." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "sECMSc_aqoQe" + }, + "outputs": [], + "source": [ + "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "BKKV0YiGqoOx" + }, + "outputs": [], + "source": [ + "def tokenize_and_cut(sentence):\n", + " tokens = tokenizer.tokenize(sentence) \n", + " tokens = tokens[:MAX_SEQ_LEN - 2]\n", + " return tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "54lgG62wqoLR" + }, + "outputs": [], + "source": [ + "TEXT = Field(\n", + " batch_first=True,\n", + " use_vocab=False,\n", + " tokenize=tokenize_and_cut,\n", + " preprocessing=tokenizer.convert_tokens_to_ids,\n", + " init_token=tokenizer.cls_token_id,\n", + " eos_token=tokenizer.sep_token_id,\n", + " pad_token=tokenizer.pad_token_id,\n", + " unk_token=tokenizer.unk_token_id,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "kSJfnzOnqoJz" + }, + "outputs": [], + "source": [ + "LABEL = LabelField(dtype=torch.int64)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 54 + }, + "colab_type": "code", + "id": "Vf5JqfIbqoG1", + "outputId": "62ff6e11-7061-43f1-b520-7e809229a7d9" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3min 19s, sys: 1.34 s, total: 3min 20s\n", + "Wall time: 3min 20s\n" + ] + } + ], + "source": [ + "%%time\n", + "# make splits for data\n", + "ds_train, ds_test = IMDB.splits(TEXT, LABEL)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "0vofdyXkqoEc" + }, + "outputs": [], + "source": [ + "LABEL.build_vocab(ds_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "5NPNcyLGqoAt" + }, + "outputs": [], + "source": [ + "bert = BertModel.from_pretrained('bert-base-uncased')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "KIz8txtS7Qbe" + }, + "source": [ + "## Model definition" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "8F6xB4ehqn-e" + }, + "outputs": [], + "source": [ + "class BERTGRUSentiment(nn.Module):\n", + " def __init__(\n", + " self,\n", + " bert,\n", + " hidden_dim,\n", + " output_dim,\n", + " n_layers,\n", + " bidirectional,\n", + " dropout\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.bert = bert\n", + " embedding_dim = bert.config.to_dict()['hidden_size']\n", + " self.rnn = nn.GRU(\n", + " embedding_dim,\n", + " hidden_dim,\n", + " num_layers=n_layers,\n", + " bidirectional=bidirectional,\n", + " batch_first=True,\n", + " dropout=0 if n_layers < 2 else dropout,\n", + " )\n", + "\n", + " self.dropout = nn.Dropout(dropout)\n", + " self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)\n", + " self.sm = nn.Softmax(dim=-1)\n", + "\n", + " def forward(self, text):\n", + " # text = [batch size, sent len]\n", + "\n", + " with torch.no_grad():\n", + " embedded = self.bert(text)[0]\n", + " # embedded = [batch size, sent len, emb dim]\n", + "\n", + " _, hidden = self.rnn(embedded)\n", + " # hidden = [n layers * n directions, batch size, emb dim]\n", + "\n", + " if self.rnn.bidirectional:\n", + " hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))\n", + " else:\n", + " hidden = self.dropout(hidden[-1, :, :])\n", + " # hidden = [batch size, hid dim]\n", + "\n", + " output = self.out(hidden)\n", + " # output = [batch size, out dim]\n", + "\n", + " return self.sm(output)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "_oxEgqGlvwaX" + }, + "outputs": [], + "source": [ + "# model hyper-parameters\n", + "HIDDEN_DIM = 256\n", + "OUTPUT_DIM = 2\n", + "N_LAYERS = 2\n", + "BIDIRECTIONAL = True\n", + "DROPOUT = 0.25" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vQ88u81C7XzV" + }, + "source": [ + "## Custom code" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "x9Q_ZmHLqn25" + }, + "outputs": [], + "source": [ + "class SkorchBucketIterator(BucketIterator):\n", + " def __iter__(self):\n", + " for batch in super().__iter__():\n", + " # We make a small modification: Instead of just returning batch\n", + " # we return batch.text and batch.label, corresponding to X and y\n", + " yield batch.text, batch.label.long()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "104GBSjNqnzG" + }, + "outputs": [], + "source": [ + "def my_split(dataset, y, seed=SEED):\n", + " # use 70% of the training data for skorch-interval validation\n", + " return dataset.split(random_state=random.seed(seed))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eb8hVteK7dSH" + }, + "source": [ + "## Define and train neural net" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "MOlHycsspDdD" + }, + "outputs": [], + "source": [ + "net = NeuralNetClassifier(\n", + " module=BERTGRUSentiment,\n", + " module__bert=bert,\n", + " module__hidden_dim=HIDDEN_DIM,\n", + " module__output_dim=OUTPUT_DIM,\n", + " module__n_layers=N_LAYERS,\n", + " module__bidirectional=BIDIRECTIONAL,\n", + " module__dropout=DROPOUT,\n", + "\n", + " optimizer=torch.optim.Adam,\n", + "\n", + " iterator_train=SkorchBucketIterator,\n", + " iterator_valid=SkorchBucketIterator,\n", + " train_split=my_split,\n", + "\n", + " callbacks=[\n", + " # don't update the pretrained bert model parameters\n", + " Freezer(['bert*']),\n", + " # each epoch takes many minutes on colab, uncomment the\n", + " # next line to see a progress bar\n", + " # ProgressBar(batches_per_epoch=len(ds_train) // 128 + 1),\n", + " ],\n", + "\n", + " device='cuda',\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "colab_type": "code", + "id": "lfYbWgA_rlOV", + "outputId": "ef36d66c-9a88-43fb-e54d-bc7fdcc3144b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " epoch train_loss valid_acc valid_loss dur\n", + "------- ------------ ----------- ------------ ---------\n", + " 1 \u001b[36m0.7713\u001b[0m \u001b[32m0.8348\u001b[0m \u001b[35m0.3604\u001b[0m 1145.7862\n", + " 2 \u001b[36m0.3791\u001b[0m \u001b[32m0.8864\u001b[0m \u001b[35m0.2884\u001b[0m 1150.6875\n", + " 3 \u001b[36m0.3496\u001b[0m \u001b[32m0.8881\u001b[0m \u001b[35m0.2822\u001b[0m 1148.2985\n", + " 4 0.3603 \u001b[32m0.8900\u001b[0m \u001b[35m0.2790\u001b[0m 1147.4423\n", + " 5 \u001b[36m0.3465\u001b[0m \u001b[32m0.8977\u001b[0m \u001b[35m0.2656\u001b[0m 1146.9673\n", + " 6 0.3468 0.8861 0.2969 1149.5763\n", + " 7 \u001b[36m0.3455\u001b[0m 0.8787 0.2821 1148.0665\n", + " 8 \u001b[36m0.3348\u001b[0m 0.8929 0.2704 1147.7167\n", + " 9 0.3455 \u001b[32m0.8981\u001b[0m 0.2776 1148.3576\n", + " 10 \u001b[36m0.3329\u001b[0m \u001b[32m0.8988\u001b[0m 0.2670 1150.3590\n" + ] + }, + { + "data": { + "text/plain": [ + "[initialized](\n", + " module_=BERTGRUSentiment(\n", + " (bert): BertModel(\n", + " (embeddings): BertEmbeddings(\n", + " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): BertEncoder(\n", + " (layer): ModuleList(\n", + " (0): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (1): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (2): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (3): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (4): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (5): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (6): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (7): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (8): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (9): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (10): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (11): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BertPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + " )\n", + " (rnn): GRU(768, 256, num_layers=2, batch_first=True, dropout=0.25, bidirectional=True)\n", + " (dropout): Dropout(p=0.25, inplace=False)\n", + " (out): Linear(in_features=512, out_features=2, bias=True)\n", + " (sm): Softmax(dim=-1)\n", + " ),\n", + ")" + ] + }, + "execution_count": 17, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# we can set y=None because the labels are contained inside the dataset\n", + "net.fit(ds_train, y=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ES5zTENXrm3J" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "skorch-torchtext-bert.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}