diff --git a/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage_wholegraph.py b/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage_wholegraph.py index f760ef3e1b..964449276a 100644 --- a/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage_wholegraph.py +++ b/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage_wholegraph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Copyright (c) 2023-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -32,6 +32,20 @@ def get_cudart_version(): return major * 1000 + minor * 10 +pytestmark = [ + pytest.mark.skipif( + isinstance(torch, MissingModule) or not torch.cuda.is_available(), + reason="PyTorch with GPU support not available", + ), + pytest.mark.skipif( + isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available" + ), + pytest.mark.skipif( + get_cudart_version() < 11080, reason="not compatible with CUDA < 11.8" + ), +] + + def runtest(rank: int, world_size: int): torch.cuda.set_device(rank) @@ -69,13 +83,6 @@ def runtest(rank: int, world_size: int): @pytest.mark.sg -@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") -@pytest.mark.skipif( - isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available" -) -@pytest.mark.skipif( - get_cudart_version() < 11080, reason="not compatible with CUDA < 11.8" -) def test_feature_storage_wholegraph_backend(): world_size = torch.cuda.device_count() print("gpu count:", world_size) @@ -87,13 +94,6 @@ def test_feature_storage_wholegraph_backend(): @pytest.mark.mg -@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") -@pytest.mark.skipif( - isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available" -) -@pytest.mark.skipif( - get_cudart_version() < 11080, reason="not compatible with CUDA < 11.8" -) def test_feature_storage_wholegraph_backend_mg(): world_size = torch.cuda.device_count() print("gpu count:", world_size)