Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PartialState().local_main_process_first() when map in examples #1926

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions examples/datasets/anthropic_hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,7 @@ def process(row):
row["prompt"] = row["chosen"][0]["content"]
return row

ds = ds.map(
process,
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
ds = ds.map(process, num_proc=args.dataset_num_proc)
if args.push_to_hub:
revisions = ["main"] if args.update_main_revision else []
revisions.append(args.revision)
Expand Down
9 changes: 2 additions & 7 deletions examples/datasets/sentiment_descriptiveness.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def filter(row):
return True

print("=== Before filtering ===", ds)
ds = ds.filter(filter, load_from_cache_file=False, num_proc=args.dataset_num_proc)
ds = ds.filter(filter, num_proc=args.dataset_num_proc)
print("=== After filtering ===", ds)

# here we simply take the preferred sample as the chosen one and the first non-preferred sample as the rejected one
Expand Down Expand Up @@ -146,12 +146,7 @@ def process(row):
assert chosen_sample != rejected_sample
return row

ds = ds.map(
process,
batched=True,
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
ds = ds.map(process, batched=True, num_proc=args.dataset_num_proc)
for key in ds: # reorder columns
ds[key] = ds[key].select_columns(["prompt", "chosen", "rejected"])
if args.push_to_hub:
Expand Down
12 changes: 2 additions & 10 deletions examples/datasets/tldr_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,7 @@ def process(row):
row["rejected"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": rejected}]
return row

ds = ds.map(
process,
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
ds = ds.map(process, num_proc=args.dataset_num_proc)
for key in ds: # reorder columns
ds[key] = ds[key].select_columns(
["prompt", "chosen", "rejected", "info", "summaries", "choice", "worker", "batch", "split", "extra"]
Expand Down Expand Up @@ -145,11 +141,7 @@ def sft_process(row):
]
return row

sft_ds = sft_ds.map(
sft_process,
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
sft_ds = sft_ds.map(sft_process, num_proc=args.dataset_num_proc)
for key in sft_ds: # reorder columns
sft_ds[key] = sft_ds[key].select_columns(["prompt", "messages", "id", "subreddit", "title", "post", "summary"])
if args.push_to_hub:
Expand Down
6 changes: 1 addition & 5 deletions examples/datasets/tokenize_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,5 @@ def process(row):
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row

ds = ds.map(
process,
load_from_cache_file=False,
num_proc=args.dataset_num_proc,
)
ds = ds.map(process, num_proc=args.dataset_num_proc)
print(ds["train"][0]["chosen"])
7 changes: 4 additions & 3 deletions examples/scripts/bco.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,18 @@ def mean_pooling(model_output, attention_mask):
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)

# Load the dataset
dataset = build_helpfulness_dataset(script_args.llm_name, num_proc=bco_args.dataset_num_proc)

# Apply chat template
def format_dataset(example):
example["prompt"] = tokenizer.apply_chat_template(
example["prompt"], tokenize=False, add_generation_prompt=True
)
return example

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# Load the dataset
dataset = build_helpfulness_dataset(script_args.llm_name, num_proc=bco_args.dataset_num_proc)
formatted_dataset = dataset.map(format_dataset, batched=False, num_proc=bco_args.dataset_num_proc)

accelerator = Accelerator()
Expand Down
11 changes: 6 additions & 5 deletions examples/scripts/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

from dataclasses import dataclass, field

from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

Expand Down Expand Up @@ -99,11 +100,11 @@ def process(row):
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row

ds = ds.map(
process,
load_from_cache_file=False,
num_proc=cpo_args.dataset_num_proc,
)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=cpo_args.dataset_num_proc)

train_dataset = ds["train"]
eval_dataset = ds["test"]

Expand Down
12 changes: 6 additions & 6 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from accelerate import PartialState
from trl import (
DPOConfig,
DPOTrainer,
Expand Down Expand Up @@ -159,11 +159,11 @@ def process(row):
row["rejected"] = tokenizer.apply_chat_template([row["rejected"][-1]], tokenize=False)
return row

ds = ds.map(
process,
load_from_cache_file=False,
num_proc=training_args.dataset_num_proc,
)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=training_args.dataset_num_proc)

train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]

Expand Down
3 changes: 3 additions & 0 deletions examples/scripts/dpo_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,11 @@ def process(row):
row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False)
return row

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=training_args.dataset_num_proc)

train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]

Expand Down
6 changes: 5 additions & 1 deletion examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

from dataclasses import dataclass

from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

Expand Down Expand Up @@ -102,7 +103,10 @@ def format_dataset(example):
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
return example

formatted_dataset = dataset.map(format_dataset, num_proc=kto_args.dataset_num_proc)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
formatted_dataset = dataset.map(format_dataset, num_proc=kto_args.dataset_num_proc)

# Initialize the KTO trainer
kto_trainer = KTOTrainer(
Expand Down
19 changes: 12 additions & 7 deletions examples/scripts/online_dpo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Optional

from accelerate import PartialState
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -74,7 +75,6 @@ def tokenize(element):
remove_columns=dataset.column_names,
batched=True,
num_proc=num_proc,
load_from_cache_file=False,
)


Expand Down Expand Up @@ -105,13 +105,18 @@ def tokenize(element):
for key in raw_datasets:
raw_datasets[key] = raw_datasets[key].select(range(1024))
train_dataset = raw_datasets[args.dataset_train_split]
train_dataset = prepare_dataset(train_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc)

if args.dataset_test_split is not None:
eval_dataset = raw_datasets[args.dataset_test_split]
eval_dataset = prepare_dataset(eval_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc)
else:
eval_dataset = None
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc)

if args.dataset_test_split is not None:
eval_dataset = raw_datasets[args.dataset_test_split]
eval_dataset = prepare_dataset(eval_dataset, tokenizer, args.dataset_text_field, config.dataset_num_proc)
else:
eval_dataset = None

################
# Training
################
Expand Down
11 changes: 6 additions & 5 deletions examples/scripts/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

from dataclasses import dataclass, field

from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

Expand Down Expand Up @@ -100,11 +101,11 @@ def process(row):
row["rejected"] = tokenizer.apply_chat_template([row["rejected"][-1]], tokenize=False)
return row

ds = ds.map(
process,
load_from_cache_file=False,
num_prc=orpo_args.dataset_num_proc,
)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_prc=orpo_args.dataset_num_proc)

train_dataset = ds["train"]
eval_dataset = ds["test"]

Expand Down
14 changes: 9 additions & 5 deletions examples/scripts/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Optional

import torch
from accelerate import Accelerator
from accelerate import Accelerator, PartialState
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
Expand Down Expand Up @@ -53,11 +53,14 @@ class ScriptArguments:

trl_model_class = AutoModelForCausalLMWithValueHead if not args.use_seq2seq else AutoModelForSeq2SeqLMWithValueHead

tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
tokenizer.pad_token = tokenizer.eos_token


# Below is an example function to build the dataset. In our case, we use the IMDB dataset
# from the `datasets` library. One should customize this function to train the model on
# its own dataset.
def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text_length=8):
def build_dataset(query_dataset, input_min_text_length=2, input_max_text_length=8):
"""
Build dataset for training. This builds the dataset from `load_dataset`, one should
customize this function to train the model on its own dataset.
Expand All @@ -70,8 +73,6 @@ def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text
dataloader (`torch.utils.data.DataLoader`):
The dataloader for the dataset.
"""
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token
# load imdb with datasets
ds = load_dataset(query_dataset, split="train")
ds = ds.rename_columns({"text": "review"})
Expand All @@ -90,7 +91,10 @@ def tokenize(sample):


# We retrieve the dataloader by calling the `build_dataset` function.
dataset = build_dataset(ppo_config, ppo_config.query_dataset)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
dataset = build_dataset(ppo_config, ppo_config.query_dataset)


def collator(data):
Expand Down
12 changes: 9 additions & 3 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import shutil

from accelerate import PartialState
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -93,10 +94,15 @@ def tokenize(element):
tokenize,
batched=True,
remove_columns=dataset.column_names,
load_from_cache_file=False,
num_proc=config.dataset_num_proc,
)

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)

################
# Training
################
Expand All @@ -107,8 +113,8 @@ def tokenize(element):
ref_policy=ref_policy,
reward_model=reward_model,
value_model=value_model,
train_dataset=prepare_dataset(train_dataset, tokenizer),
eval_dataset=prepare_dataset(eval_dataset, tokenizer),
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(config.output_dir)
Expand Down
15 changes: 10 additions & 5 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import shutil

from accelerate import PartialState
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -101,11 +102,15 @@ def tokenize(element):
num_proc=config.dataset_num_proc,
)

train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)
# filtering
train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc)
eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)
# filtering
train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc)
eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc)

assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token"
################
# Training
Expand Down
6 changes: 5 additions & 1 deletion examples/scripts/ppo_multi_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Optional

import torch
from accelerate import PartialState
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
Expand Down Expand Up @@ -95,7 +96,10 @@ def tokenize(example):

tokenizer.pad_token = tokenizer.eos_token

dataset = create_and_prepare_dataset(tokenizer, script_args.dataset_num_proc)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
dataset = create_and_prepare_dataset(tokenizer, script_args.dataset_num_proc)


def collator(data):
Expand Down
Loading
Loading