diff --git a/examples/pissa_finetuning/preprocess.py b/examples/pissa_finetuning/preprocess.py index c17f75e6e5..0b0d515197 100644 --- a/examples/pissa_finetuning/preprocess.py +++ b/examples/pissa_finetuning/preprocess.py @@ -21,10 +21,13 @@ from peft import LoraConfig, get_peft_model -parser = argparse.ArgumentParser( - description="Merge Adapter to Base Model", help="The name or path of the fp32/16 base model." +parser = argparse.ArgumentParser() +parser.add_argument( + "--base_model_name_or_path", + description="Merge Adapter to Base Model", + help="The name or path of the fp32/16 base model.", ) -parser.add_argument("--base_model_name_or_path", type=str, default="bf16") +parser.add_argument("--output_dir", type=str, help="The directory to save the PiSSA model.") parser.add_argument("--bits", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) parser.add_argument( "--init_lora_weights", type=str, default="pissa", help="(`['pissa', 'pissa_niter_[number of iters]']`)"