Skip to content

Commit

Permalink
Merge pull request #135 from NanoCode012/fix/grad-accu-readme
Browse files Browse the repository at this point in the history
Fix: Update doc for grad_accu and add validation tests for batch size
  • Loading branch information
NanoCode012 authored May 31, 2023
2 parents a6f5e5e + 3c71c8d commit 288fd62
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ Add below flag to train command above
Please reduce any below
- `micro_batch_size`
- `eval_batch_size`
- `gradient_accumulation_steps`
- `sequence_len`

> RuntimeError: expected scalar type Float but found Half
Expand Down
6 changes: 6 additions & 0 deletions src/axolotl/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ def validate_config(cfg):
raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
)
if cfg.batch_size:
logging.warning(
"%s\n%s",
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
if cfg.load_4bit:
raise ValueError(
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
Expand Down
19 changes: 19 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Module for testing the validation module"""

import logging
import unittest
from typing import Optional

import pytest

Expand All @@ -13,6 +15,12 @@ class ValidationTest(unittest.TestCase):
Test the validation module
"""

_caplog: Optional[pytest.LogCaptureFixture] = None

@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog

def test_load_4bit_deprecate(self):
cfg = DictDefault(
{
Expand All @@ -23,6 +31,17 @@ def test_load_4bit_deprecate(self):
with pytest.raises(ValueError):
validate_config(cfg)

def test_batch_size_unused_warning(self):
cfg = DictDefault(
{
"batch_size": 32,
}
)

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert "batch_size is not recommended" in self._caplog.records[0].message

def test_qlora(self):
base_cfg = DictDefault(
{
Expand Down

0 comments on commit 288fd62

Please sign in to comment.