diff --git a/docs/source/developer_guides/checkpoint.md b/docs/source/developer_guides/checkpoint.md index fa919487f6..77c40a3b39 100644 --- a/docs/source/developer_guides/checkpoint.md +++ b/docs/source/developer_guides/checkpoint.md @@ -105,7 +105,7 @@ class LoraLayer(BaseTunerLayer): self._disable_adapters = False self.merged_adapters = [] self.use_dora: dict[str, bool] = {} - self.lora_magnitude_vector: Optional[torch.nn.ParameterDict] = None # for DoRA + self.lora_magnitude_vector = torch.nn.ParameterDict # for DoRA self._caches: dict[str, Any] = {} self.kwargs = kwargs ``` @@ -148,7 +148,7 @@ If you call `save_pretrained("some/path")` and the adapter name is not `"default In some circumstances, deciding which values to add to the checkpoint file can become a bit more complicated. For example, in PEFT, DoRA is implemented as a special case of LoRA. If you want to convert a DoRA model to PEFT, you should create a LoRA checkpoint with extra entries for DoRA. You can see this in the `__init__` of the previous `LoraLayer` code: ```python -self.lora_magnitude_vector: Optional[torch.nn.ParameterDict] = None # for DoRA +self.lora_magnitude_vector = torch.nn.ParameterDict # for DoRA ``` This indicates that there is an optional extra parameter per layer for DoRA. diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 62fff4b9b9..26afe57abb 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -51,7 +51,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: self._disable_adapters = False self.merged_adapters = [] self.use_dora: dict[str, bool] = {} - self.lora_magnitude_vector: Optional[torch.nn.ParameterDict] = None # for DoRA + self.lora_magnitude_vector = torch.nn.ParameterDict() # for DoRA self._caches: dict[str, Any] = {} self.kwargs = kwargs @@ -215,33 +215,38 @@ def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: return weight_norm def dora_init(self, adapter_name: str) -> None: - lora_A = self.lora_A[adapter_name].weight - lora_B = self.lora_B[adapter_name].weight + lora_A = self.lora_A[adapter_name] + lora_B = self.lora_B[adapter_name] + # temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2 - dtype_is_fp16 = lora_A.dtype == torch.float16 + dtype_is_fp16 = lora_A.weight.dtype == torch.float16 if dtype_is_fp16: - lora_A = lora_A.float() - lora_B = lora_B.float() + lora_A.weight.data = lora_A.weight.data.float() + lora_B.weight.data = lora_B.weight.data.float() scaling = self.scaling[adapter_name] with gather_params_ctx(self.get_base_layer().parameters()): base_layer = self.get_base_layer() weight = dequantize_module_weight(base_layer) if weight.data.ndim == 4: # For handling LoRAs applied to Conv2Ds. - lora_weight = torch.mm(lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1)) + lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) lora_weight = lora_weight.reshape(weight.shape) else: - lora_weight = lora_B @ lora_A + # Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead, + # calculate the same but using forward. + x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device) + lora_weight = lora_B(lora_A(x_eye)).T if dtype_is_fp16: lora_weight = lora_weight.half() weight_norm = self._get_weight_norm(weight, lora_weight, scaling) - if self.lora_magnitude_vector is None: - self.lora_magnitude_vector = nn.ParameterDict() + if not self.lora_magnitude_vector: + # first dora layer being added + # add lora_magnitude_vector to the list of learnable parameters + self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",) + self.lora_magnitude_vector[adapter_name] = nn.Parameter(weight_norm, requires_grad=True) - # add lora_magnitude_vector to the list of learnable parameters - self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",) def _cache_store(self, key: str, value: Any) -> None: self._caches[key] = value @@ -255,23 +260,29 @@ def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter): For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer output. """ - lora_weight = lora_B.weight @ lora_A.weight + lora_result = lora_B(lora_A(x)) + + # Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead, + # calculate the same but using forward. + x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device) + lora_weight = lora_B(lora_A(x_eye)).T + magnitude = self.lora_magnitude_vector[active_adapter] base_layer = self.get_base_layer() weight = dequantize_module_weight(base_layer) weight = weight.to(x.dtype) - weight_norm = self._get_weight_norm(weight, lora_weight, scaling) # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) # "[...] we suggest treating ||V +∆V ||_c in # Eq. (5) as a constant, thereby detaching it from the gradient # graph. This means that while ||V + ∆V ||_c dynamically # reflects the updates of ∆V , it won’t receive any gradient # during backpropagation" + weight_norm = self._get_weight_norm(weight, lora_weight.detach(), scaling) weight_norm = weight_norm.detach() mag_norm_scale = (magnitude / weight_norm).view(1, -1) result_dora = (mag_norm_scale - 1) * ( F.linear(x, transpose(weight, self.fan_in_fan_out)) - ) + mag_norm_scale * lora_B(lora_A(x)) * scaling + ) + mag_norm_scale * lora_result * scaling # Note: Computation could potentially be accelerated by using the code below instead of calculating X@W again. # This is only correct if dropout=0, otherwise results will differ: @@ -1005,16 +1016,17 @@ def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter): """ base_layer = self.get_base_layer() weight = base_layer.weight + # TODO: will probably not work with FSDP as forward is bypassed lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) lora_weight = lora_weight.reshape(weight.shape) magnitude = self.lora_magnitude_vector[active_adapter] - weight_norm = self._get_weight_norm(weight, lora_weight, scaling) # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) # "[...] we suggest treating ||V +∆V ||_c in # Eq. (5) as a constant, thereby detaching it from the gradient # graph. This means that while ||V + ∆V ||_c dynamically # reflects the updates of ∆V , it won’t receive any gradient # during backpropagation" + weight_norm = self._get_weight_norm(weight, lora_weight.detach(), scaling) weight_norm = weight_norm.detach() mag_norm_scale = magnitude / weight_norm result_dora = (mag_norm_scale - 1) * ( diff --git a/src/peft/utils/integrations.py b/src/peft/utils/integrations.py index 14b37e32e8..10c76b68e1 100644 --- a/src/peft/utils/integrations.py +++ b/src/peft/utils/integrations.py @@ -52,6 +52,9 @@ def dequantize_module_weight(module: torch.nn.Module) -> torch.nn.Parameter: weight = module.weight if not isinstance(weight, torch.nn.Parameter): + if isinstance(weight, torch.Tensor): + # this is an FSDP-specific edge case + return weight raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead") cls_name = weight.__class__.__name__ diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index 43a19bdb6d..7c869d0e47 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -1010,6 +1010,7 @@ def test_4bit_dora_merging(self): # measure any differences, we need to change the magnitude vector. for name, module in model.named_modules(): if isinstance(module, LoraLinear4bit): + module.dora_init(model.active_adapter) # dora is initialized lazily module.lora_magnitude_vector["default"] = torch.nn.Parameter( 10 * torch.rand_like(module.lora_magnitude_vector["default"]) ) @@ -1062,6 +1063,7 @@ def test_8bit_dora_merging(self): # measure any differences, we need to change the magnitude vector. for name, module in model.named_modules(): if isinstance(module, LoraLinear8bitLt): + module.dora_init(model.active_adapter) # dora is initialized lazily module.lora_magnitude_vector["default"] = torch.nn.Parameter( 10 * torch.rand_like(module.lora_magnitude_vector["default"]) ) @@ -1223,6 +1225,7 @@ def test_lora_dora_add_new_adapter_does_not_change_device(self, mlp): # same as first test, but also using DoRA config = LoraConfig(target_modules=["lin0"], use_dora=True) model = get_peft_model(mlp, config) + model.lin0.dora_init("default") # dora is initialized lazily model = model.cuda() model.lin0.lora_A.cpu() model.lin0.lora_B.cpu() @@ -1235,6 +1238,7 @@ def test_lora_dora_add_new_adapter_does_not_change_device(self, mlp): assert model.lin0.base_layer.weight.device.type == "cuda" model.add_adapter("other", config) + model.lin0.dora_init("other") # dora is initialized lazily # check that after adding a new adapter, the old adapter is still on CPU assert model.lin0.lora_A.default.weight.device.type == "cpu" assert model.lin0.lora_B.default.weight.device.type == "cpu"