Skip to content

Commit

Permalink
Add ia3 and adalora support (huggingface#809)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi authored and vivekgoe committed Jun 5, 2024
1 parent 345f957 commit a43f5e8
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 18 deletions.
2 changes: 1 addition & 1 deletion examples/language-modeling/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ sentencepiece != 0.1.92
protobuf
evaluate
scikit-learn
peft == 0.6.2
peft == 0.10.0
87 changes: 76 additions & 11 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import torch
import transformers
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model, tuners
from peft import AdaLoraConfig, IA3Config, LoraConfig, TaskType, get_peft_model, tuners
from peft.utils.other import fsdp_auto_wrap_policy
from transformers import (
AutoConfig,
Expand Down Expand Up @@ -289,12 +289,53 @@ class FinetuneArguments:
)
lora_target_modules: List[str] = field(
default_factory=lambda: None,
metadata={"help": "Target modules for the LoRA method."},
metadata={"help": "Target modules for the LoRA/AdaLoRA method."},
)
train_on_inputs: bool = field(
default=True,
metadata={"help": "if False, masks out inputs in loss"},
)
adalora_init_r: int = field(
default=12,
metadata={"help": "Initial AdaLoRA rank"},
)
adalora_target_r: int = field(
default=4,
metadata={"help": "Target AdaLoRA rank"},
)
adalora_tinit: int = field(
default=50,
metadata={"help": "Number of warmup steps for AdaLoRA wherein no pruning is performed"},
)
adalora_tfinal: int = field(
default=100,
metadata={
"help": "Fix the resulting budget distribution and fine-tune the model for tfinal steps when using AdaLoRA"
},
)
adalora_delta_t: int = field(
default=10,
metadata={"help": "Interval of steps for AdaLoRA to update rank"},
)
adalora_orth_reg_weight: float = field(
default=0.5,
metadata={"help": "Orthogonal regularization weight for AdaLoRA"},
)
peft_type: str = field(
default="lora",
metadata={
"help": ("The PEFT type to use."),
"choices": ["lora", "ia3", "adalora"],
},
)
ia3_target_modules: List[str] = field(
default_factory=lambda: None,
metadata={"help": "Target modules for the IA3 method."},
)
feedforward_modules: List[str] = field(
default_factory=lambda: None,
metadata={"help": "Target feedforward modules for the IA3 method."},
)


PROMPT_DICT = {
Expand Down Expand Up @@ -684,22 +725,46 @@ def compute_metrics(eval_preds):

if training_args.do_train or training_args.do_eval:
# PEFT settings
peft_config = LoraConfig(
r=finetune_args.lora_rank,
lora_alpha=finetune_args.lora_alpha,
lora_dropout=finetune_args.lora_dropout,
target_modules=finetune_args.lora_target_modules,
bias="none",
task_type=TaskType.CAUSAL_LM,
)
if finetune_args.peft_type == "lora":
peft_config = LoraConfig(
r=finetune_args.lora_rank,
lora_alpha=finetune_args.lora_alpha,
lora_dropout=finetune_args.lora_dropout,
target_modules=finetune_args.lora_target_modules,
bias="none",
task_type=TaskType.CAUSAL_LM,
)
elif finetune_args.peft_type == "adalora":
peft_config = AdaLoraConfig(
init_r=finetune_args.adalora_init_r,
target_r=finetune_args.adalora_target_r,
tinit=finetune_args.adalora_tinit,
tfinal=finetune_args.adalora_tfinal,
deltaT=finetune_args.adalora_delta_t,
lora_alpha=finetune_args.lora_alpha,
lora_dropout=finetune_args.lora_dropout,
target_modules=finetune_args.lora_target_modules,
orth_reg_weight=finetune_args.adalora_orth_reg_weight,
bias="none",
task_type=TaskType.CAUSAL_LM,
)
from optimum.habana.peft.layer import GaudiAdaloraLayerSVDLinearForward

tuners.adalora.layer.SVDLinear.forward = GaudiAdaloraLayerSVDLinearForward
elif finetune_args.peft_type == "ia3":
peft_config = IA3Config(
target_modules=finetune_args.ia3_target_modules,
feedforward_modules=finetune_args.feedforward_modules,
task_type=TaskType.CAUSAL_LM,
)
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
if training_args.torch_compile:
from optimum.habana.peft.layer import GaudiLoraLayerLinearForward

tuners.lora.layer.Linear.forward = GaudiLoraLayerLinearForward
lora_model = get_peft_model(model, peft_config)
if training_args.bf16:
if training_args.bf16 and finetune_args.peft_type != "ia3":
lora_model = lora_model.to(torch.bfloat16)
lora_model.print_trainable_parameters()
gaudi_config = GaudiConfig()
Expand Down
5 changes: 4 additions & 1 deletion examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,10 @@ def peft_model(args, model_dtype, logger, **model_kwargs):
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs)
model = PeftModel.from_pretrained(model, args.peft_model, torch_dtype=model_dtype, **model_kwargs)

return model.merge_and_unload()
model = model.merge_and_unload()
if model_dtype == torch.bfloat16:
model = model.to(torch.bfloat16)
return model


def setup_tokenizer(args, model):
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/peft/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .layer import GaudiLoraLayerLinearForward
from .layer import GaudiAdaloraLayerSVDLinearForward, GaudiLoraLayerLinearForward
30 changes: 30 additions & 0 deletions optimum/habana/peft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,33 @@ def GaudiLoraLayerLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any

result = result.to(previous_dtype)
return result


def GaudiAdaloraLayerSVDLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
"""
Copied from SVDLinear.forward: https://github.com/huggingface/peft/blob/v0.9.0/src/peft/tuners/adalora/layer.py#L158
The only differences are:
- fix batch_gemm failure for BF16 case
"""
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
lora_E = self.lora_E[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
ranknum = self.ranknum[active_adapter] + 1e-5

x = x.to(lora_A.dtype)
result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * (scaling / ranknum)

return result
24 changes: 20 additions & 4 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@
from transformers.data.data_collator import DataCollator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.integrations import hp_params
from transformers.integrations.deepspeed import deepspeed_load_checkpoint, is_deepspeed_available
from transformers.integrations.deepspeed import (
deepspeed_load_checkpoint,
is_deepspeed_available,
is_deepspeed_zero3_enabled,
)
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import _get_fsdp_ckpt_kwargs
Expand Down Expand Up @@ -116,6 +120,7 @@

if is_peft_available():
from peft import PeftModel
from peft.utils import PeftType


if is_deepspeed_available():
Expand Down Expand Up @@ -849,6 +854,10 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args):
hb_profiler.start()

total_batched_samples = 0
if _is_peft_model(self.model) and self.model.peft_type == PeftType.ADALORA:
self.model.base_model.peft_config[self.model.trainable_adapter_name].total_step = max_steps
if max_steps < self.model.base_model.peft_config[self.model.trainable_adapter_name].tfinal:
self.model.base_model.peft_config[self.model.trainable_adapter_name].tfinal = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_iterator = train_dataloader
if hasattr(epoch_iterator, "set_epoch"):
Expand Down Expand Up @@ -990,7 +999,6 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args):
# Delay optimizer scheduling until metrics are generated
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step()

self._zero_model_grad(model)

self.state.global_step += 1
Expand Down Expand Up @@ -1539,8 +1547,16 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te
if self.args.use_lazy_mode and self.args.pipelining_fwd_bwd:
self.htcore.mark_step()

self.accelerator.backward(loss)

if _is_peft_model(self.model) and self.model.peft_type == PeftType.ADALORA:
if self.is_deepspeed_enabled and not is_deepspeed_zero3_enabled():
self.accelerator.deepspeed_engine_wrapped.engine.backward(loss)
self.model.base_model.update_and_allocate(self.state.global_step)
self.accelerator.deepspeed_engine_wrapped.engine.step()
else:
self.accelerator.backward(loss)
self.model.base_model.update_and_allocate(self.state.global_step)
else:
self.accelerator.backward(loss)
return loss.detach() / self.args.gradient_accumulation_steps

def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
Expand Down

0 comments on commit a43f5e8

Please sign in to comment.