Skip to content

Commit f865464

Browse files
djsaundeDan Saunders
and
Dan Saunders
authored
Basic evaluate CLI command / codepath (#2188)
* basic evaluate CLI command / codepath * tests for evaluate CLI command * fixes and cleanup * review comments; slightly DRYing up things --------- Co-authored-by: Dan Saunders <dan@axolotl.ai>
1 parent 3309048 commit f865464

File tree

10 files changed

+494
-112
lines changed

10 files changed

+494
-112
lines changed

src/axolotl/cli/evaluate.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""
2+
CLI to run training on a model
3+
"""
4+
import logging
5+
from pathlib import Path
6+
from typing import Union
7+
8+
import fire
9+
from dotenv import load_dotenv
10+
from transformers.hf_argparser import HfArgumentParser
11+
12+
from axolotl.cli import (
13+
check_accelerate_default_config,
14+
check_user_token,
15+
load_cfg,
16+
load_datasets,
17+
load_rl_datasets,
18+
print_axolotl_text_art,
19+
)
20+
from axolotl.common.cli import TrainerCliArgs
21+
from axolotl.evaluate import evaluate
22+
23+
LOG = logging.getLogger("axolotl.cli.evaluate")
24+
25+
26+
def do_evaluate(cfg, cli_args) -> None:
27+
# pylint: disable=duplicate-code
28+
print_axolotl_text_art()
29+
check_accelerate_default_config()
30+
check_user_token()
31+
32+
if cfg.rl: # and cfg.rl != "orpo":
33+
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
34+
else:
35+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
36+
37+
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
38+
39+
40+
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
41+
# pylint: disable=duplicate-code
42+
parsed_cfg = load_cfg(config, **kwargs)
43+
parser = HfArgumentParser(TrainerCliArgs)
44+
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
45+
return_remaining_strings=True
46+
)
47+
do_evaluate(parsed_cfg, parsed_cli_args)
48+
49+
50+
if __name__ == "__main__":
51+
load_dotenv()
52+
fire.Fire(do_cli)

src/axolotl/cli/main.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
build_command,
1313
fetch_from_github,
1414
)
15-
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
15+
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
1616
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
1717

1818

@@ -60,6 +60,31 @@ def train(config: str, accelerate: bool, **kwargs):
6060
do_cli(config=config, **kwargs)
6161

6262

63+
@cli.command()
64+
@click.argument("config", type=click.Path(exists=True, path_type=str))
65+
@click.option(
66+
"--accelerate/--no-accelerate",
67+
default=True,
68+
help="Use accelerate launch for multi-GPU training",
69+
)
70+
@add_options_from_dataclass(EvaluateCliArgs)
71+
@add_options_from_config(AxolotlInputConfig)
72+
def evaluate(config: str, accelerate: bool, **kwargs):
73+
"""Evaluate a model."""
74+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
75+
76+
if accelerate:
77+
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
78+
if config:
79+
base_cmd.append(config)
80+
cmd = build_command(base_cmd, kwargs)
81+
subprocess.run(cmd, check=True) # nosec B603
82+
else:
83+
from axolotl.cli.evaluate import do_cli
84+
85+
do_cli(config=config, **kwargs)
86+
87+
6388
@cli.command()
6489
@click.argument("config", type=click.Path(exists=True, path_type=str))
6590
@click.option(

src/axolotl/common/cli.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@
1515
LOG = logging.getLogger("axolotl.common.cli")
1616

1717

18+
@dataclass
19+
class PreprocessCliArgs:
20+
"""
21+
dataclass representing arguments for preprocessing only
22+
"""
23+
24+
debug: bool = field(default=False)
25+
debug_text_only: bool = field(default=False)
26+
debug_num_examples: int = field(default=1)
27+
prompter: Optional[str] = field(default=None)
28+
download: Optional[bool] = field(default=True)
29+
30+
1831
@dataclass
1932
class TrainerCliArgs:
2033
"""
@@ -31,16 +44,14 @@ class TrainerCliArgs:
3144

3245

3346
@dataclass
34-
class PreprocessCliArgs:
47+
class EvaluateCliArgs:
3548
"""
36-
dataclass representing arguments for preprocessing only
49+
dataclass representing the various evaluation arguments
3750
"""
3851

3952
debug: bool = field(default=False)
4053
debug_text_only: bool = field(default=False)
41-
debug_num_examples: int = field(default=1)
42-
prompter: Optional[str] = field(default=None)
43-
download: Optional[bool] = field(default=True)
54+
debug_num_examples: int = field(default=0)
4455

4556

4657
def load_model_and_tokenizer(
@@ -50,7 +61,9 @@ def load_model_and_tokenizer(
5061
):
5162
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
5263
tokenizer = load_tokenizer(cfg)
64+
5365
LOG.info("loading model and (optionally) peft_config...")
54-
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
66+
inference = getattr(cli_args, "inference", False)
67+
model, _ = load_model(cfg, tokenizer, inference=inference)
5568

5669
return model, tokenizer

src/axolotl/evaluate.py

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""Module for evaluating models."""
2+
3+
import csv
4+
import os
5+
import sys
6+
from pathlib import Path
7+
from typing import Dict, Optional
8+
9+
import torch
10+
from accelerate.logging import get_logger
11+
12+
from axolotl.common.cli import TrainerCliArgs
13+
from axolotl.logging_config import configure_logging
14+
from axolotl.train import TrainDatasetMeta
15+
from axolotl.utils.dict import DictDefault
16+
from axolotl.utils.models import load_model, load_processor, load_tokenizer
17+
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
18+
19+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
20+
src_dir = os.path.join(project_root, "src")
21+
sys.path.insert(0, src_dir)
22+
23+
configure_logging()
24+
LOG = get_logger("axolotl.evaluate")
25+
26+
27+
def evaluate_dataset(
28+
trainer, dataset, dataset_type: str, flash_optimum: bool = False
29+
) -> Optional[Dict[str, float]]:
30+
"""Helper function to evaluate a single dataset safely.
31+
32+
Args:
33+
trainer: The trainer instance
34+
dataset: Dataset to evaluate
35+
dataset_type: Type of dataset ('train' or 'eval')
36+
flash_optimum: Whether to use flash optimum
37+
38+
Returns:
39+
Dictionary of metrics or None if dataset is None
40+
"""
41+
if dataset is None:
42+
return None
43+
44+
LOG.info(f"Starting {dataset_type} set evaluation...")
45+
46+
if flash_optimum:
47+
with torch.backends.cuda.sdp_kernel(
48+
enable_flash=True,
49+
enable_math=True,
50+
enable_mem_efficient=True,
51+
):
52+
metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type)
53+
else:
54+
metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type)
55+
56+
LOG.info(f"{dataset_type.capitalize()} set evaluation completed!")
57+
LOG.info(f"{dataset_type.capitalize()} Metrics:")
58+
for key, value in metrics.items():
59+
LOG.info(f"{key}: {value}")
60+
61+
return metrics
62+
63+
64+
def evaluate(
65+
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
66+
) -> Dict[str, float]:
67+
"""
68+
Evaluate a model on training and validation datasets
69+
70+
Args:
71+
cfg: Configuration dictionary
72+
cli_args: Command line arguments
73+
dataset_meta: Dataset metadata containing training and evaluation datasets
74+
75+
Returns:
76+
Tuple containing:
77+
- The model (either PeftModel or PreTrainedModel)
78+
- The tokenizer
79+
- Dictionary of evaluation metrics
80+
"""
81+
# pylint: disable=duplicate-code
82+
# Enable expandable segments for cuda allocation to improve VRAM usage
83+
set_pytorch_cuda_alloc_conf()
84+
85+
# Load tokenizer
86+
LOG.debug(
87+
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
88+
main_process_only=True,
89+
)
90+
tokenizer = load_tokenizer(cfg)
91+
92+
# Load processor for multimodal models if needed
93+
processor = None
94+
if cfg.is_multimodal:
95+
processor = load_processor(cfg, tokenizer)
96+
97+
# Get datasets
98+
train_dataset = dataset_meta.train_dataset
99+
eval_dataset = dataset_meta.eval_dataset
100+
total_num_steps = dataset_meta.total_num_steps
101+
102+
# Load model
103+
LOG.debug("loading model for evaluation...")
104+
model, _ = load_model(
105+
cfg, tokenizer, processor=processor, inference=cli_args.inference
106+
)
107+
108+
# Set up trainer
109+
trainer = setup_trainer(
110+
cfg,
111+
train_dataset=train_dataset,
112+
eval_dataset=eval_dataset,
113+
model=(model, None, None), # No need for model_ref or peft_config
114+
tokenizer=tokenizer,
115+
processor=processor,
116+
total_num_steps=total_num_steps,
117+
)
118+
119+
# Evaluate datasets
120+
all_metrics = {}
121+
train_metrics = evaluate_dataset(trainer, train_dataset, "train", cfg.flash_optimum)
122+
eval_metrics = evaluate_dataset(trainer, eval_dataset, "eval", cfg.flash_optimum)
123+
124+
if train_metrics:
125+
all_metrics.update(train_metrics)
126+
if eval_metrics:
127+
all_metrics.update(eval_metrics)
128+
129+
# Save metrics to CSV if output directory is specified and we have metrics
130+
if cfg.output_dir and (train_metrics or eval_metrics):
131+
output_dir = Path(cfg.output_dir)
132+
output_dir.mkdir(parents=True, exist_ok=True)
133+
134+
metrics_file = output_dir / "eval_summary.csv"
135+
with metrics_file.open("w", newline="", encoding="utf-8") as file:
136+
writer = csv.writer(file)
137+
writer.writerow(["metric", "training", "validation"])
138+
139+
# Get unique metric names (removing prefixes) from available metrics
140+
train_metric_names = {
141+
k.replace("train_", ""): k for k in (train_metrics or {})
142+
}
143+
eval_metric_names = {
144+
k.replace("eval_", ""): k for k in (eval_metrics or {})
145+
}
146+
all_metric_names = sorted(
147+
set(train_metric_names.keys()) | set(eval_metric_names.keys())
148+
)
149+
150+
for metric_name in all_metric_names:
151+
train_value = (
152+
train_metrics.get(train_metric_names.get(metric_name, ""), "")
153+
if train_metrics
154+
else ""
155+
)
156+
eval_value = (
157+
eval_metrics.get(eval_metric_names.get(metric_name, ""), "")
158+
if eval_metrics
159+
else ""
160+
)
161+
writer.writerow([metric_name, train_value, eval_value])
162+
163+
LOG.info(f"Evaluation results saved to {metrics_file}")
164+
165+
del model
166+
del tokenizer
167+
168+
return all_metrics

src/axolotl/train.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from axolotl.utils.dict import DictDefault
2525
from axolotl.utils.freeze import freeze_layers_except
2626
from axolotl.utils.models import load_model, load_processor, load_tokenizer
27-
from axolotl.utils.trainer import setup_trainer
27+
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
2828

2929
try:
3030
from optimum.bettertransformer import BetterTransformer
@@ -53,25 +53,22 @@ class TrainDatasetMeta:
5353
def train(
5454
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
5555
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
56-
# enable expandable segments for cuda allocation to improve VRAM usage
57-
torch_version = torch.__version__.split(".")
58-
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
59-
if torch_major == 2 and torch_minor >= 2:
60-
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
61-
os.environ[
62-
"PYTORCH_CUDA_ALLOC_CONF"
63-
] = "expandable_segments:True,roundup_power2_divisions:16"
64-
65-
# load the tokenizer first
56+
# Enable expandable segments for cuda allocation to improve VRAM usage
57+
set_pytorch_cuda_alloc_conf()
58+
59+
# Load tokenizer
6660
LOG.debug(
6761
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
6862
main_process_only=True,
6963
)
7064
tokenizer = load_tokenizer(cfg)
65+
66+
# Load processor for multimodal models if needed
7167
processor = None
7268
if cfg.is_multimodal:
7369
processor = load_processor(cfg, tokenizer)
7470

71+
# Get datasets
7572
train_dataset = dataset_meta.train_dataset
7673
eval_dataset = dataset_meta.eval_dataset
7774
total_num_steps = dataset_meta.total_num_steps

src/axolotl/utils/data/sft.py

+3
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
119119
eval_dataset = None
120120
if cfg.dataset_exact_deduplication:
121121
LOG.info("Deduplication not available for pretrained datasets")
122+
122123
return train_dataset, eval_dataset, cfg.max_steps, prompters
124+
123125
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
124126
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
125127
if total_eval_steps == 0:
@@ -134,6 +136,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
134136
LOG.info(f"Maximum number of steps set at {total_num_steps}")
135137
else:
136138
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
139+
137140
return train_dataset, eval_dataset, total_num_steps, prompters
138141

139142

src/axolotl/utils/trainer.py

+11
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,17 @@ def prepare_opinionated_env(cfg):
512512
os.environ["TOKENIZERS_PARALLELISM"] = "false"
513513

514514

515+
def set_pytorch_cuda_alloc_conf():
516+
"""Set up CUDA allocation config if using PyTorch >= 2.2"""
517+
torch_version = torch.__version__.split(".")
518+
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
519+
if torch_major == 2 and torch_minor >= 2:
520+
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
521+
os.environ[
522+
"PYTORCH_CUDA_ALLOC_CONF"
523+
] = "expandable_segments:True,roundup_power2_divisions:16"
524+
525+
515526
def setup_trainer(
516527
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
517528
):

0 commit comments

Comments
 (0)