Skip to content

Commit

Permalink
PartialState().local_main_process_first() when map in examples (#…
Browse files Browse the repository at this point in the history
…1926)

* `PartialState().local_main_process_first()` when map in examples

* allow load from cache

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
  • Loading branch information
qgallouedec and qgallouedec authored Aug 14, 2024
1 parent 54f806b commit f05f63c
Show file tree
Hide file tree
Showing 18 changed files with 115 additions and 86 deletions.
6 changes: 1 addition & 5 deletions examples/datasets/anthropic_hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,7 @@ def process(row):
row["prompt"] = row["chosen"][0]["content"]
return row

ds = ds.map(
process,
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
ds = ds.map(process, num_proc=args.dataset_num_proc)
if args.push_to_hub:
revisions = ["main"] if args.update_main_revision else []
revisions.append(args.revision)
Expand Down
9 changes: 2 additions & 7 deletions examples/datasets/sentiment_descriptiveness.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def filter(row):
return True

print("=== Before filtering ===", ds)
ds = ds.filter(filter, load_from_cache_file=False, num_proc=args.dataset_num_proc)
ds = ds.filter(filter, num_proc=args.dataset_num_proc)
print("=== After filtering ===", ds)

# here we simply take the preferred sample as the chosen one and the first non-preferred sample as the rejected one
Expand Down Expand Up @@ -146,12 +146,7 @@ def process(row):
assert chosen_sample != rejected_sample
return row

ds = ds.map(
process,
batched=True,
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
ds = ds.map(process, batched=True, num_proc=args.dataset_num_proc)
for key in ds: # reorder columns
ds[key] = ds[key].select_columns(["prompt", "chosen", "rejected"])
if args.push_to_hub:
Expand Down
12 changes: 2 additions & 10 deletions examples/datasets/tldr_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,7 @@ def process(row):
row["rejected"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": rejected}]
return row

ds = ds.map(
process,
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
ds = ds.map(process, num_proc=args.dataset_num_proc)
for key in ds: # reorder columns
ds[key] = ds[key].select_columns(
["prompt", "chosen", "rejected", "info", "summaries", "choice", "worker", "batch", "split", "extra"]
Expand Down Expand Up @@ -145,11 +141,7 @@ def sft_process(row):
]
return row

sft_ds = sft_ds.map(
sft_process,
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
sft_ds = sft_ds.map(sft_process, num_proc=args.dataset_num_proc)
for key in sft_ds: # reorder columns
sft_ds[key] = sft_ds[key].select_columns(["prompt", "messages", "id", "subreddit", "title", "post", "summary"])
if args.push_to_hub:
Expand Down
6 changes: 1 addition & 5 deletions examples/datasets/tokenize_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,5 @@ def process(row):
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row

ds = ds.map(
process,
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
ds = ds.map(process, num_proc=args.dataset_num_proc)
print(ds["train"][0]["chosen"])
7 changes: 4 additions & 3 deletions examples/scripts/bco.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,18 @@ def mean_pooling(model_output, attention_mask):
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)

# Load the dataset
dataset = build_helpfulness_dataset(script_args.llm_name, num_proc=bco_args.dataset_num_proc)

# Apply chat template
def format_dataset(example):
example["prompt"] = tokenizer.apply_chat_template(
example["prompt"], tokenize=False, add_generation_prompt=True
)
return example

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# Load the dataset
dataset = build_helpfulness_dataset(script_args.llm_name, num_proc=bco_args.dataset_num_proc)
formatted_dataset = dataset.map(format_dataset, batched=False, num_proc=bco_args.dataset_num_proc)

accelerator = Accelerator()
Expand Down
11 changes: 6 additions & 5 deletions examples/scripts/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

from dataclasses import dataclass, field

from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

Expand Down Expand Up @@ -99,11 +100,11 @@ def process(row):
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row

ds = ds.map(
process,
load_from_cache_file=False,
num_proc=cpo_args.dataset_num_proc,
)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=cpo_args.dataset_num_proc)

train_dataset = ds["train"]
eval_dataset = ds["test"]

Expand Down
12 changes: 6 additions & 6 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from accelerate import PartialState
from trl import (
DPOConfig,
DPOTrainer,
Expand Down Expand Up @@ -159,11 +159,11 @@ def process(row):
row["rejected"] = tokenizer.apply_chat_template([row["rejected"][-1]], tokenize=False)
return row

ds = ds.map(
process,
load_from_cache_file=False,
num_proc=training_args.dataset_num_proc,
)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=training_args.dataset_num_proc)

train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]

Expand Down
3 changes: 3 additions & 0 deletions examples/scripts/dpo_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,11 @@ def process(row):
row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False)
return row

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=training_args.dataset_num_proc)

train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]

Expand Down
6 changes: 5 additions & 1 deletion examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

from dataclasses import dataclass

from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

Expand Down Expand Up @@ -102,7 +103,10 @@ def format_dataset(example):
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
return example

formatted_dataset = dataset.map(format_dataset, num_proc=kto_args.dataset_num_proc)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
formatted_dataset = dataset.map(format_dataset, num_proc=kto_args.dataset_num_proc)

# Initialize the KTO trainer
kto_trainer = KTOTrainer(
Expand Down
19 changes: 12 additions & 7 deletions examples/scripts/online_dpo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Optional

from accelerate import PartialState
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -74,7 +75,6 @@ def tokenize(element):
remove_columns=dataset.column_names,
batched=True,
num_proc=num_proc,
load_from_cache_file=False,
)


Expand Down Expand Up @@ -105,13 +105,18 @@ def tokenize(element):
for key in raw_datasets:
raw_datasets[key] = raw_datasets[key].select(range(1024))
train_dataset = raw_datasets[args.dataset_train_split]
train_dataset = prepare_dataset(train_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc)

if args.dataset_test_split is not None:
eval_dataset = raw_datasets[args.dataset_test_split]
eval_dataset = prepare_dataset(eval_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc)
else:
eval_dataset = None
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc)

if args.dataset_test_split is not None:
eval_dataset = raw_datasets[args.dataset_test_split]
eval_dataset = prepare_dataset(eval_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc)
else:
eval_dataset = None

################
# Training
################
Expand Down
11 changes: 6 additions & 5 deletions examples/scripts/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

from dataclasses import dataclass, field

from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

Expand Down Expand Up @@ -100,11 +101,11 @@ def process(row):
row["rejected"] = tokenizer.apply_chat_template([row["rejected"][-1]], tokenize=False)
return row

ds = ds.map(
process,
load_from_cache_file=False,
num_prc=orpo_args.dataset_num_proc,
)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_prc=orpo_args.dataset_num_proc)

train_dataset = ds["train"]
eval_dataset = ds["test"]

Expand Down
14 changes: 9 additions & 5 deletions examples/scripts/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Optional

import torch
from accelerate import Accelerator
from accelerate import Accelerator, PartialState
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
Expand Down Expand Up @@ -53,11 +53,14 @@ class ScriptArguments:

trl_model_class = AutoModelForCausalLMWithValueHead if not args.use_seq2seq else AutoModelForSeq2SeqLMWithValueHead

tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
tokenizer.pad_token = tokenizer.eos_token


# Below is an example function to build the dataset. In our case, we use the IMDB dataset
# from the `datasets` library. One should customize this function to train the model on
# its own dataset.
def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text_length=8):
def build_dataset(query_dataset, input_min_text_length=2, input_max_text_length=8):
"""
Build dataset for training. This builds the dataset from `load_dataset`, one should
customize this function to train the model on its own dataset.
Expand All @@ -70,8 +73,6 @@ def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text
dataloader (`torch.utils.data.DataLoader`):
The dataloader for the dataset.
"""
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token
# load imdb with datasets
ds = load_dataset(query_dataset, split="train")
ds = ds.rename_columns({"text": "review"})
Expand All @@ -90,7 +91,10 @@ def tokenize(sample):


# We retrieve the dataloader by calling the `build_dataset` function.
dataset = build_dataset(ppo_config, ppo_config.query_dataset)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
dataset = build_dataset(ppo_config, ppo_config.query_dataset)


def collator(data):
Expand Down
12 changes: 9 additions & 3 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import shutil

from accelerate import PartialState
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -93,10 +94,15 @@ def tokenize(element):
tokenize,
batched=True,
remove_columns=dataset.column_names,
load_from_cache_file=False,
num_proc=config.dataset_num_proc,
)

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)

################
# Training
################
Expand All @@ -107,8 +113,8 @@ def tokenize(element):
ref_policy=ref_policy,
reward_model=reward_model,
value_model=value_model,
train_dataset=prepare_dataset(train_dataset, tokenizer),
eval_dataset=prepare_dataset(eval_dataset, tokenizer),
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(config.output_dir)
Expand Down
15 changes: 10 additions & 5 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import shutil

from accelerate import PartialState
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -101,11 +102,15 @@ def tokenize(element):
num_proc=config.dataset_num_proc,
)

train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)
# filtering
train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc)
eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)
# filtering
train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc)
eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc)

assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token"
################
# Training
Expand Down
6 changes: 5 additions & 1 deletion examples/scripts/ppo_multi_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Optional

import torch
from accelerate import PartialState
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
Expand Down Expand Up @@ -95,7 +96,10 @@ def tokenize(example):

tokenizer.pad_token = tokenizer.eos_token

dataset = create_and_prepare_dataset(tokenizer, script_args.dataset_num_proc)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
dataset = create_and_prepare_dataset(tokenizer, script_args.dataset_num_proc)


def collator(data):
Expand Down
Loading

0 comments on commit f05f63c

Please sign in to comment.