|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import inspect |
15 | 16 | import itertools |
16 | 17 | import warnings |
17 | 18 | from collections.abc import Callable |
|
32 | 33 |
|
33 | 34 | if is_peft_available(): |
34 | 35 | import peft |
35 | | - from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training |
| 36 | + from peft import PeftConfig, PeftModel, get_peft_model |
36 | 37 |
|
37 | 38 |
|
38 | 39 | if TYPE_CHECKING: |
@@ -471,6 +472,51 @@ def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn. |
471 | 472 | pass |
472 | 473 |
|
473 | 474 |
|
| 475 | +def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None): |
| 476 | + r""" |
| 477 | + Prepare a k-bit quantized transformers model for training (PEFT/QLoRA). |
| 478 | + """ |
| 479 | + loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) |
| 480 | + quant_methods = ["gptq", "aqlm", "eetq", "torchao", "hqq"] |
| 481 | + is_quantized = getattr(model, "quantization_method", None) in quant_methods or getattr( |
| 482 | + model, "hqq_quantized", False |
| 483 | + ) |
| 484 | + |
| 485 | + if gradient_checkpointing_kwargs is None: |
| 486 | + gradient_checkpointing_kwargs = {} |
| 487 | + |
| 488 | + n_upcasted = 0 |
| 489 | + for name, param in model.named_parameters(): |
| 490 | + # freeze all parameters |
| 491 | + param.requires_grad = False |
| 492 | + |
| 493 | + # upcast LayerNorm / Norm to float32 for numerical stability |
| 494 | + if (param.dtype in [torch.float16, torch.bfloat16]) and ( |
| 495 | + "norm" in name.lower() or "layernorm" in name.lower() |
| 496 | + ): |
| 497 | + param.data = param.data.to(torch.float32) |
| 498 | + n_upcasted += 1 |
| 499 | + |
| 500 | + # Enable gradient checkpointing if needed |
| 501 | + if (loaded_in_kbit or is_quantized) and use_gradient_checkpointing: |
| 502 | + if hasattr(model, "enable_input_require_grads"): |
| 503 | + model.enable_input_require_grads() |
| 504 | + else: |
| 505 | + # backward-compatible hook |
| 506 | + def make_inputs_require_grad(module, input, output): |
| 507 | + output.requires_grad_(True) |
| 508 | + |
| 509 | + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
| 510 | + |
| 511 | + supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( |
| 512 | + inspect.signature(model.gradient_checkpointing_enable).parameters |
| 513 | + ) |
| 514 | + gc_kwargs = {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} if supports_gc_kwargs else {} |
| 515 | + model.gradient_checkpointing_enable(**gc_kwargs) |
| 516 | + |
| 517 | + return model |
| 518 | + |
| 519 | + |
474 | 520 | def enable_gradient_checkpointing( |
475 | 521 | model: PreTrainedModel, gradient_checkpointing_kwargs: Optional[dict] |
476 | 522 | ) -> PreTrainedModel: |
|
0 commit comments