diff --git a/.gitignore b/.gitignore index 7b604d88c7..4d7ba15a1b 100644 --- a/.gitignore +++ b/.gitignore @@ -186,3 +186,6 @@ out/ # vim *.swp + +# symlinked to axolotl-artifacts in docker containers +outputs diff --git a/cicd/cicd.sh b/cicd/cicd.sh index 34a30db448..9556b513be 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -4,7 +4,6 @@ set -e python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ -# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ diff --git a/cicd/multigpu.py b/cicd/multigpu.py index f9bad386a3..f924646309 100644 --- a/cicd/multigpu.py +++ b/cicd/multigpu.py @@ -1,6 +1,6 @@ """ - modal application to run axolotl gpu tests in Modal - """ +modal application to run axolotl gpu tests in Modal +""" # pylint: disable=duplicate-code import os diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index c89715719e..9370921fd6 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -19,7 +19,7 @@ LOG = logging.getLogger(__name__) -def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: +def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]: """ Evaluates a `transformers` model by first loading the dataset(s) specified in the `axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes @@ -39,7 +39,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - evaluate(cfg=cfg, dataset_meta=dataset_meta) + return evaluate(cfg=cfg, dataset_meta=dataset_meta) def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 43e2de3db6..9df619ca89 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -8,6 +8,7 @@ import axolotl from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs +from axolotl.cli.plugins import setup_plugin_commands from axolotl.cli.utils import ( add_options_from_config, add_options_from_dataclass, @@ -222,6 +223,9 @@ def fetch(directory: str, dest: Optional[str]) -> None: fetch_from_github(f"{directory}/", dest) +setup_plugin_commands(cli) + + def main(): cli() diff --git a/src/axolotl/cli/plugins.py b/src/axolotl/cli/plugins.py new file mode 100644 index 0000000000..7f0a4e6fd5 --- /dev/null +++ b/src/axolotl/cli/plugins.py @@ -0,0 +1,36 @@ +"""Module for adding click CLI commands from axolotl plugins.""" + +import logging + +import click + +from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass +from axolotl.logging_config import configure_logging +from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig + +configure_logging() +LOG = logging.getLogger(__name__) + + +def setup_plugin_commands(cli: click.core.Group) -> None: + """ + Setup CLI commands for available plugins. + + Args: + cli: Click CLI object to add plugin CLI options to. + """ + try: + from axolotl_diff_transformer.convert_diff_transformer import do_cli + from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs + + @cli.command() + @click.argument("config", type=click.Path(exists=True, path_type=str)) + @add_options_from_dataclass(ConvertDiffTransformerCliArgs) + @add_options_from_config(AxolotlInputConfig) + def convert_diff_transformer(config: str, **kwargs): + """Convert model attention layers to differential attention layers.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + do_cli(config=config, **kwargs) + + except ImportError as exc: + LOG.debug("axolotl-diff-transformer not found: %s", exc) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index addfa0ab9c..c57f79123a 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -157,6 +157,8 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]: if isinstance(value, bool): if value: cmd.append(f"--{key}") + else: + cmd.append(f"--no{key}") else: cmd.extend([f"--{key}", str(value)]) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d63a10e742..c7340e4f54 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -297,7 +297,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): """ Training arguments for Causal trainer - This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value + This code is duplicated due to HF TrainingArguments not setting output_dir with a default value so it can't be used as a mixin. """ diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index 8d9ddc6abf..db8490432c 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -4,7 +4,7 @@ import os import sys from pathlib import Path -from typing import Dict, Optional +from typing import Optional import torch from accelerate.logging import get_logger @@ -26,7 +26,7 @@ def evaluate_dataset( trainer, dataset, dataset_type: str, flash_optimum: bool = False -) -> Optional[Dict[str, float]]: +) -> Optional[dict[str, float]]: """Helper function to evaluate a single dataset safely. Args: @@ -61,7 +61,7 @@ def evaluate_dataset( return metrics -def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: +def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> dict[str, float]: """ Evaluate a model on training and validation datasets diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py index b4ffd6758f..f7d35fcf89 100644 --- a/src/axolotl/integrations/config.py +++ b/src/axolotl/integrations/config.py @@ -43,10 +43,12 @@ def merge_input_args(): input_args: List[str] = plugin_manager.get_input_args() plugin_classes = [] dynamic_input = "" + for plugin_args in input_args: plugin_module, plugin_cls = plugin_args.rsplit(".", 1) dynamic_input += f"from {plugin_module} import {plugin_cls}\n" plugin_classes.append(plugin_cls) + if dynamic_input: dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" @@ -62,4 +64,5 @@ def merge_input_args(): "AxolotlConfigWCapabilities" ] return AxolotlConfigWCapabilities, AxolotlInputConfig + return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c4b8f05b98..48e2c65584 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -812,6 +812,7 @@ def _configure_zero3_memory_efficient_loading(): if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config + self.model = self.AutoModelLoader.from_pretrained( self.base_model, config=self.model_config, diff --git a/src/axolotl/utils/yaml.py b/src/axolotl/utils/yaml.py new file mode 100644 index 0000000000..c5c9e74ae4 --- /dev/null +++ b/src/axolotl/utils/yaml.py @@ -0,0 +1,157 @@ +"""Utilities for YAML files.""" + +from collections import OrderedDict +from typing import Any, Dict, List, Set, Tuple, Union + +import yaml + + +class YAMLOrderTracker: + """Tracks the order of keys and section breaks in YAML files.""" + + def __init__(self, yaml_path: str): + self.yaml_path = yaml_path + self.structure, self.needs_break = self._parse_yaml_structure() + + def _get_indentation_level(self, line: str) -> int: + """Get the indentation level of a line.""" + return len(line) - len(line.lstrip()) + + def _parse_yaml_structure( + self, + ) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]: + """Parse the YAML file to extract structure and identify section breaks.""" + with open(self.yaml_path, "r", encoding="utf-8") as file: + contents = file.readlines() + + structure: OrderedDict = OrderedDict() + needs_break = set() # Track which keys should have a break before them + current_path = [] + last_indentation = -1 + had_empty_line = False + + for line in contents: + # Track empty lines and comments + if not line.strip() or line.strip().startswith("#"): + had_empty_line = True + continue + + # Get indentation level and content + indentation = self._get_indentation_level(line) + content = line.strip() + + # Skip lines that don't define keys + if ":" not in content: + continue + + # Extract key + key = content.split(":")[0].strip() + + # If this is a top-level key and we had an empty line, mark it + if indentation == 0: + if had_empty_line: + needs_break.add(key) + had_empty_line = False + + # Handle indentation changes + if indentation > last_indentation: + current_path.append(key) + elif indentation < last_indentation: + levels_up = (last_indentation - indentation) // 2 + current_path = current_path[:-levels_up] + current_path[-1] = key + else: + if current_path: + current_path[-1] = key + + # Update structure + current_dict = structure + for path_key in current_path[:-1]: + if path_key not in current_dict: + current_dict[path_key] = OrderedDict() + current_dict = current_dict[path_key] + + if current_path: + if current_path[-1] not in current_dict: + current_dict[current_path[-1]] = OrderedDict() + + last_indentation = indentation + + return structure, needs_break + + +class OrderedDumper(yaml.SafeDumper): + """Custom YAML dumper that maintains dictionary order.""" + + +def represent_none(self, _): + """Represent None values as empty fields.""" + return self.represent_scalar("tag:yaml.org,2002:null", "") + + +def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any: + """Custom representer for dictionaries that maintains order.""" + return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) + + +def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict: + """Reorder a dictionary based on a reference structure.""" + ordered = OrderedDict() + + # First add keys that are in the reference order + for key in reference_structure: + if key in data: + if isinstance(reference_structure[key], dict) and isinstance( + data[key], dict + ): + ordered[key] = reorder_dict(data[key], reference_structure[key]) + else: + ordered[key] = data[key] + + # Then add any remaining keys that weren't in the reference + for key in data: + if key not in ordered: + ordered[key] = data[key] + + return ordered + + +def dump_yaml_preserved_order( + data: Dict, reference_yaml_path: str, output_path: str +) -> None: + """Dump YAML file while preserving nested order and normalized spacing.""" + # Get reference structure and spacing + tracker = YAMLOrderTracker(reference_yaml_path) + + # Reorder the data + ordered_data = reorder_dict(data, tracker.structure) + + # Register the custom representers + OrderedDumper.add_representer(type(None), represent_none) + OrderedDumper.add_representer(dict, ordered_dict_representer) + OrderedDumper.add_representer(OrderedDict, ordered_dict_representer) + + # First dump to string + yaml_str = yaml.dump( + ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False + ) + + # Add spacing according to reference + lines = yaml_str.split("\n") + result_lines: List[str] = [] + current_line = 0 + + while current_line < len(lines): + line = lines[current_line] + if line.strip() and ":" in line and not line.startswith(" "): # Top-level key + key = line.split(":")[0].strip() + if key in tracker.needs_break: + # Add single empty line before this key + if result_lines and result_lines[-1] != "": + result_lines.append("") + result_lines.append(line) + current_line += 1 + + # Write the final result + with open(output_path, "w", encoding="utf-8") as file: + file.write("\n".join(result_lines)) diff --git a/tests/cli/test_cli_base.py b/tests/cli/test_cli_base.py index 6dbae045f6..f8f1edfa3e 100644 --- a/tests/cli/test_cli_base.py +++ b/tests/cli/test_cli_base.py @@ -43,14 +43,12 @@ def _test_basic_execution( result = cli_runner.invoke(cli, [command, str(config_path)]) assert mock.called - assert mock.call_args.args[0] == [ + assert mock.call_args.args[0][:5] == [ "accelerate", "launch", "-m", f"axolotl.cli.{command}", str(config_path), - "--debug-num-examples", - "0", ] assert mock.call_args.kwargs == {"check": True} assert result.exit_code == 0 diff --git a/tests/cli/test_cli_interface.py b/tests/cli/test_cli_interface.py index 8b5fec17f2..935fb85b86 100644 --- a/tests/cli/test_cli_interface.py +++ b/tests/cli/test_cli_interface.py @@ -23,6 +23,7 @@ def test_build_command(): "--batch-size", "8", "--debug", + "--nouse-fp16", ]