Skip to content

Commit

Permalink
Introduce acceleratorconfig dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Jan 23, 2024
1 parent 83f9196 commit d9e825b
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 14 deletions.
13 changes: 11 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
TrainerState,
)
from .trainer_pt_utils import (
AcceleratorConfig,
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
Expand Down Expand Up @@ -3966,11 +3967,19 @@ def create_accelerator_and_postprocess(self):
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)

# create accelerator object
accelerator_kwargs = {}
if self.args.accelerator_config is not None:
accelerator_kwargs = self.args.accelerator_config
# dict and AcceleratorConfigs are parseable, json files are not
if isinstance(accelerator_kwargs, AcceleratorConfig):
accelerator_kwargs = accelerator_kwargs.to_kwargs()
elif not isinstance(accelerator_kwargs, dict):
accelerator_kwargs = AcceleratorConfig.from_json_file(accelerator_kwargs).to_kwargs()

self.accelerator = Accelerator(
dispatch_batches=self.args.dispatch_batches,
split_batches=self.args.split_batches,
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
**accelerator_kwargs,
)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics
Expand Down
104 changes: 103 additions & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
Torch utilities for the Trainer class.
"""

import copy
import datetime
import io
import json
import math
import os
import sys
import warnings
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from logging import StreamHandler
from typing import Any, Dict, Iterator, List, Optional, Union

Expand Down Expand Up @@ -1140,3 +1142,103 @@ def smp_nested_concat(tensor):
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
# which is also the name of the decorator so Python is confused.
return tensor.concat().detach().cpu()


@dataclass
class AcceleratorConfig:
"""
A subset of arguments relating to the underlying [`accelerate.Accelerator`]
implementation utilized in the `Trainer` that can be customized.
Mostly relating to data.
Parameters:
split_batches (`bool`, *optional*, defaults to `False`):
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
`True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
in your script multiplied by the number of processes.
dispatch_batches (`bool`, *optional*):
If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
underlying dataset is an `IterableDataset`, `False` otherwise.
even_batches (`bool`, *optional*, defaults to `True`):
If set to `True`, in cases where the total batch size across all processes does not exactly divide the
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
all workers.
use_seedable_sampler (`bool`, *optional*, defaults to `True`):
Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
"""

# Data related arguments
split_batches: bool = field(
default=False,
metadata={
"help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If"
" `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a"
" round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set"
" in your script multiplied by the number of processes."
},
)
dispatch_batches: bool = field(
default=None,
metadata={
"help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
" and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
" underlying dataset is an `IterableDataslet`, `False` otherwise."
},
)
even_batches: bool = field(
default=True,
metadata={
"help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the"
" dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among"
" all workers."
},
)
use_seedable_sampler: bool = field(
default=True,
metadata={
"help": "Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`])."
"Ensures training results are fully reproducable using a different sampling technique. "
"While seed-to-seed results may differ, on average the differences are neglible when using"
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
},
)

@classmethod
def from_json_file(cls, json_file):
# Check if exists
open_file = io.open if os.path.exists(json_file) else open
with open_file(json_file, "r", encoding="utf-8") as f:
config_dict = json.load(f)
# Check for keys and load sensible defaults
if "split_batches" not in config_dict:
config_dict["split_batches"] = False
if "dispatch_batches" not in config_dict:
config_dict["dispatch_batches"] = None
if "even_batches" not in config_dict:
config_dict["even_batches"] = True
if "use_seedable_sampler" not in config_dict:
config_dict["use_seedable_sampler"] = True
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
if len(extra_keys) > 0:
raise ValueError(
f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `transformers`"
" version or fix (and potentially remove these keys) from your config file."
)
return cls(**config_dict)

def to_dict(self):
return copy.deepcopy(self.__dict__)

def to_kwargs(self):
"""
Returns a dictionary containing the attributes with values different from the default of this class.
"""
default_dict = self.__class__().to_dict()
this_dict = self.to_dict()
return {k: v for k, v in this_dict.items() if default_dict[k] != v}
75 changes: 64 additions & 11 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from packaging import version

from .debug_utils import DebugOption
from .trainer_pt_utils import AcceleratorConfig
from .trainer_utils import (
EvaluationStrategy,
FSDPOption,
Expand Down Expand Up @@ -487,6 +488,32 @@ class TrainingArguments:
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
`ds_config.json`) or an already loaded json file as a `dict`"
accelerator_config (`str`, `dict`, or `AcceleratorConfig`, *optional*):
Config to be used with the internal `Accelerator` implementation. The value is either a location of
accelerator json config file (e.g., `accelerator_config.json`), an already loaded json file as `dict`,
or an instance of [`~trainer_pt_utils.AcceleratorConfig`].
A list of config and its options:
- split_batches (`bool`, *optional*, defaults to `False`):
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
`True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
in your script multiplied by the number of processes.
- dispatch_batches (`bool`, *optional*):
If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
underlying dataset is an `IterableDataset`, `False` otherwise.
- even_batches (`bool`, *optional*, defaults to `True`):
If set to `True`, in cases where the total batch size across all processes does not exactly divide the
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
all workers.
- use_seedable_sampler (`bool`, *optional*, defaults to `True`):
Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
label_smoothing_factor (`float`, *optional*, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +
Expand Down Expand Up @@ -1073,6 +1100,16 @@ class TrainingArguments:
},
)
# Do not touch this type annotation or it will stop working in CLI
accelerator_config: Optional[str] = field(
default=None,
metadata={
"help": (
"Config to be used with the internal Accelerator object initializtion. The value is either a "
"accelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`."
)
},
)
# Do not touch this type annotation or it will stop working in CLI
deepspeed: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -1270,20 +1307,12 @@ class TrainingArguments:

dispatch_batches: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to dispatch batches across devices in distributed training. If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process "
"and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
"underlying dataset is an `IterableDataset`, `False` otherwise."
},
metadata={"help": "Deprecated. Pass {'dispatch_batches':VALUE} to `accelerator_config`."},
)

split_batches: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices during distributed training. If"
"set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it must be a"
"round multiple of the number of processes you are using (such as GPUs)."
},
default=None,
metadata={"help": "Deprecated. Pass {'split_batches':True} to `accelerator_config`."},
)

include_tokens_per_second: Optional[bool] = field(
Expand Down Expand Up @@ -1690,6 +1719,24 @@ def __post_init__(self):
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")

if isinstance(self.accelerator_config, str):
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
if self.dispatch_batches:
warnings.warn(
"Using `--dispatch_batches` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
" `--accelerator_config {'dispatch_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config["dispatch_batches"] = self.dispatch_batches

if self.split_batches:
warnings.warn(
"Using `--split_batches` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
" `--accelerator_config {'split_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config["split_batches"] = self.split_batches

if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
Expand Down Expand Up @@ -1771,6 +1818,12 @@ def __post_init__(self):
f"{self.hub_model_id}).",
FutureWarning,
)
if self.split_batches is not None:
warnings.warn(
"using `split_batches` is deprecated and will be removed in version 5"
" of 🤗 Transformers. Use the `accelerator_config` argument instead.",
FutureWarning,
)

def __str__(self):
self_as_dict = asdict(self)
Expand Down
67 changes: 67 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
torch_device,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_pt_utils import AcceleratorConfig
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, get_last_checkpoint
from transformers.training_args import OptimizerNames
from transformers.utils import (
Expand Down Expand Up @@ -2409,6 +2410,72 @@ def test_end_to_end_example(self):
execute_subprocess_async(command)
# successful return here == success - any errors would have caused an error or a timeout in the sub-call

def test_accelerator_config_from_dict(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
with tempfile.TemporaryDirectory() as tmp_dir:
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()

# Leaves all options as something *not* basic
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config={
"split_batches": True,
"dispatch_batches": True,
"even_batches": False,
"use_seedable_sampler": True,
},
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
self.assertEqual(trainer.accelerator.dispatch_batches, True)
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)

def test_accelerator_config_from_yaml(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
with tempfile.TemporaryDirectory() as tmp_dir:
path_file = Path(tmp_dir) / "accelerator_config.json"
with open(path_file, "w") as f:
accelerator_config = {
"split_batches": True,
"dispatch_batches": True,
"even_batches": False,
"use_seedable_sampler": False,
}
json.dump(accelerator_config, f)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()

# Leaves all options as something *not* basic
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=path_file)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
self.assertEqual(trainer.accelerator.dispatch_batches, True)
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)

def test_accelerator_config_from_dataclass(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
accelerator_config = AcceleratorConfig(
split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False
)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
with tempfile.TemporaryDirectory() as tmp_dir:
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
self.assertEqual(trainer.accelerator.dispatch_batches, True)
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)


@require_torch
@is_staging_test
Expand Down

0 comments on commit d9e825b

Please sign in to comment.