Skip to content

Commit 8c809e8

Browse files
author
Martin Fajcik
committed
review changes
1 parent 972da79 commit 8c809e8

File tree

2 files changed

+110
-160
lines changed

2 files changed

+110
-160
lines changed

tests/test_autocast_torchcompatibility_lora.py

-159
This file was deleted.

tests/test_gpu_examples.py

+110-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pytest
2323
import torch
2424
from datasets import Audio, DatasetDict, load_dataset
25+
from parameterized import parameterized
2526
from transformers import (
2627
AutoModelForCausalLM,
2728
AutoModelForSeq2SeqLM,
@@ -697,7 +698,8 @@ def make_inputs_require_grad(module, input, output):
697698
per_device_eval_batch_size=8,
698699
generation_max_length=128,
699700
logging_steps=25,
700-
remove_unused_columns=False, # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
701+
remove_unused_columns=False,
702+
# required as the PeftModel forward doesn't have the signature of the wrapped model's forward
701703
label_names=["labels"], # same reason as above
702704
)
703705

@@ -933,3 +935,110 @@ def test_causal_lm_training_mutli_gpu(self):
933935

934936
# assert loss is not None
935937
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
938+
939+
940+
PRECISIONS = [(torch.float32), (torch.float16), (torch.bfloat16)]
941+
942+
LORA_PARAMS = {
943+
"r": 8,
944+
"lora_alpha": 16,
945+
"lora_dropout": 0.05,
946+
}
947+
948+
949+
class SimpleModel(torch.nn.Module):
950+
def __init__(self):
951+
super().__init__()
952+
953+
self.embedding_layer = torch.nn.Embedding(1000, 768)
954+
self.layer_norm = torch.nn.LayerNorm(768)
955+
self.linear_transform = torch.nn.Linear(768, 256)
956+
957+
def forward(self, input_ids):
958+
embedded_output = self.embedding_layer(input_ids)
959+
norm_output = self.layer_norm(embedded_output)
960+
linear_output = self.linear_transform(norm_output)
961+
962+
return linear_output
963+
964+
965+
class SimpleConv2DModel(torch.nn.Module):
966+
def __init__(self):
967+
super().__init__()
968+
969+
self.embedding_layer = torch.nn.Embedding(1000, 768)
970+
self.layer_norm = torch.nn.LayerNorm(768)
971+
self.conv2d_transform = torch.nn.Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
972+
973+
def forward(self, input_ids):
974+
# Additional layers for your custom model
975+
embedded_output = self.embedding_layer(input_ids)
976+
norm_output = self.layer_norm(embedded_output)
977+
978+
# Reshape for Conv2d input (add batch size dimension)
979+
norm_output = norm_output.unsqueeze(1)
980+
conv_output = self.conv2d_transform(norm_output)
981+
982+
# Remove batch size dimension
983+
conv_output = conv_output.squeeze(1)
984+
985+
return conv_output
986+
987+
988+
@require_torch_gpu
989+
class TestAutoCast(unittest.TestCase):
990+
@parameterized.expand(PRECISIONS)
991+
def test_simple_model(self, *args, **kwargs):
992+
self._test_model(SimpleModel(), *args, **kwargs)
993+
994+
@parameterized.expand(PRECISIONS)
995+
def test_simple_lora_linear_model(self, *args, **kwargs):
996+
simple_model = SimpleModel()
997+
config = LoraConfig(
998+
**LORA_PARAMS,
999+
target_modules=["linear_transform"],
1000+
)
1001+
1002+
lora_model = get_peft_model(simple_model, config)
1003+
1004+
self._test_model(lora_model, *args, **kwargs)
1005+
1006+
@parameterized.expand(PRECISIONS)
1007+
def test_simple_lora_embedding_model(self, *args, **kwargs):
1008+
simple_model = SimpleModel()
1009+
config = LoraConfig(
1010+
**LORA_PARAMS,
1011+
target_modules=["embedding_layer"],
1012+
)
1013+
lora_model = get_peft_model(simple_model, config)
1014+
1015+
self._test_model(lora_model, *args, **kwargs)
1016+
1017+
@parameterized.expand(PRECISIONS)
1018+
def test_simple_conv2d_model(self, *args, **kwargs):
1019+
self._test_model(SimpleConv2DModel(), *args, **kwargs)
1020+
1021+
@parameterized.expand(PRECISIONS)
1022+
def test_simple_lora_conv2d_model(self, *args, **kwargs):
1023+
simple_model = SimpleConv2DModel()
1024+
config = LoraConfig(
1025+
**LORA_PARAMS,
1026+
target_modules=["conv2d_transform"],
1027+
)
1028+
lora_model = get_peft_model(simple_model, config)
1029+
self._test_model(lora_model, *args, **kwargs)
1030+
1031+
def _test_model(self, model, precision):
1032+
# Move model to GPU
1033+
model = model.cuda()
1034+
1035+
# Prepare dummy inputs
1036+
input_ids = torch.randint(0, 1000, (2, 10)).cuda()
1037+
if precision == torch.bfloat16:
1038+
if not torch.cuda.is_bf16_supported():
1039+
return
1040+
1041+
# Forward pass with test precision
1042+
with torch.autocast(enabled=True, dtype=precision, device_type="cuda"):
1043+
outputs = model(input_ids)
1044+
self.assertEqual(outputs.dtype, precision)

0 commit comments

Comments
 (0)