|
22 | 22 | import pytest
|
23 | 23 | import torch
|
24 | 24 | from datasets import Audio, DatasetDict, load_dataset
|
| 25 | +from parameterized import parameterized |
25 | 26 | from transformers import (
|
26 | 27 | AutoModelForCausalLM,
|
27 | 28 | AutoModelForSeq2SeqLM,
|
@@ -697,7 +698,8 @@ def make_inputs_require_grad(module, input, output):
|
697 | 698 | per_device_eval_batch_size=8,
|
698 | 699 | generation_max_length=128,
|
699 | 700 | logging_steps=25,
|
700 |
| - remove_unused_columns=False, # required as the PeftModel forward doesn't have the signature of the wrapped model's forward |
| 701 | + remove_unused_columns=False, |
| 702 | + # required as the PeftModel forward doesn't have the signature of the wrapped model's forward |
701 | 703 | label_names=["labels"], # same reason as above
|
702 | 704 | )
|
703 | 705 |
|
@@ -933,3 +935,110 @@ def test_causal_lm_training_mutli_gpu(self):
|
933 | 935 |
|
934 | 936 | # assert loss is not None
|
935 | 937 | self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
| 938 | + |
| 939 | + |
| 940 | +PRECISIONS = [(torch.float32), (torch.float16), (torch.bfloat16)] |
| 941 | + |
| 942 | +LORA_PARAMS = { |
| 943 | + "r": 8, |
| 944 | + "lora_alpha": 16, |
| 945 | + "lora_dropout": 0.05, |
| 946 | +} |
| 947 | + |
| 948 | + |
| 949 | +class SimpleModel(torch.nn.Module): |
| 950 | + def __init__(self): |
| 951 | + super().__init__() |
| 952 | + |
| 953 | + self.embedding_layer = torch.nn.Embedding(1000, 768) |
| 954 | + self.layer_norm = torch.nn.LayerNorm(768) |
| 955 | + self.linear_transform = torch.nn.Linear(768, 256) |
| 956 | + |
| 957 | + def forward(self, input_ids): |
| 958 | + embedded_output = self.embedding_layer(input_ids) |
| 959 | + norm_output = self.layer_norm(embedded_output) |
| 960 | + linear_output = self.linear_transform(norm_output) |
| 961 | + |
| 962 | + return linear_output |
| 963 | + |
| 964 | + |
| 965 | +class SimpleConv2DModel(torch.nn.Module): |
| 966 | + def __init__(self): |
| 967 | + super().__init__() |
| 968 | + |
| 969 | + self.embedding_layer = torch.nn.Embedding(1000, 768) |
| 970 | + self.layer_norm = torch.nn.LayerNorm(768) |
| 971 | + self.conv2d_transform = torch.nn.Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) |
| 972 | + |
| 973 | + def forward(self, input_ids): |
| 974 | + # Additional layers for your custom model |
| 975 | + embedded_output = self.embedding_layer(input_ids) |
| 976 | + norm_output = self.layer_norm(embedded_output) |
| 977 | + |
| 978 | + # Reshape for Conv2d input (add batch size dimension) |
| 979 | + norm_output = norm_output.unsqueeze(1) |
| 980 | + conv_output = self.conv2d_transform(norm_output) |
| 981 | + |
| 982 | + # Remove batch size dimension |
| 983 | + conv_output = conv_output.squeeze(1) |
| 984 | + |
| 985 | + return conv_output |
| 986 | + |
| 987 | + |
| 988 | +@require_torch_gpu |
| 989 | +class TestAutoCast(unittest.TestCase): |
| 990 | + @parameterized.expand(PRECISIONS) |
| 991 | + def test_simple_model(self, *args, **kwargs): |
| 992 | + self._test_model(SimpleModel(), *args, **kwargs) |
| 993 | + |
| 994 | + @parameterized.expand(PRECISIONS) |
| 995 | + def test_simple_lora_linear_model(self, *args, **kwargs): |
| 996 | + simple_model = SimpleModel() |
| 997 | + config = LoraConfig( |
| 998 | + **LORA_PARAMS, |
| 999 | + target_modules=["linear_transform"], |
| 1000 | + ) |
| 1001 | + |
| 1002 | + lora_model = get_peft_model(simple_model, config) |
| 1003 | + |
| 1004 | + self._test_model(lora_model, *args, **kwargs) |
| 1005 | + |
| 1006 | + @parameterized.expand(PRECISIONS) |
| 1007 | + def test_simple_lora_embedding_model(self, *args, **kwargs): |
| 1008 | + simple_model = SimpleModel() |
| 1009 | + config = LoraConfig( |
| 1010 | + **LORA_PARAMS, |
| 1011 | + target_modules=["embedding_layer"], |
| 1012 | + ) |
| 1013 | + lora_model = get_peft_model(simple_model, config) |
| 1014 | + |
| 1015 | + self._test_model(lora_model, *args, **kwargs) |
| 1016 | + |
| 1017 | + @parameterized.expand(PRECISIONS) |
| 1018 | + def test_simple_conv2d_model(self, *args, **kwargs): |
| 1019 | + self._test_model(SimpleConv2DModel(), *args, **kwargs) |
| 1020 | + |
| 1021 | + @parameterized.expand(PRECISIONS) |
| 1022 | + def test_simple_lora_conv2d_model(self, *args, **kwargs): |
| 1023 | + simple_model = SimpleConv2DModel() |
| 1024 | + config = LoraConfig( |
| 1025 | + **LORA_PARAMS, |
| 1026 | + target_modules=["conv2d_transform"], |
| 1027 | + ) |
| 1028 | + lora_model = get_peft_model(simple_model, config) |
| 1029 | + self._test_model(lora_model, *args, **kwargs) |
| 1030 | + |
| 1031 | + def _test_model(self, model, precision): |
| 1032 | + # Move model to GPU |
| 1033 | + model = model.cuda() |
| 1034 | + |
| 1035 | + # Prepare dummy inputs |
| 1036 | + input_ids = torch.randint(0, 1000, (2, 10)).cuda() |
| 1037 | + if precision == torch.bfloat16: |
| 1038 | + if not torch.cuda.is_bf16_supported(): |
| 1039 | + return |
| 1040 | + |
| 1041 | + # Forward pass with test precision |
| 1042 | + with torch.autocast(enabled=True, dtype=precision, device_type="cuda"): |
| 1043 | + outputs = model(input_ids) |
| 1044 | + self.assertEqual(outputs.dtype, precision) |
0 commit comments