Skip to content

Commit

Permalink
Add an annotation to tests that need kaggle auth (#1470)
Browse files Browse the repository at this point in the history
We can skip these by default, for users who have not yet set them up.
We will need to set them up for CI, see
keras-team/keras-hub#1459
  • Loading branch information
bestalternativereviews4 authored Feb 27, 2024
1 parent b3e2a27 commit 436bc5b
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 0 deletions.
18 changes: 18 additions & 0 deletions keras_nlp/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest
import tensorflow as tf

Expand Down Expand Up @@ -83,6 +85,10 @@ def pytest_configure(config):
"markers",
"keras_3_only: mark test as a keras 3 only test",
)
config.addinivalue_line(
"markers",
"kaggle_key_required: mark test needing a kaggle key",
)


def pytest_collection_modifyitems(config, items):
Expand All @@ -107,6 +113,16 @@ def pytest_collection_modifyitems(config, items):
not backend_config.keras_3(),
reason="tests only run on with multi-backend keras",
)
found_kaggle_key = all(
[
os.environ.get("KAGGLE_USERNAME", None),
os.environ.get("KAGGLE_KEY", None),
]
)
kaggle_key_required = pytest.mark.skipif(
not found_kaggle_key,
reason="tests only run with a kaggle api key",
)
for item in items:
if "large" in item.keywords:
item.add_marker(skip_large)
Expand All @@ -116,6 +132,8 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(tf_only)
if "keras_3_only" in item.keywords:
item.add_marker(keras_3_only)
if "kaggle_key_required" in item.keywords:
item.add_marker(kaggle_key_required)


# Disable traceback filtering for quicker debugging of tests failures.
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_saved_model(self):
input_data=self.input_data,
)

@pytest.mark.kaggle_key_required
@pytest.mark.large
def test_smallest_preset(self):
self.run_preset_test(
Expand All @@ -69,6 +70,7 @@ def test_smallest_preset(self):
),
)

@pytest.mark.kaggle_key_required
@pytest.mark.extra_large
def test_all_presets(self):
for preset in GemmaBackbone.presets:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_generate_postprocess(self):
x = preprocessor.generate_postprocess(input_data)
self.assertAllEqual(x, "the quick brown fox")

@pytest.mark.kaggle_key_required
@pytest.mark.extra_large
def test_all_presets(self):
for preset in GemmaCausalLMPreprocessor.presets:
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/models/gemma/gemma_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def test_generate_compilation(self):
causal_lm.compile(sampler="greedy")
self.assertIsNone(causal_lm.generate_function)

@pytest.mark.kaggle_key_required
@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
Expand All @@ -150,6 +151,7 @@ def test_saved_model(self):
input_data=self.input_data,
)

@pytest.mark.kaggle_key_required
@pytest.mark.extra_large
def test_all_presets(self):
for preset in GemmaCausalLM.presets:
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/gemma/gemma_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def test_sequence_length_override(self):
x = preprocessor(input_data, sequence_length=4)
self.assertAllEqual(x["token_ids"], [1, 4, 9, 2])

@pytest.mark.kaggle_key_required
@pytest.mark.extra_large
def test_all_presets(self):
for preset in GemmaPreprocessor.presets:
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/models/gemma/gemma_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_errors_missing_special_tokens(self):
)
)

@pytest.mark.kaggle_key_required
@pytest.mark.large
def test_smallest_preset(self):
self.run_preset_test(
Expand All @@ -57,6 +58,7 @@ def test_smallest_preset(self):
expected_output=[[651, 4320, 8426, 25341, 235265]],
)

@pytest.mark.kaggle_key_required
@pytest.mark.extra_large
def test_all_presets(self):
for preset in GemmaTokenizer.presets:
Expand Down

0 comments on commit 436bc5b

Please sign in to comment.