From 1e6ab4820bfe2c687dcd9fe82496dc20f2b080d9 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Tue, 27 Oct 2020 17:08:09 -0400 Subject: [PATCH 1/7] New run_clm script --- examples/language-modeling/run_clm.py | 311 ++++++++++++++++++++++++++ examples/test_examples.py | 33 +++ 2 files changed, 344 insertions(+) create mode 100644 examples/language-modeling/run_clm.py diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py new file mode 100644 index 00000000000000..80537a8f465602 --- /dev/null +++ b/examples/language-modeling/run_clm.py @@ -0,0 +1,311 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL) on a text file or a dataset. +""" + + +import logging +import math +import os +import sys +from dataclasses import dataclass, field +from glob import glob +from typing import Optional + +import numpy as np +from datasets import load_dataset, load_metric + +import transformers +from transformers import ( + CONFIG_MAPPING, + MODEL_FOR_CAUSAL_LM_MAPPING, + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + HfArgumentParser, + PreTrainedTokenizer, + Trainer, + TrainingArguments, + default_data_collator, + set_seed, +) + +from transformers.trainer_utils import is_main_process + + +logger = logging.getLogger(__name__) + + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch." + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a text file)."} + ) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + block_size: int = field( + default=-1, + metadata={ + "help": "Optional input sequence length after tokenization." + "The training dataset will be truncated in block of this size for training." + "Default to the model max input length for single sentence inputs (take into account special tokens)." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty." + "Use --overwrite_output_dir to overcome." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, + ) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + logger.info("Training/evaluation parameters %s", training_args) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub + # + # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this + # behavior (see below) + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + datasets = load_dataset(extension, data_files=data_files) + # See more about loading any type of standard or custom dataset at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if model_args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) + elif model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if model_args.model_name_or_path: + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForCausalLM.from_config(config) + + model.resize_token_embeddings(len(tokenizer)) + + # Preprocessing the datasets. + # First we tokenize all the texts. + if training_args.do_train: + column_names = datasets["train"].column_names + else: + column_names = datasets["validation"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + def tokenize_function(examples): + return tokenizer(examples[text_column_name]) + + tokenized_datasets = datasets.map(tokenize_function, batched=True, remove_columns=[text_column_name], load_from_cache_file=not data_args.overwrite_cache) + + if data_args.block_size <= 0: + block_size = tokenizer.max_len + else: + block_size = min(data_args.block_size, tokenizer.max_len) + + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len + result = {k: [t[i: i+block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()} + result["labels"] = result["input_ids"].copy() + return result + + lm_datasets = tokenized_datasets.map(group_texts, batched=True, load_from_cache_file=not data_args.overwrite_cache) + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=lm_datasets["train"] if training_args.do_train else None, + eval_dataset=lm_datasets["validation"] if training_args.do_eval else None, + tokenizer=tokenizer, + # Data collator will default to DataCollatorWithPadding, so we change it. + data_collator=default_data_collator, + ) + + # Training + if training_args.do_train: + trainer.train( + model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None + ) + trainer.save_model() # Saves the tokenizer too for easy upload + + # Evaluation + results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + + eval_output = trainer.evaluate() + + perplexity = math.exp(eval_output["eval_loss"]) + results["perplexity"] = perplexity + + output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt") + if trainer.is_world_process_zero(): + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results *****") + for key, value in results.items(): + logger.info(f" {key} = {value}") + writer.write(f"{key} = {value}\n") + + return results + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/examples/test_examples.py b/examples/test_examples.py index b17a73dc5123a3..b73b58745583b1 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -34,6 +34,7 @@ if SRC_DIRS is not None: + import run_clm import run_generation import run_glue import run_language_modeling @@ -127,6 +128,38 @@ def test_run_pl_glue(self): # for k, v in result.items(): # self.assertGreaterEqual(v, 0.75, f"({k})") # + + def test_run_clm(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_clm.py + --model_name_or_path distilgpt2 + --train_file ./tests/fixtures/sample_text.txt + --validation_file ./tests/fixtures/sample_text.txt + --do_train + --do_eval + --block_size 128 + --per_device_train_batch_size 5 + --per_device_eval_batch_size 5 + --num_train_epochs 2 + --output_dir {tmp_dir} + --overwrite_output_dir + --prediction_loss_only + """.split() + + if torch.cuda.device_count() > 1: + # Skipping because there are not enough batches to train the model + would need a drop_last to work. + return + + if torch_device != "cuda": + testargs.append("--no_cuda") + + with patch.object(sys, "argv", testargs): + result = run_clm.main() + self.assertLess(result["perplexity"], 100) def test_run_language_modeling(self): stream_handler = logging.StreamHandler(sys.stdout) From c8f097781ef3f23e0f5fa565c1f4906fe512a012 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Tue, 27 Oct 2020 17:08:25 -0400 Subject: [PATCH 2/7] Formatting --- examples/language-modeling/run_clm.py | 30 ++++++++++++++++++--------- examples/test_examples.py | 4 ++-- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 80537a8f465602..85bb354d463cf8 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -43,7 +43,6 @@ default_data_collator, set_seed, ) - from transformers.trainer_utils import is_main_process @@ -90,15 +89,14 @@ class DataTrainingArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ + dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} ) dataset_config_name: Optional[str] = field( default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} ) - train_file: Optional[str] = field( - default=None, metadata={"help": "The input training data file (a text file)."} - ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) validation_file: Optional[str] = field( default=None, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, @@ -114,7 +112,7 @@ class DataTrainingArguments: overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) - + def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: raise ValueError("Need either a dataset name or a training/validation file.") @@ -210,9 +208,13 @@ def main(): logger.warning("You are instantiating a new config instance from scratch.") if model_args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer + ) elif model_args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer + ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." @@ -243,7 +245,12 @@ def main(): def tokenize_function(examples): return tokenizer(examples[text_column_name]) - tokenized_datasets = datasets.map(tokenize_function, batched=True, remove_columns=[text_column_name], load_from_cache_file=not data_args.overwrite_cache) + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + remove_columns=[text_column_name], + load_from_cache_file=not data_args.overwrite_cache, + ) if data_args.block_size <= 0: block_size = tokenizer.max_len @@ -257,10 +264,13 @@ def group_texts(examples): # We drop the small remainder total_length = (total_length // block_size) * block_size # Split by chunks of max_len - result = {k: [t[i: i+block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()} + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } result["labels"] = result["input_ids"].copy() return result - + lm_datasets = tokenized_datasets.map(group_texts, batched=True, load_from_cache_file=not data_args.overwrite_cache) # Initialize our Trainer diff --git a/examples/test_examples.py b/examples/test_examples.py index b73b58745583b1..f5252cdd63d41f 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -128,7 +128,7 @@ def test_run_pl_glue(self): # for k, v in result.items(): # self.assertGreaterEqual(v, 0.75, f"({k})") # - + def test_run_clm(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) @@ -149,7 +149,7 @@ def test_run_clm(self): --overwrite_output_dir --prediction_loss_only """.split() - + if torch.cuda.device_count() > 1: # Skipping because there are not enough batches to train the model + would need a drop_last to work. return From c2f7910bdf0da1d6d553c9d287ddfe4678a89b5f Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Tue, 27 Oct 2020 17:12:13 -0400 Subject: [PATCH 3/7] More comments --- examples/language-modeling/run_clm.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 85bb354d463cf8..57b21a5101a804 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -257,13 +257,15 @@ def tokenize_function(examples): else: block_size = min(data_args.block_size, tokenizer.max_len) + # Main function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. total_length = (total_length // block_size) * block_size - # Split by chunks of max_len + # Split by chunks of max_len. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() @@ -271,6 +273,9 @@ def group_texts(examples): result["labels"] = result["input_ids"].copy() return result + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. lm_datasets = tokenized_datasets.map(group_texts, batched=True, load_from_cache_file=not data_args.overwrite_cache) # Initialize our Trainer From 29a8f64f391c44503edde237c1741fdcffd171e6 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Tue, 27 Oct 2020 17:15:52 -0400 Subject: [PATCH 4/7] Remove unused imports --- examples/language-modeling/run_clm.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 57b21a5101a804..12f43d575efff5 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -23,11 +23,9 @@ import os import sys from dataclasses import dataclass, field -from glob import glob from typing import Optional -import numpy as np -from datasets import load_dataset, load_metric +from datasets import load_dataset import transformers from transformers import ( @@ -37,7 +35,6 @@ AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, - PreTrainedTokenizer, Trainer, TrainingArguments, default_data_collator, From c265b358de2bc19611809e0c273fd93d5d886c85 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 28 Oct 2020 09:41:43 -0400 Subject: [PATCH 5/7] Apply suggestions from code review Co-authored-by: Thomas Wolf --- examples/language-modeling/run_clm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 12f43d575efff5..686b1fdd59d1a2 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -167,7 +167,7 @@ def main(): set_seed(training_args.seed) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) - # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ (the dataset will be downloaded automatically from the datasets Hub # # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this # behavior (see below) @@ -187,7 +187,7 @@ def main(): if extension == "txt": extension = "text" datasets = load_dataset(extension, data_files=data_files) - # See more about loading any type of standard or custom dataset at + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer @@ -254,7 +254,7 @@ def tokenize_function(examples): else: block_size = min(data_args.block_size, tokenizer.max_len) - # Main function that will concatenate all texts from our dataset and generate chunks of block_size. + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} From 47dfb039e20b332e5f9ff4e839f56fb9dff47f41 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 28 Oct 2020 10:27:21 -0400 Subject: [PATCH 6/7] Address review comments --- examples/language-modeling/run_clm.py | 36 +++++++++++++++++++----- examples/text-classification/run_glue.py | 3 +- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 686b1fdd59d1a2..154f47f7222464 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -1,6 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2020 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL) on a text file or a dataset. -""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. +Find the full list of model architectures that can be fine-tuned by this script on the documentation: +https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoModelForCausalLM +""" +# You can also adapt this script on your own text classification task. Pointers for this are left as comments. import logging import math @@ -59,7 +61,8 @@ class ModelArguments: model_name_or_path: Optional[str] = field( default=None, metadata={ - "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch." + "help": "The model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." }, ) model_type: Optional[str] = field( @@ -109,6 +112,10 @@ class DataTrainingArguments: overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: @@ -167,7 +174,8 @@ def main(): set_seed(training_args.seed) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) - # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ (the dataset will be downloaded automatically from the datasets Hub + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub # # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this # behavior (see below) @@ -245,6 +253,7 @@ def tokenize_function(examples): tokenized_datasets = datasets.map( tokenize_function, batched=True, + num_proc=data_args.preprocessing_num_workers, remove_columns=[text_column_name], load_from_cache_file=not data_args.overwrite_cache, ) @@ -252,6 +261,11 @@ def tokenize_function(examples): if data_args.block_size <= 0: block_size = tokenizer.max_len else: + if data_args.block_size > tokenizer.max_len: + logger.warn( + f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.max_len}). Using block_size={tokenizer.max_len}." + ) block_size = min(data_args.block_size, tokenizer.max_len) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. @@ -273,7 +287,15 @@ def group_texts(examples): # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower # to preprocess. - lm_datasets = tokenized_datasets.map(group_texts, batched=True, load_from_cache_file=not data_args.overwrite_cache) + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + ) # Initialize our Trainer trainer = Trainer( diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index d571f286aa6a09..8c7e2cedad725a 100644 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -1,6 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2020 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 7821071073aca0a9c21e8b8004bffff5ac7c43b7 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 28 Oct 2020 10:30:14 -0400 Subject: [PATCH 7/7] Change link to the hub --- examples/language-modeling/run_clm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 154f47f7222464..024427e257abc6 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -15,8 +15,8 @@ """ Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. -Find the full list of model architectures that can be fine-tuned by this script on the documentation: -https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoModelForCausalLM +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=causal-lm """ # You can also adapt this script on your own text classification task. Pointers for this are left as comments.