Skip to content

Commit edfee9e

Browse files
author
Dan Saunders
committed
review comments; slightly DRYing up things
1 parent 2fb3ed5 commit edfee9e

File tree

4 files changed

+22
-21
lines changed

4 files changed

+22
-21
lines changed

src/axolotl/cli/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,6 @@ def load_datasets(
476476
tokenizer,
477477
processor=processor,
478478
)
479-
print(train_dataset, eval_dataset, total_num_steps)
480479

481480
if (
482481
cli_args.debug

src/axolotl/evaluate.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from axolotl.train import TrainDatasetMeta
1515
from axolotl.utils.dict import DictDefault
1616
from axolotl.utils.models import load_model, load_processor, load_tokenizer
17-
from axolotl.utils.trainer import setup_trainer
17+
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
1818

1919
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
2020
src_dir = os.path.join(project_root, "src")
@@ -79,14 +79,8 @@ def evaluate(
7979
- Dictionary of evaluation metrics
8080
"""
8181
# pylint: disable=duplicate-code
82-
# Set up CUDA allocation config if using PyTorch >= 2.2
83-
torch_version = torch.__version__.split(".")
84-
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
85-
if torch_major == 2 and torch_minor >= 2:
86-
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
87-
os.environ[
88-
"PYTORCH_CUDA_ALLOC_CONF"
89-
] = "expandable_segments:True,roundup_power2_divisions:16"
82+
# Enable expandable segments for cuda allocation to improve VRAM usage
83+
set_pytorch_cuda_alloc_conf()
9084

9185
# Load tokenizer
9286
LOG.debug(

src/axolotl/train.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from axolotl.utils.dict import DictDefault
2525
from axolotl.utils.freeze import freeze_layers_except
2626
from axolotl.utils.models import load_model, load_processor, load_tokenizer
27-
from axolotl.utils.trainer import setup_trainer
27+
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
2828

2929
try:
3030
from optimum.bettertransformer import BetterTransformer
@@ -53,25 +53,22 @@ class TrainDatasetMeta:
5353
def train(
5454
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
5555
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
56-
# enable expandable segments for cuda allocation to improve VRAM usage
57-
torch_version = torch.__version__.split(".")
58-
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
59-
if torch_major == 2 and torch_minor >= 2:
60-
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
61-
os.environ[
62-
"PYTORCH_CUDA_ALLOC_CONF"
63-
] = "expandable_segments:True,roundup_power2_divisions:16"
64-
65-
# load the tokenizer first
56+
# Enable expandable segments for cuda allocation to improve VRAM usage
57+
set_pytorch_cuda_alloc_conf()
58+
59+
# Load tokenizer
6660
LOG.debug(
6761
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
6862
main_process_only=True,
6963
)
7064
tokenizer = load_tokenizer(cfg)
65+
66+
# Load processor for multimodal models if needed
7167
processor = None
7268
if cfg.is_multimodal:
7369
processor = load_processor(cfg, tokenizer)
7470

71+
# Get datasets
7572
train_dataset = dataset_meta.train_dataset
7673
eval_dataset = dataset_meta.eval_dataset
7774
total_num_steps = dataset_meta.total_num_steps

src/axolotl/utils/trainer.py

+11
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,17 @@ def prepare_opinionated_env(cfg):
512512
os.environ["TOKENIZERS_PARALLELISM"] = "false"
513513

514514

515+
def set_pytorch_cuda_alloc_conf():
516+
"""Set up CUDA allocation config if using PyTorch >= 2.2"""
517+
torch_version = torch.__version__.split(".")
518+
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
519+
if torch_major == 2 and torch_minor >= 2:
520+
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
521+
os.environ[
522+
"PYTORCH_CUDA_ALLOC_CONF"
523+
] = "expandable_segments:True,roundup_power2_divisions:16"
524+
525+
515526
def setup_trainer(
516527
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
517528
):

0 commit comments

Comments
 (0)