{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4ggSD_VLEFZm"
   },
   "source": [
    "# **Installing the packages**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 108102,
     "status": "ok",
     "timestamp": 1710843824570,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "TkGiG4nYCm1G",
    "outputId": "a5a00ed1-1971-4036-c7d9-ac42c50e2775"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (1.5.3)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.25.2)\n",
      "Requirement already satisfied: plotly in /usr/local/lib/python3.10/dist-packages (5.15.0)\n",
      "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.2.2)\n",
      "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (3.7.1)\n",
      "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.2.1+cu121)\n",
      "Collecting lightning\n",
      "  Downloading lightning-2.2.1-py3-none-any.whl (2.1 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: spacy in /usr/local/lib/python3.10/dist-packages (3.7.4)\n",
      "Collecting torchtext==0.6\n",
      "  Downloading torchtext-0.6.0-py3-none-any.whl (64 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.2/64.2 kB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from torchtext==0.6) (4.66.2)\n",
      "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchtext==0.6) (2.31.0)\n",
      "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from torchtext==0.6) (1.16.0)\n",
      "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from torchtext==0.6) (0.1.99)\n",
      "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2.8.2)\n",
      "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2023.4)\n",
      "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly) (8.2.3)\n",
      "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from plotly) (24.0)\n",
      "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.11.4)\n",
      "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.3.2)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.3.0)\n",
      "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.2.0)\n",
      "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (0.12.1)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (4.49.0)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.4.5)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (9.4.0)\n",
      "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (3.1.2)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.10.0)\n",
      "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n",
      "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.3)\n",
      "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)\n",
      "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)\n",
      "  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m46.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)\n",
      "  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m56.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)\n",
      "  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m61.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch)\n",
      "  Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m731.7/731.7 MB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cublas-cu12==12.1.3.1 (from torch)\n",
      "  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch)\n",
      "  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m8.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch)\n",
      "  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch)\n",
      "  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch)\n",
      "  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-nccl-cu12==2.19.3 (from torch)\n",
      "  Downloading nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl (166.0 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m166.0/166.0 MB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch)\n",
      "  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m13.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.2.0)\n",
      "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)\n",
      "  Downloading nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m79.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.10/dist-packages (from lightning) (6.0.1)\n",
      "Collecting lightning-utilities<2.0,>=0.8.0 (from lightning)\n",
      "  Downloading lightning_utilities-0.11.0-py3-none-any.whl (25 kB)\n",
      "Collecting torchmetrics<3.0,>=0.7.0 (from lightning)\n",
      "  Downloading torchmetrics-1.3.2-py3-none-any.whl (841 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m841.5/841.5 kB\u001b[0m \u001b[31m66.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting pytorch-lightning (from lightning)\n",
      "  Downloading pytorch_lightning-2.2.1-py3-none-any.whl (801 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m801.6/801.6 kB\u001b[0m \u001b[31m65.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in /usr/local/lib/python3.10/dist-packages (from spacy) (3.0.12)\n",
      "Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (1.0.5)\n",
      "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (1.0.10)\n",
      "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy) (2.0.8)\n",
      "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy) (3.0.9)\n",
      "Requirement already satisfied: thinc<8.3.0,>=8.2.2 in /usr/local/lib/python3.10/dist-packages (from spacy) (8.2.3)\n",
      "Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in /usr/local/lib/python3.10/dist-packages (from spacy) (1.1.2)\n",
      "Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /usr/local/lib/python3.10/dist-packages (from spacy) (2.4.8)\n",
      "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /usr/local/lib/python3.10/dist-packages (from spacy) (2.0.10)\n",
      "Requirement already satisfied: weasel<0.4.0,>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (0.3.4)\n",
      "Requirement already satisfied: typer<0.10.0,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (0.9.0)\n",
      "Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in /usr/local/lib/python3.10/dist-packages (from spacy) (6.4.0)\n",
      "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4 in /usr/local/lib/python3.10/dist-packages (from spacy) (2.6.4)\n",
      "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from spacy) (67.7.2)\n",
      "Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (3.3.0)\n",
      "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.10/dist-packages (from fsspec->torch) (3.9.3)\n",
      "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy) (0.6.0)\n",
      "Requirement already satisfied: pydantic-core==2.16.3 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy) (2.16.3)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torchtext==0.6) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchtext==0.6) (3.6)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchtext==0.6) (2.0.7)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchtext==0.6) (2024.2.2)\n",
      "Requirement already satisfied: blis<0.8.0,>=0.7.8 in /usr/local/lib/python3.10/dist-packages (from thinc<8.3.0,>=8.2.2->spacy) (0.7.11)\n",
      "Requirement already satisfied: confection<1.0.0,>=0.0.1 in /usr/local/lib/python3.10/dist-packages (from thinc<8.3.0,>=8.2.2->spacy) (0.1.4)\n",
      "Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.10/dist-packages (from typer<0.10.0,>=0.3.0->spacy) (8.1.7)\n",
      "Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from weasel<0.4.0,>=0.1.0->spacy) (0.16.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n",
      "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch) (1.3.1)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch) (23.2.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch) (1.4.1)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch) (6.0.5)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch) (1.9.4)\n",
      "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch) (4.0.3)\n",
      "Installing collected packages: nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchtext, torchmetrics, pytorch-lightning, lightning\n",
      "  Attempting uninstall: torchtext\n",
      "    Found existing installation: torchtext 0.17.1\n",
      "    Uninstalling torchtext-0.17.1:\n",
      "      Successfully uninstalled torchtext-0.17.1\n",
      "Successfully installed lightning-2.2.1 lightning-utilities-0.11.0 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.19.3 nvidia-nvjitlink-cu12-12.4.99 nvidia-nvtx-cu12-12.1.105 pytorch-lightning-2.2.1 torchmetrics-1.3.2 torchtext-0.6.0\n",
      "Collecting en-core-web-sm==3.7.1\n",
      "  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.8/12.8 MB\u001b[0m \u001b[31m45.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: spacy<3.8.0,>=3.7.2 in /usr/local/lib/python3.10/dist-packages (from en-core-web-sm==3.7.1) (3.7.4)\n",
      "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.12)\n",
      "Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.5)\n",
      "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.10)\n",
      "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.8)\n",
      "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.9)\n",
      "Requirement already satisfied: thinc<8.3.0,>=8.2.2 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.2.3)\n",
      "Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.1.2)\n",
      "Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.4.8)\n",
      "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.10)\n",
      "Requirement already satisfied: weasel<0.4.0,>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.3.4)\n",
      "Requirement already satisfied: typer<0.10.0,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.9.0)\n",
      "Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (6.4.0)\n",
      "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.66.2)\n",
      "Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.31.0)\n",
      "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.6.4)\n",
      "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.1.3)\n",
      "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (67.7.2)\n",
      "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (24.0)\n",
      "Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.3.0)\n",
      "Requirement already satisfied: numpy>=1.19.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.25.2)\n",
      "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.6.0)\n",
      "Requirement already satisfied: pydantic-core==2.16.3 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.16.3)\n",
      "Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.10.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.6)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.7)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2024.2.2)\n",
      "Requirement already satisfied: blis<0.8.0,>=0.7.8 in /usr/local/lib/python3.10/dist-packages (from thinc<8.3.0,>=8.2.2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.7.11)\n",
      "Requirement already satisfied: confection<1.0.0,>=0.0.1 in /usr/local/lib/python3.10/dist-packages (from thinc<8.3.0,>=8.2.2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.1.4)\n",
      "Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.10/dist-packages (from typer<0.10.0,>=0.3.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.1.7)\n",
      "Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from weasel<0.4.0,>=0.1.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.16.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.1.5)\n",
      "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
      "You can now load the package via spacy.load('en_core_web_sm')\n",
      "\u001b[38;5;3m⚠ Restart to reload dependencies\u001b[0m\n",
      "If you are in a Jupyter or Colab notebook, you may need to restart Python in\n",
      "order to load all the package's dependencies. You can do this by selecting the\n",
      "'Restart kernel' or 'Restart runtime' option.\n"
     ]
    }
   ],
   "source": [
    "! pip install pandas numpy plotly scikit-learn matplotlib torch lightning spacy torchtext==0.6\n",
    "!python -m spacy download en_core_web_sm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "y492wv7nEqWr"
   },
   "source": [
    "# **Importing the libraries**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 72166,
     "status": "ok",
     "timestamp": 1710843896713,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "NnXWOPX2Eo-E",
    "outputId": "a78dc5c0-662c-4434-ecaa-17dfa2fa1425"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mounted at /content/drive/\n"
     ]
    }
   ],
   "source": [
    "# Mount the drive if not mounted\n",
    "from google.colab import drive\n",
    "drive.mount(\"/content/drive/\")\n",
    "\n",
    "import time\n",
    "import re\n",
    "import string\n",
    "from collections import Counter\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import spacy\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.metrics import f1_score, balanced_accuracy_score, classification_report\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
    "\n",
    "from torchtext import data\n",
    "from torchtext.data import Field\n",
    "from torchtext.data import Dataset, Example\n",
    "from torchtext.data import BucketIterator\n",
    "from torchtext.vocab import FastText\n",
    "from torchtext.vocab import CharNGram\n",
    "\n",
    "#import wandb\n",
    "# wandb.login(relogin=True)\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "executionInfo": {
     "elapsed": 1193,
     "status": "ok",
     "timestamp": 1710843897902,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "_9-FVKNOglW8"
   },
   "outputs": [],
   "source": [
    "# Load spaCy's English model\n",
    "spacy_en = spacy.load('en_core_web_sm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "executionInfo": {
     "elapsed": 14,
     "status": "ok",
     "timestamp": 1710843897903,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "QcDuDu_yfaEi"
   },
   "outputs": [],
   "source": [
    "SEED = 42\n",
    "\n",
    "torch.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZSRJH7TlMuYL"
   },
   "source": [
    "# **Loading and Pre-processing data**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "executionInfo": {
     "elapsed": 2714,
     "status": "ok",
     "timestamp": 1710843900605,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "RGfbGRD-MkbN"
   },
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"/content/drive/MyDrive/AA-Tutorial/data/Agora.csv\", encoding='ISO-8859-1')\n",
    "# Renaming all the features of the dataframe\n",
    "df = df.rename(str.strip, axis='columns')\n",
    "# Merging the Item and Item Description using a [SEP] token\n",
    "separator = ' [SEP] '\n",
    "df['TEXT'] = df.apply(lambda row: f\"{row['Item']}{separator}{row['Item Description']}\", axis=1)\n",
    "# dropping Unncessary columns\n",
    "df.drop(columns=[\"Item\", \"Item Description\", \"Category\", \"Price\", \"Origin\", \"Destination\", \"Rating\", \"Remarks\"], inplace=True)\n",
    "# Assuming that vendors Amsterdam100 and amsterdam100 are the same vendors\n",
    "df.Vendor = df.Vendor.apply(lambda x: x.lower())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3GtjBjHKM2SI"
   },
   "source": [
    "Due to the extensive time required to train on over 100K+ samples, we have decided to limit our analysis to a subset of 5K samples. To get these samples, we look into vendors that have 5+ advertisements and then allocate all the vendors that have less than 5 ads into a new class, \"others\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "executionInfo": {
     "elapsed": 23,
     "status": "ok",
     "timestamp": 1710843900606,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "7c1GXGNUMzbX"
   },
   "outputs": [],
   "source": [
    "df = df.iloc[:5000]\n",
    "# Calculate advertisement frequency for each vendor\n",
    "ad_freq = df['Vendor'].value_counts()\n",
    "# Filter vendors with ad frequency less than 5\n",
    "vendors_to_replace = ad_freq[ad_freq < 5].index\n",
    "# Update DataFrame: Replace vendor names with 'others' where ad frequency is less than 5\n",
    "df['Vendor'] = df['Vendor'].apply(lambda x: 'others' if x in vendors_to_replace else x)\n",
    "# Getting all unique vendor handles from the 'Vendor' column.\n",
    "unique_vendors = df['Vendor'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "executionInfo": {
     "elapsed": 21,
     "status": "ok",
     "timestamp": 1710843900606,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "mVFyR15hcMo1"
   },
   "outputs": [],
   "source": [
    "# Creating a dictionary to map each unique vendor name to a unique integer.\n",
    "# The `enumerate` function is used to generate integer indices starting from 0 for each unique label found in `df['Vendor']`.\n",
    "# This effectively creates a label-to-index mapping.\n",
    "vendor2idx = {l: i for i, l in enumerate(df['Vendor'].unique())}\n",
    "\n",
    "# Applying the mapping to convert all categorical labels in 'label' column to integers.\n",
    "# The `apply` method goes through each label in `df['label']`, and the lambda function uses the mapping `ltoi`\n",
    "# to find the corresponding integer. The result is a column of integer labels.\n",
    "df['Vendor'] = df['Vendor'].apply(lambda y: vendor2idx[y])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6aj76lYZM7B5"
   },
   "source": [
    "# **Splitting data**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 21,
     "status": "ok",
     "timestamp": 1710843900607,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "zVAXsHy1M5bx",
    "outputId": "5ffde6ed-2698-438f-f356-87ad688198e5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training set size: 3750\n",
      "Test set size: 250\n",
      "Validation set size: 1000\n"
     ]
    }
   ],
   "source": [
    "train_df, temp_df = train_test_split(df, test_size=0.25, random_state=1111)\n",
    "# Calculate the proportion of test size in the temporary dataset\n",
    "test_size_in_temp = 0.20 / (0.20 + 0.05)\n",
    "# Now split the temporary set into test and validation sets\n",
    "test_df, val_df = train_test_split(temp_df, test_size=test_size_in_temp, random_state=1111)\n",
    "\n",
    "print(f\"Training set size: {len(train_df)}\")\n",
    "print(f\"Test set size: {len(test_df)}\")\n",
    "print(f\"Validation set size: {len(val_df)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dwWJO9IbSxzu"
   },
   "source": [
    "# **Pytorch Dataset**\n",
    "\n",
    "The AgoraDataset class inherits from Dataset, intended to facilitate the creation of a dataset from a pandas DataFrame for use in training models with PyTorch, particularly in contexts where data is tabular and includes text fields that need processing similar to what's found in torchtext. The implementation seems to aim for compatibility with torchtext's data handling by utilizing the Example and Field abstractions, though it's not a direct usage pattern seen in PyTorch's Dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "executionInfo": {
     "elapsed": 16,
     "status": "ok",
     "timestamp": 1710843900607,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "38OHN_jfS0Wy"
   },
   "outputs": [],
   "source": [
    "class AgoraDataset(data.Dataset):\n",
    "    # Constructor for initializing the dataset object.\n",
    "    def __init__(self, df, fields, is_test=False, **kwargs):\n",
    "        examples = []\n",
    "        # Iterating over each row in the DataFrame to construct dataset examples.\n",
    "        for i, row in df.iterrows():\n",
    "            # Assigning 'Vendor' as the label for training/validation data, and None for test data.\n",
    "            label = row.Vendor if not is_test else None\n",
    "            # Extracting the text data from the row.\n",
    "            text = row.TEXT\n",
    "            # Creating an Example object for each row and appending it to the examples list.\n",
    "            examples.append(data.Example.fromlist([text, label], fields))\n",
    "\n",
    "        # Calling the constructor of the parent class (Dataset) with the examples and fields.\n",
    "        super().__init__(examples, fields, **kwargs)\n",
    "\n",
    "    # A static method to define the sort key used for sorting examples, based on their text length.\n",
    "    @staticmethod\n",
    "    def sort_key(ex):\n",
    "        return len(ex.text)\n",
    "\n",
    "    # A class method to create dataset splits for training, validation, and testing.\n",
    "    @classmethod\n",
    "    def splits(cls, fields, train_df, val_df=None, test_df=None, **kwargs):\n",
    "        train_data, val_data, test_data = (None, None, None)\n",
    "\n",
    "        # Creating dataset objects for training, validation, and testing dataframes if they are provided.\n",
    "        if train_df is not None:\n",
    "            train_data = cls(train_df.copy(), fields, **kwargs)\n",
    "        if val_df is not None:\n",
    "            val_data = cls(val_df.copy(), fields, **kwargs)\n",
    "        if test_df is not None:\n",
    "            test_data = cls(test_df.copy(), fields, False, **kwargs)\n",
    "\n",
    "        # Returning the dataset objects as a tuple.\n",
    "        return tuple(d for d in (train_data, val_data, test_data) if d is not None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "q3SHi216cz6n"
   },
   "source": [
    "# **Preparing text and label fields**\n",
    "\n",
    "This approach allows for the definition, preprocessing, and numericalization (token-to-index mapping) of text data, along with the integration of pre-trained word embeddings. For this project, we are using [Fasttext](https://fasttext.cc/) simple 300 dimension word embeddings.\n",
    "\n",
    "--- \n",
    "\n",
    "### FastText Embedding\n",
    "- **Word Representation:** FastText is an embedding technique that represents words in a high-dimensional space, capturing the semantic meaning of words by considering both the word itself and its subword components (n-grams).\n",
    "- **Handling OOV Words:** One of the strengths of FastText is its ability to generate representations for out-of-vocabulary (OOV) words by using the embeddings of subword n-grams, making it robust in handling rare or unseen words in the training data.\n",
    "- **Pre-trained Models:** FastText comes with pre-trained models on large corpora, allowing for the transfer of semantic knowledge to specific tasks without the need for extensive training data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 21915,
     "status": "ok",
     "timestamp": 1710843940099,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "8aLpS4e_hMvP",
    "outputId": "e8cce298-82ac-4f27-d3b5-fad37b127e62"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ".vector_cache/wiki.simple.vec: 293MB [00:01, 179MB/s]                           \n",
      "  0%|          | 0/111051 [00:00<?, ?it/s]WARNING:torchtext.vocab:Skipping token b'111051' with 1-dimensional vector [b'300']; likely a header\n",
      "100%|██████████| 111051/111051 [00:16<00:00, 6805.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Size of TEXT vocabulary: 14801\n",
      "Size of LABEL vocabulary: 153\n",
      "[('*', 11139), (' ', 7792), ('.', 4329), ('[', 3892), (']', 3888), ('-', 3883), ('SEP', 3750), ('...', 3004), (':', 2600), ('the', 2398)]\n"
     ]
    }
   ],
   "source": [
    "# Define the tokenization function that uses spaCy for English\n",
    "def tokenize_en(text):\n",
    "    # Tokenize the input text and return a list of tokens\n",
    "    return [tok.text for tok in spacy_en.tokenizer(text)]\n",
    "\n",
    "# Define fields for the text and label data, specifying how they should be processed\n",
    "TEXT = data.Field(tokenize = tokenize_en, batch_first=True, include_lengths = True) # Process text data: tokenize and prepare batch\n",
    "LABEL = data.LabelField(batch_first=True) # Process label data\n",
    "\n",
    "# Define the structure of the dataset including fields\n",
    "fields = [('text',TEXT), ('label',LABEL)]\n",
    "\n",
    "# Split the dataset into training, validation, and test sets\n",
    "train_ds, val_ds, test_ds = AgoraDataset.splits(fields, train_df=train_df, val_df=val_df, test_df=test_df)\n",
    "\n",
    "# Build the vocabulary for the text field using the training dataset\n",
    "TEXT.build_vocab(train_ds,\n",
    "                 max_size = 100000, # Maximum size of the vocabulary\n",
    "                 vectors = 'fasttext.simple.300d', # Use FastText pre-trained word embeddings\n",
    "                 unk_init = torch.Tensor.zero_) # Initialize unknown tokens to zero\n",
    "\n",
    "# Build the vocabulary for the label field using the training dataset\n",
    "LABEL.build_vocab(train_ds)\n",
    "\n",
    "# Output the size of the vocabulary for the text and label fields\n",
    "print(\"Size of TEXT vocabulary:\",len(TEXT.vocab)) # Number of unique tokens in text\n",
    "print(\"Size of LABEL vocabulary:\",len(LABEL.vocab)) # Number of unique tokens in label\n",
    "\n",
    "# Output the 10 most common words in the vocabulary\n",
    "print(TEXT.vocab.freqs.most_common(10)) # Commonly used words"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_ftSudnzmtb_"
   },
   "source": [
    "# **Defining the model**\n",
    "\n",
    "The BiGRUModel class utilizes a [Bidirectional Gated Recurrent Unit (BiGRU)](https://en.wikipedia.org/wiki/Gated_recurrent_unit) for processing sequences.  This class effectively integrates pre-trained Fasttext embeddings as a fixed (non-trainable) input representation, processes sequences with a BiGRU to capture temporal dependencies, and applies a linear transformation to produce outputs for classification. The architecture is designed for tasks where leveraging pre-trained word embeddings can significantly enhance model performance by providing a rich, pre-learned representation of the input vocabulary.\n",
    "\n",
    "---\n",
    "\n",
    "### Bidirectional GRU\n",
    "- **Gated Recurrent Units (GRUs):** GRUs are a type of recurrent neural network (RNN) architecture that efficiently captures dependencies in sequential data, such as text, by processing input data both forwards and backwards.\n",
    "- **Bidirectional Processing:** A Bidirectional GRU processes data in both directions (forward and backward), allowing it to capture context from both past and future within the sequence. This is especially beneficial for understanding the context and meaning in text data.\n",
    "- **Handling Sequential Data:** Well-suited for tasks that involve sequential data, such as language modeling, text classification, and sentiment analysis, where the order of words is crucial for understanding the overall meaning.\n",
    "\n",
    "### Bidirectional GRU with FastText Embedding Classifier\n",
    "- **Combining Strengths:** By combining Bidirectional GRU with FastText embeddings, this classifier leverages both the contextual awareness of bidirectional processing and the rich semantic representations of FastText. This results in improved performance on text classification tasks.\n",
    "- **Suitable for Complex Tasks:** Especially effective for complex natural language processing tasks that require an understanding of nuanced context and semantics, such as sentiment analysis, question answering, and topic classification.\n",
    "- **Flexibility and Adaptability:** The approach is adaptable to various languages and domains, benefiting from FastText's robust handling of word representations and GRU's efficient processing of sequential data.\n",
    "\n",
    "Overall, a Bidirectional GRU with FastText embedding classifier represents a powerful combination for tackling a wide range of text classification challenges, offering both deep contextual understanding and rich semantic representation of words.\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "executionInfo": {
     "elapsed": 406,
     "status": "ok",
     "timestamp": 1710844043073,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "xud08-Pxvznw"
   },
   "outputs": [],
   "source": [
    "class GRU_net(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers,\n",
    "                 bidirectional, dropout, pad_idx):\n",
    "        # Initialize the parent class (nn.Module)\n",
    "        super().__init__()\n",
    "\n",
    "        # Embedding layer to transform indices into dense vectors of a fixed size\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)\n",
    "\n",
    "        # Replace LSTM with GRU\n",
    "        self.rnn = nn.GRU(embedding_dim,\n",
    "                          hidden_dim,\n",
    "                          num_layers=n_layers,\n",
    "                          bidirectional=bidirectional,\n",
    "                          dropout=dropout,\n",
    "                          batch_first=True)\n",
    "\n",
    "        # Linear layer to map from hidden state space to hidden space\n",
    "        self.fc1 = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, hidden_dim)\n",
    "\n",
    "        # Linear layer to map from hidden space to output dimension\n",
    "        self.fc2 = nn.Linear(hidden_dim, output_dim)\n",
    "\n",
    "        # Dropout for regularization\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, text, text_lengths):\n",
    "        # text: tensor of [batch size, sentence length]\n",
    "\n",
    "        # Pass text through embedding layer\n",
    "        embedded = self.embedding(text)\n",
    "        # embedded: tensor of [batch size, sentence length, embedding dimension]\n",
    "\n",
    "        # Pack sequence\n",
    "        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.cpu(), batch_first=True)\n",
    "\n",
    "        # Pass packed sequence through GRU\n",
    "        packed_output, hidden = self.rnn(packed_embedded)\n",
    "\n",
    "        # Unpack sequence (if needed, not done in this code as GRU's output is not directly used after unpacking)\n",
    "\n",
    "        # Concatenate the final forward and backward hidden state\n",
    "        if self.rnn.bidirectional:\n",
    "            hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))\n",
    "        else:\n",
    "            hidden = self.dropout(hidden[-1,:,:])\n",
    "\n",
    "        # Pass the output through the first fully connected layer\n",
    "        output = self.fc1(hidden)\n",
    "        # Apply dropout\n",
    "        output = self.dropout(output)\n",
    "        # Pass the output through the second fully connected layer\n",
    "        output = self.fc2(output)\n",
    "\n",
    "        # Return the final output\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hIBTIsrPQIQ5"
   },
   "source": [
    "# **Loading Model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 248,
     "status": "ok",
     "timestamp": 1710845502036,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "yIRRA2BWvB8F",
    "outputId": "ecca31d3-5da0-4a82-f311-24ae1dfcd1a2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GRU_net(\n",
      "  (embedding): Embedding(14801, 300, padding_idx=1)\n",
      "  (rnn): GRU(300, 256, num_layers=2, batch_first=True, dropout=0.2, bidirectional=True)\n",
      "  (fc1): Linear(in_features=512, out_features=256, bias=True)\n",
      "  (fc2): Linear(in_features=256, out_features=153, bias=True)\n",
      "  (dropout): Dropout(p=0.2, inplace=False)\n",
      ")\n",
      "The model has 6,650,757 trainable parameters\n"
     ]
    }
   ],
   "source": [
    "# Set the batch size for training and evaluation\n",
    "BATCH_SIZE = 32\n",
    "\n",
    "# Determine the computing device based on the availability of CUDA\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# Create iterators for the training and validation datasets\n",
    "train_iterator, valid_iterator = BucketIterator.splits(\n",
    "    (train_ds, val_ds), # Datasets for training and validation\n",
    "    batch_size = BATCH_SIZE, # Batch size for both datasets\n",
    "    sort_within_batch = True, # Sort examples within each batch by their lengths\n",
    "    device = device) # Specify the computing device\n",
    "\n",
    "# Create an iterator for the test dataset\n",
    "_, test_iterator = BucketIterator.splits(\n",
    "    (train_ds, test_ds), # Reuse train_ds to keep the structure, but focus is on test_ds\n",
    "    batch_size = BATCH_SIZE, # Batch size for the dataset\n",
    "    sort_within_batch = True, # Sort examples within each batch by their lengths\n",
    "    device = device) # Specify the computing device\n",
    "\n",
    "# Set hyperparameters for the model training\n",
    "num_epochs = 10 # Number of epochs to train for\n",
    "learning_rate = 0.001 # Learning rate for the optimizer\n",
    "\n",
    "# Define the model architecture parameters\n",
    "INPUT_DIM = len(TEXT.vocab) # Vocabulary size\n",
    "EMBEDDING_DIM = 300 # Size of each embedding vector\n",
    "HIDDEN_DIM = 256 # Size of hidden layers\n",
    "OUTPUT_DIM = 153 # Size of the output layer; Change this accordingly as you increase the size of the dataset\n",
    "N_LAYERS = 2 # Number of recurrent layers\n",
    "BIDIRECTIONAL = True # Use a bidirectional model\n",
    "DROPOUT = 0.2 # Dropout rate for regularization\n",
    "PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token] # Index of the padding token for embedding layer\n",
    "\n",
    "# Instantiate the model with the specified architecture parameters\n",
    "model = GRU_net(INPUT_DIM,\n",
    "                EMBEDDING_DIM,\n",
    "                HIDDEN_DIM,\n",
    "                OUTPUT_DIM,\n",
    "                N_LAYERS,\n",
    "                BIDIRECTIONAL,\n",
    "                DROPOUT,\n",
    "                PAD_IDX)\n",
    "\n",
    "# Print the model architecture for review\n",
    "print(model)\n",
    "\n",
    "# Function to count the number of trainable parameters in the model\n",
    "def count_parameters(model):\n",
    "    # Sum the number of elements in all parameters that require gradient computation\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "# Print the total number of trainable parameters in the model\n",
    "print(f'The model has {count_parameters(model):,} trainable parameters')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1710845502283,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "CbIfwMT7vzuf"
   },
   "outputs": [],
   "source": [
    "# Load pre-trained embeddings from the TEXT field's vocabulary into a variable\n",
    "pretrained_embeddings = TEXT.vocab.vectors\n",
    "\n",
    "# Copy the pre-trained embeddings into the model's embedding layer weights\n",
    "model.embedding.weight.data.copy_(pretrained_embeddings)\n",
    "\n",
    "# Initialize the embedding vector for the padding index (PAD_IDX) to all zeros\n",
    "# This is done to ensure that the padding token does not contribute to the model's predictions\n",
    "model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OcQXiwFcRhvw"
   },
   "source": [
    "# **Helper functions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "executionInfo": {
     "elapsed": 258,
     "status": "ok",
     "timestamp": 1710845504278,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "5qwEsMSdvzws"
   },
   "outputs": [],
   "source": [
    "def train(model, iterator, optimizer, criterion):\n",
    "    epoch_loss = 0\n",
    "    all_predictions = []\n",
    "    all_labels = []\n",
    "\n",
    "    model.train()\n",
    "\n",
    "    for batch in iterator:\n",
    "        text, text_lengths = batch.text\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        predictions = model(text, text_lengths)\n",
    "        # For multi-class classification, predictions are not squeezed\n",
    "        # predictions shape is [batch size, n_classes]\n",
    "\n",
    "        loss = criterion(predictions, batch.label)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        epoch_loss += loss.item()\n",
    "\n",
    "        # Convert predictions to actual class numbers\n",
    "        _, predicted_labels = torch.max(predictions, 1)\n",
    "\n",
    "        # Collect all labels and predictions for metric calculation\n",
    "        all_labels.extend(batch.label.cpu().numpy())\n",
    "        all_predictions.extend(predicted_labels.cpu().numpy())\n",
    "\n",
    "    # Calculate metrics for multi-class classification\n",
    "    balanced_acc = balanced_accuracy_score(all_labels, all_predictions)\n",
    "    weighted_f1 = f1_score(all_labels, all_predictions, average='weighted')\n",
    "    micro_f1 = f1_score(all_labels, all_predictions, average='micro')\n",
    "    macro_f1 = f1_score(all_labels, all_predictions, average='macro')\n",
    "\n",
    "    return {\n",
    "        \"loss\": epoch_loss / len(iterator),\n",
    "        \"balanced_accuracy\": balanced_acc,\n",
    "        \"weighted_f1\": weighted_f1,\n",
    "        \"micro_f1\": micro_f1,\n",
    "        \"macro_f1\": macro_f1\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1710845504991,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "XqQ7PI9TvzyM"
   },
   "outputs": [],
   "source": [
    "def evaluate(model, iterator, criterion):\n",
    "    # Initialize variables to accumulate loss and store predictions and labels\n",
    "    epoch_loss = 0\n",
    "    all_predictions = []\n",
    "    all_labels = []\n",
    "\n",
    "    # Set model to evaluation mode (turns off dropout and batch normalization)\n",
    "    model.eval()\n",
    "\n",
    "    # Disable gradient calculations to speed up the process\n",
    "    with torch.no_grad():\n",
    "        for batch in iterator:\n",
    "            # Extract text and its lengths from the current batch\n",
    "            text, text_lengths = batch.text\n",
    "\n",
    "            # Generate predictions using the model\n",
    "            predictions = model(text, text_lengths).squeeze(1)\n",
    "\n",
    "            # Compute loss for the current batch\n",
    "            loss = criterion(predictions, batch.label)\n",
    "\n",
    "            # Accumulate the loss over all batches\n",
    "            epoch_loss += loss.item()\n",
    "\n",
    "\n",
    "            # Convert predictions to actual class numbers\n",
    "            _, predicted_labels = torch.max(predictions, 1)\n",
    "\n",
    "            # Collect all labels and predictions for metric calculation\n",
    "            all_labels.extend(batch.label.cpu().numpy())\n",
    "            all_predictions.extend(predicted_labels.cpu().numpy())\n",
    "\n",
    "    # Calculate metrics using accumulated predictions and true labels\n",
    "    balanced_acc = balanced_accuracy_score(all_labels, all_predictions)\n",
    "    weighted_f1 = f1_score(all_labels, all_predictions, average='weighted')\n",
    "    micro_f1 = f1_score(all_labels, all_predictions, average='micro')\n",
    "    macro_f1 = f1_score(all_labels, all_predictions, average='macro')\n",
    "\n",
    "    # Return loss and calculated metrics\n",
    "    return {\n",
    "        \"loss\": epoch_loss / len(iterator),\n",
    "        \"balanced_accuracy\": balanced_acc,\n",
    "        \"weighted_f1\": weighted_f1,\n",
    "        \"micro_f1\": micro_f1,\n",
    "        \"macro_f1\": macro_f1\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fYD4o_bcV3IA"
   },
   "source": [
    "# **Training Model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 17235,
     "status": "ok",
     "timestamp": 1710845522989,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "_QSsFIG8vz0J",
    "outputId": "9015c82a-2062-4a09-d874-ce1a45937316"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1\n",
      "\tTrain Loss: 3.8779 | Train Balanced Acc: 0.0317\n",
      "\tValid Loss: 2.8381 | Valid Balanced Acc: 0.1062\n",
      "\tTrain Weighted F1: 0.1368 | Valid Weighted F1: 0.2793\n",
      "\tTrain Micro F1: 0.1971 | Valid Micro F1: 0.3850\n",
      "\tTrain Macro F1: 0.0292 | Valid Macro F1: 0.0753\n",
      "Epoch: 2\n",
      "\tTrain Loss: 2.2533 | Train Balanced Acc: 0.1625\n",
      "\tValid Loss: 1.9536 | Valid Balanced Acc: 0.2607\n",
      "\tTrain Weighted F1: 0.4234 | Valid Weighted F1: 0.4969\n",
      "\tTrain Micro F1: 0.4912 | Valid Micro F1: 0.5600\n",
      "\tTrain Macro F1: 0.1497 | Valid Macro F1: 0.2192\n",
      "Epoch: 3\n",
      "\tTrain Loss: 1.2362 | Train Balanced Acc: 0.3692\n",
      "\tValid Loss: 1.7194 | Valid Balanced Acc: 0.4139\n",
      "\tTrain Weighted F1: 0.6604 | Valid Weighted F1: 0.5899\n",
      "\tTrain Micro F1: 0.6949 | Valid Micro F1: 0.6150\n",
      "\tTrain Macro F1: 0.3609 | Valid Macro F1: 0.3760\n",
      "Epoch: 4\n",
      "\tTrain Loss: 0.5824 | Train Balanced Acc: 0.6204\n",
      "\tValid Loss: 1.8392 | Valid Balanced Acc: 0.5031\n",
      "\tTrain Weighted F1: 0.8353 | Valid Weighted F1: 0.6169\n",
      "\tTrain Micro F1: 0.8461 | Valid Micro F1: 0.6320\n",
      "\tTrain Macro F1: 0.6295 | Valid Macro F1: 0.4596\n",
      "Epoch: 5\n",
      "\tTrain Loss: 0.3035 | Train Balanced Acc: 0.7922\n",
      "\tValid Loss: 1.8745 | Valid Balanced Acc: 0.5800\n",
      "\tTrain Weighted F1: 0.9142 | Valid Weighted F1: 0.6727\n",
      "\tTrain Micro F1: 0.9176 | Valid Micro F1: 0.6780\n",
      "\tTrain Macro F1: 0.7958 | Valid Macro F1: 0.5422\n",
      "Epoch: 6\n",
      "\tTrain Loss: 0.1190 | Train Balanced Acc: 0.9271\n",
      "\tValid Loss: 2.1219 | Valid Balanced Acc: 0.5902\n",
      "\tTrain Weighted F1: 0.9735 | Valid Weighted F1: 0.6575\n",
      "\tTrain Micro F1: 0.9741 | Valid Micro F1: 0.6660\n",
      "\tTrain Macro F1: 0.9322 | Valid Macro F1: 0.5446\n",
      "Epoch: 7\n",
      "\tTrain Loss: 0.0689 | Train Balanced Acc: 0.9691\n",
      "\tValid Loss: 2.1048 | Valid Balanced Acc: 0.5941\n",
      "\tTrain Weighted F1: 0.9848 | Valid Weighted F1: 0.6816\n",
      "\tTrain Micro F1: 0.9848 | Valid Micro F1: 0.6880\n",
      "\tTrain Macro F1: 0.9702 | Valid Macro F1: 0.5696\n",
      "Epoch: 8\n",
      "\tTrain Loss: 0.0381 | Train Balanced Acc: 0.9903\n",
      "\tValid Loss: 2.1507 | Valid Balanced Acc: 0.6149\n",
      "\tTrain Weighted F1: 0.9923 | Valid Weighted F1: 0.6815\n",
      "\tTrain Micro F1: 0.9923 | Valid Micro F1: 0.6930\n",
      "\tTrain Macro F1: 0.9898 | Valid Macro F1: 0.5875\n",
      "Epoch: 9\n",
      "\tTrain Loss: 0.0223 | Train Balanced Acc: 0.9951\n",
      "\tValid Loss: 2.2115 | Valid Balanced Acc: 0.6162\n",
      "\tTrain Weighted F1: 0.9957 | Valid Weighted F1: 0.6837\n",
      "\tTrain Micro F1: 0.9957 | Valid Micro F1: 0.6880\n",
      "\tTrain Macro F1: 0.9958 | Valid Macro F1: 0.5693\n",
      "Epoch: 10\n",
      "\tTrain Loss: 0.0141 | Train Balanced Acc: 0.9971\n",
      "\tValid Loss: 2.2186 | Valid Balanced Acc: 0.6223\n",
      "\tTrain Weighted F1: 0.9973 | Valid Weighted F1: 0.6858\n",
      "\tTrain Micro F1: 0.9973 | Valid Micro F1: 0.6950\n",
      "\tTrain Macro F1: 0.9971 | Valid Macro F1: 0.5870\n",
      "Time taken: 17.311 seconds\n"
     ]
    }
   ],
   "source": [
    "# Record start time\n",
    "t = time.time()\n",
    "\n",
    "# Initialize the best validation loss to infinity\n",
    "best_valid_loss = float('inf')\n",
    "\n",
    "# Move the model to the appropriate device (GPU or CPU)\n",
    "model.to(device)\n",
    "\n",
    "# Define the loss function\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "batch\n",
    "# Define the optimizer\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "\n",
    "# Training and evaluation loop\n",
    "for epoch in range(num_epochs):\n",
    "    # Train the model and retrieve metrics\n",
    "    train_metrics = train(model, train_iterator, optimizer, criterion)\n",
    "\n",
    "    # Evaluate the model on the validation set and retrieve metrics\n",
    "    valid_metrics = evaluate(model, valid_iterator, criterion)\n",
    "\n",
    "    # Print training and validation metrics\n",
    "    print(f'Epoch: {epoch+1}')\n",
    "    print(f'\\tTrain Loss: {train_metrics[\"loss\"]:.4f} | Train Balanced Acc: {train_metrics[\"balanced_accuracy\"]:.4f}')\n",
    "    print(f'\\tValid Loss: {valid_metrics[\"loss\"]:.4f} | Valid Balanced Acc: {valid_metrics[\"balanced_accuracy\"]:.4f}')\n",
    "    print(f'\\tTrain Weighted F1: {train_metrics[\"weighted_f1\"]:.4f} | Valid Weighted F1: {valid_metrics[\"weighted_f1\"]:.4f}')\n",
    "    print(f'\\tTrain Micro F1: {train_metrics[\"micro_f1\"]:.4f} | Valid Micro F1: {valid_metrics[\"micro_f1\"]:.4f}')\n",
    "    print(f'\\tTrain Macro F1: {train_metrics[\"macro_f1\"]:.4f} | Valid Macro F1: {valid_metrics[\"macro_f1\"]:.4f}')\n",
    "\n",
    "    # Update lists to track loss and accuracy (if necessary for later analysis)\n",
    "    # loss.append(train_metrics[\"loss\"])\n",
    "    # acc.append(train_metrics[\"balanced_accuracy\"])\n",
    "    # val_acc.append(valid_metrics[\"balanced_accuracy\"])\n",
    "\n",
    "    # Check if the current model is the best one based on validation loss\n",
    "    # if valid_metrics[\"loss\"] < best_valid_loss:\n",
    "        # best_valid_loss = valid_metrics[\"loss\"]\n",
    "        # Save the current best model\n",
    "        # torch.save(model.state_dict(), 'best_model.pt')\n",
    "\n",
    "# Calculate and print the total time taken for training and evaluation\n",
    "print(f'Time taken: {time.time()-t:.3f} seconds')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "i-aN90EWV628"
   },
   "source": [
    "# **Testing on Test Dataset**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 310,
     "status": "ok",
     "timestamp": 1710845608163,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "gbzy83dAvz2S",
    "outputId": "0bb342d4-f61b-4289-ef23-77eefeb04a43"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss': 1.9752449542284012,\n",
       " 'balanced_accuracy': 0.667471480425201,\n",
       " 'weighted_f1': 0.7113831447711159,\n",
       " 'micro_f1': 0.708,\n",
       " 'macro_f1': 0.5553573957836322}"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate(model, test_iterator, criterion)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "myPAbhwcWXgv"
   },
   "source": [
    "# **Loading the Results Dataframe**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "executionInfo": {
     "elapsed": 745,
     "status": "ok",
     "timestamp": 1710845697781,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "RLnMzu_jvz4L"
   },
   "outputs": [],
   "source": [
    "results_df = pd.read_csv(\"/content/drive/MyDrive/AA-Tutorial/data/results.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "executionInfo": {
     "elapsed": 7,
     "status": "ok",
     "timestamp": 1710845763305,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "QEdf0o0dvz6t"
   },
   "outputs": [],
   "source": [
    "results_df[\"Bi-GRU\"] = [0.6674714, 0.7113831, 0.708, 0.5553573]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 175
    },
    "executionInfo": {
     "elapsed": 290,
     "status": "ok",
     "timestamp": 1710845767608,
     "user": {
      "displayName": "Vageesh Saxena",
      "userId": "08383190130410030928"
     },
     "user_tz": -60
    },
    "id": "I1ipazv8vz8c",
    "outputId": "0a77784b-05d1-4491-8c2c-2bcf1145fef9"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.google.colaboratory.intrinsic+json": {
       "summary": "{\n  \"name\": \"results_df\",\n  \"rows\": 4,\n  \"fields\": [\n    {\n      \"column\": \"Metrics\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 4,\n        \"samples\": [\n          \"Weighted-F1\",\n          \"Macro-F1\",\n          \"Accuracy\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"MultinomialNB\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.1336458759990484,\n        \"min\": 0.331296992481203,\n        \"max\": 0.596,\n        \"num_unique_values\": 4,\n        \"samples\": [\n          0.5244050525636352,\n          0.3374360632962008,\n          0.331296992481203\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"SVC\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.016578898600449716,\n        \"min\": 0.6719450270183899,\n        \"max\": 0.7084807843862542,\n        \"num_unique_values\": 4,\n        \"samples\": [\n          0.7084807843862542,\n          0.6719450270183899,\n          0.6823223509702638\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"RandomForest\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.052421048068768586,\n        \"min\": 0.283832389946941,\n        \"max\": 0.4121224397199897,\n        \"num_unique_values\": 4,\n        \"samples\": [\n          0.3442897309776566,\n          0.283832389946941,\n          0.4121224397199897\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"LogisticRegression\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.07654824206503479,\n        \"min\": 0.5900274131375507,\n        \"max\": 0.7757557438024769,\n        \"num_unique_values\": 4,\n        \"samples\": [\n          0.6622835815165227,\n          0.5900274131375507,\n          0.7757557438024769\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"MLP\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.05481574796019103,\n        \"min\": 0.6466153885550759,\n        \"max\": 0.756,\n        \"num_unique_values\": 4,\n        \"samples\": [\n          0.7522883095118389,\n          0.6466153885550759,\n          0.6766820067409904\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"All-miniLM\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.06650668636432978,\n        \"min\": 0.552629,\n        \"max\": 0.703,\n        \"num_unique_values\": 4,\n        \"samples\": [\n          0.6820589,\n          0.552629,\n          0.6475543\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"DistilBERT\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.04528406566353656,\n        \"min\": 0.70647042,\n        \"max\": 0.808,\n        \"num_unique_values\": 4,\n        \"samples\": [\n          0.7948137,\n          0.70647042,\n          0.77843189\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"DistilRoBERTa\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.04692494802859787,\n        \"min\": 0.707505643,\n        \"max\": 0.81099999,\n        \"num_unique_values\": 4,\n        \"samples\": [\n          0.801566,\n          0.707505643,\n          0.7820527\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"Bi-GRU\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.07291299937151123,\n        \"min\": 0.5553573,\n        \"max\": 0.7113831,\n        \"num_unique_values\": 4,\n        \"samples\": [\n          0.7113831,\n          0.5553573,\n          0.6674714\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}",
       "type": "dataframe",
       "variable_name": "results_df"
      },
      "text/html": [
       "\n",
       "  <div id=\"df-f873d53f-dcb7-4114-9f5a-bb4cbae59983\" class=\"colab-df-container\">\n",
       "    <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Metrics</th>\n",
       "      <th>MultinomialNB</th>\n",
       "      <th>SVC</th>\n",
       "      <th>RandomForest</th>\n",
       "      <th>LogisticRegression</th>\n",
       "      <th>MLP</th>\n",
       "      <th>All-miniLM</th>\n",
       "      <th>DistilBERT</th>\n",
       "      <th>DistilRoBERTa</th>\n",
       "      <th>Bi-GRU</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Accuracy</td>\n",
       "      <td>0.331297</td>\n",
       "      <td>0.682322</td>\n",
       "      <td>0.412122</td>\n",
       "      <td>0.775756</td>\n",
       "      <td>0.676682</td>\n",
       "      <td>0.647554</td>\n",
       "      <td>0.778432</td>\n",
       "      <td>0.782053</td>\n",
       "      <td>0.667471</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Weighted-F1</td>\n",
       "      <td>0.524405</td>\n",
       "      <td>0.708481</td>\n",
       "      <td>0.344290</td>\n",
       "      <td>0.662284</td>\n",
       "      <td>0.752288</td>\n",
       "      <td>0.682059</td>\n",
       "      <td>0.794814</td>\n",
       "      <td>0.801566</td>\n",
       "      <td>0.711383</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Micro-F1</td>\n",
       "      <td>0.596000</td>\n",
       "      <td>0.700000</td>\n",
       "      <td>0.344000</td>\n",
       "      <td>0.668000</td>\n",
       "      <td>0.756000</td>\n",
       "      <td>0.703000</td>\n",
       "      <td>0.808000</td>\n",
       "      <td>0.811000</td>\n",
       "      <td>0.708000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Macro-F1</td>\n",
       "      <td>0.337436</td>\n",
       "      <td>0.671945</td>\n",
       "      <td>0.283832</td>\n",
       "      <td>0.590027</td>\n",
       "      <td>0.646615</td>\n",
       "      <td>0.552629</td>\n",
       "      <td>0.706470</td>\n",
       "      <td>0.707506</td>\n",
       "      <td>0.555357</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>\n",
       "    <div class=\"colab-df-buttons\">\n",
       "\n",
       "  <div class=\"colab-df-container\">\n",
       "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-f873d53f-dcb7-4114-9f5a-bb4cbae59983')\"\n",
       "            title=\"Convert this dataframe to an interactive table.\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
       "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
       "  </svg>\n",
       "    </button>\n",
       "\n",
       "  <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    .colab-df-buttons div {\n",
       "      margin-bottom: 4px;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "    <script>\n",
       "      const buttonEl =\n",
       "        document.querySelector('#df-f873d53f-dcb7-4114-9f5a-bb4cbae59983 button.colab-df-convert');\n",
       "      buttonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "      async function convertToInteractive(key) {\n",
       "        const element = document.querySelector('#df-f873d53f-dcb7-4114-9f5a-bb4cbae59983');\n",
       "        const dataTable =\n",
       "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                    [key], {});\n",
       "        if (!dataTable) return;\n",
       "\n",
       "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "          + ' to learn more about interactive tables.';\n",
       "        element.innerHTML = '';\n",
       "        dataTable['output_type'] = 'display_data';\n",
       "        await google.colab.output.renderOutput(dataTable, element);\n",
       "        const docLink = document.createElement('div');\n",
       "        docLink.innerHTML = docLinkHtml;\n",
       "        element.appendChild(docLink);\n",
       "      }\n",
       "    </script>\n",
       "  </div>\n",
       "\n",
       "\n",
       "<div id=\"df-367e0c3e-50ca-4031-b71d-979879e012c8\">\n",
       "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-367e0c3e-50ca-4031-b71d-979879e012c8')\"\n",
       "            title=\"Suggest charts\"\n",
       "            style=\"display:none;\">\n",
       "\n",
       "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "     width=\"24px\">\n",
       "    <g>\n",
       "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
       "    </g>\n",
       "</svg>\n",
       "  </button>\n",
       "\n",
       "<style>\n",
       "  .colab-df-quickchart {\n",
       "      --bg-color: #E8F0FE;\n",
       "      --fill-color: #1967D2;\n",
       "      --hover-bg-color: #E2EBFA;\n",
       "      --hover-fill-color: #174EA6;\n",
       "      --disabled-fill-color: #AAA;\n",
       "      --disabled-bg-color: #DDD;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart {\n",
       "      --bg-color: #3B4455;\n",
       "      --fill-color: #D2E3FC;\n",
       "      --hover-bg-color: #434B5C;\n",
       "      --hover-fill-color: #FFFFFF;\n",
       "      --disabled-bg-color: #3B4455;\n",
       "      --disabled-fill-color: #666;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart {\n",
       "    background-color: var(--bg-color);\n",
       "    border: none;\n",
       "    border-radius: 50%;\n",
       "    cursor: pointer;\n",
       "    display: none;\n",
       "    fill: var(--fill-color);\n",
       "    height: 32px;\n",
       "    padding: 0;\n",
       "    width: 32px;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart:hover {\n",
       "    background-color: var(--hover-bg-color);\n",
       "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "    fill: var(--button-hover-fill-color);\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart-complete:disabled,\n",
       "  .colab-df-quickchart-complete:disabled:hover {\n",
       "    background-color: var(--disabled-bg-color);\n",
       "    fill: var(--disabled-fill-color);\n",
       "    box-shadow: none;\n",
       "  }\n",
       "\n",
       "  .colab-df-spinner {\n",
       "    border: 2px solid var(--fill-color);\n",
       "    border-color: transparent;\n",
       "    border-bottom-color: var(--fill-color);\n",
       "    animation:\n",
       "      spin 1s steps(1) infinite;\n",
       "  }\n",
       "\n",
       "  @keyframes spin {\n",
       "    0% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "      border-left-color: var(--fill-color);\n",
       "    }\n",
       "    20% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    30% {\n",
       "      border-color: transparent;\n",
       "      border-left-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    40% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-top-color: var(--fill-color);\n",
       "    }\n",
       "    60% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "    }\n",
       "    80% {\n",
       "      border-color: transparent;\n",
       "      border-right-color: var(--fill-color);\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "    90% {\n",
       "      border-color: transparent;\n",
       "      border-bottom-color: var(--fill-color);\n",
       "    }\n",
       "  }\n",
       "</style>\n",
       "\n",
       "  <script>\n",
       "    async function quickchart(key) {\n",
       "      const quickchartButtonEl =\n",
       "        document.querySelector('#' + key + ' button');\n",
       "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
       "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
       "      try {\n",
       "        const charts = await google.colab.kernel.invokeFunction(\n",
       "            'suggestCharts', [key], {});\n",
       "      } catch (error) {\n",
       "        console.error('Error during call to suggestCharts:', error);\n",
       "      }\n",
       "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
       "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
       "    }\n",
       "    (() => {\n",
       "      let quickchartButtonEl =\n",
       "        document.querySelector('#df-367e0c3e-50ca-4031-b71d-979879e012c8 button');\n",
       "      quickchartButtonEl.style.display =\n",
       "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "    })();\n",
       "  </script>\n",
       "</div>\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "       Metrics  MultinomialNB       SVC  RandomForest  LogisticRegression  \\\n",
       "0     Accuracy       0.331297  0.682322      0.412122            0.775756   \n",
       "1  Weighted-F1       0.524405  0.708481      0.344290            0.662284   \n",
       "2     Micro-F1       0.596000  0.700000      0.344000            0.668000   \n",
       "3     Macro-F1       0.337436  0.671945      0.283832            0.590027   \n",
       "\n",
       "        MLP  All-miniLM  DistilBERT  DistilRoBERTa    Bi-GRU  \n",
       "0  0.676682    0.647554    0.778432       0.782053  0.667471  \n",
       "1  0.752288    0.682059    0.794814       0.801566  0.711383  \n",
       "2  0.756000    0.703000    0.808000       0.811000  0.708000  \n",
       "3  0.646615    0.552629    0.706470       0.707506  0.555357  "
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Q5cnYoPrbTR9"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyMpYQZwCb4kuFr2uvupvA1F",
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}