Skip to content

Commit

Permalink
extract module for working with cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
tmm1 committed Aug 13, 2023
1 parent a13e45d commit 8cec513
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 47 deletions.
47 changes: 2 additions & 45 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from axolotl.logging_config import configure_logging
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process
Expand All @@ -29,7 +30,6 @@
process_datasets_for_packing,
setup_trainer,
)
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
Expand All @@ -44,27 +44,6 @@
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"


def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"

if torch.backends.mps.is_available():
return "mps"

raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"

cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}


def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to finish): ")
instruction = ""
Expand Down Expand Up @@ -194,31 +173,9 @@ def train(

validate_config(cfg)

# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
normalize_config(cfg)

setup_wandb_env_vars(cfg)
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False

if cfg.tf32:
torch.backends.cuda.matmul.allow_tf32 = True

# load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
Expand Down
50 changes: 49 additions & 1 deletion src/axolotl/utils/validation.py → src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,60 @@
"""Module for validating config files"""
"""Module for working with config dicts"""

import logging
import os

import torch

LOG = logging.getLogger("axolotl")


def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"

if torch.backends.mps.is_available():
return "mps"

raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"

cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}


def normalize_config(cfg):
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size

if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False


def validate_config(cfg):
if cfg.max_packed_sequence_len and cfg.sample_packing:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import pytest

from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.validation import validate_config


class ValidationTest(unittest.TestCase):
Expand Down

0 comments on commit 8cec513

Please sign in to comment.