Skip to content

Commit

Permalink
Add preset tests and weights URLs (#1285)
Browse files Browse the repository at this point in the history
* Add preset tests and weights URLs

* Change filename conditional
  • Loading branch information
nkovela1 authored Oct 25, 2023
1 parent 4c43428 commit aff79b3
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 1 deletion.
4 changes: 3 additions & 1 deletion keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ def from_preset(
if not load_weights:
return model

filename = os.path.basename(metadata["weights_url"])
weights = keras.utils.get_file(
"model.h5",
filename,
metadata["weights_url"],
cache_subdir=os.path.join("models", preset),
file_hash=metadata["weights_hash"],
)

model.load_weights(weights)
return model

Expand Down
29 changes: 29 additions & 0 deletions keras_nlp/models/t5/t5_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,32 @@ def test_saved_model(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

@pytest.mark.large
def test_smallest_preset(self):
self.run_preset_test(
cls=T5Backbone,
preset="t5_small_multi",
input_data=self.input_data,
expected_output_shape={
"encoder_sequence_output": (2, 3, 512),
"decoder_sequence_output": (2, 3, 512),
},
expected_partial_output={
"encoder_sequence_output": ops.array(
[-0.0034, 0.0293, -0.0827, -0.1076]
),
"decoder_sequence_output": ops.array(
[0.0097, 0.3576, -0.1508, 0.0150]
),
},
)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in T5Backbone.presets:
self.run_preset_test(
cls=T5Backbone,
preset=preset,
input_data=self.input_data,
)
24 changes: 24 additions & 0 deletions keras_nlp/models/t5/t5_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
"layer_norm_epsilon": 1e-06,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/t5_small_multi/v1/model.weights.h5",
"weights_hash": "5a241ea61142eaf96ac1805898a2f2d1",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/t5_small_multi/v1/vocab.spm",
"vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3",
},
"t5_base_multi": {
"metadata": {
Expand All @@ -62,6 +66,10 @@
"layer_norm_epsilon": 1e-06,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/t5_base_multi/v1/model.weights.h5",
"weights_hash": "9bef4c6650d91d1ea438ee4a2bea47ad",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/t5_base_multi/v1/vocab.spm",
"vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3",
},
"t5_large_multi": {
"metadata": {
Expand All @@ -86,6 +94,10 @@
"layer_norm_epsilon": 1e-06,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/t5_large_multi/v1/model.weights.h5",
"weights_hash": "eab8eee1bad033e65324a71cd6e5a8e9",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/t5_large_multi/v1/vocab.spm",
"vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3",
},
"flan_small_multi": {
"metadata": {
Expand All @@ -111,6 +123,10 @@
"layer_norm_epsilon": 1e-06,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/flan_small_multi/v1/model.weights.h5",
"weights_hash": "4e39b0bab56606a9ab2b8e52a6bc7a9f",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/flan_small_multi/v1/vocab.spm",
"vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3",
},
"flan_base_multi": {
"metadata": {
Expand All @@ -135,6 +151,10 @@
"layer_norm_epsilon": 1e-06,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/flan_base_multi/v1/model.weights.h5",
"weights_hash": "b529270c5361db36d359a46403532b5c",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/flan_base_multi/v1/vocab.spm",
"vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3",
},
"flan_large_multi": {
"metadata": {
Expand All @@ -159,5 +179,9 @@
"layer_norm_epsilon": 1e-06,
},
"preprocessor_config": {},
"weights_url": "https://storage.googleapis.com/keras-nlp/models/flan_large_multi/v1/model.weights.h5",
"weights_hash": "50b8d3c88fc10db07e495d79ff29a1b6",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/flan_large_multi/v1/vocab.spm",
"vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3",
},
}
20 changes: 20 additions & 0 deletions keras_nlp/models/t5/t5_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import io

import pytest
import sentencepiece
import tensorflow as tf

Expand Down Expand Up @@ -64,3 +65,22 @@ def test_errors_missing_special_tokens(self):
)
with self.assertRaises(ValueError):
T5Tokenizer(proto=bytes_io.getvalue())

@pytest.mark.large
def test_smallest_preset(self):
for preset in T5Tokenizer.presets:
self.run_preset_test(
cls=T5Tokenizer,
preset=preset,
input_data=["The quick brown fox."],
expected_output=[[1996, 4248, 2829, 4419, 1012]],
)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in T5Tokenizer.presets:
self.run_preset_test(
cls=T5Tokenizer,
preset=preset,
input_data=self.input_data,
)

0 comments on commit aff79b3

Please sign in to comment.