Skip to content

Commit adc024b

Browse files
committed
refactor and fixing test isolation issues
1 parent 2717b97 commit adc024b

File tree

7 files changed

+156
-149
lines changed

7 files changed

+156
-149
lines changed

src/axolotl/cli/integrations/convert_diff_transformer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
"""CLI to convert a transformers model's attns to diff attns."""
1+
"""CLI to convert a transformers model's attention layers to differential attention layers."""
2+
23
import logging
34
import warnings
45
from pathlib import Path
@@ -127,6 +128,7 @@ def convert_diff_transformer(cfg, cli_args, config_path):
127128
else:
128129
modified_cfg["plugins"] = [plugin_class]
129130

131+
# Write out the updated axolotl config while preserving original ordering / formatting
130132
dump_yaml_preserved_order(
131133
data=modified_cfg,
132134
reference_yaml_path=config_path,

src/axolotl/common/cli.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@
1212
from axolotl.utils.models import load_model, load_tokenizer
1313

1414
configure_logging()
15-
LOG = logging.getLogger("axolotl.common.cli")
15+
LOG = logging.getLogger(__name__)
1616

1717

1818
@dataclass
1919
class PreprocessCliArgs:
20-
"""
21-
dataclass with arguments for preprocessing only
22-
"""
20+
"""dataclass with arguments for preprocessing only"""
2321

2422
debug: bool = field(default=False)
2523
debug_text_only: bool = field(default=False)
@@ -30,9 +28,7 @@ class PreprocessCliArgs:
3028

3129
@dataclass
3230
class TrainerCliArgs:
33-
"""
34-
dataclass with various non-training arguments
35-
"""
31+
"""dataclass with various non-training arguments"""
3632

3733
debug: bool = field(default=False)
3834
debug_text_only: bool = field(default=False)
@@ -45,9 +41,7 @@ class TrainerCliArgs:
4541

4642
@dataclass
4743
class EvaluateCliArgs:
48-
"""
49-
dataclass with various evaluation arguments
50-
"""
44+
"""dataclass with various evaluation arguments"""
5145

5246
debug: bool = field(default=False)
5347
debug_text_only: bool = field(default=False)
@@ -56,9 +50,7 @@ class EvaluateCliArgs:
5650

5751
@dataclass
5852
class ConvertDiffTransformerCliArgs:
59-
"""
60-
dataclass with arguments for convert-diff-transformer CLI
61-
"""
53+
"""dataclass with arguments for convert-diff-transformer CLI"""
6254

6355
debug: bool = field(default=False)
6456
zero_init: bool = field(default=False)

src/axolotl/integrations/diff_transformer/convert.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,13 @@ def convert_module(module):
9898

9999
# Iterate through module children, convert any attn layers to diff attn
100100
for name, child in module.named_children():
101-
if isinstance(child, tuple(ATTENTION_MAPPING.keys())):
102-
# Choose appropriate differential attention class
103-
attention_class = ATTENTION_MAPPING[type(child)]
101+
child_class_name = type(child).__name__
102+
if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]:
103+
# Find matching attention class by name
104+
for orig_class, diff_class in ATTENTION_MAPPING.items():
105+
if orig_class.__name__ == child_class_name:
106+
attention_class = diff_class
107+
break
104108

105109
layer_type = type(child).__name__
106110
logger.info(

src/axolotl/integrations/diff_transformer/diff_attn.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222

2323
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
24-
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
2524
batch_size, n_kv_heads, slen, head_dim = x.shape
2625
if n_rep == 1:
2726
return x
@@ -249,6 +248,7 @@ def forward(
249248
class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
250249
"""SDPA-based implementation of differential attention."""
251250

251+
# pylint: disable=duplicate-code
252252
def forward(
253253
self,
254254
hidden_states: torch.Tensor,
@@ -312,6 +312,7 @@ def forward(
312312
class LlamaDifferentialFlashAttention2(DifferentialAttentionBase):
313313
"""Flash Attention 2-based implementation of differential attention."""
314314

315+
# pylint: disable=duplicate-code
315316
def forward(
316317
self,
317318
hidden_states: torch.Tensor,

src/axolotl/utils/yaml.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ class OrderedDumper(yaml.SafeDumper):
8484
"""Custom YAML dumper that maintains dictionary order."""
8585

8686

87+
def represent_none(self, _):
88+
"""Represent None values as empty fields."""
89+
return self.represent_scalar("tag:yaml.org,2002:null", "")
90+
91+
8792
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
8893
"""Custom representer for dictionaries that maintains order."""
8994
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
@@ -121,7 +126,8 @@ def dump_yaml_preserved_order(
121126
# Reorder the data
122127
ordered_data = reorder_dict(data, tracker.structure)
123128

124-
# Register the custom representer
129+
# Register the custom representers
130+
OrderedDumper.add_representer(type(None), represent_none)
125131
OrderedDumper.add_representer(dict, ordered_dict_representer)
126132
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
127133

tests/e2e/integrations/convert_diff_transformer/conftest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from click.testing import CliRunner
55

66

7-
@pytest.fixture()
7+
@pytest.fixture(scope="class")
88
def base_config():
99
"""Basic config for testing."""
1010
return {
@@ -26,6 +26,6 @@ def base_config():
2626
}
2727

2828

29-
@pytest.fixture
29+
@pytest.fixture(scope="class")
3030
def cli_runner():
3131
return CliRunner()

0 commit comments

Comments
 (0)