From 7643a9b79334df2332df3963ba8a5662f6cfa18c Mon Sep 17 00:00:00 2001 From: wangshuai09 <391746016@qq.com> Date: Thu, 28 Mar 2024 16:40:23 +0800 Subject: [PATCH] Add npu test --- pytest.ini | 1 + .../models/npu_tests/models_npu_test.py | 33 +++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 test/torchtext_unittest/models/npu_tests/models_npu_test.py diff --git a/pytest.ini b/pytest.ini index b9bb2d26ca..8d456e23ba 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,3 +4,4 @@ testpaths = test/ python_paths = ./ markers = gpu_test: marks cuda tests + npu_test: marks ascend npu tests diff --git a/test/torchtext_unittest/models/npu_tests/models_npu_test.py b/test/torchtext_unittest/models/npu_tests/models_npu_test.py new file mode 100644 index 0000000000..f3ff4919b2 --- /dev/null +++ b/test/torchtext_unittest/models/npu_tests/models_npu_test.py @@ -0,0 +1,33 @@ +import importlib +import unittest + +import pytest +import torch +from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase +from torchtext_unittest.models.roberta_models_test_impl import RobertaBaseTestModels +from torchtext_unittest.models.t5_models_test_impl import T5BaseTestModels + + +def is_npu_available(check_device=False): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: + return False + + import torch + import torch_npu # noqa: F401 + + if check_device: + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False + return hasattr(torch, "npu") and torch.npu.is_available() + + +@pytest.mark.npu_test +@unittest.skipIf(not is_npu_available(), reason="Ascend NPU is not available") +class TestModels32NPU(RobertaBaseTestModels, T5BaseTestModels, TorchtextTestCase): + dtype = torch.float32 + device = torch.device("npu")