Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DoRA to make it work with FSDP #1797

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/developer_guides/checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class LoraLayer(BaseTunerLayer):
self._disable_adapters = False
self.merged_adapters = []
self.use_dora: dict[str, bool] = {}
self.lora_magnitude_vector: Optional[torch.nn.ParameterDict] = None # for DoRA
self.lora_magnitude_vector = torch.nn.ParameterDict # for DoRA
self._caches: dict[str, Any] = {}
self.kwargs = kwargs
```
Expand Down Expand Up @@ -148,7 +148,7 @@ If you call `save_pretrained("some/path")` and the adapter name is not `"default
In some circumstances, deciding which values to add to the checkpoint file can become a bit more complicated. For example, in PEFT, DoRA is implemented as a special case of LoRA. If you want to convert a DoRA model to PEFT, you should create a LoRA checkpoint with extra entries for DoRA. You can see this in the `__init__` of the previous `LoraLayer` code:

```python
self.lora_magnitude_vector: Optional[torch.nn.ParameterDict] = None # for DoRA
self.lora_magnitude_vector = torch.nn.ParameterDict # for DoRA
```

This indicates that there is an optional extra parameter per layer for DoRA.
Expand Down
44 changes: 28 additions & 16 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self._disable_adapters = False
self.merged_adapters = []
self.use_dora: dict[str, bool] = {}
self.lora_magnitude_vector: Optional[torch.nn.ParameterDict] = None # for DoRA
self.lora_magnitude_vector = torch.nn.ParameterDict() # for DoRA
self._caches: dict[str, Any] = {}
self.kwargs = kwargs

Expand Down Expand Up @@ -215,33 +215,38 @@ def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
return weight_norm

def dora_init(self, adapter_name: str) -> None:
lora_A = self.lora_A[adapter_name].weight
lora_B = self.lora_B[adapter_name].weight
lora_A = self.lora_A[adapter_name]
lora_B = self.lora_B[adapter_name]

# temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2
dtype_is_fp16 = lora_A.dtype == torch.float16
dtype_is_fp16 = lora_A.weight.dtype == torch.float16
if dtype_is_fp16:
lora_A = lora_A.float()
lora_B = lora_B.float()
lora_A.weight.data = lora_A.weight.data.float()
lora_B.weight.data = lora_B.weight.data.float()

scaling = self.scaling[adapter_name]
with gather_params_ctx(self.get_base_layer().parameters()):
base_layer = self.get_base_layer()
weight = dequantize_module_weight(base_layer)
if weight.data.ndim == 4: # For handling LoRAs applied to Conv2Ds.
lora_weight = torch.mm(lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1))
lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1))
lora_weight = lora_weight.reshape(weight.shape)
else:
lora_weight = lora_B @ lora_A
# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
# calculate the same but using forward.
x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device)
lora_weight = lora_B(lora_A(x_eye)).T

if dtype_is_fp16:
lora_weight = lora_weight.half()
weight_norm = self._get_weight_norm(weight, lora_weight, scaling)

if self.lora_magnitude_vector is None:
self.lora_magnitude_vector = nn.ParameterDict()
if not self.lora_magnitude_vector:
# first dora layer being added
# add lora_magnitude_vector to the list of learnable parameters
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",)

self.lora_magnitude_vector[adapter_name] = nn.Parameter(weight_norm, requires_grad=True)
# add lora_magnitude_vector to the list of learnable parameters
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",)

def _cache_store(self, key: str, value: Any) -> None:
self._caches[key] = value
Expand All @@ -255,23 +260,29 @@ def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter):
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
"""
lora_weight = lora_B.weight @ lora_A.weight
lora_result = lora_B(lora_A(x))

# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
# calculate the same but using forward.
x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device)
lora_weight = lora_B(lora_A(x_eye)).T

magnitude = self.lora_magnitude_vector[active_adapter]
base_layer = self.get_base_layer()
weight = dequantize_module_weight(base_layer)
weight = weight.to(x.dtype)
weight_norm = self._get_weight_norm(weight, lora_weight, scaling)
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353)
# "[...] we suggest treating ||V +∆V ||_c in
# Eq. (5) as a constant, thereby detaching it from the gradient
# graph. This means that while ||V + ∆V ||_c dynamically
# reflects the updates of ∆V , it won’t receive any gradient
# during backpropagation"
weight_norm = self._get_weight_norm(weight, lora_weight.detach(), scaling)
weight_norm = weight_norm.detach()
mag_norm_scale = (magnitude / weight_norm).view(1, -1)
result_dora = (mag_norm_scale - 1) * (
F.linear(x, transpose(weight, self.fan_in_fan_out))
) + mag_norm_scale * lora_B(lora_A(x)) * scaling
) + mag_norm_scale * lora_result * scaling

# Note: Computation could potentially be accelerated by using the code below instead of calculating X@W again.
# This is only correct if dropout=0, otherwise results will differ:
Expand Down Expand Up @@ -1005,16 +1016,17 @@ def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter):
"""
base_layer = self.get_base_layer()
weight = base_layer.weight
# TODO: will probably not work with FSDP as forward is bypassed
lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1))
lora_weight = lora_weight.reshape(weight.shape)
magnitude = self.lora_magnitude_vector[active_adapter]
weight_norm = self._get_weight_norm(weight, lora_weight, scaling)
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353)
# "[...] we suggest treating ||V +∆V ||_c in
# Eq. (5) as a constant, thereby detaching it from the gradient
# graph. This means that while ||V + ∆V ||_c dynamically
# reflects the updates of ∆V , it won’t receive any gradient
# during backpropagation"
weight_norm = self._get_weight_norm(weight, lora_weight.detach(), scaling)
weight_norm = weight_norm.detach()
mag_norm_scale = magnitude / weight_norm
result_dora = (mag_norm_scale - 1) * (
Expand Down
3 changes: 3 additions & 0 deletions src/peft/utils/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def dequantize_module_weight(module: torch.nn.Module) -> torch.nn.Parameter:

weight = module.weight
if not isinstance(weight, torch.nn.Parameter):
if isinstance(weight, torch.Tensor):
# this is an FSDP-specific edge case
return weight
raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")

cls_name = weight.__class__.__name__
Expand Down
4 changes: 4 additions & 0 deletions tests/test_common_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,7 @@ def test_4bit_dora_merging(self):
# measure any differences, we need to change the magnitude vector.
for name, module in model.named_modules():
if isinstance(module, LoraLinear4bit):
module.dora_init(model.active_adapter) # dora is initialized lazily
module.lora_magnitude_vector["default"] = torch.nn.Parameter(
10 * torch.rand_like(module.lora_magnitude_vector["default"])
)
Expand Down Expand Up @@ -1062,6 +1063,7 @@ def test_8bit_dora_merging(self):
# measure any differences, we need to change the magnitude vector.
for name, module in model.named_modules():
if isinstance(module, LoraLinear8bitLt):
module.dora_init(model.active_adapter) # dora is initialized lazily
module.lora_magnitude_vector["default"] = torch.nn.Parameter(
10 * torch.rand_like(module.lora_magnitude_vector["default"])
)
Expand Down Expand Up @@ -1223,6 +1225,7 @@ def test_lora_dora_add_new_adapter_does_not_change_device(self, mlp):
# same as first test, but also using DoRA
config = LoraConfig(target_modules=["lin0"], use_dora=True)
model = get_peft_model(mlp, config)
model.lin0.dora_init("default") # dora is initialized lazily
model = model.cuda()
model.lin0.lora_A.cpu()
model.lin0.lora_B.cpu()
Expand All @@ -1235,6 +1238,7 @@ def test_lora_dora_add_new_adapter_does_not_change_device(self, mlp):
assert model.lin0.base_layer.weight.device.type == "cuda"

model.add_adapter("other", config)
model.lin0.dora_init("other") # dora is initialized lazily
# check that after adding a new adapter, the old adapter is still on CPU
assert model.lin0.lora_A.default.weight.device.type == "cpu"
assert model.lin0.lora_B.default.weight.device.type == "cpu"
Expand Down
Loading