From de5437ae710c738a0481b13dc9d266dd558c43a4 Mon Sep 17 00:00:00 2001 From: jeswan <57466294+jeswan@users.noreply.github.com> Date: Tue, 4 May 2021 13:19:24 -0400 Subject: [PATCH] Merge easy_add_model feature branch (#1309) * Update to Transformers v4.3.3 (#1266) * use default return_dict in taskmodels and remove hidden state context manager in models. * return hidden states in output of model wrapper * Switch to task model/head factories instead of embedded if-else statements (#1268) * Use jiant transformers model wrapper instead of if-else. Use taskmodel and head factory instead of if-else. * switch to ModelArchitectures enum instead of strings * Refactor get_output_from_encoder() to be member of JiantTaskModel (#1283) * refactor getting output from encoder to be member function of jiant model * switch to explicit encode() in jiant transformers model * fix simple runscript test * update to tokenizer 0.10.1 * Add tests for flat_strip() (#1289) * add flat_strip test * add list to test cases flat_strip * mlm_weights(), feat_spec(), flat_strip() if-else refactors (#1288) * moves remaining if-else statments to jiant model or replaces with model agnostic method * switch from jiant_transformers_model to encoder * fix bug in flat_strip() * Move tokenization logic to central JiantModelTransformers method (#1290) * move model specific tokenization logic to JiantTransformerModels * implement abstract methods for JiantTransformerModels * fix tasks circular import (#1296) * Add DeBERTa (#1295) * Add DeBERTa with sanity test * fix tasks circular import * [WIP] add deberta tests * Revert "fix tasks circular import" This reverts commit f92464020133d09d7a7d8b416762ffc2f9067a64. * deberta tests passing with transformers 6472d8 * switch to deberta-v2 * fix get_mlm_weights_dict() for deberta-v2 * update to transformers 4.5.0 * mark deberta test_export as slow * Update test_tokenization_normalization.py * add guide to add a model * fix test_expor_model tests * minor pytest fixes (add num_labels for rte, overnight flag fix) * bugfix for simple api notebook * bugfix for #1310 * bugfix for #1306: simple api notebook path name * squad running * 2nd bugfix for #1310: not all tasks have num_labels property * simple api notebook back to roberta-base * run test matrix for more steps to compare to master * save last/best model test fix Co-authored-by: Jesse Swanson --- conftest.py | 6 +- .../notebooks/simple_api_fine_tuning.ipynb | 442 +++++++-------- guides/README.md | 3 + guides/models/adding_models.md | 96 ++++ jiant/proj/main/components/container_setup.py | 9 +- jiant/proj/main/export_model.py | 2 +- jiant/proj/main/modeling/heads.py | 155 +++++- jiant/proj/main/modeling/model_setup.py | 310 ++--------- jiant/proj/main/modeling/primary.py | 522 +++++++++++++++++- jiant/proj/main/modeling/taskmodels.py | 371 ++++++------- jiant/proj/main/tokenize_and_cache.py | 16 +- jiant/proj/simple/runscript.py | 1 - jiant/shared/model_resolution.py | 314 ++--------- jiant/tasks/__init__.py | 2 - jiant/tasks/core.py | 24 +- jiant/tasks/evaluate/core.py | 181 +++--- jiant/tasks/lib/ccg.py | 2 +- jiant/tasks/lib/ropes.py | 4 +- jiant/tasks/lib/rte.py | 4 + .../templates/hacky_tokenization_matching.py | 79 +-- jiant/tasks/lib/templates/squad_style/core.py | 34 +- .../tasks/lib/templates/squad_style/utils.py | 2 +- jiant/utils/python/datastructures.py | 41 ++ jiant/utils/tokenization_normalization.py | 89 +-- jiant/utils/tokenization_utils.py | 49 ++ jiant/utils/transformer_utils.py | 32 -- pyproject.toml | 5 + requirements-no-torch.txt | 6 +- setup.py | 8 +- tests/proj/main/test_export_model.py | 24 +- tests/proj/simple/test_runscript.py | 52 +- .../test_hacky_tokenization_matching.py | 24 + tests/tasks/lib/test_mlm_premasked.py | 10 +- tests/tasks/lib/test_mlm_pretokenized.py | 11 +- tests/tasks/lib/test_mnli.py | 11 +- tests/tasks/lib/test_spr1.py | 7 +- tests/tasks/lib/test_sst.py | 2 +- .../utils/test_tokenization_normalization.py | 76 +-- 38 files changed, 1606 insertions(+), 1420 deletions(-) create mode 100644 guides/models/adding_models.md create mode 100644 jiant/utils/tokenization_utils.py delete mode 100644 jiant/utils/transformer_utils.py create mode 100644 tests/tasks/lib/templates/test_hacky_tokenization_matching.py diff --git a/conftest.py b/conftest.py index 51b07d922..045aa3e00 100644 --- a/conftest.py +++ b/conftest.py @@ -8,13 +8,15 @@ def pytest_addoption(parser): parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") parser.addoption("--rungpu", action="store_true", default=False, help="run gpu tests") - parser.addoption("--runovernight", action="store_true", - default=False, help="run overnight tests") + parser.addoption( + "--runovernight", action="store_true", default=False, help="run overnight tests" + ) def pytest_configure(config): config.addinivalue_line("markers", "slow: mark test as slow to run") config.addinivalue_line("markers", "gpu: mark test as gpu required to run") + config.addinivalue_line("markers", "overnight: mark test as gpu required to run") def pytest_collection_modifyitems(config, items): diff --git a/examples/notebooks/simple_api_fine_tuning.ipynb b/examples/notebooks/simple_api_fine_tuning.ipynb index 55aad6dff..f29fac33d 100644 --- a/examples/notebooks/simple_api_fine_tuning.ipynb +++ b/examples/notebooks/simple_api_fine_tuning.ipynb @@ -1,236 +1,210 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "simple_api_fine_tuning.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "xT4_NpW-TotE", - "colab_type": "text" - }, - "source": [ - "# Welcome to `jiant`\n", - "This notebook contains an example of fine-tuning a `roberta-base` model on the MRPC task using the simple `jiant` API." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JlM8H-WCoh9k", - "colab_type": "text" - }, - "source": [ - "# Install dependencies" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "rY-7weGtIEUX", - "colab_type": "code", - "colab": {} - }, - "source": [ - "%%capture\n", - "!git clone https://github.com/nyu-mll/jiant.git\n", - "%cd jiant && git checkout tags/2.0.0", - "\n", - "# This Colab notebook already has its CUDA-runtime compatible versions of torch and torchvision installed\n", - "!pip install -r jiant/requirements-no-torch.txt\n", - "# Install pyarrow for nlp (no longer necessary after nlp>0.3.0)\n", - "!pip install pyarrow==0.16.0" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "j78UDhA7UMzi", - "colab_type": "text" - }, - "source": [ - "# Imports" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "5hsmmr9eIJJt", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# To be removed when jiant installed with pip\n", - "import sys\n", - "sys.path.insert(0, \"/content/jiant\")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "UXYZHyyNGayI", - "colab_type": "code", - "colab": {} - }, - "source": [ - "import os\n", - "\n", - "import jiant.utils.python.io as py_io\n", - "import jiant.proj.simple.runscript as simple_run\n", - "import jiant.scripts.download_data.runscript as downloader" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aow1wIIDUS4h", - "colab_type": "text" - }, - "source": [ - "# Define task and model" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "AEtgbtkRHDJE", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# See https://github.com/nyu-mll/jiant/blob/master/guides/tasks/supported_tasks.md for supported tasks\n", - "TASK_NAME = \"mrpc\"\n", - "\n", - "# See https://huggingface.co/models for supported models\n", - "MODEL_TYPE = \"roberta-base\"" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xDjg1tRNUk9r", - "colab_type": "text" - }, - "source": [ - "# Create directories for task data and experiment" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Gxy_csM9UhhA", - "colab_type": "code", - "colab": {} - }, - "source": [ - "RUN_NAME = f\"simple_{TASK_NAME}_{MODEL_TYPE}\"\n", - "EXP_DIR = \"/content/exp\"\n", - "DATA_DIR = \"/content/exp/tasks\"\n", - "\n", - "os.makedirs(DATA_DIR, exist_ok=True)\n", - "os.makedirs(EXP_DIR, exist_ok=True)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8tXDQ1P2Unfa", - "colab_type": "text" - }, - "source": [ - "#Download data (uses `nlp` or direct download depending on task)" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "_HCQG8fEU4CU", - "colab_type": "code", - "colab": {} - }, - "source": [ - "downloader.download_data([TASK_NAME], DATA_DIR)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZsInlNWLU5ZU", - "colab_type": "text" - }, - "source": [ - "#Run simple `jiant` pipeline (train and evaluate on MRPC)" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Po0N521IHAjj", - "colab_type": "code", - "colab": {} - }, - "source": [ - "args = simple_run.RunConfiguration(\n", - " run_name=RUN_NAME,\n", - " exp_dir=EXP_DIR,\n", - " data_dir=DATA_DIR,\n", - " model_type=MODEL_TYPE,\n", - " tasks=TASK_NAME,\n", - " train_batch_size=16,\n", - " num_train_epochs=1\n", - ")\n", - "simple_run.run_simple(args)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2GLUP22PoowE", - "colab_type": "text" - }, - "source": [ - "The simple API `RunConfiguration` object is saved as `simple_run_config.json`. `simple_run_config.json` can be loaded and used as inputs to repeat experiments as follows." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ckhYG6Ijh1nC", - "colab_type": "code", - "colab": {} - }, - "source": [ - "args = simple_run.RunConfiguration.from_json_path(os.path.join(EXP_DIR, \"runs\", RUN_NAME, \"simple_run_config.json\"))\n", - "simple_run.run_simple(args)" - ], - "execution_count": null, - "outputs": [] - } - ] + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "simple_api_fine_tuning.ipynb", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "xT4_NpW-TotE" + }, + "source": [ + "# Welcome to `jiant`\n", + "This notebook contains an example of fine-tuning a `roberta-base` model on the MRPC task using the simple `jiant` API." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JlM8H-WCoh9k" + }, + "source": [ + "# Install dependencies" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "rY-7weGtIEUX" + }, + "source": [ + "%%capture\n", + "!git clone https://github.com/nyu-mll/jiant.git\n", + "# This Colab notebook already has its CUDA-runtime compatible versions of torch and torchvision installed\n", + "!pip install -r jiant/requirements-no-torch.txt" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "j78UDhA7UMzi" + }, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "5hsmmr9eIJJt" + }, + "source": [ + "import sys\n", + "sys.path.insert(0, \"/content/jiant\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "UXYZHyyNGayI" + }, + "source": [ + "import os\n", + "\n", + "import jiant.utils.python.io as py_io\n", + "import jiant.proj.simple.runscript as simple_run\n", + "import jiant.scripts.download_data.runscript as downloader" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aow1wIIDUS4h" + }, + "source": [ + "# Define task and model" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "AEtgbtkRHDJE" + }, + "source": [ + "# See https://github.com/nyu-mll/jiant/blob/master/guides/tasks/supported_tasks.md for supported tasks\n", + "TASK_NAME = \"mrpc\"\n", + "\n", + "# See https://huggingface.co/models for supported models\n", + "HF_PRETRAINED_MODEL_NAME = \"roberta-base\"" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xDjg1tRNUk9r" + }, + "source": [ + "# Create directories for task data and experiment" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Gxy_csM9UhhA" + }, + "source": [ + "# Remove forward slashes so RUN_NAME can be used as path\n", + "MODEL_NAME = HF_PRETRAINED_MODEL_NAME.split(\"/\")[-1]\n", + "RUN_NAME = f\"simple_{TASK_NAME}_{MODEL_NAME}\"\n", + "EXP_DIR = \"/content/exp\"\n", + "DATA_DIR = \"/content/exp/tasks\"\n", + "\n", + "os.makedirs(DATA_DIR, exist_ok=True)\n", + "os.makedirs(EXP_DIR, exist_ok=True)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8tXDQ1P2Unfa" + }, + "source": [ + "#Download data (uses `nlp` or direct download depending on task)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_HCQG8fEU4CU" + }, + "source": [ + "downloader.download_data([TASK_NAME], DATA_DIR)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZsInlNWLU5ZU" + }, + "source": [ + "#Run simple `jiant` pipeline (train and evaluate on MRPC)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Po0N521IHAjj" + }, + "source": [ + "args = simple_run.RunConfiguration(\n", + " run_name=RUN_NAME,\n", + " exp_dir=EXP_DIR,\n", + " data_dir=DATA_DIR,\n", + " hf_pretrained_model_name_or_path=HF_PRETRAINED_MODEL_NAME,\n", + " tasks=TASK_NAME,\n", + " train_batch_size=16,\n", + " num_train_epochs=1\n", + ")\n", + "simple_run.run_simple(args)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2GLUP22PoowE" + }, + "source": [ + "The simple API `RunConfiguration` object is saved as `simple_run_config.json`. `simple_run_config.json` can be loaded and used as inputs to repeat experiments as follows." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ckhYG6Ijh1nC", + "collapsed": true + }, + "source": [ + "args = simple_run.RunConfiguration.from_json_path(os.path.join(EXP_DIR, \"runs\", RUN_NAME, \"simple_run_config.json\"))\n", + "simple_run.run_simple(args)" + ], + "execution_count": null, + "outputs": [] + } + ] } \ No newline at end of file diff --git a/guides/README.md b/guides/README.md index 3d51476a2..17ab2244d 100644 --- a/guides/README.md +++ b/guides/README.md @@ -50,6 +50,9 @@ These are more specific guides about running experiments in `jiant`: * [My Experiment and Me](experiments/my_experiment_and_me.md): More info about a `jiant` training/eval run * [Tips for Large-scale Experiments](experiments/large_scale_experiments.md) +## Adding a Model +* [Guide for adding a model to `jiant`](models/adding_models.md) + ## Tasks These are notes on the tasks supported in `jiant`: diff --git a/guides/models/adding_models.md b/guides/models/adding_models.md new file mode 100644 index 000000000..dfda250c8 --- /dev/null +++ b/guides/models/adding_models.md @@ -0,0 +1,96 @@ + # Adding a model + +`jiant` supports or can easily be extended to support Hugging Face's [Transformer models](https://huggingface.co/transformers/viewer/) since `jiant` utilizes [Auto Classes](https://huggingface.co/transformers/model_doc/auto.html) to determine the architecture of the model used based on the name of the [pretrained model](https://huggingface.co/models). Although `jiant` uses AutoModels to reolve model classes, the `jiant` pipeline requires additional information (such as matching the correct tokenizer for the models). Furthermore, there are subtle differences in the models that `jiant` must abstract and additional steps are required to add a Hugging Face model to `jiant`. To add a model not currently supported in `jiant`, follow the following steps: + +## 1. Add to ModelArchitectures enum +Add the model to the ModelArchitectures enum in [`model_resolution.py`](../../jiant/tasks/model_resolution.py) as a member-string mapping. For example, adding the field DEBERTAV2 = "deberta-v2" would add Deberta V2 to the ModelArchitectures enum. + +## 2. Add to the TOKENIZER_CLASS_DICT +Add the model to the TOKENIZER_CLASS_DICT in [`model_resolution.py`](../../jiant/tasks/model_resolution.py). This dictionary maps the ModelArchitectures to Hugging Face tokenizer classes. + +## 3. Subclass JiantTransformersModel +Create a subclass of JiantTransformersModel in ['jiant/proj/main/modeling/primary.py'](../../jiant/proj/main/modeling/primary.py). The JiantTransformersModel is used to wrap Hugging Face Transformer models to abstract any inconsistencies in the model fields. JiantTransformersModel is an abstract class with several methods that must be implemented as well as several methods that can be optionally overridden. + + +```python +class JiantTransformersModel(metaclass=abc.ABCMeta): + def __init__(self, baseObject): + self.__class__ = type( + baseObject.__class__.__name__, (self.__class__, baseObject.__class__), {} + ) + self.__dict__ = baseObject.__dict__ + + @classmethod + @abc.abstractmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """Abstract method to tag space_tokenization and process target_tokenization with + the relevant tokenization method for the model.""" + pass + + @abc.abstractmethod + def get_mlm_weights_dict(self, weights_dict): + """Abstract method to get the pre-trained masked-language modeling head weights + from the pretrained model from the weights_dict""" + pass + + @abc.abstractmethod + def get_feat_spec(self, weights_dict): + """Abstract method that should return a FeaturizationSpec specifying the + tokenization details used for the model""" + pass + + def get_hidden_size(self): + ... + + def get_hidden_dropout_prob(self): + ... + + def encode(self, input_ids, segment_ids, input_mask, output_hidden_states=True): + ... +``` + +`jiant` uses a dynamic registry for supported models. To add your model to the dynamic registry add a decorator to the new model class registering the new model class with the corresponding ModelArchitecture enum added in Step 1. + + +```python +@JiantTransformersModelFactory.register(ModelArchitectures.DEBERTAV2) +class JiantDebertaV2Model(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) +``` + +## (Optional) 4. Add/Register additional task heads +Specific task heads may require model-specific implementation such as the MLM task heads. These model-specific task heads should be added and registered with their respective factory in (jiant/proj/main/modeling/heads.py)[../../jiant/proj/main/modeling/heads.py] if applicable. For example, MLM heads require a factory since their implementation differs across models: + +```python +@JiantMLMHeadFactory.register([ModelArchitectures.DEBERTAV2]) +class DebertaV2MLMHead(BaseMLMHead): + ... +```` + +## 5. Fine-tune the model +You should now be able to use the new model with the following simple fine-tuning example (Deberta-V2 used as an example below): + +```python +from jiant.proj.simple import runscript as run +import jiant.scripts.download_data.runscript as downloader + +EXP_DIR = "/path/to/exp" + +# Download the Data +downloader.download_data(["mrpc"], f"{EXP_DIR}/tasks") + +# Set up the arguments for the Simple API +args = run.RunConfiguration( + run_name="simple", + exp_dir=EXP_DIR, + data_dir=f"{EXP_DIR}/tasks", + hf_pretrained_model_name_or_path="microsoft/deberta-v2-xlarge", + tasks="mrpc", + train_batch_size=16, + num_train_epochs=3 +) + +# Run! +run.run_simple(args) +``` diff --git a/jiant/proj/main/components/container_setup.py b/jiant/proj/main/components/container_setup.py index 02d9c0eb8..54bb12396 100644 --- a/jiant/proj/main/components/container_setup.py +++ b/jiant/proj/main/components/container_setup.py @@ -4,7 +4,8 @@ import jiant.proj.main.components.task_sampler as jiant_task_sampler import jiant.shared.caching as caching -import jiant.tasks as tasks +from jiant.tasks.core import Task +from jiant.tasks.retrieval import create_task_from_config_path import jiant.utils.python.io as py_io from jiant.utils.python.datastructures import ExtendedDataClassMixin @@ -48,7 +49,7 @@ class TaskRunConfig(ExtendedDataClassMixin): @dataclass class JiantTaskContainer: - task_dict: Dict[str, tasks.Task] + task_dict: Dict[str, Task] task_sampler: jiant_task_sampler.BaseMultiTaskSampler task_cache_dict: Dict global_train_config: GlobalTrainConfig @@ -58,7 +59,7 @@ class JiantTaskContainer: metrics_aggregator: jiant_task_sampler.BaseMetricAggregator -def create_task_dict(task_config_dict: dict, verbose: bool = True) -> Dict[str, tasks.Task]: +def create_task_dict(task_config_dict: dict, verbose: bool = True) -> Dict[str, Task]: """Make map of task name to task instances from map of task name to task config file paths. Args: @@ -71,7 +72,7 @@ def create_task_dict(task_config_dict: dict, verbose: bool = True) -> Dict[str, """ task_dict = {} for task_name, task_config_path in task_config_dict.items(): - task = tasks.create_task_from_config_path(config_path=task_config_path, verbose=False) + task = create_task_from_config_path(config_path=task_config_path, verbose=False) if not task.name == task_name: warnings.warn( "task {} from {} has conflicting names: {}/{}. Using {}".format( diff --git a/jiant/proj/main/export_model.py b/jiant/proj/main/export_model.py index 1f65c0bb8..8369f7dfa 100644 --- a/jiant/proj/main/export_model.py +++ b/jiant/proj/main/export_model.py @@ -52,7 +52,7 @@ def export_model( torch.save(model.state_dict(), model_path) py_io.write_json(model.config.to_dict(), model_config_path) - tokenizer = AutoTokenizer.from_pretrained(hf_pretrained_model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(hf_pretrained_model_name_or_path, use_fast=False) tokenizer.save_pretrained(tokenizer_fol_path) config = { "hf_pretrained_model_name_or_path": hf_pretrained_model_name_or_path, diff --git a/jiant/proj/main/modeling/heads.py b/jiant/proj/main/modeling/heads.py index 38b637c99..75ef11143 100644 --- a/jiant/proj/main/modeling/heads.py +++ b/jiant/proj/main/modeling/heads.py @@ -1,10 +1,17 @@ +from __future__ import annotations + import abc import torch import torch.nn as nn import transformers + from jiant.ext.allennlp import SelfAttentiveSpanExtractor +from jiant.shared.model_resolution import ModelArchitectures +from jiant.tasks.core import TaskTypes +from typing import Callable +from typing import List """ In HuggingFace/others, these heads differ slightly across different encoder models. @@ -12,18 +19,75 @@ """ +class JiantHeadFactory: + """This factory is used to create task-specific heads for the supported Transformer encoders. + + Attributes: + registry (dict): Dynamic registry mapping task types to task heads + """ + + registry = {} + + @classmethod + def register(cls, task_type_list: List[TaskTypes]) -> Callable: + """Register each TaskType in task_type_list as a key mapping to a BaseHead task head + + Args: + task_type_list (List[TaskType]): List of TaskTypes that are associated to a + BaseHead task head + + Returns: + Callable: inner_wrapper() wrapping task head constructor or task head factory + """ + + def inner_wrapper(wrapped_class: BaseHead) -> Callable: + """Summary + + Args: + wrapped_class (BaseHead): Task head class + + Returns: + Callable: Task head constructor or factory + """ + for task_type in task_type_list: + assert task_type not in cls.registry + cls.registry[task_type] = wrapped_class + return wrapped_class + + return inner_wrapper + + def __call__(self, task, **kwargs) -> BaseHead: + """Summary + + Args: + task (Task): A task head will be created based on the task type + **kwargs: Arguments required for task head initialization + + Returns: + BaseHead: Initialized task head + """ + head_class = self.registry[task.TASK_TYPE] + head = head_class(task, **kwargs) + return head + + class BaseHead(nn.Module, metaclass=abc.ABCMeta): - pass + """Absract class for task heads""" + + @abc.abstractmethod + def __init__(self): + super().__init__() +@JiantHeadFactory.register([TaskTypes.CLASSIFICATION]) class ClassificationHead(BaseHead): - def __init__(self, hidden_size, hidden_dropout_prob, num_labels): + def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs): """From RobertaClassificationHead""" super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(hidden_dropout_prob) - self.out_proj = nn.Linear(hidden_size, num_labels) - self.num_labels = num_labels + self.out_proj = nn.Linear(hidden_size, len(task.LABELS)) + self.num_labels = len(task.LABELS) def forward(self, pooled): x = self.dropout(pooled) @@ -34,8 +98,9 @@ def forward(self, pooled): return logits +@JiantHeadFactory.register([TaskTypes.REGRESSION, TaskTypes.MULTIPLE_CHOICE]) class RegressionHead(BaseHead): - def __init__(self, hidden_size, hidden_dropout_prob): + def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs): """From RobertaClassificationHead""" super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) @@ -51,12 +116,13 @@ def forward(self, pooled): return scores +@JiantHeadFactory.register([TaskTypes.SPAN_COMPARISON_CLASSIFICATION]) class SpanComparisonHead(BaseHead): - def __init__(self, hidden_size, hidden_dropout_prob, num_spans, num_labels): + def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs): """From RobertaForSpanComparisonClassification""" super().__init__() - self.num_spans = num_spans - self.num_labels = num_labels + self.num_spans = task.num_spans + self.num_labels = len(task.LABELS) self.hidden_size = hidden_size self.dropout = nn.Dropout(hidden_dropout_prob) self.span_attention_extractor = SelfAttentiveSpanExtractor(hidden_size) @@ -70,13 +136,14 @@ def forward(self, unpooled, spans): return logits +@JiantHeadFactory.register([TaskTypes.TAGGING]) class TokenClassificationHead(BaseHead): - def __init__(self, hidden_size, num_labels, hidden_dropout_prob): + def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs): """From RobertaForTokenClassification""" super().__init__() - self.num_labels = num_labels + self.num_labels = len(task.LABELS) self.dropout = nn.Dropout(hidden_dropout_prob) - self.classifier = nn.Linear(hidden_size, num_labels) + self.classifier = nn.Linear(hidden_size, self.num_labels) def forward(self, unpooled): unpooled = self.dropout(unpooled) @@ -84,8 +151,9 @@ def forward(self, unpooled): return logits +@JiantHeadFactory.register([TaskTypes.SQUAD_STYLE_QA]) class QAHead(BaseHead): - def __init__(self, hidden_size): + def __init__(self, task, hidden_size, **kwargs): """From RobertaForQuestionAnswering""" super().__init__() self.qa_outputs = nn.Linear(hidden_size, 2) @@ -98,18 +166,65 @@ def forward(self, unpooled): return logits +@JiantHeadFactory.register([TaskTypes.MASKED_LANGUAGE_MODELING]) +class JiantMLMHeadFactory: + """This factory is used to create masked language modeling (MLM) task heads. + This is required due to Transformers implementing different MLM heads for + different encoders. + + Attributes: + registry (dict): Dynamic registry mapping model architectures to MLM task heads + """ + + registry = {} + + @classmethod + def register(cls, model_arch_list: List[ModelArchitectures]) -> Callable: + """Registers the ModelArchitectures in model_arch_list as keys mapping to a MLMHead + + Args: + model_arch_list (List[ModelArchitectures]): List of ModelArchitectures mapping to + an MLM task head. + + Returns: + Callable: MLMHead class + """ + + def inner_wrapper(wrapped_class: BaseMLMHead) -> Callable: + for model_arch in model_arch_list: + assert model_arch not in cls.registry + cls.registry[model_arch] = wrapped_class + return wrapped_class + + return inner_wrapper + + def __call__(self, task, **kwargs): + """Summary + + Args: + task (Task): Task used to initialize task head + **kwargs: Additional arguments required to initialize task head + """ + mlm_head_class = self.registry[task.TASK_TYPE] + mlm_head = mlm_head_class(task, **kwargs) + return mlm_head + + class BaseMLMHead(BaseHead, metaclass=abc.ABCMeta): pass +@JiantMLMHeadFactory.register([ModelArchitectures.BERT]) class BertMLMHead(BaseMLMHead): """From BertOnlyMLMHead, BertLMPredictionHead, BertPredictionHeadTransform""" def __init__(self, hidden_size, vocab_size, layer_norm_eps=1e-12, hidden_act="gelu"): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) - self.transform_act_fn = transformers.modeling_bert.ACT2FN[hidden_act] - self.LayerNorm = transformers.modeling_bert.BertLayerNorm(hidden_size, eps=layer_norm_eps) + self.transform_act_fn = transformers.models.bert.modeling_bert.ACT2FN[hidden_act] + self.LayerNorm = transformers.models.bert.modeling_bert.BertLayerNorm( + hidden_size, eps=layer_norm_eps + ) self.decoder = nn.Linear(hidden_size, vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(vocab_size), requires_grad=True) @@ -126,13 +241,16 @@ def forward(self, unpooled): return logits +@JiantMLMHeadFactory.register([ModelArchitectures.ROBERTA, ModelArchitectures.XLM_ROBERTA]) class RobertaMLMHead(BaseMLMHead): """From RobertaLMHead""" def __init__(self, hidden_size, vocab_size, layer_norm_eps=1e-12): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) - self.layer_norm = transformers.modeling_bert.BertLayerNorm(hidden_size, eps=layer_norm_eps) + self.layer_norm = transformers.models.bert.modeling_bert.BertLayerNorm( + hidden_size, eps=layer_norm_eps + ) self.decoder = nn.Linear(hidden_size, vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(vocab_size), requires_grad=True) @@ -143,7 +261,7 @@ def __init__(self, hidden_size, vocab_size, layer_norm_eps=1e-12): def forward(self, unpooled): x = self.dense(unpooled) - x = transformers.modeling_bert.gelu(x) + x = transformers.models.bert.modeling_bert.gelu(x) x = self.layer_norm(x) # project back to size of vocabulary with bias @@ -151,7 +269,8 @@ def forward(self, unpooled): return logits -class AlbertMLMHead(nn.Module): +@JiantMLMHeadFactory.register([ModelArchitectures.ALBERT]) +class AlbertMLMHead(BaseMLMHead): """From AlbertMLMHead""" def __init__(self, hidden_size, embedding_size, vocab_size, hidden_act="gelu"): @@ -161,7 +280,7 @@ def __init__(self, hidden_size, embedding_size, vocab_size, hidden_act="gelu"): self.bias = nn.Parameter(torch.zeros(vocab_size), requires_grad=True) self.dense = nn.Linear(hidden_size, embedding_size) self.decoder = nn.Linear(embedding_size, vocab_size) - self.activation = transformers.modeling_bert.ACT2FN[hidden_act] + self.activation = transformers.models.bert.modeling_bert.ACT2FN[hidden_act] # Need a link between the two variables so that the bias is correctly resized with # `resize_token_embeddings` diff --git a/jiant/proj/main/modeling/model_setup.py b/jiant/proj/main/modeling/model_setup.py index 6ff546e72..b2c1a6925 100644 --- a/jiant/proj/main/modeling/model_setup.py +++ b/jiant/proj/main/modeling/model_setup.py @@ -2,22 +2,21 @@ from typing import Any from typing import Dict from typing import List -from typing import Optional import torch import torch.nn as nn import transformers - +import warnings import jiant.proj.main.components.container_setup as container_setup -import jiant.proj.main.modeling.heads as heads import jiant.proj.main.modeling.primary as primary -import jiant.proj.main.modeling.taskmodels as taskmodels import jiant.utils.python.strings as strings +from jiant.proj.main.modeling.heads import JiantHeadFactory +from jiant.proj.main.modeling.taskmodels import JiantTaskModelFactory, Taskmodel, MLMModel + from jiant.shared.model_resolution import ModelArchitectures -from jiant.tasks import Task -from jiant.tasks import TaskTypes +from jiant.tasks.core import Task def setup_jiant_model( @@ -52,18 +51,14 @@ def setup_jiant_model( JiantModel nn.Module. """ - model = transformers.AutoModel.from_pretrained(hf_pretrained_model_name_or_path) - model_arch = ModelArchitectures.from_model_type(model.base_model_prefix) - transformers_class_spec = TRANSFORMERS_CLASS_SPEC_DICT[model_arch] - tokenizer = transformers.AutoTokenizer.from_pretrained(hf_pretrained_model_name_or_path) - ancestor_model = get_ancestor_model( - transformers_class_spec=transformers_class_spec, model_config_path=model_config_path, + hf_model = transformers.AutoModel.from_pretrained(hf_pretrained_model_name_or_path) + tokenizer = transformers.AutoTokenizer.from_pretrained( + hf_pretrained_model_name_or_path, use_fast=False ) - encoder = get_encoder(model_arch=model_arch, ancestor_model=ancestor_model) + encoder = primary.JiantTransformersModelFactory()(hf_model) taskmodels_dict = { taskmodel_name: create_taskmodel( task=task_dict[task_name_list[0]], # Take the first task - model_arch=model_arch, encoder=encoder, taskmodel_kwargs=taskmodels_config.get_taskmodel_kwargs(taskmodel_name), ) @@ -161,46 +156,32 @@ def load_encoder_from_transformers_weights( """ remainder_weights_dict = {} load_weights_dict = {} - model_arch = ModelArchitectures.from_encoder(encoder=encoder) - encoder_prefix = MODEL_PREFIX[model_arch] + "." + model_arch = ModelArchitectures.from_model_type(model_type=encoder.config.model_type) + encoder_prefix = model_arch.value + "." # Encoder for k, v in weights_dict.items(): if k.startswith(encoder_prefix): load_weights_dict[strings.remove_prefix(k, encoder_prefix)] = v + elif k.startswith(encoder_prefix.split("-")[0]): + # workaround for deberta-v2 + # remove "-v2" suffix. weight names are prefixed with "deberta" and not "deberta-v2" + load_weights_dict[strings.remove_prefix(k, encoder_prefix.split("-")[0] + ".")] = v else: remainder_weights_dict[k] = v - encoder.load_state_dict(load_weights_dict) + encoder.load_state_dict(load_weights_dict, strict=False) + if remainder_weights_dict: + warnings.warn( + "The following weights were not loaded: {}".format(remainder_weights_dict.keys()) + ) if return_remainder: return remainder_weights_dict def load_lm_heads_from_transformers_weights(jiant_model, weights_dict): - model_arch = get_model_arch_from_jiant_model(jiant_model=jiant_model) - if model_arch == ModelArchitectures.BERT: - mlm_weights_map = { - "bias": "cls.predictions.bias", - "dense.weight": "cls.predictions.transform.dense.weight", - "dense.bias": "cls.predictions.transform.dense.bias", - "LayerNorm.weight": "cls.predictions.transform.LayerNorm.weight", - "LayerNorm.bias": "cls.predictions.transform.LayerNorm.bias", - "decoder.weight": "cls.predictions.decoder.weight", - "decoder.bias": "cls.predictions.bias", # <-- linked directly to bias - } - mlm_weights_dict = {new_k: weights_dict[old_k] for new_k, old_k in mlm_weights_map.items()} - elif model_arch in (ModelArchitectures.ROBERTA, ModelArchitectures.XLM_ROBERTA): - mlm_weights_dict = { - strings.remove_prefix(k, "lm_head."): v for k, v in weights_dict.items() - } - mlm_weights_dict["decoder.bias"] = mlm_weights_dict["bias"] - elif model_arch == ModelArchitectures.ALBERT: - mlm_weights_dict = { - strings.remove_prefix(k, "predictions."): v for k, v in weights_dict.items() - } - else: - raise KeyError(model_arch) + mlm_weights_dict = jiant_model.encoder.get_mlm_weights_dict(weights_dict) missed = set() for taskmodel_name, taskmodel in jiant_model.taskmodels_dict.items(): - if not isinstance(taskmodel, taskmodels.MLMModel): + if not isinstance(taskmodel, MLMModel): continue mismatch = taskmodel.mlm_head.load_state_dict(mlm_weights_dict) assert not mismatch.missing_keys @@ -273,196 +254,37 @@ def load_partial_heads( return result -def create_taskmodel( - task, model_arch, encoder, taskmodel_kwargs: Optional[Dict] = None -) -> taskmodels.Taskmodel: +def create_taskmodel(task, encoder, **taskmodel_kwargs) -> Taskmodel: """Creates, initializes and returns the task model for a given task type and encoder. Args: task (Task): Task object associated with the taskmodel being created. - model_arch (ModelArchitectures.Any): Model architecture (e.g., ModelArchitectures.BERT). - encoder (PreTrainedModel): Transformer w/o heads (embedding layer + self-attention layer). - taskmodel_kwargs (Optional[Dict]): map containing any kwargs needed for taskmodel setup. + encoder (JiantTransformersModel): Transformer w/o heads + (embedding layer + self-attention layer). + **taskmodel_kwargs: Additional args for taskmodel setup Raises: KeyError if task does not have valid TASK_TYPE. Returns: - Taskmodel (e.g., ClassificationModel) appropriate for the task type and encoder. + Taskmodel """ - if model_arch in [ - ModelArchitectures.BERT, - ModelArchitectures.ROBERTA, - ModelArchitectures.ALBERT, - ModelArchitectures.XLM_ROBERTA, - ModelArchitectures.ELECTRA, - ]: - hidden_size = encoder.config.hidden_size - hidden_dropout_prob = encoder.config.hidden_dropout_prob - elif model_arch in [ - ModelArchitectures.BART, - ModelArchitectures.MBART, - ]: - hidden_size = encoder.config.d_model - hidden_dropout_prob = encoder.config.dropout - else: - raise KeyError() - - if task.TASK_TYPE == TaskTypes.CLASSIFICATION: - assert taskmodel_kwargs is None - classification_head = heads.ClassificationHead( - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - num_labels=len(task.LABELS), - ) - taskmodel = taskmodels.ClassificationModel( - encoder=encoder, classification_head=classification_head, - ) - elif task.TASK_TYPE == TaskTypes.REGRESSION: - assert taskmodel_kwargs is None - regression_head = heads.RegressionHead( - hidden_size=hidden_size, hidden_dropout_prob=hidden_dropout_prob, - ) - taskmodel = taskmodels.RegressionModel(encoder=encoder, regression_head=regression_head) - elif task.TASK_TYPE == TaskTypes.MULTIPLE_CHOICE: - assert taskmodel_kwargs is None - choice_scoring_head = heads.RegressionHead( - hidden_size=hidden_size, hidden_dropout_prob=hidden_dropout_prob, - ) - taskmodel = taskmodels.MultipleChoiceModel( - encoder=encoder, num_choices=task.NUM_CHOICES, choice_scoring_head=choice_scoring_head, - ) - elif task.TASK_TYPE == TaskTypes.SPAN_PREDICTION: - assert taskmodel_kwargs is None - span_prediction_head = heads.TokenClassificationHead( - hidden_size=hidden_size, - hidden_dropout_prob=encoder.config.hidden_dropout_prob, - num_labels=2, - ) - taskmodel = taskmodels.SpanPredictionModel( - encoder=encoder, span_prediction_head=span_prediction_head, - ) - elif task.TASK_TYPE == TaskTypes.SPAN_COMPARISON_CLASSIFICATION: - assert taskmodel_kwargs is None - span_comparison_head = heads.SpanComparisonHead( - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - num_spans=task.num_spans, - num_labels=len(task.LABELS), - ) - taskmodel = taskmodels.SpanComparisonModel( - encoder=encoder, span_comparison_head=span_comparison_head, - ) - elif task.TASK_TYPE == TaskTypes.MULTI_LABEL_SPAN_CLASSIFICATION: - assert taskmodel_kwargs is None - span_comparison_head = heads.SpanComparisonHead( - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - num_spans=task.num_spans, - num_labels=len(task.LABELS), - ) - taskmodel = taskmodels.MultiLabelSpanComparisonModel( - encoder=encoder, span_comparison_head=span_comparison_head, - ) - elif task.TASK_TYPE == TaskTypes.TAGGING: - assert taskmodel_kwargs is None - token_classification_head = heads.TokenClassificationHead( - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - num_labels=len(task.LABELS), - ) - taskmodel = taskmodels.TokenClassificationModel( - encoder=encoder, token_classification_head=token_classification_head, - ) - elif task.TASK_TYPE == TaskTypes.SQUAD_STYLE_QA: - assert taskmodel_kwargs is None - qa_head = heads.QAHead(hidden_size=hidden_size) - taskmodel = taskmodels.QAModel(encoder=encoder, qa_head=qa_head) - elif task.TASK_TYPE == TaskTypes.MASKED_LANGUAGE_MODELING: - assert taskmodel_kwargs is None - if model_arch == ModelArchitectures.BERT: - mlm_head = heads.BertMLMHead( - hidden_size=hidden_size, - vocab_size=encoder.config.vocab_size, - layer_norm_eps=encoder.config.layer_norm_eps, - hidden_act=encoder.config.hidden_act, - ) - elif model_arch == ModelArchitectures.ROBERTA: - mlm_head = heads.RobertaMLMHead( - hidden_size=hidden_size, - vocab_size=encoder.config.vocab_size, - layer_norm_eps=encoder.config.layer_norm_eps, - ) - elif model_arch == ModelArchitectures.ALBERT: - mlm_head = heads.AlbertMLMHead( - hidden_size=hidden_size, - embedding_size=encoder.config.embedding_size, - vocab_size=encoder.config.vocab_size, - hidden_act=encoder.config.hidden_act, - ) - elif model_arch == ModelArchitectures.XLM_ROBERTA: - mlm_head = heads.RobertaMLMHead( - hidden_size=hidden_size, - vocab_size=encoder.config.vocab_size, - layer_norm_eps=encoder.config.layer_norm_eps, - ) - elif model_arch in ( - ModelArchitectures.BART, - ModelArchitectures.MBART, - ModelArchitectures.ELECTRA, - ): - raise NotImplementedError() - else: - raise KeyError(model_arch) - taskmodel = taskmodels.MLMModel(encoder=encoder, mlm_head=mlm_head) - elif task.TASK_TYPE == TaskTypes.EMBEDDING: - if taskmodel_kwargs["pooler_type"] == "mean": - pooler_head = heads.MeanPoolerHead() - elif taskmodel_kwargs["pooler_type"] == "first": - pooler_head = heads.FirstPoolerHead() - else: - raise KeyError(taskmodel_kwargs["pooler_type"]) - taskmodel = taskmodels.EmbeddingModel( - encoder=encoder, pooler_head=pooler_head, layer=taskmodel_kwargs["layer"], - ) - else: - raise KeyError(task.TASK_TYPE) - return taskmodel - + head_kwargs = {} + head_kwargs["hidden_size"] = encoder.get_hidden_size() + head_kwargs["hidden_dropout_prob"] = encoder.get_hidden_dropout_prob() + head_kwargs["vocab_size"] = encoder.config.vocab_size + head_kwargs["model_arch"] = ModelArchitectures(encoder.config.model_type) -def get_encoder(model_arch, ancestor_model): - """From model architecture, get the encoder (encoder = embedding layer + self-attention layer). - - This function will return the "The bare Bert Model transformer outputting raw hidden-states - without any specific head on top", when provided with ModelArchitectures and BertForPreTraining - model. See Hugging Face's BertForPreTraining and BertModel documentation for more info. - - Args: - model_arch: Model architecture. - ancestor_model: Model with pretraining heads attached. - - Raises: - KeyError if ModelArchitectures + if hasattr(encoder, "hidden_act"): + head_kwargs["hidden_act"] = encoder.config.hidden_act + if hasattr(encoder, "layer_norm_eps"): + head_kwargs["layer_norm_eps"] = encoder.config.layer_norm_eps - Returns: - Bare pretrained model outputting raw hidden-states without a specific head on top. + head = JiantHeadFactory()(task, **head_kwargs) - """ - if model_arch == ModelArchitectures.BERT: - return ancestor_model.bert - elif model_arch == ModelArchitectures.ROBERTA: - return ancestor_model.roberta - elif model_arch == ModelArchitectures.ALBERT: - return ancestor_model.albert - elif model_arch == ModelArchitectures.XLM_ROBERTA: - return ancestor_model.roberta - elif model_arch in (ModelArchitectures.BART, ModelArchitectures.MBART): - return ancestor_model.model - elif model_arch == ModelArchitectures.ELECTRA: - return ancestor_model.electra - else: - raise KeyError(model_arch) + taskmodel = JiantTaskModelFactory()(task, encoder, head, **taskmodel_kwargs) + return taskmodel @dataclass @@ -472,45 +294,6 @@ class TransformersClassSpec: model_class: Any -TRANSFORMERS_CLASS_SPEC_DICT = { - ModelArchitectures.BERT: TransformersClassSpec( - config_class=transformers.BertConfig, - tokenizer_class=transformers.BertTokenizer, - model_class=transformers.BertForPreTraining, - ), - ModelArchitectures.ROBERTA: TransformersClassSpec( - config_class=transformers.RobertaConfig, - tokenizer_class=transformers.RobertaTokenizer, - model_class=transformers.RobertaForMaskedLM, - ), - ModelArchitectures.ALBERT: TransformersClassSpec( - config_class=transformers.AlbertConfig, - tokenizer_class=transformers.AlbertTokenizer, - model_class=transformers.AlbertForMaskedLM, - ), - ModelArchitectures.XLM_ROBERTA: TransformersClassSpec( - config_class=transformers.XLMRobertaConfig, - tokenizer_class=transformers.XLMRobertaTokenizer, - model_class=transformers.XLMRobertaForMaskedLM, - ), - ModelArchitectures.BART: TransformersClassSpec( - config_class=transformers.BartConfig, - tokenizer_class=transformers.BartTokenizer, - model_class=transformers.BartForConditionalGeneration, - ), - ModelArchitectures.MBART: TransformersClassSpec( - config_class=transformers.BartConfig, - tokenizer_class=transformers.MBartTokenizer, - model_class=transformers.BartForConditionalGeneration, - ), - ModelArchitectures.ELECTRA: TransformersClassSpec( - config_class=transformers.ElectraConfig, - tokenizer_class=transformers.ElectraTokenizer, - model_class=transformers.ElectraForPreTraining, - ), -} - - def get_taskmodel_and_task_names(task_to_taskmodel_map: Dict[str, str]) -> Dict[str, List[str]]: """Get mapping from task model name to the list of task names associated with that task model. @@ -529,21 +312,6 @@ def get_taskmodel_and_task_names(task_to_taskmodel_map: Dict[str, str]) -> Dict[ return taskmodel_and_task_names -def get_model_arch_from_jiant_model(jiant_model: nn.Module) -> ModelArchitectures: - return ModelArchitectures.from_encoder(encoder=jiant_model.encoder) - - -MODEL_PREFIX = { - ModelArchitectures.BERT: "bert", - ModelArchitectures.ROBERTA: "roberta", - ModelArchitectures.ALBERT: "albert", - ModelArchitectures.XLM_ROBERTA: "xlm-roberta", - ModelArchitectures.BART: "model", - ModelArchitectures.MBART: "model", - ModelArchitectures.ELECTRA: "electra", -} - - def get_ancestor_model(transformers_class_spec, model_config_path): """Load the model config from a file, configure the model, and return the model. diff --git a/jiant/proj/main/modeling/primary.py b/jiant/proj/main/modeling/primary.py index bf5f398a4..eb2c2ca78 100644 --- a/jiant/proj/main/modeling/primary.py +++ b/jiant/proj/main/modeling/primary.py @@ -1,18 +1,44 @@ -from typing import Dict, Union +import abc +from dataclasses import dataclass + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Union + +import torch import torch.nn as nn -import jiant.proj.main.modeling.taskmodels as taskmodels -import jiant.tasks as tasks +import jiant.utils.python.strings as strings +from jiant.tasks.core import BatchMixin +from jiant.tasks.core import FeaturizationSpec +from jiant.tasks.core import Task + from jiant.proj.main.components.outputs import construct_output_from_dict +from jiant.proj.main.modeling.taskmodels import Taskmodel +from jiant.shared.model_resolution import ModelArchitectures + +from jiant.utils.tokenization_utils import bow_tag_tokens +from jiant.utils.tokenization_utils import eow_tag_tokens +from jiant.utils.tokenization_utils import process_bytebpe_tokens +from jiant.utils.tokenization_utils import process_wordpiece_tokens +from jiant.utils.tokenization_utils import process_sentencepiece_tokens + + +@dataclass +class JiantModelOutput: + pooled: torch.Tensor + unpooled: torch.Tensor + other: Any = None class JiantModel(nn.Module): def __init__( self, - task_dict: Dict[str, tasks.Task], + task_dict: Dict[str, Task], encoder: nn.Module, - taskmodels_dict: Dict[str, taskmodels.Taskmodel], + taskmodels_dict: Dict[str, Taskmodel], task_to_taskmodel_map: Dict[str, str], tokenizer, ): @@ -23,15 +49,15 @@ def __init__( self.task_to_taskmodel_map = task_to_taskmodel_map self.tokenizer = tokenizer - def forward(self, batch: tasks.BatchMixin, task: tasks.Task, compute_loss: bool = False): + def forward(self, batch: BatchMixin, task: Task, compute_loss: bool = False): """Calls to this forward method are delegated to the forward of the appropriate taskmodel. When JiantModel forward is called, the task name from the task argument is used as a key to select the appropriate submodule/taskmodel, and that taskmodel's forward is called. Args: - batch (tasks.BatchMixin): model input. - task (tasks.Task): task to which to delegate the forward call. + batch (BatchMixin): model input. + task (Task): task to which to delegate the forward call. compute_loss (bool): whether to calculate and return the loss. Returns: @@ -49,26 +75,26 @@ def forward(self, batch: tasks.BatchMixin, task: tasks.Task, compute_loss: bool taskmodel_key = self.task_to_taskmodel_map[task_name] taskmodel = self.taskmodels_dict[taskmodel_key] return taskmodel( - batch=batch, task=task, tokenizer=self.tokenizer, compute_loss=compute_loss, + batch=batch, tokenizer=self.tokenizer, compute_loss=compute_loss, ).to_dict() def wrap_jiant_forward( jiant_model: Union[JiantModel, nn.DataParallel], - batch: tasks.BatchMixin, - task: tasks.Task, + batch: BatchMixin, + task: Task, compute_loss: bool = False, ): """Wrapper to repackage model inputs using dictionaries for compatibility with DataParallel. - Wrapper that converts batches (type tasks.BatchMixin) to dictionaries before delegating to + Wrapper that converts batches (type BatchMixin) to dictionaries before delegating to JiantModel's forward method, and then converts the resulting model output dict into the appropriate model output dataclass. Args: jiant_model (Union[JiantModel, nn.DataParallel]): - batch (tasks.BatchMixin): model input batch. - task (tasks.Task): Task object passed for access in the taskmodel. + batch (BatchMixin): model input batch. + task (Task): Task object passed for access in the taskmodel. compute_loss (bool): True if loss should be computed, False otherwise. Returns: @@ -85,3 +111,471 @@ def wrap_jiant_forward( if is_multi_gpu and compute_loss: model_output.loss = model_output.loss.mean() return model_output + + +class JiantTransformersModelFactory: + """This factory is used to create JiantTransformersModels based on Huggingface's models. + A wrapper class around Huggingface's Transformer models is used to abstract any inconsistencies + in the classes. + + Attributes: + registry (dict): Dynamic registry mapping ModelArchitectures to JiantTransformersModels + """ + + registry = {} + + @classmethod + def get_registry(cls): + return cls.registry + + @classmethod + def build_featurization_spec(cls, model_type, max_seq_length): + model_arch = ModelArchitectures.from_model_type(model_type) + model_class = cls.get_registry()[model_arch] + return model_class.get_feat_spec(model_type, max_seq_length) + + @classmethod + def register(cls, model_arch: ModelArchitectures) -> Callable: + """Register model_arch as a key mapping to a TaskModel + + Args: + model_arch (ModelArchitectures): ModelArchitecture key mapping to a + JiantTransformersModel + + Returns: + Callable: inner_wrapper() wrapping TaskModel constructor + """ + + def inner_wrapper(wrapped_class: JiantTransformersModel) -> Callable: + assert model_arch not in cls.registry + cls.registry[model_arch] = wrapped_class + return wrapped_class + + return inner_wrapper + + def __call__(cls, hf_model): + """Returns the JiantTransformersModel wrapper class for the corresponding Hugging Face + Transformer model. + + Args: + hf_model (PreTrainedModel): Hugging Face model to convert to JiantTransformersModel + + Returns: + JiantTransformersModel: Jiant wrapper class for Hugging Face model + """ + encoder_class = cls.registry[ModelArchitectures(hf_model.config.model_type)] + encoder = encoder_class(hf_model) + return encoder + + +class JiantTransformersModel(metaclass=abc.ABCMeta): + def __init__(self, baseObject): + self.__class__ = type( + baseObject.__class__.__name__, (self.__class__, baseObject.__class__), {} + ) + self.__dict__ = baseObject.__dict__ + + @classmethod + @abc.abstractmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """Abstract method to tag space_tokenization and process target_tokenization with + the relevant tokenization method for the model. + """ + pass + + @abc.abstractmethod + def get_mlm_weights_dict(self, weights_dict): + """Abstract method to get the pre-trained masked-language modeling head weights + from the pretrained model from the weights_dict + """ + pass + + @abc.abstractmethod + def get_feat_spec(self, weights_dict): + """Abstract method that should return a FeaturizationSpec specifying the + tokenization details used for the model + """ + pass + + def get_hidden_size(self): + return self.config.hidden_size + + def get_hidden_dropout_prob(self): + return self.config.hidden_dropout_prob + + def encode(self, input_ids, segment_ids, input_mask, output_hidden_states=True): + output = self.forward( + input_ids=input_ids, + token_type_ids=segment_ids, + attention_mask=input_mask, + output_hidden_states=output_hidden_states, + ) + return JiantModelOutput( + pooled=output.pooler_output, + unpooled=output.last_hidden_state, + other=output.hidden_states, + ) + + +@JiantTransformersModelFactory.register(ModelArchitectures.BERT) +class JiantBertModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """See tokenization_normalization.py for details""" + if tokenizer.init_kwargs.get("do_lower_case", False): + space_tokenization = [token.lower() for token in space_tokenization] + modifed_space_tokenization = bow_tag_tokens(space_tokenization) + modifed_target_tokenization = process_wordpiece_tokens(target_tokenization) + + return modifed_space_tokenization, modifed_target_tokenization + + def get_feat_spec(self, max_seq_length): + return FeaturizationSpec( + max_seq_length=max_seq_length, + cls_token_at_end=False, + pad_on_left=False, + cls_token_segment_id=0, + pad_token_segment_id=0, + pad_token_id=0, + pad_token_mask_id=0, + sequence_a_segment_id=0, + sequence_b_segment_id=1, + sep_token_extra=False, + ) + + def get_mlm_weights_dict(self, weights_dict): + mlm_weights_map = { + "bias": "cls.predictions.bias", + "dense.weight": "cls.predictions.transform.dense.weight", + "dense.bias": "cls.predictions.transform.dense.bias", + "LayerNorm.weight": "cls.predictions.transform.LayerNorm.weight", + "LayerNorm.bias": "cls.predictions.transform.LayerNorm.bias", + "decoder.weight": "cls.predictions.decoder.weight", + "decoder.bias": "cls.predictions.bias", # <-- linked directly to bias + } + mlm_weights_dict = {new_k: weights_dict[old_k] for new_k, old_k in mlm_weights_map.items()} + return mlm_weights_dict + + +@JiantTransformersModelFactory.register(ModelArchitectures.ROBERTA) +class JiantRobertaModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """See tokenization_normalization.py for details""" + modifed_space_tokenization = bow_tag_tokens(space_tokenization) + modifed_target_tokenization = ["Ġ" + target_tokenization[0]] + target_tokenization[1:] + modifed_target_tokenization = process_bytebpe_tokens(modifed_target_tokenization) + + return modifed_space_tokenization, modifed_target_tokenization + + def get_mlm_weights_dict(self, weights_dict): + mlm_weights_dict = { + strings.remove_prefix(k, "lm_head."): v for k, v in weights_dict.items() + } + mlm_weights_dict["decoder.bias"] = mlm_weights_dict["bias"] + return mlm_weights_dict + + def get_feat_spec(self, max_seq_length): + # RoBERTa is weird + # token 0 = '' which is the cls_token + # token 1 = '' which is the sep_token + # Also two ''s are used between sentences. Yes, not ''. + return FeaturizationSpec( + max_seq_length=max_seq_length, + cls_token_at_end=False, + pad_on_left=False, + cls_token_segment_id=0, + pad_token_segment_id=0, + pad_token_id=1, # Roberta uses pad_token_id = 1 + pad_token_mask_id=0, + sequence_a_segment_id=0, + sequence_b_segment_id=0, # RoBERTa has no token_type_ids + sep_token_extra=True, + ) + + +@JiantTransformersModelFactory.register(ModelArchitectures.DEBERTAV2) +class JiantDebertaV2Model(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """See tokenization_normalization.py for details""" + space_tokenization = [token for token in space_tokenization] + modifed_space_tokenization = bow_tag_tokens(space_tokenization) + modifed_target_tokenization = process_sentencepiece_tokens(target_tokenization) + + return modifed_space_tokenization, modifed_target_tokenization + + def encode(self, input_ids, segment_ids, input_mask, output_hidden_states=True): + output = self.forward( + input_ids=input_ids, + token_type_ids=segment_ids, + attention_mask=input_mask, + output_hidden_states=output_hidden_states, + ) + return JiantModelOutput( + pooled=output.last_hidden_state[:, 0, :], + unpooled=output.last_hidden_state, + other=output.hidden_states, + ) + + def get_mlm_weights_dict(self, weights_dict): + mlm_weights_map = { + "bias": "cls.predictions.bias", + "dense.weight": "cls.predictions.transform.dense.weight", + "dense.bias": "cls.predictions.transform.dense.bias", + "LayerNorm.weight": "cls.predictions.transform.LayerNorm.weight", + "LayerNorm.bias": "cls.predictions.transform.LayerNorm.bias", + "decoder.weight": "cls.predictions.decoder.weight", + "decoder.bias": "cls.predictions.bias", # <-- linked directly to bias + } + mlm_weights_dict = {new_k: weights_dict[old_k] for new_k, old_k in mlm_weights_map.items()} + return mlm_weights_dict + + def get_feat_spec(self, max_seq_length): + return FeaturizationSpec( + max_seq_length=max_seq_length, + cls_token_at_end=False, + pad_on_left=False, + cls_token_segment_id=0, + pad_token_segment_id=0, + pad_token_id=0, + pad_token_mask_id=0, + sequence_a_segment_id=0, + sequence_b_segment_id=1, + sep_token_extra=False, + ) + + +@JiantTransformersModelFactory.register(ModelArchitectures.XLM_ROBERTA) +class JiantXLMRobertaModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """See tokenization_normalization.py for details""" + space_tokenization = [token.lower() for token in space_tokenization] + modifed_space_tokenization = bow_tag_tokens(space_tokenization) + modifed_target_tokenization = process_sentencepiece_tokens(target_tokenization) + + return modifed_space_tokenization, modifed_target_tokenization + + def get_feat_spec(self, max_seq_length): + # XLM-RoBERTa is weird + # token 0 = '' which is the cls_token + # token 1 = '' which is the sep_token + # Also two ''s are used between sentences. Yes, not ''. + return FeaturizationSpec( + max_seq_length=max_seq_length, + cls_token_at_end=False, + pad_on_left=False, + cls_token_segment_id=0, + pad_token_segment_id=0, + pad_token_id=1, # XLM-RoBERTa uses pad_token_id = 1 + pad_token_mask_id=0, + sequence_a_segment_id=0, + sequence_b_segment_id=0, # XLM-RoBERTa has no token_type_ids + sep_token_extra=True, + ) + + def get_mlm_weights_dict(self, weights_dict): + mlm_weights_dict = { + strings.remove_prefix(k, "lm_head."): v for k, v in weights_dict.items() + } + mlm_weights_dict["decoder.bias"] = mlm_weights_dict["bias"] + return mlm_weights_dict + + +@JiantTransformersModelFactory.register(ModelArchitectures.XLM) +class JiantXLMModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """See tokenization_normalization.py for details""" + if tokenizer.init_kwargs.get("do_lowercase_and_remove_accent", False): + space_tokenization = [token.lower() for token in space_tokenization] + modifed_space_tokenization = eow_tag_tokens(space_tokenization) + modifed_target_tokenization = target_tokenization + + return modifed_space_tokenization, modifed_target_tokenization + + def get_feat_spec(self, max_seq_length): + return FeaturizationSpec( + max_seq_length=max_seq_length, + cls_token_at_end=False, + pad_on_left=False, + cls_token_segment_id=0, + pad_token_segment_id=0, + pad_token_id=0, + pad_token_mask_id=0, + sequence_a_segment_id=0, + sequence_b_segment_id=0, # RoBERTa has no token_type_ids + sep_token_extra=False, + ) + + +@JiantTransformersModelFactory.register(ModelArchitectures.ALBERT) +class JiantAlbertModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """See tokenization_normalization.py for details""" + space_tokenization = [token.lower() for token in space_tokenization] + modifed_space_tokenization = bow_tag_tokens(space_tokenization) + modifed_target_tokenization = process_sentencepiece_tokens(target_tokenization) + + return modifed_space_tokenization, modifed_target_tokenization + + def get_mlm_weights_dict(self, weights_dict): + mlm_weights_dict = { + strings.remove_prefix(k, "predictions."): v for k, v in weights_dict.items() + } + return mlm_weights_dict + + def get_feat_spec(self, max_seq_length): + return FeaturizationSpec( + max_seq_length=max_seq_length, + cls_token_at_end=False, # ? + pad_on_left=False, # ok + cls_token_segment_id=0, # ok + pad_token_segment_id=0, # ok + pad_token_id=0, # I think? + pad_token_mask_id=0, # I think? + sequence_a_segment_id=0, # I think? + sequence_b_segment_id=1, # I think? + sep_token_extra=False, + ) + + +@JiantTransformersModelFactory.register(ModelArchitectures.ELECTRA) +class JiantElectraModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + + def encode(self, input_ids, segment_ids, input_mask): + output = self.forward( + input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask + ) + unpooled = output.hidden_states + pooled = unpooled[:, 0, :] + return JiantModelOutput(pooled=pooled, unpooled=unpooled, other=output.hidden_states) + + def get_feat_spec(self, max_seq_length): + return FeaturizationSpec( + max_seq_length=max_seq_length, + cls_token_at_end=False, + pad_on_left=False, + cls_token_segment_id=0, + pad_token_segment_id=0, + pad_token_id=0, + pad_token_mask_id=0, + sequence_a_segment_id=0, + sequence_b_segment_id=1, + sep_token_extra=False, + ) + + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + raise NotImplementedError() + + def get_mlm_weights_dict(self, weights_dict): + raise NotImplementedError() + + +@JiantTransformersModelFactory.register(ModelArchitectures.BART) +class JiantBartModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + + def get_hidden_size(self): + return self.config.d_model + + def get_hidden_dropout_prob(self): + return self.config.dropout + + def get_feat_spec(self, max_seq_length): + # BART is weird + # token 0 = '' which is the cls_token + # token 1 = '' which is the sep_token + # Also two ''s are used between sentences. Yes, not ''. + return FeaturizationSpec( + max_seq_length=max_seq_length, + cls_token_at_end=False, + pad_on_left=False, + cls_token_segment_id=0, + pad_token_segment_id=0, + pad_token_id=1, # BART uses pad_token_id = 1 + pad_token_mask_id=0, + sequence_a_segment_id=0, + sequence_b_segment_id=0, # BART has no token_type_ids + sep_token_extra=True, + ) + + def encode(self, input_ids, input_mask, *args): + # BART and mBART and encoder-decoder architectures. + # As described in the BART paper and implemented in Transformers, + # for single input tasks, the encoder input is the sequence, + # the decode input is 1-shifted sequence, and the resulting + # sentence representation is the final decoder state. + # That's what we use for `unpooled` here. + dec_last, dec_all, enc_last, enc_all = super().__call__( + input_ids=input_ids, + attention_mask=input_mask, + output_hidden_states=True, + return_dict=True, + ) + unpooled = dec_last + other = (enc_all + dec_all,) + + bsize, slen = input_ids.shape + batch_idx = torch.arange(bsize).to(input_ids.device) + # Get last non-pad index + pooled = unpooled[batch_idx, slen - input_ids.eq(self.config.pad_token_id).sum(1) - 1] + return JiantModelOutput(pooled=pooled, unpooled=unpooled, other=other) + + def get_mlm_weights_dict(self, weights_dict): + raise NotImplementedError() + + +@JiantTransformersModelFactory.register(ModelArchitectures.MBART) +class JiantMBartModel(JiantBartModel): + def __init__(self, baseObject): + super().__init__(baseObject) + + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + raise NotImplementedError() + + def get_feat_spec(self, max_seq_length): + # mBART is weird + # token 0 = '' which is the cls_token + # token 1 = '' which is the sep_token + # Also two ''s are used between sentences. Yes, not ''. + return FeaturizationSpec( + max_seq_length=max_seq_length, + cls_token_at_end=False, + pad_on_left=False, + cls_token_segment_id=0, + pad_token_segment_id=0, + pad_token_id=1, # mBART uses pad_token_id = 1 + pad_token_mask_id=0, + sequence_a_segment_id=0, + sequence_b_segment_id=0, # mBART has no token_type_ids + sep_token_extra=True, + ) + + def get_mlm_weights_dict(self, weights_dict): + raise NotImplementedError() diff --git a/jiant/proj/main/modeling/taskmodels.py b/jiant/proj/main/modeling/taskmodels.py index 53d5e70c7..af9e62417 100644 --- a/jiant/proj/main/modeling/taskmodels.py +++ b/jiant/proj/main/modeling/taskmodels.py @@ -1,53 +1,106 @@ import abc -from dataclasses import dataclass -from typing import Any import torch import torch.nn as nn +from typing import Callable + import jiant.proj.main.modeling.heads as heads -import jiant.utils.transformer_utils as transformer_utils -from jiant.proj.main.components.outputs import LogitsOutput, LogitsAndLossOutput + +from jiant.proj.main.components.outputs import LogitsAndLossOutput +from jiant.proj.main.components.outputs import LogitsOutput from jiant.utils.python.datastructures import take_one -from jiant.shared.model_resolution import ModelArchitectures + +from jiant.tasks.core import TaskTypes + + +class JiantTaskModelFactory: + """This factory is used to create task models bundling the task, + encoder, and task head within the task model. + + Attributes: + registry (dict): Dynamic registry mapping task types to task models + """ + + registry = {} + + @classmethod + def register(cls, task_type: TaskTypes) -> Callable: + """Register task_type as a key mapping to a TaskModel + + Args: + task_type (TaskTypes): TaskType key mapping to a BaseHead task head + + Returns: + Callable: inner_wrapper() wrapping TaskModel constructor + """ + + def inner_wrapper(wrapped_class: Taskmodel) -> Callable: + assert task_type not in cls.registry + cls.registry[task_type] = wrapped_class + return wrapped_class + + return inner_wrapper + + def __call__(cls, task, encoder, head, **kwargs): + """This creates the TaskModel corresponding to the Task, abc.abstractmethod, + and encoder used. + + Args: + task (Task): Task + encoder (JiantTransformersModel): encoder + head (BaseHead): Task head + **kwargs: Additional arguments for initializing TaskModel + + Returns: + TaskModel: Initialized task model bundling task, encoder, and head + """ + taskmodel_class = cls.registry[task.TASK_TYPE] + taskmodel = taskmodel_class(task, encoder, head, **kwargs) + return taskmodel class Taskmodel(nn.Module, metaclass=abc.ABCMeta): - def __init__(self, encoder): + def __init__(self, task, encoder, head): super().__init__() + self.task = task self.encoder = encoder + self.head = head - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): raise NotImplementedError +@JiantTaskModelFactory.register(TaskTypes.CLASSIFICATION) class ClassificationModel(Taskmodel): - def __init__(self, encoder, classification_head: heads.ClassificationHead): - super().__init__(encoder=encoder) - self.classification_head = classification_head + def __init__(self, task, encoder, head: heads.ClassificationHead, **kwargs): - def forward(self, batch, task, tokenizer, compute_loss: bool = False): - encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.classification_head(pooled=encoder_output.pooled) + super().__init__(task=task, encoder=encoder, head=head) + + def forward(self, batch, tokenizer, compute_loss: bool = False): + encoder_output = self.encoder.encode( + input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, + ) + logits = self.head(pooled=encoder_output.pooled) if compute_loss: loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.classification_head.num_labels), batch.label_id.view(-1), - ) + loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_id.view(-1),) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.REGRESSION) class RegressionModel(Taskmodel): - def __init__(self, encoder, regression_head: heads.RegressionHead): - super().__init__(encoder=encoder) - self.regression_head = regression_head + def __init__(self, task, encoder, head: heads.RegressionHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): - encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) + def forward(self, batch, tokenizer, compute_loss: bool = False): + encoder_output = self.encoder.encode( + input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, + ) # TODO: Abuse of notation - these aren't really logits (issue #1187) - logits = self.regression_head(pooled=encoder_output.pooled) + logits = self.head(pooled=encoder_output.pooled) if compute_loss: loss_fct = nn.MSELoss() loss = loss_fct(logits.view(-1), batch.label.view(-1)) @@ -56,27 +109,22 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.MULTIPLE_CHOICE) class MultipleChoiceModel(Taskmodel): - def __init__(self, encoder, num_choices: int, choice_scoring_head: heads.RegressionHead): - super().__init__(encoder=encoder) - self.num_choices = num_choices - self.choice_scoring_head = choice_scoring_head - - def forward(self, batch, task, tokenizer, compute_loss: bool = False): - input_ids = batch.input_ids - segment_ids = batch.segment_ids - input_mask = batch.input_mask + def __init__(self, task, encoder, head: heads.RegressionHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) + self.num_choices = task.NUM_CHOICES + def forward(self, batch, tokenizer, compute_loss: bool = False): choice_score_list = [] encoder_output_other_ls = [] for i in range(self.num_choices): - encoder_output = get_output_from_encoder( - encoder=self.encoder, - input_ids=input_ids[:, i], - segment_ids=segment_ids[:, i], - input_mask=input_mask[:, i], + encoder_output = self.encoder.encode( + input_ids=batch.input_ids[:, i], + segment_ids=batch.segment_ids[:, i], + input_mask=batch.input_mask[:, i], ) - choice_score = self.choice_scoring_head(pooled=encoder_output.pooled) + choice_score = self.head(pooled=encoder_output.pooled) choice_score_list.append(choice_score) encoder_output_other_ls.append(encoder_output.other) @@ -85,7 +133,7 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): for j in range(len(encoder_output_other_ls[0])): reshaped_outputs.append( [ - torch.stack([misc[j][layer_i] for misc in encoder_output_other_ls], dim=1) + torch.stack([misc[j][layer_i] for misc in encoder_output_other_ls], dim=1,) for layer_i in range(len(encoder_output_other_ls[0][0])) ] ) @@ -103,36 +151,48 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=reshaped_outputs) +@JiantTaskModelFactory.register(TaskTypes.SPAN_COMPARISON_CLASSIFICATION) class SpanComparisonModel(Taskmodel): - def __init__(self, encoder, span_comparison_head: heads.SpanComparisonHead): - super().__init__(encoder=encoder) - self.span_comparison_head = span_comparison_head - - def forward(self, batch, task, tokenizer, compute_loss: bool = False): - encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans) + def __init__(self, task, encoder, head: heads.SpanComparisonHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) + + def forward(self, batch, tokenizer, compute_loss: bool = False): + """Summary + + Args: + batch (TYPE): Description + tokenizer (TYPE): Description + compute_loss (bool, optional): Description + + Returns: + TYPE: Description + """ + encoder_output = self.encoder.encode( + input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, + ) + logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans) if compute_loss: loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - logits.view(-1, self.span_comparison_head.num_labels), batch.label_id.view(-1), - ) + loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_id.view(-1),) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.SPAN_PREDICTION) class SpanPredictionModel(Taskmodel): - def __init__(self, encoder, span_prediction_head: heads.TokenClassificationHead): - super().__init__(encoder=encoder) + def __init__(self, task, encoder, head: heads.TokenClassificationHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) self.offset_margin = 1000 # 1000 is a big enough number that exp(-1000) will be strict 0 in float32. # So that if we add 1000 to the valid dimensions in the input of softmax, # we can guarantee the output distribution will only be non-zero at those dimensions. - self.span_prediction_head = span_prediction_head - def forward(self, batch, task, tokenizer, compute_loss: bool = False): - encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.span_prediction_head(unpooled=encoder_output.unpooled) + def forward(self, batch, tokenizer, compute_loss: bool = False): + encoder_output = self.encoder.encode( + input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, + ) + logits = self.head(unpooled=encoder_output.unpooled) # Ensure logits in valid range is at least self.offset_margin higher than others logits_offset = logits.max() - logits.min() + self.offset_margin logits = logits + logits_offset * batch.selection_token_mask.unsqueeze(dim=2) @@ -146,38 +206,40 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.MULTI_LABEL_SPAN_CLASSIFICATION) class MultiLabelSpanComparisonModel(Taskmodel): - def __init__(self, encoder, span_comparison_head: heads.SpanComparisonHead): - super().__init__(encoder=encoder) - self.span_comparison_head = span_comparison_head + def __init__(self, task, encoder, head: heads.SpanComparisonHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): - encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.span_comparison_head(unpooled=encoder_output.unpooled, spans=batch.spans) + def forward(self, batch, tokenizer, compute_loss: bool = False): + encoder_output = self.encoder.encode( + input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, + ) + logits = self.head(unpooled=encoder_output.unpooled, spans=batch.spans) if compute_loss: loss_fct = nn.BCEWithLogitsLoss() - loss = loss_fct( - logits.view(-1, self.span_comparison_head.num_labels), batch.label_ids.float(), - ) + loss = loss_fct(logits.view(-1, self.head.num_labels), batch.label_ids.float(),) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) else: return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.TAGGING) class TokenClassificationModel(Taskmodel): """From RobertaForTokenClassification""" - def __init__(self, encoder, token_classification_head: heads.TokenClassificationHead): - super().__init__(encoder=encoder) - self.token_classification_head = token_classification_head + def __init__(self, task, encoder, head: heads.TokenClassificationHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): - encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.token_classification_head(unpooled=encoder_output.unpooled) + def forward(self, batch, tokenizer, compute_loss: bool = False): + encoder_output = self.encoder.encode( + input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, + ) + logits = self.head(unpooled=encoder_output.unpooled) if compute_loss: loss_fct = nn.CrossEntropyLoss() active_loss = batch.label_mask.view(-1) == 1 - active_logits = logits.view(-1, self.token_classification_head.num_labels)[active_loss] + active_logits = logits.view(-1, self.head.num_labels)[active_loss] active_labels = batch.label_ids.view(-1)[active_loss] loss = loss_fct(active_logits, active_labels) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) @@ -185,14 +247,16 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.SQUAD_STYLE_QA) class QAModel(Taskmodel): - def __init__(self, encoder, qa_head: heads.QAHead): - super().__init__(encoder=encoder) - self.qa_head = qa_head + def __init__(self, task, encoder, head: heads.QAHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): - encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) - logits = self.qa_head(unpooled=encoder_output.unpooled) + def forward(self, batch, tokenizer, compute_loss: bool = False): + encoder_output = self.encoder.encode( + input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, + ) + logits = self.head(unpooled=encoder_output.unpooled) if compute_loss: loss = compute_qa_loss( logits=logits, @@ -204,22 +268,23 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.MASKED_LANGUAGE_MODELING) class MLMModel(Taskmodel): - def __init__(self, encoder, mlm_head: heads.BaseMLMHead): - super().__init__(encoder=encoder) - self.mlm_head = mlm_head + def __init__(self, task, encoder, head: heads.BaseMLMHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) - def forward(self, batch, task, tokenizer, compute_loss: bool = False): + def forward(self, batch, tokenizer, compute_loss: bool = False): masked_batch = batch.get_masked( - mlm_probability=task.mlm_probability, tokenizer=tokenizer, do_mask=task.do_mask, + mlm_probability=self.task.mlm_probability, + tokenizer=tokenizer, + do_mask=self.task.do_mask, ) - encoder_output = get_output_from_encoder( - encoder=self.encoder, - input_ids=masked_batch.masked_input_ids, + encoder_output = self.encoder.encode( + input_ids=masked_batch.input_ids, segment_ids=masked_batch.segment_ids, input_mask=masked_batch.input_mask, ) - logits = self.mlm_head(unpooled=encoder_output.unpooled) + logits = self.head(unpooled=encoder_output.unpooled) if compute_loss: loss = compute_mlm_loss(logits=logits, masked_lm_labels=masked_batch.masked_lm_labels) return LogitsAndLossOutput(logits=logits, loss=loss, other=encoder_output.other) @@ -227,25 +292,27 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) +@JiantTaskModelFactory.register(TaskTypes.EMBEDDING) class EmbeddingModel(Taskmodel): - def __init__(self, encoder, pooler_head: heads.AbstractPoolerHead, layer): - super().__init__(encoder=encoder) - self.pooler_head = pooler_head - self.layer = layer - - def forward(self, batch, task, tokenizer, compute_loss: bool = False): - with transformer_utils.output_hidden_states_context(self.encoder): - encoder_output = get_output_from_encoder_and_batch(encoder=self.encoder, batch=batch) + def __init__(self, task, encoder, head: heads.AbstractPoolerHead, **kwargs): + super().__init__(task=task, encoder=encoder, head=head) + self.layer = kwargs["layer"] + + def forward(self, batch, tokenizer, compute_loss: bool = False): + encoder_output = self.encoder.encode( + input_ids=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask, + ) + # A tuple of layers of hidden states hidden_states = take_one(encoder_output.other) layer_hidden_states = hidden_states[self.layer] - if isinstance(self.pooler_head, heads.MeanPoolerHead): - logits = self.pooler_head(unpooled=layer_hidden_states, input_mask=batch.input_mask) - elif isinstance(self.pooler_head, heads.FirstPoolerHead): - logits = self.pooler_head(layer_hidden_states) + if isinstance(self.head, heads.MeanPoolerHead): + logits = self.head(unpooled=layer_hidden_states, input_mask=batch.input_mask) + elif isinstance(self.head, heads.FirstPoolerHead): + logits = self.head(layer_hidden_states) else: - raise TypeError(type(self.pooler_head)) + raise TypeError(type(self.head)) # TODO: Abuse of notation - these aren't really logits (issue #1187) if compute_loss: @@ -259,114 +326,6 @@ def forward(self, batch, task, tokenizer, compute_loss: bool = False): return LogitsOutput(logits=logits, other=encoder_output.other) -@dataclass -class EncoderOutput: - pooled: torch.Tensor - unpooled: torch.Tensor - other: Any = None - # Extend later with attention, hidden_acts, etc - - -def get_output_from_encoder_and_batch(encoder, batch) -> EncoderOutput: - """Pass batch to encoder, return encoder model output. - - Args: - encoder: bare model outputting raw hidden-states without any specific head. - batch: Batch object (containing token indices, token type ids, and attention mask). - - Returns: - EncoderOutput containing pooled and unpooled model outputs as well as any other outputs. - - """ - return get_output_from_encoder( - encoder=encoder, - input_ids=batch.input_ids, - segment_ids=batch.segment_ids, - input_mask=batch.input_mask, - ) - - -def get_output_from_encoder(encoder, input_ids, segment_ids, input_mask) -> EncoderOutput: - """Pass inputs to encoder, return encoder output. - - Args: - encoder: bare model outputting raw hidden-states without any specific head. - input_ids: token indices (see huggingface.co/transformers/glossary.html#input-ids). - segment_ids: token type ids (see huggingface.co/transformers/glossary.html#token-type-ids). - input_mask: attention mask (see huggingface.co/transformers/glossary.html#attention-mask). - - Raises: - RuntimeError if encoder output contains less than 2 elements. - - Returns: - EncoderOutput containing pooled and unpooled model outputs as well as any other outputs. - - """ - model_arch = ModelArchitectures.from_encoder(encoder) - if model_arch in [ - ModelArchitectures.BERT, - ModelArchitectures.ROBERTA, - ModelArchitectures.ALBERT, - ModelArchitectures.XLM_ROBERTA, - ]: - pooled, unpooled, other = get_output_from_standard_transformer_models( - encoder=encoder, input_ids=input_ids, segment_ids=segment_ids, input_mask=input_mask, - ) - elif model_arch == ModelArchitectures.ELECTRA: - pooled, unpooled, other = get_output_from_electra( - encoder=encoder, input_ids=input_ids, segment_ids=segment_ids, input_mask=input_mask, - ) - elif model_arch in [ - ModelArchitectures.BART, - ModelArchitectures.MBART, - ]: - pooled, unpooled, other = get_output_from_bart_models( - encoder=encoder, input_ids=input_ids, input_mask=input_mask, - ) - else: - raise KeyError(model_arch) - - # Extend later with attention, hidden_acts, etc - if other: - return EncoderOutput(pooled=pooled, unpooled=unpooled, other=other) - else: - return EncoderOutput(pooled=pooled, unpooled=unpooled) - - -def get_output_from_standard_transformer_models(encoder, input_ids, segment_ids, input_mask): - output = encoder(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask) - pooled, unpooled, other = output[1], output[0], output[2:] - return pooled, unpooled, other - - -def get_output_from_bart_models(encoder, input_ids, input_mask): - # BART and mBART and encoder-decoder architectures. - # As described in the BART paper and implemented in Transformers, - # for single input tasks, the encoder input is the sequence, - # the decode input is 1-shifted sequence, and the resulting - # sentence representation is the final decoder state. - # That's what we use for `unpooled` here. - dec_last, dec_all, enc_last, enc_all = encoder( - input_ids=input_ids, attention_mask=input_mask, output_hidden_states=True, - ) - unpooled = dec_last - - other = (enc_all + dec_all,) - - bsize, slen = input_ids.shape - batch_idx = torch.arange(bsize).to(input_ids.device) - # Get last non-pad index - pooled = unpooled[batch_idx, slen - input_ids.eq(encoder.config.pad_token_id).sum(1) - 1] - return pooled, unpooled, other - - -def get_output_from_electra(encoder, input_ids, segment_ids, input_mask): - output = encoder(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask) - unpooled = output[0] - pooled = unpooled[:, 0, :] - return pooled, unpooled, output - - def compute_mlm_loss(logits, masked_lm_labels): vocab_size = logits.shape[-1] loss_fct = nn.CrossEntropyLoss() diff --git a/jiant/proj/main/tokenize_and_cache.py b/jiant/proj/main/tokenize_and_cache.py index 51688b735..2fb3bffd6 100644 --- a/jiant/proj/main/tokenize_and_cache.py +++ b/jiant/proj/main/tokenize_and_cache.py @@ -1,15 +1,17 @@ import os -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig +from transformers import AutoTokenizer import jiant.proj.main.preprocessing as preprocessing import jiant.shared.caching as shared_caching -import jiant.shared.model_resolution as model_resolution -import jiant.tasks as tasks import jiant.tasks.evaluate as evaluate -import jiant.utils.zconf as zconf import jiant.utils.python.io as py_io +import jiant.utils.zconf as zconf + +from jiant.proj.main.modeling.primary import JiantTransformersModelFactory from jiant.shared.constants import PHASE +from jiant.tasks.retrieval import create_task_from_config_path @zconf.run_config @@ -145,11 +147,11 @@ def main(args: RunConfiguration): config = AutoConfig.from_pretrained(args.hf_pretrained_model_name_or_path) model_type = config.model_type - task = tasks.create_task_from_config_path(config_path=args.task_config_path, verbose=True) - feat_spec = model_resolution.build_featurization_spec( + task = create_task_from_config_path(config_path=args.task_config_path, verbose=True) + feat_spec = JiantTransformersModelFactory.build_featurization_spec( model_type=model_type, max_seq_length=args.max_seq_length, ) - tokenizer = AutoTokenizer.from_pretrained(args.hf_pretrained_model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(args.hf_pretrained_model_name_or_path, use_fast=False) if isinstance(args.phases, str): phases = args.phases.split(",") else: diff --git a/jiant/proj/simple/runscript.py b/jiant/proj/simple/runscript.py index e7bace5fd..c46d0630c 100644 --- a/jiant/proj/simple/runscript.py +++ b/jiant/proj/simple/runscript.py @@ -104,7 +104,6 @@ def create_and_write_task_configs(task_name_list, data_dir, task_config_base_pat def run_simple(args: RunConfiguration, with_continue: bool = False): hf_config = AutoConfig.from_pretrained(args.hf_pretrained_model_name_or_path) - model_cache_path = replace_none( args.model_cache_path, default=os.path.join(args.exp_dir, "models") ) diff --git a/jiant/shared/model_resolution.py b/jiant/shared/model_resolution.py index 7ec446137..5a993ddad 100644 --- a/jiant/shared/model_resolution.py +++ b/jiant/shared/model_resolution.py @@ -1,150 +1,39 @@ from dataclasses import dataclass from enum import Enum +from jiant.utils.python.datastructures import BiDict import transformers -from jiant.tasks.core import FeaturizationSpec - class ModelArchitectures(Enum): - BERT = 1 - XLM = 2 - ROBERTA = 3 - ALBERT = 4 - XLM_ROBERTA = 5 - BART = 6 - MBART = 7 - ELECTRA = 8 + BERT = "bert" + XLM = "xlm" + ROBERTA = "roberta" + ALBERT = "albert" + XLM_ROBERTA = "xlm-roberta" + BART = "bart" + MBART = "mbart" + ELECTRA = "electra" + DEBERTAV2 = "deberta-v2" @classmethod def from_model_type(cls, model_type: str): - """Get the model architecture for the provided shortcut name. - - Args: - model_type (str): model shortcut name. - - Returns: - Model architecture associated with the provided shortcut name. - - """ - if model_type.startswith("bert"): - return cls.BERT - elif model_type.startswith("xlm") and not model_type.startswith("xlm-roberta"): - return cls.XLM - elif model_type.startswith("roberta"): - return cls.ROBERTA - elif model_type.startswith("albert"): - return cls.ALBERT - elif model_type == "glove_lstm": - return cls.GLOVE_LSTM - elif model_type.startswith("xlm-roberta"): - return cls.XLM_ROBERTA - elif model_type.startswith("bart"): - return cls.BART - elif model_type.startswith("mbart"): - return cls.MBART - elif model_type.startswith("electra"): - return cls.ELECTRA - else: - raise KeyError(model_type) - - @classmethod - def from_transformers_model(cls, transformers_model): - if isinstance( - transformers_model, transformers.BertPreTrainedModel - ) and transformers_model.__class__.__name__.startswith("Bert"): - return cls.BERT - elif isinstance(transformers_model, transformers.XLMPreTrainedModel): - return cls.XLM - elif isinstance( - transformers_model, transformers.BertPreTrainedModel - ) and transformers_model.__class__.__name__.startswith("Robert"): - return cls.ROBERTA - elif isinstance( - transformers_model, transformers.BertPreTrainedModel - ) and transformers_model.__class__.__name__.startswith("XLMRoberta"): - return cls.XLM_ROBERTA - elif isinstance(transformers_model, transformers.modeling_albert.AlbertPreTrainedModel): - return cls.ALBERT - elif isinstance(transformers_model, transformers.modeling_bart.PretrainedBartModel): - return bart_or_mbart_model_heuristic(model_config=transformers_model.config) - elif isinstance(transformers_model, transformers.modeling_electra.ElectraPreTrainedModel): - return cls.ELECTRA - else: - raise KeyError(str(transformers_model)) + return cls(model_type) - @classmethod - def from_tokenizer_class(cls, tokenizer_class): - if isinstance(tokenizer_class, transformers.BertTokenizer): - return cls.BERT - elif isinstance(tokenizer_class, transformers.XLMTokenizer): - return cls.XLM - elif isinstance(tokenizer_class, transformers.RobertaTokenizer): - return cls.ROBERTA - elif isinstance(tokenizer_class, transformers.XLMRobertaTokenizer): - return cls.XLM_ROBERTA - elif isinstance(tokenizer_class, transformers.AlbertTokenizer): - return cls.ALBERT - elif isinstance(tokenizer_class, transformers.BartTokenizer): - return cls.BART - elif isinstance(tokenizer_class, transformers.MBartTokenizer): - return cls.MBART - elif isinstance(tokenizer_class, transformers.ElectraTokenizer): - return cls.ELECTRA - else: - raise KeyError(str(tokenizer_class)) - - @classmethod - def is_transformers_model_arch(cls, model_arch): - return model_arch in [ - cls.BERT, - cls.XLM, - cls.ROBERTA, - cls.ALBERT, - cls.XLM_ROBERTA, - cls.BART, - cls.MBART, - cls.ELECTRA, - ] - @classmethod - def from_encoder(cls, encoder): - if ( - isinstance(encoder, transformers.BertModel) - and encoder.__class__.__name__ == "BertModel" - ): - return cls.BERT - elif ( - isinstance(encoder, transformers.XLMModel) and encoder.__class__.__name__ == "XLMModel" - ): - return cls.XLM - elif ( - isinstance(encoder, transformers.RobertaModel) - and encoder.__class__.__name__ == "RobertaModel" - ): - return cls.ROBERTA - elif ( - isinstance(encoder, transformers.AlbertModel) - and encoder.__class__.__name__ == "AlbertModel" - ): - return cls.ALBERT - elif ( - isinstance(encoder, transformers.XLMRobertaModel) - and encoder.__class__.__name__ == "XlmRobertaModel" - ): - return cls.XLM_ROBERTA - elif ( - isinstance(encoder, transformers.BartModel) - and encoder.__class__.__name__ == "BartModel" - ): - return bart_or_mbart_model_heuristic(model_config=encoder.config) - elif ( - isinstance(encoder, transformers.ElectraModel) - and encoder.__class__.__name__ == "ElectraModel" - ): - return cls.ELECTRA - else: - raise KeyError(type(encoder)) +TOKENIZER_CLASS_DICT = BiDict( + { + ModelArchitectures.BERT: transformers.BertTokenizer, + ModelArchitectures.XLM: transformers.XLMTokenizer, + ModelArchitectures.ROBERTA: transformers.RobertaTokenizer, + ModelArchitectures.XLM_ROBERTA: transformers.XLMRobertaTokenizer, + ModelArchitectures.ALBERT: transformers.AlbertTokenizer, + ModelArchitectures.BART: transformers.BartTokenizer, + ModelArchitectures.MBART: transformers.MBartTokenizer, + ModelArchitectures.ELECTRA: transformers.ElectraTokenizer, + ModelArchitectures.DEBERTAV2: transformers.DebertaV2Tokenizer, + } +) @dataclass @@ -154,145 +43,6 @@ class ModelClassSpec: model_class: type -def build_featurization_spec(model_type, max_seq_length): - model_arch = ModelArchitectures.from_model_type(model_type) - if model_arch == ModelArchitectures.BERT: - return FeaturizationSpec( - max_seq_length=max_seq_length, - cls_token_at_end=False, - pad_on_left=False, - cls_token_segment_id=0, - pad_token_segment_id=0, - pad_token_id=0, - pad_token_mask_id=0, - sequence_a_segment_id=0, - sequence_b_segment_id=1, - sep_token_extra=False, - ) - elif model_arch == ModelArchitectures.XLM: - return FeaturizationSpec( - max_seq_length=max_seq_length, - cls_token_at_end=False, - pad_on_left=False, - cls_token_segment_id=0, - pad_token_segment_id=0, - pad_token_id=0, - pad_token_mask_id=0, - sequence_a_segment_id=0, - sequence_b_segment_id=0, # RoBERTa has no token_type_ids - sep_token_extra=False, - ) - elif model_arch == ModelArchitectures.ROBERTA: - # RoBERTa is weird - # token 0 = '' which is the cls_token - # token 1 = '' which is the sep_token - # Also two ''s are used between sentences. Yes, not ''. - return FeaturizationSpec( - max_seq_length=max_seq_length, - cls_token_at_end=False, - pad_on_left=False, - cls_token_segment_id=0, - pad_token_segment_id=0, - pad_token_id=1, # Roberta uses pad_token_id = 1 - pad_token_mask_id=0, - sequence_a_segment_id=0, - sequence_b_segment_id=0, # RoBERTa has no token_type_ids - sep_token_extra=True, - ) - elif model_arch == ModelArchitectures.ALBERT: - # - return FeaturizationSpec( - max_seq_length=max_seq_length, - cls_token_at_end=False, # ? - pad_on_left=False, # ok - cls_token_segment_id=0, # ok - pad_token_segment_id=0, # ok - pad_token_id=0, # I think? - pad_token_mask_id=0, # I think? - sequence_a_segment_id=0, # I think? - sequence_b_segment_id=1, # I think? - sep_token_extra=False, - ) - elif model_arch == ModelArchitectures.XLM_ROBERTA: - # XLM-RoBERTa is weird - # token 0 = '' which is the cls_token - # token 1 = '' which is the sep_token - # Also two ''s are used between sentences. Yes, not ''. - return FeaturizationSpec( - max_seq_length=max_seq_length, - cls_token_at_end=False, - pad_on_left=False, - cls_token_segment_id=0, - pad_token_segment_id=0, - pad_token_id=1, # XLM-RoBERTa uses pad_token_id = 1 - pad_token_mask_id=0, - sequence_a_segment_id=0, - sequence_b_segment_id=0, # XLM-RoBERTa has no token_type_ids - sep_token_extra=True, - ) - elif model_arch == ModelArchitectures.BART: - # BART is weird - # token 0 = '' which is the cls_token - # token 1 = '' which is the sep_token - # Also two ''s are used between sentences. Yes, not ''. - return FeaturizationSpec( - max_seq_length=max_seq_length, - cls_token_at_end=False, - pad_on_left=False, - cls_token_segment_id=0, - pad_token_segment_id=0, - pad_token_id=1, # BART uses pad_token_id = 1 - pad_token_mask_id=0, - sequence_a_segment_id=0, - sequence_b_segment_id=0, # BART has no token_type_ids - sep_token_extra=True, - ) - elif model_arch == ModelArchitectures.MBART: - # mBART is weird - # token 0 = '' which is the cls_token - # token 1 = '' which is the sep_token - # Also two ''s are used between sentences. Yes, not ''. - return FeaturizationSpec( - max_seq_length=max_seq_length, - cls_token_at_end=False, - pad_on_left=False, - cls_token_segment_id=0, - pad_token_segment_id=0, - pad_token_id=1, # mBART uses pad_token_id = 1 - pad_token_mask_id=0, - sequence_a_segment_id=0, - sequence_b_segment_id=0, # mBART has no token_type_ids - sep_token_extra=True, - ) - elif model_arch == ModelArchitectures.ELECTRA: - return FeaturizationSpec( - max_seq_length=max_seq_length, - cls_token_at_end=False, - pad_on_left=False, - cls_token_segment_id=0, - pad_token_segment_id=0, - pad_token_id=0, - pad_token_mask_id=0, - sequence_a_segment_id=0, - sequence_b_segment_id=1, - sep_token_extra=False, - ) - else: - raise KeyError(model_arch) - - -TOKENIZER_CLASS_DICT = { - ModelArchitectures.BERT: transformers.BertTokenizer, - ModelArchitectures.XLM: transformers.XLMTokenizer, - ModelArchitectures.ROBERTA: transformers.RobertaTokenizer, - ModelArchitectures.XLM_ROBERTA: transformers.XLMRobertaTokenizer, - ModelArchitectures.ALBERT: transformers.AlbertTokenizer, - ModelArchitectures.BART: transformers.BartTokenizer, - ModelArchitectures.MBART: transformers.MBartTokenizer, - ModelArchitectures.ELECTRA: transformers.ElectraTokenizer, -} - - def resolve_tokenizer_class(model_type): """Get tokenizer class for a given model architecture. @@ -303,7 +53,21 @@ def resolve_tokenizer_class(model_type): Tokenizer associated with the given model. """ - return TOKENIZER_CLASS_DICT[ModelArchitectures.from_model_type(model_type)] + return TOKENIZER_CLASS_DICT[ModelArchitectures(model_type)] + + +def resolve_model_arch_tokenizer(tokenizer): + """Get the model architecture for a given tokenizer. + + Args: + tokenizer + + Returns: + ModelArchitecture + + """ + assert len(TOKENIZER_CLASS_DICT.inverse[tokenizer.__class__]) == 1 + return TOKENIZER_CLASS_DICT.inverse[tokenizer.__class__][0] def resolve_is_lower_case(tokenizer): diff --git a/jiant/tasks/__init__.py b/jiant/tasks/__init__.py index 28093f838..e69de29bb 100644 --- a/jiant/tasks/__init__.py +++ b/jiant/tasks/__init__.py @@ -1,2 +0,0 @@ -from .retrieval import * # noqa: F401,F403 -from .core import BatchMixin, TaskTypes # noqa: F401 diff --git a/jiant/tasks/core.py b/jiant/tasks/core.py index 59471f163..85cc0e0e4 100644 --- a/jiant/tasks/core.py +++ b/jiant/tasks/core.py @@ -85,18 +85,18 @@ def data_row_collate_fn(batch): class TaskTypes(Enum): - CLASSIFICATION = 1 - REGRESSION = 2 - SPAN_COMPARISON_CLASSIFICATION = 3 - MULTIPLE_CHOICE = 4 - SPAN_CHOICE_PROB_TASK = 5 - SQUAD_STYLE_QA = 6 - TAGGING = 7 - MASKED_LANGUAGE_MODELING = 8 - EMBEDDING = 9 - MULTI_LABEL_SPAN_CLASSIFICATION = 10 - SPAN_PREDICTION = 11 - UNDEFINED = -1 + CLASSIFICATION = "classification" + REGRESSION = "regression" + SPAN_COMPARISON_CLASSIFICATION = "span_comparison_classification" + MULTIPLE_CHOICE = "multiple_choice" + SPAN_CHOICE_PROB_TASK = "span_choice_prob_task" + SQUAD_STYLE_QA = "squad_style_qa" + TAGGING = "tagging" + MASKED_LANGUAGE_MODELING = "masked_language_modeling" + EMBEDDING = "embedding" + MULTI_LABEL_SPAN_CLASSIFICATION = "multi_label_span_classification" + SPAN_PREDICTION = "span_prediction" + UNDEFINED = "undefined" class BatchTuple(NamedTuple): diff --git a/jiant/tasks/evaluate/core.py b/jiant/tasks/evaluate/core.py index fbd535692..5e07224eb 100644 --- a/jiant/tasks/evaluate/core.py +++ b/jiant/tasks/evaluate/core.py @@ -1,26 +1,35 @@ import itertools import json + from dataclasses import dataclass import numpy as np import pandas as pd import seqeval.metrics as seqeval_metrics import torch -from sklearn.metrics import f1_score, matthews_corrcoef -from scipy.stats import pearsonr, spearmanr -from typing import Dict, List + +from scipy.stats import pearsonr +from scipy.stats import spearmanr +from sklearn.metrics import f1_score +from sklearn.metrics import matthews_corrcoef +from typing import Dict +from typing import List import jiant.shared.model_resolution as model_resolution -import jiant.tasks as tasks -import jiant.tasks.lib.templates.squad_style.core as squad_style -import jiant.tasks.lib.templates.squad_style.utils as squad_style_utils -import jiant.tasks.lib.mlqa as mlqa_lib + import jiant.tasks.lib.bucc2018 as bucc2018_lib +import jiant.tasks.lib.mlqa as mlqa_lib import jiant.tasks.lib.tatoeba as tatoeba_lib +import jiant.tasks.lib.templates.squad_style.core as squad_style +import jiant.tasks.lib.templates.squad_style.utils as squad_style_utils + +import jiant.tasks.retrieval as tasks_retrieval + from jiant.tasks.lib.templates import mlm as mlm_template from jiant.utils.python.datastructures import ExtendedDataClassMixin from jiant.utils.python.io import read_json -from jiant.utils.string_comparing import string_f1_score, exact_match_score +from jiant.utils.string_comparing import exact_match_score +from jiant.utils.string_comparing import string_f1_score @dataclass @@ -752,7 +761,7 @@ def compute_metrics_from_accumulator( self, task, accumulator: BaseAccumulator, tokenizer, labels ) -> Metrics: logits = accumulator.get_accumulated() - assert isinstance(task, (tasks.TyDiQATask, tasks.XquadTask)) + assert isinstance(task, (tasks_retrieval.TyDiQATask, tasks_retrieval.XquadTask)) lang = task.language results, predictions = squad_style.compute_predictions_logits_v3( data_rows=labels, @@ -954,119 +963,119 @@ def get_evaluation_scheme_for_task(task) -> BaseEvaluationScheme: if isinstance( task, ( - tasks.AdversarialNliTask, - tasks.AbductiveNliTask, - tasks.AcceptabilityDefinitenessTask, - tasks.AcceptabilityCoordTask, - tasks.AcceptabilityEOSTask, - tasks.AcceptabilityWHwordsTask, - tasks.BoolQTask, - tasks.CopaTask, - tasks.FeverNliTask, - tasks.MnliTask, - tasks.PawsXTask, - tasks.QnliTask, - tasks.RaceTask, - tasks.RteTask, - tasks.SciTailTask, - tasks.SentEvalBigramShiftTask, - tasks.SentEvalCoordinationInversionTask, - tasks.SentEvalObjNumberTask, - tasks.SentEvalOddManOutTask, - tasks.SentEvalPastPresentTask, - tasks.SentEvalSentenceLengthTask, - tasks.SentEvalSubjNumberTask, - tasks.SentEvalTopConstituentsTask, - tasks.SentEvalTreeDepthTask, - tasks.SentEvalWordContentTask, - tasks.SnliTask, - tasks.SstTask, - tasks.WiCTask, - tasks.WnliTask, - tasks.WSCTask, - tasks.XnliTask, - tasks.MCScriptTask, - tasks.ArctTask, - tasks.PiqaTask, + tasks_retrieval.AdversarialNliTask, + tasks_retrieval.AbductiveNliTask, + tasks_retrieval.AcceptabilityDefinitenessTask, + tasks_retrieval.AcceptabilityCoordTask, + tasks_retrieval.AcceptabilityEOSTask, + tasks_retrieval.AcceptabilityWHwordsTask, + tasks_retrieval.BoolQTask, + tasks_retrieval.CopaTask, + tasks_retrieval.FeverNliTask, + tasks_retrieval.MnliTask, + tasks_retrieval.PawsXTask, + tasks_retrieval.QnliTask, + tasks_retrieval.RaceTask, + tasks_retrieval.RteTask, + tasks_retrieval.SciTailTask, + tasks_retrieval.SentEvalBigramShiftTask, + tasks_retrieval.SentEvalCoordinationInversionTask, + tasks_retrieval.SentEvalObjNumberTask, + tasks_retrieval.SentEvalOddManOutTask, + tasks_retrieval.SentEvalPastPresentTask, + tasks_retrieval.SentEvalSentenceLengthTask, + tasks_retrieval.SentEvalSubjNumberTask, + tasks_retrieval.SentEvalTopConstituentsTask, + tasks_retrieval.SentEvalTreeDepthTask, + tasks_retrieval.SentEvalWordContentTask, + tasks_retrieval.SnliTask, + tasks_retrieval.SstTask, + tasks_retrieval.WiCTask, + tasks_retrieval.WnliTask, + tasks_retrieval.WSCTask, + tasks_retrieval.XnliTask, + tasks_retrieval.MCScriptTask, + tasks_retrieval.ArctTask, + tasks_retrieval.PiqaTask, ), ): return SimpleAccuracyEvaluationScheme() - elif isinstance(task, tasks.MCTACOTask): + elif isinstance(task, tasks_retrieval.MCTACOTask): return MCTACOEvaluationScheme() - elif isinstance(task, tasks.CCGTask): + elif isinstance(task, tasks_retrieval.CCGTask): return CCGEvaluationScheme() - elif isinstance(task, tasks.CommitmentBankTask): + elif isinstance(task, tasks_retrieval.CommitmentBankTask): return CommitmentBankEvaluationScheme() - elif isinstance(task, tasks.ColaTask): + elif isinstance(task, tasks_retrieval.ColaTask): return MCCEvaluationScheme() elif isinstance( task, ( - tasks.ArcEasyTask, - tasks.ArcChallengeTask, - tasks.CommonsenseQATask, - tasks.CosmosQATask, - tasks.SWAGTask, - tasks.HellaSwagTask, - tasks.MutualTask, - tasks.MutualPlusTask, - tasks.QuailTask, - tasks.SocialIQATask, - tasks.WinograndeTask, - tasks.MCTestTask, + tasks_retrieval.ArcEasyTask, + tasks_retrieval.ArcChallengeTask, + tasks_retrieval.CommonsenseQATask, + tasks_retrieval.CosmosQATask, + tasks_retrieval.SWAGTask, + tasks_retrieval.HellaSwagTask, + tasks_retrieval.MutualTask, + tasks_retrieval.MutualPlusTask, + tasks_retrieval.QuailTask, + tasks_retrieval.SocialIQATask, + tasks_retrieval.WinograndeTask, + tasks_retrieval.MCTestTask, ), ): return MultipleChoiceAccuracyEvaluationScheme() - elif isinstance(task, (tasks.MrpcTask, tasks.QqpTask)): + elif isinstance(task, (tasks_retrieval.MrpcTask, tasks_retrieval.QqpTask)): return AccAndF1EvaluationScheme() elif isinstance( task, ( - tasks.Spr1Task, - tasks.Spr2Task, - tasks.SemevalTask, - tasks.SrlTask, - tasks.NerTask, - tasks.CorefTask, - tasks.DprTask, - tasks.DepTask, - tasks.PosTask, - tasks.NonterminalTask, + tasks_retrieval.Spr1Task, + tasks_retrieval.Spr2Task, + tasks_retrieval.SemevalTask, + tasks_retrieval.SrlTask, + tasks_retrieval.NerTask, + tasks_retrieval.CorefTask, + tasks_retrieval.DprTask, + tasks_retrieval.DepTask, + tasks_retrieval.PosTask, + tasks_retrieval.NonterminalTask, ), ): return MultiLabelAccAndF1EvaluationScheme() - elif isinstance(task, tasks.ReCoRDTask): + elif isinstance(task, tasks_retrieval.ReCoRDTask): return ReCordEvaluationScheme() elif isinstance( task, ( - tasks.SquadTask, - tasks.RopesTask, - tasks.QuorefTask, - tasks.NewsQATask, - tasks.MrqaNaturalQuestionsTask, + tasks_retrieval.SquadTask, + tasks_retrieval.RopesTask, + tasks_retrieval.QuorefTask, + tasks_retrieval.NewsQATask, + tasks_retrieval.MrqaNaturalQuestionsTask, ), ): return SQuADEvaluationScheme() - elif isinstance(task, (tasks.TyDiQATask, tasks.XquadTask)): + elif isinstance(task, (tasks_retrieval.TyDiQATask, tasks_retrieval.XquadTask)): return XlingQAEvaluationScheme() - elif isinstance(task, tasks.MlqaTask): + elif isinstance(task, tasks_retrieval.MlqaTask): return MLQAEvaluationScheme() - elif isinstance(task, tasks.MultiRCTask): + elif isinstance(task, tasks_retrieval.MultiRCTask): return MultiRCEvaluationScheme() - elif isinstance(task, tasks.StsbTask): + elif isinstance(task, tasks_retrieval.StsbTask): return PearsonAndSpearmanEvaluationScheme() - elif isinstance(task, tasks.MLMSimpleTask): + elif isinstance(task, tasks_retrieval.MLMSimpleTask): return MLMEvaluationScheme() - elif isinstance(task, (tasks.MLMPremaskedTask, tasks.MLMPretokenizedTask)): + elif isinstance(task, (tasks_retrieval.MLMPremaskedTask, tasks_retrieval.MLMPretokenizedTask)): return MLMPremaskedEvaluationScheme() - elif isinstance(task, (tasks.QAMRTask, tasks.QASRLTask)): + elif isinstance(task, (tasks_retrieval.QAMRTask, tasks_retrieval.QASRLTask)): return SpanPredictionF1andEMScheme() - elif isinstance(task, (tasks.UdposTask, tasks.PanxTask)): + elif isinstance(task, (tasks_retrieval.UdposTask, tasks_retrieval.PanxTask)): return F1TaggingEvaluationScheme() - elif isinstance(task, tasks.Bucc2018Task): + elif isinstance(task, tasks_retrieval.Bucc2018Task): return Bucc2018EvaluationScheme() - elif isinstance(task, tasks.TatoebaTask): + elif isinstance(task, tasks_retrieval.TatoebaTask): return TatoebaEvaluationScheme() else: raise KeyError(task) diff --git a/jiant/tasks/lib/ccg.py b/jiant/tasks/lib/ccg.py index c468c03e8..6eebf1f02 100644 --- a/jiant/tasks/lib/ccg.py +++ b/jiant/tasks/lib/ccg.py @@ -31,7 +31,7 @@ def tokenize(self, tokenizer): tokenized = tokenizer.tokenize(self.text) split_text = self.text.split(" ") # CCG data is space-tokenized input_flat_stripped = tokenization_utils.input_flat_strip(split_text) - flat_stripped, indices = tokenization_utils.delegate_flat_strip( + flat_stripped, indices = tokenization_utils.flat_strip( tokens=tokenized, tokenizer=tokenizer, return_indices=True, ) assert flat_stripped == input_flat_stripped diff --git a/jiant/tasks/lib/ropes.py b/jiant/tasks/lib/ropes.py index 74c823ad7..8c022bf0e 100644 --- a/jiant/tasks/lib/ropes.py +++ b/jiant/tasks/lib/ropes.py @@ -90,7 +90,9 @@ def to_feature_list( # (This may not apply for future added models that don't start with a CLS token, # such as XLNet/GPT-2) sequence_added_tokens = 1 - sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair + sequence_pair_added_tokens = ( + tokenizer.model_max_length - tokenizer.model_max_length_sentences_pair + ) span_doc_tokens = all_doc_tokens while len(spans) * doc_stride < len(all_doc_tokens): diff --git a/jiant/tasks/lib/rte.py b/jiant/tasks/lib/rte.py index 923c546dc..c62a02c6f 100644 --- a/jiant/tasks/lib/rte.py +++ b/jiant/tasks/lib/rte.py @@ -81,6 +81,10 @@ class RteTask(SuperGlueMixin, GlueMixin, Task): LABELS = ["entailment", "not_entailment"] LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) + @property + def num_labels(self): + return len(self.LABELS) + def get_train_examples(self): return self._create_examples(lines=read_jsonl(self.train_path), set_type="train") diff --git a/jiant/tasks/lib/templates/hacky_tokenization_matching.py b/jiant/tasks/lib/templates/hacky_tokenization_matching.py index 96ed3be64..5ce0f900c 100644 --- a/jiant/tasks/lib/templates/hacky_tokenization_matching.py +++ b/jiant/tasks/lib/templates/hacky_tokenization_matching.py @@ -1,5 +1,4 @@ """TODO: Remove when Tokenizers gets better (issue #1189)""" -import transformers from jiant.tasks.utils import ExclusiveSpan @@ -31,80 +30,8 @@ def input_flat_strip(tokens): return "".join(tokens).lower() -def delegate_flat_strip(tokens, tokenizer, return_indices=False): - if isinstance(tokenizer, transformers.BertTokenizer): - return bert_flat_strip(tokens=tokens, return_indices=return_indices) - elif isinstance(tokenizer, transformers.RobertaTokenizer): - return roberta_flat_strip(tokens=tokens, return_indices=return_indices) - elif isinstance(tokenizer, transformers.AlbertTokenizer): - return albert_flat_strip(tokens=tokens, return_indices=return_indices) - elif isinstance(tokenizer, transformers.XLMRobertaTokenizer): - return xlm_roberta_flat_strip(tokens=tokens, return_indices=return_indices) - else: - raise KeyError(type(tokenizer)) - - -def bert_flat_strip(tokens, return_indices=False): - ls = [] - count = 0 - indices = [] - for token in tokens: - if token.startswith("##"): - token = token.replace("##", "") - else: - pass - ls.append(token) - indices += [count] * len(token) - count += 1 - string = "".join(ls).lower() - if return_indices: - return string, indices - else: - return string - - -def roberta_flat_strip(tokens, return_indices=False): - ls = [] - count = 0 - indices = [] - for token in tokens: - if token.startswith("Ġ"): - token = token.replace("Ġ", "") - else: - pass - ls.append(token) - indices += [count] * len(token) - count += 1 - string = "".join(ls).lower() - if return_indices: - return string, indices - else: - return string - - -def xlm_roberta_flat_strip(tokens, return_indices=False): - # TODO: Refactor to use general SentencePiece function (issue #1181) - return albert_flat_strip(tokens=tokens, return_indices=return_indices) - - -def albert_flat_strip(tokens, return_indices=False): - ls = [] - count = 0 - indices = [] - for token in tokens: - token = token.replace('"', "``") - if token.startswith("▁"): - token = token[1:] - else: - pass - ls.append(token) - indices += [count] * len(token) - count += 1 - string = "".join(ls).lower() - if return_indices: - return string, indices - else: - return string +def flat_strip(tokens, tokenizer, return_indices=False): + return tokenizer.convert_tokens_to_string(tokens).replace(" ", "").lower() def starts_with(ls, prefix): @@ -118,7 +45,7 @@ def get_token_span(sentence, span: ExclusiveSpan, tokenizer): assert starts_with(tokenized, tokenized_start1) # assert starts_with(tokenized, tokenized_start2) # <- fails because of "does" in "doesn't" word = sentence[span.to_slice()] - assert word.lower().replace(" ", "") in delegate_flat_strip( + assert word.lower().replace(" ", "") in flat_strip( tokenized_start2[len(tokenized_start1) :], tokenizer=tokenizer, ) token_span = ExclusiveSpan(start=len(tokenized_start1), end=len(tokenized_start2)) diff --git a/jiant/tasks/lib/templates/squad_style/core.py b/jiant/tasks/lib/templates/squad_style/core.py index b19e216f1..9231643a5 100644 --- a/jiant/tasks/lib/templates/squad_style/core.py +++ b/jiant/tasks/lib/templates/squad_style/core.py @@ -5,7 +5,9 @@ from dataclasses import dataclass from typing import Union, List, Dict, Optional -from transformers.tokenization_bert import whitespace_tokenize +from transformers.models.bert.tokenization_bert import whitespace_tokenize +from transformers.tokenization_utils_base import TruncationStrategy + from jiant.tasks.lib.templates.squad_style import utils as squad_utils from jiant.shared.constants import PHASE @@ -144,20 +146,30 @@ def to_feature_list( # in the way they compute mask of added tokens. tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower() sequence_added_tokens = ( - tokenizer.max_len - tokenizer.max_len_single_sentence + 1 + tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1 if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET - else tokenizer.max_len - tokenizer.max_len_single_sentence + else tokenizer.model_max_length - tokenizer.max_len_single_sentence ) - sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair + sequence_pair_added_tokens = tokenizer.model_max_length - tokenizer.max_len_sentences_pair span_doc_tokens = all_doc_tokens while len(spans) * doc_stride < len(all_doc_tokens): + # Define the side we want to truncate / pad and the text/pair sorting + if tokenizer.padding_side == "right": + texts = truncated_query + pairs = span_doc_tokens + truncation = TruncationStrategy.ONLY_SECOND.value + else: + texts = span_doc_tokens + pairs = truncated_query + truncation = TruncationStrategy.ONLY_FIRST.value + encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic - truncated_query if tokenizer.padding_side == "right" else span_doc_tokens, - span_doc_tokens if tokenizer.padding_side == "right" else truncated_query, - truncation="only_second" if tokenizer.padding_side == "right" else "only_first", - pad_to_max_length=True, + texts, + pairs, + truncation=truncation, + padding="max_length", max_length=max_seq_length, return_overflowing_tokens=True, stride=max_seq_length @@ -234,9 +246,9 @@ def to_feature_list( # Identify the position of the CLS token cls_index = span["input_ids"].index(tokenizer.cls_token_id) - # p_mask: mask with 1 for token than cannot be in the answer - # (0 for token which can be in an answer) - # Original TF implem also keep the classification token (set to 0) (not sure why...) + # p_mask: mask with 1 for token than cannot be in the answer (0 for token + # which can be in an answer) + # Original TF implementation also keep the classification token (set to 0) p_mask = np.ones_like(span["token_type_ids"]) if tokenizer.padding_side == "right": p_mask[len(truncated_query) + sequence_added_tokens :] = 0 diff --git a/jiant/tasks/lib/templates/squad_style/utils.py b/jiant/tasks/lib/templates/squad_style/utils.py index bc84fc9eb..9cd35daa0 100644 --- a/jiant/tasks/lib/templates/squad_style/utils.py +++ b/jiant/tasks/lib/templates/squad_style/utils.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import List, Dict -from transformers.tokenization_bert import BasicTokenizer +from transformers.models.bert.tokenization_bert import BasicTokenizer from jiant.utils.display import maybe_tqdm diff --git a/jiant/utils/python/datastructures.py b/jiant/utils/python/datastructures.py index 8707261ab..eefc6e902 100644 --- a/jiant/utils/python/datastructures.py +++ b/jiant/utils/python/datastructures.py @@ -279,3 +279,44 @@ def get_maps(self) -> Tuple[Dict, Dict]: """ return self.a_to_b, self.b_to_a + + +class BiDict(dict): + """Maintains bidirectional dict + + Example: + bd = BiDict({'a': 1, 'b': 2}) + print(bd) # {'a': 1, 'b': 2} + print(bd.inverse) # {1: ['a'], 2: ['b']} + bd['c'] = 1 # Now two keys have the same value (= 1) + print(bd) # {'a': 1, 'c': 1, 'b': 2} + print(bd.inverse) # {1: ['a', 'c'], 2: ['b']} + del bd['c'] + print(bd) # {'a': 1, 'b': 2} + print(bd.inverse) # {1: ['a'], 2: ['b']} + del bd['a'] + print(bd) # {'b': 2} + print(bd.inverse) # {2: ['b']} + bd['b'] = 3 + print(bd) # {'b': 3} + print(bd.inverse) # {2: [], 3: ['b']} + + """ + + def __init__(self, *args, **kwargs): + super(BiDict, self).__init__(*args, **kwargs) + self.inverse = {} + for key, value in self.items(): + self.inverse.setdefault(value, []).append(key) + + def __setitem__(self, key, value): + if key in self: + self.inverse[self[key]].remove(key) + super(BiDict, self).__setitem__(key, value) + self.inverse.setdefault(value, []).append(key) + + def __delitem__(self, key): + self.inverse.setdefault(self[key], []).remove(key) + if self[key] in self.inverse and not self.inverse[self[key]]: + del self.inverse[self[key]] + super(BiDict, self).__delitem__(key) diff --git a/jiant/utils/tokenization_normalization.py b/jiant/utils/tokenization_normalization.py index 8a07be2c9..95c59ac5e 100644 --- a/jiant/utils/tokenization_normalization.py +++ b/jiant/utils/tokenization_normalization.py @@ -8,11 +8,12 @@ """ -import re import transformers from typing import Sequence from jiant.utils.testing import utils as test_utils +from jiant.shared.model_resolution import resolve_model_arch_tokenizer +from jiant.proj.main.modeling.primary import JiantTransformersModelFactory def normalize_tokenizations( @@ -49,80 +50,24 @@ def normalize_tokenizations( if len(space_tokenization) == 0 or len(target_tokenization) == 0: raise ValueError("Empty token sequence.") - if isinstance(tokenizer, transformers.BertTokenizer): - if tokenizer.init_kwargs.get("do_lower_case", False): - space_tokenization = [token.lower() for token in space_tokenization] - modifed_space_tokenization = bow_tag_tokens(space_tokenization) - modifed_target_tokenization = _process_wordpiece_tokens(target_tokenization) - elif isinstance(tokenizer, transformers.XLMTokenizer): - if tokenizer.init_kwargs.get("do_lowercase_and_remove_accent", False): - space_tokenization = [token.lower() for token in space_tokenization] - modifed_space_tokenization = eow_tag_tokens(space_tokenization) - modifed_target_tokenization = target_tokenization - elif isinstance(tokenizer, transformers.RobertaTokenizer): - modifed_space_tokenization = bow_tag_tokens(space_tokenization) - modifed_target_tokenization = ["Ġ" + target_tokenization[0]] + target_tokenization[1:] - modifed_target_tokenization = _process_bytebpe_tokens(modifed_target_tokenization) - elif isinstance(tokenizer, (transformers.AlbertTokenizer, transformers.XLMRobertaTokenizer)): - space_tokenization = [token.lower() for token in space_tokenization] - modifed_space_tokenization = bow_tag_tokens(space_tokenization) - modifed_target_tokenization = _process_sentencepiece_tokens(target_tokenization) - else: - if test_utils.is_pytest(): - from jiant.utils.testing.tokenizer import SimpleSpaceTokenizer - - if isinstance(tokenizer, SimpleSpaceTokenizer): - return space_tokenization, target_tokenization - raise ValueError("Tokenizer not supported.") + if test_utils.is_pytest(): + from jiant.utils.testing.tokenizer import SimpleSpaceTokenizer + + if isinstance(tokenizer, SimpleSpaceTokenizer): + return space_tokenization, target_tokenization + + model_arch = resolve_model_arch_tokenizer(tokenizer) + print(model_arch) + jiant_transformer_model_class = JiantTransformersModelFactory.get_registry()[model_arch] + ( + modifed_space_tokenization, + modifed_target_tokenization, + ) = jiant_transformer_model_class.normalize_tokenizations( + tokenizer, space_tokenization, target_tokenization + ) # safety check: if normalization changed sequence length, alignment is likely to break. assert len(modifed_space_tokenization) == len(space_tokenization) assert len(modifed_target_tokenization) == len(target_tokenization) return modifed_space_tokenization, modifed_target_tokenization - - -def bow_tag_tokens(tokens: Sequence[str], bow_tag: str = ""): - """Applies a beginning of word (BoW) marker to every token in the tokens sequence.""" - return [bow_tag + t for t in tokens] - - -def eow_tag_tokens(tokens: Sequence[str], eow_tag: str = ""): - """Applies a end of word (EoW) marker to every token in the tokens sequence.""" - return [t + eow_tag for t in tokens] - - -def _process_wordpiece_tokens(tokens: Sequence[str]): - return [_process_wordpiece_token_for_alignment(token) for token in tokens] - - -def _process_sentencepiece_tokens(tokens: Sequence[str]): - return [_process_sentencepiece_token_for_alignment(token) for token in tokens] - - -def _process_bytebpe_tokens(tokens: Sequence[str]): - return [_process_bytebpe_token_for_alignment(token) for token in tokens] - - -def _process_wordpiece_token_for_alignment(t): - """Add word boundary markers, removes token prefix (no-space meta-symbol — '##' for BERT).""" - if t.startswith("##"): - return re.sub(r"^##", "", t) - else: - return "" + t - - -def _process_sentencepiece_token_for_alignment(t): - """Add word boundary markers, removes token prefix (space meta-symbol).""" - if t.startswith("▁"): - return "" + re.sub(r"^▁", "", t) - else: - return t - - -def _process_bytebpe_token_for_alignment(t): - """Add word boundary markers, removes token prefix (space meta-symbol).""" - if t.startswith("Ġ"): - return "" + re.sub(r"^Ġ", "", t) - else: - return t diff --git a/jiant/utils/tokenization_utils.py b/jiant/utils/tokenization_utils.py new file mode 100644 index 000000000..7e250348d --- /dev/null +++ b/jiant/utils/tokenization_utils.py @@ -0,0 +1,49 @@ +import re + +from typing import Sequence + + +def bow_tag_tokens(tokens: Sequence[str], bow_tag: str = ""): + """Applies a beginning of word (BoW) marker to every token in the tokens sequence.""" + return [bow_tag + t for t in tokens] + + +def eow_tag_tokens(tokens: Sequence[str], eow_tag: str = ""): + """Applies a end of word (EoW) marker to every token in the tokens sequence.""" + return [t + eow_tag for t in tokens] + + +def process_wordpiece_tokens(tokens: Sequence[str]): + return [process_wordpiece_token_for_alignment(token) for token in tokens] + + +def process_sentencepiece_tokens(tokens: Sequence[str]): + return [process_sentencepiece_token_for_alignment(token) for token in tokens] + + +def process_bytebpe_tokens(tokens: Sequence[str]): + return [process_bytebpe_token_for_alignment(token) for token in tokens] + + +def process_wordpiece_token_for_alignment(t): + """Add word boundary markers, removes token prefix (no-space meta-symbol — '##' for BERT).""" + if t.startswith("##"): + return re.sub(r"^##", "", t) + else: + return "" + t + + +def process_sentencepiece_token_for_alignment(t): + """Add word boundary markers, removes token prefix (space meta-symbol).""" + if t.startswith("▁"): + return "" + re.sub(r"^▁", "", t) + else: + return t + + +def process_bytebpe_token_for_alignment(t): + """Add word boundary markers, removes token prefix (space meta-symbol).""" + if t.startswith("Ġ"): + return "" + re.sub(r"^Ġ", "", t) + else: + return t diff --git a/jiant/utils/transformer_utils.py b/jiant/utils/transformer_utils.py deleted file mode 100644 index 844b6c368..000000000 --- a/jiant/utils/transformer_utils.py +++ /dev/null @@ -1,32 +0,0 @@ -import contextlib - -from jiant.shared.model_resolution import ModelArchitectures - - -@contextlib.contextmanager -def output_hidden_states_context(encoder): - model_arch = ModelArchitectures.from_encoder(encoder) - if model_arch in ( - ModelArchitectures.BERT, - ModelArchitectures.ROBERTA, - ModelArchitectures.ALBERT, - ModelArchitectures.XLM_ROBERTA, - ModelArchitectures.ELECTRA, - ): - if hasattr(encoder.encoder, "output_hidden_states"): - # Transformers < v2 - modified_obj = encoder.encoder - elif hasattr(encoder.encoder.config, "output_hidden_states"): - # Transformers >= v3 - modified_obj = encoder.encoder.config - else: - raise RuntimeError(f"Failed to convert model {type(encoder)} to output hidden states") - old_value = modified_obj.output_hidden_states - modified_obj.output_hidden_states = True - yield - modified_obj.output_hidden_states = old_value - elif model_arch in (ModelArchitectures.BART, ModelArchitectures.MBART): - yield - return - else: - raise KeyError(model_arch) diff --git a/pyproject.toml b/pyproject.toml index 79f3550b9..cc81630b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,11 @@ [tool.black] line-length = 100 +[tool.pytest.ini_options] +filterwarnings = [ + "ignore::UserWarning", +] + include = '\.pyi?$' exclude = ''' diff --git a/requirements-no-torch.txt b/requirements-no-torch.txt index 0b45593c0..8143a32ac 100644 --- a/requirements-no-torch.txt +++ b/requirements-no-torch.txt @@ -12,7 +12,7 @@ sacremoses==0.0.43 seqeval==0.0.12 scikit-learn==0.22.2.post1 scipy==1.4.1 -sentencepiece==0.1.86 -tokenizers==0.8.1.rc2 +sentencepiece==0.1.91 +tokenizers==0.10.1 tqdm==4.46.0 -transformers==3.1.0 +transformers==4.5.0 diff --git a/setup.py b/setup.py index 1e426a3ee..baee1405c 100644 --- a/setup.py +++ b/setup.py @@ -71,11 +71,11 @@ "seqeval == 0.0.12", "scikit-learn == 0.22.2.post1", "scipy == 1.4.1", - "sentencepiece == 0.1.86", - "tokenizers == 0.8.1.rc2", - "torch >= 1.8.1", + "sentencepiece == 0.1.91", + "tokenizers == 0.10.1", "tqdm == 4.46.0", - "transformers == 3.1.0", + "transformers == 4.5.0", + "torch >= 1.8.1", "torchvision == 0.9.1", ], extras_require=extras, diff --git a/tests/proj/main/test_export_model.py b/tests/proj/main/test_export_model.py index dc259c407..14255f0df 100644 --- a/tests/proj/main/test_export_model.py +++ b/tests/proj/main/test_export_model.py @@ -1,8 +1,14 @@ import os import pytest -from transformers import BertPreTrainedModel, BertTokenizer, RobertaForMaskedLM, RobertaTokenizer + +from transformers import BertPreTrainedModel +from transformers import BertTokenizer +from transformers import DebertaV2ForMaskedLM +from transformers import RobertaForMaskedLM +from transformers import RobertaTokenizer import jiant.utils.python.io as py_io + from jiant.proj.main.export_model import export_model @@ -22,3 +28,19 @@ def test_export_model(tmp_path, model_type, model_class, hf_pretrained_model_nam assert read_config["hf_pretrained_model_name_or_path"] == hf_pretrained_model_name_or_path assert read_config["model_path"] == os.path.join(tmp_path, "model", "model.p") assert read_config["model_config_path"] == os.path.join(tmp_path, "model", "config.json") + + +@pytest.mark.slow +@pytest.mark.parametrize( + "model_type, model_class, hf_pretrained_model_name_or_path", + [("deberta-v2-xlarge", DebertaV2ForMaskedLM, "microsoft/deberta-v2-xlarge",), ], +) +def test_export_model_large(tmp_path, model_type, model_class, hf_pretrained_model_name_or_path): + export_model( + hf_pretrained_model_name_or_path=hf_pretrained_model_name_or_path, + output_base_path=tmp_path, + ) + read_config = py_io.read_json(os.path.join(tmp_path, f"config.json")) + assert read_config["hf_pretrained_model_name_or_path"] == hf_pretrained_model_name_or_path + assert read_config["model_path"] == os.path.join(tmp_path, "model", "model.p") + assert read_config["model_config_path"] == os.path.join(tmp_path, "model", "config.json") diff --git a/tests/proj/simple/test_runscript.py b/tests/proj/simple/test_runscript.py index 92ee288db..ed048e5f6 100644 --- a/tests/proj/simple/test_runscript.py +++ b/tests/proj/simple/test_runscript.py @@ -8,15 +8,30 @@ import jiant.scripts.download_data.runscript as downloader import jiant.utils.torch_utils as torch_utils -EXPECTED_AGG_VAL_METRICS = {"bert-base-cased": {"rte": 0.5740072202166066, "commonsenseqa": 0.4258804258804259, "squad_v1": 29.071789929086883}, - "roberta-base": {"rte": 0.49458483754512633, "commonsenseqa": 0.23013923013923013, "squad_v1": 48.222444172918955}, - "xlm-roberta-base": {"rte": 0.4729241877256318, "commonsenseqa": 0.22686322686322685, "squad_v1": 10.30104037978786}} - - +EXPECTED_AGG_VAL_METRICS = { + "bert-base-cased": { + "rte": 0.5956678700361011, + "commonsenseqa": 0.5176085176085176, + "squad_v1": 54.045103183650156, + }, + "roberta-base": { + "rte": 0.6967509025270758, + "commonsenseqa": 0.44963144963144963, + "squad_v1": 68.66217365509084, + }, + "xlm-roberta-base": { + "rte": 0.5956678700361011, + "commonsenseqa": 0.24242424242424243, + "squad_v1": 42.86723254466678, + }, +} + + +@pytest.mark.slow @pytest.mark.parametrize("task_name", ["copa"]) -@pytest.mark.parametrize("model_type", ["bert-base-cased"]) -def test_simple_runscript(tmpdir, task_name, model_type): - RUN_NAME = f"{test_simple_runscript.__name__}_{task_name}_{model_type}" +@pytest.mark.parametrize("model_type", ["bert-base-uncased", "microsoft/deberta-v2-xlarge"]) +def test_simple_runscript_sanity(tmpdir, task_name, model_type): + RUN_NAME = f"{test_simple_runscript_sanity.__name__}_{task_name}_{model_type.replace('/','_')}" data_dir = str(tmpdir.mkdir("data")) exp_dir = str(tmpdir.mkdir("exp")) @@ -38,8 +53,12 @@ def test_simple_runscript(tmpdir, task_name, model_type): assert val_metrics["aggregated"] > 0 +@pytest.mark.gpu @pytest.mark.overnight -@pytest.mark.parametrize(("task_name", "train_examples_cap"), [("rte", 1024), ("commonsenseqa", 1024), ("squad_v1", 2048)]) +@pytest.mark.parametrize( + ("task_name", "train_examples_cap"), + [("rte", 4096), ("commonsenseqa", 4096), ("squad_v1", 4096)], +) @pytest.mark.parametrize("model_type", ["bert-base-cased", "roberta-base", "xlm-roberta-base"]) def test_simple_runscript(tmpdir, task_name, train_examples_cap, model_type): RUN_NAME = f"{test_simple_runscript.__name__}_{task_name}_{model_type}" @@ -63,7 +82,10 @@ def test_simple_runscript(tmpdir, task_name, train_examples_cap, model_type): run.run_simple(args) val_metrics = py_io.read_json(os.path.join(exp_dir, "runs", RUN_NAME, "val_metrics.json")) - assert math.isclose(val_metrics["aggregated"], EXPECTED_AGG_VAL_METRICS[model_type][task_name]) + assert ( + math.isclose(val_metrics["aggregated"], EXPECTED_AGG_VAL_METRICS[model_type][task_name]) + or val_metrics["aggregated"] >= EXPECTED_AGG_VAL_METRICS[model_type][task_name] + ) torch.use_deterministic_algorithms(False) @@ -83,12 +105,12 @@ def test_simple_runscript_save(tmpdir, task_name, model_type): data_dir=data_dir, hf_pretrained_model_name_or_path=model_type, tasks=task_name, - max_steps=1, + train_examples_cap=64, train_batch_size=32, do_save=True, - eval_every_steps=10, + eval_every_steps=1, learning_rate=0.01, - num_train_epochs=5, + num_train_epochs=2, ) run.run_simple(args) @@ -114,7 +136,7 @@ def test_simple_runscript_save(tmpdir, task_name, model_type): data_dir=data_dir, hf_pretrained_model_name_or_path=model_type, tasks=task_name, - max_steps=1, + train_examples_cap=32, train_batch_size=16, do_save_best=True, ) @@ -134,7 +156,7 @@ def test_simple_runscript_save(tmpdir, task_name, model_type): data_dir=data_dir, hf_pretrained_model_name_or_path=model_type, tasks=task_name, - max_steps=1, + train_examples_cap=32, train_batch_size=16, do_save_last=True, ) diff --git a/tests/tasks/lib/templates/test_hacky_tokenization_matching.py b/tests/tasks/lib/templates/test_hacky_tokenization_matching.py new file mode 100644 index 000000000..6dd70bd65 --- /dev/null +++ b/tests/tasks/lib/templates/test_hacky_tokenization_matching.py @@ -0,0 +1,24 @@ +import pytest + +from transformers import RobertaTokenizer, BertTokenizer + +from jiant.tasks.lib.templates.hacky_tokenization_matching import flat_strip + +from jiant.utils.testing.tokenizer import SimpleSpaceTokenizer +import jiant.shared.model_resolution as model_resolution + + +TEST_STRINGS = ["Hi, my name is Bob Roberts."] +FLAT_STRIP_EXPECTED_STRINGS = ["hi,mynameisbobroberts."] + + +@pytest.mark.parametrize("model_type", ["albert-base-v2", "roberta-base", "bert-base-uncased"]) +def test_delegate_flat_strip(model_type): + tokenizer = model_resolution.resolve_tokenizer_class(model_type.split("-")[0]).from_pretrained( + model_type + ) + for test_string, target_string in zip(TEST_STRINGS, FLAT_STRIP_EXPECTED_STRINGS): + flat_strip_result = flat_strip( + tokenizer.tokenize(test_string), tokenizer, return_indices=False + ) + assert flat_strip_result == target_string diff --git a/tests/tasks/lib/test_mlm_premasked.py b/tests/tasks/lib/test_mlm_premasked.py index 17a5dd0fe..a270dd095 100644 --- a/tests/tasks/lib/test_mlm_premasked.py +++ b/tests/tasks/lib/test_mlm_premasked.py @@ -1,11 +1,13 @@ import transformers import jiant.shared.model_resolution as model_resolution -import jiant.tasks as tasks + +from jiant.proj.main.modeling.primary import JiantTransformersModelFactory +from jiant.tasks.retrieval import MLMPremaskedTask def test_tokenization_and_featurization(): - task = tasks.MLMPremaskedTask(name="mlm_premasked", path_dict={}) + task = MLMPremaskedTask(name="mlm_premasked", path_dict={}) tokenizer = transformers.RobertaTokenizer.from_pretrained("roberta-base") example = task.Example(guid=None, text="Hi, my name is Bob Roberts.", masked_spans=[[15, 18]],) tokenized_example = example.tokenize(tokenizer=tokenizer) @@ -34,8 +36,8 @@ def test_tokenization_and_featurization(): data_row = tokenized_example.featurize( tokenizer=tokenizer, - feat_spec=model_resolution.build_featurization_spec( - model_type="roberta-base", max_seq_length=16, + feat_spec=JiantTransformersModelFactory.build_featurization_spec( + model_type="roberta", max_seq_length=16, ), ) assert list(data_row.masked_input_ids) == [ diff --git a/tests/tasks/lib/test_mlm_pretokenized.py b/tests/tasks/lib/test_mlm_pretokenized.py index 2b682f234..86d0c81dc 100644 --- a/tests/tasks/lib/test_mlm_pretokenized.py +++ b/tests/tasks/lib/test_mlm_pretokenized.py @@ -1,11 +1,14 @@ import transformers import jiant.shared.model_resolution as model_resolution -import jiant.tasks as tasks +from jiant.tasks.retrieval import MLMPretokenizedTask + +from jiant.shared.model_resolution import ModelArchitectures +from jiant.proj.main.modeling.primary import JiantTransformersModelFactory def test_tokenization_and_featurization(): - task = tasks.MLMPretokenizedTask(name="mlm_pretokenized", path_dict={}) + task = MLMPretokenizedTask(name="mlm_pretokenized", path_dict={}) tokenizer = transformers.RobertaTokenizer.from_pretrained("roberta-base") example = task.Example( guid=None, @@ -36,8 +39,8 @@ def test_tokenization_and_featurization(): data_row = tokenized_example.featurize( tokenizer=tokenizer, - feat_spec=model_resolution.build_featurization_spec( - model_type="roberta-base", max_seq_length=16, + feat_spec=JiantTransformersModelFactory.build_featurization_spec( + model_type=ModelArchitectures.ROBERTA.value, max_seq_length=16, ), ) assert list(data_row.masked_input_ids) == [ diff --git a/tests/tasks/lib/test_mnli.py b/tests/tasks/lib/test_mnli.py index 2b564e41d..a4f2438a0 100644 --- a/tests/tasks/lib/test_mnli.py +++ b/tests/tasks/lib/test_mnli.py @@ -1,10 +1,13 @@ +import numpy as np import os + from collections import Counter -import numpy as np from jiant.shared import model_resolution -from jiant.tasks import create_task_from_config_path +from jiant.shared.model_resolution import ModelArchitectures +from jiant.tasks.retrieval import create_task_from_config_path from jiant.utils.testing.tokenizer import SimpleSpaceTokenizer +from jiant.proj.main.modeling.primary import JiantTransformersModelFactory TRAIN_EXAMPLES = [ @@ -300,8 +303,8 @@ def test_featurization_of_task_data(): train_example_0_length = len(tokenized_examples[0].premise) + len( tokenized_examples[0].hypothesis ) - feat_spec = model_resolution.build_featurization_spec( - model_type="bert-", max_seq_length=train_example_0_length + feat_spec = JiantTransformersModelFactory.build_featurization_spec( + model_type=ModelArchitectures.BERT.value, max_seq_length=train_example_0_length ) featurized_examples = [ tokenized_example.featurize(tokenizer=tokenizer, feat_spec=feat_spec) diff --git a/tests/tasks/lib/test_spr1.py b/tests/tasks/lib/test_spr1.py index fb8aefb2a..13e7ed331 100644 --- a/tests/tasks/lib/test_spr1.py +++ b/tests/tasks/lib/test_spr1.py @@ -6,8 +6,9 @@ from unittest.mock import Mock from jiant.shared import model_resolution -from jiant.tasks import create_task_from_config_path +from jiant.tasks.retrieval import create_task_from_config_path from jiant.utils.testing.tokenizer import SimpleSpaceTokenizer +from jiant.proj.main.modeling.primary import JiantTransformersModelFactory TRAIN_EXAMPLES = [ @@ -314,8 +315,8 @@ def test_featurization_of_task_data(): # Testing conversion of a tokenized example to a featurized example train_example_0_length = len(tokenized_examples[0].tokens) + 4 - feat_spec = model_resolution.build_featurization_spec( - model_type="bert-", max_seq_length=train_example_0_length + feat_spec = JiantTransformersModelFactory.build_featurization_spec( + model_type="bert", max_seq_length=train_example_0_length ) featurized_examples = [ tokenized_example.featurize(tokenizer=tokenizer, feat_spec=feat_spec) diff --git a/tests/tasks/lib/test_sst.py b/tests/tasks/lib/test_sst.py index 3a76f28f1..afaa20eb8 100644 --- a/tests/tasks/lib/test_sst.py +++ b/tests/tasks/lib/test_sst.py @@ -3,7 +3,7 @@ import numpy as np -from jiant.tasks import create_task_from_config_path +from jiant.tasks.retrieval import create_task_from_config_path from jiant.utils.testing.tokenizer import SimpleSpaceTokenizer diff --git a/tests/utils/test_tokenization_normalization.py b/tests/utils/test_tokenization_normalization.py index c07ae44bf..538a7fac5 100644 --- a/tests/utils/test_tokenization_normalization.py +++ b/tests/utils/test_tokenization_normalization.py @@ -1,8 +1,13 @@ import pytest import jiant.utils.tokenization_normalization as tn +import jiant.utils.tokenization_utils as tu -from transformers import BertTokenizer, XLMTokenizer, RobertaTokenizer, AlbertTokenizer +from transformers import AlbertTokenizer +from transformers import BertTokenizer +from transformers import DebertaV2Tokenizer +from transformers import RobertaTokenizer +from transformers import XLMTokenizer def test_process_wordpiece_token_sequence(): @@ -52,7 +57,7 @@ def test_process_wordpiece_token_sequence(): "rules", ".", ] - adjusted_wordpiece_tokens = tn._process_wordpiece_tokens(original_wordpiece_tokens) + adjusted_wordpiece_tokens = tu.process_wordpiece_tokens(original_wordpiece_tokens) assert adjusted_wordpiece_tokens == expected_adjusted_wordpiece_tokens @@ -103,7 +108,7 @@ def test_process_sentencepiece_token_sequence(): "▁rules", ".", ] - adjusted_sentencepiece_tokens = tn._process_sentencepiece_tokens(original_sentencepiece_tokens) + adjusted_sentencepiece_tokens = tu.process_sentencepiece_tokens(original_sentencepiece_tokens) assert adjusted_sentencepiece_tokens == expected_adjusted_sentencepiece_tokens @@ -144,7 +149,7 @@ def test_process_bytebpe_token_sequence(): "Ġrules", ".", ] - adjusted_bytebpe_tokens = tn._process_bytebpe_tokens(original_bytebpe_tokens) + adjusted_bytebpe_tokens = tu.process_bytebpe_tokens(original_bytebpe_tokens) assert adjusted_bytebpe_tokens == expected_adjusted_bytebpe_tokens @@ -155,58 +160,21 @@ def test_process_bytebpe_token_sequence(): @pytest.mark.slow -def test_space_tokenization_and_bert_uncased_tokenization_normalization(): +@pytest.mark.parametrize( + "hf_tokenizer, hf_model", + [ + (BertTokenizer, "bert-base-uncased"), + (BertTokenizer, "bert-base-cased"), + (XLMTokenizer, "xlm-mlm-en-2048"), + (RobertaTokenizer, "roberta-base"), + (AlbertTokenizer, "albert-base-v1"), + (DebertaV2Tokenizer, "microsoft/deberta-v2-xlarge"), + ], +) +def test_space_tokenization_tokenization_normalization(hf_tokenizer, hf_model): text = "Jeff Immelt chose to focus on the incomprehensibility of accounting rules ." space_tokenized = text.split(" ") - tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") - target_tokenized = tokenizer.tokenize(text) - normed_space_tokenized, normed_target_tokenized = tn.normalize_tokenizations( - space_tokenized, target_tokenized, tokenizer - ) - assert "".join(normed_space_tokenized) == "".join(normed_target_tokenized) - - -@pytest.mark.slow -def test_space_tokenization_and_bert_cased_tokenization_normalization(): - text = "Jeff Immelt chose to focus on the incomprehensibility of accounting rules ." - space_tokenized = text.split(" ") - tokenizer = BertTokenizer.from_pretrained("bert-base-cased") - target_tokenized = tokenizer.tokenize(text) - normed_space_tokenized, normed_target_tokenized = tn.normalize_tokenizations( - space_tokenized, target_tokenized, tokenizer - ) - assert "".join(normed_space_tokenized) == "".join(normed_target_tokenized) - - -@pytest.mark.slow -def test_space_tokenization_and_xlm_uncased_tokenization_normalization(): - text = "Jeff Immelt chose to focus on the incomprehensibility of accounting rules ." - space_tokenized = text.split(" ") - tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048") - target_tokenized = tokenizer.tokenize(text) - normed_space_tokenized, normed_target_tokenized = tn.normalize_tokenizations( - space_tokenized, target_tokenized, tokenizer - ) - assert "".join(normed_space_tokenized) == "".join(normed_target_tokenized) - - -@pytest.mark.slow -def test_space_tokenization_and_roberta_tokenization_normalization(): - text = "Jeff Immelt chose to focus on the incomprehensibility of accounting rules ." - space_tokenized = text.split(" ") - tokenizer = RobertaTokenizer.from_pretrained("roberta-base") - target_tokenized = tokenizer.tokenize(text) - normed_space_tokenized, normed_target_tokenized = tn.normalize_tokenizations( - space_tokenized, target_tokenized, tokenizer - ) - assert "".join(normed_space_tokenized) == "".join(normed_target_tokenized) - - -@pytest.mark.slow -def test_space_tokenization_and_albert_tokenization_normalization(): - text = "Jeff Immelt chose to focus on the incomprehensibility of accounting rules ." - space_tokenized = text.split(" ") - tokenizer = AlbertTokenizer.from_pretrained("albert-base-v1") + tokenizer = hf_tokenizer.from_pretrained(hf_model) target_tokenized = tokenizer.tokenize(text) normed_space_tokenized, normed_target_tokenized = tn.normalize_tokenizations( space_tokenized, target_tokenized, tokenizer