From 164193fa7ea196fde7a87406f0123c1be0c26356 Mon Sep 17 00:00:00 2001
From: Zach Mueller <muellerzr@gmail.com>
Date: Wed, 14 Feb 2024 13:26:02 -0500
Subject: [PATCH] [Big deprecation] Introduces a `DataLoaderConfig` (#2441)

* Deprecate and introduce dataloader_config

* Update docs

* Doc nits

* More tests, adjust based on PR review

* Fixup tests

* Nits

* Update docs/source/quicktour.md

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Clean

* Actually create one

* Forgot to change one

* Use pytest

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
---
 docs/source/package_reference/utilities.md    |  2 +
 docs/source/quicktour.md                      |  2 +-
 src/accelerate/__init__.py                    |  1 +
 src/accelerate/accelerator.py                 | 86 +++++++++++++------
 .../scripts/external_deps/test_metrics.py     | 13 ++-
 .../scripts/test_distributed_data_loop.py     |  5 +-
 .../test_utils/scripts/test_script.py         | 20 +++--
 src/accelerate/utils/__init__.py              |  1 +
 src/accelerate/utils/dataclasses.py           | 42 +++++++++
 tests/test_accelerator.py                     | 38 ++++++++
 10 files changed, 171 insertions(+), 39 deletions(-)

diff --git a/docs/source/package_reference/utilities.md b/docs/source/package_reference/utilities.md
index 404d6d4da7a..9650bee10ba 100644
--- a/docs/source/package_reference/utilities.md
+++ b/docs/source/package_reference/utilities.md
@@ -95,6 +95,8 @@ These are classes which can be configured and passed through to the appropriate
 
 [[autodoc]] utils.BnbQuantizationConfig
 
+[[autodoc]] utils.DataLoaderConfiguration
+
 [[autodoc]] utils.ProjectConfiguration
 
 ## Environmental Variables
diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md
index 29a936352d5..be41eb7ad3d 100644
--- a/docs/source/quicktour.md
+++ b/docs/source/quicktour.md
@@ -83,7 +83,7 @@ is shuffled the same way (if you decided to use `shuffle=True` or any kind of ra
     your script. For instance, training on 4 GPUs with a batch size of 16 set when creating the training dataloader will
     train at an actual batch size of 64 (4 * 16).
     If you want the batch size remain the same regardless of how many GPUs the script is run on, you can use the 
-    option `split_batches=True` when creating and initializing [`Accelerator`].
+    option `split_batches=True` when creating and initializing [`Accelerator`] by passing in a [`utils.DataLoaderConfiguration`].
     Your training dataloader may change length when going through this method: if you run on X GPUs, it will have its
     length divided by X (since your actual batch size will be multiplied by X), unless you set
     `split_batches=True`.
diff --git a/src/accelerate/__init__.py b/src/accelerate/__init__.py
index 70b897ca916..91e84c02792 100644
--- a/src/accelerate/__init__.py
+++ b/src/accelerate/__init__.py
@@ -16,6 +16,7 @@
 from .state import PartialState
 from .utils import (
     AutocastKwargs,
+    DataLoaderConfiguration,
     DeepSpeedPlugin,
     DistributedDataParallelKwargs,
     DistributedType,
diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py
index 28a5a0ee1c9..26743c089b2 100755
--- a/src/accelerate/accelerator.py
+++ b/src/accelerate/accelerator.py
@@ -47,6 +47,7 @@
     WEIGHTS_INDEX_NAME,
     WEIGHTS_NAME,
     AutocastKwargs,
+    DataLoaderConfiguration,
     DeepSpeedPlugin,
     DistributedDataParallelKwargs,
     DistributedType,
@@ -150,6 +151,12 @@
 
 logger = get_logger(__name__)
 
+# Sentinel values for defaults
+_split_batches = object()
+_dispatch_batches = object()
+_even_batches = object()
+_use_seedable_sampler = object()
+
 
 class Accelerator:
     """
@@ -159,11 +166,6 @@ class Accelerator:
         device_placement (`bool`, *optional*, defaults to `True`):
             Whether or not the accelerator should put objects on device (tensors yielded by the dataloader, model,
             etc...).
-        split_batches (`bool`, *optional*, defaults to `False`):
-            Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
-            `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
-            round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
-            in your script multiplied by the number of processes.
         mixed_precision (`str`, *optional*):
             Whether or not to use mixed precision training. Choose from 'no','fp16','bf16 or 'fp8'. Will default to the
             value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the
@@ -176,6 +178,8 @@ class Accelerator:
         cpu (`bool`, *optional*):
             Whether or not to force the script to execute on CPU. Will ignore GPU available if set to `True` and force
             the execution on one process only.
+        dataloader_config (`DataLoaderConfiguration`, *optional*):
+            A configuration for how the dataloaders should be handled in distributed scenarios.
         deepspeed_plugin (`DeepSpeedPlugin`, *optional*):
             Tweak your DeepSpeed related args using this argument. This argument is optional and can be configured
             directly using *accelerate config*
@@ -210,19 +214,6 @@ class Accelerator:
         project_dir (`str`, `os.PathLike`, *optional*):
             A path to a directory for storing data such as logs of locally-compatible loggers and potentially saved
             checkpoints.
-        dispatch_batches (`bool`, *optional*):
-            If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
-            and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
-            underlying dataset is an `IterableDataset`, `False` otherwise.
-        even_batches (`bool`, *optional*, defaults to `True`):
-            If set to `True`, in cases where the total batch size across all processes does not exactly divide the
-            dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
-            all workers.
-        use_seedable_sampler (`bool`, *optional*, defaults to `False`):
-            Whether or not use a fully seedable random sampler ([`~data_loader.SeedableRandomSampler`]). Ensures
-            training results are fully reproducable using a different sampling technique. While seed-to-seed results
-            may differ, on average the differences are neglible when using multiple different seeds to compare. Should
-            also be ran with [`~utils.set_seed`] each time for the best results.
         step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`):
             Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only
             done under certain circumstances (at the end of each epoch, for instance).
@@ -254,10 +245,11 @@ class Accelerator:
     def __init__(
         self,
         device_placement: bool = True,
-        split_batches: bool = False,
+        split_batches: bool = _split_batches,
         mixed_precision: PrecisionType | str | None = None,
         gradient_accumulation_steps: int = 1,
         cpu: bool = False,
+        dataloader_config: DataLoaderConfiguration | None = None,
         deepspeed_plugin: DeepSpeedPlugin | None = None,
         fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
         megatron_lm_plugin: MegatronLMPlugin | None = None,
@@ -266,9 +258,9 @@ def __init__(
         project_dir: str | os.PathLike | None = None,
         project_config: ProjectConfiguration | None = None,
         gradient_accumulation_plugin: GradientAccumulationPlugin | None = None,
-        dispatch_batches: bool | None = None,
-        even_batches: bool = True,
-        use_seedable_sampler: bool = False,
+        dispatch_batches: bool | None = _dispatch_batches,
+        even_batches: bool = _even_batches,
+        use_seedable_sampler: bool = _use_seedable_sampler,
         step_scheduler_with_optimizer: bool = True,
         kwargs_handlers: list[KwargsHandler] | None = None,
         dynamo_backend: DynamoBackend | str | None = None,
@@ -422,10 +414,32 @@ def __init__(
                 )
 
         self.device_placement = device_placement
-        self.split_batches = split_batches
-        self.dispatch_batches = dispatch_batches
-        self.even_batches = even_batches
-        self.use_seedable_sampler = use_seedable_sampler
+        if dataloader_config is None:
+            dataloader_config = DataLoaderConfiguration()
+        self.dataloader_config = dataloader_config
+        # Deal with deprecated args
+        # TODO: Remove in v1.0.0
+        deprecated_dl_args = {}
+        if dispatch_batches is not _dispatch_batches:
+            deprecated_dl_args["dispatch_batches"] = dispatch_batches
+            self.dataloader_config.dispatch_batches = dispatch_batches
+        if split_batches is not _split_batches:
+            deprecated_dl_args["split_batches"] = split_batches
+            self.dataloader_config.split_batches = split_batches
+        if even_batches is not _even_batches:
+            deprecated_dl_args["even_batches"] = even_batches
+            self.dataloader_config.even_batches = even_batches
+        if use_seedable_sampler is not _use_seedable_sampler:
+            deprecated_dl_args["use_seedable_sampler"] = use_seedable_sampler
+            self.dataloader_config.use_seedable_sampler = use_seedable_sampler
+        if len(deprecated_dl_args) > 0:
+            values = ", ".join([f"{k}={v}" for k, v in deprecated_dl_args.items()])
+            warnings.warn(
+                f"Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: {deprecated_dl_args.keys()}. "
+                "Please pass an `accelerate.DataLoaderConfiguration` instead: \n"
+                f"dataloader_config = DataLoaderConfiguration({values})",
+                FutureWarning,
+            )
         self.step_scheduler_with_optimizer = step_scheduler_with_optimizer
 
         # Mixed precision attributes
@@ -515,6 +529,26 @@ def local_process_index(self):
     def device(self):
         return self.state.device
 
+    @property
+    def split_batches(self):
+        return self.dataloader_config.split_batches
+
+    @property
+    def dispatch_batches(self):
+        return self.dataloader_config.dispatch_batches
+
+    @property
+    def even_batches(self):
+        return self.dataloader_config.even_batches
+
+    @even_batches.setter
+    def even_batches(self, value: bool):
+        self.dataloader_config.even_batches = value
+
+    @property
+    def use_seedable_sampler(self):
+        return self.dataloader_config.use_seedable_sampler
+
     @property
     def project_dir(self):
         return self.project_configuration.project_dir
diff --git a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py
index aba990b3d48..d41e5cd5964 100755
--- a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py
+++ b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py
@@ -25,7 +25,7 @@
 from torch.utils.data import DataLoader, IterableDataset
 from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
-from accelerate import Accelerator, DistributedType
+from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
 from accelerate.data_loader import DataLoaderDispatcher
 from accelerate.test_utils import RegressionDataset, RegressionModel, torch_device
 from accelerate.utils import is_torch_xla_available, set_seed
@@ -81,7 +81,8 @@ def collate_fn(examples):
 
 
 def get_mrpc_setup(dispatch_batches, split_batches):
-    accelerator = Accelerator(dispatch_batches=dispatch_batches, split_batches=split_batches)
+    dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, split_batches=split_batches)
+    accelerator = Accelerator(data_loader_config=dataloader_config)
     dataloader = get_dataloader(accelerator, not dispatch_batches)
     model = AutoModelForSequenceClassification.from_pretrained(
         "hf-internal-testing/mrpc-bert-base-cased", return_dict=True
@@ -240,7 +241,8 @@ def test_gather_for_metrics_drop_last():
 
 
 def main():
-    accelerator = Accelerator(split_batches=False, dispatch_batches=False)
+    dataloader_config = DataLoaderConfiguration(split_batches=False, dispatch_batches=False)
+    accelerator = Accelerator(dataloader_config=dataloader_config)
     if accelerator.is_local_main_process:
         datasets.utils.logging.set_verbosity_warning()
         transformers.utils.logging.set_verbosity_warning()
@@ -276,7 +278,10 @@ def main():
             print("**Test torch metrics**")
         for split_batches in [True, False]:
             for dispatch_batches in dispatch_batches_options:
-                accelerator = Accelerator(split_batches=split_batches, dispatch_batches=dispatch_batches)
+                dataloader_config = DataLoaderConfiguration(
+                    split_batches=split_batches, dispatch_batches=dispatch_batches
+                )
+                accelerator = Accelerator(dataloader_config=dataloader_config)
                 if accelerator.is_local_main_process:
                     print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99")
                 test_torch_metrics(accelerator, 99)
diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py
index 0f5e9de6028..17d577c58ac 100644
--- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py
+++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py
@@ -22,7 +22,7 @@
 import torch
 from torch.utils.data import DataLoader, IterableDataset, TensorDataset
 
-from accelerate.accelerator import Accelerator
+from accelerate.accelerator import Accelerator, DataLoaderConfiguration
 from accelerate.utils.dataclasses import DistributedType
 
 
@@ -35,7 +35,8 @@ def __iter__(self):
 
 
 def create_accelerator(even_batches=True):
-    accelerator = Accelerator(even_batches=even_batches)
+    dataloader_config = DataLoaderConfiguration(even_batches=even_batches)
+    accelerator = Accelerator(dataloader_config=dataloader_config)
     assert accelerator.num_processes == 2, "this script expects that two GPUs are available"
     return accelerator
 
diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py
index cc61cbc4d59..754c265aea8 100644
--- a/src/accelerate/test_utils/scripts/test_script.py
+++ b/src/accelerate/test_utils/scripts/test_script.py
@@ -30,6 +30,7 @@
 from accelerate.state import AcceleratorState
 from accelerate.test_utils import RegressionDataset, are_the_same_tensors
 from accelerate.utils import (
+    DataLoaderConfiguration,
     DistributedType,
     gather,
     is_bf16_available,
@@ -355,7 +356,9 @@ def check_seedable_sampler():
     set_seed(42)
     train_set = RegressionDataset(length=10, seed=42)
     train_dl = DataLoader(train_set, batch_size=2, shuffle=True)
-    accelerator = Accelerator(use_seedable_sampler=True)
+
+    config = DataLoaderConfiguration(use_seedable_sampler=True)
+    accelerator = Accelerator(dataloader_config=config)
     train_dl = accelerator.prepare(train_dl)
     original_items = []
     for _ in range(3):
@@ -424,7 +427,8 @@ def training_check(use_seedable_sampler=False):
 
     accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.")
 
-    accelerator = Accelerator(split_batches=True, use_seedable_sampler=use_seedable_sampler)
+    dataloader_config = DataLoaderConfiguration(split_batches=True, use_seedable_sampler=use_seedable_sampler)
+    accelerator = Accelerator(dataloader_config=dataloader_config)
     train_dl = generate_baseline_dataloader(
         train_set, generator, batch_size * state.num_processes, use_seedable_sampler
     )
@@ -452,7 +456,8 @@ def training_check(use_seedable_sampler=False):
         # Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
         print("FP16 training check.")
         AcceleratorState._reset_state()
-        accelerator = Accelerator(mixed_precision="fp16", use_seedable_sampler=use_seedable_sampler)
+        dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
+        accelerator = Accelerator(mixed_precision="fp16", dataloader_config=dataloader_config)
         train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
         model = RegressionModel()
         optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
@@ -492,7 +497,8 @@ def training_check(use_seedable_sampler=False):
         # Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16
         print("BF16 training check.")
         AcceleratorState._reset_state()
-        accelerator = Accelerator(mixed_precision="bf16", use_seedable_sampler=use_seedable_sampler)
+        dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
+        accelerator = Accelerator(mixed_precision="bf16", dataloader_config=dataloader_config)
         train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
         model = RegressionModel()
         optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
@@ -516,7 +522,8 @@ def training_check(use_seedable_sampler=False):
     if is_ipex_available():
         print("ipex BF16 training check.")
         AcceleratorState._reset_state()
-        accelerator = Accelerator(mixed_precision="bf16", cpu=True, use_seedable_sampler=use_seedable_sampler)
+        dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
+        accelerator = Accelerator(mixed_precision="bf16", cpu=True, dataloader_config=dataloader_config)
         train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
         model = RegressionModel()
         optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
@@ -540,7 +547,8 @@ def training_check(use_seedable_sampler=False):
     if is_xpu_available():
         print("xpu BF16 training check.")
         AcceleratorState._reset_state()
-        accelerator = Accelerator(mixed_precision="bf16", cpu=False, use_seedable_sampler=use_seedable_sampler)
+        dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler)
+        accelerator = Accelerator(mixed_precision="bf16", cpu=False, dataloader_config=dataloader_config)
         train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler)
         model = RegressionModel()
         optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py
index 233937b371e..7d3d85c9e1b 100644
--- a/src/accelerate/utils/__init__.py
+++ b/src/accelerate/utils/__init__.py
@@ -18,6 +18,7 @@
     BnbQuantizationConfig,
     ComputeEnvironment,
     CustomDtype,
+    DataLoaderConfiguration,
     DeepSpeedPlugin,
     DistributedDataParallelKwargs,
     DistributedType,
diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py
index 6555c2b7f14..202b69e5bc1 100644
--- a/src/accelerate/utils/dataclasses.py
+++ b/src/accelerate/utils/dataclasses.py
@@ -480,6 +480,48 @@ class TensorInformation:
     dtype: torch.dtype
 
 
+@dataclass
+class DataLoaderConfiguration:
+    """
+    Configuration for dataloader-related items when calling `accelerator.prepare`.
+    """
+
+    split_batches: bool = field(
+        default=False,
+        metadata={
+            "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If"
+            " `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a"
+            " round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set"
+            " in your script multiplied by the number of processes."
+        },
+    )
+    dispatch_batches: bool = field(
+        default=None,
+        metadata={
+            "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
+            " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
+            " underlying dataset is an `IterableDataslet`, `False` otherwise."
+        },
+    )
+    even_batches: bool = field(
+        default=True,
+        metadata={
+            "help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the"
+            " dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among"
+            " all workers."
+        },
+    )
+    use_seedable_sampler: bool = field(
+        default=False,
+        metadata={
+            "help": "Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`])."
+            "Ensures training results are fully reproducable using a different sampling technique. "
+            "While seed-to-seed results may differ, on average the differences are neglible when using"
+            "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
+        },
+    )
+
+
 @dataclass
 class ProjectConfiguration:
     """
diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py
index a7044a5aa5f..c861f77e877 100644
--- a/tests/test_accelerator.py
+++ b/tests/test_accelerator.py
@@ -4,6 +4,7 @@
 import tempfile
 from unittest.mock import patch
 
+import pytest
 import torch
 from parameterized import parameterized
 from torch.utils.data import DataLoader, TensorDataset
@@ -55,6 +56,43 @@ def parameterized_custom_name_func(func, param_num, param):
 
 
 class AcceleratorTester(AccelerateTestCase):
+    # Should be removed after 1.0.0 release
+    def test_deprecated_values(self):
+        # Test defaults
+        accelerator = Accelerator()
+        assert accelerator.split_batches is False, "split_batches should be False by default"
+        assert accelerator.dispatch_batches is None, "dispatch_batches should be None by default"
+        assert accelerator.even_batches is True, "even_batches should be True by default"
+        assert accelerator.use_seedable_sampler is False, "use_seedable_sampler should be False by default"
+
+        # Pass some arguments only
+        with pytest.warns(FutureWarning) as cm:
+            accelerator = Accelerator(
+                dispatch_batches=True,
+                split_batches=False,
+            )
+            deprecation_warning = str(cm.list[0].message)
+            assert accelerator.split_batches is False, "split_batches should be True"
+            assert accelerator.dispatch_batches is True, "dispatch_batches should be True"
+            assert accelerator.even_batches is True, "even_batches should be True by default"
+            assert accelerator.use_seedable_sampler is False, "use_seedable_sampler should be False by default"
+            assert "dispatch_batches" in deprecation_warning
+            assert "split_batches" in deprecation_warning
+            assert "even_batches" not in deprecation_warning
+            assert "use_seedable_sampler" not in deprecation_warning
+
+        # Pass in some arguments, but with their defaults
+        with pytest.warns(FutureWarning) as cm:
+            accelerator = Accelerator(
+                even_batches=True,
+                use_seedable_sampler=False,
+            )
+            deprecation_warning = str(cm.list[0].message)
+            assert "even_batches" in deprecation_warning
+            assert accelerator.even_batches is True
+            assert "use_seedable_sampler" in deprecation_warning
+            assert accelerator.use_seedable_sampler is False
+
     @require_non_cpu
     def test_accelerator_can_be_reinstantiated(self):
         _ = Accelerator()