Skip to content

Commit cbafd66

Browse files
MFajcikBenjaminBossan
authored andcommitted
Better respect result dtype in LoRA layers (huggingface#1010)
1 parent 33be8ea commit cbafd66

File tree

2 files changed

+116
-6
lines changed

2 files changed

+116
-6
lines changed

src/peft/tuners/lora/layer.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
297297
return output_tensor
298298

299299
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
300-
previous_dtype = x.dtype
301-
302300
if self.disable_adapters:
303301
if self.merged:
304302
self.unmerge()
@@ -307,6 +305,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
307305
result = self.base_layer(x, *args, **kwargs)
308306
else:
309307
result = self.base_layer(x, *args, **kwargs)
308+
torch_result_dtype = result.dtype
310309
for active_adapter in self.active_adapters:
311310
if active_adapter not in self.lora_A.keys():
312311
continue
@@ -317,7 +316,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
317316
x = x.to(lora_A.weight.dtype)
318317
result += lora_B(lora_A(dropout(x))) * scaling
319318

320-
result = result.to(previous_dtype)
319+
result = result.to(torch_result_dtype)
321320
return result
322321

323322
def __repr__(self) -> str:
@@ -483,6 +482,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
483482
result = self.base_layer(x, *args, **kwargs)
484483
else:
485484
result = self.base_layer(x, *args, **kwargs)
485+
torch_result_dtype = result.dtype
486486
for active_adapter in self.active_adapters:
487487
if active_adapter not in self.lora_embedding_A:
488488
continue
@@ -491,6 +491,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
491491
scaling = self.scaling[active_adapter]
492492
after_A = self._embed(x, embedding_A)
493493
result += (after_A @ embedding_B) * scaling
494+
result = result.to(torch_result_dtype)
494495

495496
return result
496497

@@ -650,8 +651,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
650651
return output_tensor
651652

652653
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
653-
previous_dtype = x.dtype
654-
655654
if self.disable_adapters:
656655
if self.merged:
657656
self.unmerge()
@@ -660,6 +659,8 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
660659
result = self.base_layer(x, *args, **kwargs)
661660
else:
662661
result = self.base_layer(x, *args, **kwargs)
662+
torch_result_dtype = result.dtype
663+
663664
for active_adapter in self.active_adapters:
664665
if active_adapter not in self.lora_A.keys():
665666
continue
@@ -670,7 +671,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
670671
x = x.to(lora_A.weight.dtype)
671672
result += lora_B(lora_A(dropout(x))) * scaling
672673

673-
result = result.to(previous_dtype)
674+
result = result.to(torch_result_dtype)
674675
return result
675676

676677
def __repr__(self) -> str:

tests/test_gpu_examples.py

+109
Original file line numberDiff line numberDiff line change
@@ -1524,3 +1524,112 @@ def test_causal_lm_training_multi_gpu(self):
15241524

15251525
# assert loss is not None
15261526
assert trainer.state.log_history[-1]["train_loss"] is not None
1527+
1528+
1529+
PRECISIONS = [(torch.float32), (torch.float16), (torch.bfloat16)]
1530+
1531+
LORA_PARAMS = {
1532+
"r": 8,
1533+
"lora_alpha": 16,
1534+
"lora_dropout": 0.05,
1535+
}
1536+
1537+
1538+
class SimpleModel(torch.nn.Module):
1539+
def __init__(self):
1540+
super().__init__()
1541+
1542+
self.embedding_layer = torch.nn.Embedding(1000, 768)
1543+
self.layer_norm = torch.nn.LayerNorm(768)
1544+
self.linear_transform = torch.nn.Linear(768, 256)
1545+
1546+
def forward(self, input_ids):
1547+
embedded_output = self.embedding_layer(input_ids)
1548+
norm_output = self.layer_norm(embedded_output)
1549+
linear_output = self.linear_transform(norm_output)
1550+
1551+
return linear_output
1552+
1553+
1554+
class SimpleConv2DModel(torch.nn.Module):
1555+
def __init__(self):
1556+
super().__init__()
1557+
1558+
self.embedding_layer = torch.nn.Embedding(1000, 768)
1559+
self.layer_norm = torch.nn.LayerNorm(768)
1560+
self.conv2d_transform = torch.nn.Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1561+
1562+
def forward(self, input_ids):
1563+
# Additional layers for your custom model
1564+
embedded_output = self.embedding_layer(input_ids)
1565+
norm_output = self.layer_norm(embedded_output)
1566+
1567+
# Reshape for Conv2d input (add batch size dimension)
1568+
norm_output = norm_output.unsqueeze(1)
1569+
conv_output = self.conv2d_transform(norm_output)
1570+
1571+
# Remove batch size dimension
1572+
conv_output = conv_output.squeeze(1)
1573+
1574+
return conv_output
1575+
1576+
1577+
@require_torch_gpu
1578+
class TestAutoCast(unittest.TestCase):
1579+
# This test makes sure, that Lora dtypes are consistent with the types
1580+
# infered by torch.autocast under tested PRECISIONS
1581+
@parameterized.expand(PRECISIONS)
1582+
def test_simple_model(self, *args, **kwargs):
1583+
self._test_model(SimpleModel(), *args, **kwargs)
1584+
1585+
@parameterized.expand(PRECISIONS)
1586+
def test_simple_lora_linear_model(self, *args, **kwargs):
1587+
simple_model = SimpleModel()
1588+
config = LoraConfig(
1589+
**LORA_PARAMS,
1590+
target_modules=["linear_transform"],
1591+
)
1592+
1593+
lora_model = get_peft_model(simple_model, config)
1594+
1595+
self._test_model(lora_model, *args, **kwargs)
1596+
1597+
@parameterized.expand(PRECISIONS)
1598+
def test_simple_lora_embedding_model(self, *args, **kwargs):
1599+
simple_model = SimpleModel()
1600+
config = LoraConfig(
1601+
**LORA_PARAMS,
1602+
target_modules=["embedding_layer"],
1603+
)
1604+
lora_model = get_peft_model(simple_model, config)
1605+
1606+
self._test_model(lora_model, *args, **kwargs)
1607+
1608+
@parameterized.expand(PRECISIONS)
1609+
def test_simple_conv2d_model(self, *args, **kwargs):
1610+
self._test_model(SimpleConv2DModel(), *args, **kwargs)
1611+
1612+
@parameterized.expand(PRECISIONS)
1613+
def test_simple_lora_conv2d_model(self, *args, **kwargs):
1614+
simple_model = SimpleConv2DModel()
1615+
config = LoraConfig(
1616+
**LORA_PARAMS,
1617+
target_modules=["conv2d_transform"],
1618+
)
1619+
lora_model = get_peft_model(simple_model, config)
1620+
self._test_model(lora_model, *args, **kwargs)
1621+
1622+
def _test_model(self, model, precision):
1623+
# Move model to GPU
1624+
model = model.cuda()
1625+
1626+
# Prepare dummy inputs
1627+
input_ids = torch.randint(0, 1000, (2, 10)).cuda()
1628+
if precision == torch.bfloat16:
1629+
if not torch.cuda.is_bf16_supported():
1630+
self.skipTest("Bfloat16 not supported on this device")
1631+
1632+
# Forward pass with test precision
1633+
with torch.autocast(enabled=True, dtype=precision, device_type="cuda"):
1634+
outputs = model(input_ids)
1635+
assert outputs.dtype == precision

0 commit comments

Comments
 (0)