diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 1070fe0f79..e3bf808978 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -3,6 +3,7 @@ import itertools import os import sys +import unittest from abc import ( ABC, abstractmethod, @@ -33,6 +34,11 @@ Backend, ) +from ..utils import ( + CI, + TEST_DEVICE, +) + INSTALLED_TF = Backend.get_backend("tensorflow")().is_available() INSTALLED_PT = Backend.get_backend("pytorch")().is_available() INSTALLED_JAX = Backend.get_backend("jax")().is_available() @@ -340,6 +346,7 @@ def test_tf_self_consistent(self): np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") def test_dp_consistent_with_ref(self): """Test whether DP and reference are consistent.""" if self.skip_dp: @@ -358,6 +365,7 @@ def test_dp_consistent_with_ref(self): np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") def test_dp_self_consistent(self): """Test whether DP is self consistent.""" if self.skip_dp: @@ -447,6 +455,7 @@ def test_jax_self_consistent(self): else: self.assertEqual(rr1, rr2) + @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") def test_array_api_strict_consistent_with_ref(self): """Test whether array_api_strict and reference are consistent.""" if self.skip_array_api_strict: @@ -465,6 +474,7 @@ def test_array_api_strict_consistent_with_ref(self): np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") def test_array_api_strict_self_consistent(self): """Test whether array_api_strict is self consistent.""" if self.skip_array_api_strict: diff --git a/source/tests/universal/common/cases/model/utils.py b/source/tests/universal/common/cases/model/utils.py index d583d06b05..628c415eb2 100644 --- a/source/tests/universal/common/cases/model/utils.py +++ b/source/tests/universal/common/cases/model/utils.py @@ -22,6 +22,7 @@ GLOBAL_SEED, ) from .....utils import ( + CI, TEST_DEVICE, ) @@ -327,7 +328,7 @@ def test_zero_forward(self): continue np.testing.assert_allclose(rr1, rr2, atol=aprec) - @unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") + @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") def test_permutation(self): """Test permutation.""" if getattr(self, "skip_test_permutation", False): @@ -413,7 +414,7 @@ def test_permutation(self): else: raise RuntimeError(f"Unknown output key: {kk}") - @unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") + @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") def test_trans(self): """Test translation.""" if getattr(self, "skip_test_trans", False): @@ -482,7 +483,7 @@ def test_trans(self): else: raise RuntimeError(f"Unknown output key: {kk}") - @unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") + @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") def test_rot(self): """Test rotation.""" if getattr(self, "skip_test_rot", False): @@ -672,7 +673,7 @@ def test_rot(self): else: raise RuntimeError(f"Unknown output key: {kk}") - @unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") + @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") def test_smooth(self): """Test smooth.""" if getattr(self, "skip_test_smooth", False): @@ -779,7 +780,7 @@ def test_smooth(self): else: raise RuntimeError(f"Unknown output key: {kk}") - @unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") + @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") def test_autodiff(self): """Test autodiff.""" if getattr(self, "skip_test_autodiff", False): @@ -919,7 +920,7 @@ def ff_cell(bb): # not support virial by far pass - @unittest.skipIf(TEST_DEVICE == "cpu", "Skip test on CPU.") + @unittest.skipIf(TEST_DEVICE == "cpu" and CI, "Skip test on CPU.") def test_device_consistence(self): """Test forward consistency between devices.""" test_spin = getattr(self, "test_spin", False) diff --git a/source/tests/universal/dpmodel/atomc_model/test_atomic_model.py b/source/tests/universal/dpmodel/atomc_model/test_atomic_model.py index 4c5a2b291b..8e7324e2bc 100644 --- a/source/tests/universal/dpmodel/atomc_model/test_atomic_model.py +++ b/source/tests/universal/dpmodel/atomc_model/test_atomic_model.py @@ -26,6 +26,7 @@ parameterized, ) from ....utils import ( + CI, TEST_DEVICE, ) from ...common.cases.atomic_model.atomic_model import ( @@ -98,7 +99,7 @@ ), # fitting_class_param & class ), ) -@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") +@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") class TestEnergyAtomicModelDP(unittest.TestCase, EnerAtomicModelTest, DPTestCase): @classmethod def setUpClass(cls): @@ -165,7 +166,7 @@ def setUpClass(cls): ), # fitting_class_param & class ), ) -@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") +@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") class TestDosAtomicModelDP(unittest.TestCase, DosAtomicModelTest, DPTestCase): @classmethod def setUpClass(cls): @@ -227,7 +228,7 @@ def setUpClass(cls): ), # fitting_class_param & class ), ) -@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") +@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") class TestDipoleAtomicModelDP(unittest.TestCase, DipoleAtomicModelTest, DPTestCase): @classmethod def setUpClass(cls): @@ -290,7 +291,7 @@ def setUpClass(cls): ), # fitting_class_param & class ), ) -@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") +@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") class TestPolarAtomicModelDP(unittest.TestCase, PolarAtomicModelTest, DPTestCase): @classmethod def setUpClass(cls): @@ -351,7 +352,7 @@ def setUpClass(cls): ), # fitting_class_param & class ), ) -@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") +@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") class TestZBLAtomicModelDP(unittest.TestCase, ZBLAtomicModelTest, DPTestCase): @classmethod def setUpClass(cls): @@ -429,7 +430,7 @@ def setUpClass(cls): ), # fitting_class_param & class ), ) -@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") +@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") class TestPropertyAtomicModelDP(unittest.TestCase, PropertyAtomicModelTest, DPTestCase): @classmethod def setUpClass(cls): diff --git a/source/tests/universal/dpmodel/descriptor/test_descriptor.py b/source/tests/universal/dpmodel/descriptor/test_descriptor.py index 256bea74f8..fc7ee8b075 100644 --- a/source/tests/universal/dpmodel/descriptor/test_descriptor.py +++ b/source/tests/universal/dpmodel/descriptor/test_descriptor.py @@ -26,6 +26,7 @@ GLOBAL_SEED, ) from ....utils import ( + CI, TEST_DEVICE, ) from ...common.cases.descriptor.descriptor import ( @@ -519,7 +520,7 @@ def DescriptorParamHybridMixedTTebd(ntypes, rcut, rcut_smth, sel, type_map, **kw (DescriptorParamHybridMixedTTebd, DescrptHybrid), ) # class_param & class ) -@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") +@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") class TestDescriptorDP(unittest.TestCase, DescriptorTest, DPTestCase): def setUp(self): DescriptorTest.setUp(self) diff --git a/source/tests/universal/dpmodel/fitting/test_fitting.py b/source/tests/universal/dpmodel/fitting/test_fitting.py index 393bab1707..f64faee76f 100644 --- a/source/tests/universal/dpmodel/fitting/test_fitting.py +++ b/source/tests/universal/dpmodel/fitting/test_fitting.py @@ -20,6 +20,7 @@ GLOBAL_SEED, ) from ....utils import ( + CI, TEST_DEVICE, ) from ...common.cases.fitting.fitting import ( @@ -236,7 +237,7 @@ def FittingParamProperty( ), # class_param & class (True, False), # mixed_types ) -@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") +@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") class TestFittingDP(unittest.TestCase, FittingTest, DPTestCase): def setUp(self): ((FittingParam, Fitting), self.mixed_types) = self.param diff --git a/source/tests/universal/dpmodel/model/test_model.py b/source/tests/universal/dpmodel/model/test_model.py index 66edc2d50e..265dc43c6c 100644 --- a/source/tests/universal/dpmodel/model/test_model.py +++ b/source/tests/universal/dpmodel/model/test_model.py @@ -25,6 +25,7 @@ parameterized, ) from ....utils import ( + CI, TEST_DEVICE, ) from ...common.cases.model.model import ( @@ -112,7 +113,7 @@ def skip_model_tests(test_obj): ), # fitting_class_param & class ), ) -@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") +@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") class TestEnergyModelDP(unittest.TestCase, EnerModelTest, DPTestCase): @classmethod def setUpClass(cls): @@ -200,7 +201,7 @@ def setUpClass(cls): ), # fitting_class_param & class ), ) -@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") +@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") class TestSpinEnergyModelDP(unittest.TestCase, SpinEnerModelTest, DPTestCase): @classmethod def setUpClass(cls): diff --git a/source/tests/universal/dpmodel/utils/test_type_embed.py b/source/tests/universal/dpmodel/utils/test_type_embed.py index 67faef0a8d..ee3063af7d 100644 --- a/source/tests/universal/dpmodel/utils/test_type_embed.py +++ b/source/tests/universal/dpmodel/utils/test_type_embed.py @@ -6,6 +6,7 @@ ) from ....utils import ( + CI, TEST_DEVICE, ) from ...common.cases.utils.type_embed import ( @@ -16,7 +17,7 @@ ) -@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") +@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") class TestTypeEmbd(unittest.TestCase, TypeEmbdTest, DPTestCase): def setUp(self): TypeEmbdTest.setUp(self) diff --git a/source/tests/utils.py b/source/tests/utils.py index 694f55186e..bfb3d445af 100644 --- a/source/tests/utils.py +++ b/source/tests/utils.py @@ -5,3 +5,6 @@ TEST_DEVICE = "cpu" else: TEST_DEVICE = "cuda" + +# see https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables +CI = os.environ.get("CI") == "true"