diff --git a/README.md b/README.md index 67592454a..c355e7279 100644 --- a/README.md +++ b/README.md @@ -397,6 +397,7 @@ Add below flag to train command above Please reduce any below - `micro_batch_size` - `eval_batch_size` + - `gradient_accumulation_steps` - `sequence_len` > RuntimeError: expected scalar type Float but found Half diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 0d9610aae..38e0b9819 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -8,6 +8,12 @@ def validate_config(cfg): raise ValueError( "please set only one of gradient_accumulation_steps or batch_size" ) + if cfg.batch_size: + logging.warning( + "%s\n%s", + "batch_size is not recommended. Please use gradient_accumulation_steps instead.", + "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", + ) if cfg.load_4bit: raise ValueError( "cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq" diff --git a/tests/test_validation.py b/tests/test_validation.py index 93ec15269..ce744f762 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,6 +1,8 @@ """Module for testing the validation module""" +import logging import unittest +from typing import Optional import pytest @@ -13,6 +15,12 @@ class ValidationTest(unittest.TestCase): Test the validation module """ + _caplog: Optional[pytest.LogCaptureFixture] = None + + @pytest.fixture(autouse=True) + def inject_fixtures(self, caplog): + self._caplog = caplog + def test_load_4bit_deprecate(self): cfg = DictDefault( { @@ -23,6 +31,17 @@ def test_load_4bit_deprecate(self): with pytest.raises(ValueError): validate_config(cfg) + def test_batch_size_unused_warning(self): + cfg = DictDefault( + { + "batch_size": 32, + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert "batch_size is not recommended" in self._caplog.records[0].message + def test_qlora(self): base_cfg = DictDefault( {