Skip to content

Commit

Permalink
add arguments to activate ilab plugin
Browse files Browse the repository at this point in the history
Signed-off-by: 1000960000 user <aaron.chew1@ibm.com>
  • Loading branch information
achew010 committed Aug 5, 2024
1 parent d35a139 commit 2310f32
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 23 deletions.
2 changes: 1 addition & 1 deletion tests/acceleration/spying_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def augmentation(

def get_callbacks_and_ready_for_train(self, *args, **kwargs):
spy["get_ready_for_train_calls"] += 1
return plugin_cls.get_callbacks_and_ready_for_train(self, args, **kwargs)
return plugin_cls.get_callbacks_and_ready_for_train(self, *args, **kwargs)

attributes = {
"model_loader": model_loader,
Expand Down
17 changes: 17 additions & 0 deletions tests/acceleration/test_acceleration_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
FastKernelsConfig,
FusedLoraConfig,
)
from tuning.config.acceleration_configs.instruct_lab_config import (
InstructLabConfig,
PaddingFree,
)
from tuning.config.acceleration_configs.quantized_lora_config import (
AutoGPTQLoraConfig,
BNBQLoraConfig,
Expand Down Expand Up @@ -65,6 +69,13 @@ def test_dataclass_parse_successfully():
assert cfg.auto_gptq is None
assert isinstance(cfg.bnb_qlora, BNBQLoraConfig)

# 3. Specifing "--padding_free" will parse a PaddingFree class
parser = transformers.HfArgumentParser(dataclass_types=InstructLabConfig)
(cfg,) = parser.parse_args_into_dataclasses(
["--padding_free", "huggingface"],
)
assert isinstance(cfg.padding_free, PaddingFree)


def test_two_dataclasses_parse_successfully_together():
"""Ensure that the two dataclasses can parse arguments successfully
Expand Down Expand Up @@ -133,3 +144,9 @@ def test_dataclass_will_fail_to_accept_illegal_args():
ValueError, match="quant_type can only be either 'nf4' or 'fp4."
):
BNBQLoraConfig(quant_type="fake-quant-type")

# 3 padding-free plugin only supports huggingface models
with pytest.raises(
ValueError, match="only 'huggingface' method currently supported."
):
PaddingFree(method="invalid-method")
105 changes: 89 additions & 16 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tests.test_sft_trainer import DATA_ARGS, MODEL_ARGS, PEFT_LORA_ARGS, TRAIN_ARGS

# Local
from ..data import TWITTER_COMPLAINTS_TOKENIZED
from .spying_utils import create_mock_plugin_class_and_spy
from tuning import sft_trainer
from tuning.config.acceleration_configs import (
Expand All @@ -41,6 +42,10 @@
FastKernelsConfig,
FusedLoraConfig,
)
from tuning.config.acceleration_configs.instruct_lab_config import (
InstructLabConfig,
PaddingFree,
)
from tuning.config.acceleration_configs.quantized_lora_config import (
AutoGPTQLoraConfig,
BNBQLoraConfig,
Expand All @@ -51,7 +56,10 @@
if is_fms_accelerate_available():

# Third Party
from fms_acceleration.utils.test_utils import build_framework_and_maybe_instantiate
from fms_acceleration.utils.test_utils import (
build_framework_and_maybe_instantiate,
instantiate_model_patcher,
)

if is_fms_accelerate_available(plugins="peft"):
# Third Party
Expand All @@ -64,6 +72,10 @@
# Third Party
from fms_acceleration_foak import FastQuantizedPeftAccelerationPlugin

if is_fms_accelerate_available(plugins="ilab"):
# Third Party
from fms_acceleration_ilab import PaddingFreeAccelerationPlugin


# There are more extensive unit tests in the
# https://github.com/foundation-model-stack/fms-acceleration
Expand Down Expand Up @@ -351,6 +363,8 @@ def test_framework_intialized_properly_peft(
train_args.output_dir = tempdir
train_args.save_strategy = "no"
train_args.fp16 = True
peft_args = copy.deepcopy(PEFT_LORA_ARGS)
peft_args.target_modules = ["q_proj", "k_proj"]

installation_path, (MockedPlugin, spy) = mock_and_spy

Expand All @@ -361,13 +375,14 @@ def test_framework_intialized_properly_peft(
[([installation_path], MockedPlugin)],
instantiate=False,
):
sft_trainer.train(
model_args,
DATA_ARGS,
train_args,
PEFT_LORA_ARGS,
quantized_lora_config=quantized_lora_config,
)
with instantiate_model_patcher():
sft_trainer.train(
model_args,
DATA_ARGS,
train_args,
peft_args,
quantized_lora_config=quantized_lora_config,
)

# spy inside the train to ensure that the acceleration plugin
# was called. In the context of the AutoGPTQ plugin
Expand Down Expand Up @@ -399,6 +414,8 @@ def test_framework_intialized_properly_foak():
train_args.output_dir = tempdir
train_args.save_strategy = "no"
train_args.fp16 = True
peft_args = copy.deepcopy(PEFT_LORA_ARGS)
peft_args.target_modules = ["q_proj", "k_proj"]

# setup default quantized lora args dataclass
# - with auth gptq as the quantized method
Expand Down Expand Up @@ -428,14 +445,15 @@ def test_framework_intialized_properly_foak():
],
instantiate=False,
):
sft_trainer.train(
model_args,
DATA_ARGS,
train_args,
PEFT_LORA_ARGS,
quantized_lora_config=quantized_lora_config,
fusedops_kernels_config=fusedops_kernels_config,
)
with instantiate_model_patcher():
sft_trainer.train(
model_args,
DATA_ARGS,
train_args,
peft_args,
quantized_lora_config=quantized_lora_config,
fusedops_kernels_config=fusedops_kernels_config,
)

# spy inside the train to ensure that the AutoGPTQ plugin is called
assert spy["model_loader_calls"] == 1
Expand All @@ -446,3 +464,58 @@ def test_framework_intialized_properly_foak():
assert spy2["model_loader_calls"] == 0
assert spy2["augmentation_calls"] == 1
assert spy2["get_ready_for_train_calls"] == 1


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="ilab"),
reason="Only runs if fms-accelerate is installed along with instruct-lab plugin",
)
def test_framework_initialize_and_trains_with_ilab():
"""
Ensure that a properly configured ilab dataclass is
correctly activated in train.
"""

with tempfile.TemporaryDirectory() as tempdir:

model_args = copy.deepcopy(MODEL_ARGS)
model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3"
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.save_strategy = "no"
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = TWITTER_COMPLAINTS_TOKENIZED
data_args.response_template = None
data_args.dataset_text_field = None

# initialize a config
instruct_lab_config = InstructLabConfig(
padding_free=PaddingFree(method="huggingface")
)

# create mocked plugin class for spying
MockedPlugin1, spy = create_mock_plugin_class_and_spy(
"PaddingFreeMock", PaddingFreeAccelerationPlugin
)

# 1. mock a plugin class
# 2. register the mocked plugins
# 3. call sft_trainer.train
with build_framework_and_maybe_instantiate(
[
(["training.attention.padding_free"], MockedPlugin1),
],
instantiate=False,
):
with instantiate_model_patcher():
sft_trainer.train(
model_args,
data_args,
train_args,
instruct_lab_config=instruct_lab_config,
)

# spy inside the train to ensure that the ilab plugin is called
assert spy["model_loader_calls"] == 0
assert spy["augmentation_calls"] == 1
assert spy["get_ready_for_train_calls"] == 1
7 changes: 4 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def test_parse_arguments(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand All @@ -136,7 +137,7 @@ def test_parse_arguments_defaults(job_config):
assert "torch_dtype" not in job_config_defaults
assert job_config_defaults["use_flash_attn"] is False
assert "save_strategy" not in job_config_defaults
model_args, _, training_args, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
model_args, _, training_args, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_defaults
)
assert str(model_args.torch_dtype) == "torch.bfloat16"
Expand All @@ -148,14 +149,14 @@ def test_parse_arguments_peft_method(job_config):
parser = sft_trainer.get_parser()
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, _, tune_config, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_pt
)
assert isinstance(tune_config, peft_config.PromptTuningConfig)

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, _, tune_config, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_lora
)
assert isinstance(tune_config, peft_config.LoraConfig)
Expand Down
1 change: 1 addition & 0 deletions tuning/config/acceleration_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
# Local
from .acceleration_framework_config import AccelerationFrameworkConfig
from .fused_ops_and_kernels import FusedOpsAndKernelsConfig
from .instruct_lab_config import InstructLabConfig
from .quantized_lora_config import QuantizedLoraConfig
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

# Local
from .fused_ops_and_kernels import FastKernelsConfig, FusedLoraConfig
from .instruct_lab_config import PaddingFree
from .quantized_lora_config import AutoGPTQLoraConfig, BNBQLoraConfig
from tuning.utils.import_utils import is_fms_accelerate_available

Expand Down Expand Up @@ -98,6 +99,15 @@ class AccelerationFrameworkConfig:
),
] = None

padding_free: Annotated[
PaddingFree,
ConfigAnnotation(
path="training.attention",
experimental=True,
required_packages=["ilab"],
),
] = None

@staticmethod
def from_dataclasses(*dataclasses: Type):
"Convert one or many FMS config dataclasses to a monolithic AccelerationConfig"
Expand Down
25 changes: 25 additions & 0 deletions tuning/config/acceleration_configs/instruct_lab_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Standard
from dataclasses import dataclass

# Local
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass


@parsable_dataclass
@dataclass
class PaddingFree:

method: str = "huggingface"

def __post_init__(self):
if self.method != "huggingface":
raise ValueError("only 'huggingface' method currently supported.")

@dataclass
class InstructLabConfig:

padding_free: PaddingFree = None

def __post_init__(self):
# ensure nested dataclasses initialized
ensure_nested_dataclasses_initialized(self)
Loading

0 comments on commit 2310f32

Please sign in to comment.