diff --git a/examples/README.md b/examples/README.md index 99ec4f118731..249f5134acea 100644 --- a/examples/README.md +++ b/examples/README.md @@ -37,7 +37,7 @@ git checkout tags/v3.4.0 |---|---|:---:|:---:|:---:|:---:| | [**`language-modeling`**](https://github.com/huggingface/transformers/tree/master/examples/language-modeling) | Raw text | ✅ | - | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb) | [**`text-classification`**](https://github.com/huggingface/transformers/tree/master/examples/text-classification) | GLUE, XNLI | ✅ | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/huggingface/notebooks/blob/master/examples/text_classification.ipynb) -| [**`token-classification`**](https://github.com/huggingface/transformers/tree/master/examples/token-classification) | CoNLL NER | ✅ | ✅ | - | - +| [**`token-classification`**](https://github.com/huggingface/transformers/tree/master/examples/token-classification) | CoNLL NER | ✅ | ✅ | ✅ | - | [**`multiple-choice`**](https://github.com/huggingface/transformers/tree/master/examples/multiple-choice) | SWAG, RACE, ARC | ✅ | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ViktorAlm/notebooks/blob/master/MPC_GPU_Demo_for_TF_and_PT.ipynb) | [**`question-answering`**](https://github.com/huggingface/transformers/tree/master/examples/question-answering) | SQuAD | ✅ | ✅ | - | - | [**`text-generation`**](https://github.com/huggingface/transformers/tree/master/examples/text-generation) | - | n/a | n/a | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb) diff --git a/examples/test_examples.py b/examples/test_examples.py index 4eda398537d7..ad4c5ffe27e2 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -28,7 +28,13 @@ SRC_DIRS = [ os.path.join(os.path.dirname(__file__), dirname) - for dirname in ["text-generation", "text-classification", "language-modeling", "question-answering"] + for dirname in [ + "text-generation", + "text-classification", + "token-classification", + "language-modeling", + "question-answering", + ] ] sys.path.extend(SRC_DIRS) @@ -38,6 +44,7 @@ import run_generation import run_glue import run_mlm + import run_ner import run_pl_glue import run_squad @@ -185,6 +192,36 @@ def test_run_mlm(self): result = run_mlm.main() self.assertLess(result["perplexity"], 42) + def test_run_ner(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_ner.py + --model_name_or_path bert-base-uncased + --train_file tests/fixtures/tests_samples/conll/sample.json + --validation_file tests/fixtures/tests_samples/conll/sample.json + --output_dir {tmp_dir} + --overwrite_output_dir + --do_train + --do_eval + --warmup_steps=2 + --learning_rate=2e-4 + --per_gpu_train_batch_size=2 + --per_gpu_eval_batch_size=2 + --num_train_epochs=2 + """.split() + + if torch_device != "cuda": + testargs.append("--no_cuda") + + with patch.object(sys, "argv", testargs): + result = run_ner.main() + self.assertGreaterEqual(result["eval_accuracy_score"], 0.75) + self.assertGreaterEqual(result["eval_precision"], 0.75) + self.assertLess(result["eval_loss"], 0.5) + def test_run_squad(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) diff --git a/examples/token-classification/README.md b/examples/token-classification/README.md index 8b2c2335acbd..7c9e160650e5 100644 --- a/examples/token-classification/README.md +++ b/examples/token-classification/README.md @@ -1,6 +1,40 @@ -## Named Entity Recognition +## Token classification -Based on the scripts [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner.py) for Pytorch and +Fine-tuning the library models for token classification task such as Named Entity Recognition (NER) or Parts-of-speech +tagging (POS). The main scrip `run_ner.py` leverages the 🤗 Datasets library and the Trainer API. You can easily +customize it to your needs if you need extra processing on your datasets. + +It will either run on a datasets hosted on our [hub](https://huggingface.co/datasets) or with your own text files for +training and validation. + +The following example fine-tunes BERT on CoNLL-2003: + +```bash +python run_ner.py \ + --model_name_or_path bert-base-uncased \ + --dataset_name conll2003 \ + --output_dir /tmp/test-ner \ + --do_train \ + --do_eval +``` + +or just can just run the bash script `run.sh`. + +To run on your own training and validation files, use the following command: + +```bash +python run_ner.py \ + --model_name_or_path bert-base-uncased \ + --train_file path_to_train_file \ + --validation_file path_to_validation_file \ + --output_dir /tmp/test-ner \ + --do_train \ + --do_eval +``` + +## Old version of the script + +Based on the scripts [`run_ner_old.py`](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner_old.py) for Pytorch and [`run_tf_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_tf_ner.py) for Tensorflow 2. The following examples are covered in this section: @@ -69,7 +103,7 @@ export SEED=1 To start training, just run: ```bash -python3 run_ner.py --data_dir ./ \ +python3 run_ner_old.py --data_dir ./ \ --labels ./labels.txt \ --model_name_or_path $BERT_MODEL \ --output_dir $OUTPUT_DIR \ @@ -87,7 +121,7 @@ If your GPU supports half-precision training, just add the `--fp16` flag. After #### JSON-based configuration file -Instead of passing all parameters via commandline arguments, the `run_ner.py` script also supports reading parameters from a json-based configuration file: +Instead of passing all parameters via commandline arguments, the `run_ner_old.py` script also supports reading parameters from a json-based configuration file: ```json { @@ -106,7 +140,7 @@ Instead of passing all parameters via commandline arguments, the `run_ner.py` sc } ``` -It must be saved with a `.json` extension and can be used by running `python3 run_ner.py config.json`. +It must be saved with a `.json` extension and can be used by running `python3 run_ner_old.py config.json`. #### Evaluation @@ -250,7 +284,7 @@ cat data_wnut_17/train.txt data_wnut_17/dev.txt data_wnut_17/test.txt | cut -d " #### Run the Pytorch version -Fine-tuning with the PyTorch version can be started using the `run_ner.py` script. In this example we use a JSON-based configuration file. +Fine-tuning with the PyTorch version can be started using the `run_ner_old.py` script. In this example we use a JSON-based configuration file. This configuration file looks like: @@ -274,7 +308,7 @@ This configuration file looks like: If your GPU supports half-precision training, please set `fp16` to `true`. -Save this JSON-based configuration under `wnut_17.json`. The fine-tuning can be started with `python3 run_ner.py wnut_17.json`. +Save this JSON-based configuration under `wnut_17.json`. The fine-tuning can be started with `python3 run_ner_old.py wnut_17.json`. #### Evaluation diff --git a/examples/token-classification/run.sh b/examples/token-classification/run.sh index f5cbf0d50e02..6c46a813974c 100755 --- a/examples/token-classification/run.sh +++ b/examples/token-classification/run.sh @@ -1,36 +1,6 @@ -## The relevant files are currently on a shared Google -## drive at https://drive.google.com/drive/folders/1kC0I2UGl2ltrluI9NqDjaQJGw5iliw_J -## Monitor for changes and eventually migrate to nlp dataset -curl -L 'https://drive.google.com/uc?export=download&id=1Jjhbal535VVz2ap4v4r_rN1UEHTdLK5P' \ -| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp -curl -L 'https://drive.google.com/uc?export=download&id=1ZfRcQThdtAR5PPRjIDtrVP7BtXSCUBbm' \ -| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp -curl -L 'https://drive.google.com/uc?export=download&id=1u9mb7kNJHWQCWyweMDRMuTFoOHOfeBTH' \ -| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp - -export MAX_LENGTH=128 -export BERT_MODEL=bert-base-multilingual-cased -python3 scripts/preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt -python3 scripts/preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt -python3 scripts/preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt -cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt -export OUTPUT_DIR=germeval-model -export BATCH_SIZE=32 -export NUM_EPOCHS=3 -export SAVE_STEPS=750 -export SEED=1 - python3 run_ner.py \ ---task_type NER \ ---data_dir . \ ---labels ./labels.txt \ ---model_name_or_path $BERT_MODEL \ ---output_dir $OUTPUT_DIR \ ---max_seq_length $MAX_LENGTH \ ---num_train_epochs $NUM_EPOCHS \ ---per_gpu_train_batch_size $BATCH_SIZE \ ---save_steps $SAVE_STEPS \ ---seed $SEED \ ---do_train \ ---do_eval \ ---do_predict + --model_name_or_path bert-base-uncased \ + --dataset_name conll2003 \ + --output_dir /tmp/test-ner \ + --do_train \ + --do_eval diff --git a/examples/token-classification/run_chunk.sh b/examples/token-classification/run_chunk.sh index 13341555b699..3dbb03306d96 100755 --- a/examples/token-classification/run_chunk.sh +++ b/examples/token-classification/run_chunk.sh @@ -21,7 +21,7 @@ export NUM_EPOCHS=3 export SAVE_STEPS=750 export SEED=1 -python3 run_ner.py \ +python3 run_ner_old.py \ --task_type Chunk \ --data_dir . \ --model_name_or_path $BERT_MODEL \ diff --git a/examples/token-classification/run_ner.py b/examples/token-classification/run_ner.py index a2981415f690..3eed7098a5aa 100644 --- a/examples/token-classification/run_ner.py +++ b/examples/token-classification/run_ner.py @@ -1,6 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2020 The HuggingFace Team All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,29 +12,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Fine-tuning the library models for named entity recognition on CoNLL-2003. """ +""" +Fine-tuning the library models for token classification. +""" +# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as comments. + import logging import os import sys from dataclasses import dataclass, field -from importlib import import_module -from typing import Dict, List, Optional, Tuple +from typing import Optional import numpy as np +from datasets import load_dataset from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score -from torch import nn +import transformers from transformers import ( AutoConfig, AutoModelForTokenClassification, AutoTokenizer, - EvalPrediction, + DataCollatorForTokenClassification, HfArgumentParser, Trainer, TrainingArguments, set_seed, ) -from utils_ner import Split, TokenClassificationDataset, TokenClassificationTask +from transformers.trainer_utils import is_main_process logger = logging.getLogger(__name__) @@ -53,15 +56,9 @@ class ModelArguments: config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) - task_type: Optional[str] = field( - default="NER", metadata={"help": "Task type to fine tune in training (e.g. NER, POS, etc)"} - ) tokenizer_name: Optional[str] = field( default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} ) - use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."}) - # If you want to tweak more attributes on your tokenizer, you should do it in a distinct script, - # or just modify its tokenizer_config.json. cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} ) @@ -73,23 +70,58 @@ class DataTrainingArguments: Arguments pertaining to what data we are going to input our model for training and eval. """ - data_dir: str = field( - metadata={"help": "The input data dir. Should contain the .txt files for a CoNLL-2003-formatted task."} + task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."}) + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} ) - labels: Optional[str] = field( + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a csv or JSON file)."} + ) + validation_file: Optional[str] = field( default=None, - metadata={"help": "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."}, + metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."}, ) - max_seq_length: int = field( - default=128, - metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - }, + test_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."}, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + }, + ) + label_all_tokens: bool = field( + default=False, + metadata={ + "help": "Whether to put the label for one word on all tokens of generated by that word or just on the " + "one (in which case the other tokens will have a padding index)." + }, + ) + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + self.task_name = self.task_name.lower() def main(): @@ -112,60 +144,90 @@ def main(): and not training_args.overwrite_output_dir ): raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." - ) - - module = import_module("tasks") - try: - token_classification_task_clazz = getattr(module, model_args.task_type) - token_classification_task: TokenClassificationTask = token_classification_task_clazz() - except AttributeError: - raise ValueError( - f"Task {model_args.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. " - f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}" + f"Output directory ({training_args.output_dir}) already exists and is not empty." + "Use --overwrite_output_dir to overcome." ) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, ) + + # Log on each process the small summary: logger.warning( - "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.local_rank, - training_args.device, - training_args.n_gpu, - bool(training_args.local_rank != -1), - training_args.fp16, + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() logger.info("Training/evaluation parameters %s", training_args) - # Set seed + # Set seed before initializing model. set_seed(training_args.seed) - # Prepare CONLL-2003 task - labels = token_classification_task.get_labels(data_args.labels) - label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)} - num_labels = len(labels) + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.train_file.split(".")[-1] + datasets = load_dataset(extension, data_files=data_files) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + if training_args.do_train: + column_names = datasets["train"].column_names + else: + column_names = datasets["validation"].column_names + text_column_name = "words" if "words" in column_names else column_names[0] + label_column_name = data_args.task_name if data_args.task_name in column_names else column_names[1] + + # Labeling (this part will be easier when https://github.com/huggingface/datasets/issues/797 is solved) + def get_label_list(labels): + unique_labels = set() + for label in labels: + unique_labels = unique_labels | set(label) + label_list = list(unique_labels) + label_list.sort() + return label_list + + label_list = get_label_list(datasets["train"][label_column_name]) + label_to_id = {l: i for i, l in enumerate(label_list)} + num_labels = len(label_list) # Load pretrained model and tokenizer # # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. - config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, - id2label=label_map, - label2id={label: i for i, label in enumerate(labels)}, + finetuning_task=data_args.task_name, cache_dir=model_args.cache_dir, ) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, - use_fast=model_args.use_fast, + use_fast=True, ) model = AutoModelForTokenClassification.from_pretrained( model_args.model_name_or_path, @@ -174,67 +236,85 @@ def main(): cache_dir=model_args.cache_dir, ) - # Get datasets - train_dataset = ( - TokenClassificationDataset( - token_classification_task=token_classification_task, - data_dir=data_args.data_dir, - tokenizer=tokenizer, - labels=labels, - model_type=config.model_type, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.train, + # Preprocessing the dataset + # Padding strategy + padding = "max_length" if data_args.pad_to_max_length else False + + # Tokenize all texts and align the labels with them. + def tokenize_and_align_labels(examples): + tokenized_inputs = tokenizer( + examples[text_column_name], + padding=padding, + truncation=True, + # We use this argument because the texts in our dataset are lists of words (with a label for each word). + is_split_into_words=True, + return_offsets_mapping=True, ) - if training_args.do_train - else None + offset_mappings = tokenized_inputs.pop("offset_mapping") + labels = [] + for label, offset_mapping in zip(examples[label_column_name], offset_mappings): + label_index = 0 + current_label = -100 + label_ids = [] + for offset in offset_mapping: + # We set the label for the first token of each word. Special characters will have an offset of (0, 0) + # so the test ignores them. + if offset[0] == 0 and offset[1] != 0: + current_label = label_to_id[label[label_index]] + label_index += 1 + label_ids.append(current_label) + # For special tokens, we set the label to -100 so it's automatically ignored in the loss function. + elif offset[0] == 0 and offset[1] == 0: + label_ids.append(-100) + # For the other tokens in a word, we set the label to either the current label or -100, depending on + # the label_all_tokens flag. + else: + label_ids.append(current_label if data_args.label_all_tokens else -100) + + labels.append(label_ids) + tokenized_inputs["labels"] = labels + return tokenized_inputs + + tokenized_datasets = datasets.map( + tokenize_and_align_labels, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, ) - eval_dataset = ( - TokenClassificationDataset( - token_classification_task=token_classification_task, - data_dir=data_args.data_dir, - tokenizer=tokenizer, - labels=labels, - model_type=config.model_type, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.dev, - ) - if training_args.do_eval - else None - ) - - def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[List[int], List[int]]: - preds = np.argmax(predictions, axis=2) - batch_size, seq_len = preds.shape + # Data collator + data_collator = DataCollatorForTokenClassification(tokenizer) - out_label_list = [[] for _ in range(batch_size)] - preds_list = [[] for _ in range(batch_size)] + # Metrics + def compute_metrics(p): + predictions, labels = p + predictions = np.argmax(predictions, axis=2) - for i in range(batch_size): - for j in range(seq_len): - if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index: - out_label_list[i].append(label_map[label_ids[i][j]]) - preds_list[i].append(label_map[preds[i][j]]) + # Remove ignored index (special tokens) + true_predictions = [ + [label_list[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [label_list[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] - return preds_list, out_label_list - - def compute_metrics(p: EvalPrediction) -> Dict: - preds_list, out_label_list = align_predictions(p.predictions, p.label_ids) return { - "accuracy_score": accuracy_score(out_label_list, preds_list), - "precision": precision_score(out_label_list, preds_list), - "recall": recall_score(out_label_list, preds_list), - "f1": f1_score(out_label_list, preds_list), + "accuracy_score": accuracy_score(true_labels, true_predictions), + "precision": precision_score(true_labels, true_predictions), + "recall": recall_score(true_labels, true_predictions), + "f1": f1_score(true_labels, true_predictions), } # Initialize our Trainer trainer = Trainer( model=model, args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, + train_dataset=tokenized_datasets["train"] if training_args.do_train else None, + eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=data_collator, compute_metrics=compute_metrics, ) @@ -243,58 +323,50 @@ def compute_metrics(p: EvalPrediction) -> Dict: trainer.train( model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None ) - trainer.save_model() - # For convenience, we also re-save the tokenizer to the same directory, - # so that you can share your model easily on huggingface.co/models =) - if trainer.is_world_master(): - tokenizer.save_pretrained(training_args.output_dir) + trainer.save_model() # Saves the tokenizer too for easy upload # Evaluation results = {} if training_args.do_eval: logger.info("*** Evaluate ***") - result = trainer.evaluate() + results = trainer.evaluate() - output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") - if trainer.is_world_master(): + output_eval_file = os.path.join(training_args.output_dir, "eval_results_ner.txt") + if trainer.is_world_process_zero(): with open(output_eval_file, "w") as writer: logger.info("***** Eval results *****") - for key, value in result.items(): - logger.info(" %s = %s", key, value) - writer.write("%s = %s\n" % (key, value)) - - results.update(result) + for key, value in results.items(): + logger.info(f" {key} = {value}") + writer.write(f"{key} = {value}\n") # Predict if training_args.do_predict: - test_dataset = TokenClassificationDataset( - token_classification_task=token_classification_task, - data_dir=data_args.data_dir, - tokenizer=tokenizer, - labels=labels, - model_type=config.model_type, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.test, - ) + logger.info("*** Predict ***") + + test_dataset = datasets["test"] + predictions, labels, metrics = trainer.predict(test_dataset) + predictions = np.argmax(predictions, axis=2) - predictions, label_ids, metrics = trainer.predict(test_dataset) - preds_list, _ = align_predictions(predictions, label_ids) + # Remove ignored index (special tokens) + true_predictions = [ + [label_list[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt") if trainer.is_world_master(): with open(output_test_results_file, "w") as writer: for key, value in metrics.items(): - logger.info(" %s = %s", key, value) - writer.write("%s = %s\n" % (key, value)) + logger.info(f" {key} = {value}") + writer.write(f"{key} = {value}\n") # Save predictions output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt") if trainer.is_world_master(): with open(output_test_predictions_file, "w") as writer: - with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f: - token_classification_task.write_predictions_to_file(writer, f, preds_list) + for prediction in true_predictions: + writer.write(" ".join(prediction) + "\n") return results diff --git a/examples/token-classification/run_ner_old.py b/examples/token-classification/run_ner_old.py new file mode 100644 index 000000000000..a2981415f690 --- /dev/null +++ b/examples/token-classification/run_ner_old.py @@ -0,0 +1,308 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Fine-tuning the library models for named entity recognition on CoNLL-2003. """ +import logging +import os +import sys +from dataclasses import dataclass, field +from importlib import import_module +from typing import Dict, List, Optional, Tuple + +import numpy as np +from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score +from torch import nn + +from transformers import ( + AutoConfig, + AutoModelForTokenClassification, + AutoTokenizer, + EvalPrediction, + HfArgumentParser, + Trainer, + TrainingArguments, + set_seed, +) +from utils_ner import Split, TokenClassificationDataset, TokenClassificationTask + + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + task_type: Optional[str] = field( + default="NER", metadata={"help": "Task type to fine tune in training (e.g. NER, POS, etc)"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."}) + # If you want to tweak more attributes on your tokenizer, you should do it in a distinct script, + # or just modify its tokenizer_config.json. + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + data_dir: str = field( + metadata={"help": "The input data dir. Should contain the .txt files for a CoNLL-2003-formatted task."} + ) + labels: Optional[str] = field( + default=None, + metadata={"help": "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."}, + ) + max_seq_length: int = field( + default=128, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." + ) + + module = import_module("tasks") + try: + token_classification_task_clazz = getattr(module, model_args.task_type) + token_classification_task: TokenClassificationTask = token_classification_task_clazz() + except AttributeError: + raise ValueError( + f"Task {model_args.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. " + f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}" + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + training_args.local_rank, + training_args.device, + training_args.n_gpu, + bool(training_args.local_rank != -1), + training_args.fp16, + ) + logger.info("Training/evaluation parameters %s", training_args) + + # Set seed + set_seed(training_args.seed) + + # Prepare CONLL-2003 task + labels = token_classification_task.get_labels(data_args.labels) + label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)} + num_labels = len(labels) + + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + id2label=label_map, + label2id={label: i for i, label in enumerate(labels)}, + cache_dir=model_args.cache_dir, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast, + ) + model = AutoModelForTokenClassification.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + ) + + # Get datasets + train_dataset = ( + TokenClassificationDataset( + token_classification_task=token_classification_task, + data_dir=data_args.data_dir, + tokenizer=tokenizer, + labels=labels, + model_type=config.model_type, + max_seq_length=data_args.max_seq_length, + overwrite_cache=data_args.overwrite_cache, + mode=Split.train, + ) + if training_args.do_train + else None + ) + eval_dataset = ( + TokenClassificationDataset( + token_classification_task=token_classification_task, + data_dir=data_args.data_dir, + tokenizer=tokenizer, + labels=labels, + model_type=config.model_type, + max_seq_length=data_args.max_seq_length, + overwrite_cache=data_args.overwrite_cache, + mode=Split.dev, + ) + if training_args.do_eval + else None + ) + + def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[List[int], List[int]]: + preds = np.argmax(predictions, axis=2) + + batch_size, seq_len = preds.shape + + out_label_list = [[] for _ in range(batch_size)] + preds_list = [[] for _ in range(batch_size)] + + for i in range(batch_size): + for j in range(seq_len): + if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index: + out_label_list[i].append(label_map[label_ids[i][j]]) + preds_list[i].append(label_map[preds[i][j]]) + + return preds_list, out_label_list + + def compute_metrics(p: EvalPrediction) -> Dict: + preds_list, out_label_list = align_predictions(p.predictions, p.label_ids) + return { + "accuracy_score": accuracy_score(out_label_list, preds_list), + "precision": precision_score(out_label_list, preds_list), + "recall": recall_score(out_label_list, preds_list), + "f1": f1_score(out_label_list, preds_list), + } + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics, + ) + + # Training + if training_args.do_train: + trainer.train( + model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None + ) + trainer.save_model() + # For convenience, we also re-save the tokenizer to the same directory, + # so that you can share your model easily on huggingface.co/models =) + if trainer.is_world_master(): + tokenizer.save_pretrained(training_args.output_dir) + + # Evaluation + results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + + result = trainer.evaluate() + + output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") + if trainer.is_world_master(): + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results *****") + for key, value in result.items(): + logger.info(" %s = %s", key, value) + writer.write("%s = %s\n" % (key, value)) + + results.update(result) + + # Predict + if training_args.do_predict: + test_dataset = TokenClassificationDataset( + token_classification_task=token_classification_task, + data_dir=data_args.data_dir, + tokenizer=tokenizer, + labels=labels, + model_type=config.model_type, + max_seq_length=data_args.max_seq_length, + overwrite_cache=data_args.overwrite_cache, + mode=Split.test, + ) + + predictions, label_ids, metrics = trainer.predict(test_dataset) + preds_list, _ = align_predictions(predictions, label_ids) + + output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt") + if trainer.is_world_master(): + with open(output_test_results_file, "w") as writer: + for key, value in metrics.items(): + logger.info(" %s = %s", key, value) + writer.write("%s = %s\n" % (key, value)) + + # Save predictions + output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt") + if trainer.is_world_master(): + with open(output_test_predictions_file, "w") as writer: + with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f: + token_classification_task.write_predictions_to_file(writer, f, preds_list) + + return results + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/examples/token-classification/run_old.sh b/examples/token-classification/run_old.sh new file mode 100755 index 000000000000..90cb4484d0a6 --- /dev/null +++ b/examples/token-classification/run_old.sh @@ -0,0 +1,36 @@ +## The relevant files are currently on a shared Google +## drive at https://drive.google.com/drive/folders/1kC0I2UGl2ltrluI9NqDjaQJGw5iliw_J +## Monitor for changes and eventually migrate to nlp dataset +curl -L 'https://drive.google.com/uc?export=download&id=1Jjhbal535VVz2ap4v4r_rN1UEHTdLK5P' \ +| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > train.txt.tmp +curl -L 'https://drive.google.com/uc?export=download&id=1ZfRcQThdtAR5PPRjIDtrVP7BtXSCUBbm' \ +| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp +curl -L 'https://drive.google.com/uc?export=download&id=1u9mb7kNJHWQCWyweMDRMuTFoOHOfeBTH' \ +| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp + +export MAX_LENGTH=128 +export BERT_MODEL=bert-base-multilingual-cased +python3 scripts/preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt +python3 scripts/preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt +python3 scripts/preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt +cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt +export OUTPUT_DIR=germeval-model +export BATCH_SIZE=32 +export NUM_EPOCHS=3 +export SAVE_STEPS=750 +export SEED=1 + +python3 run_ner_old.py \ +--task_type NER \ +--data_dir . \ +--labels ./labels.txt \ +--model_name_or_path $BERT_MODEL \ +--output_dir $OUTPUT_DIR \ +--max_seq_length $MAX_LENGTH \ +--num_train_epochs $NUM_EPOCHS \ +--per_gpu_train_batch_size $BATCH_SIZE \ +--save_steps $SAVE_STEPS \ +--seed $SEED \ +--do_train \ +--do_eval \ +--do_predict diff --git a/examples/token-classification/run_pos.sh b/examples/token-classification/run_pos.sh index 7d76ed8a2a8a..f4e05058a3dc 100755 --- a/examples/token-classification/run_pos.sh +++ b/examples/token-classification/run_pos.sh @@ -21,7 +21,7 @@ export NUM_EPOCHS=3 export SAVE_STEPS=750 export SEED=1 -python3 run_ner.py \ +python3 _old.py \ --task_type POS \ --data_dir . \ --model_name_or_path $BERT_MODEL \ diff --git a/examples/token-classification/test_ner_examples.py b/examples/token-classification/test_ner_examples.py index d6bb0b25fa3b..6ecb421a7dbb 100644 --- a/examples/token-classification/test_ner_examples.py +++ b/examples/token-classification/test_ner_examples.py @@ -3,7 +3,7 @@ import unittest from unittest.mock import patch -import run_ner +import run_ner_old as run_ner from transformers.testing_utils import slow diff --git a/model_cards/mrm8488/RuPERTa-base-finetuned-ner/README.md b/model_cards/mrm8488/RuPERTa-base-finetuned-ner/README.md index f31b0e37c104..5b4524001ede 100644 --- a/model_cards/mrm8488/RuPERTa-base-finetuned-ner/README.md +++ b/model_cards/mrm8488/RuPERTa-base-finetuned-ner/README.md @@ -17,7 +17,7 @@ This model is a fine-tuned on [NER-C](https://www.kaggle.com/nltkdata/conll-corp | Dev | 40 K | -- [Fine-tune on NER script provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner.py) +- [Fine-tune on NER script provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner_old.py) - Labels covered: diff --git a/model_cards/mrm8488/RuPERTa-base-finetuned-pos/README.md b/model_cards/mrm8488/RuPERTa-base-finetuned-pos/README.md index e101381f521e..26865503ff4f 100644 --- a/model_cards/mrm8488/RuPERTa-base-finetuned-pos/README.md +++ b/model_cards/mrm8488/RuPERTa-base-finetuned-pos/README.md @@ -16,7 +16,7 @@ This model is a fine-tuned on [CONLL CORPORA](https://www.kaggle.com/nltkdata/co | Train | 445 K | | Dev | 55 K | -- [Fine-tune on NER script provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner.py) +- [Fine-tune on NER script provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner_old.py) - Labels covered: diff --git a/model_cards/mrm8488/TinyBERT-spanish-uncased-finetuned-ner/README.md b/model_cards/mrm8488/TinyBERT-spanish-uncased-finetuned-ner/README.md index 1a595fe1ba5a..7f2f6a9d2f68 100644 --- a/model_cards/mrm8488/TinyBERT-spanish-uncased-finetuned-ner/README.md +++ b/model_cards/mrm8488/TinyBERT-spanish-uncased-finetuned-ner/README.md @@ -19,7 +19,7 @@ I preprocessed the dataset and split it as train / dev (80/20) | Dev | 2.2 K | -- [Fine-tune on NER script provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner.py) +- [Fine-tune on NER script provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner_old.py) - Labels covered: diff --git a/model_cards/mrm8488/bert-base-german-finetuned-ler/README.md b/model_cards/mrm8488/bert-base-german-finetuned-ler/README.md index 2374fcfc3cfd..dfe02be656bf 100644 --- a/model_cards/mrm8488/bert-base-german-finetuned-ler/README.md +++ b/model_cards/mrm8488/bert-base-german-finetuned-ler/README.md @@ -18,7 +18,7 @@ Court decisions from 2017 and 2018 were selected for the dataset, published onli | Train | 1657048 | | Eval | 500000 | -- Training script: [Fine-tuning script for NER provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner.py) +- Training script: [Fine-tuning script for NER provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner_old.py) Colab: [How to fine-tune a model for NER using HF scripts](https://colab.research.google.com/drive/156Qrd7NsUHwA3nmQ6gXdZY0NzOvqk9AT?usp=sharing) - Labels covered (and its distribution): diff --git a/model_cards/mrm8488/bert-small-finetuned-typo-detection/README.md b/model_cards/mrm8488/bert-small-finetuned-typo-detection/README.md index 8b9c4649923e..1e2c83436ad3 100644 --- a/model_cards/mrm8488/bert-small-finetuned-typo-detection/README.md +++ b/model_cards/mrm8488/bert-small-finetuned-typo-detection/README.md @@ -11,7 +11,7 @@ thumbnail: - Dataset: [GitHub Typo Corpus](https://github.com/mhagiwara/github-typo-corpus) 📚 -- [Fine-tune script on NER dataset provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner.py) 🏋️‍♂️ +- [Fine-tune script on NER dataset provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner_old.py) 🏋️‍♂️ ## Metrics on test set 📋 diff --git a/model_cards/mrm8488/bert-spanish-cased-finetuned-ner/README.md b/model_cards/mrm8488/bert-spanish-cased-finetuned-ner/README.md index 67465c9ea847..4468b57f978d 100644 --- a/model_cards/mrm8488/bert-spanish-cased-finetuned-ner/README.md +++ b/model_cards/mrm8488/bert-spanish-cased-finetuned-ner/README.md @@ -19,7 +19,7 @@ I preprocessed the dataset and split it as train / dev (80/20) | Dev | 2.2 K | -- [Fine-tune on NER script provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner.py) +- [Fine-tune on NER script provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner_old.py) - Labels covered: diff --git a/model_cards/mrm8488/bert-spanish-cased-finetuned-pos-syntax/README.md b/model_cards/mrm8488/bert-spanish-cased-finetuned-pos-syntax/README.md index 266906a532b4..54bb61e2b2ad 100644 --- a/model_cards/mrm8488/bert-spanish-cased-finetuned-pos-syntax/README.md +++ b/model_cards/mrm8488/bert-spanish-cased-finetuned-pos-syntax/README.md @@ -11,7 +11,7 @@ This model is a fine-tuned version of the Spanish BERT [(BETO)](https://github.c - [Dataset: CONLL Corpora ES](https://www.kaggle.com/nltkdata/conll-corpora) -#### [Fine-tune script on NER dataset provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner.py) +#### [Fine-tune script on NER dataset provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner_old.py) #### 21 Syntax annotations (Labels) covered: diff --git a/model_cards/mrm8488/bert-spanish-cased-finetuned-pos/README.md b/model_cards/mrm8488/bert-spanish-cased-finetuned-pos/README.md index e1827e4effa6..356dd0f5ab93 100644 --- a/model_cards/mrm8488/bert-spanish-cased-finetuned-pos/README.md +++ b/model_cards/mrm8488/bert-spanish-cased-finetuned-pos/README.md @@ -19,7 +19,7 @@ I preprocessed the dataset and split it as train / dev (80/20) | Dev | 50 K | -- [Fine-tune on NER script provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner.py) +- [Fine-tune on NER script provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner_old.py) - **60** Labels covered: diff --git a/model_cards/mrm8488/distilbert-base-multi-cased-finetuned-typo-detection/README.md b/model_cards/mrm8488/distilbert-base-multi-cased-finetuned-typo-detection/README.md index 354a25df84e7..009bc1522c38 100644 --- a/model_cards/mrm8488/distilbert-base-multi-cased-finetuned-typo-detection/README.md +++ b/model_cards/mrm8488/distilbert-base-multi-cased-finetuned-typo-detection/README.md @@ -11,7 +11,7 @@ thumbnail: - Dataset: [GitHub Typo Corpus](https://github.com/mhagiwara/github-typo-corpus) 📚 for 15 languages -- [Fine-tune script on NER dataset provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner.py) 🏋️‍♂️ +- [Fine-tune script on NER dataset provided by Huggingface](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner_old.py) 🏋️‍♂️ ## Metrics on test set 📋 diff --git a/model_cards/savasy/bert-base-turkish-ner-cased/README.md b/model_cards/savasy/bert-base-turkish-ner-cased/README.md index 815ec0a95717..079575985cb1 100644 --- a/model_cards/savasy/bert-base-turkish-ner-cased/README.md +++ b/model_cards/savasy/bert-base-turkish-ner-cased/README.md @@ -32,7 +32,7 @@ export SEED=1 ``` Then run pre-training: ``` -python3 run_ner.py --data_dir ./tr-data3 \ +python3 run_ner_old.py --data_dir ./tr-data3 \ --model_type bert \ --labels ./tr-data/labels.txt \ --model_name_or_path $BERT_MODEL \ diff --git a/tests/fixtures/tests_samples/conll/sample.json b/tests/fixtures/tests_samples/conll/sample.json new file mode 100644 index 000000000000..0bc42a92fe8c --- /dev/null +++ b/tests/fixtures/tests_samples/conll/sample.json @@ -0,0 +1,10 @@ +{"words": ["He", "was", "the", "27th", "pitcher", "used", "by", "the", "Angels", "this", "season", ",", "tying", "a", "major-league", "record", "."], "ner": ["O", "O", "O", "O", "O", "O", "O", "O", "B-ORG", "O", "O", "O", "O", "O", "O", "O", "O"]} +{"words": ["CHICAGO", "AT", "ATLANTA"], "ner": ["B-ORG", "O", "B-LOC"]} +{"words": ["President", "Bill", "Clinton", "earlier", "this", "month", "invoked", "special", "powers", "to", "appoint", "Fowler", "during", "the", "congressional", "recess", "because", "the", "Senate", "delayed", "confirming", "his", "nomination", "."], "ner": ["O", "B-PER", "I-PER", "O", "O", "O", "O", "O", "O", "O", "O", "B-PER", "O", "O", "O", "O", "O", "O", "B-ORG", "O", "O", "O", "O", "O"]} +{"words": ["goals", "for", ",", "goals", "against", ",", "points", ")", "."], "ner": ["O", "O", "O", "O", "O", "O", "O", "O", "O"]} +{"words": ["\"", "It", "is", "one", "step", "short", "of", "an", "emergency", "situation", ",", "\"", "a", "police", "spokesman", "said", "via", "telephone", "from", "a", "command", "post", "in", "the", "bush", "."], "ner": ["O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"]} +{"words": ["U.S.", "Ambassador", "Myles", "Frechette", "applauded", "the", "move", ",", "saying", "it", "could", "prompt", "the", "Clinton", "administration", "to", "remove", "Colombia", "from", "a", "list", "of", "outcast", "nations", "that", "have", "failed", "to", "cooperate", "in", "U.S.", "counternarcotics", "efforts", "."], "ner": ["B-LOC", "O", "B-PER", "I-PER", "O", "O", "O", "O", "O", "O", "O", "O", "O", "B-PER", "O", "O", "O", "B-LOC", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "B-LOC", "O", "O", "O"]} +{"words": ["Halftime"], "ner": ["O"]} +{"words": ["It", "has", "manufacturing", "plants", "in", "San", "Diego", ";", "Creedmoor", ",", "N.C.", ";", "Hampshire", ",", "England", ";", "and", "Tijuana", ",", "Mexico", ",", "and", "distributes", "its", "prodcuts", "in", "more", "than", "120", "countries", "."], "ner": ["O", "O", "O", "O", "O", "B-LOC", "I-LOC", "O", "B-LOC", "O", "B-LOC", "O", "B-LOC", "O", "B-LOC", "O", "O", "B-LOC", "O", "B-LOC", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"]} +{"words": ["Scotland", "manager", "Craig", "Brown", "said", "on", "Thursday", ":", "\"", "I", "'ve", "watched", "Duncan", "Ferguson", "in", "action", "twice", "recently", "and", "he", "'s", "bang", "in", "form", "."], "ner": ["B-LOC", "O", "B-PER", "I-PER", "O", "O", "O", "O", "O", "O", "O", "O", "B-PER", "I-PER", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"]} +{"words": ["Clinton", "flew", "in", "by", "helicopter", "from", "Michigan", "City", ",", "Indiana", ",", "after", "ending", "a", "four-day", ",", "559-mile", "trip", "aboard", "a", "campaign", "train", "from", "Washington", "."], "ner": ["B-PER", "O", "O", "O", "O", "O", "B-LOC", "I-LOC", "O", "B-LOC", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "B-LOC", "O"]} \ No newline at end of file