Skip to content

Commit 530f772

Browse files
committed
[CI] Fix benchmarks for LLMs
ghstack-source-id: 1c27473 Pull-Request: #3212
1 parent 8d2ad89 commit 530f772

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

.github/workflows/benchmarks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ jobs:
8484
python3.10 -m pip install ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]" transformers
8585
python -m pip install "pybind11[global]"
8686
python3.10 -m pip install git+https://github.com/pytorch/tensordict
87-
python3.10 -m pip install safetensors tqdm pandas numpy matplotlib
87+
python3.10 -m pip install safetensors tqdm pandas numpy matplotlib ray
8888
python3.10 setup.py develop
8989
9090
# test import

.github/workflows/benchmarks_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ jobs:
7979
${{ matrix.device == 'CPU' && 'export CUDA_VISIBLE_DEVICES=' || '' }}
8080
8181
python3.10 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U
82-
python3.10 -m pip install ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]" transformers
82+
python3.10 -m pip install ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]" transformers ray
8383
python3.10 -m pip install "pybind11[global]"
8484
python3.10 -m pip install git+https://github.com/pytorch/tensordict
8585
python3.10 -m pip install safetensors tqdm pandas numpy matplotlib

benchmarks/test_llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616

1717
_has_transformers = importlib.import_module("transformers") is not None
1818

19+
# Skip all these tests if gpu is not available
20+
pytestmark = pytest.mark.skipif(
21+
not torch.cuda.is_available(), reason="GPU not available"
22+
)
23+
1924

2025
@pytest.fixture(scope="module")
2126
def transformers_wrapper():

0 commit comments

Comments
 (0)