diff --git a/docs/source/package_reference/utilities.md b/docs/source/package_reference/utilities.md index 404d6d4da7a..9650bee10ba 100644 --- a/docs/source/package_reference/utilities.md +++ b/docs/source/package_reference/utilities.md @@ -95,6 +95,8 @@ These are classes which can be configured and passed through to the appropriate [[autodoc]] utils.BnbQuantizationConfig +[[autodoc]] utils.DataLoaderConfiguration + [[autodoc]] utils.ProjectConfiguration ## Environmental Variables diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index 29a936352d5..be41eb7ad3d 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -83,7 +83,7 @@ is shuffled the same way (if you decided to use `shuffle=True` or any kind of ra your script. For instance, training on 4 GPUs with a batch size of 16 set when creating the training dataloader will train at an actual batch size of 64 (4 * 16). If you want the batch size remain the same regardless of how many GPUs the script is run on, you can use the - option `split_batches=True` when creating and initializing [`Accelerator`]. + option `split_batches=True` when creating and initializing [`Accelerator`] by passing in a [`utils.DataLoaderConfiguration`]. Your training dataloader may change length when going through this method: if you run on X GPUs, it will have its length divided by X (since your actual batch size will be multiplied by X), unless you set `split_batches=True`. diff --git a/src/accelerate/__init__.py b/src/accelerate/__init__.py index 70b897ca916..91e84c02792 100644 --- a/src/accelerate/__init__.py +++ b/src/accelerate/__init__.py @@ -16,6 +16,7 @@ from .state import PartialState from .utils import ( AutocastKwargs, + DataLoaderConfiguration, DeepSpeedPlugin, DistributedDataParallelKwargs, DistributedType, diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 28a5a0ee1c9..26743c089b2 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -47,6 +47,7 @@ WEIGHTS_INDEX_NAME, WEIGHTS_NAME, AutocastKwargs, + DataLoaderConfiguration, DeepSpeedPlugin, DistributedDataParallelKwargs, DistributedType, @@ -150,6 +151,12 @@ logger = get_logger(__name__) +# Sentinel values for defaults +_split_batches = object() +_dispatch_batches = object() +_even_batches = object() +_use_seedable_sampler = object() + class Accelerator: """ @@ -159,11 +166,6 @@ class Accelerator: device_placement (`bool`, *optional*, defaults to `True`): Whether or not the accelerator should put objects on device (tensors yielded by the dataloader, model, etc...). - split_batches (`bool`, *optional*, defaults to `False`): - Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If - `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a - round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set - in your script multiplied by the number of processes. mixed_precision (`str`, *optional*): Whether or not to use mixed precision training. Choose from 'no','fp16','bf16 or 'fp8'. Will default to the value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the @@ -176,6 +178,8 @@ class Accelerator: cpu (`bool`, *optional*): Whether or not to force the script to execute on CPU. Will ignore GPU available if set to `True` and force the execution on one process only. + dataloader_config (`DataLoaderConfiguration`, *optional*): + A configuration for how the dataloaders should be handled in distributed scenarios. deepspeed_plugin (`DeepSpeedPlugin`, *optional*): Tweak your DeepSpeed related args using this argument. This argument is optional and can be configured directly using *accelerate config* @@ -210,19 +214,6 @@ class Accelerator: project_dir (`str`, `os.PathLike`, *optional*): A path to a directory for storing data such as logs of locally-compatible loggers and potentially saved checkpoints. - dispatch_batches (`bool`, *optional*): - If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process - and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose - underlying dataset is an `IterableDataset`, `False` otherwise. - even_batches (`bool`, *optional*, defaults to `True`): - If set to `True`, in cases where the total batch size across all processes does not exactly divide the - dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among - all workers. - use_seedable_sampler (`bool`, *optional*, defaults to `False`): - Whether or not use a fully seedable random sampler ([`~data_loader.SeedableRandomSampler`]). Ensures - training results are fully reproducable using a different sampling technique. While seed-to-seed results - may differ, on average the differences are neglible when using multiple different seeds to compare. Should - also be ran with [`~utils.set_seed`] each time for the best results. step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`): Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only done under certain circumstances (at the end of each epoch, for instance). @@ -254,10 +245,11 @@ class Accelerator: def __init__( self, device_placement: bool = True, - split_batches: bool = False, + split_batches: bool = _split_batches, mixed_precision: PrecisionType | str | None = None, gradient_accumulation_steps: int = 1, cpu: bool = False, + dataloader_config: DataLoaderConfiguration | None = None, deepspeed_plugin: DeepSpeedPlugin | None = None, fsdp_plugin: FullyShardedDataParallelPlugin | None = None, megatron_lm_plugin: MegatronLMPlugin | None = None, @@ -266,9 +258,9 @@ def __init__( project_dir: str | os.PathLike | None = None, project_config: ProjectConfiguration | None = None, gradient_accumulation_plugin: GradientAccumulationPlugin | None = None, - dispatch_batches: bool | None = None, - even_batches: bool = True, - use_seedable_sampler: bool = False, + dispatch_batches: bool | None = _dispatch_batches, + even_batches: bool = _even_batches, + use_seedable_sampler: bool = _use_seedable_sampler, step_scheduler_with_optimizer: bool = True, kwargs_handlers: list[KwargsHandler] | None = None, dynamo_backend: DynamoBackend | str | None = None, @@ -422,10 +414,32 @@ def __init__( ) self.device_placement = device_placement - self.split_batches = split_batches - self.dispatch_batches = dispatch_batches - self.even_batches = even_batches - self.use_seedable_sampler = use_seedable_sampler + if dataloader_config is None: + dataloader_config = DataLoaderConfiguration() + self.dataloader_config = dataloader_config + # Deal with deprecated args + # TODO: Remove in v1.0.0 + deprecated_dl_args = {} + if dispatch_batches is not _dispatch_batches: + deprecated_dl_args["dispatch_batches"] = dispatch_batches + self.dataloader_config.dispatch_batches = dispatch_batches + if split_batches is not _split_batches: + deprecated_dl_args["split_batches"] = split_batches + self.dataloader_config.split_batches = split_batches + if even_batches is not _even_batches: + deprecated_dl_args["even_batches"] = even_batches + self.dataloader_config.even_batches = even_batches + if use_seedable_sampler is not _use_seedable_sampler: + deprecated_dl_args["use_seedable_sampler"] = use_seedable_sampler + self.dataloader_config.use_seedable_sampler = use_seedable_sampler + if len(deprecated_dl_args) > 0: + values = ", ".join([f"{k}={v}" for k, v in deprecated_dl_args.items()]) + warnings.warn( + f"Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: {deprecated_dl_args.keys()}. " + "Please pass an `accelerate.DataLoaderConfiguration` instead: \n" + f"dataloader_config = DataLoaderConfiguration({values})", + FutureWarning, + ) self.step_scheduler_with_optimizer = step_scheduler_with_optimizer # Mixed precision attributes @@ -515,6 +529,26 @@ def local_process_index(self): def device(self): return self.state.device + @property + def split_batches(self): + return self.dataloader_config.split_batches + + @property + def dispatch_batches(self): + return self.dataloader_config.dispatch_batches + + @property + def even_batches(self): + return self.dataloader_config.even_batches + + @even_batches.setter + def even_batches(self, value: bool): + self.dataloader_config.even_batches = value + + @property + def use_seedable_sampler(self): + return self.dataloader_config.use_seedable_sampler + @property def project_dir(self): return self.project_configuration.project_dir diff --git a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py index aba990b3d48..d41e5cd5964 100755 --- a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py @@ -25,7 +25,7 @@ from torch.utils.data import DataLoader, IterableDataset from transformers import AutoModelForSequenceClassification, AutoTokenizer -from accelerate import Accelerator, DistributedType +from accelerate import Accelerator, DataLoaderConfiguration, DistributedType from accelerate.data_loader import DataLoaderDispatcher from accelerate.test_utils import RegressionDataset, RegressionModel, torch_device from accelerate.utils import is_torch_xla_available, set_seed @@ -81,7 +81,8 @@ def collate_fn(examples): def get_mrpc_setup(dispatch_batches, split_batches): - accelerator = Accelerator(dispatch_batches=dispatch_batches, split_batches=split_batches) + dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, split_batches=split_batches) + accelerator = Accelerator(data_loader_config=dataloader_config) dataloader = get_dataloader(accelerator, not dispatch_batches) model = AutoModelForSequenceClassification.from_pretrained( "hf-internal-testing/mrpc-bert-base-cased", return_dict=True @@ -240,7 +241,8 @@ def test_gather_for_metrics_drop_last(): def main(): - accelerator = Accelerator(split_batches=False, dispatch_batches=False) + dataloader_config = DataLoaderConfiguration(split_batches=False, dispatch_batches=False) + accelerator = Accelerator(dataloader_config=dataloader_config) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_warning() @@ -276,7 +278,10 @@ def main(): print("**Test torch metrics**") for split_batches in [True, False]: for dispatch_batches in dispatch_batches_options: - accelerator = Accelerator(split_batches=split_batches, dispatch_batches=dispatch_batches) + dataloader_config = DataLoaderConfiguration( + split_batches=split_batches, dispatch_batches=dispatch_batches + ) + accelerator = Accelerator(dataloader_config=dataloader_config) if accelerator.is_local_main_process: print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99") test_torch_metrics(accelerator, 99) diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index 0f5e9de6028..17d577c58ac 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -22,7 +22,7 @@ import torch from torch.utils.data import DataLoader, IterableDataset, TensorDataset -from accelerate.accelerator import Accelerator +from accelerate.accelerator import Accelerator, DataLoaderConfiguration from accelerate.utils.dataclasses import DistributedType @@ -35,7 +35,8 @@ def __iter__(self): def create_accelerator(even_batches=True): - accelerator = Accelerator(even_batches=even_batches) + dataloader_config = DataLoaderConfiguration(even_batches=even_batches) + accelerator = Accelerator(dataloader_config=dataloader_config) assert accelerator.num_processes == 2, "this script expects that two GPUs are available" return accelerator diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index cc61cbc4d59..754c265aea8 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -30,6 +30,7 @@ from accelerate.state import AcceleratorState from accelerate.test_utils import RegressionDataset, are_the_same_tensors from accelerate.utils import ( + DataLoaderConfiguration, DistributedType, gather, is_bf16_available, @@ -355,7 +356,9 @@ def check_seedable_sampler(): set_seed(42) train_set = RegressionDataset(length=10, seed=42) train_dl = DataLoader(train_set, batch_size=2, shuffle=True) - accelerator = Accelerator(use_seedable_sampler=True) + + config = DataLoaderConfiguration(use_seedable_sampler=True) + accelerator = Accelerator(dataloader_config=config) train_dl = accelerator.prepare(train_dl) original_items = [] for _ in range(3): @@ -424,7 +427,8 @@ def training_check(use_seedable_sampler=False): accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.") - accelerator = Accelerator(split_batches=True, use_seedable_sampler=use_seedable_sampler) + dataloader_config = DataLoaderConfiguration(split_batches=True, use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(dataloader_config=dataloader_config) train_dl = generate_baseline_dataloader( train_set, generator, batch_size * state.num_processes, use_seedable_sampler ) @@ -452,7 +456,8 @@ def training_check(use_seedable_sampler=False): # Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16 print("FP16 training check.") AcceleratorState._reset_state() - accelerator = Accelerator(mixed_precision="fp16", use_seedable_sampler=use_seedable_sampler) + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="fp16", dataloader_config=dataloader_config) train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -492,7 +497,8 @@ def training_check(use_seedable_sampler=False): # Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16 print("BF16 training check.") AcceleratorState._reset_state() - accelerator = Accelerator(mixed_precision="bf16", use_seedable_sampler=use_seedable_sampler) + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="bf16", dataloader_config=dataloader_config) train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -516,7 +522,8 @@ def training_check(use_seedable_sampler=False): if is_ipex_available(): print("ipex BF16 training check.") AcceleratorState._reset_state() - accelerator = Accelerator(mixed_precision="bf16", cpu=True, use_seedable_sampler=use_seedable_sampler) + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="bf16", cpu=True, dataloader_config=dataloader_config) train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -540,7 +547,8 @@ def training_check(use_seedable_sampler=False): if is_xpu_available(): print("xpu BF16 training check.") AcceleratorState._reset_state() - accelerator = Accelerator(mixed_precision="bf16", cpu=False, use_seedable_sampler=use_seedable_sampler) + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="bf16", cpu=False, dataloader_config=dataloader_config) train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 233937b371e..7d3d85c9e1b 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -18,6 +18,7 @@ BnbQuantizationConfig, ComputeEnvironment, CustomDtype, + DataLoaderConfiguration, DeepSpeedPlugin, DistributedDataParallelKwargs, DistributedType, diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 6555c2b7f14..202b69e5bc1 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -480,6 +480,48 @@ class TensorInformation: dtype: torch.dtype +@dataclass +class DataLoaderConfiguration: + """ + Configuration for dataloader-related items when calling `accelerator.prepare`. + """ + + split_batches: bool = field( + default=False, + metadata={ + "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If" + " `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a" + " round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set" + " in your script multiplied by the number of processes." + }, + ) + dispatch_batches: bool = field( + default=None, + metadata={ + "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process" + " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" + " underlying dataset is an `IterableDataslet`, `False` otherwise." + }, + ) + even_batches: bool = field( + default=True, + metadata={ + "help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the" + " dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among" + " all workers." + }, + ) + use_seedable_sampler: bool = field( + default=False, + metadata={ + "help": "Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`])." + "Ensures training results are fully reproducable using a different sampling technique. " + "While seed-to-seed results may differ, on average the differences are neglible when using" + "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results." + }, + ) + + @dataclass class ProjectConfiguration: """ diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index a7044a5aa5f..c861f77e877 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -4,6 +4,7 @@ import tempfile from unittest.mock import patch +import pytest import torch from parameterized import parameterized from torch.utils.data import DataLoader, TensorDataset @@ -55,6 +56,43 @@ def parameterized_custom_name_func(func, param_num, param): class AcceleratorTester(AccelerateTestCase): + # Should be removed after 1.0.0 release + def test_deprecated_values(self): + # Test defaults + accelerator = Accelerator() + assert accelerator.split_batches is False, "split_batches should be False by default" + assert accelerator.dispatch_batches is None, "dispatch_batches should be None by default" + assert accelerator.even_batches is True, "even_batches should be True by default" + assert accelerator.use_seedable_sampler is False, "use_seedable_sampler should be False by default" + + # Pass some arguments only + with pytest.warns(FutureWarning) as cm: + accelerator = Accelerator( + dispatch_batches=True, + split_batches=False, + ) + deprecation_warning = str(cm.list[0].message) + assert accelerator.split_batches is False, "split_batches should be True" + assert accelerator.dispatch_batches is True, "dispatch_batches should be True" + assert accelerator.even_batches is True, "even_batches should be True by default" + assert accelerator.use_seedable_sampler is False, "use_seedable_sampler should be False by default" + assert "dispatch_batches" in deprecation_warning + assert "split_batches" in deprecation_warning + assert "even_batches" not in deprecation_warning + assert "use_seedable_sampler" not in deprecation_warning + + # Pass in some arguments, but with their defaults + with pytest.warns(FutureWarning) as cm: + accelerator = Accelerator( + even_batches=True, + use_seedable_sampler=False, + ) + deprecation_warning = str(cm.list[0].message) + assert "even_batches" in deprecation_warning + assert accelerator.even_batches is True + assert "use_seedable_sampler" in deprecation_warning + assert accelerator.use_seedable_sampler is False + @require_non_cpu def test_accelerator_can_be_reinstantiated(self): _ = Accelerator()