-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New run_clm script #8105
New run_clm script #8105
Changes from 5 commits
1e6ab48
c8f0977
c2f7910
29a8f64
c265b35
47dfb03
7821071
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,323 @@ | ||||||||||||||||||
# coding=utf-8 | ||||||||||||||||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | ||||||||||||||||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | ||||||||||||||||||
# | ||||||||||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||
# you may not use this file except in compliance with the License. | ||||||||||||||||||
# You may obtain a copy of the License at | ||||||||||||||||||
# | ||||||||||||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||
# | ||||||||||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||
# limitations under the License. | ||||||||||||||||||
""" | ||||||||||||||||||
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL) on a text file or a dataset. | ||||||||||||||||||
""" | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From experience, users will understand that only GPT, GPT-2 and CTRL are supported by that script. I would put
Suggested change
But that might be a bit too much. Maybe adding a README would be simpler. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not AutomodelWithLMHead, just CausalLM, but I can add that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe the link is overkill, I just had an issue with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or this link? https://huggingface.co/models?filter=lm-head There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This works too but this shows checkpoints, whereas this script can also train from scratch so showing architectures would probably be better There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No this links shows all kinds of LM. The script will only work with a model that can be loaded with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (the other one is the deprecated one, will remove soon) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great! |
||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
import logging | ||||||||||||||||||
import math | ||||||||||||||||||
import os | ||||||||||||||||||
import sys | ||||||||||||||||||
from dataclasses import dataclass, field | ||||||||||||||||||
from typing import Optional | ||||||||||||||||||
|
||||||||||||||||||
from datasets import load_dataset | ||||||||||||||||||
|
||||||||||||||||||
import transformers | ||||||||||||||||||
from transformers import ( | ||||||||||||||||||
CONFIG_MAPPING, | ||||||||||||||||||
MODEL_FOR_CAUSAL_LM_MAPPING, | ||||||||||||||||||
AutoConfig, | ||||||||||||||||||
AutoModelForCausalLM, | ||||||||||||||||||
AutoTokenizer, | ||||||||||||||||||
HfArgumentParser, | ||||||||||||||||||
Trainer, | ||||||||||||||||||
TrainingArguments, | ||||||||||||||||||
default_data_collator, | ||||||||||||||||||
set_seed, | ||||||||||||||||||
) | ||||||||||||||||||
from transformers.trainer_utils import is_main_process | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
logger = logging.getLogger(__name__) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, for the script, we should use the regular one. @LysandreJik had a very long explanation of why that I don't remember. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The gist of it is that imo the In my opinion the control of logging in a user script should contain both: import logging
from transformers import logging as hf_logging
hf_logging.set_verbosity_xxx()
logger = logging.getLogger(__name__)
# then do stuff with the logger without worrying about the HF logging which has already been managed before
logger.warn("xxx") |
||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) | ||||||||||||||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
@dataclass | ||||||||||||||||||
class ModelArguments: | ||||||||||||||||||
""" | ||||||||||||||||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. | ||||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
model_name_or_path: Optional[str] = field( | ||||||||||||||||||
default=None, | ||||||||||||||||||
metadata={ | ||||||||||||||||||
"help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch." | ||||||||||||||||||
}, | ||||||||||||||||||
) | ||||||||||||||||||
model_type: Optional[str] = field( | ||||||||||||||||||
default=None, | ||||||||||||||||||
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, | ||||||||||||||||||
) | ||||||||||||||||||
config_name: Optional[str] = field( | ||||||||||||||||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} | ||||||||||||||||||
) | ||||||||||||||||||
tokenizer_name: Optional[str] = field( | ||||||||||||||||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} | ||||||||||||||||||
) | ||||||||||||||||||
cache_dir: Optional[str] = field( | ||||||||||||||||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} | ||||||||||||||||||
) | ||||||||||||||||||
use_fast_tokenizer: bool = field( | ||||||||||||||||||
default=True, | ||||||||||||||||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
@dataclass | ||||||||||||||||||
class DataTrainingArguments: | ||||||||||||||||||
""" | ||||||||||||||||||
Arguments pertaining to what data we are going to input our model for training and eval. | ||||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
dataset_name: Optional[str] = field( | ||||||||||||||||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} | ||||||||||||||||||
) | ||||||||||||||||||
dataset_config_name: Optional[str] = field( | ||||||||||||||||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} | ||||||||||||||||||
) | ||||||||||||||||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) | ||||||||||||||||||
validation_file: Optional[str] = field( | ||||||||||||||||||
default=None, | ||||||||||||||||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, | ||||||||||||||||||
) | ||||||||||||||||||
block_size: int = field( | ||||||||||||||||||
default=-1, | ||||||||||||||||||
metadata={ | ||||||||||||||||||
"help": "Optional input sequence length after tokenization." | ||||||||||||||||||
"The training dataset will be truncated in block of this size for training." | ||||||||||||||||||
"Default to the model max input length for single sentence inputs (take into account special tokens)." | ||||||||||||||||||
}, | ||||||||||||||||||
) | ||||||||||||||||||
overwrite_cache: bool = field( | ||||||||||||||||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
def __post_init__(self): | ||||||||||||||||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None: | ||||||||||||||||||
raise ValueError("Need either a dataset name or a training/validation file.") | ||||||||||||||||||
else: | ||||||||||||||||||
if self.train_file is not None: | ||||||||||||||||||
extension = self.train_file.split(".")[-1] | ||||||||||||||||||
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." | ||||||||||||||||||
if self.validation_file is not None: | ||||||||||||||||||
extension = self.validation_file.split(".")[-1] | ||||||||||||||||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def main(): | ||||||||||||||||||
# See all possible arguments in src/transformers/training_args.py | ||||||||||||||||||
# or by passing the --help flag to this script. | ||||||||||||||||||
# We now keep distinct sets of args, for a cleaner separation of concerns. | ||||||||||||||||||
|
||||||||||||||||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) | ||||||||||||||||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): | ||||||||||||||||||
# If we pass only one argument to the script and it's the path to a json file, | ||||||||||||||||||
# let's parse it to get our arguments. | ||||||||||||||||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) | ||||||||||||||||||
else: | ||||||||||||||||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | ||||||||||||||||||
|
||||||||||||||||||
if ( | ||||||||||||||||||
os.path.exists(training_args.output_dir) | ||||||||||||||||||
and os.listdir(training_args.output_dir) | ||||||||||||||||||
and training_args.do_train | ||||||||||||||||||
and not training_args.overwrite_output_dir | ||||||||||||||||||
): | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
f"Output directory ({training_args.output_dir}) already exists and is not empty." | ||||||||||||||||||
"Use --overwrite_output_dir to overcome." | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
# Setup logging | ||||||||||||||||||
logging.basicConfig( | ||||||||||||||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||||||||||||||||||
datefmt="%m/%d/%Y %H:%M:%S", | ||||||||||||||||||
level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
# Log on each process the small summary: | ||||||||||||||||||
logger.warning( | ||||||||||||||||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" | ||||||||||||||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" | ||||||||||||||||||
) | ||||||||||||||||||
# Set the verbosity to info of the Transformers logger (on main process only): | ||||||||||||||||||
if is_main_process(training_args.local_rank): | ||||||||||||||||||
transformers.utils.logging.set_verbosity_info() | ||||||||||||||||||
logger.info("Training/evaluation parameters %s", training_args) | ||||||||||||||||||
Comment on lines
+164
to
+171
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's exactly what I'm talking about :) |
||||||||||||||||||
|
||||||||||||||||||
# Set seed before initializing model. | ||||||||||||||||||
set_seed(training_args.seed) | ||||||||||||||||||
|
||||||||||||||||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) | ||||||||||||||||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ (the dataset will be downloaded automatically from the datasets Hub | ||||||||||||||||||
# | ||||||||||||||||||
# For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this | ||||||||||||||||||
# behavior (see below) | ||||||||||||||||||
# | ||||||||||||||||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently | ||||||||||||||||||
# download the dataset. | ||||||||||||||||||
if data_args.dataset_name is not None: | ||||||||||||||||||
# Downloading and loading a dataset from the hub. | ||||||||||||||||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) | ||||||||||||||||||
else: | ||||||||||||||||||
data_files = {} | ||||||||||||||||||
if data_args.train_file is not None: | ||||||||||||||||||
data_files["train"] = data_args.train_file | ||||||||||||||||||
if data_args.validation_file is not None: | ||||||||||||||||||
data_files["validation"] = data_args.train_file | ||||||||||||||||||
extension = data_args.train_file.split(".")[-1] | ||||||||||||||||||
if extension == "txt": | ||||||||||||||||||
extension = "text" | ||||||||||||||||||
datasets = load_dataset(extension, data_files=data_files) | ||||||||||||||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at | ||||||||||||||||||
# https://huggingface.co/docs/datasets/loading_datasets.html. | ||||||||||||||||||
|
||||||||||||||||||
# Load pretrained model and tokenizer | ||||||||||||||||||
# | ||||||||||||||||||
# Distributed training: | ||||||||||||||||||
# The .from_pretrained methods guarantee that only one local process can concurrently | ||||||||||||||||||
# download model & vocab. | ||||||||||||||||||
|
||||||||||||||||||
if model_args.config_name: | ||||||||||||||||||
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) | ||||||||||||||||||
elif model_args.model_name_or_path: | ||||||||||||||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) | ||||||||||||||||||
else: | ||||||||||||||||||
config = CONFIG_MAPPING[model_args.model_type]() | ||||||||||||||||||
logger.warning("You are instantiating a new config instance from scratch.") | ||||||||||||||||||
|
||||||||||||||||||
if model_args.tokenizer_name: | ||||||||||||||||||
tokenizer = AutoTokenizer.from_pretrained( | ||||||||||||||||||
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer | ||||||||||||||||||
) | ||||||||||||||||||
elif model_args.model_name_or_path: | ||||||||||||||||||
tokenizer = AutoTokenizer.from_pretrained( | ||||||||||||||||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer | ||||||||||||||||||
) | ||||||||||||||||||
else: | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script." | ||||||||||||||||||
"You can do it from another script, save it, and load it from here, using --tokenizer_name." | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
if model_args.model_name_or_path: | ||||||||||||||||||
model = AutoModelForCausalLM.from_pretrained( | ||||||||||||||||||
model_args.model_name_or_path, | ||||||||||||||||||
from_tf=bool(".ckpt" in model_args.model_name_or_path), | ||||||||||||||||||
config=config, | ||||||||||||||||||
cache_dir=model_args.cache_dir, | ||||||||||||||||||
) | ||||||||||||||||||
else: | ||||||||||||||||||
logger.info("Training new model from scratch") | ||||||||||||||||||
model = AutoModelForCausalLM.from_config(config) | ||||||||||||||||||
|
||||||||||||||||||
model.resize_token_embeddings(len(tokenizer)) | ||||||||||||||||||
|
||||||||||||||||||
# Preprocessing the datasets. | ||||||||||||||||||
# First we tokenize all the texts. | ||||||||||||||||||
if training_args.do_train: | ||||||||||||||||||
column_names = datasets["train"].column_names | ||||||||||||||||||
else: | ||||||||||||||||||
column_names = datasets["validation"].column_names | ||||||||||||||||||
text_column_name = "text" if "text" in column_names else column_names[0] | ||||||||||||||||||
|
||||||||||||||||||
def tokenize_function(examples): | ||||||||||||||||||
return tokenizer(examples[text_column_name]) | ||||||||||||||||||
|
||||||||||||||||||
tokenized_datasets = datasets.map( | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the two calls to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could do the same thing to the |
||||||||||||||||||
tokenize_function, | ||||||||||||||||||
batched=True, | ||||||||||||||||||
remove_columns=[text_column_name], | ||||||||||||||||||
load_from_cache_file=not data_args.overwrite_cache, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
if data_args.block_size <= 0: | ||||||||||||||||||
block_size = tokenizer.max_len | ||||||||||||||||||
else: | ||||||||||||||||||
block_size = min(data_args.block_size, tokenizer.max_len) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we print a warning here to tell the user their There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can add that. |
||||||||||||||||||
|
||||||||||||||||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. | ||||||||||||||||||
def group_texts(examples): | ||||||||||||||||||
# Concatenate all texts. | ||||||||||||||||||
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} | ||||||||||||||||||
total_length = len(concatenated_examples[list(examples.keys())[0]]) | ||||||||||||||||||
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can | ||||||||||||||||||
# customize this part to your needs. | ||||||||||||||||||
total_length = (total_length // block_size) * block_size | ||||||||||||||||||
# Split by chunks of max_len. | ||||||||||||||||||
result = { | ||||||||||||||||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)] | ||||||||||||||||||
for k, t in concatenated_examples.items() | ||||||||||||||||||
} | ||||||||||||||||||
result["labels"] = result["input_ids"].copy() | ||||||||||||||||||
return result | ||||||||||||||||||
|
||||||||||||||||||
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder | ||||||||||||||||||
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower | ||||||||||||||||||
# to preprocess. | ||||||||||||||||||
lm_datasets = tokenized_datasets.map(group_texts, batched=True, load_from_cache_file=not data_args.overwrite_cache) | ||||||||||||||||||
|
||||||||||||||||||
# Initialize our Trainer | ||||||||||||||||||
trainer = Trainer( | ||||||||||||||||||
model=model, | ||||||||||||||||||
args=training_args, | ||||||||||||||||||
train_dataset=lm_datasets["train"] if training_args.do_train else None, | ||||||||||||||||||
eval_dataset=lm_datasets["validation"] if training_args.do_eval else None, | ||||||||||||||||||
tokenizer=tokenizer, | ||||||||||||||||||
# Data collator will default to DataCollatorWithPadding, so we change it. | ||||||||||||||||||
data_collator=default_data_collator, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
# Training | ||||||||||||||||||
if training_args.do_train: | ||||||||||||||||||
trainer.train( | ||||||||||||||||||
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None | ||||||||||||||||||
) | ||||||||||||||||||
trainer.save_model() # Saves the tokenizer too for easy upload | ||||||||||||||||||
|
||||||||||||||||||
# Evaluation | ||||||||||||||||||
results = {} | ||||||||||||||||||
if training_args.do_eval: | ||||||||||||||||||
logger.info("*** Evaluate ***") | ||||||||||||||||||
|
||||||||||||||||||
eval_output = trainer.evaluate() | ||||||||||||||||||
|
||||||||||||||||||
perplexity = math.exp(eval_output["eval_loss"]) | ||||||||||||||||||
results["perplexity"] = perplexity | ||||||||||||||||||
|
||||||||||||||||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt") | ||||||||||||||||||
if trainer.is_world_process_zero(): | ||||||||||||||||||
with open(output_eval_file, "w") as writer: | ||||||||||||||||||
logger.info("***** Eval results *****") | ||||||||||||||||||
for key, value in results.items(): | ||||||||||||||||||
logger.info(f" {key} = {value}") | ||||||||||||||||||
writer.write(f"{key} = {value}\n") | ||||||||||||||||||
|
||||||||||||||||||
return results | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _mp_fn(index): | ||||||||||||||||||
# For xla_spawn (TPUs) | ||||||||||||||||||
main() | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need the copyright to Google AI and NVIDIA? Are there some snippets taken from their codebases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's a bad copy paste.