Skip to content

Commit

Permalink
TST Enable more tests in XPU (#2036)
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany authored Aug 26, 2024
1 parent 900f96c commit 5996d39
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
replace_lora_weights_loftq,
)
from peft.tuners import boft
from peft.utils import SAFETENSORS_WEIGHTS_NAME
from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device
from peft.utils.loftq_utils import NFQuantizer
from peft.utils.other import fsdp_auto_wrap_policy

Expand All @@ -69,6 +69,7 @@
require_bitsandbytes,
require_eetq,
require_hqq,
require_non_cpu,
require_optimum,
require_torch_gpu,
require_torch_multi_gpu,
Expand Down Expand Up @@ -1379,7 +1380,7 @@ def test_non_default_adapter_name(self):
assert n_total_default == n_total_other


@require_torch_gpu
@require_non_cpu
class OffloadSaveTests(unittest.TestCase):
def setUp(self):
self.causal_lm_model_id = "gpt2"
Expand Down Expand Up @@ -1424,7 +1425,6 @@ def test_offload_load(self):
assert torch.allclose(output, offloaded_output, atol=1e-5)

@pytest.mark.single_gpu_tests
@require_torch_gpu
def test_offload_merge(self):
r"""
Test merging, unmerging, and unloading of a model with CPU- and disk- offloaded modules.
Expand Down Expand Up @@ -2158,7 +2158,7 @@ def test_notebook_launcher(self):
run_command(cmd, env=os.environ.copy())


@require_torch_gpu
@require_non_cpu
class MixedPrecisionTests(unittest.TestCase):
def setUp(self):
self.causal_lm_model_id = "facebook/opt-125m"
Expand Down Expand Up @@ -3020,8 +3020,10 @@ def forward(self, input_ids):
return conv_output


@require_torch_gpu
@require_non_cpu
class TestAutoCast(unittest.TestCase):
device = infer_device()

# This test makes sure, that Lora dtypes are consistent with the types
# infered by torch.autocast under tested PRECISIONS
@parameterized.expand(PRECISIONS)
Expand Down Expand Up @@ -3067,16 +3069,18 @@ def test_simple_lora_conv2d_model(self, *args, **kwargs):

def _test_model(self, model, precision):
# Move model to GPU
model = model.cuda()
model = model.to(self.device)

# Prepare dummy inputs
input_ids = torch.randint(0, 1000, (2, 10)).cuda()
input_ids = torch.randint(0, 1000, (2, 10)).to(self.device)
if precision == torch.bfloat16:
if not torch.cuda.is_bf16_supported():
is_xpu = self.device == "xpu"
is_cuda_bf16 = self.device == "cuda" and torch.cuda.is_bf16_supported()
if not (is_xpu or is_cuda_bf16):
self.skipTest("Bfloat16 not supported on this device")

# Forward pass with test precision
with torch.autocast(enabled=True, dtype=precision, device_type="cuda"):
with torch.autocast(enabled=True, dtype=precision, device_type=self.device):
outputs = model(input_ids)
assert outputs.dtype == precision

Expand Down

0 comments on commit 5996d39

Please sign in to comment.