From f05f63c1eaf57c0e77b5fd79d2163844efde6cc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 14 Aug 2024 12:01:03 +0200 Subject: [PATCH] `PartialState().local_main_process_first()` when `map` in examples (#1926) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * `PartialState().local_main_process_first()` when map in examples * allow load from cache --------- Co-authored-by: Quentin Gallouédec --- examples/datasets/anthropic_hh.py | 6 +---- .../datasets/sentiment_descriptiveness.py | 9 ++----- examples/datasets/tldr_preference.py | 12 ++------- examples/datasets/tokenize_ds.py | 6 +---- examples/scripts/bco.py | 7 +++--- examples/scripts/cpo.py | 11 ++++---- examples/scripts/dpo.py | 12 ++++----- examples/scripts/dpo_visual.py | 3 +++ examples/scripts/kto.py | 6 ++++- examples/scripts/online_dpo.py | 19 ++++++++------ examples/scripts/orpo.py | 11 ++++---- examples/scripts/ppo.py | 14 +++++++---- examples/scripts/ppo/ppo.py | 12 ++++++--- examples/scripts/ppo/ppo_tldr.py | 15 +++++++---- examples/scripts/ppo_multi_adapter.py | 6 ++++- examples/scripts/reward_modeling.py | 25 +++++++++++-------- examples/scripts/rloo/rloo.py | 12 ++++++--- examples/scripts/rloo/rloo_tldr.py | 15 +++++++---- 18 files changed, 115 insertions(+), 86 deletions(-) diff --git a/examples/datasets/anthropic_hh.py b/examples/datasets/anthropic_hh.py index 6a753b6768..854391e209 100644 --- a/examples/datasets/anthropic_hh.py +++ b/examples/datasets/anthropic_hh.py @@ -79,11 +79,7 @@ def process(row): row["prompt"] = row["chosen"][0]["content"] return row - ds = ds.map( - process, - load_from_cache_file=False, - num_proc=args.dataset_num_proc, - ) + ds = ds.map(process, num_proc=args.dataset_num_proc) if args.push_to_hub: revisions = ["main"] if args.update_main_revision else [] revisions.append(args.revision) diff --git a/examples/datasets/sentiment_descriptiveness.py b/examples/datasets/sentiment_descriptiveness.py index b970842750..79b524d982 100644 --- a/examples/datasets/sentiment_descriptiveness.py +++ b/examples/datasets/sentiment_descriptiveness.py @@ -109,7 +109,7 @@ def filter(row): return True print("=== Before filtering ===", ds) - ds = ds.filter(filter, load_from_cache_file=False, num_proc=args.dataset_num_proc) + ds = ds.filter(filter, num_proc=args.dataset_num_proc) print("=== After filtering ===", ds) # here we simply take the preferred sample as the chosen one and the first non-preferred sample as the rejected one @@ -146,12 +146,7 @@ def process(row): assert chosen_sample != rejected_sample return row - ds = ds.map( - process, - batched=True, - load_from_cache_file=False, - num_proc=args.dataset_num_proc, - ) + ds = ds.map(process, batched=True, num_proc=args.dataset_num_proc) for key in ds: # reorder columns ds[key] = ds[key].select_columns(["prompt", "chosen", "rejected"]) if args.push_to_hub: diff --git a/examples/datasets/tldr_preference.py b/examples/datasets/tldr_preference.py index 29799429f2..5074823cec 100644 --- a/examples/datasets/tldr_preference.py +++ b/examples/datasets/tldr_preference.py @@ -76,11 +76,7 @@ def process(row): row["rejected"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": rejected}] return row - ds = ds.map( - process, - load_from_cache_file=False, - num_proc=args.dataset_num_proc, - ) + ds = ds.map(process, num_proc=args.dataset_num_proc) for key in ds: # reorder columns ds[key] = ds[key].select_columns( ["prompt", "chosen", "rejected", "info", "summaries", "choice", "worker", "batch", "split", "extra"] @@ -145,11 +141,7 @@ def sft_process(row): ] return row - sft_ds = sft_ds.map( - sft_process, - load_from_cache_file=False, - num_proc=args.dataset_num_proc, - ) + sft_ds = sft_ds.map(sft_process, num_proc=args.dataset_num_proc) for key in sft_ds: # reorder columns sft_ds[key] = sft_ds[key].select_columns(["prompt", "messages", "id", "subreddit", "title", "post", "summary"]) if args.push_to_hub: diff --git a/examples/datasets/tokenize_ds.py b/examples/datasets/tokenize_ds.py index 2f34b9ffc7..9351455a37 100644 --- a/examples/datasets/tokenize_ds.py +++ b/examples/datasets/tokenize_ds.py @@ -38,9 +38,5 @@ def process(row): row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False) return row - ds = ds.map( - process, - load_from_cache_file=False, - num_proc=args.dataset_num_proc, - ) + ds = ds.map(process, num_proc=args.dataset_num_proc) print(ds["train"][0]["chosen"]) diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py index b8d8f78121..3d6b9ac453 100644 --- a/examples/scripts/bco.py +++ b/examples/scripts/bco.py @@ -183,9 +183,6 @@ def mean_pooling(model_output, attention_mask): if tokenizer.chat_template is None: model, tokenizer = setup_chat_format(model, tokenizer) - # Load the dataset - dataset = build_helpfulness_dataset(script_args.llm_name, num_proc=bco_args.dataset_num_proc) - # Apply chat template def format_dataset(example): example["prompt"] = tokenizer.apply_chat_template( @@ -193,7 +190,11 @@ def format_dataset(example): ) return example + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): + # Load the dataset + dataset = build_helpfulness_dataset(script_args.llm_name, num_proc=bco_args.dataset_num_proc) formatted_dataset = dataset.map(format_dataset, batched=False, num_proc=bco_args.dataset_num_proc) accelerator = Accelerator() diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index f8cfd6f7c4..aefc7bc309 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -54,6 +54,7 @@ from dataclasses import dataclass, field +from accelerate import PartialState from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser @@ -99,11 +100,11 @@ def process(row): row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False) return row - ds = ds.map( - process, - load_from_cache_file=False, - num_proc=cpo_args.dataset_num_proc, - ) + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + ds = ds.map(process, num_proc=cpo_args.dataset_num_proc) + train_dataset = ds["train"] eval_dataset = ds["test"] diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 1564434674..2704df97f1 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -70,7 +70,7 @@ import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer - +from accelerate import PartialState from trl import ( DPOConfig, DPOTrainer, @@ -159,11 +159,11 @@ def process(row): row["rejected"] = tokenizer.apply_chat_template([row["rejected"][-1]], tokenize=False) return row - ds = ds.map( - process, - load_from_cache_file=False, - num_proc=training_args.dataset_num_proc, - ) + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + ds = ds.map(process, num_proc=training_args.dataset_num_proc) + train_dataset = ds[args.dataset_train_split] eval_dataset = ds[args.dataset_test_split] diff --git a/examples/scripts/dpo_visual.py b/examples/scripts/dpo_visual.py index 9665b5ca69..1d4687f5fb 100644 --- a/examples/scripts/dpo_visual.py +++ b/examples/scripts/dpo_visual.py @@ -149,8 +149,11 @@ def process(row): row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False) return row + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): ds = ds.map(process, num_proc=training_args.dataset_num_proc) + train_dataset = ds[args.dataset_train_split] eval_dataset = ds[args.dataset_test_split] diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index 40dbb8a596..e05ab40a02 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -55,6 +55,7 @@ from dataclasses import dataclass +from accelerate import PartialState from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser @@ -102,7 +103,10 @@ def format_dataset(example): example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False) return example - formatted_dataset = dataset.map(format_dataset, num_proc=kto_args.dataset_num_proc) + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + formatted_dataset = dataset.map(format_dataset, num_proc=kto_args.dataset_num_proc) # Initialize the KTO trainer kto_trainer = KTOTrainer( diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index b39f6e8c95..2d229d45fc 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import Optional +from accelerate import PartialState from datasets import load_dataset from transformers import ( AutoModelForCausalLM, @@ -74,7 +75,6 @@ def tokenize(element): remove_columns=dataset.column_names, batched=True, num_proc=num_proc, - load_from_cache_file=False, ) @@ -105,13 +105,18 @@ def tokenize(element): for key in raw_datasets: raw_datasets[key] = raw_datasets[key].select(range(1024)) train_dataset = raw_datasets[args.dataset_train_split] - train_dataset = prepare_dataset(train_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc) - if args.dataset_test_split is not None: - eval_dataset = raw_datasets[args.dataset_test_split] - eval_dataset = prepare_dataset(eval_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc) - else: - eval_dataset = None + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + train_dataset = prepare_dataset(train_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc) + + if args.dataset_test_split is not None: + eval_dataset = raw_datasets[args.dataset_test_split] + eval_dataset = prepare_dataset(eval_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc) + else: + eval_dataset = None + ################ # Training ################ diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index d3769acaf8..076dc036e4 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -54,6 +54,7 @@ from dataclasses import dataclass, field +from accelerate import PartialState from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser @@ -100,11 +101,11 @@ def process(row): row["rejected"] = tokenizer.apply_chat_template([row["rejected"][-1]], tokenize=False) return row - ds = ds.map( - process, - load_from_cache_file=False, - num_prc=orpo_args.dataset_num_proc, - ) + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + ds = ds.map(process, num_prc=orpo_args.dataset_num_proc) + train_dataset = ds["train"] eval_dataset = ds["test"] diff --git a/examples/scripts/ppo.py b/examples/scripts/ppo.py index 88b7d20fa1..8788c562c5 100644 --- a/examples/scripts/ppo.py +++ b/examples/scripts/ppo.py @@ -19,7 +19,7 @@ from typing import Optional import torch -from accelerate import Accelerator +from accelerate import Accelerator, PartialState from datasets import load_dataset from peft import LoraConfig from tqdm import tqdm @@ -53,11 +53,14 @@ class ScriptArguments: trl_model_class = AutoModelForCausalLMWithValueHead if not args.use_seq2seq else AutoModelForSeq2SeqLMWithValueHead +tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name) +tokenizer.pad_token = tokenizer.eos_token + # Below is an example function to build the dataset. In our case, we use the IMDB dataset # from the `datasets` library. One should customize this function to train the model on # its own dataset. -def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text_length=8): +def build_dataset(query_dataset, input_min_text_length=2, input_max_text_length=8): """ Build dataset for training. This builds the dataset from `load_dataset`, one should customize this function to train the model on its own dataset. @@ -70,8 +73,6 @@ def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text dataloader (`torch.utils.data.DataLoader`): The dataloader for the dataset. """ - tokenizer = AutoTokenizer.from_pretrained(config.model_name) - tokenizer.pad_token = tokenizer.eos_token # load imdb with datasets ds = load_dataset(query_dataset, split="train") ds = ds.rename_columns({"text": "review"}) @@ -90,7 +91,10 @@ def tokenize(sample): # We retrieve the dataloader by calling the `build_dataset` function. -dataset = build_dataset(ppo_config, ppo_config.query_dataset) +# Compute that only on the main process for faster data processing. +# see: https://github.com/huggingface/trl/pull/1255 +with PartialState().local_main_process_first(): + dataset = build_dataset(ppo_config, ppo_config.query_dataset) def collator(data): diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 864ed72e1d..d75af791b1 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -1,5 +1,6 @@ import shutil +from accelerate import PartialState from datasets import load_dataset from transformers import ( AutoModelForCausalLM, @@ -93,10 +94,15 @@ def tokenize(element): tokenize, batched=True, remove_columns=dataset.column_names, - load_from_cache_file=False, num_proc=config.dataset_num_proc, ) + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + train_dataset = prepare_dataset(train_dataset, tokenizer) + eval_dataset = prepare_dataset(eval_dataset, tokenizer) + ################ # Training ################ @@ -107,8 +113,8 @@ def tokenize(element): ref_policy=ref_policy, reward_model=reward_model, value_model=value_model, - train_dataset=prepare_dataset(train_dataset, tokenizer), - eval_dataset=prepare_dataset(eval_dataset, tokenizer), + train_dataset=train_dataset, + eval_dataset=eval_dataset, ) trainer.train() trainer.save_model(config.output_dir) diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 2e1f325b30..146194e3d8 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -1,5 +1,6 @@ import shutil +from accelerate import PartialState from datasets import load_dataset from transformers import ( AutoModelForCausalLM, @@ -101,11 +102,15 @@ def tokenize(element): num_proc=config.dataset_num_proc, ) - train_dataset = prepare_dataset(train_dataset, tokenizer) - eval_dataset = prepare_dataset(eval_dataset, tokenizer) - # filtering - train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc) - eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc) + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + train_dataset = prepare_dataset(train_dataset, tokenizer) + eval_dataset = prepare_dataset(eval_dataset, tokenizer) + # filtering + train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc) + eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc) + assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token" ################ # Training diff --git a/examples/scripts/ppo_multi_adapter.py b/examples/scripts/ppo_multi_adapter.py index 466d42bbd1..d9f916a8ee 100644 --- a/examples/scripts/ppo_multi_adapter.py +++ b/examples/scripts/ppo_multi_adapter.py @@ -15,6 +15,7 @@ from typing import Optional import torch +from accelerate import PartialState from datasets import load_dataset from peft import LoraConfig from tqdm import tqdm @@ -95,7 +96,10 @@ def tokenize(example): tokenizer.pad_token = tokenizer.eos_token -dataset = create_and_prepare_dataset(tokenizer, script_args.dataset_num_proc) +# Compute that only on the main process for faster data processing. +# see: https://github.com/huggingface/trl/pull/1255 +with PartialState().local_main_process_first(): + dataset = create_and_prepare_dataset(tokenizer, script_args.dataset_num_proc) def collator(data): diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index c33f297080..5b36de765c 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -31,6 +31,7 @@ import warnings import torch +from accelerate import PartialState from datasets import load_dataset from tqdm import tqdm from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser @@ -99,16 +100,20 @@ def preprocess_function(examples): return new_examples # Preprocess the dataset and filter out examples that are longer than args.max_length - raw_datasets = raw_datasets.map( - preprocess_function, - batched=True, - num_proc=config.dataset_num_proc, - ) - raw_datasets = raw_datasets.filter( - lambda x: len(x["input_ids_chosen"]) <= config.max_length - and len(x["input_ids_rejected"]) <= config.max_length, - num_proc=config.dataset_num_proc, - ) + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + raw_datasets = raw_datasets.map( + preprocess_function, + batched=True, + num_proc=config.dataset_num_proc, + ) + raw_datasets = raw_datasets.filter( + lambda x: len(x["input_ids_chosen"]) <= config.max_length + and len(x["input_ids_rejected"]) <= config.max_length, + num_proc=config.dataset_num_proc, + ) + train_dataset = raw_datasets["train"] eval_dataset = raw_datasets["test"] diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index 2173789b9c..d1fbd45e83 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -1,5 +1,6 @@ import shutil +from accelerate import PartialState from datasets import load_dataset from transformers import ( AutoModelForCausalLM, @@ -92,10 +93,15 @@ def tokenize(element): tokenize, batched=True, remove_columns=dataset.column_names, - load_from_cache_file=False, num_proc=config.dataset_num_proc, ) + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + train_dataset = prepare_dataset(train_dataset, tokenizer) + eval_dataset = prepare_dataset(eval_dataset, tokenizer) + ################ # Training ################ @@ -105,8 +111,8 @@ def tokenize(element): policy=policy, ref_policy=ref_policy, reward_model=reward_model, - train_dataset=prepare_dataset(train_dataset, tokenizer), - eval_dataset=prepare_dataset(eval_dataset, tokenizer), + train_dataset=train_dataset, + eval_dataset=eval_dataset, ) trainer.train() trainer.save_model(config.output_dir) diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index 600c5a1a04..f57072197c 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -1,5 +1,6 @@ import shutil +from accelerate import PartialState from datasets import load_dataset from transformers import ( AutoModelForCausalLM, @@ -100,11 +101,15 @@ def tokenize(element): num_proc=config.dataset_num_proc, ) - train_dataset = prepare_dataset(train_dataset, tokenizer) - eval_dataset = prepare_dataset(eval_dataset, tokenizer) - # filtering - train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc) - eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc) + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + train_dataset = prepare_dataset(train_dataset, tokenizer) + eval_dataset = prepare_dataset(eval_dataset, tokenizer) + # filtering + train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc) + eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc) + assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token" ################ # Training