Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert-diff-transformer CLI command / codepath #2197

Open
wants to merge 44 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
7a4b296
Basic evaluate CLI command / codepath (#2188)
djsaunde Dec 16, 2024
13cdffa
initial diff attn layer / model conversion implementation (support fo…
djsaunde Dec 11, 2024
7be0d74
Adding script for doing conversion; fixes and updates
djsaunde Dec 12, 2024
df1504a
adding CLI command for convert-diff-transformer
djsaunde Dec 12, 2024
e484ec7
training fixes, patching, minor cleanup
djsaunde Dec 13, 2024
849bc94
various improvemnents
djsaunde Dec 13, 2024
2f9fa4c
various improvemnents
djsaunde Dec 13, 2024
6665acf
fix model save / load logic
djsaunde Dec 17, 2024
4c050ce
pre-commit fix
djsaunde Dec 17, 2024
41ebd93
moving monkeypatch
djsaunde Dec 17, 2024
bda1eed
differential flash attention 2; cleanup
djsaunde Dec 17, 2024
63b8e42
duplicate code ignore
djsaunde Dec 17, 2024
d22e113
convert-differential-transformer test coverage
djsaunde Dec 17, 2024
ea07a70
plugin implementation
djsaunde Dec 18, 2024
0b382c8
fixes post-rebase
djsaunde Dec 18, 2024
505321a
isolating problematic test
djsaunde Dec 18, 2024
66176b3
adding split_heads argument for retaining original (Q, K) dimensionan…
djsaunde Dec 18, 2024
1d935f6
moving tests around for flash_attn install
djsaunde Dec 18, 2024
390cb57
removing extra pytest xdist args
djsaunde Dec 19, 2024
0d56582
adding yaml dumper preserving input config format
djsaunde Dec 20, 2024
fcbfa86
refactor and fixing test isolation issues
djsaunde Dec 21, 2024
5b90da0
added modeling code; cleanup + refactor
Dec 23, 2024
a3fd507
fix duplicate-code warnings
Dec 23, 2024
4ff3328
updated custom modeling code
Dec 24, 2024
eb6611d
progress on modeling code
djsaunde Dec 24, 2024
3bc568e
adding registration function
Dec 27, 2024
78e0ec0
changes
djsaunde Dec 27, 2024
e5fa842
update
djsaunde Dec 27, 2024
332ce0a
fixes and cleanup
djsaunde Dec 28, 2024
2a7f139
pre-commit fix
djsaunde Dec 28, 2024
70c4e6f
updates and cleanup
djsaunde Jan 6, 2025
443327c
CLI build_command bugfix
djsaunde Jan 8, 2025
4f804f6
adding diff attn callback, adding documentation
djsaunde Jan 10, 2025
7aca08f
adding guard statements
djsaunde Jan 10, 2025
6dd47ed
fire CLI fixes
djsaunde Jan 10, 2025
661d71a
adding diff attn negative component warmup (in progress)
djsaunde Jan 10, 2025
fd8ad6f
fixing negative component mixing
djsaunde Jan 13, 2025
2869421
inline comment change
djsaunde Jan 14, 2025
7145d52
moving diff attn code to separate repo
djsaunde Jan 23, 2025
016ba12
README update
djsaunde Jan 23, 2025
66262c3
moving out all diff attn code to plugin repo
djsaunde Jan 24, 2025
ef38f10
merging into main
djsaunde Jan 24, 2025
0e9bfa6
small fixes, improvements
djsaunde Jan 24, 2025
2daa940
Merge branch 'main' into diff-transformer
djsaunde Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,6 @@ out/

# vim
*.swp

# symlinked to axolotl-artifacts in docker containers
outputs
1 change: 0 additions & 1 deletion cicd/cicd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"

pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
Expand Down
4 changes: 2 additions & 2 deletions cicd/multigpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
modal application to run axolotl gpu tests in Modal
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code

import os
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/cli/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
LOG = logging.getLogger(__name__)


def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]:
"""
Evaluates a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
Expand All @@ -39,7 +39,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

evaluate(cfg=cfg, dataset_meta=dataset_meta)
return evaluate(cfg=cfg, dataset_meta=dataset_meta)


def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
Expand Down
4 changes: 4 additions & 0 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.plugins import setup_plugin_commands
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
Expand Down Expand Up @@ -222,6 +223,9 @@ def fetch(directory: str, dest: Optional[str]) -> None:
fetch_from_github(f"{directory}/", dest)


setup_plugin_commands(cli)


def main():
cli()

Expand Down
36 changes: 36 additions & 0 deletions src/axolotl/cli/plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Module for adding click CLI commands from axolotl plugins."""

import logging

import click

from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass
from axolotl.logging_config import configure_logging
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig

configure_logging()
LOG = logging.getLogger(__name__)


def setup_plugin_commands(cli: click.core.Group) -> None:
"""
Setup CLI commands for available plugins.

Args:
cli: Click CLI object to add plugin CLI options to.
"""
try:
from axolotl_diff_transformer.convert_diff_transformer import do_cli
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs

@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
do_cli(config=config, **kwargs)

except ImportError as exc:
LOG.debug("axolotl-diff-transformer not found: %s", exc)
2 changes: 2 additions & 0 deletions src/axolotl/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.append(f"--no{key}")
else:
cmd.extend([f"--{key}", str(value)])

Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer

This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value
so it can't be used as a mixin.
"""

Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import sys
from pathlib import Path
from typing import Dict, Optional
from typing import Optional

import torch
from accelerate.logging import get_logger
Expand All @@ -26,7 +26,7 @@

def evaluate_dataset(
trainer, dataset, dataset_type: str, flash_optimum: bool = False
) -> Optional[Dict[str, float]]:
) -> Optional[dict[str, float]]:
"""Helper function to evaluate a single dataset safely.

Args:
Expand Down Expand Up @@ -61,7 +61,7 @@ def evaluate_dataset(
return metrics


def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> dict[str, float]:
"""
Evaluate a model on training and validation datasets

Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/integrations/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ def merge_input_args():
input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = []
dynamic_input = ""

for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls)

if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
Expand All @@ -62,4 +64,5 @@ def merge_input_args():
"AxolotlConfigWCapabilities"
]
return AxolotlConfigWCapabilities, AxolotlInputConfig

return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
1 change: 1 addition & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,7 @@ def _configure_zero3_memory_efficient_loading():

if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config

self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,
Expand Down
157 changes: 157 additions & 0 deletions src/axolotl/utils/yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Utilities for YAML files."""

from collections import OrderedDict
from typing import Any, Dict, List, Set, Tuple, Union

import yaml


class YAMLOrderTracker:
"""Tracks the order of keys and section breaks in YAML files."""

def __init__(self, yaml_path: str):
self.yaml_path = yaml_path
self.structure, self.needs_break = self._parse_yaml_structure()

def _get_indentation_level(self, line: str) -> int:
"""Get the indentation level of a line."""
return len(line) - len(line.lstrip())

def _parse_yaml_structure(
self,
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
"""Parse the YAML file to extract structure and identify section breaks."""
with open(self.yaml_path, "r", encoding="utf-8") as file:
contents = file.readlines()

structure: OrderedDict = OrderedDict()
needs_break = set() # Track which keys should have a break before them
current_path = []
last_indentation = -1
had_empty_line = False

for line in contents:
# Track empty lines and comments
if not line.strip() or line.strip().startswith("#"):
had_empty_line = True
continue

# Get indentation level and content
indentation = self._get_indentation_level(line)
content = line.strip()

# Skip lines that don't define keys
if ":" not in content:
continue

# Extract key
key = content.split(":")[0].strip()

# If this is a top-level key and we had an empty line, mark it
if indentation == 0:
if had_empty_line:
needs_break.add(key)
had_empty_line = False

# Handle indentation changes
if indentation > last_indentation:
current_path.append(key)
elif indentation < last_indentation:
levels_up = (last_indentation - indentation) // 2
current_path = current_path[:-levels_up]
current_path[-1] = key
else:
if current_path:
current_path[-1] = key

# Update structure
current_dict = structure
for path_key in current_path[:-1]:
if path_key not in current_dict:
current_dict[path_key] = OrderedDict()
current_dict = current_dict[path_key]

if current_path:
if current_path[-1] not in current_dict:
current_dict[current_path[-1]] = OrderedDict()

last_indentation = indentation

return structure, needs_break


class OrderedDumper(yaml.SafeDumper):
"""Custom YAML dumper that maintains dictionary order."""


def represent_none(self, _):
"""Represent None values as empty fields."""
return self.represent_scalar("tag:yaml.org,2002:null", "")


def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
"""Custom representer for dictionaries that maintains order."""
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())


def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
"""Reorder a dictionary based on a reference structure."""
ordered = OrderedDict()

# First add keys that are in the reference order
for key in reference_structure:
if key in data:
if isinstance(reference_structure[key], dict) and isinstance(
data[key], dict
):
ordered[key] = reorder_dict(data[key], reference_structure[key])
else:
ordered[key] = data[key]

# Then add any remaining keys that weren't in the reference
for key in data:
if key not in ordered:
ordered[key] = data[key]

return ordered


def dump_yaml_preserved_order(
data: Dict, reference_yaml_path: str, output_path: str
) -> None:
"""Dump YAML file while preserving nested order and normalized spacing."""
# Get reference structure and spacing
tracker = YAMLOrderTracker(reference_yaml_path)

# Reorder the data
ordered_data = reorder_dict(data, tracker.structure)

# Register the custom representers
OrderedDumper.add_representer(type(None), represent_none)
OrderedDumper.add_representer(dict, ordered_dict_representer)
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)

# First dump to string
yaml_str = yaml.dump(
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
)

# Add spacing according to reference
lines = yaml_str.split("\n")
result_lines: List[str] = []
current_line = 0

while current_line < len(lines):
line = lines[current_line]
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
key = line.split(":")[0].strip()
if key in tracker.needs_break:
# Add single empty line before this key
if result_lines and result_lines[-1] != "":
result_lines.append("")
result_lines.append(line)
current_line += 1

# Write the final result
with open(output_path, "w", encoding="utf-8") as file:
file.write("\n".join(result_lines))
4 changes: 1 addition & 3 deletions tests/cli/test_cli_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,12 @@ def _test_basic_execution(
result = cli_runner.invoke(cli, [command, str(config_path)])

assert mock.called
assert mock.call_args.args[0] == [
assert mock.call_args.args[0][:5] == [
"accelerate",
"launch",
"-m",
f"axolotl.cli.{command}",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
Expand Down
1 change: 1 addition & 0 deletions tests/cli/test_cli_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_build_command():
"--batch-size",
"8",
"--debug",
"--nouse-fp16",
]


Expand Down
Loading