@@ -1524,3 +1524,112 @@ def test_causal_lm_training_multi_gpu(self):
1524
1524
1525
1525
# assert loss is not None
1526
1526
assert trainer .state .log_history [- 1 ]["train_loss" ] is not None
1527
+
1528
+
1529
+ PRECISIONS = [(torch .float32 ), (torch .float16 ), (torch .bfloat16 )]
1530
+
1531
+ LORA_PARAMS = {
1532
+ "r" : 8 ,
1533
+ "lora_alpha" : 16 ,
1534
+ "lora_dropout" : 0.05 ,
1535
+ }
1536
+
1537
+
1538
+ class SimpleModel (torch .nn .Module ):
1539
+ def __init__ (self ):
1540
+ super ().__init__ ()
1541
+
1542
+ self .embedding_layer = torch .nn .Embedding (1000 , 768 )
1543
+ self .layer_norm = torch .nn .LayerNorm (768 )
1544
+ self .linear_transform = torch .nn .Linear (768 , 256 )
1545
+
1546
+ def forward (self , input_ids ):
1547
+ embedded_output = self .embedding_layer (input_ids )
1548
+ norm_output = self .layer_norm (embedded_output )
1549
+ linear_output = self .linear_transform (norm_output )
1550
+
1551
+ return linear_output
1552
+
1553
+
1554
+ class SimpleConv2DModel (torch .nn .Module ):
1555
+ def __init__ (self ):
1556
+ super ().__init__ ()
1557
+
1558
+ self .embedding_layer = torch .nn .Embedding (1000 , 768 )
1559
+ self .layer_norm = torch .nn .LayerNorm (768 )
1560
+ self .conv2d_transform = torch .nn .Conv2d (1 , 256 , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = (1 , 1 ))
1561
+
1562
+ def forward (self , input_ids ):
1563
+ # Additional layers for your custom model
1564
+ embedded_output = self .embedding_layer (input_ids )
1565
+ norm_output = self .layer_norm (embedded_output )
1566
+
1567
+ # Reshape for Conv2d input (add batch size dimension)
1568
+ norm_output = norm_output .unsqueeze (1 )
1569
+ conv_output = self .conv2d_transform (norm_output )
1570
+
1571
+ # Remove batch size dimension
1572
+ conv_output = conv_output .squeeze (1 )
1573
+
1574
+ return conv_output
1575
+
1576
+
1577
+ @require_torch_gpu
1578
+ class TestAutoCast (unittest .TestCase ):
1579
+ # This test makes sure, that Lora dtypes are consistent with the types
1580
+ # infered by torch.autocast under tested PRECISIONS
1581
+ @parameterized .expand (PRECISIONS )
1582
+ def test_simple_model (self , * args , ** kwargs ):
1583
+ self ._test_model (SimpleModel (), * args , ** kwargs )
1584
+
1585
+ @parameterized .expand (PRECISIONS )
1586
+ def test_simple_lora_linear_model (self , * args , ** kwargs ):
1587
+ simple_model = SimpleModel ()
1588
+ config = LoraConfig (
1589
+ ** LORA_PARAMS ,
1590
+ target_modules = ["linear_transform" ],
1591
+ )
1592
+
1593
+ lora_model = get_peft_model (simple_model , config )
1594
+
1595
+ self ._test_model (lora_model , * args , ** kwargs )
1596
+
1597
+ @parameterized .expand (PRECISIONS )
1598
+ def test_simple_lora_embedding_model (self , * args , ** kwargs ):
1599
+ simple_model = SimpleModel ()
1600
+ config = LoraConfig (
1601
+ ** LORA_PARAMS ,
1602
+ target_modules = ["embedding_layer" ],
1603
+ )
1604
+ lora_model = get_peft_model (simple_model , config )
1605
+
1606
+ self ._test_model (lora_model , * args , ** kwargs )
1607
+
1608
+ @parameterized .expand (PRECISIONS )
1609
+ def test_simple_conv2d_model (self , * args , ** kwargs ):
1610
+ self ._test_model (SimpleConv2DModel (), * args , ** kwargs )
1611
+
1612
+ @parameterized .expand (PRECISIONS )
1613
+ def test_simple_lora_conv2d_model (self , * args , ** kwargs ):
1614
+ simple_model = SimpleConv2DModel ()
1615
+ config = LoraConfig (
1616
+ ** LORA_PARAMS ,
1617
+ target_modules = ["conv2d_transform" ],
1618
+ )
1619
+ lora_model = get_peft_model (simple_model , config )
1620
+ self ._test_model (lora_model , * args , ** kwargs )
1621
+
1622
+ def _test_model (self , model , precision ):
1623
+ # Move model to GPU
1624
+ model = model .cuda ()
1625
+
1626
+ # Prepare dummy inputs
1627
+ input_ids = torch .randint (0 , 1000 , (2 , 10 )).cuda ()
1628
+ if precision == torch .bfloat16 :
1629
+ if not torch .cuda .is_bf16_supported ():
1630
+ self .skipTest ("Bfloat16 not supported on this device" )
1631
+
1632
+ # Forward pass with test precision
1633
+ with torch .autocast (enabled = True , dtype = precision , device_type = "cuda" ):
1634
+ outputs = model (input_ids )
1635
+ assert outputs .dtype == precision
0 commit comments