Skip to content

Commit

Permalink
Only load data on main process (#1255)
Browse files Browse the repository at this point in the history
* fix: only load data on main process

* define is_main_process once

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* avoid re-initializing PartialState on train dataset check

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* avoid re-initializing PartialState on eval dataset check

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* process dataset on main first to take advantage of caching

* fix typo in docs

* use decorator to manage state

* Revert "fix typo in docs"

This reverts commit 0880a18.

* Revert "Revert "fix typo in docs""

This reverts commit ff7ee33.

* Revert "use decorator to manage state"

This reverts commit 7ac7a45.

* use is_local_main_process instead of is_main_process

* fix: use context manager instead of attribute

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
  • Loading branch information
JohnGiorgi and younesbelkada authored Jan 26, 2024
1 parent 29d439a commit 4edc688
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ trainer = SFTTrainer(

trainer.train()
```
To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)

### Packing dataset ([`ConstantLengthDataset`])

Expand Down
49 changes: 26 additions & 23 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
import torch.nn as nn
from accelerate.state import PartialState
from datasets import Dataset
from datasets.arrow_writer import SchemaInferenceError
from datasets.builder import DatasetGenerationError
Expand Down Expand Up @@ -252,27 +253,13 @@ def make_inputs_require_grad(module, input, output):
if data_collator is None:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

if dataset_kwargs is None:
dataset_kwargs = {}
if train_dataset is not None:
train_dataset = self._prepare_dataset(
train_dataset,
tokenizer,
packing,
dataset_text_field,
max_seq_length,
formatting_func,
num_of_sequences,
chars_per_token,
remove_unused_columns=args.remove_unused_columns if args is not None else True,
**dataset_kwargs,
)
if eval_dataset is not None:
_multiple = isinstance(eval_dataset, dict)
_eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset}
for _eval_dataset_name, _eval_dataset in _eval_datasets.items():
_eval_datasets[_eval_dataset_name] = self._prepare_dataset(
_eval_dataset,
# Pre-process the datasets only once per node. The remaining processes will use the cache.
if PartialState().is_local_main_process:
if dataset_kwargs is None:
dataset_kwargs = {}
if train_dataset is not None:
train_dataset = self._prepare_dataset(
train_dataset,
tokenizer,
packing,
dataset_text_field,
Expand All @@ -283,8 +270,24 @@ def make_inputs_require_grad(module, input, output):
remove_unused_columns=args.remove_unused_columns if args is not None else True,
**dataset_kwargs,
)
if not _multiple:
eval_dataset = _eval_datasets["singleton"]
if eval_dataset is not None:
_multiple = isinstance(eval_dataset, dict)
_eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset}
for _eval_dataset_name, _eval_dataset in _eval_datasets.items():
_eval_datasets[_eval_dataset_name] = self._prepare_dataset(
_eval_dataset,
tokenizer,
packing,
dataset_text_field,
max_seq_length,
formatting_func,
num_of_sequences,
chars_per_token,
remove_unused_columns=args.remove_unused_columns if args is not None else True,
**dataset_kwargs,
)
if not _multiple:
eval_dataset = _eval_datasets["singleton"]

if tokenizer.padding_side is not None and tokenizer.padding_side != "right":
warnings.warn(
Expand Down

0 comments on commit 4edc688

Please sign in to comment.