diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index 22d1a1e5c8..0f65cbf55c 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -200,6 +200,11 @@ def _replace_module(self, parent_module, child_name, new_module, old_module): new_module.state = old_module.state new_module.to(old_module.weight.device) + # dispatch to correct device + for name, module in new_module.named_modules(): + if "lora_" in name: + module.to(old_module.weight.device) + def __getattr__(self, name: str): """Forward missing attributes to the wrapped module.""" try: @@ -345,6 +350,7 @@ def forward(self, x: torch.Tensor): transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling ) self.merged = False + return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) elif self.r > 0 and not self.merged: result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)