diff --git a/README.md b/README.md index 6405f189..7d27448e 100644 --- a/README.md +++ b/README.md @@ -200,7 +200,7 @@ Explore the explanations of your trained model using the DIANNA dashboard (for n _Dianna dashboard screenshot here_ ## Datasets @@ -260,9 +260,10 @@ And here are links to notebooks showing how we created our models on the benchma ### Text -| Models | Generation | -| :---------------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [Movie reviews model](https://zenodo.org/record/5910598) | [Stanford sentiment treebank model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/movie_reviews/generate_model.ipynb) | +| Models | Generation | +|:---------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [Movie reviews model](https://zenodo.org/record/5910598) | [Stanford sentiment treebank model generation](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/movie_reviews/generate_model.ipynb) | +| [Regalatory statement classifier](https://zenodo.org/record/8200001) | [EU-law regulatory-statement-classification](https://github.com/nature-of-eu-rules/regulatory-statement-classification) | ### Time series @@ -306,13 +307,13 @@ Also [GradCAM](https://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_ Our goal is that the scientific community embrases XAI as a source for novel and unexplored perspectives on scientific problems. Here, we offer [tutorials](./tutorials) on specific scientific use-cases of uisng XAI: -| Use-case (data) \ XAI | [RISE](http://bmvc2018.org/contents/papers/1064.pdf) | [LIME](https://www.kdd.org/kdd2016/papers/files/rfp0573-ribeiroA.pdf) | [KernelSHAP](https://proceedings.neurips.cc/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf) | -| :----------------------------------------------------------------| :----------------------------------------------------| :---------------------------------------------------------------------| :-------------------------------------------------------------------------------------------------------| -| Biology (Phytomorphology): Tree Leaves classification (images) | | ✅ | | -| Astronomy: Fast Radio Burst detection (timeseries) | ✅ | | | -| Land-atmosphere modeling: Latent heat flux prediction (tabular) | | | ✅ | -| Social sciences (text) | work in progress | ... |... | -| Climate | planned | ... | ... | +| Use-case (data) \ XAI | [RISE](http://bmvc2018.org/contents/papers/1064.pdf) | [LIME](https://www.kdd.org/kdd2016/papers/files/rfp0573-ribeiroA.pdf) | [KernelSHAP](https://proceedings.neurips.cc/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf) | +|:-------------------------------------------------------------------|:-----------------------------------------------------| :---------------------------------------------------------------------| :-------------------------------------------------------------------------------------------------------| +| Biology (Phytomorphology): Tree Leaves classification (images) | | ✅ | | +| Astronomy: Fast Radio Burst detection (timeseries) | ✅ | | | +| Land-atmosphere modeling: Latent heat flux prediction (tabular) | | | ✅ | +| Social sciences: EU-law regulatory statement classification (text) | | ✅ | | +| Climate | planned | ... | ... | ## Reference documentation diff --git a/dianna/utils/downloader.py b/dianna/utils/downloader.py index e4a9cef6..ab1a02d3 100644 --- a/dianna/utils/downloader.py +++ b/dianna/utils/downloader.py @@ -59,6 +59,10 @@ "doi:10.5281/zenodo.10656613/apertif_frb_dynamic_spectrum_model.onnx", "sha256:3c87db3c6257d7f251a7bdceb3197d5bb482b8edc19870219fb7ca7c204dd257" ], + "inlegal_bert_xgboost_classifier.json": [ + "doi:10.5281/zenodo.8200001/inlegal_bert_xgboost_classifier.json", + "68a672f29aac4a19c404c24f4c5da82a1ce7f704ccce701b0a1073c63730e127" + ], "stemmus_scope_emulator_model_LEtot.onnx": [ "doi:10.5281/zenodo.12623256/stemmus_scope_emulator_model_LEtot.onnx", "sha256:8c8f34ad5a2c519b1f3c67a4eb0c645c96cac1de166277bfb24e7887c2ce83be" diff --git a/docs/tutorials/9-lime_text_eulaw.nblink b/docs/tutorials/9-lime_text_eulaw.nblink new file mode 100644 index 00000000..36a5f475 --- /dev/null +++ b/docs/tutorials/9-lime_text_eulaw.nblink @@ -0,0 +1,3 @@ +{ + "path": "../../tutorials/explainers/LIME/lime_text_eulaw.ipynb" +} diff --git a/setup.cfg b/setup.cfg index c6f808b4..5c30f7c6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -112,6 +112,9 @@ notebooks = torch torchvision ipywidgets + freetype-py + transformers + xgboost [options.entry_points] console_scripts = diff --git a/tutorials/README.md b/tutorials/README.md index 16fd2ac5..9890ce6c 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -18,23 +18,23 @@ pip install .[notebooks] #### Illustrative (Simple) -|*Data modality*|Dataset|*Task*|Logo| -|:------------|:------|:---|:----| -|*Images*|Binary MNIST | Binary digit *classification*| mnist_zero_and_one_half_size| -||[Simple Geometric (circles and triangles)](https://doi.org/10.5281/zenodo.5012824)| Binary shape *classificaiton* |SimpleGeometric Logo| -||[Imagenet](https://image-net.org/download.php) |$1000$ classes natural images *classificaiton* | ImageNet_autocrop| -|*Text*| [Stanford sentiment treebank](https://nlp.stanford.edu/sentiment/index.html) |Positive or negative movie reviews sentiment *classificaiton* | nlp-logo_half_size| -|*Timeseries* | [Coffee dataset](https://www.timeseriesclassification.com/description.php?Dataset=Coffee) | Binary *classificaiton* of Robusta and Aribica coffee beans| Coffe Logo| -| | [Weather dataset](https://zenodo.org/record/7525955) |Binary *classification* (warm/cold season) of temperature time-series |Weather Logo| -|*Tabular*| [Penguin dataset](https://www.kaggle.com/code/parulpandey/penguin-dataset-the-new-iris)| $3$ penguin spicies (Adele, Chinstrap, Gentoo) *classificaiton* | Penguin Logo | | -| | [Weather dataset](https://zenodo.org/record/7525955) | Next day sunshine hours prediction (*regression*) | Weather Logo| +|*Data modality*|Dataset| *Task* |Logo| +|:------------|:------|:----------------------------------------------------------------------|:----| +|*Images*|Binary MNIST | Binary digit *classification* | mnist_zero_and_one_half_size| +||[Simple Geometric (circles and triangles)](https://doi.org/10.5281/zenodo.5012824)| Binary shape *classificaiton* |SimpleGeometric Logo| +||[Imagenet](https://image-net.org/download.php) | $1000$ classes natural images *classificaiton* | ImageNet_autocrop| +|*Text*| [Stanford sentiment treebank](https://nlp.stanford.edu/sentiment/index.html) | Positive or negative movie reviews sentiment *classification* | nlp-logo_half_size| +|*Timeseries* | [Coffee dataset](https://www.timeseriesclassification.com/description.php?Dataset=Coffee) | Binary *classificaiton* of Robusta and Aribica coffee beans | Coffe Logo| +| | [Weather dataset](https://zenodo.org/record/7525955) | Binary *classification* (warm/cold season) of temperature time-series |Weather Logo| +|*Tabular*| [Penguin dataset](https://www.kaggle.com/code/parulpandey/penguin-dataset-the-new-iris)| $3$ penguin spicies (Adele, Chinstrap, Gentoo) *classificaiton* | Penguin Logo | | +| | [Weather dataset](https://zenodo.org/record/7525955) | Next day sunshine hours prediction (*regression*) | Weather Logo| #### Scientific use-cases |*Data modality*|Dataset|*Task*|Logo| |:------------|:------|:---|:----| |*Images*|[Simple Scientific (LeafSnap30)](https://zenodo.org/record/5061353/)| $30$ tree species leaves *classification* | LeafSnap30 Logo | -|*Text*| | | | +|*Text*| [EU-law statements](https://zenodo.org/records/8200001) | Regulatory or non-regulatory *classification* | nlp-logo_half_size| |*Timeseries* | Fast Radio Burst (FRB) dataset (not publicly available) | Binary *classificaiton* of Fast Radio Burst (FRB) timeseries data : noise or a real FRB. | FRB logo| |*Tabular*| [Land atmosphere dataset](https://zenodo.org/records/12623257)| Prediction of "latent heat flux" (*regression*). The random forest model is used as an [emulator](https://github.com/EcoExtreML/Emulator) to replace the physical model [STEMMUS_SCOPE](https://github.com/EcoExtreML/STEMMUS_SCOPE) to predict global maps of latent heat flux. | Atmosphere Logo | @@ -59,12 +59,13 @@ To learn more about how we aproach the masking for time-series data, please read #### Scientific use-cases -|*Modality* \ Method|RISE|[LIME](https://youtu.be/d6j6bofhj2M)|Kernel[SHAP](https://youtu.be/9haIOplEIGM)| -|:-----|:---|:---|:---| -|*Images*| | [LeafSnap30 Logo](./explainers/LIME/lime_images.ipynb) or [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/LIME/lime_images.ipynb) || -|*Text* | | | | -| *Time series*| [FRB logo](./explainers/RISE/rise_timeseries_frb.ipynb) or [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/RISE/rise_timeseries_frb.ipynb) | | -| *Tabular* | | |[Atmosphere Logo](./explainers/KernelSHAP/kernelshap_tabular_land_atmosphere.ipynb) or [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/KernelSHAP/kernelshap_tabular_land_atmosphere.ipynb)| +| *Modality* \ Method |RISE| [LIME](https://youtu.be/d6j6bofhj2M) |Kernel[SHAP](https://youtu.be/9haIOplEIGM)| +|:--------------------|:---|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:---| +| *Images* | | [LeafSnap30 Logo](./explainers/LIME/lime_images.ipynb) or [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/LIME/lime_images.ipynb) || +| | | [FRB logo](./explainers/RISE/rise_timeseries_frb.ipynb) or [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/RISE/rise_timeseries_frb.ipynb) | | +| *Text* | | [nlp-logo_half_size](./explainers/LIME/lime_text_eulaw.ipynb) or [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/LIME/lime_text_eulaw.ipynb) | | +| *Time series* | [FRB logo](./explainers/RISE/rise_timeseries_frb.ipynb) or [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/RISE/rise_timeseries_frb.ipynb) | | +| *Tabular* | | |[Atmosphere Logo](./explainers/KernelSHAP/kernelshap_tabular_land_atmosphere.ipynb) or [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dianna-ai/dianna/blob/main/tutorials/explainers/KernelSHAP/kernelshap_tabular_land_atmosphere.ipynb)| ### IMPORTANT: Hyperparameters The XAI methods (explainers) are sensitive to the choice of their hyperparameters! In this [master Thesis](https://staff.fnwi.uva.nl/a.s.z.belloum/MSctheses/MScthesis_Willem_van_der_Spec.pdf), this sensitivity is researched and useful conclusions are drawn. @@ -85,13 +86,13 @@ Also the main conclusions (🠊) from the thesis (on images and text) about the 🠊 Larger $n_masks$ will return more consistent results at the cost of computation time. If 2 identical runs yield (very) different results, these will likely contain a lot of (or even mostly) noise and a higher value for $n_masks$ should be used instead. #### LIME -| Hyperparameter | Default value | LeafSnap30 Logo (*i*) |Weather Logo (*ts*)| Coffe Logo(*ts*)| -| ------------- | ------------- |--------| -----| -----| -| $n_{samples}$ | **$5000$** | $1000$ | $10 000$| $500$| -| *Kernel Width* | **$25$**| default | default| default| -| $n_{features}$ | **$10$** | $30$ | default| default| +| Hyperparameter | Default value | LeafSnap30 Logo (*i*) |Weather Logo (*ts*)| Coffe Logo(*ts*)| [nlp-logo_half_size](./explainers/LIME/lime_text_eulaw.ipynb) | +| ------------- | ------------- |--------| -----| -----|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| $n_{samples}$ | **$5000$** | $1000$ | $10 000$| $500$| 2000 | +| *Kernel Width* | **$25$**| default | default| default| default | +| $n_{features}$ | **$10$** | $30$ | default| default| 999 | -🠊 The most crucial parameter is the *Kernel width*: low values cause high sensitivity, however that observaiton was dependant on the evaluaiton metric. +🠊 The most crucial parameter is the *Kernel width*: low values cause high sensitivity, however that observation was dependent on the evaluation metric. #### KernelSHAP | Hyperparameter | Default value | mnist_zero_and_one_half_size (*i*)| SimpleGeometric Logo (*i*) | Atmosphere Logo (*tab*) | diff --git a/tutorials/explainers/LIME/lime_text_eulaw.ipynb b/tutorials/explainers/LIME/lime_text_eulaw.ipynb new file mode 100644 index 00000000..a4c1780c --- /dev/null +++ b/tutorials/explainers/LIME/lime_text_eulaw.ipynb @@ -0,0 +1,1048 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "african-verse", + "metadata": {}, + "source": [ + "\"Logo_ER10\"\n", + "\n", + "### Interpreting the sentence classification of legal texts using LIME\n", + "\n", + "LIME (Local Interpretable Model-agnostic Explanations) is an explainable-AI method that aims to create an interpretable model that locally represents the classifier. For more details see the [LIME paper](https://arxiv.org/abs/1602.04938).\n", + "\n", + "This notebook demonstrates how to use the LIME explainable-AI method in [DIANNA](https://github.com/dianna-ai/dianna) to explain a text classification model created as part of the [Nature of EU Rules project](https://research-software-directory.org/projects/the-nature-of-eu-rules-strict-and-detailed-or-lacking-bite). The model is used to perform binary classification of individual sentences from EU legislation to determine whether they specify a regulation or not (i.e., whether they specify a legal obligation or prohibition that some legal entity should comply with). [Here's an example](https://eur-lex.europa.eu/legal-content/EN/TXT/HTML/?uri=CELEX:32012R1215&qid=1724343987254) of what an EU legislative document looks like.\n", + "\n", + "##### Regulatory sentence example:\n", + "\n", + "```Citizens of all Member States shall separate their recyclables before disposing of trash, or else face a fine.```\n", + "\n", + "##### Non-regulatory (constitutive) sentence example:\n", + "\n", + "```This Regulation shall apply in civil and commercial matters whatever the nature of the court or tribunal.```\n", + "\n", + "**Note:** while the occurrence of words like ``shall`` and ``must`` (which are called **deontic** words) are necessary condition for a sentence to be classified as regulatory, they are not a sufficient condition." + ] + }, + { + "cell_type": "markdown", + "id": "fa6d17b0", + "metadata": {}, + "source": [ + "#### Colab Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "471630ff", + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-10T11:13:59.624761Z", + "start_time": "2024-07-10T11:13:59.563650Z" + } + }, + "outputs": [], + "source": [ + "running_in_colab = 'google.colab' in str(get_ipython())\n", + "if running_in_colab:\n", + " # install dianna\n", + " !python3 -m pip install dianna[notebooks]" + ] + }, + { + "cell_type": "markdown", + "id": "a5cf6f82-c1c7-4814-ae0f-5a1c0b8578f6", + "metadata": {}, + "source": [ + "#### 0. Imports and paths" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "34b556d8-5337-44dc-8efe-14d1dff6f011", + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-10T11:13:59.625762Z", + "start_time": "2024-07-10T11:13:59.568658Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\ChristiaanMeijer\\anaconda3\\envs\\dianna3112\\Lib\\site-packages\\torchtext\\data\\__init__.py:4: UserWarning: \n", + "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", + "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", + " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n" + ] + } + ], + "source": [ + "from torch.utils.data import DataLoader\n", + "from typing import Iterable\n", + "from tqdm import tqdm\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "import xgboost\n", + "\n", + "import dianna\n", + "from dianna import visualization\n", + "from dianna.utils.downloader import download\n", + "from dianna.utils.tokenizers import SpacyTokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c616916c-78ef-48d0-a744-b25b37b62a3f", + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-10T11:13:59.666220Z", + "start_time": "2024-07-10T11:13:59.576245Z" + } + }, + "outputs": [], + "source": [ + "class_names = ['constitutive', 'regulatory']\n", + "model_path = download('inlegal_bert_xgboost_classifier.json', 'model')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# 1 - Define test data" + ], + "id": "156805d0e3de333b" + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-10T11:13:59.666220Z", + "start_time": "2024-07-10T11:13:59.590692Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "constitutive_statement_0 = \"The purchase, import or transport from Syria of crude oil and petroleum products shall be prohibited.\"\n", + "constitutive_statement_1 = \"This Decision shall enter into force on the twentieth day following that of its publication in the Official Journal of the European Union.\"\n", + "regulatory_statement_0 = \"Where observations are submitted, or where substantial new evidence is presented, the Council shall review its decision and inform the person or entity concerned accordingly.\"\n", + "regulatory_statement_1 = \"The relevant Member State shall inform the other Member States of any authorisation granted under this Article.\"\n", + "regulatory_statement_2 = \"Member States shall cooperate, in accordance with their national legislation, with inspections and disposals undertaken pursuant to paragraphs 1 and 2.\"" + ], + "id": "200b1ed56f4ddf5f" + }, + { + "cell_type": "markdown", + "id": "bad4f5b1-6097-4ef3-98c4-78432ad640b0", + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-21T10:28:34.466985Z", + "start_time": "2024-03-21T10:28:34.456937Z" + } + }, + "source": [ + "# 2 - Load and prepare the model\n", + "\n", + "The model is a combination of a pretrained transformer used as a feature extractor, with an XGBoost model trained on top. The following cells load the model into the variable called `model_runner`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-10T11:13:59.666763Z", + "start_time": "2024-07-10T11:13:59.598468Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer, AutoModel\n", + "def create_features(texts: list[str], model_tag=\"law-ai/InLegalBERT\") -> torch.Tensor:\n", + " \"\"\"Create features for a list of texts.\"\"\"\n", + " max_length = 512\n", + " tokenizer = AutoTokenizer.from_pretrained(model_tag)\n", + " model = AutoModel.from_pretrained(model_tag)\n", + "\n", + " def process_batch(batch: Iterable[str]):\n", + " cropped_texts = [text[:max_length] for text in batch]\n", + " encoded_inputs = tokenizer(cropped_texts, padding='longest', truncation=True, max_length=max_length,\n", + " return_tensors=\"pt\")\n", + " with torch.no_grad():\n", + " outputs = model(**encoded_inputs)\n", + " last_hidden_states = outputs.last_hidden_state\n", + " sentence_features = last_hidden_states.mean(dim=1)\n", + " return sentence_features\n", + "\n", + " dataloader = DataLoader(texts, batch_size=1) # batch size of 1 was quickest for my development machine\n", + " features = [process_batch(batch) for batch in tqdm(dataloader, desc=f'Creating features')]\n", + " return np.array(torch.cat(features, dim=0))\n", + "\n" + ], + "id": "e1a16e955d860d89" + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-10T11:13:59.672763Z", + "start_time": "2024-07-10T11:13:59.606453Z" + }, + "collapsed": false + }, + "outputs": [], + "source": [ + "models={}\n", + "def classify_texts(texts: list[str], model_path, return_proba: bool = False):\n", + " \"\"\"Classifies every text in a list of texts using the xgboost model stored in model_path.\n", + "\n", + " The xgboost model will be loaded and used to classify the texts. The texts however will first be processed by a\n", + " large language model which will do the feature extraction for every text. The classifications of the\n", + " xgboost model will be returned.\n", + " For training the xgboost model, see train_legalbert_xgboost.py.\n", + "\n", + " Parameters\n", + " ----------\n", + " texts\n", + " A list of strings of which each needs to be classified.\n", + " model_path\n", + " The path to a stored xgboost model\n", + " return_proba\n", + " return the probabilities of the model\n", + "\n", + " Returns\n", + " -------\n", + " List of classifications, one for every text in the list\n", + "\n", + " \"\"\"\n", + " features = create_features(texts)\n", + " if model_path not in models:\n", + " print(f'Loading model from {model_path}.')\n", + " model = xgboost.XGBClassifier()\n", + " model.load_model(model_path)\n", + " models[model_path] = model\n", + "\n", + " model = models[model_path]\n", + " if return_proba:\n", + " return model.predict_proba(features)\n", + " return model.predict(features)" + ], + "id": "9d8346a5f0e78d5d" + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "555842c5-3f82-4f63-93bb-696645d4b447", + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-10T11:14:00.114564Z", + "start_time": "2024-07-10T11:13:59.617764Z" + } + }, + "outputs": [], + "source": [ + "class StatementClassifier:\n", + " def __init__(self):\n", + " self.tokenizer = SpacyTokenizer(name='en_core_web_sm')\n", + "\n", + " def __call__(self, sentences):\n", + " # ensure the input has a batch axis\n", + " if isinstance(sentences, str):\n", + " sentences = [sentences]\n", + "\n", + " probs = classify_texts(sentences, model_path, return_proba=True)\n", + "\n", + " return np.transpose([(probs[:, 0]), (1 - probs[:, 0])])\n", + "\n", + "model_runner = StatementClassifier()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Test the model" + ], + "id": "e6ae5f1011540179" + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-10T11:14:01.218604Z", + "start_time": "2024-07-10T11:14:00.117563Z" + }, + "collapsed": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Creating features: 100%|██████████| 5/5 [00:01<00:00, 3.63it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model from C:\\Users\\ChristiaanMeijer\\AppData\\Local\\dianna\\dianna\\Cache\\inlegal_bert_xgboost_classifier.json.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
statementpredictionactual
0The purchase, import or transport from Syria o...constitutiveconstitutive
1This Decision shall enter into force on the tw...constitutiveconstitutive
2Where observations are submitted, or where sub...regulatoryregulatory
3The relevant Member State shall inform the oth...regulatoryregulatory
4Member States shall cooperate, in accordance w...regulatoryregulatory
\n", + "
" + ], + "text/plain": [ + " statement prediction \\\n", + "0 The purchase, import or transport from Syria o... constitutive \n", + "1 This Decision shall enter into force on the tw... constitutive \n", + "2 Where observations are submitted, or where sub... regulatory \n", + "3 The relevant Member State shall inform the oth... regulatory \n", + "4 Member States shall cooperate, in accordance w... regulatory \n", + "\n", + " actual \n", + "0 constitutive \n", + "1 constitutive \n", + "2 regulatory \n", + "3 regulatory \n", + "4 regulatory " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "statements = [constitutive_statement_0, constitutive_statement_1, regulatory_statement_0, regulatory_statement_1,\n", + " regulatory_statement_2]\n", + "actual_classes = [class_names[c] for c in [0,0,1,1,1]]\n", + "model_outputs = model_runner(statements)\n", + "predictioned_classes = [class_names[m] for m in np.argmax(model_outputs, axis=1)]\n", + "\n", + "pd.DataFrame({'statement': statements, 'prediction': predictioned_classes, 'actual': actual_classes})" + ], + "id": "d4a307a3437d15c" + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# 3 - Explain the model" + ], + "id": "fe8e935b18285991" + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Set parameters for DIANNA" + ], + "id": "62e20ed610533a98" + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-10T11:14:01.225117Z", + "start_time": "2024-07-10T11:14:01.221604Z" + }, + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "label_of_interest is regulatory\n" + ] + } + ], + "source": [ + "label_of_interest = 1\n", + "print('label_of_interest is', class_names[label_of_interest])\n", + "statement = regulatory_statement_1\n", + "num_samples = 2000\n", + "num_features = 999 # top n number of words to include in the attribution map\n", + "\n", + "def run_dianna(input_text):\n", + " return dianna.explain_text(model_runner, input_text, model_runner.tokenizer,\n", + " 'LIME', labels=[label_of_interest], num_samples=num_samples, num_features=num_features, )[0]" + ], + "id": "ada613b3b918ee2f" + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Are the results stable with current parameters?\n", + "A crucial hyperparameter is the `num_samples` which is set above. Too few samples results in a noisy explanation. Too many, is computationally expensive. If repeated runs yield (very) different results, the number of samples is too low for the current setting (which includes data, model, sentence length and other xai parameters)." + ], + "id": "286e22d0c9d8f486" + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-21T11:57:43.971896Z", + "start_time": "2024-03-21T11:57:43.921805Z" + }, + "collapsed": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Creating features: 100%|██████████| 2000/2000 [02:17<00:00, 14.55it/s]\n", + "Creating features: 100%|██████████| 2000/2000 [02:10<00:00, 15.32it/s]\n", + "Creating features: 100%|██████████| 2000/2000 [02:11<00:00, 15.16it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
TherelevantMemberStateshallinformtheotherMemberStatesofanyauthorisationgrantedunderthisArticle.
count3.0000003.0000003.0000003.0000003.0000003.0000003.0000003.0000003.0000003.0000003.0000003.0000003.0000003.0000003.0000003.0000003.0000003.000000
mean0.0665430.0400310.0600460.1059160.1927360.1454570.0785870.0641910.0891210.1423540.1147810.0425350.0655030.0398730.0918660.0808700.0543410.111533
std0.0274770.0074840.0225900.0006650.0086120.0147560.0045250.0048210.0117750.0227800.0113970.0133240.0106380.0090310.0030860.0100700.0207050.005437
min0.0496050.0347200.0384250.1051550.1828330.1286930.0734300.0608160.0758470.1241680.1025720.0281730.0590080.0295110.0885170.0698020.0383320.108313
25%0.0506910.0357510.0483230.1056810.1898680.1399460.0769340.0614300.0845270.1295780.1096010.0365560.0593650.0367760.0905010.0765580.0426490.108395
50%0.0517780.0367820.0582210.1062080.1969020.1511980.0804380.0620440.0932070.1349890.1166300.0449390.0597220.0440420.0924850.0833140.0469670.108477
75%0.0750110.0426860.0708570.1062960.1976880.1538390.0811660.0658780.0957580.1514470.1208850.0497160.0687510.0450540.0935400.0864040.0623450.113144
max0.0982450.0485910.0834930.1063850.1984730.1564790.0818940.0697130.0983080.1679050.1251400.0544940.0777800.0460660.0945950.0894930.0777240.117810
\n", + "
" + ], + "text/plain": [ + " The relevant Member State shall inform the \\\n", + "count 3.000000 3.000000 3.000000 3.000000 3.000000 3.000000 3.000000 \n", + "mean 0.066543 0.040031 0.060046 0.105916 0.192736 0.145457 0.078587 \n", + "std 0.027477 0.007484 0.022590 0.000665 0.008612 0.014756 0.004525 \n", + "min 0.049605 0.034720 0.038425 0.105155 0.182833 0.128693 0.073430 \n", + "25% 0.050691 0.035751 0.048323 0.105681 0.189868 0.139946 0.076934 \n", + "50% 0.051778 0.036782 0.058221 0.106208 0.196902 0.151198 0.080438 \n", + "75% 0.075011 0.042686 0.070857 0.106296 0.197688 0.153839 0.081166 \n", + "max 0.098245 0.048591 0.083493 0.106385 0.198473 0.156479 0.081894 \n", + "\n", + " other Member States of any authorisation \\\n", + "count 3.000000 3.000000 3.000000 3.000000 3.000000 3.000000 \n", + "mean 0.064191 0.089121 0.142354 0.114781 0.042535 0.065503 \n", + "std 0.004821 0.011775 0.022780 0.011397 0.013324 0.010638 \n", + "min 0.060816 0.075847 0.124168 0.102572 0.028173 0.059008 \n", + "25% 0.061430 0.084527 0.129578 0.109601 0.036556 0.059365 \n", + "50% 0.062044 0.093207 0.134989 0.116630 0.044939 0.059722 \n", + "75% 0.065878 0.095758 0.151447 0.120885 0.049716 0.068751 \n", + "max 0.069713 0.098308 0.167905 0.125140 0.054494 0.077780 \n", + "\n", + " granted under this Article . \n", + "count 3.000000 3.000000 3.000000 3.000000 3.000000 \n", + "mean 0.039873 0.091866 0.080870 0.054341 0.111533 \n", + "std 0.009031 0.003086 0.010070 0.020705 0.005437 \n", + "min 0.029511 0.088517 0.069802 0.038332 0.108313 \n", + "25% 0.036776 0.090501 0.076558 0.042649 0.108395 \n", + "50% 0.044042 0.092485 0.083314 0.046967 0.108477 \n", + "75% 0.045054 0.093540 0.086404 0.062345 0.113144 \n", + "max 0.046066 0.094595 0.089493 0.077724 0.117810 " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "explanation_relevances = [run_dianna(statement) for i in range(3)]\n", + "sorted_relevances = [sorted(r, key=lambda t : t[1]) for r in explanation_relevances]\n", + "\n", + "pd.DataFrame([[r[2] for r in sr] for sr in sorted_relevances], columns=[r[0] for r in sorted_relevances[0]]).describe()" + ], + "id": "a9d3404a5edfe5c6" + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "Seems quite stable with 2000 samples in LIME. We can now run DIANNA knowing results will contain mostly signal and not just noise." + ], + "id": "b9ecfcd2400ffc26" + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Running the actual explainer\n", + "We run the explainer one more time (convenient for whoever skipped the last cell) and show the attributions. This time we run it on a single example statement from the test data. The output table displays attribution scores for each word in the sentence. Each score represents how important or relevant that particular word was for the model to assign it the specified class. The closer the attribution score is to 1 the more important the word is to the classification." + ], + "id": "847554ae856aa9bc" + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-10T10:39:53.877421Z", + "start_time": "2024-07-10T10:38:41.095511Z" + }, + "collapsed": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Creating features: 100%|██████████| 2000/2000 [02:28<00:00, 13.44it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "attributions for class regulatory\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
012
0The00.040727
1relevant10.063082
2Member20.065679
3State30.080698
4shall40.171471
5inform50.142825
6the60.073943
7other70.078633
8Member80.070353
9States90.132782
10of100.115537
11any110.059245
12authorisation120.057643
13granted130.057985
14under140.118352
15this150.085583
16Article160.062079
17.170.120736
\n", + "
" + ], + "text/plain": [ + " 0 1 2\n", + "0 The 0 0.040727\n", + "1 relevant 1 0.063082\n", + "2 Member 2 0.065679\n", + "3 State 3 0.080698\n", + "4 shall 4 0.171471\n", + "5 inform 5 0.142825\n", + "6 the 6 0.073943\n", + "7 other 7 0.078633\n", + "8 Member 8 0.070353\n", + "9 States 9 0.132782\n", + "10 of 10 0.115537\n", + "11 any 11 0.059245\n", + "12 authorisation 12 0.057643\n", + "13 granted 13 0.057985\n", + "14 under 14 0.118352\n", + "15 this 15 0.085583\n", + "16 Article 16 0.062079\n", + "17 . 17 0.120736" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "explanation_relevance = run_dianna(statement)\n", + "print('attributions for class', class_names[label_of_interest])\n", + "pd.DataFrame(explanation_relevance)" + ], + "id": "b5b2395041aa2577" + }, + { + "cell_type": "markdown", + "id": "7e177746-3654-4518-9c1c-b7047f922273", + "metadata": { + "ExecuteTime": { + "end_time": "2024-03-21T10:38:43.263330Z", + "start_time": "2024-03-21T10:38:41.425881Z" + } + }, + "source": [ + "# 4 - Visualization\n", + "DIANNA includes a visualization package, capable of highlighting the relevance of each word in the text for a chosen class. The visualization is in HTML format.\n", + "Words in favour of the selected class are highlighted in red, while words against the selected class - in blue.\n", + "\n", + "Below we see a plot showing which words (colored with higher intensity of red) contributed the most to the classification of the sentence ``The relevant Member State shall inform the other Member States of any authorisation granted under this Article.`` as ``regulatory``. \n", + "\n", + "We can see that the most important words are: ``shall``, ``inform``, ``States``, ``of`` and ``under``. ``shall`` indicates a necessary deontic word for regulatory sentences which makes sense. ``inform`` is an action word or verb. These actions typically occur in regulatory sentences because such sentences often have to indicate what a particular party should or should not do to successfully comply with a regulation. Therefore it also makes sense that ``inform`` has high attribution. \n", + "\n", + "Another component often occurring in EU regulatory sentences is a reference to the actual party or legal entity being regulated. In this case the party being regulated is ``Member States`` (all EU Member countries). However we see that ``Member`` does not receive a high attribution score relative to ``States``. It could be that the model prefers ``States of`` as indicating a party to be regulated as in ``United States of America`` (``of`` has a similar attribution score to ``States``). Perhaps the model has learned that ``States of`` is usually a subphrase of an agent of some kind (which is the precondition for being a party to be regulated)." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0136005d-a22f-43a0-80da-4ec1f283f870", + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-10T10:39:54.163294Z", + "start_time": "2024-07-10T10:39:53.872085Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "_ = visualization.highlight_text(explanation_relevance, model_runner.tokenizer.tokenize(statement))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "execution": { + "timeout": 1800 + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}