Skip to content

Commit

Permalink
[usability] add hymba lora target
Browse files Browse the repository at this point in the history
  • Loading branch information
wheresmyhair committed Dec 24, 2024
1 parent 3967232 commit ea0c5ca
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ class ModelArguments:
metadata={
"help": "Merging ratio between the fine-tuned model and the original. This is controlled by a parameter called alpha in the paper."},
)
lora_target_modules: List[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name",}
lora_target_modules: str = field(
default=None, metadata={"help": "Model modules to apply LoRA to. Use comma to separate multiple modules."}
)
lora_dropout: float = field(
default=0.1,
Expand Down Expand Up @@ -364,6 +364,9 @@ def __post_init__(self):
if not is_flash_attn_available():
self.use_flash_attention = False
logger.warning("Flash attention is not available in the current environment. Disabling flash attention.")

if self.lora_target_modules is not None:
self.lora_target_modules: List[str] = split_args(self.lora_target_modules)


@dataclass
Expand Down Expand Up @@ -1464,3 +1467,7 @@ class AutoArguments:

def get_pipeline_args_class(pipeline_name: str):
return PIPELINE_ARGUMENT_MAPPING[pipeline_name]


def split_args(args):
return [elem.strip() for elem in args.split(",")] if isinstance(args, str) else args

0 comments on commit ea0c5ca

Please sign in to comment.