Skip to content

Commit

Permalink
chore(ci): skip more tests on GPU CI (#4200)
Browse files Browse the repository at this point in the history
Also, only skip these GPU tests on the CI. When we test locally, it's
expected to run the tests.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a global variable `CI` to enhance test execution control
based on the continuous integration environment.
  
- **Bug Fixes**
- Updated test skipping conditions across multiple test classes to
ensure tests are only executed on CPU when the CI environment is active.

- **Documentation**
- Enhanced clarity on test conditions by including the `CI` variable in
relevant test decorators.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Oct 11, 2024
1 parent 2ca1c06 commit 8174cf1
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 17 deletions.
10 changes: 10 additions & 0 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import os
import sys
import unittest
from abc import (
ABC,
abstractmethod,
Expand Down Expand Up @@ -33,6 +34,11 @@
Backend,
)

from ..utils import (
CI,
TEST_DEVICE,
)

INSTALLED_TF = Backend.get_backend("tensorflow")().is_available()
INSTALLED_PT = Backend.get_backend("pytorch")().is_available()
INSTALLED_JAX = Backend.get_backend("jax")().is_available()
Expand Down Expand Up @@ -340,6 +346,7 @@ def test_tf_self_consistent(self):
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_dp_consistent_with_ref(self):
"""Test whether DP and reference are consistent."""
if self.skip_dp:
Expand All @@ -358,6 +365,7 @@ def test_dp_consistent_with_ref(self):
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_dp_self_consistent(self):
"""Test whether DP is self consistent."""
if self.skip_dp:
Expand Down Expand Up @@ -447,6 +455,7 @@ def test_jax_self_consistent(self):
else:
self.assertEqual(rr1, rr2)

@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_array_api_strict_consistent_with_ref(self):
"""Test whether array_api_strict and reference are consistent."""
if self.skip_array_api_strict:
Expand All @@ -465,6 +474,7 @@ def test_array_api_strict_consistent_with_ref(self):
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_array_api_strict_self_consistent(self):
"""Test whether array_api_strict is self consistent."""
if self.skip_array_api_strict:
Expand Down
13 changes: 7 additions & 6 deletions source/tests/universal/common/cases/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
GLOBAL_SEED,
)
from .....utils import (
CI,
TEST_DEVICE,
)

Expand Down Expand Up @@ -327,7 +328,7 @@ def test_zero_forward(self):
continue
np.testing.assert_allclose(rr1, rr2, atol=aprec)

@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_permutation(self):
"""Test permutation."""
if getattr(self, "skip_test_permutation", False):
Expand Down Expand Up @@ -413,7 +414,7 @@ def test_permutation(self):
else:
raise RuntimeError(f"Unknown output key: {kk}")

@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_trans(self):
"""Test translation."""
if getattr(self, "skip_test_trans", False):
Expand Down Expand Up @@ -482,7 +483,7 @@ def test_trans(self):
else:
raise RuntimeError(f"Unknown output key: {kk}")

@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_rot(self):
"""Test rotation."""
if getattr(self, "skip_test_rot", False):
Expand Down Expand Up @@ -672,7 +673,7 @@ def test_rot(self):
else:
raise RuntimeError(f"Unknown output key: {kk}")

@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_smooth(self):
"""Test smooth."""
if getattr(self, "skip_test_smooth", False):
Expand Down Expand Up @@ -779,7 +780,7 @@ def test_smooth(self):
else:
raise RuntimeError(f"Unknown output key: {kk}")

@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_autodiff(self):
"""Test autodiff."""
if getattr(self, "skip_test_autodiff", False):
Expand Down Expand Up @@ -919,7 +920,7 @@ def ff_cell(bb):
# not support virial by far
pass

@unittest.skipIf(TEST_DEVICE == "cpu", "Skip test on CPU.")
@unittest.skipIf(TEST_DEVICE == "cpu" and CI, "Skip test on CPU.")
def test_device_consistence(self):
"""Test forward consistency between devices."""
test_spin = getattr(self, "test_spin", False)
Expand Down
13 changes: 7 additions & 6 deletions source/tests/universal/dpmodel/atomc_model/test_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
parameterized,
)
from ....utils import (
CI,
TEST_DEVICE,
)
from ...common.cases.atomic_model.atomic_model import (
Expand Down Expand Up @@ -98,7 +99,7 @@
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestEnergyAtomicModelDP(unittest.TestCase, EnerAtomicModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -165,7 +166,7 @@ def setUpClass(cls):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestDosAtomicModelDP(unittest.TestCase, DosAtomicModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -227,7 +228,7 @@ def setUpClass(cls):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestDipoleAtomicModelDP(unittest.TestCase, DipoleAtomicModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -290,7 +291,7 @@ def setUpClass(cls):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestPolarAtomicModelDP(unittest.TestCase, PolarAtomicModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -351,7 +352,7 @@ def setUpClass(cls):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestZBLAtomicModelDP(unittest.TestCase, ZBLAtomicModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -429,7 +430,7 @@ def setUpClass(cls):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestPropertyAtomicModelDP(unittest.TestCase, PropertyAtomicModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down
3 changes: 2 additions & 1 deletion source/tests/universal/dpmodel/descriptor/test_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
GLOBAL_SEED,
)
from ....utils import (
CI,
TEST_DEVICE,
)
from ...common.cases.descriptor.descriptor import (
Expand Down Expand Up @@ -519,7 +520,7 @@ def DescriptorParamHybridMixedTTebd(ntypes, rcut, rcut_smth, sel, type_map, **kw
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
) # class_param & class
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestDescriptorDP(unittest.TestCase, DescriptorTest, DPTestCase):
def setUp(self):
DescriptorTest.setUp(self)
Expand Down
3 changes: 2 additions & 1 deletion source/tests/universal/dpmodel/fitting/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
GLOBAL_SEED,
)
from ....utils import (
CI,
TEST_DEVICE,
)
from ...common.cases.fitting.fitting import (
Expand Down Expand Up @@ -236,7 +237,7 @@ def FittingParamProperty(
), # class_param & class
(True, False), # mixed_types
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestFittingDP(unittest.TestCase, FittingTest, DPTestCase):
def setUp(self):
((FittingParam, Fitting), self.mixed_types) = self.param
Expand Down
5 changes: 3 additions & 2 deletions source/tests/universal/dpmodel/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
parameterized,
)
from ....utils import (
CI,
TEST_DEVICE,
)
from ...common.cases.model.model import (
Expand Down Expand Up @@ -112,7 +113,7 @@ def skip_model_tests(test_obj):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestEnergyModelDP(unittest.TestCase, EnerModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -200,7 +201,7 @@ def setUpClass(cls):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestSpinEnergyModelDP(unittest.TestCase, SpinEnerModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down
3 changes: 2 additions & 1 deletion source/tests/universal/dpmodel/utils/test_type_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)

from ....utils import (
CI,
TEST_DEVICE,
)
from ...common.cases.utils.type_embed import (
Expand All @@ -16,7 +17,7 @@
)


@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestTypeEmbd(unittest.TestCase, TypeEmbdTest, DPTestCase):
def setUp(self):
TypeEmbdTest.setUp(self)
Expand Down
3 changes: 3 additions & 0 deletions source/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@
TEST_DEVICE = "cpu"
else:
TEST_DEVICE = "cuda"

# see https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables
CI = os.environ.get("CI") == "true"

0 comments on commit 8174cf1

Please sign in to comment.